Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,11 @@

from .dbuffer import DBuffer
from .fully_shard import fully_shard
from .module import FsdpModule
from .parameter_group import FsdpParameterGroup
from .placement import Flat, Partial, Placement, Placements, Replicate

__all__ = [
"DBuffer",
"Flat",
"FsdpModule",
"FsdpParameterGroup",
"Partial",
"Placement",
"Placements",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from .module import FsdpModule
from .placement import Placements

__all__ = ["fully_shard"]


def fully_shard(
module: nn.Module,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Module mixin for the minimal Megatron-FSDP path."""

import dataclasses
from collections import deque
from collections.abc import Callable
from typing import cast

Expand All @@ -26,17 +28,64 @@
from .placement import MeshAxis, Placements


@dataclasses.dataclass(frozen=True)
class DelayedRelease:
"""A module whose unsharded storage can be released after its consumer event."""

consumer_event: torch.cuda.Event | None
module: "FsdpModule"


class FsdpContext:
"""Runtime stream and release scheduler shared by one FSDP subtree."""

allgather_stream: torch.cuda.Stream
delayed_releases: deque[DelayedRelease]
root_module: "FsdpModule"

def __init__(self, device: torch.device, root_module: "FsdpModule") -> None:
"""Create rank-local stream state for a root FSDP subtree.

Args:
device: Device on which this context schedules communication.
root_module: Outermost module that owns this context.
"""
self.root_module = root_module
self.delayed_releases = deque()
with torch.cuda.device(device):
self.allgather_stream = torch.cuda.Stream()

def enqueue_release(self, module: "FsdpModule") -> None:
"""Queue a module's unsharded storage for delayed release."""
consumer_event = torch.cuda.current_stream(self.allgather_stream.device).record_event()
self.delayed_releases.append(DelayedRelease(consumer_event=consumer_event, module=module))

def drain_delayed_releases(self, target_length: int) -> None:
"""Release queued module storages FIFO until the queue reaches ``target_length``."""
if target_length < 0:
raise ValueError(f"target_length must be non-negative, got {target_length}.")

while len(self.delayed_releases) > target_length:
delayed_release = self.delayed_releases.popleft()
with torch.cuda.stream(self.allgather_stream):
if delayed_release.consumer_event is not None:
self.allgather_stream.wait_event(delayed_release.consumer_event)
delayed_release.module.release_unsharded_storage()


class FsdpModule:
"""Mixin attached to modules managed by the minimal FSDP path."""

_parameter_groups: tuple[FsdpParameterGroup, ...]
_context: FsdpContext | None
_ready_grad_parameters: set[nn.Parameter]
_num_training_parameters: int

def __init__(
self, mesh: DeviceMesh, placements: Placements, mixed_precision_policy: MixedPrecisionPolicy
) -> None:
"""Initialize FSDP runtime state on an already-constructed module."""
self._context = None
owned_parameters = _collect_owned_parameters(self)
axis_indices = tuple(_axis_index(mesh, axis) for axis in placements.dp_axes)
assert axis_indices == tuple(
Expand All @@ -59,6 +108,34 @@ def __init__(
)
self._register_hooks()

def _lazy_init_context(self) -> None:
"""Initialize the shared runtime context for this FSDP root subtree."""
if self._context is not None:
return

fsdp_context = FsdpContext(
device=self._parameter_groups[0].main_weight.device, root_module=self
)
for submodule in cast(nn.Module, self).modules():
if not isinstance(submodule, FsdpModule):
continue
if submodule._context is not None:
raise RuntimeError(
"FSDP context is already initialized for a descendant module. "
"Run forward through the root FSDP module first."
)
submodule._context = fsdp_context

@property
def context(self) -> FsdpContext:
"""Return the initialized runtime context."""
assert self._context is not None
return self._context

def is_root(self) -> bool:
"""Return whether this module is the outermost FSDP unit in its context."""
return self.context.root_module is self

def _register_hooks(self) -> None:
module = cast(nn.Module, self)
module.register_forward_pre_hook(lambda _module, _args: self.pre_forward())
Expand All @@ -84,31 +161,63 @@ def grad_hook(_parameter: nn.Parameter) -> None:

def pre_forward(self) -> None:
"""Prepare full parameters for forward compute."""
self._lazy_init_context()
self._ready_grad_parameters.clear()
for group in self._parameter_groups:
group.sync_model_weight_from_main_weight()
group.unshard_parameters()
if self.is_root():
allgather_stream = self.context.allgather_stream
allgather_stream.wait_stream(torch.cuda.current_stream(allgather_stream.device))
self._unshard_parameter_groups(sync_model_weight=True)

def _unshard_parameter_groups(self, *, sync_model_weight: bool) -> None:
"""Materialize full parameters for this FSDP unit."""
self.context.drain_delayed_releases(target_length=1)

allgather_stream = self.context.allgather_stream
current_stream = torch.cuda.current_stream(allgather_stream.device)

with torch.cuda.stream(allgather_stream):
for group in self._parameter_groups:
if sync_model_weight:
# TODO: After NVIDIA/Megatron-LM#5411 lands, move this sync to the
# optimizer post-step hook instead of running it every microbatch.
group.sync_model_weight_from_main_weight()
group.unshard_parameters()
current_stream.wait_stream(allgather_stream)

def post_forward(self) -> None:
"""Return parameters to their sharded resting state after forward compute."""
self._reshard_parameter_groups()
self.context.enqueue_release(self)
if self.is_root():
self.context.drain_delayed_releases(target_length=0)

def _reshard_parameter_groups(self) -> None:
for group in self._parameter_groups:
group.reshard_parameters()

def pre_backward(self) -> None:
"""Prepare full parameters for backward compute."""
for group in self._parameter_groups:
group.unshard_parameters()
self._unshard_parameter_groups(sync_model_weight=False)

def post_backward(self) -> None:
"""Reduce gradients and return parameters to their sharded resting state."""
for group in self._parameter_groups:
if group.requires_grad:
group.reduce_gradients()
group.reshard_parameters()
self._reshard_parameter_groups()
self.context.enqueue_release(self)
if self.is_root():
self.context.drain_delayed_releases(target_length=0)
self._ready_grad_parameters.clear()

def release_unsharded_storage(self) -> None:
"""Release unsharded storage owned by this FSDP unit."""
for group in self._parameter_groups:
group.release_unsharded_storage()

@property
def parameter_groups(self) -> tuple[FsdpParameterGroup, ...]:
"""Return parameter groups owned by this FSDP unit."""
"""Parameter groups owned by this FSDP unit."""
return self._parameter_groups


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,9 @@ def unshard_parameters(self) -> None:
def reshard_parameters(self) -> None:
"""Install sharded DTensor parameters on the owning modules."""
self._switch_to_sharded_parameters()
# At post-backward time, replacing unsharded parameter .data with size-0
# empty tensors would also be safe: autograd has consumed the saved
# forward views. That alternative is not much cleaner than releasing
# this storage, and splitting post-forward and post-backward reshard
# behavior would make the caller code less clean, so keep the shared
# storage-release path.

def release_unsharded_storage(self) -> None:
"""Release this group's full-parameter storage."""
self._unsharded_model_weight.release_storage()

def reduce_gradients(self) -> None:
Expand Down
100 changes: 100 additions & 0 deletions tests/unit_tests/distributed/megatron_fsdp/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.

"""Unit tests for experimental Megatron-FSDP runtime contexts."""

import torch
from torch import nn
from torch.distributed.device_mesh import init_device_mesh

from megatron.core.distributed.fsdp.src.megatron_fsdp.experimental import (
Flat,
Placements,
fully_shard,
)


class NestedModel(nn.Module):
"""Model with direct and child-owned parameters."""

def __init__(self) -> None:
super().__init__()
self.bias = nn.Parameter(torch.ones(4))
self.inner = nn.Linear(4, 4, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Run the nested model."""
return self.inner(x) + self.bias


class MultiChildModel(nn.Module):
"""Model with direct parameters and multiple child FSDP units."""

def __init__(self, dim: int, num_children: int) -> None:
super().__init__()
self.bias = nn.Parameter(torch.ones(dim))
self.layers = nn.ModuleList([nn.Linear(dim, dim, bias=False) for _ in range(num_children)])

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Run through every child layer with a root-owned bias."""
x = x + self.bias
for layer in self.layers:
x = torch.relu(layer(x))
return x


def _flat_placements() -> Placements:
return Placements(dp_axes=[0], parameter=[Flat()], gradient=[Flat()], optimizer=[Flat()])


def test_child_then_parent_share_one_context(distributed_setup):
"""A parent FSDP unit should lazily create one context for its subtree."""
device = distributed_setup.device

mesh = init_device_mesh(device.type, (distributed_setup.world_size,))
model = NestedModel().to(device)

fully_shard(model.inner, mesh=mesh, placements=_flat_placements())
fully_shard(model, mesh=mesh, placements=_flat_placements())

with torch.no_grad():
model(torch.ones(2, 4, device=device))

assert model.inner.context is model.context
assert model.is_root()
assert not model.inner.is_root()


def test_two_child_subtrees_then_parent_collapse_to_one_context(distributed_setup):
"""Sharding a parent should lazily assign one context across child subtrees."""
device = distributed_setup.device

mesh = init_device_mesh(device.type, (distributed_setup.world_size,))
model = MultiChildModel(dim=4, num_children=2).to(device)

fully_shard(model.layers[0], mesh=mesh, placements=_flat_placements())
fully_shard(model.layers[1], mesh=mesh, placements=_flat_placements())
fully_shard(model, mesh=mesh, placements=_flat_placements())

with torch.no_grad():
model(torch.ones(2, 4, device=device))

assert model.layers[0].context is model.context
assert model.layers[1].context is model.context


def test_sibling_roots_without_parent_keep_separate_contexts(distributed_setup):
"""Independent FSDP roots should not share runtime scheduling state."""
device = distributed_setup.device

mesh = init_device_mesh(device.type, (distributed_setup.world_size,))
model = MultiChildModel(dim=4, num_children=2).to(device)

fully_shard(model.layers[0], mesh=mesh, placements=_flat_placements())
fully_shard(model.layers[1], mesh=mesh, placements=_flat_placements())

with torch.no_grad():
model(torch.ones(2, 4, device=device))

assert model.layers[0].context is not model.layers[1].context
assert model.layers[0].is_root()
assert model.layers[1].is_root()
Loading
Loading