Skip to content

Optimize 10b 2b sdpa moe#11

Open
nscottnichols wants to merge 6 commits into
saforem2:ezpzfrom
nscottnichols:optimize_10b_2b_sdpa_moe_forward
Open

Optimize 10b 2b sdpa moe#11
nscottnichols wants to merge 6 commits into
saforem2:ezpzfrom
nscottnichols:optimize_10b_2b_sdpa_moe_forward

Conversation

@nscottnichols
Copy link
Copy Markdown

@nscottnichols nscottnichols commented May 12, 2026

Added some optimizations for the MoE path, wanted to PR before our forks drift too far apart

Summary by Sourcery

Optimize Mixture-of-Experts routing, dispatch, and expert execution paths for better load balancing and no-grad performance while preserving numerical behavior.

New Features:

  • Add optional force-load-balance MoE token dispatcher path with equal-split all-to-all routing and debug counters controllable via environment variables.
  • Introduce batched no-grad expert execution and cached weight transformations for uniform per-expert token counts.
  • Provide in-place deterministic scatter_add for forward-only paths and integrate it into MoE combine steps.

Enhancements:

  • Extend token dispatch metadata and routing to reuse precomputed per-expert token counts across router, dispatcher, and experts, avoiding redundant histogram computation.
  • Optimize all-to-all token dispatch/combine with equal-padding fast paths and shape-aware permute/unpermute logic, including specialized handling for TorchAO dispatch.
  • Improve MoE scoring paths by avoiding unnecessary dtype upcasts/ casts where possible and by differentiating bf16 scoring behavior for monitoring.
  • Make expert FSDP sharding robust to PyTorch builds lacking per-parameter mesh APIs by falling back to a two-phase sharding strategy.

Tests:

  • Add extensive unit tests for MoE fast-path counters, deterministic scatter_add in-place behavior, token dispatcher correctness, equal-split and force-load-balance routing logic, and equal-count permute/unpermute paths.

@sourcery-ai
Copy link
Copy Markdown

sourcery-ai Bot commented May 12, 2026

Reviewer's Guide

Optimizes Mixture-of-Experts dispatch/combine and expert execution paths, adds fast paths for force-load-balance routing and equal all‑to‑all padding, introduces deterministic in-place scatter_add usage where safe, and extends tests and FSDP sharding logic to cover the new behavior and environments.

Flow diagram for updated MoE forward path and expert fast paths

flowchart TD
  Router[Router.forward]
  TDDispatch[TokenDispatcher.dispatch]
  ExpertsForward[MoEBlock._experts_forward]
  ExpertsLoop[_run_experts_for_loop]
  ExpertsBatched[Batched no-grad equal-count experts]
  TDCombine[TokenDispatcher.combine]

  Router -->|compute scores, selected_experts_indices, num_tokens_per_expert| TDDispatch
  TDDispatch -->|routed_input, num_tokens_local, metadata.num_tokens_per_expert_list| ExpertsForward

  ExpertsForward -->|decide num_tokens_for_experts| ExpertsLoop

  ExpertsLoop -->|check
  equal counts
  and no grad| ExpertsBatched
  ExpertsLoop -->|otherwise
  per-expert matmuls| TDCombine

  ExpertsBatched -->|use cached w13,w2_t
  batched bmm| TDCombine

  TDCombine -->|_scatter_add_forward_or_autograd
  deterministic_scatter_add or deterministic_scatter_add_| Output[MoE output]
Loading

Flow diagram for AllToAllTokenDispatcher equal-split and load-balance fast paths

flowchart TD
  A[AllToAllTokenDispatcher.dispatch]
  CheckEP[ep_mesh is None?]
  LocalDispatch[LocalTokenDispatcher.dispatch]
  ForceLB[force_load_balance fast path]
  Histc[Compute num_tokens_per_expert with histc]
  EqualA2A[Compute input_splits, output_splits]
  EqualSize[Compute equal_a2a_split_size]
  Pad[Pad routed_input to equal splits]
  DirectRoute[_force_load_balance_equal_split_routed_input]
  A2A[all_to_all_single_autograd dispatch]
  Permute[_permute or _permute_equal_counts]
  Meta[Build AllToAllDispatchMetadata]

  A --> CheckEP
  CheckEP -->|yes| LocalDispatch
  CheckEP -->|no| ForceLB

  ForceLB -->|conditions met| EqualA2A
  ForceLB -->|conditions not met| Histc

  Histc --> EqualA2A

  EqualA2A -->|can use equal splits| EqualSize
  EqualA2A -->|cannot use equal splits| A2A

  EqualSize -->|direct equal-split route
  and no score_before_experts| DirectRoute
  EqualSize -->|otherwise| Pad

  DirectRoute --> A2A
  Pad --> A2A

  A2A --> Permute
  Permute --> Meta

  subgraph CombinePath
    B[AllToAllTokenDispatcher.combine]
    Unpermute[_unpermute or equal_count_unpermute]
    PadCombine[_pad_to_equal_splits]
    A2ACombine[all_to_all_single_autograd combine]
    Compact[_compact_equal_splits]
    Scatter[_scatter_add_forward_or_autograd]

    B --> Unpermute
    Unpermute --> PadCombine
    PadCombine --> A2ACombine
    A2ACombine --> Compact
    Compact --> Scatter
  end
Loading

File-Level Changes

Change Details Files
Add MoE fast-path debugging counters and in-place deterministic scatter_add, and route Local/AllToAll combine through a helper that chooses in-place vs autograd-safe paths.
  • Introduce environment-gated MoE fast-path counters and rank selection helpers, printing aggregated stats at process exit when enabled.
  • Add deterministic_scatter_add_ in-place op and wrap scatter_add calls in a helper that picks out-of-place vs in-place depending on grad mode and aliasing with inputs.
  • Update LocalTokenDispatcher.combine and AllToAllTokenDispatcher.combine to apply scores in native dtype, build scatter indices once, and use the new scatter_add helper to avoid mutating aliased tensors in autograd paths.
torchtitan/models/common/token_dispatcher.py
torchtitan/ops/scatter_add.py
tests/unit_tests/test_expert_parallel.py
Extend MoE dispatch metadata and AllToAllTokenDispatcher with force-load-balance and equal-split all-to-all fast paths, including equal-count permute/unpermute.
  • Augment LocalDispatchMetadata/AllToAllDispatchMetadata with optional num_tokens_per_expert_list, rank_major_shape, and equal_a2a padding metadata, and allow permuted_indices to be optional.
  • Add force_load_balance configuration, cached sort indices, and helper utilities for balanced counts, equal A2A padding, pad/compact helpers, and an equal-count permute/unpermute path when expert segments are uniform.
  • Rework AllToAllTokenDispatcher.dispatch/combine to optionally bypass argsort with cached indices, compute splits via either load-balance logic or all-to-all, support equal-split all-to-all with optional global max padding, and use the new equal-count permute/unpermute fast path when applicable.
torchtitan/models/common/token_dispatcher.py
tests/unit_tests/test_expert_parallel.py
Improve LocalTokenDispatcher and TorchAOTokenDispatcher correctness/coverage and add tests for scoring, aliasing, permute, and unpermute behaviors.
  • Allow dispatchers to accept precomputed num_tokens_per_expert, reuse LocalTokenDispatcher logic when EP=1, and recompute counts after SP split when necessary.
  • Adjust TorchAOTokenDispatcher._unpermute signature to accept optional rank_major_shape and explicitly reject rank-major unpermute, and add tests to validate permute/unpermute semantics including TorchAO’s sentinel-row behavior.
  • Add comprehensive unit tests covering MoE fast-path counters, deterministic scatter_add in-place semantics, LocalTokenDispatcher correctness with/without shared experts and bfloat16, aliasing behavior in no-grad paths, and force-load-balance routing and splits utilities.
torchtitan/models/common/token_dispatcher.py
tests/unit_tests/test_expert_parallel.py
Optimize expert computation for non-grouped MoE by batching equal token counts and caching fused expert weights and transposes in no-grad paths.
  • Change expert forward loop to accept either tensor or list counts, avoid split() overhead by using explicit slicing, and add a batched no-grad path when all experts receive the same number of tokens, doing two batched bmm calls with fused w1/w3 and cached w2^T.
  • Introduce per-module caches for concatenated (w1
Extend TokenChoiceTopKRouter to produce num_tokens_per_expert and support a debug force-load-balance mode that keeps counts consistent with assignments.
  • Update debug round-robin routing to also return balanced num_tokens_per_expert and adjust its public contract and tests accordingly.
  • Refactor forward() to share routing code between normal and debug paths, computing num_tokens_per_expert via torch.histc only in the non-debug case and returning it alongside top_scores and selected_experts_indices.
  • Update the high-level MoE forward to pass router-produced num_tokens_per_expert into token dispatchers and experts so downstream fast paths can exploit known token counts.
torchtitan/models/common/moe.py
tests/unit_tests/test_expert_parallel.py
Make ezpz MoE FSDP sharding robust to builds lacking FSDP’s per-param mesh API, and propagate debug force-load-balance into token_dispatcher configuration.
  • Change parallelize logic to try importing FSDPMeshInfo/ShardPlacementResult, falling back to a two-stage fully_shard call (experts on edp_mesh, block on dp_mesh) when unavailable, while preserving existing per-param mesh behavior when supported.
  • Wire debug.moe_force_load_balance to experts.token_dispatcher.force_load_balance so runtime configuration can enable the new AllToAll fast path.
  • Update shard_placement_fn behavior/comments to clarify behavior on XPU or other constrained PyTorch builds.
torchtitan/experiments/ezpz/moe/parallelize.py
torchtitan/experiments/ezpz/moe/model.py

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

Copy link
Copy Markdown

@sourcery-ai sourcery-ai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey - I've found 5 issues, and left some high level feedback:

  • In _pad_to_equal_splits, the F.pad(piece, (0, 0, 0, equal_split_size - split)) call pads extra rows at the front rather than the end of each segment, which contradicts the new test expectations (and the intended layout of original tokens followed by padding); consider swapping the height pads to (0, 0, equal_split_size - split, 0) or using torch.cat with an explicit zeros tensor instead.
  • The _shares_storage helper only compares untyped_storage().data_ptr() and ignores storage offsets, so two views into the same underlying storage at different offsets will be treated as non-aliasing; this can allow _scatter_add_forward_or_autograd to perform in-place updates on tensors that still overlap x, potentially corrupting inputs—consider also checking storage_offset() or using torch._unsafe_view-based alias checks.
Prompt for AI Agents
Please address the comments from this code review:

## Overall Comments
- In `_pad_to_equal_splits`, the `F.pad(piece, (0, 0, 0, equal_split_size - split))` call pads extra rows at the *front* rather than the *end* of each segment, which contradicts the new test expectations (and the intended layout of original tokens followed by padding); consider swapping the height pads to `(0, 0, equal_split_size - split, 0)` or using `torch.cat` with an explicit zeros tensor instead.
- The `_shares_storage` helper only compares `untyped_storage().data_ptr()` and ignores storage offsets, so two views into the same underlying storage at different offsets will be treated as non-aliasing; this can allow `_scatter_add_forward_or_autograd` to perform in-place updates on tensors that still overlap `x`, potentially corrupting inputs—consider also checking `storage_offset()` or using `torch._unsafe_view`-based alias checks.

## Individual Comments

### Comment 1
<location path="torchtitan/models/common/moe.py" line_range="49" />
<code_context>
+    ):
+        _record_moe_fastpath("batched_no_grad_experts")
+        tokens_per_expert = num_tokens_per_expert_list[0]
+        x_grouped = x.view(len(num_tokens_per_expert_list), tokens_per_expert, x.shape[-1])
+        if w13 is None:
+            w13 = torch.cat((w1, w3), dim=1)
</code_context>
<issue_to_address>
**issue (bug_risk):** The batched experts path assumes `x` is contiguous; using `view` here can misbehave for non-contiguous inputs.

For non-contiguous `x` (e.g., after `all_to_all`/permute), `view` can fail or return a wrongly grouped tensor. Since this path depends on an expert-major contiguous layout, consider either `x = x.contiguous().view(...)` or `x_grouped = x.reshape(...)`, and add an assertion like `x.numel() == len(num_tokens_per_expert_list) * tokens_per_expert * x.shape[-1]` to keep the fast path correct under layout changes.
</issue_to_address>

### Comment 2
<location path="torchtitan/models/common/token_dispatcher.py" line_range="430" />
<code_context>
+        """
+        pieces = []
+        for rank, split in enumerate(input_splits):
+            tokens_for_rank = split // num_local_experts
+            piece = x[rank::ep_size].repeat(num_local_experts, 1)
+            if split < equal_split_size:
</code_context>
<issue_to_address>
**nitpick:** `tokens_for_rank` is computed but never used in the equal-split routing fast path.

In `_force_load_balance_equal_split_routed_input`, `tokens_for_rank` is computed but never used. If it’s only for debugging, remove it; if it’s meant to check something like `split % num_local_experts == 0`, consider replacing it with an explicit assertion instead.
</issue_to_address>

### Comment 3
<location path="torchtitan/ops/scatter_add.py" line_range="23-20" />
<code_context>
         torch.use_deterministic_algorithms(prev, warn_only=prev_warn_only)


+def deterministic_scatter_add_(
+    out: torch.Tensor, index: torch.Tensor, src: torch.Tensor
+) -> torch.Tensor:
+    """Deterministic in-place scatter_add for forward-only/no-grad paths."""
+    prev = torch.are_deterministic_algorithms_enabled()
+    prev_warn_only = torch.is_deterministic_algorithms_warn_only_enabled()
+    torch.use_deterministic_algorithms(True, warn_only=False)
+    try:
+        return out.scatter_add_(dim=0, index=index, src=src)
+    finally:
+        torch.use_deterministic_algorithms(prev, warn_only=prev_warn_only)
+
+
</code_context>
<issue_to_address>
**suggestion (bug_risk):** The new in-place `deterministic_scatter_add_` lacks a `register_fake` implementation, unlike the out-of-place version.

The out-of-place op has a `@...register_fake` handler so fake-tensor / export can trace it without real execution. The in-place variant currently lacks this, so tracing/fake tensor modes may fail when they use this path. Please add a matching fake registration for the in-place op to keep behavior consistent (e.g., returning an appropriately shaped empty-like tensor or `out`).

Suggested implementation:

```python
def deterministic_scatter_add_(
    out: torch.Tensor, index: torch.Tensor, src: torch.Tensor
) -> torch.Tensor:
    """Deterministic in-place scatter_add for forward-only/no-grad paths."""
    prev = torch.are_deterministic_algorithms_enabled()
    prev_warn_only = torch.is_deterministic_algorithms_warn_only_enabled()
    torch.use_deterministic_algorithms(True, warn_only=False)
    try:
        return out.scatter_add_(dim=0, index=index, src=src)
    finally:
        torch.use_deterministic_algorithms(prev, warn_only=prev_warn_only)


@torch.library.register_fake("torchtitan::deterministic_scatter_add_")
def deterministic_scatter_add__fake(
    out: torch.Tensor, index: torch.Tensor, src: torch.Tensor
) -> torch.Tensor:
    # In fake / export modes, model the in-place op by returning `out` with the
    # correct shape and dtype, without performing any real computation.
    return out

```

The `@torch.library.register_fake("torchtitan::deterministic_scatter_add_")` line assumes that:
1. You are using `torch.library`-based custom ops, and
2. The schema/name for the in-place op is `torchtitan::deterministic_scatter_add_`.

To keep this consistent with the existing out-of-place op:
- Update the decorator target to match whatever is used for the out-of-place fake registration, e.g. if the out-of-place version uses:
  `@_scatter_add_lib.register_fake("deterministic_scatter_add")`
  then this should instead be:
  `@_scatter_add_lib.register_fake("deterministic_scatter_add_")`
- Ensure the function name (`deterministic_scatter_add__fake`) matches the naming convention you use for other fake handlers in this file (you may want to rename it to mirror the out-of-place fake’s function name plus a trailing underscore).
</issue_to_address>

### Comment 4
<location path="tests/unit_tests/test_expert_parallel.py" line_range="15" />
<code_context>
+from torch import nn
+
+from torchtitan.models.common.linear import Linear
+from torchtitan.models.common.moe import TokenChoiceTopKRouter
+from torchtitan.models.common.token_dispatcher import (
+    AllToAllTokenDispatcher,
</code_context>
<issue_to_address>
**suggestion (testing):** Add unit tests for the new batched no-grad experts path and caching in MoE experts.

The new batched no-grad path in `_run_experts_for_loop` (for equal `num_tokens_per_expert`) and the caching of `w13` / `w2_t` in `MixtureOfExperts` under `torch.no_grad()` are currently untested. Please add tests that:

1. Check numerical equivalence between the batched path and the original per-expert loop when token counts are equal.
2. Validate cache hit/miss behavior, including `_MOE_FASTPATH_COUNTERS` updates for `cached_w13_hit`/`miss` and `cached_w2_t_hit`/`miss` based on tensor identity/version.

A focused `TestMoEExpertsFastPaths` that builds a small MoE layer, runs two no-grad forwards with identical weights and `num_tokens_per_expert`, and asserts both output equality and the expected counter values would cover this well.

Suggested implementation:

```python
from torch import nn

from torchtitan.models.common.linear import Linear
from torchtitan.models.common.moe import MixtureOfExperts, TokenChoiceTopKRouter
from torchtitan.models.common.token_dispatcher import (

```

```python
from torchtitan.ops.scatter_add import (
    deterministic_scatter_add,
    deterministic_scatter_add_,
)


class TestMoEExpertsFastPaths(unittest.TestCase):
    def setUp(self) -> None:
        # Reset fastpath counters before each test to isolate behavior
        for key in [
            "cached_w13_hit",
            "cached_w13_miss",
            "cached_w2_t_hit",
            "cached_w2_t_miss",
        ]:
            if key in _MOE_FASTPATH_COUNTERS:
                _MOE_FASTPATH_COUNTERS[key] = 0

    def _build_small_moe(self, d_model: int = 4, d_hidden: int = 8, num_experts: int = 2):
        # Minimal MoE module; exact signature may need to be aligned with implementation
        moe = MixtureOfExperts(
            model_dim=d_model,
            ffn_hidden_size=d_hidden,
            num_experts=num_experts,
        )
        moe.eval()
        return moe

    def _assert_fastpath_counters(
        self,
        cached_w13_hit: int,
        cached_w13_miss: int,
        cached_w2_t_hit: int,
        cached_w2_t_miss: int,
    ):
        if "cached_w13_hit" in _MOE_FASTPATH_COUNTERS:
            self.assertEqual(_MOE_FASTPATH_COUNTERS["cached_w13_hit"], cached_w13_hit)
        if "cached_w13_miss" in _MOE_FASTPATH_COUNTERS:
            self.assertEqual(_MOE_FASTPATH_COUNTERS["cached_w13_miss"], cached_w13_miss)
        if "cached_w2_t_hit" in _MOE_FASTPATH_COUNTERS:
            self.assertEqual(_MOE_FASTPATH_COUNTERS["cached_w2_t_hit"], cached_w2_t_hit)
        if "cached_w2_t_miss" in _MOE_FASTPATH_COUNTERS:
            self.assertEqual(_MOE_FASTPATH_COUNTERS["cached_w2_t_miss"], cached_w2_t_miss)

    def test_batched_nograd_experts_numerical_equivalence(self):
        """
        Validate that the batched no-grad experts path produces numerically
        equivalent outputs to the original per-expert loop when
        num_tokens_per_expert is equal across experts.
        """
        torch.manual_seed(0)
        device = "cuda" if torch.cuda.is_available() else "cpu"

        d_model = 4
        d_hidden = 8
        num_experts = 2
        tokens_per_expert = 3  # equal token count per expert
        batch_tokens = num_experts * tokens_per_expert

        moe = self._build_small_moe(d_model, d_hidden, num_experts).to(device)

        # Synthetic input tokens
        x = torch.randn(batch_tokens, d_model, device=device)

        # Build routing such that each expert gets the same number of tokens.
        # We assume the MixtureOfExperts forward can take a precomputed
        # expert assignment or router probabilities; adjust as needed.
        expert_indices = torch.arange(num_experts, device=device).repeat_interleave(tokens_per_expert)
        # Shape [T, num_experts] one-hot or logit-style router scores
        router_logits = F.one_hot(expert_indices, num_classes=num_experts).float()

        # Run reference path with autograd on (forces original per-expert loop)
        x_ref = x.detach().clone().requires_grad_(True)
        out_ref = moe(x_ref, router_logits)

        # Run batched no-grad fast path; new code path should be triggered
        with torch.no_grad():
            out_fast = moe(x, router_logits)

        self.assertTrue(
            torch.allclose(out_ref, out_fast, atol=1e-5, rtol=1e-5),
            msg="MoE batched no-grad experts path must be numerically equivalent to per-expert loop",
        )

    def test_moe_expert_cache_counters(self):
        """
        Validate cache hit/miss behavior for the no-grad expert fast path:
        - First no-grad forward should be a cache miss for both w13 and w2_t
        - Second no-grad forward with identical weights should be a cache hit
        - Counter values should be updated via _MOE_FASTPATH_COUNTERS
        """
        torch.manual_seed(1)
        device = "cuda" if torch.cuda.is_available() else "cpu"

        d_model = 4
        d_hidden = 8
        num_experts = 2
        tokens_per_expert = 2
        batch_tokens = num_experts * tokens_per_expert

        moe = self._build_small_moe(d_model, d_hidden, num_experts).to(device)

        # Clone weights so we can later mutate their version to force misses.
        # These attribute names may need to be adjusted to match the implementation.
        w1_before = moe.experts[0].w1.weight
        w3_before = moe.experts[0].w3.weight
        w2_before = moe.experts[0].w2.weight

        x = torch.randn(batch_tokens, d_model, device=device)
        expert_indices = torch.arange(num_experts, device=device).repeat_interleave(tokens_per_expert)
        router_logits = F.one_hot(expert_indices, num_classes=num_experts).float()

        # First no-grad forward: expect cache misses for both w13 and w2_t
        with torch.no_grad():
            moe(x, router_logits)

        # Depending on how many times the fastpath is invoked internally,
        # we assert "at least" one miss.
        if "cached_w13_miss" in _MOE_FASTPATH_COUNTERS:
            self.assertGreaterEqual(_MOE_FASTPATH_COUNTERS["cached_w13_miss"], 1)
        if "cached_w2_t_miss" in _MOE_FASTPATH_COUNTERS:
            self.assertGreaterEqual(_MOE_FASTPATH_COUNTERS["cached_w2_t_miss"], 1)

        # Second no-grad forward with identical weights: expect hits
        with torch.no_grad():
            moe(x, router_logits)

        if "cached_w13_hit" in _MOE_FASTPATH_COUNTERS:
            self.assertGreaterEqual(_MOE_FASTPATH_COUNTERS["cached_w13_hit"], 1)
        if "cached_w2_t_hit" in _MOE_FASTPATH_COUNTERS:
            self.assertGreaterEqual(_MOE_FASTPATH_COUNTERS["cached_w2_t_hit"], 1)

        # Now mutate weights to invalidate caches and force misses based on tensor identity/version
        with torch.no_grad():
            w1_before.add_(0.1)
            w3_before.add_(0.1)
            w2_before.add_(0.1)

        prev_counters = dict(_MOE_FASTPATH_COUNTERS)

        with torch.no_grad():
            moe(x, router_logits)

        # After weight mutation, we expect additional misses
        if "cached_w13_miss" in _MOE_FASTPATH_COUNTERS:
            self.assertGreater(
                _MOE_FASTPATH_COUNTERS["cached_w13_miss"],
                prev_counters.get("cached_w13_miss", 0),
            )
        if "cached_w2_t_miss" in _MOE_FASTPATH_COUNTERS:
            self.assertGreater(
                _MOE_FASTPATH_COUNTERS["cached_w2_t_miss"],
                prev_counters.get("cached_w2_t_miss", 0),
            )

```

The above tests assume:
1. `MixtureOfExperts` exists in `torchtitan.models.common.moe` and has a constructor like:
   `MixtureOfExperts(model_dim: int, ffn_hidden_size: int, num_experts: int, ...)` and a forward signature `forward(x, router_logits)`.
2. A submodule layout like `moe.experts[expert_index].w1`, `.w2`, `.w3` for FFN weights.
3. `_MOE_FASTPATH_COUNTERS` exposes the keys `"cached_w13_hit"`, `"cached_w13_miss"`, `"cached_w2_t_hit"`, and `"cached_w2_t_miss"`.

You may need to:
- Adjust `_build_small_moe` to match the actual constructor signature of `MixtureOfExperts` and any required router/dispatcher configuration.
- Adapt how `router_logits` / expert assignments are passed into `forward` to line up with your implementation (e.g., using `TokenChoiceTopKRouter` directly or passing a routing map).
- Update weight attribute names (`w1`, `w2`, `w3`) and expert access (`moe.experts[...]`) to the actual structure.
- If `_run_experts_for_loop` is exposed differently, you can refine `test_batched_nograd_experts_numerical_equivalence` to call it directly with explicit `num_tokens_per_expert` instead of going through the public `forward`.
</issue_to_address>

### Comment 5
<location path="torchtitan/models/common/token_dispatcher.py" line_range="966" />
<code_context>
         self.pad_multiple = config.pad_multiple

-    def dispatch(self, x, top_scores, selected_experts_indices):
+    def dispatch(
+        self,
+        x,
</code_context>
<issue_to_address>
**issue (complexity):** Consider encapsulating equal-split routing, force-load-balance fast paths, and fastpath-debug bookkeeping into small helpers/strategy objects so AllToAllTokenDispatcher.dispatch/combine stay linear and easier to read.

The main complexity spike is in `AllToAllTokenDispatcher.dispatch/combine`, where multiple orthogonal concerns are interwoven and encoded via booleans/metadata fields. You can keep all functionality while reducing cognitive load by **encapsulating equal-split + fastpath state** and **separating routing strategies**.

### 1. Encapsulate equal A2A split behavior

Right now `equal_a2a_split_size`, `normal_equal_a2a_padding`, `dispatch_input_splits`, `dispatch_output_splits`, `input_splits_list`, `output_splits_list`, and the pad/compact helpers are threaded across dispatch/combine and encoded in metadata. This can be encapsulated into a small helper/struct that owns the state and provides symmetric dispatch/combine transforms.

Example sketch:

```python
@dataclass
class EqualA2ARouting:
    enabled: bool
    equal_split_size: int | None = None
    normal_equal_padding: bool = False

    # These are the "logical" splits; we can expose them if caller needs them.
    input_splits: list[int] | None = None
    output_splits: list[int] | None = None

    def prepare_dispatch(
        self,
        routed_input: torch.Tensor,
        input_splits: list[int],
        output_splits: list[int],
        *,
        device: torch.device,
        dtype: torch.dtype,
        ep_mesh: DeviceMesh,
        sp_size: int,
        score_before_experts: bool,
    ) -> tuple[torch.Tensor, list[int] | None, list[int] | None]:
        """
        Returns:
          dispatched_tensor, dispatch_output_splits, dispatch_input_splits
        """
        if not self.enabled:
            self.input_splits = input_splits
            self.output_splits = output_splits
            return routed_input, output_splits, input_splits

        # decide equal_split_size + normal_equal_padding here
        # self.equal_split_size, self.normal_equal_padding = ...

        padded = self._pad_to_equal_splits(
            routed_input, input_splits, self.equal_split_size
        )
        self.input_splits = input_splits
        self.output_splits = output_splits
        return padded, None, None  # equal splits => None for autograd wrapper

    def prepare_combine(
        self,
        routed_output: torch.Tensor,
    ) -> tuple[torch.Tensor, list[int] | None, list[int] | None]:
        """
        Returns:
          tensor_for_all_to_all, combine_output_splits, combine_input_splits
        """
        if not self.enabled or self.equal_split_size is None:
            return routed_output, self.output_splits, self.input_splits

        padded = self._pad_to_equal_splits(
            routed_output, self.output_splits, self.equal_split_size
        )
        return padded, None, None

    def post_dispatch(self, x: torch.Tensor) -> torch.Tensor:
        if not self.enabled or self.equal_split_size is None:
            return x
        return self._compact_equal_splits(x, self.output_splits, self.equal_split_size)

    def post_combine(self, x: torch.Tensor) -> torch.Tensor:
        if not self.enabled or self.equal_split_size is None:
            return x
        return self._compact_equal_splits(x, self.input_splits, self.equal_split_size)

    @staticmethod
    def _pad_to_equal_splits(x: torch.Tensor, splits: list[int], equal_split: int) -> torch.Tensor:
        # existing padding logic here

    @staticmethod
    def _compact_equal_splits(x: torch.Tensor, splits: list[int], equal_split: int) -> torch.Tensor:
        # existing compact logic here
```

Then `dispatch()` becomes more linear:

```python
def dispatch(...):
    ...
    # compute num_tokens_per_expert, input_splits_list, output_splits_list

    equal_route = EqualA2ARouting(
        enabled=self._should_use_equal_a2a(input_splits_list, ...),
    )
    routed_input = x[token_indices_experts_sorted]
    if self.score_before_experts:
        routed_input = ...

    routed_input, dispatch_out_splits, dispatch_in_splits = equal_route.prepare_dispatch(
        routed_input,
        input_splits_list,
        output_splits_list,
        device=num_tokens_per_expert.device,
        dtype=num_tokens_per_expert.dtype,
        ep_mesh=self.ep_mesh,
        sp_size=self.sp_size,
        score_before_experts=self.score_before_experts,
    )

    routed_input = all_to_all_single_autograd(
        routed_input,
        dispatch_out_splits,
        dispatch_in_splits,
        self.ep_mesh,
    )
    routed_input = equal_route.post_dispatch(routed_input)

    # permute / permute_equal_counts
    ...

    metadata = AllToAllDispatchMetadata(
        ...,
        equal_a2a_split_size=equal_route.equal_split_size,
        normal_equal_a2a_padding=equal_route.normal_equal_padding,
        input_splits=equal_route.input_splits,
        output_splits=equal_route.output_splits,
    )
```

And `combine()` can mirror that with fewer ad-hoc flags:

```python
def combine(...):
    ...
    routed_output = self._unpermute(
        routed_output,
        metadata.input_shape,
        metadata.permuted_indices,
        metadata.rank_major_shape,
    )

    equal_route = EqualA2ARouting(
        enabled=metadata.equal_a2a_split_size is not None,
        equal_split_size=metadata.equal_a2a_split_size,
        normal_equal_padding=metadata.normal_equal_a2a_padding,
        input_splits=metadata.input_splits,
        output_splits=metadata.output_splits,
    )

    routed_output, combine_out_splits, combine_in_splits = equal_route.prepare_combine(
        routed_output
    )

    routed_output = all_to_all_single_autograd(
        routed_output,
        combine_out_splits,
        combine_in_splits,
        self.ep_mesh,
    )
    routed_output = equal_route.post_combine(routed_output)

    ...
```

That removes a lot of local booleans (`direct_equal_split_route`, `normal_equal_a2a_padding`, `dispatch_input_splits`, `dispatch_output_splits`) and makes the symmetry between dispatch/combine explicit and local to one helper.

### 2. Separate force-load-balance routing as a strategy

`use_force_load_balance_fast_path`, `_force_load_balance_splits`, `_force_load_balance_sort_indices`, `_force_load_balance_equal_split_routed_input`, plus the implicit protocol using `uniform_local_count` + `rank_major_shape` introduce another cross-cutting mode. You can keep the optimization but isolate its decision + logic in a small “strategy” object or helper, so `dispatch()` only selects it and asks for results, instead of mixing both paths inline.

Example sketch:

```python
class ForceLoadBalanceRouting:
    def __init__(self, dispatcher: "AllToAllTokenDispatcher"):
        self.dispatcher = dispatcher

    def is_enabled(self, *, sp_size: int) -> bool:
        return (
            self.dispatcher.force_load_balance
            and sp_size == 1
            and not torch.compiler.is_compiling()
        )

    def sort_indices(
        self, selected_experts_indices: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        return self.dispatcher._force_load_balance_sort_indices(
            selected_experts_indices
        )

    def compute_splits(
        self,
        num_tokens_per_expert: torch.Tensor,
        num_tokens: int,
        ep_size: int,
        ep_rank: int,
    ) -> tuple[torch.Tensor, list[int], list[int], int | None]:
        return self.dispatcher._force_load_balance_splits(
            num_tokens,
            ep_size,
            ep_rank,
            num_tokens_per_expert.dtype,
            num_tokens_per_expert.device,
        )

    def build_equal_split_buffer(
        self,
        x: torch.Tensor,
        ep_size: int,
        num_local_experts: int,
        input_splits: list[int],
        equal_split_size: int,
    ) -> torch.Tensor:
        return self.dispatcher._force_load_balance_equal_split_routed_input(
            x, ep_size, num_local_experts, input_splits, equal_split_size
        )
```

Usage in `dispatch()`:

```python
def dispatch(...):
    ...
    flb = ForceLoadBalanceRouting(self)
    use_flb = flb.is_enabled(sp_size=self.sp_size)
    if use_flb:
        assignment_indices_experts_sorted, token_indices_experts_sorted = flb.sort_indices(
            selected_experts_indices
        )
    else:
        assignment_indices_experts_sorted = torch.argsort(
            selected_experts_indices.view(-1), stable=True
        )
        token_indices_experts_sorted = assignment_indices_experts_sorted // self.top_k

    top_scores_experts_sorted = top_scores.view(-1)[assignment_indices_experts_sorted]

    if use_flb:
        (
            num_tokens_per_expert_group,
            input_splits_list,
            output_splits_list,
            uniform_local_count,
        ) = flb.compute_splits(
            num_tokens_per_expert,
            original_num_tokens,
            ep_size,
            self.ep_mesh.get_local_rank(),
        )
    else:
        # existing all_to_all_single path
        ...
```

This keeps the “normal” path readable and clearly separates the specialized force-load-balance behavior, while preserving all current fastpaths.

### 3. Isolate fastpath debug accounting

Environment checks + counter updates are currently sprinkled throughout dispatch/combine, which adds noise on the hot path. You can wrap them behind small helpers that return identity/decorated callables so control flow stays clean.

For example, instead of:

```python
if not self.score_before_experts:
    if routed_output.dtype == torch.bfloat16:
        _record_moe_fastpath("score_after_experts_bf16")
    else:
        _record_moe_fastpath("score_after_experts")
    routed_output = routed_output * ...
```

factor a tiny helper:

```python
def _record_score_path(name: str, tensor: torch.Tensor) -> None:
    if tensor.dtype == torch.bfloat16:
        _record_moe_fastpath(f"{name}_bf16")
    else:
        _record_moe_fastpath(name)

# call site
if not self.score_before_experts:
    _record_score_path("score_after_experts", routed_output)
    routed_output = routed_output * ...
```

or a decorator-style wrapper for the scatter-add path:

```python
def _scatter_add_with_debug(self, out, index, src, protected_input):
    _record_moe_fastpath("scatter_add")
    return _scatter_add_forward_or_autograd(out, index, src, protected_input)
```

and then:

```python
out = self._scatter_add_with_debug(out, scatter_index, routed_output, x)
```

This keeps the mathematical logic (“what is being computed”) separated from telemetry, making the main control flow easier to follow.

---

These changes keep all current functionality and fast paths, but:

* reduce the number of flags and cross-method protocols exposed in `dispatch/combine`,
* make equal-split and force-load-balance behavior self-contained and symmetric, and
* move debug accounting off the critical visual path.
</issue_to_address>

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

Comment thread torchtitan/models/common/moe.py Outdated
Comment thread torchtitan/models/common/token_dispatcher.py Outdated
Comment thread torchtitan/ops/scatter_add.py
Comment thread tests/unit_tests/test_expert_parallel.py Outdated
self.pad_multiple = config.pad_multiple

def dispatch(self, x, top_scores, selected_experts_indices):
def dispatch(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider encapsulating equal-split routing, force-load-balance fast paths, and fastpath-debug bookkeeping into small helpers/strategy objects so AllToAllTokenDispatcher.dispatch/combine stay linear and easier to read.

The main complexity spike is in AllToAllTokenDispatcher.dispatch/combine, where multiple orthogonal concerns are interwoven and encoded via booleans/metadata fields. You can keep all functionality while reducing cognitive load by encapsulating equal-split + fastpath state and separating routing strategies.

1. Encapsulate equal A2A split behavior

Right now equal_a2a_split_size, normal_equal_a2a_padding, dispatch_input_splits, dispatch_output_splits, input_splits_list, output_splits_list, and the pad/compact helpers are threaded across dispatch/combine and encoded in metadata. This can be encapsulated into a small helper/struct that owns the state and provides symmetric dispatch/combine transforms.

Example sketch:

@dataclass
class EqualA2ARouting:
    enabled: bool
    equal_split_size: int | None = None
    normal_equal_padding: bool = False

    # These are the "logical" splits; we can expose them if caller needs them.
    input_splits: list[int] | None = None
    output_splits: list[int] | None = None

    def prepare_dispatch(
        self,
        routed_input: torch.Tensor,
        input_splits: list[int],
        output_splits: list[int],
        *,
        device: torch.device,
        dtype: torch.dtype,
        ep_mesh: DeviceMesh,
        sp_size: int,
        score_before_experts: bool,
    ) -> tuple[torch.Tensor, list[int] | None, list[int] | None]:
        """
        Returns:
          dispatched_tensor, dispatch_output_splits, dispatch_input_splits
        """
        if not self.enabled:
            self.input_splits = input_splits
            self.output_splits = output_splits
            return routed_input, output_splits, input_splits

        # decide equal_split_size + normal_equal_padding here
        # self.equal_split_size, self.normal_equal_padding = ...

        padded = self._pad_to_equal_splits(
            routed_input, input_splits, self.equal_split_size
        )
        self.input_splits = input_splits
        self.output_splits = output_splits
        return padded, None, None  # equal splits => None for autograd wrapper

    def prepare_combine(
        self,
        routed_output: torch.Tensor,
    ) -> tuple[torch.Tensor, list[int] | None, list[int] | None]:
        """
        Returns:
          tensor_for_all_to_all, combine_output_splits, combine_input_splits
        """
        if not self.enabled or self.equal_split_size is None:
            return routed_output, self.output_splits, self.input_splits

        padded = self._pad_to_equal_splits(
            routed_output, self.output_splits, self.equal_split_size
        )
        return padded, None, None

    def post_dispatch(self, x: torch.Tensor) -> torch.Tensor:
        if not self.enabled or self.equal_split_size is None:
            return x
        return self._compact_equal_splits(x, self.output_splits, self.equal_split_size)

    def post_combine(self, x: torch.Tensor) -> torch.Tensor:
        if not self.enabled or self.equal_split_size is None:
            return x
        return self._compact_equal_splits(x, self.input_splits, self.equal_split_size)

    @staticmethod
    def _pad_to_equal_splits(x: torch.Tensor, splits: list[int], equal_split: int) -> torch.Tensor:
        # existing padding logic here

    @staticmethod
    def _compact_equal_splits(x: torch.Tensor, splits: list[int], equal_split: int) -> torch.Tensor:
        # existing compact logic here

Then dispatch() becomes more linear:

def dispatch(...):
    ...
    # compute num_tokens_per_expert, input_splits_list, output_splits_list

    equal_route = EqualA2ARouting(
        enabled=self._should_use_equal_a2a(input_splits_list, ...),
    )
    routed_input = x[token_indices_experts_sorted]
    if self.score_before_experts:
        routed_input = ...

    routed_input, dispatch_out_splits, dispatch_in_splits = equal_route.prepare_dispatch(
        routed_input,
        input_splits_list,
        output_splits_list,
        device=num_tokens_per_expert.device,
        dtype=num_tokens_per_expert.dtype,
        ep_mesh=self.ep_mesh,
        sp_size=self.sp_size,
        score_before_experts=self.score_before_experts,
    )

    routed_input = all_to_all_single_autograd(
        routed_input,
        dispatch_out_splits,
        dispatch_in_splits,
        self.ep_mesh,
    )
    routed_input = equal_route.post_dispatch(routed_input)

    # permute / permute_equal_counts
    ...

    metadata = AllToAllDispatchMetadata(
        ...,
        equal_a2a_split_size=equal_route.equal_split_size,
        normal_equal_a2a_padding=equal_route.normal_equal_padding,
        input_splits=equal_route.input_splits,
        output_splits=equal_route.output_splits,
    )

And combine() can mirror that with fewer ad-hoc flags:

def combine(...):
    ...
    routed_output = self._unpermute(
        routed_output,
        metadata.input_shape,
        metadata.permuted_indices,
        metadata.rank_major_shape,
    )

    equal_route = EqualA2ARouting(
        enabled=metadata.equal_a2a_split_size is not None,
        equal_split_size=metadata.equal_a2a_split_size,
        normal_equal_padding=metadata.normal_equal_a2a_padding,
        input_splits=metadata.input_splits,
        output_splits=metadata.output_splits,
    )

    routed_output, combine_out_splits, combine_in_splits = equal_route.prepare_combine(
        routed_output
    )

    routed_output = all_to_all_single_autograd(
        routed_output,
        combine_out_splits,
        combine_in_splits,
        self.ep_mesh,
    )
    routed_output = equal_route.post_combine(routed_output)

    ...

That removes a lot of local booleans (direct_equal_split_route, normal_equal_a2a_padding, dispatch_input_splits, dispatch_output_splits) and makes the symmetry between dispatch/combine explicit and local to one helper.

2. Separate force-load-balance routing as a strategy

use_force_load_balance_fast_path, _force_load_balance_splits, _force_load_balance_sort_indices, _force_load_balance_equal_split_routed_input, plus the implicit protocol using uniform_local_count + rank_major_shape introduce another cross-cutting mode. You can keep the optimization but isolate its decision + logic in a small “strategy” object or helper, so dispatch() only selects it and asks for results, instead of mixing both paths inline.

Example sketch:

class ForceLoadBalanceRouting:
    def __init__(self, dispatcher: "AllToAllTokenDispatcher"):
        self.dispatcher = dispatcher

    def is_enabled(self, *, sp_size: int) -> bool:
        return (
            self.dispatcher.force_load_balance
            and sp_size == 1
            and not torch.compiler.is_compiling()
        )

    def sort_indices(
        self, selected_experts_indices: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        return self.dispatcher._force_load_balance_sort_indices(
            selected_experts_indices
        )

    def compute_splits(
        self,
        num_tokens_per_expert: torch.Tensor,
        num_tokens: int,
        ep_size: int,
        ep_rank: int,
    ) -> tuple[torch.Tensor, list[int], list[int], int | None]:
        return self.dispatcher._force_load_balance_splits(
            num_tokens,
            ep_size,
            ep_rank,
            num_tokens_per_expert.dtype,
            num_tokens_per_expert.device,
        )

    def build_equal_split_buffer(
        self,
        x: torch.Tensor,
        ep_size: int,
        num_local_experts: int,
        input_splits: list[int],
        equal_split_size: int,
    ) -> torch.Tensor:
        return self.dispatcher._force_load_balance_equal_split_routed_input(
            x, ep_size, num_local_experts, input_splits, equal_split_size
        )

Usage in dispatch():

def dispatch(...):
    ...
    flb = ForceLoadBalanceRouting(self)
    use_flb = flb.is_enabled(sp_size=self.sp_size)
    if use_flb:
        assignment_indices_experts_sorted, token_indices_experts_sorted = flb.sort_indices(
            selected_experts_indices
        )
    else:
        assignment_indices_experts_sorted = torch.argsort(
            selected_experts_indices.view(-1), stable=True
        )
        token_indices_experts_sorted = assignment_indices_experts_sorted // self.top_k

    top_scores_experts_sorted = top_scores.view(-1)[assignment_indices_experts_sorted]

    if use_flb:
        (
            num_tokens_per_expert_group,
            input_splits_list,
            output_splits_list,
            uniform_local_count,
        ) = flb.compute_splits(
            num_tokens_per_expert,
            original_num_tokens,
            ep_size,
            self.ep_mesh.get_local_rank(),
        )
    else:
        # existing all_to_all_single path
        ...

This keeps the “normal” path readable and clearly separates the specialized force-load-balance behavior, while preserving all current fastpaths.

3. Isolate fastpath debug accounting

Environment checks + counter updates are currently sprinkled throughout dispatch/combine, which adds noise on the hot path. You can wrap them behind small helpers that return identity/decorated callables so control flow stays clean.

For example, instead of:

if not self.score_before_experts:
    if routed_output.dtype == torch.bfloat16:
        _record_moe_fastpath("score_after_experts_bf16")
    else:
        _record_moe_fastpath("score_after_experts")
    routed_output = routed_output * ...

factor a tiny helper:

def _record_score_path(name: str, tensor: torch.Tensor) -> None:
    if tensor.dtype == torch.bfloat16:
        _record_moe_fastpath(f"{name}_bf16")
    else:
        _record_moe_fastpath(name)

# call site
if not self.score_before_experts:
    _record_score_path("score_after_experts", routed_output)
    routed_output = routed_output * ...

or a decorator-style wrapper for the scatter-add path:

def _scatter_add_with_debug(self, out, index, src, protected_input):
    _record_moe_fastpath("scatter_add")
    return _scatter_add_forward_or_autograd(out, index, src, protected_input)

and then:

out = self._scatter_add_with_debug(out, scatter_index, routed_output, x)

This keeps the mathematical logic (“what is being computed”) separated from telemetry, making the main control flow easier to follow.


These changes keep all current functionality and fast paths, but:

  • reduce the number of flags and cross-method protocols exposed in dispatch/combine,
  • make equal-split and force-load-balance behavior self-contained and symmetric, and
  • move debug accounting off the critical visual path.

@nscottnichols
Copy link
Copy Markdown
Author

nscottnichols commented May 12, 2026

@saforem2 do you want me to fix this sourcery junk or will you handle it? Some of them are good points.

@saforem2
Copy link
Copy Markdown
Owner

😂 @nscottnichols honestly whatever you think; like you said, some of them are good points but some of them are either nitpicky or incorrect

My only actual question would be, can we move these changes to be self-contained in the experiments/ezpz/ directory somehow?

i.e. could we move the changes from maybe:

models/common/moe.py --> experiments/ezpz/moe/model.py
models/common/token_dispatcher.py --> experiments/ezpz/moe/token_dispatcher.py
ops/scatter_add.py --> experiments/ezpz/moe/token_dispatcher.py

I guess this can be a PITA if it continues to grow; but if there are useful changes outside of experiments/ezpz/ we can maybe split them out and submit them as PRs to upstream.

Also, incase you haven't seen this thread in #10, it looks like there are upstream changes that break things:

upstream just landed pytorch#3308 (b301dfa, "Remove MoE expert for-loop fallback") which removes use_grouped_mm from GroupedExperts.Config entirely and inlines _grouped_mm as the only path. [...]

@nscottnichols
Copy link
Copy Markdown
Author

My only actual question would be, can we move these changes to be self-contained in the experiments/ezpz/ directory somehow?

Let me think about how this can be done. There are definitely some parts that can be moved.

@saforem2
Copy link
Copy Markdown
Owner

Expert-side pieces folded into #12 (commit 364fbedf):

  • Equal-counts no-grad bmm fast path inside the for-loop backend
  • Version-keyed w13 / w2_t caches on EzpzGroupedExperts
  • _record_moe_fastpath counters (env-gated)
  • The FSDP try/except ImportError two-phase fallback (commit 09a7c697, layered on top of Fix MoE expert FSDP mesh info for HSDP #9's HSDPMeshInfo fix)

Deferred to a follow-up: the token-dispatcher rewrite in torchtitan/models/common/token_dispatcher.py (+570/-91), torchtitan/ops/scatter_add.py in-place wrapper, and experiments/ezpz/moe/model.py force_load_balance flag-through. Keeping that all in experiments/ezpz/ needs an EzpzLocalTokenDispatcher subclass which is a meaningful chunk on its own. Going to keep this open as the home for that follow-up.

@nscottnichols
Copy link
Copy Markdown
Author

Hey - I've found 5 issues, and left some high level feedback:

  • In _pad_to_equal_splits, the F.pad(piece, (0, 0, 0, equal_split_size - split)) call pads extra rows at the front rather than the end of each segment, which contradicts the new test expectations (and the intended layout of original tokens followed by padding); consider swapping the height pads to (0, 0, equal_split_size - split, 0) or using torch.cat with an explicit zeros tensor instead.
  • The _shares_storage helper only compares untyped_storage().data_ptr() and ignores storage offsets, so two views into the same underlying storage at different offsets will be treated as non-aliasing; this can allow _scatter_add_forward_or_autograd to perform in-place updates on tensors that still overlap x, potentially corrupting inputs—consider also checking storage_offset() or using torch._unsafe_view-based alias checks.

I checked this and did not keep a code change here. For a 2D tensor, F.pad(piece, (0, 0, 0, n)) pads rows at the end, which matches the existing unit test expectations. I also tried replacing it with explicit zero-row concatenation, but that regressed the benchmark from ~29.9 ms to 31.514 ms, so I reverted that experiment.

@nscottnichols
Copy link
Copy Markdown
Author

Oh, I also didn't change the "shares storage" guard. The current guard is intentionally conservative for the known alias case we hit: shared_experts(x) returning x exactly. If we want broader partial-overlap detection for arbitrary views, that is worth a separate patch because it can affect the no-grad in-place scatter fast path.

@saforem2
Copy link
Copy Markdown
Owner

FYI — #13 landed the upstream resync onto ezpz (pulls pytorch#3308 which removed use_grouped_mm and _run_experts_for_loop from core). The replay introduced EzpzGroupedExperts(GroupedExperts) in torchtitan/experiments/ezpz/moe/experts.py with a vendored copy of _run_experts_for_loop.

When rebasing this PR onto merged ezpz, the expert-side optimizations (cached w13 / w2_t + equal-counts no-grad bmm fast path + counters) should layer onto that vendored _run_experts_for_loop in experts.py rather than models/common/moe.py. The token-dispatcher rewrite is unaffected by #13. Your HSDP try/except ImportError import-fallback in parallelize.py is also unaffected and should rebase cleanly.

Withdrawing the prior #12 which had attempted to bundle these in — keeping #11 as the canonical home for the optimizations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants