Skip to content

[Megatron-FSDP] MaxPoolAllocator for double-buffering hybrid architectures.#5462

Open
cspades wants to merge 9 commits into
NVIDIA:mainfrom
cspades:cye/maxpool-dbuf
Open

[Megatron-FSDP] MaxPoolAllocator for double-buffering hybrid architectures.#5462
cspades wants to merge 9 commits into
NVIDIA:mainfrom
cspades:cye/maxpool-dbuf

Conversation

@cspades

@cspades cspades commented Jun 23, 2026

Copy link
Copy Markdown
Member
  • I, the PR author, have personally reviewed every line of this PR.

What does this PR do ?

image

Iterating through all FSDP units, data buckets are categorized by data-type, sorted from small to large, and compared to the current MaxPool. If there are not enough buckets in the pool to support the unit, buckets are added to the pool (with size 0). If the largest buckets of the pool are not large enough to support the buckets in the unit (assigned to the pool from smallest to largest), the buckets in the pool are enlarged. After this process, we arrive at a minimal set of buckets that can symmetrically double-buffer every FSDP unit in the model.

  • Adds hybrid architecture double buffering via FSDP unit max-pooling for Megatron-FSDP. (V1)
    • Opens up CG or NCCL UBR support for hybrid architectures, which will help support users for a while.
  • Adds the strict_assignment state to attempt to assign the same bucket previously assigned to an FSDP unit before warning the user and assigning a different bucket to the unit.
    • If this warning appears during warmup or CUDA graph capture, likely some memory is being orphaned and you will hit numerical errors.
  • Fixes an issue where parameters / buckets that are not members of an FSDP unit will pre-fetch subsequent buckets that aren't subsequently used, exhausting buffers in the double buffer allocator and causing an allocation error.
    • Only necessary for double buffer allocators, which require careful management of the 2 buffers in the pool.
  • Deprecates --grad-reduce-in-bf16 / reduce_grad_in_fp32 for Megatron-FSDP, which has been incredibly confusing to use. Default arguments (auto) assume BF16 for both, so will not OOM any existing user's configs.
  • Adds a call to torch.autograd.graph.set_override_stale_capture_stream(True) (only supported on new PyTorch versions since Detect and fix stale stream references in autograd during CUDA graph capture pytorch/pytorch#180090) to prevent full-iteration CG errors like this:
[rank0]: RuntimeError: During CUDA graph capture, autograd node 'torch::autograd::AccumulateGrad' has a stale reference to the default stream (stream 0) from warmup. This will invalidate the capture because cudaStreamWaitEvent on the default stream pulls a non-capturing stream into the graph.

[rank0]: To fix, either:
[rank0]:   (a) Run warmup on the same stream that capture will use, or
[rank0]:   (b) Delete references to the loss / autograd graph (e.g. `del loss`) before capture, or
[rank0]:   (c) Call torch.autograd.graph.set_override_stale_capture_stream(True) to automatically redirect stale nodes to the capturing stream.

^ (a) is annoying to implement, (b) is dirty, and (c) is EZ-PZ and recommended.

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact @NVIDIA/mcore-oncall.

Issue tracking

For PRs from open-source community contributors:

  • New features: a linked issue is required. Please open a feature request and reference it here before submitting the PR.
  • Small updates (bug fixes, minor improvements): a linked issue is recommended and will accelerate the PR review process.

Linked issue:

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment @NVIDIA/mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

@cspades cspades self-assigned this Jun 23, 2026
@copy-pr-bot

copy-pr-bot Bot commented Jun 23, 2026

Copy link
Copy Markdown

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@cspades cspades force-pushed the cye/maxpool-dbuf branch from af4ad72 to b81af2c Compare June 23, 2026 22:21
@cspades cspades marked this pull request as ready for review June 25, 2026 01:24
@cspades cspades requested review from a team as code owners June 25, 2026 01:24
Comment on lines 4435 to 4448
# Do not release the buckets that are being all-gathered.
no_fsdp_units = True
for bucket_id in ag_buckets:
self.bucket_can_be_released[self.get_bucket_key(bucket_id, bwd)] = False
fsdp_unit_id = parameter_groups[bucket_id].fsdp_unit_id
if fsdp_unit_id is not None and fsdp_unit_id >= 0:
no_fsdp_units = False

# If prefetch is enabled, we will add prefetch buckets to ag_buckets.
if prefetch:
# If there are no FSDP units associated with params, we should not prefetch.
if prefetch and not no_fsdp_units:

def next_bucket_id(ag_buckets):
"""

@cspades cspades Jun 25, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shjwudp Please take a look at this code. It is a behavior change in how we pre-fetch.

Without this change, Nemotron will hit an error where a parameter owned by a module that is not an FSDP unit will pre-fetch buckets that will not be used for a long time. In this case, I believe the LanguageModelEmbedding will pre-fetch Layer 2 during the last MTP layer of the model.

This pre-fetch of an irrelevant FSDP unit will cause both buffers in the double buffer pool to be used, preventing our code from allocating the MTP layer because we do not have enough free buckets to support it.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a precondition for using a double buffer allocator to avoid this modification broadcast to all use cases?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely, will add.

@cspades cspades Jun 26, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

        if prefetch and not (
            # When double buffering, if parameters are not members of FSDP units,
            # we should skip pre-fetch to efficiently supply buffers from the pool.
            # Non-unit module pre-fetch can run inside other FSDP unit modules and
            # un-shard irrelevant model components that pointlessly steal buffer
            # allocations from the expected FSDP unit allocation and violating
            # the maximum limit of 2 buffers allocated at any point in time.
            self.buffer.ddp_config.fsdp_double_buffer
            and no_fsdp_units
        ):

So basically, we'll still do a "naive" pre-fetch if we are not using double buffers.

That being said, I feel like sometimes this can increase memory overhead. I think it is a trade-off. If we "naively" pre-fetch the next bucket even though the current bucket is not an FSDP unit, then it means:

  • If the next layer is relevant and is computed after the current layer, then we will have better overlap and performance.
    • LanguageModelEmbedding -> first Transformer layer.
  • If the next layer is not relevant and is not computed after the current layer, then we will un-shard some extra bucket(s) and increase the memory overhead to support the current layer, the actual next compute layer, and the next un-used extra layer.
    • LanguageModelEmbedding tied to MultiTokePrediction so we pre-fetch maybe 1 Transformer, 1 MoETransformer, and some Layer 2 Transformer.

I think the above 2 points are both somewhat common, but one thing we can do is to suggest users use double buffer for weight-tied output layer, otherwise there will be higher memory overhead at the end of the model if we do not skip this pre-fetch.

If this can be customized by the user at a per-module level, the user can decide the pre-fetch graph instead of us.

Comment on lines +2549 to +2553
dtype_attr=(
self.mp_policy.grad_comm_dtype
if isinstance(self.mp_policy.grad_comm_dtype, torch.dtype)
else "grad_dtype"
),

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Translate: If the gradient communication data-type is set, then that is what this allocator will allocate, and the MaxPoolAllocator needs to know the correct dtype to properly plan ahead for the bucket assignments. Otherwise, just check the main gradient data-type for the ParameterGroup.

Before this change, there was no data-type argument, the FixedPoolAllocator just used the dtype to find symmetric buckets, so this doesn't change any behavior besides the case where we use a custom gradient communication data-type.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Translation looks good. Even better if it's a code comment. The comment can go either here or to the allocator's constructor.

Comment on lines +967 to +970
def _build_fixed_max_pool(self):
"""
Compute the maximum double-buffer pool required to support all FSDP units.
"""

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The max pooling algorithm is here. The rest of the code is similar to FixedPoolAllocator.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do max pooling decisions depend on prefetching/overlapping? Conceptually, more aggressive prefetching needs more memory and therefore affects the max pooling algorithm?

else:
if self.ddp_config.data_parallel_sharding_strategy == "optim_grads_params":
self.fsdp_unit_modules = [TransformerLayer]
self.fsdp_unit_modules = [TransformerLayer, MoETransformerLayer, MambaLayer]

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before, we did not shard MambaLayer at all.

Comment on lines +210 to +218
if hasattr(torch.autograd.graph, 'set_override_stale_capture_stream'):
torch.autograd.graph.set_override_stale_capture_stream(True)
else:
logger.warning(
'torch.autograd.graph.set_override_stale_capture_stream is not '
'available in this PyTorch version; CUDA graph capture may fail '
'if autograd nodes hold stale references to non-capturing streams. '
'Upgrade to a PyTorch build that includes pytorch/pytorch#180090.'
)

@cspades cspades Jun 25, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should just be something that we should call if we have a new enough PyTorch version: pytorch/pytorch#180090

It harmlessly makes things a lot easier w.r.t. stragglers on the Autograd / accumulate stream. cc @nanz-nv

cspades added 9 commits June 26, 2026 10:13
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
… later, and grad_comm_dtype not respected during FixedPool/MaxPool bucket planning.

Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
…ction.

Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>

@wujingyue wujingyue left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deprecates --grad-reduce-in-bf16 / reduce_grad_in_fp32 for Megatron-FSDP, which has been incredibly confusing to use. Default arguments (auto) assume BF16 for both, so will not OOM any existing user's configs.
Adds a call to torch.autograd.graph.set_override_stale_capture_stream(True) (only supported on new PyTorch versions since pytorch/pytorch#180090) to prevent full-iteration CG errors like this:

Thanks for the PR and the figures!

While I'm still reviewing the rest, can these two changes go to a separate PR(s)? https://google.github.io/eng-practices/review/developer/small-cls.html

dtype: torch.dtype,
device: torch.device,
mem_alloc_context: Optional[Callable] = None,
strict_assignments: bool = True,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adds the strict_assignment state to attempt to assign the same bucket previously assigned to an FSDP unit before warning the user and assigning a different bucket to the unit.

What's the downside of strict_assignment=True? I wonder if it should be always on, or always on/off for certain allocators so we have fewer knobs to worry about.

@cspades cspades Jun 26, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should always be on, it falls back to the original behavior if it is unsuccessful, but tries to allocate what it has allocated before, the scope of which is defined by this boolean being set to True. So it fixes CG issues while improving the rigor of the double buffer assignment strategy in general.

In V2, we should do this with context managers or something for TracePoolAlloc, but here it is far too messy to implement that when we don't have a use case for multiple model call patterns.

@cspades

cspades commented Jun 26, 2026

Copy link
Copy Markdown
Member Author

While I'm still reviewing the rest, can these two changes go to a separate PR(s)? https://google.github.io/eng-practices/review/developer/small-cls.html

@wujingyue Considering this exact commit needs to be merged for the NeMo release code freeze in a few days, could we make an exception in this case? These three features are all needed for Nemotron benchmarks. I'm concerned that waiting on 3 PR's to be merged in a few work days is not feasible.

Comment on lines +222 to +223
--record-memory-history
--memory-snapshot-path "${NSYS_PROFILE_PATH}/torch_memprof_node${SLURM_NODEID}_rank${SLURM_PROCID}.pickle"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove?

--eval-interval 100
--save-interval 1000
--log-throughput
--logging-level 20

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove?


if self.megatron_fsdp_max_pool_double_buffer:
# MaxPoolAllocator is a type of double-buffer allocator.
self.fsdp_double_buffer = True

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of quietly overriding fsdp_double_buffer, we may want to assert fsdp_double_buffer to make sure the user understands the contract.

Comment on lines +2549 to +2553
dtype_attr=(
self.mp_policy.grad_comm_dtype
if isinstance(self.mp_policy.grad_comm_dtype, torch.dtype)
else "grad_dtype"
),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Translation looks good. Even better if it's a code comment. The comment can go either here or to the allocator's constructor.

@wujingyue

Copy link
Copy Markdown
Contributor

I'm concerned that waiting on 3 PR's to be merged in a few work days is not feasible.

In my experience, reviewing three stacked PRs is usually faster than reviewing a single large PR. Stacked PRs can also be reviewed in parallel, though I may be missing something about how the review process works in Megatron-LM.

As a less ideal alternative, you could keep everything in a single PR but split it into three well-structured commits. GitHub's UI supports reviewing commits individually, which provides a similar incremental review experience.

assert (
len(self.fsdp_double_buffer_units) > 0
), "Found no FSDP units to use max-sized buffering."
if torch.distributed.get_rank() == 0:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider log_single_rank to reduce indentation

), "Found no FSDP units to use max-sized buffering."
if torch.distributed.get_rank() == 0:
if any(
pg.fsdp_unit_id == -1 or pg.fsdp_unit_id is None for pg in self.fsdp_param_groups

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems weird to have both -1 and None to represent the same meaning. But I guess this is likely a pre-existing problem.

self.bucket_alloc_index[bucket_id] = (-1, bucket_offset)

# Log the max pool bucket sizes and bucket IDs responsible.
if torch.distributed.get_rank() == 0:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still needed given log_single_rank below?

def __init__(
self,
name: str,
fsdp_param_groups: List["ParameterGroup"],

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to myself: these parameter groups may span across FSDP units.

Comment on lines +967 to +970
def _build_fixed_max_pool(self):
"""
Compute the maximum double-buffer pool required to support all FSDP units.
"""

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do max pooling decisions depend on prefetching/overlapping? Conceptually, more aggressive prefetching needs more memory and therefore affects the max pooling algorithm?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants