Skip to content

feat(ucx_utils): add MX_RDMA_NIC_PIN=stripe for multi-NIC parallelism#450

Open
KavinKrishnan wants to merge 1 commit into
mainfrom
kavink/multi-nic-stripe-mode
Open

feat(ucx_utils): add MX_RDMA_NIC_PIN=stripe for multi-NIC parallelism#450
KavinKrishnan wants to merge 1 commit into
mainfrom
kavink/multi-nic-stripe-mode

Conversation

@KavinKrishnan

@KavinKrishnan KavinKrishnan commented Jun 23, 2026

Copy link
Copy Markdown
Contributor

Summary

Adds MX_RDMA_NIC_PIN=stripe mode to modelexpress.ucx_utils.apply_nic_pin_for_device. Lists all visible compute-rate IB NICs in UCX_NET_DEVICES (comma-joined) and bumps UCX_MAX_RMA_RAILS to the NIC count, so UCX actually uses all of them for parallel RDMA pulls.

Closes #449.

Motivation

MX/NIXL's per-receiver RDMA pull is single-NIC-pinned by default. On multi-NIC pods (4-NIC GB200, 8-NIC GB300, 32-NIC AWS p5) this leaves 75-87.5% of allocated NIC bandwidth idle. Single-receiver bulk-pull on a 4-NIC GB200 pod measures at ~316 Gbps — saturating one NIC, but well below the ~1.2 Tbps the four NICs are capable of.

Striping the pull across all allocated NICs lifts per-receiver bandwidth proportionally to the number of NICs UCX can fan out across, which is the dominant lever for refit-cycle wall time on many-receiver setups.

Change

One file changed in production code (modelexpress/ucx_utils.py): adds _stripe_all_compute_nics, extends _resolve_nic_pin to handle stripe / all aliases, and extends apply_nic_pin_for_device to bump UCX_MAX_RMA_RAILS to the NIC count when striping.

Existing modes (auto, explicit list, off/unset) are unchanged. Stripe is opt-in.

Note on MAX_RMA_RAILS: UCX defaults this to 2, AND NIXL hardcodes it to 2 in nixl/src/plugins/ucx/ucx_utils.cpp:422. The NIXL hardcode overrides our env-set value unless we also patch NIXL — tracked in #449 as Layer 3. For now this PR is the deployment-side fix; users opting into stripe mode should also ensure they're on a NIXL build that respects the env-var override (or apply the future NIXL patch alongside).

Tests

9 new unit tests in tests/test_nic_pin_stripe.py:

Test Asserts
test_stripe_lists_all_compute_nics 4 NICs → mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1
test_stripe_bumps_max_rma_rails_to_nic_count UCX_MAX_RMA_RAILS=4 set on 4-NIC pod
test_stripe_respects_explicit_max_rma_rails User-set rails value is not overwritten
test_stripe_single_nic_skips_rails_bump 1-NIC pod degenerates cleanly
test_stripe_no_nics_visible NIC-less host is a no-op
test_stripe_alias_all MX_RDMA_NIC_PIN=all works as alias
test_auto_mode_does_not_set_rails Existing auto mode unchanged
test_min_rate_filter_passed_to_stripe Rate filter env var forwarded
test_off_mode_unchanged Sanity check for off mode

All 9 pass.

Cluster validation

Validated on a 4-NIC GB200 pod with the existing multi-cycle Qwen3-4B-Thinking refit benchmark, with both the trainer (publisher) and receiver opted into MX_RDMA_NIC_PIN=stripe. Real Megatron-Bridge-loaded model, 290 source tensors, 8.04 GB total, 3 back-to-back refit cycles per run.

Mode (both sides) Cycle 1 wall Warm pull (cycle 2/3) Warm wall Bandwidth
auto (current default — 1 NIC) 0.417 s 0.203 s 0.215 s 316 Gbps
stripe (this PR — 4 NICs) 0.543 s 0.103 s 0.114 s 622 Gbps
Δ +30% (cold) -49% -47% +97%

Cycle-1 wall regresses ~30% because there are 4× more NICs to register against; cycles 2+ recover that and net out ahead because the cached registration carries forward (paired with #421 / #429 / ai-dynamo/dynamo#10901 buffer caching).

Why only ~2× scaling on a 4-NIC pod

Bumping UCX_MAX_RMA_RAILS from 4 to 8 didn't help — bandwidth plateaus at ~625 Gbps. Likely causes:

  1. NIXL hardcodes MAX_RMA_RAILS=2 in nixl/src/plugins/ucx/ucx_utils.cpp:422 via config.modify(...). Per UCX semantics that variant is "modify if not already set via env var", so our env should win — but the empirical bandwidth ceiling is suspiciously close to 2-rail behavior. Worth instrumenting with UCX_LOG_LEVEL=debug to confirm which rail count UCX actually picked.
  2. NUMA crossings: NICs 0/1 are NUMA 0 and NICs 2/3 are NUMA 1 on this pod. If the GPU and scratch buffer are NUMA-0-affine, traffic over NICs 2/3 crosses NUMA, halving effective bandwidth.
  3. UCX's internal striping uses round-robin across rails — for many small tensors (290 buffers, avg 28 MB each), the protocol overhead may cap us before we hit the bandwidth ceiling.

#449's Layer 3 (expose max_rma_rails as a NIXL backend param) targets cause 1.

Test plan

  • Unit tests (9 new, all pass)
  • Cluster validation on 4-NIC GB200 pod — A/B confirmed, 47% warm-cycle wall reduction, ~2× pull bandwidth
  • Update example manifests in examples_dgd/k8s_exemplars/V*/ to use stripe mode (separate small PR after this lands)
  • NIXL Layer 3 — instrument why we plateau at ~2× scaling, expose max_rma_rails as a backend param

Related

@copy-pr-bot

copy-pr-bot Bot commented Jun 23, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai

coderabbitai Bot commented Jun 23, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Walkthrough

Adds stripe/all as a new MX_RDMA_NIC_PIN mode. A new _stripe_all_compute_nics() helper builds a comma-separated UCX_NET_DEVICES from all eligible compute IB NICs. _resolve_nic_pin() and apply_nic_pin_for_device() are extended to handle this mode and automatically set UCX_MAX_RMA_RAILS to the NIC count. A new test module validates all stripe behaviors.

Changes

stripe/all multi-NIC UCX striping support

Layer / File(s) Summary
_stripe_all_compute_nics helper and _resolve_nic_pin stripe handling
modelexpress_client/python/modelexpress/ucx_utils.py
Adds _stripe_all_compute_nics() returning a comma-separated UCX_NET_DEVICES string from all compute IB NICs filtered by MX_RDMA_NIC_PIN_MIN_RATE_GBPS; extends _resolve_nic_pin() to parse the min-rate env var, dispatch to the new helper in stripe/all mode, and emit a warning when no matching NICs are found.
apply_nic_pin_for_device UCX_MAX_RMA_RAILS extension
modelexpress_client/python/modelexpress/ucx_utils.py
In stripe/all mode, after UCX_NET_DEVICES is resolved, sets UCX_MAX_RMA_RAILS to the number of NIC entries when UCX_MAX_RMA_RAILS is not already present in the environment.
stripe mode test suite
modelexpress_client/python/tests/test_nic_pin_stripe.py
New test module with autouse env-clearing fixture, a _fake_nics helper, and tests covering multi-NIC UCX_NET_DEVICES generation, UCX_MAX_RMA_RAILS auto-bump and preservation, single-NIC degenerate case, no-NIC no-op, all alias, auto mode, min-rate forwarding, and off mode sanity check.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Poem

🐇 Hop across the rails, one NIC won't do,
Four stripes of bandwidth, shining brand new!
UCX_NET_DEVICES lists them all in a row,
UCX_MAX_RMA_RAILS tells UCX: let it flow.
No NIC left idle, no bandwidth to waste,
Multi-rail RDMA — with cottontail haste! 🚀

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately and concisely summarizes the main change: adding MX_RDMA_NIC_PIN=stripe mode for multi-NIC parallel RDMA functionality.
Linked Issues check ✅ Passed The PR fully implements Layer 2 of issue #449, adding stripe mode with multi-NIC support, UCX_MAX_RMA_RAILS automation, and comprehensive test coverage as specified.
Out of Scope Changes check ✅ Passed All changes are scoped to Layer 2 implementation: ucx_utils.py enhancements and related test coverage. No unrelated modifications present.
Docstring Coverage ✅ Passed Docstring coverage is 93.75% which is sufficient. The required threshold is 80.00%.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelexpress_client/python/modelexpress/ucx_utils.py`:
- Around line 409-420: The code parses MX_RDMA_NIC_PIN_MIN_RATE_GBPS as a float
but does not validate that the resulting value is finite and non-negative;
float() accepts nan, inf, and negative numbers which would break stripe
filtering semantics. After successfully converting raw_min to a float in the try
block, add validation to check if min_rate is finite (using math.isfinite()) and
greater than or equal to zero. If either check fails, log a warning message
indicating the invalid value and set min_rate to None to fall back to
auto-detect, similar to the existing ValueError handling.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 1143706b-0f88-4e8b-826e-cbdbe01185b1

📥 Commits

Reviewing files that changed from the base of the PR and between 485dfaf and 56e597c.

📒 Files selected for processing (2)
  • modelexpress_client/python/modelexpress/ucx_utils.py
  • modelexpress_client/python/tests/test_nic_pin_stripe.py

Comment on lines +409 to +420
raw_min = os.environ.get("MX_RDMA_NIC_PIN_MIN_RATE_GBPS")
if raw_min is None or raw_min.strip() == "":
min_rate: float | None = None
else:
try:
min_rate = float(raw_min)
except ValueError:
logger.warning(
f"MX_RDMA_NIC_PIN_MIN_RATE_GBPS={raw_min!r} not a float; "
f"falling back to max-rate auto-detect"
)
min_rate = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win

Validate MX_RDMA_NIC_PIN_MIN_RATE_GBPS as finite and non-negative.

float() accepts values like nan, inf, and negatives, which can silently break stripe filtering semantics. Reject non-finite or < 0 values and fall back to auto-detect with a warning.

Suggested fix
     raw_min = os.environ.get("MX_RDMA_NIC_PIN_MIN_RATE_GBPS")
     if raw_min is None or raw_min.strip() == "":
         min_rate: float | None = None
     else:
         try:
             min_rate = float(raw_min)
+            if not math.isfinite(min_rate) or min_rate < 0:
+                raise ValueError("min rate must be finite and >= 0")
         except ValueError:
             logger.warning(
                 f"MX_RDMA_NIC_PIN_MIN_RATE_GBPS={raw_min!r} not a float; "
                 f"falling back to max-rate auto-detect"
             )
             min_rate = None
 import logging
+import math
 import os
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
raw_min = os.environ.get("MX_RDMA_NIC_PIN_MIN_RATE_GBPS")
if raw_min is None or raw_min.strip() == "":
min_rate: float | None = None
else:
try:
min_rate = float(raw_min)
except ValueError:
logger.warning(
f"MX_RDMA_NIC_PIN_MIN_RATE_GBPS={raw_min!r} not a float; "
f"falling back to max-rate auto-detect"
)
min_rate = None
raw_min = os.environ.get("MX_RDMA_NIC_PIN_MIN_RATE_GBPS")
if raw_min is None or raw_min.strip() == "":
min_rate: float | None = None
else:
try:
min_rate = float(raw_min)
if not math.isfinite(min_rate) or min_rate < 0:
raise ValueError("min rate must be finite and >= 0")
except ValueError:
logger.warning(
f"MX_RDMA_NIC_PIN_MIN_RATE_GBPS={raw_min!r} not a float; "
f"falling back to max-rate auto-detect"
)
min_rate = None
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelexpress_client/python/modelexpress/ucx_utils.py` around lines 409 - 420,
The code parses MX_RDMA_NIC_PIN_MIN_RATE_GBPS as a float but does not validate
that the resulting value is finite and non-negative; float() accepts nan, inf,
and negative numbers which would break stripe filtering semantics. After
successfully converting raw_min to a float in the try block, add validation to
check if min_rate is finite (using math.isfinite()) and greater than or equal to
zero. If either check fails, log a warning message indicating the invalid value
and set min_rate to None to fall back to auto-detect, similar to the existing
ValueError handling.

@KavinKrishnan

KavinKrishnan commented Jun 23, 2026

Copy link
Copy Markdown
Contributor Author

Cluster validation on 4-NIC GB200 — confirmed

Validated the stripe mode end-to-end against the existing multi-cycle Qwen3-4B-Thinking refit benchmark, with the trainer (publisher) and receiver both opted into MX_RDMA_NIC_PIN=stripe. Real Megatron-Bridge-loaded model, 290 source tensors, 8.04 GB total, 4 RDMA NICs available per pod.

Mode (both sides) Cycle 1 alloc+register Cycle 1 pull Cycle 1 wall Warm pull (cycle 2/3) Warm wall Bandwidth
auto (current default — 1 NIC) 0.043 + 0.151 s 0.209 s 0.417 s 0.203 s 0.215 s 316 Gbps
stripe (this PR — 4 NICs) 0.025 + 0.331 s 0.172 s 0.543 s 0.103 s 0.114 s 622 Gbps
Δ +50% reg (4 NICs) -18% +30% (cold) -49% -47% +97%

Trainer UCX log on publisher start confirms 4-NIC config picked up:

UCX  INFO  UCX_* env variables: ... UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1 UCX_MAX_RMA_RAILS=4

Bandwidth caveat — only ~2× scaling, not 4×

Going from 1 NIC to 4 NICs got us 2× bandwidth (316 → 622 Gbps), not 4×. Bumping UCX_MAX_RMA_RAILS from 4 to 8 didn't help either — bandwidth plateaus at ~625 Gbps. Likely causes (in order of probability):

  1. NIXL hardcodes MAX_RMA_RAILS=2 in nixl/src/plugins/ucx/ucx_utils.cpp:422 via config.modify(...). Per UCX semantics that variant is "modify if not already set via env var", so our env should win — but the empirical bandwidth ceiling is suspiciously close to 2-rail behavior. Worth instrumenting with UCX_LOG_LEVEL=debug to confirm which rail count UCX actually picked.
  2. NUMA crossings: NICs 0/1 are NUMA 0 and NICs 2/3 are NUMA 1 on this pod. If the GPU + scratch buffer are NUMA-0-affine, traffic over NICs 2/3 crosses NUMA, halving effective bandwidth.
  3. UCX's internal striping uses round-robin across rails — for many small tensors (290 buffers, avg 28 MB each), the protocol overhead may cap us before we hit the bandwidth ceiling.

#449's Layer 3 (expose max_rma_rails as a NIXL backend param + bump default) plus the NUMA fix would target this. For now: 2× is the validated win on a 4-NIC pod, and on 8-NIC pods this should still scale meaningfully.

Extrapolation to larger multi-receiver setups

For a 16 GB model with ~16 concurrent receivers at the same ~50% bandwidth efficiency we saw here, paired with the buffer-caching fix in ai-dynamo/dynamo#10901:

Setup Per-receiver pull 16-receiver steady-state (4-wave dispatch)
Today (1 NIC + caching) 0.42 s 2.5-4.0 s
Stripe @ 2× (this PR, conservative) 0.21 s 1.0-2.0 s
Stripe @ 4× (with NIXL Layer 3 fix) 0.10 s 0.5-1.0 s

Stripe + caching together substantially reduce wall time on multi-receiver dispatches. Layer 3 (NIXL max_rma_rails param) would close most of the remaining ratio to NIC-saturating throughput.

Test plan checkbox update

  • Unit tests (9 new, all pass)
  • Cluster validation on 4-NIC GB200 — A/B confirmed, 47% warm-cycle wall reduction, 2× pull bandwidth
  • Update example manifests to opt into stripe (separate small PR after this lands)
  • NIXL Layer 3 — instrument why we plateau at 2× scaling, expose max_rma_rails param

Issue #449.

Single-NIC pinning leaves 75-87.5% of allocated NIC bandwidth idle on
multi-NIC pods (4-NIC GB200, 8-NIC GB300, 32-NIC AWS p5). Per-receiver
RDMA pull is capped at single-NIC bandwidth as a result, which is the
dominant factor in refit-cycle wall time on many-receiver setups.

This PR adds `MX_RDMA_NIC_PIN=stripe` (alias `all`):
  - Lists ALL compute-rate IB NICs visible to the pod in
    UCX_NET_DEVICES (comma-joined, :1 port suffix on each).
  - Bumps UCX_MAX_RMA_RAILS to the NIC count (only if not already set)
    so UCX actually uses all of them — without this, both UCX's
    default and NIXL's hard-set MAX_RMA_RAILS=2 cap us at 2 NICs.
  - Degenerates gracefully to single-NIC behaviour on 1-NIC pods.
  - Forwards MX_RDMA_NIC_PIN_MIN_RATE_GBPS to filter side-fabric NICs.

Existing modes (auto, explicit list, off) are unchanged. Stripe is
opt-in via the env var.

Cluster validation on a 4-NIC GB200 pod, single-receiver Qwen3-4B
refit benchmark, both publisher + receiver opted into stripe:

  Today (1 NIC, MX_RDMA_NIC_PIN=auto):  0.215 s warm cycle, 316 Gbps
  Stripe (4 NICs, both sides):          0.114 s warm cycle, 622 Gbps
                                        (-47% wall, +97% bandwidth)

The 2x scaling (vs 4x theoretical) is investigated in the PR
description; primary suspect is NIXL's hardcoded MAX_RMA_RAILS=2 in
src/plugins/ucx/ucx_utils.cpp:422 winning over the env-set value.
Tracked as Layer 3 in #449.

Doesn't address structural collective-broadcast wins (tree topology,
collective init amortization). For full-tensor broadcast in
tightly-coupled topologies those need separate work and may not be
worth chasing in MX's design space.

