Optimize 10b 2b sdpa moe#11
Conversation
Reviewer's GuideOptimizes Mixture-of-Experts dispatch/combine and expert execution paths, adds fast paths for force-load-balance routing and equal all‑to‑all padding, introduces deterministic in-place scatter_add usage where safe, and extends tests and FSDP sharding logic to cover the new behavior and environments. Flow diagram for updated MoE forward path and expert fast pathsflowchart TD
Router[Router.forward]
TDDispatch[TokenDispatcher.dispatch]
ExpertsForward[MoEBlock._experts_forward]
ExpertsLoop[_run_experts_for_loop]
ExpertsBatched[Batched no-grad equal-count experts]
TDCombine[TokenDispatcher.combine]
Router -->|compute scores, selected_experts_indices, num_tokens_per_expert| TDDispatch
TDDispatch -->|routed_input, num_tokens_local, metadata.num_tokens_per_expert_list| ExpertsForward
ExpertsForward -->|decide num_tokens_for_experts| ExpertsLoop
ExpertsLoop -->|check
equal counts
and no grad| ExpertsBatched
ExpertsLoop -->|otherwise
per-expert matmuls| TDCombine
ExpertsBatched -->|use cached w13,w2_t
batched bmm| TDCombine
TDCombine -->|_scatter_add_forward_or_autograd
deterministic_scatter_add or deterministic_scatter_add_| Output[MoE output]
Flow diagram for AllToAllTokenDispatcher equal-split and load-balance fast pathsflowchart TD
A[AllToAllTokenDispatcher.dispatch]
CheckEP[ep_mesh is None?]
LocalDispatch[LocalTokenDispatcher.dispatch]
ForceLB[force_load_balance fast path]
Histc[Compute num_tokens_per_expert with histc]
EqualA2A[Compute input_splits, output_splits]
EqualSize[Compute equal_a2a_split_size]
Pad[Pad routed_input to equal splits]
DirectRoute[_force_load_balance_equal_split_routed_input]
A2A[all_to_all_single_autograd dispatch]
Permute[_permute or _permute_equal_counts]
Meta[Build AllToAllDispatchMetadata]
A --> CheckEP
CheckEP -->|yes| LocalDispatch
CheckEP -->|no| ForceLB
ForceLB -->|conditions met| EqualA2A
ForceLB -->|conditions not met| Histc
Histc --> EqualA2A
EqualA2A -->|can use equal splits| EqualSize
EqualA2A -->|cannot use equal splits| A2A
EqualSize -->|direct equal-split route
and no score_before_experts| DirectRoute
EqualSize -->|otherwise| Pad
DirectRoute --> A2A
Pad --> A2A
A2A --> Permute
Permute --> Meta
subgraph CombinePath
B[AllToAllTokenDispatcher.combine]
Unpermute[_unpermute or equal_count_unpermute]
PadCombine[_pad_to_equal_splits]
A2ACombine[all_to_all_single_autograd combine]
Compact[_compact_equal_splits]
Scatter[_scatter_add_forward_or_autograd]
B --> Unpermute
Unpermute --> PadCombine
PadCombine --> A2ACombine
A2ACombine --> Compact
Compact --> Scatter
end
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Hey - I've found 5 issues, and left some high level feedback:
- In
_pad_to_equal_splits, theF.pad(piece, (0, 0, 0, equal_split_size - split))call pads extra rows at the front rather than the end of each segment, which contradicts the new test expectations (and the intended layout of original tokens followed by padding); consider swapping the height pads to(0, 0, equal_split_size - split, 0)or usingtorch.catwith an explicit zeros tensor instead. - The
_shares_storagehelper only comparesuntyped_storage().data_ptr()and ignores storage offsets, so two views into the same underlying storage at different offsets will be treated as non-aliasing; this can allow_scatter_add_forward_or_autogradto perform in-place updates on tensors that still overlapx, potentially corrupting inputs—consider also checkingstorage_offset()or usingtorch._unsafe_view-based alias checks.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- In `_pad_to_equal_splits`, the `F.pad(piece, (0, 0, 0, equal_split_size - split))` call pads extra rows at the *front* rather than the *end* of each segment, which contradicts the new test expectations (and the intended layout of original tokens followed by padding); consider swapping the height pads to `(0, 0, equal_split_size - split, 0)` or using `torch.cat` with an explicit zeros tensor instead.
- The `_shares_storage` helper only compares `untyped_storage().data_ptr()` and ignores storage offsets, so two views into the same underlying storage at different offsets will be treated as non-aliasing; this can allow `_scatter_add_forward_or_autograd` to perform in-place updates on tensors that still overlap `x`, potentially corrupting inputs—consider also checking `storage_offset()` or using `torch._unsafe_view`-based alias checks.
## Individual Comments
### Comment 1
<location path="torchtitan/models/common/moe.py" line_range="49" />
<code_context>
+ ):
+ _record_moe_fastpath("batched_no_grad_experts")
+ tokens_per_expert = num_tokens_per_expert_list[0]
+ x_grouped = x.view(len(num_tokens_per_expert_list), tokens_per_expert, x.shape[-1])
+ if w13 is None:
+ w13 = torch.cat((w1, w3), dim=1)
</code_context>
<issue_to_address>
**issue (bug_risk):** The batched experts path assumes `x` is contiguous; using `view` here can misbehave for non-contiguous inputs.
For non-contiguous `x` (e.g., after `all_to_all`/permute), `view` can fail or return a wrongly grouped tensor. Since this path depends on an expert-major contiguous layout, consider either `x = x.contiguous().view(...)` or `x_grouped = x.reshape(...)`, and add an assertion like `x.numel() == len(num_tokens_per_expert_list) * tokens_per_expert * x.shape[-1]` to keep the fast path correct under layout changes.
</issue_to_address>
### Comment 2
<location path="torchtitan/models/common/token_dispatcher.py" line_range="430" />
<code_context>
+ """
+ pieces = []
+ for rank, split in enumerate(input_splits):
+ tokens_for_rank = split // num_local_experts
+ piece = x[rank::ep_size].repeat(num_local_experts, 1)
+ if split < equal_split_size:
</code_context>
<issue_to_address>
**nitpick:** `tokens_for_rank` is computed but never used in the equal-split routing fast path.
In `_force_load_balance_equal_split_routed_input`, `tokens_for_rank` is computed but never used. If it’s only for debugging, remove it; if it’s meant to check something like `split % num_local_experts == 0`, consider replacing it with an explicit assertion instead.
</issue_to_address>
### Comment 3
<location path="torchtitan/ops/scatter_add.py" line_range="23-20" />
<code_context>
torch.use_deterministic_algorithms(prev, warn_only=prev_warn_only)
+def deterministic_scatter_add_(
+ out: torch.Tensor, index: torch.Tensor, src: torch.Tensor
+) -> torch.Tensor:
+ """Deterministic in-place scatter_add for forward-only/no-grad paths."""
+ prev = torch.are_deterministic_algorithms_enabled()
+ prev_warn_only = torch.is_deterministic_algorithms_warn_only_enabled()
+ torch.use_deterministic_algorithms(True, warn_only=False)
+ try:
+ return out.scatter_add_(dim=0, index=index, src=src)
+ finally:
+ torch.use_deterministic_algorithms(prev, warn_only=prev_warn_only)
+
+
</code_context>
<issue_to_address>
**suggestion (bug_risk):** The new in-place `deterministic_scatter_add_` lacks a `register_fake` implementation, unlike the out-of-place version.
The out-of-place op has a `@...register_fake` handler so fake-tensor / export can trace it without real execution. The in-place variant currently lacks this, so tracing/fake tensor modes may fail when they use this path. Please add a matching fake registration for the in-place op to keep behavior consistent (e.g., returning an appropriately shaped empty-like tensor or `out`).
Suggested implementation:
```python
def deterministic_scatter_add_(
out: torch.Tensor, index: torch.Tensor, src: torch.Tensor
) -> torch.Tensor:
"""Deterministic in-place scatter_add for forward-only/no-grad paths."""
prev = torch.are_deterministic_algorithms_enabled()
prev_warn_only = torch.is_deterministic_algorithms_warn_only_enabled()
torch.use_deterministic_algorithms(True, warn_only=False)
try:
return out.scatter_add_(dim=0, index=index, src=src)
finally:
torch.use_deterministic_algorithms(prev, warn_only=prev_warn_only)
@torch.library.register_fake("torchtitan::deterministic_scatter_add_")
def deterministic_scatter_add__fake(
out: torch.Tensor, index: torch.Tensor, src: torch.Tensor
) -> torch.Tensor:
# In fake / export modes, model the in-place op by returning `out` with the
# correct shape and dtype, without performing any real computation.
return out
```
The `@torch.library.register_fake("torchtitan::deterministic_scatter_add_")` line assumes that:
1. You are using `torch.library`-based custom ops, and
2. The schema/name for the in-place op is `torchtitan::deterministic_scatter_add_`.
To keep this consistent with the existing out-of-place op:
- Update the decorator target to match whatever is used for the out-of-place fake registration, e.g. if the out-of-place version uses:
`@_scatter_add_lib.register_fake("deterministic_scatter_add")`
then this should instead be:
`@_scatter_add_lib.register_fake("deterministic_scatter_add_")`
- Ensure the function name (`deterministic_scatter_add__fake`) matches the naming convention you use for other fake handlers in this file (you may want to rename it to mirror the out-of-place fake’s function name plus a trailing underscore).
</issue_to_address>
### Comment 4
<location path="tests/unit_tests/test_expert_parallel.py" line_range="15" />
<code_context>
+from torch import nn
+
+from torchtitan.models.common.linear import Linear
+from torchtitan.models.common.moe import TokenChoiceTopKRouter
+from torchtitan.models.common.token_dispatcher import (
+ AllToAllTokenDispatcher,
</code_context>
<issue_to_address>
**suggestion (testing):** Add unit tests for the new batched no-grad experts path and caching in MoE experts.
The new batched no-grad path in `_run_experts_for_loop` (for equal `num_tokens_per_expert`) and the caching of `w13` / `w2_t` in `MixtureOfExperts` under `torch.no_grad()` are currently untested. Please add tests that:
1. Check numerical equivalence between the batched path and the original per-expert loop when token counts are equal.
2. Validate cache hit/miss behavior, including `_MOE_FASTPATH_COUNTERS` updates for `cached_w13_hit`/`miss` and `cached_w2_t_hit`/`miss` based on tensor identity/version.
A focused `TestMoEExpertsFastPaths` that builds a small MoE layer, runs two no-grad forwards with identical weights and `num_tokens_per_expert`, and asserts both output equality and the expected counter values would cover this well.
Suggested implementation:
```python
from torch import nn
from torchtitan.models.common.linear import Linear
from torchtitan.models.common.moe import MixtureOfExperts, TokenChoiceTopKRouter
from torchtitan.models.common.token_dispatcher import (
```
```python
from torchtitan.ops.scatter_add import (
deterministic_scatter_add,
deterministic_scatter_add_,
)
class TestMoEExpertsFastPaths(unittest.TestCase):
def setUp(self) -> None:
# Reset fastpath counters before each test to isolate behavior
for key in [
"cached_w13_hit",
"cached_w13_miss",
"cached_w2_t_hit",
"cached_w2_t_miss",
]:
if key in _MOE_FASTPATH_COUNTERS:
_MOE_FASTPATH_COUNTERS[key] = 0
def _build_small_moe(self, d_model: int = 4, d_hidden: int = 8, num_experts: int = 2):
# Minimal MoE module; exact signature may need to be aligned with implementation
moe = MixtureOfExperts(
model_dim=d_model,
ffn_hidden_size=d_hidden,
num_experts=num_experts,
)
moe.eval()
return moe
def _assert_fastpath_counters(
self,
cached_w13_hit: int,
cached_w13_miss: int,
cached_w2_t_hit: int,
cached_w2_t_miss: int,
):
if "cached_w13_hit" in _MOE_FASTPATH_COUNTERS:
self.assertEqual(_MOE_FASTPATH_COUNTERS["cached_w13_hit"], cached_w13_hit)
if "cached_w13_miss" in _MOE_FASTPATH_COUNTERS:
self.assertEqual(_MOE_FASTPATH_COUNTERS["cached_w13_miss"], cached_w13_miss)
if "cached_w2_t_hit" in _MOE_FASTPATH_COUNTERS:
self.assertEqual(_MOE_FASTPATH_COUNTERS["cached_w2_t_hit"], cached_w2_t_hit)
if "cached_w2_t_miss" in _MOE_FASTPATH_COUNTERS:
self.assertEqual(_MOE_FASTPATH_COUNTERS["cached_w2_t_miss"], cached_w2_t_miss)
def test_batched_nograd_experts_numerical_equivalence(self):
"""
Validate that the batched no-grad experts path produces numerically
equivalent outputs to the original per-expert loop when
num_tokens_per_expert is equal across experts.
"""
torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
d_model = 4
d_hidden = 8
num_experts = 2
tokens_per_expert = 3 # equal token count per expert
batch_tokens = num_experts * tokens_per_expert
moe = self._build_small_moe(d_model, d_hidden, num_experts).to(device)
# Synthetic input tokens
x = torch.randn(batch_tokens, d_model, device=device)
# Build routing such that each expert gets the same number of tokens.
# We assume the MixtureOfExperts forward can take a precomputed
# expert assignment or router probabilities; adjust as needed.
expert_indices = torch.arange(num_experts, device=device).repeat_interleave(tokens_per_expert)
# Shape [T, num_experts] one-hot or logit-style router scores
router_logits = F.one_hot(expert_indices, num_classes=num_experts).float()
# Run reference path with autograd on (forces original per-expert loop)
x_ref = x.detach().clone().requires_grad_(True)
out_ref = moe(x_ref, router_logits)
# Run batched no-grad fast path; new code path should be triggered
with torch.no_grad():
out_fast = moe(x, router_logits)
self.assertTrue(
torch.allclose(out_ref, out_fast, atol=1e-5, rtol=1e-5),
msg="MoE batched no-grad experts path must be numerically equivalent to per-expert loop",
)
def test_moe_expert_cache_counters(self):
"""
Validate cache hit/miss behavior for the no-grad expert fast path:
- First no-grad forward should be a cache miss for both w13 and w2_t
- Second no-grad forward with identical weights should be a cache hit
- Counter values should be updated via _MOE_FASTPATH_COUNTERS
"""
torch.manual_seed(1)
device = "cuda" if torch.cuda.is_available() else "cpu"
d_model = 4
d_hidden = 8
num_experts = 2
tokens_per_expert = 2
batch_tokens = num_experts * tokens_per_expert
moe = self._build_small_moe(d_model, d_hidden, num_experts).to(device)
# Clone weights so we can later mutate their version to force misses.
# These attribute names may need to be adjusted to match the implementation.
w1_before = moe.experts[0].w1.weight
w3_before = moe.experts[0].w3.weight
w2_before = moe.experts[0].w2.weight
x = torch.randn(batch_tokens, d_model, device=device)
expert_indices = torch.arange(num_experts, device=device).repeat_interleave(tokens_per_expert)
router_logits = F.one_hot(expert_indices, num_classes=num_experts).float()
# First no-grad forward: expect cache misses for both w13 and w2_t
with torch.no_grad():
moe(x, router_logits)
# Depending on how many times the fastpath is invoked internally,
# we assert "at least" one miss.
if "cached_w13_miss" in _MOE_FASTPATH_COUNTERS:
self.assertGreaterEqual(_MOE_FASTPATH_COUNTERS["cached_w13_miss"], 1)
if "cached_w2_t_miss" in _MOE_FASTPATH_COUNTERS:
self.assertGreaterEqual(_MOE_FASTPATH_COUNTERS["cached_w2_t_miss"], 1)
# Second no-grad forward with identical weights: expect hits
with torch.no_grad():
moe(x, router_logits)
if "cached_w13_hit" in _MOE_FASTPATH_COUNTERS:
self.assertGreaterEqual(_MOE_FASTPATH_COUNTERS["cached_w13_hit"], 1)
if "cached_w2_t_hit" in _MOE_FASTPATH_COUNTERS:
self.assertGreaterEqual(_MOE_FASTPATH_COUNTERS["cached_w2_t_hit"], 1)
# Now mutate weights to invalidate caches and force misses based on tensor identity/version
with torch.no_grad():
w1_before.add_(0.1)
w3_before.add_(0.1)
w2_before.add_(0.1)
prev_counters = dict(_MOE_FASTPATH_COUNTERS)
with torch.no_grad():
moe(x, router_logits)
# After weight mutation, we expect additional misses
if "cached_w13_miss" in _MOE_FASTPATH_COUNTERS:
self.assertGreater(
_MOE_FASTPATH_COUNTERS["cached_w13_miss"],
prev_counters.get("cached_w13_miss", 0),
)
if "cached_w2_t_miss" in _MOE_FASTPATH_COUNTERS:
self.assertGreater(
_MOE_FASTPATH_COUNTERS["cached_w2_t_miss"],
prev_counters.get("cached_w2_t_miss", 0),
)
```
The above tests assume:
1. `MixtureOfExperts` exists in `torchtitan.models.common.moe` and has a constructor like:
`MixtureOfExperts(model_dim: int, ffn_hidden_size: int, num_experts: int, ...)` and a forward signature `forward(x, router_logits)`.
2. A submodule layout like `moe.experts[expert_index].w1`, `.w2`, `.w3` for FFN weights.
3. `_MOE_FASTPATH_COUNTERS` exposes the keys `"cached_w13_hit"`, `"cached_w13_miss"`, `"cached_w2_t_hit"`, and `"cached_w2_t_miss"`.
You may need to:
- Adjust `_build_small_moe` to match the actual constructor signature of `MixtureOfExperts` and any required router/dispatcher configuration.
- Adapt how `router_logits` / expert assignments are passed into `forward` to line up with your implementation (e.g., using `TokenChoiceTopKRouter` directly or passing a routing map).
- Update weight attribute names (`w1`, `w2`, `w3`) and expert access (`moe.experts[...]`) to the actual structure.
- If `_run_experts_for_loop` is exposed differently, you can refine `test_batched_nograd_experts_numerical_equivalence` to call it directly with explicit `num_tokens_per_expert` instead of going through the public `forward`.
</issue_to_address>
### Comment 5
<location path="torchtitan/models/common/token_dispatcher.py" line_range="966" />
<code_context>
self.pad_multiple = config.pad_multiple
- def dispatch(self, x, top_scores, selected_experts_indices):
+ def dispatch(
+ self,
+ x,
</code_context>
<issue_to_address>
**issue (complexity):** Consider encapsulating equal-split routing, force-load-balance fast paths, and fastpath-debug bookkeeping into small helpers/strategy objects so AllToAllTokenDispatcher.dispatch/combine stay linear and easier to read.
The main complexity spike is in `AllToAllTokenDispatcher.dispatch/combine`, where multiple orthogonal concerns are interwoven and encoded via booleans/metadata fields. You can keep all functionality while reducing cognitive load by **encapsulating equal-split + fastpath state** and **separating routing strategies**.
### 1. Encapsulate equal A2A split behavior
Right now `equal_a2a_split_size`, `normal_equal_a2a_padding`, `dispatch_input_splits`, `dispatch_output_splits`, `input_splits_list`, `output_splits_list`, and the pad/compact helpers are threaded across dispatch/combine and encoded in metadata. This can be encapsulated into a small helper/struct that owns the state and provides symmetric dispatch/combine transforms.
Example sketch:
```python
@dataclass
class EqualA2ARouting:
enabled: bool
equal_split_size: int | None = None
normal_equal_padding: bool = False
# These are the "logical" splits; we can expose them if caller needs them.
input_splits: list[int] | None = None
output_splits: list[int] | None = None
def prepare_dispatch(
self,
routed_input: torch.Tensor,
input_splits: list[int],
output_splits: list[int],
*,
device: torch.device,
dtype: torch.dtype,
ep_mesh: DeviceMesh,
sp_size: int,
score_before_experts: bool,
) -> tuple[torch.Tensor, list[int] | None, list[int] | None]:
"""
Returns:
dispatched_tensor, dispatch_output_splits, dispatch_input_splits
"""
if not self.enabled:
self.input_splits = input_splits
self.output_splits = output_splits
return routed_input, output_splits, input_splits
# decide equal_split_size + normal_equal_padding here
# self.equal_split_size, self.normal_equal_padding = ...
padded = self._pad_to_equal_splits(
routed_input, input_splits, self.equal_split_size
)
self.input_splits = input_splits
self.output_splits = output_splits
return padded, None, None # equal splits => None for autograd wrapper
def prepare_combine(
self,
routed_output: torch.Tensor,
) -> tuple[torch.Tensor, list[int] | None, list[int] | None]:
"""
Returns:
tensor_for_all_to_all, combine_output_splits, combine_input_splits
"""
if not self.enabled or self.equal_split_size is None:
return routed_output, self.output_splits, self.input_splits
padded = self._pad_to_equal_splits(
routed_output, self.output_splits, self.equal_split_size
)
return padded, None, None
def post_dispatch(self, x: torch.Tensor) -> torch.Tensor:
if not self.enabled or self.equal_split_size is None:
return x
return self._compact_equal_splits(x, self.output_splits, self.equal_split_size)
def post_combine(self, x: torch.Tensor) -> torch.Tensor:
if not self.enabled or self.equal_split_size is None:
return x
return self._compact_equal_splits(x, self.input_splits, self.equal_split_size)
@staticmethod
def _pad_to_equal_splits(x: torch.Tensor, splits: list[int], equal_split: int) -> torch.Tensor:
# existing padding logic here
@staticmethod
def _compact_equal_splits(x: torch.Tensor, splits: list[int], equal_split: int) -> torch.Tensor:
# existing compact logic here
```
Then `dispatch()` becomes more linear:
```python
def dispatch(...):
...
# compute num_tokens_per_expert, input_splits_list, output_splits_list
equal_route = EqualA2ARouting(
enabled=self._should_use_equal_a2a(input_splits_list, ...),
)
routed_input = x[token_indices_experts_sorted]
if self.score_before_experts:
routed_input = ...
routed_input, dispatch_out_splits, dispatch_in_splits = equal_route.prepare_dispatch(
routed_input,
input_splits_list,
output_splits_list,
device=num_tokens_per_expert.device,
dtype=num_tokens_per_expert.dtype,
ep_mesh=self.ep_mesh,
sp_size=self.sp_size,
score_before_experts=self.score_before_experts,
)
routed_input = all_to_all_single_autograd(
routed_input,
dispatch_out_splits,
dispatch_in_splits,
self.ep_mesh,
)
routed_input = equal_route.post_dispatch(routed_input)
# permute / permute_equal_counts
...
metadata = AllToAllDispatchMetadata(
...,
equal_a2a_split_size=equal_route.equal_split_size,
normal_equal_a2a_padding=equal_route.normal_equal_padding,
input_splits=equal_route.input_splits,
output_splits=equal_route.output_splits,
)
```
And `combine()` can mirror that with fewer ad-hoc flags:
```python
def combine(...):
...
routed_output = self._unpermute(
routed_output,
metadata.input_shape,
metadata.permuted_indices,
metadata.rank_major_shape,
)
equal_route = EqualA2ARouting(
enabled=metadata.equal_a2a_split_size is not None,
equal_split_size=metadata.equal_a2a_split_size,
normal_equal_padding=metadata.normal_equal_a2a_padding,
input_splits=metadata.input_splits,
output_splits=metadata.output_splits,
)
routed_output, combine_out_splits, combine_in_splits = equal_route.prepare_combine(
routed_output
)
routed_output = all_to_all_single_autograd(
routed_output,
combine_out_splits,
combine_in_splits,
self.ep_mesh,
)
routed_output = equal_route.post_combine(routed_output)
...
```
That removes a lot of local booleans (`direct_equal_split_route`, `normal_equal_a2a_padding`, `dispatch_input_splits`, `dispatch_output_splits`) and makes the symmetry between dispatch/combine explicit and local to one helper.
### 2. Separate force-load-balance routing as a strategy
`use_force_load_balance_fast_path`, `_force_load_balance_splits`, `_force_load_balance_sort_indices`, `_force_load_balance_equal_split_routed_input`, plus the implicit protocol using `uniform_local_count` + `rank_major_shape` introduce another cross-cutting mode. You can keep the optimization but isolate its decision + logic in a small “strategy” object or helper, so `dispatch()` only selects it and asks for results, instead of mixing both paths inline.
Example sketch:
```python
class ForceLoadBalanceRouting:
def __init__(self, dispatcher: "AllToAllTokenDispatcher"):
self.dispatcher = dispatcher
def is_enabled(self, *, sp_size: int) -> bool:
return (
self.dispatcher.force_load_balance
and sp_size == 1
and not torch.compiler.is_compiling()
)
def sort_indices(
self, selected_experts_indices: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
return self.dispatcher._force_load_balance_sort_indices(
selected_experts_indices
)
def compute_splits(
self,
num_tokens_per_expert: torch.Tensor,
num_tokens: int,
ep_size: int,
ep_rank: int,
) -> tuple[torch.Tensor, list[int], list[int], int | None]:
return self.dispatcher._force_load_balance_splits(
num_tokens,
ep_size,
ep_rank,
num_tokens_per_expert.dtype,
num_tokens_per_expert.device,
)
def build_equal_split_buffer(
self,
x: torch.Tensor,
ep_size: int,
num_local_experts: int,
input_splits: list[int],
equal_split_size: int,
) -> torch.Tensor:
return self.dispatcher._force_load_balance_equal_split_routed_input(
x, ep_size, num_local_experts, input_splits, equal_split_size
)
```
Usage in `dispatch()`:
```python
def dispatch(...):
...
flb = ForceLoadBalanceRouting(self)
use_flb = flb.is_enabled(sp_size=self.sp_size)
if use_flb:
assignment_indices_experts_sorted, token_indices_experts_sorted = flb.sort_indices(
selected_experts_indices
)
else:
assignment_indices_experts_sorted = torch.argsort(
selected_experts_indices.view(-1), stable=True
)
token_indices_experts_sorted = assignment_indices_experts_sorted // self.top_k
top_scores_experts_sorted = top_scores.view(-1)[assignment_indices_experts_sorted]
if use_flb:
(
num_tokens_per_expert_group,
input_splits_list,
output_splits_list,
uniform_local_count,
) = flb.compute_splits(
num_tokens_per_expert,
original_num_tokens,
ep_size,
self.ep_mesh.get_local_rank(),
)
else:
# existing all_to_all_single path
...
```
This keeps the “normal” path readable and clearly separates the specialized force-load-balance behavior, while preserving all current fastpaths.
### 3. Isolate fastpath debug accounting
Environment checks + counter updates are currently sprinkled throughout dispatch/combine, which adds noise on the hot path. You can wrap them behind small helpers that return identity/decorated callables so control flow stays clean.
For example, instead of:
```python
if not self.score_before_experts:
if routed_output.dtype == torch.bfloat16:
_record_moe_fastpath("score_after_experts_bf16")
else:
_record_moe_fastpath("score_after_experts")
routed_output = routed_output * ...
```
factor a tiny helper:
```python
def _record_score_path(name: str, tensor: torch.Tensor) -> None:
if tensor.dtype == torch.bfloat16:
_record_moe_fastpath(f"{name}_bf16")
else:
_record_moe_fastpath(name)
# call site
if not self.score_before_experts:
_record_score_path("score_after_experts", routed_output)
routed_output = routed_output * ...
```
or a decorator-style wrapper for the scatter-add path:
```python
def _scatter_add_with_debug(self, out, index, src, protected_input):
_record_moe_fastpath("scatter_add")
return _scatter_add_forward_or_autograd(out, index, src, protected_input)
```
and then:
```python
out = self._scatter_add_with_debug(out, scatter_index, routed_output, x)
```
This keeps the mathematical logic (“what is being computed”) separated from telemetry, making the main control flow easier to follow.
---
These changes keep all current functionality and fast paths, but:
* reduce the number of flags and cross-method protocols exposed in `dispatch/combine`,
* make equal-split and force-load-balance behavior self-contained and symmetric, and
* move debug accounting off the critical visual path.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| self.pad_multiple = config.pad_multiple | ||
|
|
||
| def dispatch(self, x, top_scores, selected_experts_indices): | ||
| def dispatch( |
There was a problem hiding this comment.
issue (complexity): Consider encapsulating equal-split routing, force-load-balance fast paths, and fastpath-debug bookkeeping into small helpers/strategy objects so AllToAllTokenDispatcher.dispatch/combine stay linear and easier to read.
The main complexity spike is in AllToAllTokenDispatcher.dispatch/combine, where multiple orthogonal concerns are interwoven and encoded via booleans/metadata fields. You can keep all functionality while reducing cognitive load by encapsulating equal-split + fastpath state and separating routing strategies.
1. Encapsulate equal A2A split behavior
Right now equal_a2a_split_size, normal_equal_a2a_padding, dispatch_input_splits, dispatch_output_splits, input_splits_list, output_splits_list, and the pad/compact helpers are threaded across dispatch/combine and encoded in metadata. This can be encapsulated into a small helper/struct that owns the state and provides symmetric dispatch/combine transforms.
Example sketch:
@dataclass
class EqualA2ARouting:
enabled: bool
equal_split_size: int | None = None
normal_equal_padding: bool = False
# These are the "logical" splits; we can expose them if caller needs them.
input_splits: list[int] | None = None
output_splits: list[int] | None = None
def prepare_dispatch(
self,
routed_input: torch.Tensor,
input_splits: list[int],
output_splits: list[int],
*,
device: torch.device,
dtype: torch.dtype,
ep_mesh: DeviceMesh,
sp_size: int,
score_before_experts: bool,
) -> tuple[torch.Tensor, list[int] | None, list[int] | None]:
"""
Returns:
dispatched_tensor, dispatch_output_splits, dispatch_input_splits
"""
if not self.enabled:
self.input_splits = input_splits
self.output_splits = output_splits
return routed_input, output_splits, input_splits
# decide equal_split_size + normal_equal_padding here
# self.equal_split_size, self.normal_equal_padding = ...
padded = self._pad_to_equal_splits(
routed_input, input_splits, self.equal_split_size
)
self.input_splits = input_splits
self.output_splits = output_splits
return padded, None, None # equal splits => None for autograd wrapper
def prepare_combine(
self,
routed_output: torch.Tensor,
) -> tuple[torch.Tensor, list[int] | None, list[int] | None]:
"""
Returns:
tensor_for_all_to_all, combine_output_splits, combine_input_splits
"""
if not self.enabled or self.equal_split_size is None:
return routed_output, self.output_splits, self.input_splits
padded = self._pad_to_equal_splits(
routed_output, self.output_splits, self.equal_split_size
)
return padded, None, None
def post_dispatch(self, x: torch.Tensor) -> torch.Tensor:
if not self.enabled or self.equal_split_size is None:
return x
return self._compact_equal_splits(x, self.output_splits, self.equal_split_size)
def post_combine(self, x: torch.Tensor) -> torch.Tensor:
if not self.enabled or self.equal_split_size is None:
return x
return self._compact_equal_splits(x, self.input_splits, self.equal_split_size)
@staticmethod
def _pad_to_equal_splits(x: torch.Tensor, splits: list[int], equal_split: int) -> torch.Tensor:
# existing padding logic here
@staticmethod
def _compact_equal_splits(x: torch.Tensor, splits: list[int], equal_split: int) -> torch.Tensor:
# existing compact logic hereThen dispatch() becomes more linear:
def dispatch(...):
...
# compute num_tokens_per_expert, input_splits_list, output_splits_list
equal_route = EqualA2ARouting(
enabled=self._should_use_equal_a2a(input_splits_list, ...),
)
routed_input = x[token_indices_experts_sorted]
if self.score_before_experts:
routed_input = ...
routed_input, dispatch_out_splits, dispatch_in_splits = equal_route.prepare_dispatch(
routed_input,
input_splits_list,
output_splits_list,
device=num_tokens_per_expert.device,
dtype=num_tokens_per_expert.dtype,
ep_mesh=self.ep_mesh,
sp_size=self.sp_size,
score_before_experts=self.score_before_experts,
)
routed_input = all_to_all_single_autograd(
routed_input,
dispatch_out_splits,
dispatch_in_splits,
self.ep_mesh,
)
routed_input = equal_route.post_dispatch(routed_input)
# permute / permute_equal_counts
...
metadata = AllToAllDispatchMetadata(
...,
equal_a2a_split_size=equal_route.equal_split_size,
normal_equal_a2a_padding=equal_route.normal_equal_padding,
input_splits=equal_route.input_splits,
output_splits=equal_route.output_splits,
)And combine() can mirror that with fewer ad-hoc flags:
def combine(...):
...
routed_output = self._unpermute(
routed_output,
metadata.input_shape,
metadata.permuted_indices,
metadata.rank_major_shape,
)
equal_route = EqualA2ARouting(
enabled=metadata.equal_a2a_split_size is not None,
equal_split_size=metadata.equal_a2a_split_size,
normal_equal_padding=metadata.normal_equal_a2a_padding,
input_splits=metadata.input_splits,
output_splits=metadata.output_splits,
)
routed_output, combine_out_splits, combine_in_splits = equal_route.prepare_combine(
routed_output
)
routed_output = all_to_all_single_autograd(
routed_output,
combine_out_splits,
combine_in_splits,
self.ep_mesh,
)
routed_output = equal_route.post_combine(routed_output)
...That removes a lot of local booleans (direct_equal_split_route, normal_equal_a2a_padding, dispatch_input_splits, dispatch_output_splits) and makes the symmetry between dispatch/combine explicit and local to one helper.
2. Separate force-load-balance routing as a strategy
use_force_load_balance_fast_path, _force_load_balance_splits, _force_load_balance_sort_indices, _force_load_balance_equal_split_routed_input, plus the implicit protocol using uniform_local_count + rank_major_shape introduce another cross-cutting mode. You can keep the optimization but isolate its decision + logic in a small “strategy” object or helper, so dispatch() only selects it and asks for results, instead of mixing both paths inline.
Example sketch:
class ForceLoadBalanceRouting:
def __init__(self, dispatcher: "AllToAllTokenDispatcher"):
self.dispatcher = dispatcher
def is_enabled(self, *, sp_size: int) -> bool:
return (
self.dispatcher.force_load_balance
and sp_size == 1
and not torch.compiler.is_compiling()
)
def sort_indices(
self, selected_experts_indices: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
return self.dispatcher._force_load_balance_sort_indices(
selected_experts_indices
)
def compute_splits(
self,
num_tokens_per_expert: torch.Tensor,
num_tokens: int,
ep_size: int,
ep_rank: int,
) -> tuple[torch.Tensor, list[int], list[int], int | None]:
return self.dispatcher._force_load_balance_splits(
num_tokens,
ep_size,
ep_rank,
num_tokens_per_expert.dtype,
num_tokens_per_expert.device,
)
def build_equal_split_buffer(
self,
x: torch.Tensor,
ep_size: int,
num_local_experts: int,
input_splits: list[int],
equal_split_size: int,
) -> torch.Tensor:
return self.dispatcher._force_load_balance_equal_split_routed_input(
x, ep_size, num_local_experts, input_splits, equal_split_size
)Usage in dispatch():
def dispatch(...):
...
flb = ForceLoadBalanceRouting(self)
use_flb = flb.is_enabled(sp_size=self.sp_size)
if use_flb:
assignment_indices_experts_sorted, token_indices_experts_sorted = flb.sort_indices(
selected_experts_indices
)
else:
assignment_indices_experts_sorted = torch.argsort(
selected_experts_indices.view(-1), stable=True
)
token_indices_experts_sorted = assignment_indices_experts_sorted // self.top_k
top_scores_experts_sorted = top_scores.view(-1)[assignment_indices_experts_sorted]
if use_flb:
(
num_tokens_per_expert_group,
input_splits_list,
output_splits_list,
uniform_local_count,
) = flb.compute_splits(
num_tokens_per_expert,
original_num_tokens,
ep_size,
self.ep_mesh.get_local_rank(),
)
else:
# existing all_to_all_single path
...This keeps the “normal” path readable and clearly separates the specialized force-load-balance behavior, while preserving all current fastpaths.
3. Isolate fastpath debug accounting
Environment checks + counter updates are currently sprinkled throughout dispatch/combine, which adds noise on the hot path. You can wrap them behind small helpers that return identity/decorated callables so control flow stays clean.
For example, instead of:
if not self.score_before_experts:
if routed_output.dtype == torch.bfloat16:
_record_moe_fastpath("score_after_experts_bf16")
else:
_record_moe_fastpath("score_after_experts")
routed_output = routed_output * ...factor a tiny helper:
def _record_score_path(name: str, tensor: torch.Tensor) -> None:
if tensor.dtype == torch.bfloat16:
_record_moe_fastpath(f"{name}_bf16")
else:
_record_moe_fastpath(name)
# call site
if not self.score_before_experts:
_record_score_path("score_after_experts", routed_output)
routed_output = routed_output * ...or a decorator-style wrapper for the scatter-add path:
def _scatter_add_with_debug(self, out, index, src, protected_input):
_record_moe_fastpath("scatter_add")
return _scatter_add_forward_or_autograd(out, index, src, protected_input)and then:
out = self._scatter_add_with_debug(out, scatter_index, routed_output, x)This keeps the mathematical logic (“what is being computed”) separated from telemetry, making the main control flow easier to follow.
These changes keep all current functionality and fast paths, but:
- reduce the number of flags and cross-method protocols exposed in
dispatch/combine, - make equal-split and force-load-balance behavior self-contained and symmetric, and
- move debug accounting off the critical visual path.
|
@saforem2 do you want me to fix this sourcery junk or will you handle it? Some of them are good points. |
|
😂 @nscottnichols honestly whatever you think; like you said, some of them are good points but some of them are either nitpicky or incorrect My only actual question would be, can we move these changes to be self-contained in the i.e. could we move the changes from maybe: models/common/moe.py --> experiments/ezpz/moe/model.py
models/common/token_dispatcher.py --> experiments/ezpz/moe/token_dispatcher.py
ops/scatter_add.py --> experiments/ezpz/moe/token_dispatcher.pyI guess this can be a PITA if it continues to grow; but if there are useful changes outside of Also, incase you haven't seen this thread in #10, it looks like there are upstream changes that break things:
|
Let me think about how this can be done. There are definitely some parts that can be moved. |
|
Expert-side pieces folded into #12 (commit
Deferred to a follow-up: the token-dispatcher rewrite in |
I checked this and did not keep a code change here. For a 2D tensor, F.pad(piece, (0, 0, 0, n)) pads rows at the end, which matches the existing unit test expectations. I also tried replacing it with explicit zero-row concatenation, but that regressed the benchmark from ~29.9 ms to 31.514 ms, so I reverted that experiment. |
|
Oh, I also didn't change the "shares storage" guard. The current guard is intentionally conservative for the known alias case we hit: shared_experts(x) returning x exactly. If we want broader partial-overlap detection for arbitrary views, that is worth a separate patch because it can affect the no-grad in-place scatter fast path. |
|
FYI — #13 landed the upstream resync onto When rebasing this PR onto merged Withdrawing the prior #12 which had attempted to bundle these in — keeping #11 as the canonical home for the optimizations. |
Added some optimizations for the MoE path, wanted to PR before our forks drift too far apart
Summary by Sourcery
Optimize Mixture-of-Experts routing, dispatch, and expert execution paths for better load balancing and no-grad performance while preserving numerical behavior.
New Features:
Enhancements:
Tests: