perf(vllm/mx_refit): cache NIXL-registered dest buffers across refit cycles#10901
perf(vllm/mx_refit): cache NIXL-registered dest buffers across refit cycles#10901KavinKrishnan wants to merge 2 commits into
Conversation
|
👋 Hi KavinKrishnan! Thank you for contributing to ai-dynamo/dynamo. Just a reminder: The 🚀 |
WalkthroughIntroduces a new ChangesMX Refit Worker Extension
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
components/src/dynamo/vllm/mx_refit/extension.py (1)
404-410: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low valueMove stdlib imports to module top.
import time as _time(Line 404) andimport threading(Line 777) are standard-library imports inside function bodies. The lazy-import rationale only applies to the optionalmodelexpress/vllmdeps; stdlib imports have no load-cost justification and hide dependencies. Hoist them to the top of the module.As per coding guidelines: "Keep imports at the top of the file; always flag
importstatements inside function bodies, methods, or classes" (exception only for optional dependencies). As per path instructions: ".ai/python-guidelines.md — all imports are at module top".🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@components/src/dynamo/vllm/mx_refit/extension.py` around lines 404 - 410, The standard library imports `import time as _time` (found near line 404) and `import threading` (found near line 777) are currently located inside function bodies rather than at the module top level. Move both of these stdlib imports to the top of the module file, outside of any function or class definitions, since the lazy-import rationale only applies to optional dependencies like modelexpress and vllm, not standard library modules. This follows the coding guideline that all imports should be at module top and ensures dependencies are clearly visible.Sources: Coding guidelines, Path instructions
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@components/src/dynamo/vllm/mx_refit/extension.py`:
- Line 1: The file extension.py has formatting and import ordering issues that
are not compliant with the project's code style standards as enforced by black
and isort. To fix this, run the pre-commit hooks on the entire repository using
`pre-commit run --all-files` or alternatively run `ruff format` followed by
`isort` on the specific file to automatically reformat the code and organize
imports correctly. Commit the resulting changes before merging.
- Around line 615-641: The bug is that the cleanup guard at line 635 only
removes freshly allocated buffers from plan_dests when cached_plan_dests is None
(no warm cycle). On warm cycles, freshly allocated row-parallel buffers that
fail the contiguity check remain in plan_dests and get consumed with stale data
instead of being routed to v0. Track which tensor names were newly allocated in
the current cycle (you can use a set or extend the existing
newly_allocated_this_cycle counter to track names), then modify the condition at
line 635 to also check if plan.tensor_name is in the current cycle's newly
allocated set. This ensures freshly allocated non-contiguous buffers are removed
regardless of whether we're in a warm cycle, allowing them to be properly
handled by the v0 path via assemble_into_destination.
---
Nitpick comments:
In `@components/src/dynamo/vllm/mx_refit/extension.py`:
- Around line 404-410: The standard library imports `import time as _time`
(found near line 404) and `import threading` (found near line 777) are currently
located inside function bodies rather than at the module top level. Move both of
these stdlib imports to the top of the module file, outside of any function or
class definitions, since the lazy-import rationale only applies to optional
dependencies like modelexpress and vllm, not standard library modules. This
follows the coding guideline that all imports should be at module top and
ensures dependencies are clearly visible.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: d9e26229-e4e4-4b73-b507-ff3ec5a9ee81
📒 Files selected for processing (2)
components/src/dynamo/vllm/mx_refit/__init__.pycomponents/src/dynamo/vllm/mx_refit/extension.py
| @@ -0,0 +1,864 @@ | |||
| # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |||
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win
Run black and isort — pre-commit failed in CI.
The pre-merge pipeline reports both black (reformatting) and isort (import ordering) modified this file. Run pre-commit run --all-files (or ruff format + isort) and commit the result before merge.
🧰 Tools
🪛 GitHub Actions: Pre Merge / 8_pre-commit.txt
[error] 1-1: pre-commit hook 'isort' failed. Files were modified by this hook (Fixing /home/runner/work/dynamo/dynamo/components/src/dynamo/vllm/mx_refit/extension.py). Run 'pre-commit run isort --all-files' or apply the changes.
[error] 1-1: pre-commit hook 'black' failed. Files were modified by this hook (reformatted components/src/dynamo/vllm/mx_refit/extension.py). Run 'pre-commit run black --all-files' or apply the changes.
🪛 GitHub Actions: Pre Merge / pre-commit
[error] 1-1: pre-commit hook 'isort' failed (files were modified by this hook). Fix applied to the file; re-run the check after committing changes.
[error] 1-1: pre-commit hook 'black' failed (files were modified by this hook). Reformatting was applied to the file; re-run the check after committing changes.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@components/src/dynamo/vllm/mx_refit/extension.py` at line 1, The file
extension.py has formatting and import ordering issues that are not compliant
with the project's code style standards as enforced by black and isort. To fix
this, run the pre-commit hooks on the entire repository using `pre-commit run
--all-files` or alternatively run `ruff format` followed by `isort` on the
specific file to automatically reformat the code and organize imports correctly.
Commit the resulting changes before merging.
Source: Pipeline failures
| if plan.tensor_name in plan_dests: | ||
| dest = plan_dests[plan.tensor_name] | ||
| else: | ||
| dest = torch.empty(plan.target_shape, dtype=dt, device=device) | ||
| plan_dests[plan.tensor_name] = dest | ||
| newly_allocated_this_cycle += 1 | ||
| axis = 1 if plan.assembly == "concat_dim1" else 0 | ||
| routed_v1 = True | ||
| for src in plan.sources: | ||
| target_lo, target_hi = src.target_local_range | ||
| dest_view = dest.narrow(axis, target_lo, target_hi - target_lo) | ||
| if not dest_view.is_contiguous(): | ||
| routed_v1 = False | ||
| break | ||
| v1_batches[src.mx_source_id].append( | ||
| (plan.tensor_name, src.source_subslice, dest_view) | ||
| ) | ||
| if not routed_v1: | ||
| # Don't drop cached entries — they may be valid for | ||
| # other plans; just route this plan to v0. | ||
| if cached_plan_dests is None: | ||
| plan_dests.pop(plan.tensor_name, None) | ||
| for sid in v1_batches: | ||
| v1_batches[sid] = [ | ||
| r for r in v1_batches[sid] if r[0] != plan.tensor_name | ||
| ] | ||
| v0_plans.append(plan) |
There was a problem hiding this comment.
🗄️ Data Integrity & Integration | 🔴 Critical
🧩 Analysis chain
🏁 Script executed:
# First, map the file to understand structure
ast-grep outline components/src/dynamo/vllm/mx_refit/extension.py --view expandedRepository: ai-dynamo/dynamo
Length of output: 1131
🏁 Script executed:
# Read the relevant section around the bug (lines 600-700)
sed -n '600,700p' components/src/dynamo/vllm/mx_refit/extension.pyRepository: ai-dynamo/dynamo
Length of output: 4937
🏁 Script executed:
# Search for where plan_dests is initialized and used
rg "plan_dests" components/src/dynamo/vllm/mx_refit/extension.py -n -B 2 -A 2Repository: ai-dynamo/dynamo
Length of output: 3013
🏁 Script executed:
# Look for test files to verify multi-cycle test coverage
find components -name "*test*" -type f | grep -E "(mx_refit|extension)" | head -20Repository: ai-dynamo/dynamo
Length of output: 154
🏁 Script executed:
# Search for test files in the repository
find components -path "*test*" -name "*.py" | grep -i "mx\|refit" | head -20Repository: ai-dynamo/dynamo
Length of output: 154
🏁 Script executed:
# Search more broadly for test files
find . -name "test_*.py" -o -name "*_test.py" | grep -E "mx|refit" | head -20Repository: ai-dynamo/dynamo
Length of output: 208
🏁 Script executed:
# Look for any pytest files or test directories
find . -type d -name "test*" | head -10Repository: ai-dynamo/dynamo
Length of output: 562
🏁 Script executed:
# Search for references to mixed-TP tests or multi-cycle validation
rg "mixed.?TP|warm.?cycle|consecutive.*cycle" --type py -iRepository: ai-dynamo/dynamo
Length of output: 154
🏁 Script executed:
# Search for tests that validate the _update_weights_via_mx_megatron method
rg "_update_weights_via_mx_megatron|MxRefitWorkerExtension" --type py | grep -i testRepository: ai-dynamo/dynamo
Length of output: 154
🏁 Script executed:
# Check the mx_integration test file
cat lib/gpu_memory_service/tests/test_mx_integration.pyRepository: ai-dynamo/dynamo
Length of output: 3580
🏁 Script executed:
# Search for multi-cycle test scenarios in all test files
rg "cycle|refit" lib/gpu_memory_service/tests/ -A 3 -B 1Repository: ai-dynamo/dynamo
Length of output: 936
🏁 Script executed:
# Look for test files in the vllm/mx_refit directory
find components/src/dynamo/vllm/mx_refit -name "*.py" | head -20Repository: ai-dynamo/dynamo
Length of output: 251
🏁 Script executed:
# Look for any conftest or test utilities in the mx_refit area
find components/src/dynamo/vllm -name "test_*.py" -o -name "*_test.py"Repository: ai-dynamo/dynamo
Length of output: 2728
🏁 Script executed:
# Read more of the code to understand the complete flow
sed -n '590,710p' components/src/dynamo/vllm/mx_refit/extension.pyRepository: ai-dynamo/dynamo
Length of output: 6096
🏁 Script executed:
# Search for PR/commit references to understand the context of this fix
rg "16-receiver|Llama 3.1|2026-06-22" components/src/dynamo/vllm/mx_refit/ -B 2 -A 2Repository: ai-dynamo/dynamo
Length of output: 1411
🏁 Script executed:
# Look for tests that specifically test mixed-TP or warm cycle scenarios
find . -path "*/test*" -name "*.py" -exec grep -l "plan_dests\|mx_megatron\|warm.*cycle" {} \;Repository: ai-dynamo/dynamo
Length of output: 154
🏁 Script executed:
# Check if there are integration tests or unit tests for MxRefitWorkerExtension
rg "MxRefitWorkerExtension|_update_weights_via_mx_megatron" --type pyRepository: ai-dynamo/dynamo
Length of output: 154
🏁 Script executed:
# Look at the register_tensors call and assembly loop more carefully
sed -n '643,710p' components/src/dynamo/vllm/mx_refit/extension.pyRepository: ai-dynamo/dynamo
Length of output: 3475
🏁 Script executed:
# Search for the validation lines mentioned: "8/8, 16/16"
rg "8/8|16/16" components/src/dynamo/vllm/ --type py -B 2 -A 2Repository: ai-dynamo/dynamo
Length of output: 154
🏁 Script executed:
# Search for mixed-TP or megatron test files
find components/src/dynamo/vllm/tests -name "*.py" -exec grep -l "megatron\|mixed.?TP" {} \;Repository: ai-dynamo/dynamo
Length of output: 154
🏁 Script executed:
# Look for any tests that call _update_weights_via_mx_megatron
rg "_update_weights_via_mx_megatron|plan_dests" --type pyRepository: ai-dynamo/dynamo
Length of output: 154
🏁 Script executed:
# Search for broader mixed-TP validation or validation tests
rg "validation|refit.*cycle|warm.*cycle" components/src/dynamo/vllm/ --type py -iRepository: ai-dynamo/dynamo
Length of output: 2716
🏁 Script executed:
# Look for any pytest markers or test discovery in the extension module itself
rg "pytest|test_" components/src/dynamo/vllm/mx_refit/ --type pyRepository: ai-dynamo/dynamo
Length of output: 343
🏁 Script executed:
# Search for any validation comments or markers about single vs multi-cycle in the PR
rg "single.?cycle|multi.?cycle|consecutive.*cycle" . --type py -iRepository: ai-dynamo/dynamo
Length of output: 154
🏁 Script executed:
# Check if there are any conftest files that might contain shared fixtures
find components/src/dynamo/vllm -name "conftest.py"Repository: ai-dynamo/dynamo
Length of output: 199
🏁 Script executed:
# Look at the full context of the assembly loop to verify the bug path
sed -n '690,720p' components/src/dynamo/vllm/mx_refit/extension.pyRepository: ai-dynamo/dynamo
Length of output: 1778
🏁 Script executed:
# Final verification: check if there are any inline test markers or validation in the file
rg "TODO|FIXME|XXX|BUG|HACK" components/src/dynamo/vllm/mx_refit/extension.py -B 1 -A 1Repository: ai-dynamo/dynamo
Length of output: 358
🏁 Script executed:
# Check the register_tensors signature to confirm it's being called correctly
rg "def register_tensors|_nixl.register" . --type py -B 2 -A 5 | head -30Repository: ai-dynamo/dynamo
Length of output: 3281
Mixed-TP warm-cycle data integrity bug: freshly allocated v0 (row-parallel) destination buffers are left in plan_dests and consumed as empty tensors on subsequent cycles.
On a warm cycle (cached_plan_dests is not None), when a row-parallel plan is non-contiguous and freshly allocated at line 618, the pop guard at line 635 (if cached_plan_dests is None) evaluates to false, so the buffer is not removed. The buffer is then registered and consumed in the assembly loop at line 694, loading garbage/stale weights silently instead of taking the v0 assemble_into_destination path.
The fix must track per-cycle allocations, not just whether a cache existed:
Proposed fix
plan_dests: dict[str, torch.Tensor] = cached_plan_dests or {}
v1_batches: dict[str, list] = {c.ref.mx_source_id: [] for c in megatron_cands}
v0_plans: list = []
newly_allocated_this_cycle = 0
+ newly_allocated_names: set[str] = set()
for plan in plans:
...
if plan.tensor_name in plan_dests:
dest = plan_dests[plan.tensor_name]
else:
dest = torch.empty(plan.target_shape, dtype=dt, device=device)
plan_dests[plan.tensor_name] = dest
newly_allocated_this_cycle += 1
+ newly_allocated_names.add(plan.tensor_name)
...
if not routed_v1:
- # Don't drop cached entries — they may be valid for
- # other plans; just route this plan to v0.
- if cached_plan_dests is None:
- plan_dests.pop(plan.tensor_name, None)
+ # Drop only buffers allocated this cycle; pre-existing
+ # cached entries valid for other plans must stay.
+ if plan.tensor_name in newly_allocated_names:
+ plan_dests.pop(plan.tensor_name, None)
+ newly_allocated_this_cycle -= 1
+ newly_allocated_names.discard(plan.tensor_name)Verify multi-cycle test coverage for row-parallel (axis-1, non-contiguous) tensors across two or more consecutive refit cycles. Current validation appears limited to single-cycle scenarios.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if plan.tensor_name in plan_dests: | |
| dest = plan_dests[plan.tensor_name] | |
| else: | |
| dest = torch.empty(plan.target_shape, dtype=dt, device=device) | |
| plan_dests[plan.tensor_name] = dest | |
| newly_allocated_this_cycle += 1 | |
| axis = 1 if plan.assembly == "concat_dim1" else 0 | |
| routed_v1 = True | |
| for src in plan.sources: | |
| target_lo, target_hi = src.target_local_range | |
| dest_view = dest.narrow(axis, target_lo, target_hi - target_lo) | |
| if not dest_view.is_contiguous(): | |
| routed_v1 = False | |
| break | |
| v1_batches[src.mx_source_id].append( | |
| (plan.tensor_name, src.source_subslice, dest_view) | |
| ) | |
| if not routed_v1: | |
| # Don't drop cached entries — they may be valid for | |
| # other plans; just route this plan to v0. | |
| if cached_plan_dests is None: | |
| plan_dests.pop(plan.tensor_name, None) | |
| for sid in v1_batches: | |
| v1_batches[sid] = [ | |
| r for r in v1_batches[sid] if r[0] != plan.tensor_name | |
| ] | |
| v0_plans.append(plan) | |
| if plan.tensor_name in plan_dests: | |
| dest = plan_dests[plan.tensor_name] | |
| else: | |
| dest = torch.empty(plan.target_shape, dtype=dt, device=device) | |
| plan_dests[plan.tensor_name] = dest | |
| newly_allocated_this_cycle += 1 | |
| newly_allocated_names.add(plan.tensor_name) | |
| axis = 1 if plan.assembly == "concat_dim1" else 0 | |
| routed_v1 = True | |
| for src in plan.sources: | |
| target_lo, target_hi = src.target_local_range | |
| dest_view = dest.narrow(axis, target_lo, target_hi - target_lo) | |
| if not dest_view.is_contiguous(): | |
| routed_v1 = False | |
| break | |
| v1_batches[src.mx_source_id].append( | |
| (plan.tensor_name, src.source_subslice, dest_view) | |
| ) | |
| if not routed_v1: | |
| # Drop only buffers allocated this cycle; pre-existing | |
| # cached entries valid for other plans must stay. | |
| if plan.tensor_name in newly_allocated_names: | |
| plan_dests.pop(plan.tensor_name, None) | |
| newly_allocated_this_cycle -= 1 | |
| newly_allocated_names.discard(plan.tensor_name) | |
| for sid in v1_batches: | |
| v1_batches[sid] = [ | |
| r for r in v1_batches[sid] if r[0] != plan.tensor_name | |
| ] | |
| v0_plans.append(plan) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@components/src/dynamo/vllm/mx_refit/extension.py` around lines 615 - 641, The
bug is that the cleanup guard at line 635 only removes freshly allocated buffers
from plan_dests when cached_plan_dests is None (no warm cycle). On warm cycles,
freshly allocated row-parallel buffers that fail the contiguity check remain in
plan_dests and get consumed with stale data instead of being routed to v0. Track
which tensor names were newly allocated in the current cycle (you can use a set
or extend the existing newly_allocated_this_cycle counter to track names), then
modify the condition at line 635 to also check if plan.tensor_name is in the
current cycle's newly allocated set. This ensures freshly allocated
non-contiguous buffers are removed regardless of whether we're in a warm cycle,
allowing them to be properly handled by the v0 path via
assemble_into_destination.
| full_shape[spec.shard_axis] = ( | ||
| axis_extent if layout.tp_rank == target_tp - 1 | ||
| else per_rank | ||
| ) |
There was a problem hiding this comment.
🔴 Last TP rank buffer gets full global axis extent instead of per-rank shard size
In matched-TP buffer allocation (line 524-527), the last TP rank's buffer dimension is set to axis_extent (the entire global dimension) instead of the correct per-rank remainder axis_extent - per_rank * (target_tp - 1). For TP=2, spec.target_shape is the global shape (per_rank × source_tp at extension.py:457-459). The code computes per_rank = axis_extent // target_tp, which is correct for non-last ranks. But for layout.tp_rank == target_tp - 1, it assigns the full axis_extent, which is target_tp times larger than needed.
Concrete example for TP=2
Source per-rank weight shape = [1024, 4096] on shard_axis=0.
Global shape (spec.target_shape) = [2048, 4096].
axis_extent = 2048, per_rank = 1024.
- Rank 0: buffer = [1024, 4096] ✓
- Rank 1 (last): buffer = [2048, 4096] ✗ — should be [1024, 4096]
Rank 1 allocates a buffer 2× the needed size. The RDMA receive_from at line 550 fills only the first 1024 rows from the source's per-rank shard, leaving 1024 rows of uninitialized data. When run_refit_cycle at line 565 processes pre_assembled_buffers=buffers, it may interpret the full buffer shape as valid data, producing incorrect HF weight tensors with garbage.
| full_shape[spec.shard_axis] = ( | |
| axis_extent if layout.tp_rank == target_tp - 1 | |
| else per_rank | |
| ) | |
| full_shape[spec.shard_axis] = ( | |
| axis_extent - per_rank * (target_tp - 1) | |
| if layout.tp_rank == target_tp - 1 | |
| else per_rank | |
| ) |
Was this helpful? React with 👍 or 👎 to provide feedback.
| * synthetic TP=2 → TP=1 mixed-TP target-narrower: 8 / 8 | ||
| * synthetic TP=1 → TP=2 mixed-TP target-wider (v1 sliced-pull): 16 / 16 | ||
| """ | ||
| import time as _time |
There was a problem hiding this comment.
🟡 Stdlib imports (time, threading) placed inside method bodies instead of at module top
import time as _time at line 404 and import threading at line 777 are standard-library imports placed inside method bodies. The .ai/python-guidelines.md critical rule states: "Always flag any import statement that appears inside a function body, method, or class. Imports inside functions hide dependencies, make the module harder to understand at a glance, and can mask missing packages until a specific code path is hit at runtime." Other files in the same package (e.g. handlers.py:15-16, worker_factory.py:10) correctly import threading and time at the top of the file.
Prompt for agents
Move `import time as _time` (line 404) and `import threading` (line 777) to the top of the file alongside the other stdlib imports (gc, logging, os, traceback) at lines 27-30. These are standard library modules with no reason for lazy loading. The .ai/python-guidelines.md critical rule requires all imports at module level.
Was this helpful? React with 👍 or 👎 to provide feedback.
| if not getattr(self, "_mx_receiver", None): | ||
| # Import here so workers that never refit via MX don't pay | ||
| # the modelexpress import cost. | ||
| from modelexpress import MxV2RefitReceiver |
There was a problem hiding this comment.
🟡 modelexpress and vllm sub-module imports inside method bodies instead of using try/except ImportError at module top
Multiple modelexpress imports (lines 95, 237, 405-410, 790) and a vllm sub-module import (lines 177-179) are placed inside method bodies. The .ai/python-guidelines.md critical rule states these must be flagged. For optional dependencies like modelexpress, the approved exception pattern is try/except ImportError at the top of the file (setting fallbacks to None), not lazy imports inside functions. The rule explicitly states: "Using try/except ImportError is the correct pattern when: An optional backend dependency... may not be installed, and the code provides a fallback or sets the import to None."
Prompt for agents
Move all modelexpress and vllm sub-module imports to the top of the file using the approved try/except ImportError pattern for optional dependencies. For example:
try:
from modelexpress import MxV2RefitReceiver
from modelexpress.ucx_utils import apply_nic_pin_for_device
from modelexpress.megatron_translator import (
MegatronReceiverContext, ReceiveSpec,
assemble_into_destination, discover_megatron_context,
run_refit_cycle, translate_megatron_to_hf,
)
from modelexpress.nemo_rl_v2 import MegatronTensorSpec, TargetTpLayout
except ImportError:
MxV2RefitReceiver = None
# other fallbacks...
Similarly move the vllm sub-module import:
from vllm.model_executor.model_loader.utils import process_weights_after_loading
This affects lines 95, 177-179, 237, 405-410, and 790 in extension.py.
Was this helpful? React with 👍 or 👎 to provide feedback.
| except Exception as exc: # noqa: BLE001 | ||
| logger.error( | ||
| "[mx] update_weights_via_mx failed on rank=%d: %s\n%s", | ||
| getattr( | ||
| getattr(self, "_mx_receiver", None), "worker_rank", -1 | ||
| ), | ||
| exc, | ||
| traceback.format_exc(), | ||
| ) | ||
| return False |
There was a problem hiding this comment.
🟡 Multiple except Exception blocks catch-and-suppress without re-raising
The .ai/python-guidelines.md states: "If you must catch broadly, use except Exception: (never bare except:) and always re-raise after logging." The main refit handler at lines 351-360 catches except Exception, logs the error, and returns False without re-raising. Similarly, lines 104-105, 341-346, 740-743, 819-824, and 855-856 all catch Exception and suppress it. While the RPC handler design deliberately returns failure status instead of propagating exceptions, this violates the guidelines' explicit "always re-raise" rule.
Was this helpful? React with 👍 or 👎 to provide feedback.
| target_tp_rank = ( | ||
| torch.distributed.get_rank() | ||
| if torch.distributed.is_initialized() else 0 | ||
| ) | ||
| layout = TargetTpLayout(tp_size=target_tp, tp_rank=target_tp_rank) |
There was a problem hiding this comment.
🚩 target_tp_rank uses global rank — correct for TP-only but wrong for TP+PP
At extension.py:430-433, target_tp_rank is set via torch.distributed.get_rank() which returns the rank in the default (world) process group. For pure TP deployments (the common vLLM case), global_rank == tp_rank, so this works. However, for TP+PP configurations (e.g. TP=2, PP=2 → world_size=4), rank 2 would have global_rank=2 but tp_rank=0. This would cause the matched-TP check at extension.py:486 to fail (no source has tp_rank=2 when tp_size=2), incorrectly falling through to the mixed-TP path. Since vLLM does support PP, this could be a latent issue if PP>1 is ever used with this refit extension. Not flagged as a bug because PP is uncommon for inference and the code explicitly targets TP-only workflows, but worth noting for future generalization.
Was this helpful? React with 👍 or 👎 to provide feedback.
| ok = self.update_weights_via_mx( | ||
| version=latest, mx_config=cfg | ||
| ) |
There was a problem hiding this comment.
🚩 Poller thread calls update_weights_via_mx outside collective_rpc — potential concurrency with inference
The poller thread at extension.py:842 calls self.update_weights_via_mx(version=latest, mx_config=cfg) directly from a background thread. Unlike the explicit RPC path (which goes through collective_rpc which pauses generation), the poller does not coordinate with the inference loop. This means model.load_weights() at extension.py:156 could execute concurrently with model.forward() in the main worker thread. While PyTorch's default CUDA stream is shared across threads (serializing GPU ops), CPU-side tensor metadata modifications during load_weights are not protected. In practice, CUDA stream serialization and the GIL may prevent observable corruption, but this is architecturally risky. The poller docstring doesn't mention this limitation.
Was this helpful? React with 👍 or 👎 to provide feedback.
| SHARD_AXIS_BY_ROLE = { | ||
| "column": 0, "qkv_column": 0, "gated_mlp_column": 0, | ||
| "vocab_parallel": 0, "row": 1, | ||
| "expert_column": 0, "expert_row": 0, "replicated": 0, |
There was a problem hiding this comment.
🚩 expert_row shard_axis set to 0 while regular row uses 1 — intentional for MoE passthrough?
At extension.py:441, SHARD_AXIS_BY_ROLE maps "expert_row": 0 while "row": 1. This seems inconsistent — regular row-parallel shards along axis 1. The shard_axis is used at line 457-459 to compute global_shape: global_shape[shard_axis] *= source_tp_size. For expert_row with axis=0, this multiplies the wrong dimension relative to regular row. However, lines 528-531 treat expert tensors as passthrough (pass) and lines 611-612 route per_expert plans to v0 scratch. The shard_axis for experts may represent the expert-grouping dimension rather than the TP dimension, making axis=0 correct in context. Without access to the Megatron-Core MoE conventions, I couldn't definitively determine whether this is correct.
Was this helpful? React with 👍 or 👎 to provide feedback.
Adds a new `dynamo.vllm.mx_refit` module that ports NeMo-RL's
`VllmInternalWorkerExtension.update_weights_via_mx` into Dynamo's
vLLM worker so external trainers can publish per-rank tensor shards
via NIXL RDMA into a running AsyncLLM mid-training without going
through HF-format round-tripping.
Two paths are supported:
1. DTensor / v1 path: trainer publishes torch.distributed DTensors;
receiver does a full-tensor pull (`.full_tensor()`-equivalent over
NIXL) and lets vLLM's standard load_weights re-shard.
2. Megatron-Core / v2 path: trainer publishes per-rank Megatron-native
shards with a sidecar describing the model's transformer config and
a name_map. Receiver builds a slice plan against its own vLLM TP/EP
layout, pulls each rank's contribution directly into a pre-allocated
global tensor, then applies role-aware translation (QKV
un-interleave, gated-MLP split, name remap) via the vendored Bridge
helpers shipped with the modelexpress Python client. Supports both
matched-TP (source_tp == target_tp) and mixed-TP (source_tp != target_tp,
with v1 sliced-pull for contiguous narrows + v0 host-scratch+slice
for strided / per-expert dict assembly).
The Megatron client-side machinery (slice planner, translator, vendored
Bridge helpers, v1 NIXL sliced-pull primitive) lives in
ai-dynamo/modelexpress PR ai-dynamo#429 (merged). The trainer-side publisher
lives in NeMo-RL PR ai-dynamo#2 (merged into the dynamo-k8s-integration branch).
This file has been carried as a private kubectl-overlay against
ai-dynamo's container image and validated end-to-end on GB200 against:
- Qwen3-4B-Thinking-2507: 398/398 HF tensors byte-identical
(matched-TP).
- Qwen3-MoE-30B-A3B: 18,867/18,867 HF tensors byte-identical
(matched-TP, grouped-MoE).
- Mixed-TP target-wider (trainer TP=1 -> vLLM TP=2): 16/16
byte-identical via v1 NIXL sliced-pull.
No public Dynamo APIs change; the extension is opt-in via the standard
`mx_refit.MxRefitWorkerExtension` registration on the vLLM worker.
A separate follow-up PR adds buffer-caching across refit cycles to
avoid re-allocating + re-registering NIXL buffers on every cycle.
…cycles
Audit of refit-cycle wall time on multi-receiver setups surfaced a
receiver-side bug: both the matched-TP and mixed-TP branches in
`_update_weights_via_mx_megatron` re-allocated + re-registered per-rank
buffers with NIXL on every refit cycle, paying ~0.15 s of `ibv_reg_mr`
per cycle on Qwen3-4B (proportionally more on larger models).
Plan shapes (and buffer shapes more generally) are deterministic for a
fixed `(source_tp, target_tp)` layout, so cycle-N's allocations are
identical to cycle-1's. Cache the buffers / plan_dests dicts on
`self._mx_megatron_buffers` and `self._mx_megatron_plan_dests`
respectively, and skip the `register_tensors` call when nothing was
newly allocated this cycle.
Cluster validation on GB200 / Qwen3-4B-Thinking 2026-06-23, 3
back-to-back refit cycles via a standalone benchmark exercising the
same MxV2RefitReceiver + NIXL layer:
alloc register pull translate total
Before fix:
Cycle 1 0.032 0.152 0.209 0.016 0.409
Cycle 2 0.026 0.152 0.203 0.012 0.392
Cycle 3 0.001 0.024 0.210 0.011 0.246
After fix:
Cycle 1 0.028 0.085 0.206 0.014 0.333
Cycle 2 0.000 0.000 0.204 0.010 0.215 (-45%)
Cycle 3 0.000 0.000 0.204 0.011 0.215
Pull step unchanged (~205-210 ms for 8 GB at 308 Gbps single-NIC); the
fix touches setup overhead only. Larger models will see proportionally
larger savings since `ibv_reg_mr` scales with both buffer count and
pinned-memory size.
The matching fix in the NeMo-RL trainer-side `vllm_backend.py` lives
at jthomson04/RL#7.
8dad80b to
4bd2ae9
Compare
…timizations
Lands the v2 client surface and the surrounding RL workstream needed
to unblock the downstream Megatron-MX and perf PRs.
## What this contributes
### v2 client surface (this PR's primary deliverable)
* `MxV2TrainingPublisher` + `MxV2RefitReceiver` (modelexpress.nemo_rl_v2)
-- the fat-client surface for per-rank shard publish and
receiver-side multi-source assembly.
* `MxWeightTransferEngine` (modelexpress.vllm_weight_transfer) -- an
adapter implementing vLLM's upstream WeightTransferEngine ABC.
* `TensorDescriptorV2.extra_parameters` (map<string,string>) plus
`SourceIdentity.revision` (string) -- the two proto extensions that
carry all per-tensor + per-source RL metadata.
### Phase 3a -- compile_target + compile_metadata on TensorDescriptorV2
New per-tensor fields default to ``hf_raw`` / ``{}``; the wire encoder
omits them when default so existing payloads stay byte-identical.
New constants: ``COMPILE_TARGET_HF_RAW``, ``_VLLM_FUSED``,
``_DEEPGEMM_FP8``, ``_CUTLASS_FP8``, ``_TRTLLM``. New helper
``compile_target_matches(descriptor, *, allowed_targets,
required_metadata=None)`` for receiver-side filtering with whitelist
plus required-metadata-subset semantics.
### Phase 3b -- compile_target_filter on discover_v2_sources
New kwargs ``compile_target_filter`` (whitelist set) and
``required_compile_metadata`` (subset-of-every-tensor's
compile_metadata). Candidates with no v2 registry are rejected when
either filter is set; candidates with mixed compile targets are
rejected if any tensor falls outside the allowed set.
``V2SourceCandidate.compile_targets: frozenset[str]`` exposed for
caller introspection.
### Phase 4 -- Multi-source slice discovery for mixed trainer/inference TP
New types: ``TargetTPLayout``, ``SliceSource``, ``SliceCoveragePlan``.
New methods ``MxV2RefitReceiver.discover_v2_sources_for_slice`` and
``MxV2RefitReceiver.receive_via_plan`` -- planner walks v2 candidates
per tensor, intersects each publisher's local_shard_range against the
receiver's requested slice, emits the minimal candidate set covering
it; surfaces coverage gaps and shard_axis mismatches in plan.missing.
``receive_via_plan`` orchestrates per-candidate scratch RDMA pulls
and stitches via torch.cat along the shard axis.
### Proto extensions
* ``TensorDescriptorV2.extra_parameters: map<string, string>`` -- the
escape hatch downstream RL clients use for per-tensor metadata
(megatron_role, compile_target, expert_id, revision, training_step,
...). Heavily used by #429.
* ``SourceIdentity.revision: string`` -- content-addressed weight
version. Non-empty value guarantees two sources with identical
SourceIdentity have bit-identical weight bytes, enabling
decentralized modes (no central coordinator) to use mx_source_id
as a full content check.
* Redis backend ``SourceAttributesJson`` carries the new fields plus
the union of main's compatibility additions
(backend_framework_version, torch_version, cuda_version,
triton_version, gpu_arch, compile_config_digest).
## Tests
68/68 unit tests green on this branch's surface (test_v2_source_picker,
test_v2_shape_registry, test_source_id, test_types, test_vllm_adapter).
## Compatibility
Every new field has a backwards-compatible default and every new method
arg is optional. Existing demos and downstream consumers run unchanged.
## Downstream impact
This PR is the prerequisite for several PRs in flight:
* #421 -- publish_self_as_source tree fan-out fix
* #429 -- Megatron-Core MX clients
* #450 -- MX_RDMA_NIC_PIN=stripe multi-NIC mode
* ai-dynamo/dynamo#10900 -- first-time upstream port of
dynamo.vllm.mx_refit extension
* ai-dynamo/dynamo#10901 -- buffer-caching perf fix in Dynamo extension
* jthomson04/RL#2 (merged) + #7 -- NeMo RL Megatron-MX integration plus
perf fix
Once this PR lands, kavink/nemo_rl_moe (the integration branch holding
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
#421 and #429) can be retired by rebasing those PRs onto main.
…timizations
Lands the v2 client surface and the surrounding RL workstream needed
to unblock the downstream Megatron-MX and perf PRs.
## What this contributes
### v2 client surface (this PR's primary deliverable)
* `MxV2TrainingPublisher` + `MxV2RefitReceiver` (modelexpress.nemo_rl_v2)
-- the fat-client surface for per-rank shard publish and
receiver-side multi-source assembly.
* `MxWeightTransferEngine` (modelexpress.vllm_weight_transfer) -- an
adapter implementing vLLM's upstream WeightTransferEngine ABC.
* `TensorDescriptorV2.extra_parameters` (map<string,string>) plus
`SourceIdentity.revision` (string) -- the two proto extensions that
carry all per-tensor + per-source RL metadata.
### Phase 3a -- compile_target + compile_metadata on TensorDescriptorV2
New per-tensor fields default to ``hf_raw`` / ``{}``; the wire encoder
omits them when default so existing payloads stay byte-identical.
New constants: ``COMPILE_TARGET_HF_RAW``, ``_VLLM_FUSED``,
``_DEEPGEMM_FP8``, ``_CUTLASS_FP8``, ``_TRTLLM``. New helper
``compile_target_matches(descriptor, *, allowed_targets,
required_metadata=None)`` for receiver-side filtering with whitelist
plus required-metadata-subset semantics.
### Phase 3b -- compile_target_filter on discover_v2_sources
New kwargs ``compile_target_filter`` (whitelist set) and
``required_compile_metadata`` (subset-of-every-tensor's
compile_metadata). Candidates with no v2 registry are rejected when
either filter is set; candidates with mixed compile targets are
rejected if any tensor falls outside the allowed set.
``V2SourceCandidate.compile_targets: frozenset[str]`` exposed for
caller introspection.
### Phase 4 -- Multi-source slice discovery for mixed trainer/inference TP
New types: ``TargetTPLayout``, ``SliceSource``, ``SliceCoveragePlan``.
New methods ``MxV2RefitReceiver.discover_v2_sources_for_slice`` and
``MxV2RefitReceiver.receive_via_plan`` -- planner walks v2 candidates
per tensor, intersects each publisher's local_shard_range against the
receiver's requested slice, emits the minimal candidate set covering
it; surfaces coverage gaps and shard_axis mismatches in plan.missing.
``receive_via_plan`` orchestrates per-candidate scratch RDMA pulls
and stitches via torch.cat along the shard axis.
### Proto extensions
* ``TensorDescriptorV2.extra_parameters: map<string, string>`` -- the
escape hatch downstream RL clients use for per-tensor metadata
(megatron_role, compile_target, expert_id, revision, training_step,
...). Heavily used by #429.
* ``SourceIdentity.revision: string`` -- content-addressed weight
version. Non-empty value guarantees two sources with identical
SourceIdentity have bit-identical weight bytes, enabling
decentralized modes (no central coordinator) to use mx_source_id
as a full content check.
* Redis backend ``SourceAttributesJson`` carries the new fields plus
the union of main's compatibility additions
(backend_framework_version, torch_version, cuda_version,
triton_version, gpu_arch, compile_config_digest).
## Tests
68/68 unit tests green on this branch's surface (test_v2_source_picker,
test_v2_shape_registry, test_source_id, test_types, test_vllm_adapter).
## Compatibility
Every new field has a backwards-compatible default and every new method
arg is optional. Existing demos and downstream consumers run unchanged.
## Downstream impact
This PR is the prerequisite for several PRs in flight:
* #421 -- publish_self_as_source tree fan-out fix
* #429 -- Megatron-Core MX clients
* #450 -- MX_RDMA_NIC_PIN=stripe multi-NIC mode
* ai-dynamo/dynamo#10900 -- first-time upstream port of
dynamo.vllm.mx_refit extension
* ai-dynamo/dynamo#10901 -- buffer-caching perf fix in Dynamo extension
* jthomson04/RL#2 (merged) + #7 -- NeMo RL Megatron-MX integration plus
perf fix
Once this PR lands, kavink/nemo_rl_moe (the integration branch holding
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
#421 and #429) can be retired by rebasing those PRs onto main.
…timizations
Lands the v2 client surface and the surrounding RL workstream needed
to unblock the downstream Megatron-MX and perf PRs.
## What this contributes
### v2 client surface (this PR's primary deliverable)
* `MxV2TrainingPublisher` + `MxV2RefitReceiver` (modelexpress.nemo_rl_v2)
-- the fat-client surface for per-rank shard publish and
receiver-side multi-source assembly.
* `MxWeightTransferEngine` (modelexpress.vllm_weight_transfer) -- an
adapter implementing vLLM's upstream WeightTransferEngine ABC.
* `TensorDescriptorV2.extra_parameters` (map<string,string>) plus
`SourceIdentity.revision` (string) -- the two proto extensions that
carry all per-tensor + per-source RL metadata.
### Phase 3a -- compile_target + compile_metadata on TensorDescriptorV2
New per-tensor fields default to ``hf_raw`` / ``{}``; the wire encoder
omits them when default so existing payloads stay byte-identical.
New constants: ``COMPILE_TARGET_HF_RAW``, ``_VLLM_FUSED``,
``_DEEPGEMM_FP8``, ``_CUTLASS_FP8``, ``_TRTLLM``. New helper
``compile_target_matches(descriptor, *, allowed_targets,
required_metadata=None)`` for receiver-side filtering with whitelist
plus required-metadata-subset semantics.
### Phase 3b -- compile_target_filter on discover_v2_sources
New kwargs ``compile_target_filter`` (whitelist set) and
``required_compile_metadata`` (subset-of-every-tensor's
compile_metadata). Candidates with no v2 registry are rejected when
either filter is set; candidates with mixed compile targets are
rejected if any tensor falls outside the allowed set.
``V2SourceCandidate.compile_targets: frozenset[str]`` exposed for
caller introspection.
### Phase 4 -- Multi-source slice discovery for mixed trainer/inference TP
New types: ``TargetTPLayout``, ``SliceSource``, ``SliceCoveragePlan``.
New methods ``MxV2RefitReceiver.discover_v2_sources_for_slice`` and
``MxV2RefitReceiver.receive_via_plan`` -- planner walks v2 candidates
per tensor, intersects each publisher's local_shard_range against the
receiver's requested slice, emits the minimal candidate set covering
it; surfaces coverage gaps and shard_axis mismatches in plan.missing.
``receive_via_plan`` orchestrates per-candidate scratch RDMA pulls
and stitches via torch.cat along the shard axis.
### Proto extensions
* ``TensorDescriptorV2.extra_parameters: map<string, string>`` -- the
escape hatch downstream RL clients use for per-tensor metadata
(megatron_role, compile_target, expert_id, revision, training_step,
...). Heavily used by #429.
* ``SourceIdentity.revision: string`` -- content-addressed weight
version. Non-empty value guarantees two sources with identical
SourceIdentity have bit-identical weight bytes, enabling
decentralized modes (no central coordinator) to use mx_source_id
as a full content check.
* Redis backend ``SourceAttributesJson`` carries the new fields plus
the union of main's compatibility additions
(backend_framework_version, torch_version, cuda_version,
triton_version, gpu_arch, compile_config_digest).
## Tests
68/68 unit tests green on this branch's surface (test_v2_source_picker,
test_v2_shape_registry, test_source_id, test_types, test_vllm_adapter).
## Compatibility
Every new field has a backwards-compatible default and every new method
arg is optional. Existing demos and downstream consumers run unchanged.
## Downstream impact
This PR is the prerequisite for several PRs in flight:
* #421 -- publish_self_as_source tree fan-out fix
* #429 -- Megatron-Core MX clients
* #450 -- MX_RDMA_NIC_PIN=stripe multi-NIC mode
* ai-dynamo/dynamo#10900 -- first-time upstream port of
dynamo.vllm.mx_refit extension
* ai-dynamo/dynamo#10901 -- buffer-caching perf fix in Dynamo extension
* jthomson04/RL#2 (merged) + #7 -- NeMo RL Megatron-MX integration plus
perf fix
Once this PR lands, kavink/nemo_rl_moe (the integration branch holding
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
#421 and #429) can be retired by rebasing those PRs onto main.
…timizations
Lands the v2 client surface and the surrounding RL workstream needed
to unblock the downstream Megatron-MX and perf PRs.
## What this contributes
### v2 client surface (this PR's primary deliverable)
* `MxV2TrainingPublisher` + `MxV2RefitReceiver` (modelexpress.nemo_rl_v2)
-- the fat-client surface for per-rank shard publish and
receiver-side multi-source assembly.
* `MxWeightTransferEngine` (modelexpress.vllm_weight_transfer) -- an
adapter implementing vLLM's upstream WeightTransferEngine ABC.
* `TensorDescriptorV2.extra_parameters` (map<string,string>) plus
`SourceIdentity.revision` (string) -- the two proto extensions that
carry all per-tensor + per-source RL metadata.
### Phase 3a -- compile_target + compile_metadata on TensorDescriptorV2
New per-tensor fields default to ``hf_raw`` / ``{}``; the wire encoder
omits them when default so existing payloads stay byte-identical.
New constants: ``COMPILE_TARGET_HF_RAW``, ``_VLLM_FUSED``,
``_DEEPGEMM_FP8``, ``_CUTLASS_FP8``, ``_TRTLLM``. New helper
``compile_target_matches(descriptor, *, allowed_targets,
required_metadata=None)`` for receiver-side filtering with whitelist
plus required-metadata-subset semantics.
### Phase 3b -- compile_target_filter on discover_v2_sources
New kwargs ``compile_target_filter`` (whitelist set) and
``required_compile_metadata`` (subset-of-every-tensor's
compile_metadata). Candidates with no v2 registry are rejected when
either filter is set; candidates with mixed compile targets are
rejected if any tensor falls outside the allowed set.
``V2SourceCandidate.compile_targets: frozenset[str]`` exposed for
caller introspection.
### Phase 4 -- Multi-source slice discovery for mixed trainer/inference TP
New types: ``TargetTPLayout``, ``SliceSource``, ``SliceCoveragePlan``.
New methods ``MxV2RefitReceiver.discover_v2_sources_for_slice`` and
``MxV2RefitReceiver.receive_via_plan`` -- planner walks v2 candidates
per tensor, intersects each publisher's local_shard_range against the
receiver's requested slice, emits the minimal candidate set covering
it; surfaces coverage gaps and shard_axis mismatches in plan.missing.
``receive_via_plan`` orchestrates per-candidate scratch RDMA pulls
and stitches via torch.cat along the shard axis.
### Proto extensions
* ``TensorDescriptorV2.extra_parameters: map<string, string>`` -- the
escape hatch downstream RL clients use for per-tensor metadata
(megatron_role, compile_target, expert_id, revision, training_step,
...). Heavily used by #429.
* ``SourceIdentity.revision: string`` -- content-addressed weight
version. Non-empty value guarantees two sources with identical
SourceIdentity have bit-identical weight bytes, enabling
decentralized modes (no central coordinator) to use mx_source_id
as a full content check.
* Redis backend ``SourceAttributesJson`` carries the new fields plus
the union of main's compatibility additions
(backend_framework_version, torch_version, cuda_version,
triton_version, gpu_arch, compile_config_digest).
## Tests
68/68 unit tests green on this branch's surface (test_v2_source_picker,
test_v2_shape_registry, test_source_id, test_types, test_vllm_adapter).
## Compatibility
Every new field has a backwards-compatible default and every new method
arg is optional. Existing demos and downstream consumers run unchanged.
## Downstream impact
This PR is the prerequisite for several PRs in flight:
* #421 -- publish_self_as_source tree fan-out fix
* #429 -- Megatron-Core MX clients
* #450 -- MX_RDMA_NIC_PIN=stripe multi-NIC mode
* ai-dynamo/dynamo#10900 -- first-time upstream port of
dynamo.vllm.mx_refit extension
* ai-dynamo/dynamo#10901 -- buffer-caching perf fix in Dynamo extension
* jthomson04/RL#2 (merged) + #7 -- NeMo RL Megatron-MX integration plus
perf fix
Once this PR lands, kavink/nemo_rl_moe (the integration branch holding
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
#421 and #429) can be retired by rebasing those PRs onto main.
Summary
Cache NIXL-registered dest buffers across Megatron-MX refit cycles in
dynamo.vllm.mx_refit's_update_weights_via_mx_megatron. Plan shapes (and buffer shapes more generally) are deterministic for a fixed(source_tp, target_tp)layout — re-allocating + re-registering NIXL buffers on every refit cycle wastes ~0.15 s per cycle on small models and proportionally more on larger ones.Follow-up to #10900 (the port PR). This PR contains BOTH commits — the port plus this fix — so the diff against
mainshows the full delta. Once the port PR merges, this PR will rebase to show only the perf-fix commit.Motivation
Audit of refit-cycle wall time on multi-receiver setups surfaced a per-cycle bottleneck: NIXL
register_tensorswas being called every refit for ~290 buffers on Qwen3-4B (8 GB), costing ~150 ms ofibv_reg_mroverhead per receiver per cycle. The buffer shapes never change for a fixed TP layout, so the register call can be amortized to cycle 1 only.Change
Three patterns applied:
bufferscached onself._mx_megatron_buffers(line ~510)plan_destscached onself._mx_megatron_plan_dests(line ~593)register_tensorswhen nothing was newly allocated this cycle.The cache survives for the lifetime of the worker. Plans that fall back to the v0 scratch path don't break the cache for plans that were v1 in prior cycles.
Validation
Cluster-validated on GB200 + Qwen3-4B-Thinking-2507 via 3 back-to-back refit cycles. The benchmark exercises the same
MxV2RefitReceiver+ NIXL layer that this extension uses, withMX_CACHE_BUFFERSenv toggle for A/B comparison (effectively a Python-level cache that mirrors the production code change line-for-line):registercost (0.152 s) goes to zero after the fix.Pair
Matching NeMo-RL trainer-side fix at jthomson04/RL#7.
Test Plan