Tests: 9 new unit tests in tests/test_nic_pin_stripe.py covering the
4-NIC stripe path, MAX_RMA_RAILS bump logic, single-NIC degenerate
case, env-var precedence, alias handling, and rate filtering.
Existing nic-pin / ucx tests unaffected.
@KavinKrishnan KavinKrishnan force-pushed the kavink/multi-nic-stripe-mode branch from 56e597c to 2c6fd8a Compare June 24, 2026 04:56
KavinKrishnan added a commit that referenced this pull request Jun 26, 2026
…timizations

Lands the v2 client surface and the surrounding RL workstream needed
to unblock the downstream Megatron-MX and perf PRs.

## What this contributes

### v2 client surface (this PR's primary deliverable)

* `MxV2TrainingPublisher` + `MxV2RefitReceiver` (modelexpress.nemo_rl_v2)
  -- the fat-client surface for per-rank shard publish and
  receiver-side multi-source assembly.
* `MxWeightTransferEngine` (modelexpress.vllm_weight_transfer) -- an
  adapter implementing vLLM's upstream WeightTransferEngine ABC.
* `TensorDescriptorV2.extra_parameters` (map<string,string>) plus
  `SourceIdentity.revision` (string) -- the two proto extensions that
  carry all per-tensor + per-source RL metadata.

### Phase 3a -- compile_target + compile_metadata on TensorDescriptorV2

New per-tensor fields default to ``hf_raw`` / ``{}``; the wire encoder
omits them when default so existing payloads stay byte-identical.
New constants: ``COMPILE_TARGET_HF_RAW``, ``_VLLM_FUSED``,
``_DEEPGEMM_FP8``, ``_CUTLASS_FP8``, ``_TRTLLM``. New helper
``compile_target_matches(descriptor, *, allowed_targets,
required_metadata=None)`` for receiver-side filtering with whitelist
plus required-metadata-subset semantics.

### Phase 3b -- compile_target_filter on discover_v2_sources

New kwargs ``compile_target_filter`` (whitelist set) and
``required_compile_metadata`` (subset-of-every-tensor's
compile_metadata). Candidates with no v2 registry are rejected when
either filter is set; candidates with mixed compile targets are
rejected if any tensor falls outside the allowed set.
``V2SourceCandidate.compile_targets: frozenset[str]`` exposed for
caller introspection.

### Phase 4 -- Multi-source slice discovery for mixed trainer/inference TP

New types: ``TargetTPLayout``, ``SliceSource``, ``SliceCoveragePlan``.
New methods ``MxV2RefitReceiver.discover_v2_sources_for_slice`` and
``MxV2RefitReceiver.receive_via_plan`` -- planner walks v2 candidates
per tensor, intersects each publisher's local_shard_range against the
receiver's requested slice, emits the minimal candidate set covering
it; surfaces coverage gaps and shard_axis mismatches in plan.missing.
``receive_via_plan`` orchestrates per-candidate scratch RDMA pulls
and stitches via torch.cat along the shard axis.

### Proto extensions

* ``TensorDescriptorV2.extra_parameters: map<string, string>`` -- the
  escape hatch downstream RL clients use for per-tensor metadata
  (megatron_role, compile_target, expert_id, revision, training_step,
  ...). Heavily used by #429.
* ``SourceIdentity.revision: string`` -- content-addressed weight
  version. Non-empty value guarantees two sources with identical
  SourceIdentity have bit-identical weight bytes, enabling
  decentralized modes (no central coordinator) to use mx_source_id
  as a full content check.
* Redis backend ``SourceAttributesJson`` carries the new fields plus
  the union of main's compatibility additions
  (backend_framework_version, torch_version, cuda_version,
  triton_version, gpu_arch, compile_config_digest).

## Tests

68/68 unit tests green on this branch's surface (test_v2_source_picker,
test_v2_shape_registry, test_source_id, test_types, test_vllm_adapter).

## Compatibility

Every new field has a backwards-compatible default and every new method
arg is optional. Existing demos and downstream consumers run unchanged.

## Downstream impact

This PR is the prerequisite for several PRs in flight:

* #421 -- publish_self_as_source tree fan-out fix
* #429 -- Megatron-Core MX clients
* #450 -- MX_RDMA_NIC_PIN=stripe multi-NIC mode
* ai-dynamo/dynamo#10900 -- first-time upstream port of
  dynamo.vllm.mx_refit extension
* ai-dynamo/dynamo#10901 -- buffer-caching perf fix in Dynamo extension
* jthomson04/RL#2 (merged) + #7 -- NeMo RL Megatron-MX integration plus
  perf fix

Once this PR lands, kavink/nemo_rl_moe (the integration branch holding

Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
#421 and #429) can be retired by rebasing those PRs onto main.
KavinKrishnan added a commit that referenced this pull request Jun 26, 2026
…timizations

Lands the v2 client surface and the surrounding RL workstream needed
to unblock the downstream Megatron-MX and perf PRs.

## What this contributes

### v2 client surface (this PR's primary deliverable)

* `MxV2TrainingPublisher` + `MxV2RefitReceiver` (modelexpress.nemo_rl_v2)
  -- the fat-client surface for per-rank shard publish and
  receiver-side multi-source assembly.
* `MxWeightTransferEngine` (modelexpress.vllm_weight_transfer) -- an
  adapter implementing vLLM's upstream WeightTransferEngine ABC.
* `TensorDescriptorV2.extra_parameters` (map<string,string>) plus
  `SourceIdentity.revision` (string) -- the two proto extensions that
  carry all per-tensor + per-source RL metadata.

### Phase 3a -- compile_target + compile_metadata on TensorDescriptorV2

New per-tensor fields default to ``hf_raw`` / ``{}``; the wire encoder
omits them when default so existing payloads stay byte-identical.
New constants: ``COMPILE_TARGET_HF_RAW``, ``_VLLM_FUSED``,
``_DEEPGEMM_FP8``, ``_CUTLASS_FP8``, ``_TRTLLM``. New helper
``compile_target_matches(descriptor, *, allowed_targets,
required_metadata=None)`` for receiver-side filtering with whitelist
plus required-metadata-subset semantics.

### Phase 3b -- compile_target_filter on discover_v2_sources

New kwargs ``compile_target_filter`` (whitelist set) and
``required_compile_metadata`` (subset-of-every-tensor's
compile_metadata). Candidates with no v2 registry are rejected when
either filter is set; candidates with mixed compile targets are
rejected if any tensor falls outside the allowed set.
``V2SourceCandidate.compile_targets: frozenset[str]`` exposed for
caller introspection.

### Phase 4 -- Multi-source slice discovery for mixed trainer/inference TP

New types: ``TargetTPLayout``, ``SliceSource``, ``SliceCoveragePlan``.
New methods ``MxV2RefitReceiver.discover_v2_sources_for_slice`` and
``MxV2RefitReceiver.receive_via_plan`` -- planner walks v2 candidates
per tensor, intersects each publisher's local_shard_range against the
receiver's requested slice, emits the minimal candidate set covering
it; surfaces coverage gaps and shard_axis mismatches in plan.missing.
``receive_via_plan`` orchestrates per-candidate scratch RDMA pulls
and stitches via torch.cat along the shard axis.

### Proto extensions

* ``TensorDescriptorV2.extra_parameters: map<string, string>`` -- the
  escape hatch downstream RL clients use for per-tensor metadata
  (megatron_role, compile_target, expert_id, revision, training_step,
  ...). Heavily used by #429.
* ``SourceIdentity.revision: string`` -- content-addressed weight
  version. Non-empty value guarantees two sources with identical
  SourceIdentity have bit-identical weight bytes, enabling
  decentralized modes (no central coordinator) to use mx_source_id
  as a full content check.
* Redis backend ``SourceAttributesJson`` carries the new fields plus
  the union of main's compatibility additions
  (backend_framework_version, torch_version, cuda_version,
  triton_version, gpu_arch, compile_config_digest).

## Tests

68/68 unit tests green on this branch's surface (test_v2_source_picker,
test_v2_shape_registry, test_source_id, test_types, test_vllm_adapter).

## Compatibility

Every new field has a backwards-compatible default and every new method
arg is optional. Existing demos and downstream consumers run unchanged.

## Downstream impact

This PR is the prerequisite for several PRs in flight:

* #421 -- publish_self_as_source tree fan-out fix
* #429 -- Megatron-Core MX clients
* #450 -- MX_RDMA_NIC_PIN=stripe multi-NIC mode
* ai-dynamo/dynamo#10900 -- first-time upstream port of
  dynamo.vllm.mx_refit extension
* ai-dynamo/dynamo#10901 -- buffer-caching perf fix in Dynamo extension
* jthomson04/RL#2 (merged) + #7 -- NeMo RL Megatron-MX integration plus
  perf fix

Once this PR lands, kavink/nemo_rl_moe (the integration branch holding

Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
#421 and #429) can be retired by rebasing those PRs onto main.
KavinKrishnan added a commit that referenced this pull request Jun 26, 2026
…timizations

Lands the v2 client surface and the surrounding RL workstream needed
to unblock the downstream Megatron-MX and perf PRs.

## What this contributes

### v2 client surface (this PR's primary deliverable)

* `MxV2TrainingPublisher` + `MxV2RefitReceiver` (modelexpress.nemo_rl_v2)
  -- the fat-client surface for per-rank shard publish and
  receiver-side multi-source assembly.
* `MxWeightTransferEngine` (modelexpress.vllm_weight_transfer) -- an
  adapter implementing vLLM's upstream WeightTransferEngine ABC.
* `TensorDescriptorV2.extra_parameters` (map<string,string>) plus
  `SourceIdentity.revision` (string) -- the two proto extensions that
  carry all per-tensor + per-source RL metadata.

### Phase 3a -- compile_target + compile_metadata on TensorDescriptorV2

New per-tensor fields default to ``hf_raw`` / ``{}``; the wire encoder
omits them when default so existing payloads stay byte-identical.
New constants: ``COMPILE_TARGET_HF_RAW``, ``_VLLM_FUSED``,
``_DEEPGEMM_FP8``, ``_CUTLASS_FP8``, ``_TRTLLM``. New helper
``compile_target_matches(descriptor, *, allowed_targets,
required_metadata=None)`` for receiver-side filtering with whitelist
plus required-metadata-subset semantics.

### Phase 3b -- compile_target_filter on discover_v2_sources

New kwargs ``compile_target_filter`` (whitelist set) and
``required_compile_metadata`` (subset-of-every-tensor's
compile_metadata). Candidates with no v2 registry are rejected when
either filter is set; candidates with mixed compile targets are
rejected if any tensor falls outside the allowed set.
``V2SourceCandidate.compile_targets: frozenset[str]`` exposed for
caller introspection.

### Phase 4 -- Multi-source slice discovery for mixed trainer/inference TP

New types: ``TargetTPLayout``, ``SliceSource``, ``SliceCoveragePlan``.
New methods ``MxV2RefitReceiver.discover_v2_sources_for_slice`` and
``MxV2RefitReceiver.receive_via_plan`` -- planner walks v2 candidates
per tensor, intersects each publisher's local_shard_range against the
receiver's requested slice, emits the minimal candidate set covering
it; surfaces coverage gaps and shard_axis mismatches in plan.missing.
``receive_via_plan`` orchestrates per-candidate scratch RDMA pulls
and stitches via torch.cat along the shard axis.

### Proto extensions

* ``TensorDescriptorV2.extra_parameters: map<string, string>`` -- the
  escape hatch downstream RL clients use for per-tensor metadata
  (megatron_role, compile_target, expert_id, revision, training_step,
  ...). Heavily used by #429.
* ``SourceIdentity.revision: string`` -- content-addressed weight
  version. Non-empty value guarantees two sources with identical
  SourceIdentity have bit-identical weight bytes, enabling
  decentralized modes (no central coordinator) to use mx_source_id
  as a full content check.
* Redis backend ``SourceAttributesJson`` carries the new fields plus
  the union of main's compatibility additions
  (backend_framework_version, torch_version, cuda_version,
  triton_version, gpu_arch, compile_config_digest).

## Tests

68/68 unit tests green on this branch's surface (test_v2_source_picker,
test_v2_shape_registry, test_source_id, test_types, test_vllm_adapter).

## Compatibility

Every new field has a backwards-compatible default and every new method
arg is optional. Existing demos and downstream consumers run unchanged.

## Downstream impact

This PR is the prerequisite for several PRs in flight:

* #421 -- publish_self_as_source tree fan-out fix
* #429 -- Megatron-Core MX clients
* #450 -- MX_RDMA_NIC_PIN=stripe multi-NIC mode
* ai-dynamo/dynamo#10900 -- first-time upstream port of
  dynamo.vllm.mx_refit extension
* ai-dynamo/dynamo#10901 -- buffer-caching perf fix in Dynamo extension
* jthomson04/RL#2 (merged) + #7 -- NeMo RL Megatron-MX integration plus
  perf fix

Once this PR lands, kavink/nemo_rl_moe (the integration branch holding

Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
#421 and #429) can be retired by rebasing those PRs onto main.
KavinKrishnan added a commit that referenced this pull request Jun 26, 2026
…timizations

Lands the v2 client surface and the surrounding RL workstream needed
to unblock the downstream Megatron-MX and perf PRs.

## What this contributes

### v2 client surface (this PR's primary deliverable)

* `MxV2TrainingPublisher` + `MxV2RefitReceiver` (modelexpress.nemo_rl_v2)
  -- the fat-client surface for per-rank shard publish and
  receiver-side multi-source assembly.
* `MxWeightTransferEngine` (modelexpress.vllm_weight_transfer) -- an
  adapter implementing vLLM's upstream WeightTransferEngine ABC.
* `TensorDescriptorV2.extra_parameters` (map<string,string>) plus
  `SourceIdentity.revision` (string) -- the two proto extensions that
  carry all per-tensor + per-source RL metadata.

### Phase 3a -- compile_target + compile_metadata on TensorDescriptorV2

New per-tensor fields default to ``hf_raw`` / ``{}``; the wire encoder
omits them when default so existing payloads stay byte-identical.
New constants: ``COMPILE_TARGET_HF_RAW``, ``_VLLM_FUSED``,
``_DEEPGEMM_FP8``, ``_CUTLASS_FP8``, ``_TRTLLM``. New helper
``compile_target_matches(descriptor, *, allowed_targets,
required_metadata=None)`` for receiver-side filtering with whitelist
plus required-metadata-subset semantics.

### Phase 3b -- compile_target_filter on discover_v2_sources

New kwargs ``compile_target_filter`` (whitelist set) and
``required_compile_metadata`` (subset-of-every-tensor's
compile_metadata). Candidates with no v2 registry are rejected when
either filter is set; candidates with mixed compile targets are
rejected if any tensor falls outside the allowed set.
``V2SourceCandidate.compile_targets: frozenset[str]`` exposed for
caller introspection.

### Phase 4 -- Multi-source slice discovery for mixed trainer/inference TP

New types: ``TargetTPLayout``, ``SliceSource``, ``SliceCoveragePlan``.
New methods ``MxV2RefitReceiver.discover_v2_sources_for_slice`` and
``MxV2RefitReceiver.receive_via_plan`` -- planner walks v2 candidates
per tensor, intersects each publisher's local_shard_range against the
receiver's requested slice, emits the minimal candidate set covering
it; surfaces coverage gaps and shard_axis mismatches in plan.missing.
``receive_via_plan`` orchestrates per-candidate scratch RDMA pulls
and stitches via torch.cat along the shard axis.

### Proto extensions

* ``TensorDescriptorV2.extra_parameters: map<string, string>`` -- the
  escape hatch downstream RL clients use for per-tensor metadata
  (megatron_role, compile_target, expert_id, revision, training_step,
  ...). Heavily used by #429.
* ``SourceIdentity.revision: string`` -- content-addressed weight
  version. Non-empty value guarantees two sources with identical
  SourceIdentity have bit-identical weight bytes, enabling
  decentralized modes (no central coordinator) to use mx_source_id
  as a full content check.
* Redis backend ``SourceAttributesJson`` carries the new fields plus
  the union of main's compatibility additions
  (backend_framework_version, torch_version, cuda_version,
  triton_version, gpu_arch, compile_config_digest).

## Tests

68/68 unit tests green on this branch's surface (test_v2_source_picker,
test_v2_shape_registry, test_source_id, test_types, test_vllm_adapter).

## Compatibility

Every new field has a backwards-compatible default and every new method
arg is optional. Existing demos and downstream consumers run unchanged.

## Downstream impact

This PR is the prerequisite for several PRs in flight:

* #421 -- publish_self_as_source tree fan-out fix
* #429 -- Megatron-Core MX clients
* #450 -- MX_RDMA_NIC_PIN=stripe multi-NIC mode
* ai-dynamo/dynamo#10900 -- first-time upstream port of
  dynamo.vllm.mx_refit extension
* ai-dynamo/dynamo#10901 -- buffer-caching perf fix in Dynamo extension
* jthomson04/RL#2 (merged) + #7 -- NeMo RL Megatron-MX integration plus
  perf fix

Once this PR lands, kavink/nemo_rl_moe (the integration branch holding

Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
#421 and #429) can be retired by rebasing those PRs onto main.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

perf: multi-NIC parallel RDMA pull — MX is single-NIC-pinned by default, leaves 75-87% of allocated bandwidth idle

1 participant