Skip to content

perf(vllm/mx_refit): cache NIXL-registered dest buffers across refit cycles#10901

Open
KavinKrishnan wants to merge 2 commits into
ai-dynamo:mainfrom
KavinKrishnan:kavink/megatron-mx-perf-fix
Open

perf(vllm/mx_refit): cache NIXL-registered dest buffers across refit cycles#10901
KavinKrishnan wants to merge 2 commits into
ai-dynamo:mainfrom
KavinKrishnan:kavink/megatron-mx-perf-fix

Conversation

@KavinKrishnan

@KavinKrishnan KavinKrishnan commented Jun 23, 2026

Copy link
Copy Markdown
Contributor

Summary

Cache NIXL-registered dest buffers across Megatron-MX refit cycles in dynamo.vllm.mx_refit's _update_weights_via_mx_megatron. Plan shapes (and buffer shapes more generally) are deterministic for a fixed (source_tp, target_tp) layout — re-allocating + re-registering NIXL buffers on every refit cycle wastes ~0.15 s per cycle on small models and proportionally more on larger ones.

Follow-up to #10900 (the port PR). This PR contains BOTH commits — the port plus this fix — so the diff against main shows the full delta. Once the port PR merges, this PR will rebase to show only the perf-fix commit.

Motivation

Audit of refit-cycle wall time on multi-receiver setups surfaced a per-cycle bottleneck: NIXL register_tensors was being called every refit for ~290 buffers on Qwen3-4B (8 GB), costing ~150 ms of ibv_reg_mr overhead per receiver per cycle. The buffer shapes never change for a fixed TP layout, so the register call can be amortized to cycle 1 only.

Change

Three patterns applied:

  1. Matched-TP buffers cached on self._mx_megatron_buffers (line ~510)
  2. Mixed-TP plan_dests cached on self._mx_megatron_plan_dests (line ~593)
  3. Both paths skip register_tensors when nothing was newly allocated this cycle.

The cache survives for the lifetime of the worker. Plans that fall back to the v0 scratch path don't break the cache for plans that were v1 in prior cycles.

Validation

Cluster-validated on GB200 + Qwen3-4B-Thinking-2507 via 3 back-to-back refit cycles. The benchmark exercises the same MxV2RefitReceiver + NIXL layer that this extension uses, with MX_CACHE_BUFFERS env toggle for A/B comparison (effectively a Python-level cache that mirrors the production code change line-for-line):

                 alloc  register   pull  translate  total
BEFORE FIX:
  Cycle 1       0.032     0.152   0.209     0.016   0.409
  Cycle 2       0.026     0.152   0.203     0.012   0.392
  Cycle 3       0.001     0.024   0.210     0.011   0.246

AFTER FIX:
  Cycle 1       0.028     0.085   0.206     0.014   0.333
  Cycle 2       0.000     0.000   0.204     0.010   0.215   (-45% vs cycle-2 baseline)
  Cycle 3       0.000     0.000   0.204     0.011   0.215
  • The pull step (~205-210 ms for 8 GB at 308 Gbps single-NIC) is unchanged — the fix touches setup overhead only.
  • Cycle 2's register cost (0.152 s) goes to zero after the fix.

Pair

Matching NeMo-RL trainer-side fix at jthomson04/RL#7.

Test Plan

  • Multi-cycle benchmark on GB200 / Qwen3-4B-Thinking — captured before/after numbers summarised above.
  • Re-run on a larger multi-receiver setup once paired with the NeMo-RL trainer fix.

@KavinKrishnan KavinKrishnan requested review from a team as code owners June 23, 2026 21:40
@copy-pr-bot

copy-pr-bot Bot commented Jun 23, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@github-actions

Copy link
Copy Markdown
Contributor

👋 Hi KavinKrishnan! Thank you for contributing to ai-dynamo/dynamo.

Just a reminder: The NVIDIA Test Github Validation CI runs an essential subset of the testing framework to quickly catch errors.Your PR reviewers may elect to test the changes comprehensively before approving your changes.

🚀

@github-actions github-actions Bot added external-contribution Pull request is from an external contributor perf backend::vllm Relates to the vllm backend labels Jun 23, 2026
@datadog-official

datadog-official Bot commented Jun 23, 2026

Copy link
Copy Markdown

Pipelines

⚠️ Warnings

🚦 2 Pipeline jobs failed

Pre Merge | pre-commit   View in Datadog   GitHub Actions

Pre Merge | pre-merge-status-check   View in Datadog   GitHub Actions

This comment will be updated automatically if new data arrives.
🔗 Commit SHA: 4bd2ae9 | Docs | Give us feedback!

@coderabbitai

coderabbitai Bot commented Jun 23, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Walkthrough

Introduces a new mx_refit package under components/src/dynamo/vllm/mx_refit/ consisting of a package __init__.py and extension.py. The package exposes MxConfig (a dataclass for refit configuration) and MxRefitWorkerExtension (a vLLM worker extension class) implementing standard HF-layout and Megatron-layout RDMA weight refit paths plus a background version-polling thread.

Changes

MX Refit Worker Extension

Layer / File(s) Summary
Package contract and MxConfig dataclass
components/src/dynamo/vllm/mx_refit/__init__.py, components/src/dynamo/vllm/mx_refit/extension.py
__init__.py re-exports MxConfig and MxRefitWorkerExtension; extension.py adds the module docstring, imports, logger, and MxConfig dataclass with from_dict parsing/normalization.
NIC pinning and weight-loading helpers
components/src/dynamo/vllm/mx_refit/extension.py
Adds _pin_local_nic for NUMA-local NIC pinning before NIXL init, prepare_refit_info for storing trainer state_dict_info on the worker, _mx_load_weights delegating to vLLM's model.load_weights, and _mx_maybe_process_fp8_kv_cache for conditional FP8 weight post-processing.
Main refit RPC (update_weights_via_mx)
components/src/dynamo/vllm/mx_refit/extension.py
Lazy-initializes MxV2RefitReceiver with NIC pinning, discovers candidate sources for the requested version, routes Megatron candidates to a separate method, otherwise performs scratch-path RDMA pulls for HF-shaped tensors, loads weights, syncs CUDA, runs FP8 KV post-processing, optionally tree-republishes, and returns True/False.
Megatron refit path (_update_weights_via_mx_megatron)
components/src/dynamo/vllm/mx_refit/extension.py
Handles matched-TP (pre-allocated NIXL buffers, bulk receive_from) and mixed-TP (cached destination views, v1 sliced pulls, v0 scratch/host-copy fallback) Megatron layouts; translates Megatron tensors to HF tensors and runs the full weight-load + FP8 + optional tree-republish pipeline.
Background poller (start_mx_refit_poller)
components/src/dynamo/vllm/mx_refit/extension.py
Spawns an idempotent daemon thread that periodically calls discover_v2_sources to detect new training_step versions and triggers update_weights_via_mx on advancement, with clean shutdown via a stop event.

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 63.64% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ⚠️ Warning The description is detailed, but it does not follow the required template and omits the required Related Issues section and reviewer-start guidance. Add the template sections Overview, Details, Where should the reviewer start?, and a completed Related Issues section with one selected path.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly identifies the primary optimization: caching NIXL-registered buffers across refit cycles, which directly addresses the performance issue and matches the main change described in the PR.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
components/src/dynamo/vllm/mx_refit/extension.py (1)

404-410: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low value

Move stdlib imports to module top.

import time as _time (Line 404) and import threading (Line 777) are standard-library imports inside function bodies. The lazy-import rationale only applies to the optional modelexpress/vllm deps; stdlib imports have no load-cost justification and hide dependencies. Hoist them to the top of the module.

As per coding guidelines: "Keep imports at the top of the file; always flag import statements inside function bodies, methods, or classes" (exception only for optional dependencies). As per path instructions: ".ai/python-guidelines.md — all imports are at module top".

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@components/src/dynamo/vllm/mx_refit/extension.py` around lines 404 - 410, The
standard library imports `import time as _time` (found near line 404) and
`import threading` (found near line 777) are currently located inside function
bodies rather than at the module top level. Move both of these stdlib imports to
the top of the module file, outside of any function or class definitions, since
the lazy-import rationale only applies to optional dependencies like
modelexpress and vllm, not standard library modules. This follows the coding
guideline that all imports should be at module top and ensures dependencies are
clearly visible.

Sources: Coding guidelines, Path instructions

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@components/src/dynamo/vllm/mx_refit/extension.py`:
- Line 1: The file extension.py has formatting and import ordering issues that
are not compliant with the project's code style standards as enforced by black
and isort. To fix this, run the pre-commit hooks on the entire repository using
`pre-commit run --all-files` or alternatively run `ruff format` followed by
`isort` on the specific file to automatically reformat the code and organize
imports correctly. Commit the resulting changes before merging.
- Around line 615-641: The bug is that the cleanup guard at line 635 only
removes freshly allocated buffers from plan_dests when cached_plan_dests is None
(no warm cycle). On warm cycles, freshly allocated row-parallel buffers that
fail the contiguity check remain in plan_dests and get consumed with stale data
instead of being routed to v0. Track which tensor names were newly allocated in
the current cycle (you can use a set or extend the existing
newly_allocated_this_cycle counter to track names), then modify the condition at
line 635 to also check if plan.tensor_name is in the current cycle's newly
allocated set. This ensures freshly allocated non-contiguous buffers are removed
regardless of whether we're in a warm cycle, allowing them to be properly
handled by the v0 path via assemble_into_destination.

---

Nitpick comments:
In `@components/src/dynamo/vllm/mx_refit/extension.py`:
- Around line 404-410: The standard library imports `import time as _time`
(found near line 404) and `import threading` (found near line 777) are currently
located inside function bodies rather than at the module top level. Move both of
these stdlib imports to the top of the module file, outside of any function or
class definitions, since the lazy-import rationale only applies to optional
dependencies like modelexpress and vllm, not standard library modules. This
follows the coding guideline that all imports should be at module top and
ensures dependencies are clearly visible.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: d9e26229-e4e4-4b73-b507-ff3ec5a9ee81

📥 Commits

Reviewing files that changed from the base of the PR and between f8b9c6d and 8dad80b.

📒 Files selected for processing (2)
  • components/src/dynamo/vllm/mx_refit/__init__.py
  • components/src/dynamo/vllm/mx_refit/extension.py

@@ -0,0 +1,864 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟡 Minor | ⚡ Quick win

Run black and isort — pre-commit failed in CI.

The pre-merge pipeline reports both black (reformatting) and isort (import ordering) modified this file. Run pre-commit run --all-files (or ruff format + isort) and commit the result before merge.

🧰 Tools
🪛 GitHub Actions: Pre Merge / 8_pre-commit.txt

[error] 1-1: pre-commit hook 'isort' failed. Files were modified by this hook (Fixing /home/runner/work/dynamo/dynamo/components/src/dynamo/vllm/mx_refit/extension.py). Run 'pre-commit run isort --all-files' or apply the changes.


[error] 1-1: pre-commit hook 'black' failed. Files were modified by this hook (reformatted components/src/dynamo/vllm/mx_refit/extension.py). Run 'pre-commit run black --all-files' or apply the changes.

🪛 GitHub Actions: Pre Merge / pre-commit

[error] 1-1: pre-commit hook 'isort' failed (files were modified by this hook). Fix applied to the file; re-run the check after committing changes.


[error] 1-1: pre-commit hook 'black' failed (files were modified by this hook). Reformatting was applied to the file; re-run the check after committing changes.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@components/src/dynamo/vllm/mx_refit/extension.py` at line 1, The file
extension.py has formatting and import ordering issues that are not compliant
with the project's code style standards as enforced by black and isort. To fix
this, run the pre-commit hooks on the entire repository using `pre-commit run
--all-files` or alternatively run `ruff format` followed by `isort` on the
specific file to automatically reformat the code and organize imports correctly.
Commit the resulting changes before merging.

Source: Pipeline failures

Comment on lines +615 to +641
if plan.tensor_name in plan_dests:
dest = plan_dests[plan.tensor_name]
else:
dest = torch.empty(plan.target_shape, dtype=dt, device=device)
plan_dests[plan.tensor_name] = dest
newly_allocated_this_cycle += 1
axis = 1 if plan.assembly == "concat_dim1" else 0
routed_v1 = True
for src in plan.sources:
target_lo, target_hi = src.target_local_range
dest_view = dest.narrow(axis, target_lo, target_hi - target_lo)
if not dest_view.is_contiguous():
routed_v1 = False
break
v1_batches[src.mx_source_id].append(
(plan.tensor_name, src.source_subslice, dest_view)
)
if not routed_v1:
# Don't drop cached entries — they may be valid for
# other plans; just route this plan to v0.
if cached_plan_dests is None:
plan_dests.pop(plan.tensor_name, None)
for sid in v1_batches:
v1_batches[sid] = [
r for r in v1_batches[sid] if r[0] != plan.tensor_name
]
v0_plans.append(plan)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🗄️ Data Integrity & Integration | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, map the file to understand structure
ast-grep outline components/src/dynamo/vllm/mx_refit/extension.py --view expanded

Repository: ai-dynamo/dynamo

Length of output: 1131


🏁 Script executed:

# Read the relevant section around the bug (lines 600-700)
sed -n '600,700p' components/src/dynamo/vllm/mx_refit/extension.py

Repository: ai-dynamo/dynamo

Length of output: 4937


🏁 Script executed:

# Search for where plan_dests is initialized and used
rg "plan_dests" components/src/dynamo/vllm/mx_refit/extension.py -n -B 2 -A 2

Repository: ai-dynamo/dynamo

Length of output: 3013


🏁 Script executed:

# Look for test files to verify multi-cycle test coverage
find components -name "*test*" -type f | grep -E "(mx_refit|extension)" | head -20

Repository: ai-dynamo/dynamo

Length of output: 154


🏁 Script executed:

# Search for test files in the repository
find components -path "*test*" -name "*.py" | grep -i "mx\|refit" | head -20

Repository: ai-dynamo/dynamo

Length of output: 154


🏁 Script executed:

# Search more broadly for test files
find . -name "test_*.py" -o -name "*_test.py" | grep -E "mx|refit" | head -20

Repository: ai-dynamo/dynamo

Length of output: 208


🏁 Script executed:

# Look for any pytest files or test directories
find . -type d -name "test*" | head -10

Repository: ai-dynamo/dynamo

Length of output: 562


🏁 Script executed:

# Search for references to mixed-TP tests or multi-cycle validation
rg "mixed.?TP|warm.?cycle|consecutive.*cycle" --type py -i

Repository: ai-dynamo/dynamo

Length of output: 154


🏁 Script executed:

# Search for tests that validate the _update_weights_via_mx_megatron method
rg "_update_weights_via_mx_megatron|MxRefitWorkerExtension" --type py | grep -i test

Repository: ai-dynamo/dynamo

Length of output: 154


🏁 Script executed:

# Check the mx_integration test file
cat lib/gpu_memory_service/tests/test_mx_integration.py

Repository: ai-dynamo/dynamo

Length of output: 3580


🏁 Script executed:

# Search for multi-cycle test scenarios in all test files
rg "cycle|refit" lib/gpu_memory_service/tests/ -A 3 -B 1

Repository: ai-dynamo/dynamo

Length of output: 936


🏁 Script executed:

# Look for test files in the vllm/mx_refit directory
find components/src/dynamo/vllm/mx_refit -name "*.py" | head -20

Repository: ai-dynamo/dynamo

Length of output: 251


🏁 Script executed:

# Look for any conftest or test utilities in the mx_refit area
find components/src/dynamo/vllm -name "test_*.py" -o -name "*_test.py"

Repository: ai-dynamo/dynamo

Length of output: 2728


🏁 Script executed:

# Read more of the code to understand the complete flow
sed -n '590,710p' components/src/dynamo/vllm/mx_refit/extension.py

Repository: ai-dynamo/dynamo

Length of output: 6096


🏁 Script executed:

# Search for PR/commit references to understand the context of this fix
rg "16-receiver|Llama 3.1|2026-06-22" components/src/dynamo/vllm/mx_refit/ -B 2 -A 2

Repository: ai-dynamo/dynamo

Length of output: 1411


🏁 Script executed:

# Look for tests that specifically test mixed-TP or warm cycle scenarios
find . -path "*/test*" -name "*.py" -exec grep -l "plan_dests\|mx_megatron\|warm.*cycle" {} \;

Repository: ai-dynamo/dynamo

Length of output: 154


🏁 Script executed:

# Check if there are integration tests or unit tests for MxRefitWorkerExtension
rg "MxRefitWorkerExtension|_update_weights_via_mx_megatron" --type py

Repository: ai-dynamo/dynamo

Length of output: 154


🏁 Script executed:

# Look at the register_tensors call and assembly loop more carefully
sed -n '643,710p' components/src/dynamo/vllm/mx_refit/extension.py

Repository: ai-dynamo/dynamo

Length of output: 3475


🏁 Script executed:

# Search for the validation lines mentioned: "8/8, 16/16"
rg "8/8|16/16" components/src/dynamo/vllm/ --type py -B 2 -A 2

Repository: ai-dynamo/dynamo

Length of output: 154


🏁 Script executed:

# Search for mixed-TP or megatron test files
find components/src/dynamo/vllm/tests -name "*.py" -exec grep -l "megatron\|mixed.?TP" {} \;

Repository: ai-dynamo/dynamo

Length of output: 154


🏁 Script executed:

# Look for any tests that call _update_weights_via_mx_megatron
rg "_update_weights_via_mx_megatron|plan_dests" --type py

Repository: ai-dynamo/dynamo

Length of output: 154


🏁 Script executed:

# Search for broader mixed-TP validation or validation tests
rg "validation|refit.*cycle|warm.*cycle" components/src/dynamo/vllm/ --type py -i

Repository: ai-dynamo/dynamo

Length of output: 2716


🏁 Script executed:

# Look for any pytest markers or test discovery in the extension module itself
rg "pytest|test_" components/src/dynamo/vllm/mx_refit/ --type py

Repository: ai-dynamo/dynamo

Length of output: 343


🏁 Script executed:

# Search for any validation comments or markers about single vs multi-cycle in the PR
rg "single.?cycle|multi.?cycle|consecutive.*cycle" . --type py -i

Repository: ai-dynamo/dynamo

Length of output: 154


🏁 Script executed:

# Check if there are any conftest files that might contain shared fixtures
find components/src/dynamo/vllm -name "conftest.py"

Repository: ai-dynamo/dynamo

Length of output: 199


🏁 Script executed:

# Look at the full context of the assembly loop to verify the bug path
sed -n '690,720p' components/src/dynamo/vllm/mx_refit/extension.py

Repository: ai-dynamo/dynamo

Length of output: 1778


🏁 Script executed:

# Final verification: check if there are any inline test markers or validation in the file
rg "TODO|FIXME|XXX|BUG|HACK" components/src/dynamo/vllm/mx_refit/extension.py -B 1 -A 1

Repository: ai-dynamo/dynamo

Length of output: 358


🏁 Script executed:

# Check the register_tensors signature to confirm it's being called correctly
rg "def register_tensors|_nixl.register" . --type py -B 2 -A 5 | head -30

Repository: ai-dynamo/dynamo

Length of output: 3281


Mixed-TP warm-cycle data integrity bug: freshly allocated v0 (row-parallel) destination buffers are left in plan_dests and consumed as empty tensors on subsequent cycles.

On a warm cycle (cached_plan_dests is not None), when a row-parallel plan is non-contiguous and freshly allocated at line 618, the pop guard at line 635 (if cached_plan_dests is None) evaluates to false, so the buffer is not removed. The buffer is then registered and consumed in the assembly loop at line 694, loading garbage/stale weights silently instead of taking the v0 assemble_into_destination path.

The fix must track per-cycle allocations, not just whether a cache existed:

Proposed fix
             plan_dests: dict[str, torch.Tensor] = cached_plan_dests or {}
             v1_batches: dict[str, list] = {c.ref.mx_source_id: [] for c in megatron_cands}
             v0_plans: list = []
             newly_allocated_this_cycle = 0
+            newly_allocated_names: set[str] = set()

             for plan in plans:
                 ...
                 if plan.tensor_name in plan_dests:
                     dest = plan_dests[plan.tensor_name]
                 else:
                     dest = torch.empty(plan.target_shape, dtype=dt, device=device)
                     plan_dests[plan.tensor_name] = dest
                     newly_allocated_this_cycle += 1
+                    newly_allocated_names.add(plan.tensor_name)
                 ...
                 if not routed_v1:
-                    # Don't drop cached entries — they may be valid for
-                    # other plans; just route this plan to v0.
-                    if cached_plan_dests is None:
-                        plan_dests.pop(plan.tensor_name, None)
+                    # Drop only buffers allocated this cycle; pre-existing
+                    # cached entries valid for other plans must stay.
+                    if plan.tensor_name in newly_allocated_names:
+                        plan_dests.pop(plan.tensor_name, None)
+                        newly_allocated_this_cycle -= 1
+                        newly_allocated_names.discard(plan.tensor_name)

Verify multi-cycle test coverage for row-parallel (axis-1, non-contiguous) tensors across two or more consecutive refit cycles. Current validation appears limited to single-cycle scenarios.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if plan.tensor_name in plan_dests:
dest = plan_dests[plan.tensor_name]
else:
dest = torch.empty(plan.target_shape, dtype=dt, device=device)
plan_dests[plan.tensor_name] = dest
newly_allocated_this_cycle += 1
axis = 1 if plan.assembly == "concat_dim1" else 0
routed_v1 = True
for src in plan.sources:
target_lo, target_hi = src.target_local_range
dest_view = dest.narrow(axis, target_lo, target_hi - target_lo)
if not dest_view.is_contiguous():
routed_v1 = False
break
v1_batches[src.mx_source_id].append(
(plan.tensor_name, src.source_subslice, dest_view)
)
if not routed_v1:
# Don't drop cached entries — they may be valid for
# other plans; just route this plan to v0.
if cached_plan_dests is None:
plan_dests.pop(plan.tensor_name, None)
for sid in v1_batches:
v1_batches[sid] = [
r for r in v1_batches[sid] if r[0] != plan.tensor_name
]
v0_plans.append(plan)
if plan.tensor_name in plan_dests:
dest = plan_dests[plan.tensor_name]
else:
dest = torch.empty(plan.target_shape, dtype=dt, device=device)
plan_dests[plan.tensor_name] = dest
newly_allocated_this_cycle += 1
newly_allocated_names.add(plan.tensor_name)
axis = 1 if plan.assembly == "concat_dim1" else 0
routed_v1 = True
for src in plan.sources:
target_lo, target_hi = src.target_local_range
dest_view = dest.narrow(axis, target_lo, target_hi - target_lo)
if not dest_view.is_contiguous():
routed_v1 = False
break
v1_batches[src.mx_source_id].append(
(plan.tensor_name, src.source_subslice, dest_view)
)
if not routed_v1:
# Drop only buffers allocated this cycle; pre-existing
# cached entries valid for other plans must stay.
if plan.tensor_name in newly_allocated_names:
plan_dests.pop(plan.tensor_name, None)
newly_allocated_this_cycle -= 1
newly_allocated_names.discard(plan.tensor_name)
for sid in v1_batches:
v1_batches[sid] = [
r for r in v1_batches[sid] if r[0] != plan.tensor_name
]
v0_plans.append(plan)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@components/src/dynamo/vllm/mx_refit/extension.py` around lines 615 - 641, The
bug is that the cleanup guard at line 635 only removes freshly allocated buffers
from plan_dests when cached_plan_dests is None (no warm cycle). On warm cycles,
freshly allocated row-parallel buffers that fail the contiguity check remain in
plan_dests and get consumed with stale data instead of being routed to v0. Track
which tensor names were newly allocated in the current cycle (you can use a set
or extend the existing newly_allocated_this_cycle counter to track names), then
modify the condition at line 635 to also check if plan.tensor_name is in the
current cycle's newly allocated set. This ensures freshly allocated
non-contiguous buffers are removed regardless of whether we're in a warm cycle,
allowing them to be properly handled by the v0 path via
assemble_into_destination.

@devin-ai-integration devin-ai-integration Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 7 potential issues.

Open in Devin Review

Comment on lines +524 to +527
full_shape[spec.shard_axis] = (
axis_extent if layout.tp_rank == target_tp - 1
else per_rank
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Last TP rank buffer gets full global axis extent instead of per-rank shard size

In matched-TP buffer allocation (line 524-527), the last TP rank's buffer dimension is set to axis_extent (the entire global dimension) instead of the correct per-rank remainder axis_extent - per_rank * (target_tp - 1). For TP=2, spec.target_shape is the global shape (per_rank × source_tp at extension.py:457-459). The code computes per_rank = axis_extent // target_tp, which is correct for non-last ranks. But for layout.tp_rank == target_tp - 1, it assigns the full axis_extent, which is target_tp times larger than needed.

Concrete example for TP=2

Source per-rank weight shape = [1024, 4096] on shard_axis=0.
Global shape (spec.target_shape) = [2048, 4096].
axis_extent = 2048, per_rank = 1024.

  • Rank 0: buffer = [1024, 4096] ✓
  • Rank 1 (last): buffer = [2048, 4096] ✗ — should be [1024, 4096]

Rank 1 allocates a buffer 2× the needed size. The RDMA receive_from at line 550 fills only the first 1024 rows from the source's per-rank shard, leaving 1024 rows of uninitialized data. When run_refit_cycle at line 565 processes pre_assembled_buffers=buffers, it may interpret the full buffer shape as valid data, producing incorrect HF weight tensors with garbage.

Suggested change
full_shape[spec.shard_axis] = (
axis_extent if layout.tp_rank == target_tp - 1
else per_rank
)
full_shape[spec.shard_axis] = (
axis_extent - per_rank * (target_tp - 1)
if layout.tp_rank == target_tp - 1
else per_rank
)
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

* synthetic TP=2 → TP=1 mixed-TP target-narrower: 8 / 8
* synthetic TP=1 → TP=2 mixed-TP target-wider (v1 sliced-pull): 16 / 16
"""
import time as _time

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Stdlib imports (time, threading) placed inside method bodies instead of at module top

import time as _time at line 404 and import threading at line 777 are standard-library imports placed inside method bodies. The .ai/python-guidelines.md critical rule states: "Always flag any import statement that appears inside a function body, method, or class. Imports inside functions hide dependencies, make the module harder to understand at a glance, and can mask missing packages until a specific code path is hit at runtime." Other files in the same package (e.g. handlers.py:15-16, worker_factory.py:10) correctly import threading and time at the top of the file.

Prompt for agents
Move `import time as _time` (line 404) and `import threading` (line 777) to the top of the file alongside the other stdlib imports (gc, logging, os, traceback) at lines 27-30. These are standard library modules with no reason for lazy loading. The .ai/python-guidelines.md critical rule requires all imports at module level.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

if not getattr(self, "_mx_receiver", None):
# Import here so workers that never refit via MX don't pay
# the modelexpress import cost.
from modelexpress import MxV2RefitReceiver

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 modelexpress and vllm sub-module imports inside method bodies instead of using try/except ImportError at module top

Multiple modelexpress imports (lines 95, 237, 405-410, 790) and a vllm sub-module import (lines 177-179) are placed inside method bodies. The .ai/python-guidelines.md critical rule states these must be flagged. For optional dependencies like modelexpress, the approved exception pattern is try/except ImportError at the top of the file (setting fallbacks to None), not lazy imports inside functions. The rule explicitly states: "Using try/except ImportError is the correct pattern when: An optional backend dependency... may not be installed, and the code provides a fallback or sets the import to None."

Prompt for agents
Move all modelexpress and vllm sub-module imports to the top of the file using the approved try/except ImportError pattern for optional dependencies. For example:

try:
    from modelexpress import MxV2RefitReceiver
    from modelexpress.ucx_utils import apply_nic_pin_for_device
    from modelexpress.megatron_translator import (
        MegatronReceiverContext, ReceiveSpec,
        assemble_into_destination, discover_megatron_context,
        run_refit_cycle, translate_megatron_to_hf,
    )
    from modelexpress.nemo_rl_v2 import MegatronTensorSpec, TargetTpLayout
except ImportError:
    MxV2RefitReceiver = None
    # other fallbacks...

Similarly move the vllm sub-module import:
from vllm.model_executor.model_loader.utils import process_weights_after_loading

This affects lines 95, 177-179, 237, 405-410, and 790 in extension.py.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Comment on lines +351 to +360
except Exception as exc: # noqa: BLE001
logger.error(
"[mx] update_weights_via_mx failed on rank=%d: %s\n%s",
getattr(
getattr(self, "_mx_receiver", None), "worker_rank", -1
),
exc,
traceback.format_exc(),
)
return False

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Multiple except Exception blocks catch-and-suppress without re-raising

The .ai/python-guidelines.md states: "If you must catch broadly, use except Exception: (never bare except:) and always re-raise after logging." The main refit handler at lines 351-360 catches except Exception, logs the error, and returns False without re-raising. Similarly, lines 104-105, 341-346, 740-743, 819-824, and 855-856 all catch Exception and suppress it. While the RPC handler design deliberately returns failure status instead of propagating exceptions, this violates the guidelines' explicit "always re-raise" rule.

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Comment on lines +430 to +434
target_tp_rank = (
torch.distributed.get_rank()
if torch.distributed.is_initialized() else 0
)
layout = TargetTpLayout(tp_size=target_tp, tp_rank=target_tp_rank)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚩 target_tp_rank uses global rank — correct for TP-only but wrong for TP+PP

At extension.py:430-433, target_tp_rank is set via torch.distributed.get_rank() which returns the rank in the default (world) process group. For pure TP deployments (the common vLLM case), global_rank == tp_rank, so this works. However, for TP+PP configurations (e.g. TP=2, PP=2 → world_size=4), rank 2 would have global_rank=2 but tp_rank=0. This would cause the matched-TP check at extension.py:486 to fail (no source has tp_rank=2 when tp_size=2), incorrectly falling through to the mixed-TP path. Since vLLM does support PP, this could be a latent issue if PP>1 is ever used with this refit extension. Not flagged as a bug because PP is uncommon for inference and the code explicitly targets TP-only workflows, but worth noting for future generalization.

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Comment on lines +842 to +844
ok = self.update_weights_via_mx(
version=latest, mx_config=cfg
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚩 Poller thread calls update_weights_via_mx outside collective_rpc — potential concurrency with inference

The poller thread at extension.py:842 calls self.update_weights_via_mx(version=latest, mx_config=cfg) directly from a background thread. Unlike the explicit RPC path (which goes through collective_rpc which pauses generation), the poller does not coordinate with the inference loop. This means model.load_weights() at extension.py:156 could execute concurrently with model.forward() in the main worker thread. While PyTorch's default CUDA stream is shared across threads (serializing GPU ops), CPU-side tensor metadata modifications during load_weights are not protected. In practice, CUDA stream serialization and the GIL may prevent observable corruption, but this is architecturally risky. The poller docstring doesn't mention this limitation.

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

SHARD_AXIS_BY_ROLE = {
"column": 0, "qkv_column": 0, "gated_mlp_column": 0,
"vocab_parallel": 0, "row": 1,
"expert_column": 0, "expert_row": 0, "replicated": 0,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚩 expert_row shard_axis set to 0 while regular row uses 1 — intentional for MoE passthrough?

At extension.py:441, SHARD_AXIS_BY_ROLE maps "expert_row": 0 while "row": 1. This seems inconsistent — regular row-parallel shards along axis 1. The shard_axis is used at line 457-459 to compute global_shape: global_shape[shard_axis] *= source_tp_size. For expert_row with axis=0, this multiplies the wrong dimension relative to regular row. However, lines 528-531 treat expert tensors as passthrough (pass) and lines 611-612 route per_expert plans to v0 scratch. The shard_axis for experts may represent the expert-grouping dimension rather than the TP dimension, making axis=0 correct in context. Without access to the Megatron-Core MoE conventions, I couldn't definitively determine whether this is correct.

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Adds a new `dynamo.vllm.mx_refit` module that ports NeMo-RL's
`VllmInternalWorkerExtension.update_weights_via_mx` into Dynamo's
vLLM worker so external trainers can publish per-rank tensor shards
via NIXL RDMA into a running AsyncLLM mid-training without going
through HF-format round-tripping.

Two paths are supported:

  1. DTensor / v1 path: trainer publishes torch.distributed DTensors;
     receiver does a full-tensor pull (`.full_tensor()`-equivalent over
     NIXL) and lets vLLM's standard load_weights re-shard.

  2. Megatron-Core / v2 path: trainer publishes per-rank Megatron-native
     shards with a sidecar describing the model's transformer config and
     a name_map. Receiver builds a slice plan against its own vLLM TP/EP
     layout, pulls each rank's contribution directly into a pre-allocated
     global tensor, then applies role-aware translation (QKV
     un-interleave, gated-MLP split, name remap) via the vendored Bridge
     helpers shipped with the modelexpress Python client. Supports both
     matched-TP (source_tp == target_tp) and mixed-TP (source_tp != target_tp,
     with v1 sliced-pull for contiguous narrows + v0 host-scratch+slice
     for strided / per-expert dict assembly).

The Megatron client-side machinery (slice planner, translator, vendored
Bridge helpers, v1 NIXL sliced-pull primitive) lives in
ai-dynamo/modelexpress PR ai-dynamo#429 (merged). The trainer-side publisher
lives in NeMo-RL PR ai-dynamo#2 (merged into the dynamo-k8s-integration branch).

This file has been carried as a private kubectl-overlay against
ai-dynamo's container image and validated end-to-end on GB200 against:

  - Qwen3-4B-Thinking-2507: 398/398 HF tensors byte-identical
    (matched-TP).
  - Qwen3-MoE-30B-A3B: 18,867/18,867 HF tensors byte-identical
    (matched-TP, grouped-MoE).
  - Mixed-TP target-wider (trainer TP=1 -> vLLM TP=2): 16/16
    byte-identical via v1 NIXL sliced-pull.

No public Dynamo APIs change; the extension is opt-in via the standard
`mx_refit.MxRefitWorkerExtension` registration on the vLLM worker.

A separate follow-up PR adds buffer-caching across refit cycles to
avoid re-allocating + re-registering NIXL buffers on every cycle.
…cycles

Audit of refit-cycle wall time on multi-receiver setups surfaced a
receiver-side bug: both the matched-TP and mixed-TP branches in
`_update_weights_via_mx_megatron` re-allocated + re-registered per-rank
buffers with NIXL on every refit cycle, paying ~0.15 s of `ibv_reg_mr`
per cycle on Qwen3-4B (proportionally more on larger models).

Plan shapes (and buffer shapes more generally) are deterministic for a
fixed `(source_tp, target_tp)` layout, so cycle-N's allocations are
identical to cycle-1's. Cache the buffers / plan_dests dicts on
`self._mx_megatron_buffers` and `self._mx_megatron_plan_dests`
respectively, and skip the `register_tensors` call when nothing was
newly allocated this cycle.

Cluster validation on GB200 / Qwen3-4B-Thinking 2026-06-23, 3
back-to-back refit cycles via a standalone benchmark exercising the
same MxV2RefitReceiver + NIXL layer:

                 alloc  register   pull  translate  total
  Before fix:
    Cycle 1     0.032     0.152   0.209     0.016   0.409
    Cycle 2     0.026     0.152   0.203     0.012   0.392
    Cycle 3     0.001     0.024   0.210     0.011   0.246

  After fix:
    Cycle 1     0.028     0.085   0.206     0.014   0.333
    Cycle 2     0.000     0.000   0.204     0.010   0.215   (-45%)
    Cycle 3     0.000     0.000   0.204     0.011   0.215

Pull step unchanged (~205-210 ms for 8 GB at 308 Gbps single-NIC); the
fix touches setup overhead only. Larger models will see proportionally
larger savings since `ibv_reg_mr` scales with both buffer count and
pinned-memory size.

The matching fix in the NeMo-RL trainer-side `vllm_backend.py` lives
at jthomson04/RL#7.
@KavinKrishnan KavinKrishnan force-pushed the kavink/megatron-mx-perf-fix branch from 8dad80b to 4bd2ae9 Compare June 24, 2026 04:57
@KavinKrishnan KavinKrishnan temporarily deployed to external_collaborator June 24, 2026 04:57 — with GitHub Actions Inactive
KavinKrishnan added a commit to ai-dynamo/modelexpress that referenced this pull request Jun 26, 2026
…timizations

Lands the v2 client surface and the surrounding RL workstream needed
to unblock the downstream Megatron-MX and perf PRs.

## What this contributes

### v2 client surface (this PR's primary deliverable)

* `MxV2TrainingPublisher` + `MxV2RefitReceiver` (modelexpress.nemo_rl_v2)
  -- the fat-client surface for per-rank shard publish and
  receiver-side multi-source assembly.
* `MxWeightTransferEngine` (modelexpress.vllm_weight_transfer) -- an
  adapter implementing vLLM's upstream WeightTransferEngine ABC.
* `TensorDescriptorV2.extra_parameters` (map<string,string>) plus
  `SourceIdentity.revision` (string) -- the two proto extensions that
  carry all per-tensor + per-source RL metadata.

### Phase 3a -- compile_target + compile_metadata on TensorDescriptorV2

New per-tensor fields default to ``hf_raw`` / ``{}``; the wire encoder
omits them when default so existing payloads stay byte-identical.
New constants: ``COMPILE_TARGET_HF_RAW``, ``_VLLM_FUSED``,
``_DEEPGEMM_FP8``, ``_CUTLASS_FP8``, ``_TRTLLM``. New helper
``compile_target_matches(descriptor, *, allowed_targets,
required_metadata=None)`` for receiver-side filtering with whitelist
plus required-metadata-subset semantics.

### Phase 3b -- compile_target_filter on discover_v2_sources

New kwargs ``compile_target_filter`` (whitelist set) and
``required_compile_metadata`` (subset-of-every-tensor's
compile_metadata). Candidates with no v2 registry are rejected when
either filter is set; candidates with mixed compile targets are
rejected if any tensor falls outside the allowed set.
``V2SourceCandidate.compile_targets: frozenset[str]`` exposed for
caller introspection.

### Phase 4 -- Multi-source slice discovery for mixed trainer/inference TP

New types: ``TargetTPLayout``, ``SliceSource``, ``SliceCoveragePlan``.
New methods ``MxV2RefitReceiver.discover_v2_sources_for_slice`` and
``MxV2RefitReceiver.receive_via_plan`` -- planner walks v2 candidates
per tensor, intersects each publisher's local_shard_range against the
receiver's requested slice, emits the minimal candidate set covering
it; surfaces coverage gaps and shard_axis mismatches in plan.missing.
``receive_via_plan`` orchestrates per-candidate scratch RDMA pulls
and stitches via torch.cat along the shard axis.

### Proto extensions

* ``TensorDescriptorV2.extra_parameters: map<string, string>`` -- the
  escape hatch downstream RL clients use for per-tensor metadata
  (megatron_role, compile_target, expert_id, revision, training_step,
  ...). Heavily used by #429.
* ``SourceIdentity.revision: string`` -- content-addressed weight
  version. Non-empty value guarantees two sources with identical
  SourceIdentity have bit-identical weight bytes, enabling
  decentralized modes (no central coordinator) to use mx_source_id
  as a full content check.
* Redis backend ``SourceAttributesJson`` carries the new fields plus
  the union of main's compatibility additions
  (backend_framework_version, torch_version, cuda_version,
  triton_version, gpu_arch, compile_config_digest).

## Tests

68/68 unit tests green on this branch's surface (test_v2_source_picker,
test_v2_shape_registry, test_source_id, test_types, test_vllm_adapter).

## Compatibility

Every new field has a backwards-compatible default and every new method
arg is optional. Existing demos and downstream consumers run unchanged.

## Downstream impact

This PR is the prerequisite for several PRs in flight:

* #421 -- publish_self_as_source tree fan-out fix
* #429 -- Megatron-Core MX clients
* #450 -- MX_RDMA_NIC_PIN=stripe multi-NIC mode
* ai-dynamo/dynamo#10900 -- first-time upstream port of
  dynamo.vllm.mx_refit extension
* ai-dynamo/dynamo#10901 -- buffer-caching perf fix in Dynamo extension
* jthomson04/RL#2 (merged) + #7 -- NeMo RL Megatron-MX integration plus
  perf fix

Once this PR lands, kavink/nemo_rl_moe (the integration branch holding

Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
#421 and #429) can be retired by rebasing those PRs onto main.
KavinKrishnan added a commit to ai-dynamo/modelexpress that referenced this pull request Jun 26, 2026
…timizations

Lands the v2 client surface and the surrounding RL workstream needed
to unblock the downstream Megatron-MX and perf PRs.

## What this contributes

### v2 client surface (this PR's primary deliverable)

* `MxV2TrainingPublisher` + `MxV2RefitReceiver` (modelexpress.nemo_rl_v2)
  -- the fat-client surface for per-rank shard publish and
  receiver-side multi-source assembly.
* `MxWeightTransferEngine` (modelexpress.vllm_weight_transfer) -- an
  adapter implementing vLLM's upstream WeightTransferEngine ABC.
* `TensorDescriptorV2.extra_parameters` (map<string,string>) plus
  `SourceIdentity.revision` (string) -- the two proto extensions that
  carry all per-tensor + per-source RL metadata.

### Phase 3a -- compile_target + compile_metadata on TensorDescriptorV2

New per-tensor fields default to ``hf_raw`` / ``{}``; the wire encoder
omits them when default so existing payloads stay byte-identical.
New constants: ``COMPILE_TARGET_HF_RAW``, ``_VLLM_FUSED``,
``_DEEPGEMM_FP8``, ``_CUTLASS_FP8``, ``_TRTLLM``. New helper
``compile_target_matches(descriptor, *, allowed_targets,
required_metadata=None)`` for receiver-side filtering with whitelist
plus required-metadata-subset semantics.

### Phase 3b -- compile_target_filter on discover_v2_sources

New kwargs ``compile_target_filter`` (whitelist set) and
``required_compile_metadata`` (subset-of-every-tensor's
compile_metadata). Candidates with no v2 registry are rejected when
either filter is set; candidates with mixed compile targets are
rejected if any tensor falls outside the allowed set.
``V2SourceCandidate.compile_targets: frozenset[str]`` exposed for
caller introspection.

### Phase 4 -- Multi-source slice discovery for mixed trainer/inference TP

New types: ``TargetTPLayout``, ``SliceSource``, ``SliceCoveragePlan``.
New methods ``MxV2RefitReceiver.discover_v2_sources_for_slice`` and
``MxV2RefitReceiver.receive_via_plan`` -- planner walks v2 candidates
per tensor, intersects each publisher's local_shard_range against the
receiver's requested slice, emits the minimal candidate set covering
it; surfaces coverage gaps and shard_axis mismatches in plan.missing.
``receive_via_plan`` orchestrates per-candidate scratch RDMA pulls
and stitches via torch.cat along the shard axis.

### Proto extensions

* ``TensorDescriptorV2.extra_parameters: map<string, string>`` -- the
  escape hatch downstream RL clients use for per-tensor metadata
  (megatron_role, compile_target, expert_id, revision, training_step,
  ...). Heavily used by #429.
* ``SourceIdentity.revision: string`` -- content-addressed weight
  version. Non-empty value guarantees two sources with identical
  SourceIdentity have bit-identical weight bytes, enabling
  decentralized modes (no central coordinator) to use mx_source_id
  as a full content check.
* Redis backend ``SourceAttributesJson`` carries the new fields plus
  the union of main's compatibility additions
  (backend_framework_version, torch_version, cuda_version,
  triton_version, gpu_arch, compile_config_digest).

## Tests

68/68 unit tests green on this branch's surface (test_v2_source_picker,
test_v2_shape_registry, test_source_id, test_types, test_vllm_adapter).

## Compatibility

Every new field has a backwards-compatible default and every new method
arg is optional. Existing demos and downstream consumers run unchanged.

## Downstream impact

This PR is the prerequisite for several PRs in flight:

* #421 -- publish_self_as_source tree fan-out fix
* #429 -- Megatron-Core MX clients
* #450 -- MX_RDMA_NIC_PIN=stripe multi-NIC mode
* ai-dynamo/dynamo#10900 -- first-time upstream port of
  dynamo.vllm.mx_refit extension
* ai-dynamo/dynamo#10901 -- buffer-caching perf fix in Dynamo extension
* jthomson04/RL#2 (merged) + #7 -- NeMo RL Megatron-MX integration plus
  perf fix

Once this PR lands, kavink/nemo_rl_moe (the integration branch holding

Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
#421 and #429) can be retired by rebasing those PRs onto main.
KavinKrishnan added a commit to ai-dynamo/modelexpress that referenced this pull request Jun 26, 2026
…timizations

Lands the v2 client surface and the surrounding RL workstream needed
to unblock the downstream Megatron-MX and perf PRs.

## What this contributes

### v2 client surface (this PR's primary deliverable)

* `MxV2TrainingPublisher` + `MxV2RefitReceiver` (modelexpress.nemo_rl_v2)
  -- the fat-client surface for per-rank shard publish and
  receiver-side multi-source assembly.
* `MxWeightTransferEngine` (modelexpress.vllm_weight_transfer) -- an
  adapter implementing vLLM's upstream WeightTransferEngine ABC.
* `TensorDescriptorV2.extra_parameters` (map<string,string>) plus
  `SourceIdentity.revision` (string) -- the two proto extensions that
  carry all per-tensor + per-source RL metadata.

### Phase 3a -- compile_target + compile_metadata on TensorDescriptorV2

New per-tensor fields default to ``hf_raw`` / ``{}``; the wire encoder
omits them when default so existing payloads stay byte-identical.
New constants: ``COMPILE_TARGET_HF_RAW``, ``_VLLM_FUSED``,
``_DEEPGEMM_FP8``, ``_CUTLASS_FP8``, ``_TRTLLM``. New helper
``compile_target_matches(descriptor, *, allowed_targets,
required_metadata=None)`` for receiver-side filtering with whitelist
plus required-metadata-subset semantics.

### Phase 3b -- compile_target_filter on discover_v2_sources

New kwargs ``compile_target_filter`` (whitelist set) and
``required_compile_metadata`` (subset-of-every-tensor's
compile_metadata). Candidates with no v2 registry are rejected when
either filter is set; candidates with mixed compile targets are
rejected if any tensor falls outside the allowed set.
``V2SourceCandidate.compile_targets: frozenset[str]`` exposed for
caller introspection.

### Phase 4 -- Multi-source slice discovery for mixed trainer/inference TP

New types: ``TargetTPLayout``, ``SliceSource``, ``SliceCoveragePlan``.
New methods ``MxV2RefitReceiver.discover_v2_sources_for_slice`` and
``MxV2RefitReceiver.receive_via_plan`` -- planner walks v2 candidates
per tensor, intersects each publisher's local_shard_range against the
receiver's requested slice, emits the minimal candidate set covering
it; surfaces coverage gaps and shard_axis mismatches in plan.missing.
``receive_via_plan`` orchestrates per-candidate scratch RDMA pulls
and stitches via torch.cat along the shard axis.

### Proto extensions

* ``TensorDescriptorV2.extra_parameters: map<string, string>`` -- the
  escape hatch downstream RL clients use for per-tensor metadata
  (megatron_role, compile_target, expert_id, revision, training_step,
  ...). Heavily used by #429.
* ``SourceIdentity.revision: string`` -- content-addressed weight
  version. Non-empty value guarantees two sources with identical
  SourceIdentity have bit-identical weight bytes, enabling
  decentralized modes (no central coordinator) to use mx_source_id
  as a full content check.
* Redis backend ``SourceAttributesJson`` carries the new fields plus
  the union of main's compatibility additions
  (backend_framework_version, torch_version, cuda_version,
  triton_version, gpu_arch, compile_config_digest).

## Tests

68/68 unit tests green on this branch's surface (test_v2_source_picker,
test_v2_shape_registry, test_source_id, test_types, test_vllm_adapter).

## Compatibility

Every new field has a backwards-compatible default and every new method
arg is optional. Existing demos and downstream consumers run unchanged.

## Downstream impact

This PR is the prerequisite for several PRs in flight:

* #421 -- publish_self_as_source tree fan-out fix
* #429 -- Megatron-Core MX clients
* #450 -- MX_RDMA_NIC_PIN=stripe multi-NIC mode
* ai-dynamo/dynamo#10900 -- first-time upstream port of
  dynamo.vllm.mx_refit extension
* ai-dynamo/dynamo#10901 -- buffer-caching perf fix in Dynamo extension
* jthomson04/RL#2 (merged) + #7 -- NeMo RL Megatron-MX integration plus
  perf fix

Once this PR lands, kavink/nemo_rl_moe (the integration branch holding

Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
#421 and #429) can be retired by rebasing those PRs onto main.
KavinKrishnan added a commit to ai-dynamo/modelexpress that referenced this pull request Jun 26, 2026
…timizations

Lands the v2 client surface and the surrounding RL workstream needed
to unblock the downstream Megatron-MX and perf PRs.

## What this contributes

### v2 client surface (this PR's primary deliverable)

* `MxV2TrainingPublisher` + `MxV2RefitReceiver` (modelexpress.nemo_rl_v2)
  -- the fat-client surface for per-rank shard publish and
  receiver-side multi-source assembly.
* `MxWeightTransferEngine` (modelexpress.vllm_weight_transfer) -- an
  adapter implementing vLLM's upstream WeightTransferEngine ABC.
* `TensorDescriptorV2.extra_parameters` (map<string,string>) plus
  `SourceIdentity.revision` (string) -- the two proto extensions that
  carry all per-tensor + per-source RL metadata.

### Phase 3a -- compile_target + compile_metadata on TensorDescriptorV2

New per-tensor fields default to ``hf_raw`` / ``{}``; the wire encoder
omits them when default so existing payloads stay byte-identical.
New constants: ``COMPILE_TARGET_HF_RAW``, ``_VLLM_FUSED``,
``_DEEPGEMM_FP8``, ``_CUTLASS_FP8``, ``_TRTLLM``. New helper
``compile_target_matches(descriptor, *, allowed_targets,
required_metadata=None)`` for receiver-side filtering with whitelist
plus required-metadata-subset semantics.

### Phase 3b -- compile_target_filter on discover_v2_sources

New kwargs ``compile_target_filter`` (whitelist set) and
``required_compile_metadata`` (subset-of-every-tensor's
compile_metadata). Candidates with no v2 registry are rejected when
either filter is set; candidates with mixed compile targets are
rejected if any tensor falls outside the allowed set.
``V2SourceCandidate.compile_targets: frozenset[str]`` exposed for
caller introspection.

### Phase 4 -- Multi-source slice discovery for mixed trainer/inference TP

New types: ``TargetTPLayout``, ``SliceSource``, ``SliceCoveragePlan``.
New methods ``MxV2RefitReceiver.discover_v2_sources_for_slice`` and
``MxV2RefitReceiver.receive_via_plan`` -- planner walks v2 candidates
per tensor, intersects each publisher's local_shard_range against the
receiver's requested slice, emits the minimal candidate set covering
it; surfaces coverage gaps and shard_axis mismatches in plan.missing.
``receive_via_plan`` orchestrates per-candidate scratch RDMA pulls
and stitches via torch.cat along the shard axis.

### Proto extensions

* ``TensorDescriptorV2.extra_parameters: map<string, string>`` -- the
  escape hatch downstream RL clients use for per-tensor metadata
  (megatron_role, compile_target, expert_id, revision, training_step,
  ...). Heavily used by #429.
* ``SourceIdentity.revision: string`` -- content-addressed weight
  version. Non-empty value guarantees two sources with identical
  SourceIdentity have bit-identical weight bytes, enabling
  decentralized modes (no central coordinator) to use mx_source_id
  as a full content check.
* Redis backend ``SourceAttributesJson`` carries the new fields plus
  the union of main's compatibility additions
  (backend_framework_version, torch_version, cuda_version,
  triton_version, gpu_arch, compile_config_digest).

## Tests

68/68 unit tests green on this branch's surface (test_v2_source_picker,
test_v2_shape_registry, test_source_id, test_types, test_vllm_adapter).

## Compatibility

Every new field has a backwards-compatible default and every new method
arg is optional. Existing demos and downstream consumers run unchanged.

## Downstream impact

This PR is the prerequisite for several PRs in flight:

* #421 -- publish_self_as_source tree fan-out fix
* #429 -- Megatron-Core MX clients
* #450 -- MX_RDMA_NIC_PIN=stripe multi-NIC mode
* ai-dynamo/dynamo#10900 -- first-time upstream port of
  dynamo.vllm.mx_refit extension
* ai-dynamo/dynamo#10901 -- buffer-caching perf fix in Dynamo extension
* jthomson04/RL#2 (merged) + #7 -- NeMo RL Megatron-MX integration plus
  perf fix

Once this PR lands, kavink/nemo_rl_moe (the integration branch holding

Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
#421 and #429) can be retired by rebasing those PRs onto main.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

backend::vllm Relates to the vllm backend external-contribution Pull request is from an external contributor perf size/XL

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant