[megatron] reduce R3 router replay memory with per-trajectory traces#6860
[megatron] reduce R3 router replay memory with per-trajectory traces#6860tntnnlrw wants to merge 2 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a per-trajectory routing trace fast path ('R3 speedup') to optimize memory and performance in the agent loop and padding utilities, including functions to narrow routed expert dtypes and handle unpadded per-trajectory data. The review feedback highlights critical performance bottlenecks caused by synchronous GPU-to-CPU transfers (device-to-host copies). Specifically, it recommends avoiding .item() calls on GPU tensors by checking .is_cuda in narrow_routed_experts, converting seqlens to a CPU list before looping in _align_r3_per_traj_tensors, and removing a redundant narrowing operation on the concatenated GPU tensor flat.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def narrow_routed_experts(routed_experts): | ||
| """Convert routed expert ids to the smallest integer dtype that preserves their values.""" | ||
| if routed_experts is None: | ||
| return None | ||
|
|
||
| if isinstance(routed_experts, torch.Tensor): | ||
| if routed_experts.numel() == 0: | ||
| return routed_experts.to(torch.uint8) |
There was a problem hiding this comment.
If routed_experts is a PyTorch tensor on GPU, calling .min().item() and .max().item() will trigger synchronous device-to-host copies, blocking the CPU thread. Add a check to return the tensor immediately if it is already on a CUDA device, as it should have already been narrowed on CPU before being moved to GPU.
| def narrow_routed_experts(routed_experts): | |
| """Convert routed expert ids to the smallest integer dtype that preserves their values.""" | |
| if routed_experts is None: | |
| return None | |
| if isinstance(routed_experts, torch.Tensor): | |
| if routed_experts.numel() == 0: | |
| return routed_experts.to(torch.uint8) | |
| def narrow_routed_experts(routed_experts): | |
| """Convert routed expert ids to the smallest integer dtype that preserves their values.""" | |
| if routed_experts is None: | |
| return None | |
| if isinstance(routed_experts, torch.Tensor): | |
| if routed_experts.is_cuda: | |
| return routed_experts | |
| if routed_experts.numel() == 0: | |
| return routed_experts.to(torch.uint8) |
| def _align_r3_per_traj_tensors(per_traj_list, seqlens, target_device): | ||
| aligned = [] | ||
| for i, item in enumerate(per_traj_list): | ||
| tensor = _r3_per_traj_item_to_tensor(item).to(device=target_device) | ||
| target_len = int(seqlens[i].item()) |
There was a problem hiding this comment.
Calling seqlens[i].item() inside a loop over the batch size triggers a synchronous device-to-host copy for every single iteration when seqlens is on GPU. This introduces a severe performance bottleneck. Convert seqlens to a CPU list once before the loop using seqlens.tolist() to avoid per-iteration GPU synchronization.
def _align_r3_per_traj_tensors(per_traj_list, seqlens, target_device):
aligned = []
seqlens_cpu = seqlens.tolist()
for i, item in enumerate(per_traj_list):
tensor = _r3_per_traj_item_to_tensor(item).to(device=target_device)
target_len = int(seqlens_cpu[i])| if r3_speedup_enabled() and "routed_experts_per_traj" in data.keys(): | ||
| per_traj_stack = data.pop("routed_experts_per_traj") | ||
| per_traj_list = normalize_r3_per_traj_list(per_traj_stack) | ||
| tensors = _align_r3_per_traj_tensors(per_traj_list, attention_mask.sum(dim=-1), input_ids_rmpad.device) | ||
| flat = narrow_routed_experts(torch.cat(tensors, dim=0)) |
There was a problem hiding this comment.
Since per_traj_list is already normalized and narrowed on CPU via normalize_r3_per_traj_list, the concatenated tensor flat is already narrowed. Calling narrow_routed_experts on the GPU tensor flat is redundant and triggers a GPU synchronization. You can safely concatenate the tensors directly.
| if r3_speedup_enabled() and "routed_experts_per_traj" in data.keys(): | |
| per_traj_stack = data.pop("routed_experts_per_traj") | |
| per_traj_list = normalize_r3_per_traj_list(per_traj_stack) | |
| tensors = _align_r3_per_traj_tensors(per_traj_list, attention_mask.sum(dim=-1), input_ids_rmpad.device) | |
| flat = narrow_routed_experts(torch.cat(tensors, dim=0)) | |
| if r3_speedup_enabled() and "routed_experts_per_traj" in data.keys(): | |
| per_traj_stack = data.pop("routed_experts_per_traj") | |
| per_traj_list = normalize_r3_per_traj_list(per_traj_stack) | |
| tensors = _align_r3_per_traj_tensors(per_traj_list, attention_mask.sum(dim=-1), input_ids_rmpad.device) | |
| flat = torch.cat(tensors, dim=0) |
|
Pushed follow-up commit
Verified locally with:
|
|
Hi @wuxibin89, could you take a look at this PR when you have a chance? I kept the R3 per-trajectory trace path opt-in, added CPU-only tests, and pushed |
|
Hi @HollowMan6, sorry to bother. Since you recently reviewed the Megatron router replay fallback fix in #6653, could you take a look at this related opt-in R3 per-trajectory trace path when you have bandwidth? The default path is unchanged and the PR includes CPU-only coverage. |
Summary
VERL_R3_SPEEDUP=1path that keeps R3 rollout routing traces as per-trajectory[actual_len, num_layers, topk]arrays instead of materializing dense[batch, seq_len, num_layers, topk]tensors in the agent loop.VERL_R3_SPEEDUPis unset or disabled.Motivation
For long-context R3 rollouts, the dense padded routed-experts tensor can add avoidable memory pressure and copies before the current nested/no-padding Megatron path. This PR keeps the routing traces in their natural per-trajectory form until the no-padding conversion step.
Tests
PYTHONPATH=. pytest tests/utils/test_padding_on_cpu.py -qPYTHONPATH=. pytest tests/experimental/agent_loop/test_agent_loop_extra_fields_schema_on_cpu.py -qpython -m ruff check verl/experimental/agent_loop/agent_loop.py verl/workers/utils/padding.py tests/utils/test_padding_on_cpu.py tests/experimental/agent_loop/test_agent_loop_extra_fields_schema_on_cpu.pypython -m py_compile verl/experimental/agent_loop/agent_loop.py verl/workers/utils/padding.py tests/utils/test_padding_on_cpu.py tests/experimental/agent_loop/test_agent_loop_extra_fields_schema_on_cpu.pygit diff --checkChecklist