From 2231d90743fb40511e860db4746d32a6e6e6a4a1 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 26 Jun 2026 20:23:42 +0000 Subject: [PATCH] Add FSDP all-gather stream overlap Signed-off-by: Jingyue Wu --- .../megatron_fsdp/experimental/__init__.py | 4 - .../megatron_fsdp/experimental/fully_shard.py | 2 + .../src/megatron_fsdp/experimental/module.py | 123 +++++++++- .../experimental/parameter_group.py | 9 +- .../distributed/megatron_fsdp/test_context.py | 100 ++++++++ .../test_experimental_fully_shard.py | 227 +++++++++++++++++- 6 files changed, 443 insertions(+), 22 deletions(-) create mode 100644 tests/unit_tests/distributed/megatron_fsdp/test_context.py diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/experimental/__init__.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/experimental/__init__.py index 87fd3dac52b..397c0b9cbc1 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/experimental/__init__.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/experimental/__init__.py @@ -16,15 +16,11 @@ from .dbuffer import DBuffer from .fully_shard import fully_shard -from .module import FsdpModule -from .parameter_group import FsdpParameterGroup from .placement import Flat, Partial, Placement, Placements, Replicate __all__ = [ "DBuffer", "Flat", - "FsdpModule", - "FsdpParameterGroup", "Partial", "Placement", "Placements", diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/experimental/fully_shard.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/experimental/fully_shard.py index 136b600b84c..c14c9d9bcf3 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/experimental/fully_shard.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/experimental/fully_shard.py @@ -21,6 +21,8 @@ from .module import FsdpModule from .placement import Placements +__all__ = ["fully_shard"] + def fully_shard( module: nn.Module, diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/experimental/module.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/experimental/module.py index 8907f0764b4..892a149f352 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/experimental/module.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/experimental/module.py @@ -14,6 +14,8 @@ """Module mixin for the minimal Megatron-FSDP path.""" +import dataclasses +from collections import deque from collections.abc import Callable from typing import cast @@ -26,10 +28,56 @@ from .placement import MeshAxis, Placements +@dataclasses.dataclass(frozen=True) +class DelayedRelease: + """A module whose unsharded storage can be released after its consumer event.""" + + consumer_event: torch.cuda.Event | None + module: "FsdpModule" + + +class FsdpContext: + """Runtime stream and release scheduler shared by one FSDP subtree.""" + + allgather_stream: torch.cuda.Stream + delayed_releases: deque[DelayedRelease] + root_module: "FsdpModule" + + def __init__(self, device: torch.device, root_module: "FsdpModule") -> None: + """Create rank-local stream state for a root FSDP subtree. + + Args: + device: Device on which this context schedules communication. + root_module: Outermost module that owns this context. + """ + self.root_module = root_module + self.delayed_releases = deque() + with torch.cuda.device(device): + self.allgather_stream = torch.cuda.Stream() + + def enqueue_release(self, module: "FsdpModule") -> None: + """Queue a module's unsharded storage for delayed release.""" + consumer_event = torch.cuda.current_stream(self.allgather_stream.device).record_event() + self.delayed_releases.append(DelayedRelease(consumer_event=consumer_event, module=module)) + + def drain_delayed_releases(self, target_length: int) -> None: + """Release queued module storages FIFO until the queue reaches ``target_length``.""" + if target_length < 0: + raise ValueError(f"target_length must be non-negative, got {target_length}.") + + while len(self.delayed_releases) > target_length: + delayed_release = self.delayed_releases.popleft() + with torch.cuda.stream(self.allgather_stream): + if delayed_release.consumer_event is not None: + self.allgather_stream.wait_event(delayed_release.consumer_event) + delayed_release.module.release_unsharded_storage() + + class FsdpModule: """Mixin attached to modules managed by the minimal FSDP path.""" _parameter_groups: tuple[FsdpParameterGroup, ...] + _context: FsdpContext | None _ready_grad_parameters: set[nn.Parameter] _num_training_parameters: int @@ -37,6 +85,7 @@ def __init__( self, mesh: DeviceMesh, placements: Placements, mixed_precision_policy: MixedPrecisionPolicy ) -> None: """Initialize FSDP runtime state on an already-constructed module.""" + self._context = None owned_parameters = _collect_owned_parameters(self) axis_indices = tuple(_axis_index(mesh, axis) for axis in placements.dp_axes) assert axis_indices == tuple( @@ -59,6 +108,34 @@ def __init__( ) self._register_hooks() + def _lazy_init_context(self) -> None: + """Initialize the shared runtime context for this FSDP root subtree.""" + if self._context is not None: + return + + fsdp_context = FsdpContext( + device=self._parameter_groups[0].main_weight.device, root_module=self + ) + for submodule in cast(nn.Module, self).modules(): + if not isinstance(submodule, FsdpModule): + continue + if submodule._context is not None: + raise RuntimeError( + "FSDP context is already initialized for a descendant module. " + "Run forward through the root FSDP module first." + ) + submodule._context = fsdp_context + + @property + def context(self) -> FsdpContext: + """Return the initialized runtime context.""" + assert self._context is not None + return self._context + + def is_root(self) -> bool: + """Return whether this module is the outermost FSDP unit in its context.""" + return self.context.root_module is self + def _register_hooks(self) -> None: module = cast(nn.Module, self) module.register_forward_pre_hook(lambda _module, _args: self.pre_forward()) @@ -84,31 +161,63 @@ def grad_hook(_parameter: nn.Parameter) -> None: def pre_forward(self) -> None: """Prepare full parameters for forward compute.""" + self._lazy_init_context() self._ready_grad_parameters.clear() - for group in self._parameter_groups: - group.sync_model_weight_from_main_weight() - group.unshard_parameters() + if self.is_root(): + allgather_stream = self.context.allgather_stream + allgather_stream.wait_stream(torch.cuda.current_stream(allgather_stream.device)) + self._unshard_parameter_groups(sync_model_weight=True) + + def _unshard_parameter_groups(self, *, sync_model_weight: bool) -> None: + """Materialize full parameters for this FSDP unit.""" + self.context.drain_delayed_releases(target_length=1) + + allgather_stream = self.context.allgather_stream + current_stream = torch.cuda.current_stream(allgather_stream.device) + + with torch.cuda.stream(allgather_stream): + for group in self._parameter_groups: + if sync_model_weight: + # TODO: After NVIDIA/Megatron-LM#5411 lands, move this sync to the + # optimizer post-step hook instead of running it every microbatch. + group.sync_model_weight_from_main_weight() + group.unshard_parameters() + current_stream.wait_stream(allgather_stream) def post_forward(self) -> None: """Return parameters to their sharded resting state after forward compute.""" + self._reshard_parameter_groups() + self.context.enqueue_release(self) + if self.is_root(): + self.context.drain_delayed_releases(target_length=0) + + def _reshard_parameter_groups(self) -> None: for group in self._parameter_groups: group.reshard_parameters() def pre_backward(self) -> None: """Prepare full parameters for backward compute.""" - for group in self._parameter_groups: - group.unshard_parameters() + self._unshard_parameter_groups(sync_model_weight=False) def post_backward(self) -> None: """Reduce gradients and return parameters to their sharded resting state.""" for group in self._parameter_groups: if group.requires_grad: group.reduce_gradients() - group.reshard_parameters() + self._reshard_parameter_groups() + self.context.enqueue_release(self) + if self.is_root(): + self.context.drain_delayed_releases(target_length=0) self._ready_grad_parameters.clear() + def release_unsharded_storage(self) -> None: + """Release unsharded storage owned by this FSDP unit.""" + for group in self._parameter_groups: + group.release_unsharded_storage() + + @property def parameter_groups(self) -> tuple[FsdpParameterGroup, ...]: - """Return parameter groups owned by this FSDP unit.""" + """Parameter groups owned by this FSDP unit.""" return self._parameter_groups diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/experimental/parameter_group.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/experimental/parameter_group.py index a2c7bd0bccb..4273b7aba2c 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/experimental/parameter_group.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/experimental/parameter_group.py @@ -206,12 +206,9 @@ def unshard_parameters(self) -> None: def reshard_parameters(self) -> None: """Install sharded DTensor parameters on the owning modules.""" self._switch_to_sharded_parameters() - # At post-backward time, replacing unsharded parameter .data with size-0 - # empty tensors would also be safe: autograd has consumed the saved - # forward views. That alternative is not much cleaner than releasing - # this storage, and splitting post-forward and post-backward reshard - # behavior would make the caller code less clean, so keep the shared - # storage-release path. + + def release_unsharded_storage(self) -> None: + """Release this group's full-parameter storage.""" self._unsharded_model_weight.release_storage() def reduce_gradients(self) -> None: diff --git a/tests/unit_tests/distributed/megatron_fsdp/test_context.py b/tests/unit_tests/distributed/megatron_fsdp/test_context.py new file mode 100644 index 00000000000..9ffd6bd8cdb --- /dev/null +++ b/tests/unit_tests/distributed/megatron_fsdp/test_context.py @@ -0,0 +1,100 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Unit tests for experimental Megatron-FSDP runtime contexts.""" + +import torch +from torch import nn +from torch.distributed.device_mesh import init_device_mesh + +from megatron.core.distributed.fsdp.src.megatron_fsdp.experimental import ( + Flat, + Placements, + fully_shard, +) + + +class NestedModel(nn.Module): + """Model with direct and child-owned parameters.""" + + def __init__(self) -> None: + super().__init__() + self.bias = nn.Parameter(torch.ones(4)) + self.inner = nn.Linear(4, 4, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Run the nested model.""" + return self.inner(x) + self.bias + + +class MultiChildModel(nn.Module): + """Model with direct parameters and multiple child FSDP units.""" + + def __init__(self, dim: int, num_children: int) -> None: + super().__init__() + self.bias = nn.Parameter(torch.ones(dim)) + self.layers = nn.ModuleList([nn.Linear(dim, dim, bias=False) for _ in range(num_children)]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Run through every child layer with a root-owned bias.""" + x = x + self.bias + for layer in self.layers: + x = torch.relu(layer(x)) + return x + + +def _flat_placements() -> Placements: + return Placements(dp_axes=[0], parameter=[Flat()], gradient=[Flat()], optimizer=[Flat()]) + + +def test_child_then_parent_share_one_context(distributed_setup): + """A parent FSDP unit should lazily create one context for its subtree.""" + device = distributed_setup.device + + mesh = init_device_mesh(device.type, (distributed_setup.world_size,)) + model = NestedModel().to(device) + + fully_shard(model.inner, mesh=mesh, placements=_flat_placements()) + fully_shard(model, mesh=mesh, placements=_flat_placements()) + + with torch.no_grad(): + model(torch.ones(2, 4, device=device)) + + assert model.inner.context is model.context + assert model.is_root() + assert not model.inner.is_root() + + +def test_two_child_subtrees_then_parent_collapse_to_one_context(distributed_setup): + """Sharding a parent should lazily assign one context across child subtrees.""" + device = distributed_setup.device + + mesh = init_device_mesh(device.type, (distributed_setup.world_size,)) + model = MultiChildModel(dim=4, num_children=2).to(device) + + fully_shard(model.layers[0], mesh=mesh, placements=_flat_placements()) + fully_shard(model.layers[1], mesh=mesh, placements=_flat_placements()) + fully_shard(model, mesh=mesh, placements=_flat_placements()) + + with torch.no_grad(): + model(torch.ones(2, 4, device=device)) + + assert model.layers[0].context is model.context + assert model.layers[1].context is model.context + + +def test_sibling_roots_without_parent_keep_separate_contexts(distributed_setup): + """Independent FSDP roots should not share runtime scheduling state.""" + device = distributed_setup.device + + mesh = init_device_mesh(device.type, (distributed_setup.world_size,)) + model = MultiChildModel(dim=4, num_children=2).to(device) + + fully_shard(model.layers[0], mesh=mesh, placements=_flat_placements()) + fully_shard(model.layers[1], mesh=mesh, placements=_flat_placements()) + + with torch.no_grad(): + model(torch.ones(2, 4, device=device)) + + assert model.layers[0].context is not model.layers[1].context + assert model.layers[0].is_root() + assert model.layers[1].is_root() diff --git a/tests/unit_tests/distributed/megatron_fsdp/test_experimental_fully_shard.py b/tests/unit_tests/distributed/megatron_fsdp/test_experimental_fully_shard.py index b9735ccd8c9..19aa6e7984d 100644 --- a/tests/unit_tests/distributed/megatron_fsdp/test_experimental_fully_shard.py +++ b/tests/unit_tests/distributed/megatron_fsdp/test_experimental_fully_shard.py @@ -9,6 +9,7 @@ from torch import nn from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor import DTensor +from torch.profiler import ProfilerActivity, profile from megatron.core.distributed.fsdp.src.megatron_fsdp.experimental import ( Flat, @@ -47,6 +48,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.inner(x) + self.bias +class MultiChildModel(nn.Module): + """Model with direct parameters and multiple child FSDP units.""" + + def __init__(self, dim: int, num_children: int) -> None: + super().__init__() + self.bias = nn.Parameter(torch.ones(dim)) + self.layers = nn.ModuleList([nn.Linear(dim, dim, bias=False) for _ in range(num_children)]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Run through every child layer with a root-owned bias.""" + x = x + self.bias + for layer in self.layers: + x = torch.relu(layer(x)) + return x + + class SaveNonLeafWeightView(torch.autograd.Function): """Autograd function that saves a non-leaf parameter view for backward.""" @@ -86,6 +103,13 @@ def _mb(num_bytes: int) -> str: return f"{num_bytes / 1024**2:.2f} MB" +def _events_overlap(first, second) -> bool: + return ( + first.time_range.start < second.time_range.end + and second.time_range.start < first.time_range.end + ) + + @pytest.mark.parametrize("num_microbatches", [1, 3]) def test_fully_shard_losses_match_baseline(distributed_setup, num_microbatches): """Minimal per-module FSDP training should match single-rank SGD.""" @@ -157,14 +181,207 @@ def test_nested_fully_shard_excludes_child_owned_parameters(distributed_setup): fully_shard(model, mesh=mesh, placements=_flat_placements()) inner_names = [ - name for group in model.inner.parameter_groups() for name in group.parameter_names + name for group in model.inner.parameter_groups for name in group.parameter_names ] - outer_names = [name for group in model.parameter_groups() for name in group.parameter_names] + outer_names = [name for group in model.parameter_groups for name in group.parameter_names] assert inner_names == ["weight"] assert outer_names == ["bias"] +def test_forward_peak_memory_bounds_in_flight_child_all_gathers(distributed_setup): + """Forward peak memory should stay below three live child all-gathers.""" + rank = distributed_setup.rank + world_size = distributed_setup.world_size + device = distributed_setup.device + if world_size < 2: + pytest.skip("This test requires at least 2 ranks.") + + mesh = init_device_mesh(device.type, (world_size,)) + dim = 4096 + dtype = torch.bfloat16 + model = MultiChildModel(dim=dim, num_children=4).to(dtype=dtype, device=device) + placements = _flat_placements() + policy = MixedPrecisionPolicy(main_params_dtype=dtype, main_grads_dtype=dtype) + for layer in model.layers: + fully_shard(layer, mesh=mesh, placements=placements, mixed_precision_policy=policy) + fully_shard(model, mesh=mesh, placements=placements, mixed_precision_policy=policy) + + x = torch.randn(2, dim, device=device, dtype=dtype) + with torch.no_grad(): + model(x) + torch.cuda.synchronize(device) + torch.cuda.empty_cache() + + resting_allocated = torch.cuda.memory_allocated(device) + torch.cuda.reset_peak_memory_stats(device) + with torch.no_grad(): + model(x) + torch.cuda.synchronize(device) + peak_delta = torch.cuda.max_memory_allocated(device) - resting_allocated + + child_weight_nbytes = dim * dim * torch.empty((), dtype=dtype).element_size() + bound_nbytes = 3 * child_weight_nbytes + + # A parent forward should keep one previous child unsharded until its compute + # stream consumer is safe, plus the current child being unsharded. The bound + # is looser than two child weights to avoid coupling this test to CUDA + # allocator granularity and small temporary buffers, while still catching + # delayed releases piling up across the four child layers. + assert peak_delta < bound_nbytes, ( + "FSDP forward peak memory exceeded the in-flight all-gather bound: " + f"rank={rank}, peak_delta={_mb(peak_delta)}, " + f"three_child_weights={_mb(bound_nbytes)}" + ) + + +def test_root_forward_returns_to_resting_memory(distributed_setup): + """Root forward should release child all-gather storage before returning.""" + rank = distributed_setup.rank + world_size = distributed_setup.world_size + device = distributed_setup.device + if world_size < 2: + pytest.skip("This test requires at least 2 ranks.") + + mesh = init_device_mesh(device.type, (world_size,)) + dim = 4096 + dtype = torch.bfloat16 + model = MultiChildModel(dim=dim, num_children=2).to(dtype=dtype, device=device) + placements = _flat_placements() + policy = MixedPrecisionPolicy(main_params_dtype=dtype, main_grads_dtype=dtype) + for layer in model.layers: + fully_shard(layer, mesh=mesh, placements=placements, mixed_precision_policy=policy) + fully_shard(model, mesh=mesh, placements=placements, mixed_precision_policy=policy) + + x = torch.randn(2, dim, device=device, dtype=dtype) + torch.cuda.synchronize(device) + torch.cuda.empty_cache() + resting_allocated = torch.cuda.memory_allocated(device) + + with torch.no_grad(): + output = model(x) + del output + torch.cuda.synchronize(device) + allocated_after_forward = torch.cuda.memory_allocated(device) + extra_allocated = allocated_after_forward - resting_allocated + child_weight_nbytes = dim * dim * torch.empty((), dtype=dtype).element_size() + + assert extra_allocated < child_weight_nbytes, ( + "Root forward did not return to resting memory after draining child releases: " + f"rank={rank}, extra_allocated={_mb(extra_allocated)}, " + f"one_child_weight={_mb(child_weight_nbytes)}" + ) + + +def test_root_backward_returns_to_resting_memory(distributed_setup): + """Root backward should release child all-gather storage before returning.""" + rank = distributed_setup.rank + world_size = distributed_setup.world_size + device = distributed_setup.device + if world_size < 2: + pytest.skip("This test requires at least 2 ranks.") + + mesh = init_device_mesh(device.type, (world_size,)) + dim = 4096 + dtype = torch.bfloat16 + model = MultiChildModel(dim=dim, num_children=2).to(dtype=dtype, device=device) + placements = _flat_placements() + policy = MixedPrecisionPolicy(main_params_dtype=dtype, main_grads_dtype=dtype) + for layer in model.layers: + fully_shard(layer, mesh=mesh, placements=placements, mixed_precision_policy=policy) + fully_shard(model, mesh=mesh, placements=placements, mixed_precision_policy=policy) + + x = torch.randn(2, dim, device=device, dtype=dtype, requires_grad=True) + output = model(x) + loss = output.float().square().mean() + torch.cuda.synchronize(device) + torch.cuda.empty_cache() + allocated_before_backward = torch.cuda.memory_allocated(device) + + loss.backward() + del loss, output + torch.cuda.synchronize(device) + allocated_after_backward = torch.cuda.memory_allocated(device) + extra_allocated = allocated_after_backward - allocated_before_backward + child_weight_nbytes = dim * dim * torch.empty((), dtype=dtype).element_size() + + assert extra_allocated < child_weight_nbytes, ( + "Root backward did not return to resting memory after draining child releases: " + f"rank={rank}, extra_allocated={_mb(extra_allocated)}, " + f"one_child_weight={_mb(child_weight_nbytes)}" + ) + + +def test_overlaps_all_gather_and_compute(distributed_setup): + """A shared root context should let child all-gathers overlap GEMM compute.""" + world_size = distributed_setup.world_size + device = distributed_setup.device + if world_size < 2: + pytest.skip("This test requires at least 2 ranks.") + + mesh = init_device_mesh(device.type, (world_size,)) + dim = 4096 + num_children = 4 + dtype = torch.bfloat16 + model = MultiChildModel(dim=dim, num_children=num_children).to(dtype=dtype) + placements = _flat_placements() + policy = MixedPrecisionPolicy(main_params_dtype=dtype, main_grads_dtype=dtype) + for layer in model.layers: + fully_shard(layer, mesh=mesh, placements=placements, mixed_precision_policy=policy) + fully_shard(model, mesh=mesh, placements=placements, mixed_precision_policy=policy) + + x = torch.randn(4096, dim, device=device, dtype=dtype, requires_grad=True) + + def train_one_iteration() -> None: + model.zero_grad(set_to_none=True) + model(x).sum().backward() + + train_one_iteration() + torch.cuda.synchronize(device) + + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + train_one_iteration() + prof.step() + torch.cuda.synchronize(device) + + cuda_events = [event for event in prof.events() if event.device_type.name == "CUDA"] + all_gather_events = [ + event + for event in cuda_events + if "nccl" in event.name.lower() and "allgather" in event.name.lower() + ] + compute_events = [ + event + for event in cuda_events + if any(token in event.name.lower() for token in ("gemm", "cutlass", "cublas")) + ] + assert all_gather_events, [event.name for event in cuda_events] + assert compute_events, [event.name for event in cuda_events] + + all_gather_streams = {event.device_resource_id for event in all_gather_events} + compute_streams = {event.device_resource_id for event in compute_events} + assert len(all_gather_streams) == 1 + assert all_gather_streams.isdisjoint(compute_streams) + + overlap_count = sum( + any(_events_overlap(all_gather_event, compute_event) for compute_event in compute_events) + for all_gather_event in all_gather_events + ) + # This profiles a full forward/backward iteration, so backward all-gathers are + # included in all_gather_events. The expected overlap count is from the forward + # child pipeline: each child after the first can all-gather while the previous + # child computes, giving num_children - 1 overlaps. Backward does not overlap + # in this all-gather-only path because gradient reduction is not delayed: + # each module synchronously reduces gradients in post_backward before autograd + # reaches the next module's pre_backward all-gather. The next PR addresses + # this by delaying gradient reduction. + expected_overlap_count = num_children - 1 + assert overlap_count >= expected_overlap_count, ( + f"Expected at least {expected_overlap_count} all-gather events to overlap compute, " + f"got {overlap_count}/{len(all_gather_events)}." + ) + + def test_frozen_parameter_group_does_not_allocate_main_grad(distributed_setup): """A non-trainable parameter group should not allocate persistent main gradients.""" world_size = distributed_setup.world_size @@ -178,7 +395,7 @@ def test_frozen_parameter_group_does_not_allocate_main_grad(distributed_setup): fully_shard(model, mesh=mesh, placements=_flat_placements()) - (group,) = model.parameter_groups() + (group,) = model.parameter_groups assert not group.requires_grad assert group.main_grad is None @@ -260,7 +477,7 @@ def test_cpu_initialized_parameters_shard_to_mesh_device(distributed_setup): fully_shard(model, mesh=mesh, placements=_flat_placements()) - (group,) = model.parameter_groups() + (group,) = model.parameter_groups full_weight = group.model_weight.allgather(0).get_local_tensor(0) assert full_weight.device.type == device.type torch.testing.assert_close(full_weight, expected_weight) @@ -277,7 +494,7 @@ def test_non_leaf_parameter_view_survives_storage_resize(distributed_setup): model = NonLeafViewModel().to(device) fully_shard(model, mesh=mesh, placements=_flat_placements()) - group = model.parameter_groups()[0] + group = model.parameter_groups[0] x = torch.randn(8, device=device, requires_grad=True) loss = model(x).sum()