feat(ucx_utils): add MX_RDMA_NIC_PIN=stripe for multi-NIC parallelism#450
feat(ucx_utils): add MX_RDMA_NIC_PIN=stripe for multi-NIC parallelism#450KavinKrishnan wants to merge 1 commit into
Conversation
WalkthroughAdds Changesstripe/all multi-NIC UCX striping support
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelexpress_client/python/modelexpress/ucx_utils.py`:
- Around line 409-420: The code parses MX_RDMA_NIC_PIN_MIN_RATE_GBPS as a float
but does not validate that the resulting value is finite and non-negative;
float() accepts nan, inf, and negative numbers which would break stripe
filtering semantics. After successfully converting raw_min to a float in the try
block, add validation to check if min_rate is finite (using math.isfinite()) and
greater than or equal to zero. If either check fails, log a warning message
indicating the invalid value and set min_rate to None to fall back to
auto-detect, similar to the existing ValueError handling.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 1143706b-0f88-4e8b-826e-cbdbe01185b1
📒 Files selected for processing (2)
modelexpress_client/python/modelexpress/ucx_utils.pymodelexpress_client/python/tests/test_nic_pin_stripe.py
| raw_min = os.environ.get("MX_RDMA_NIC_PIN_MIN_RATE_GBPS") | ||
| if raw_min is None or raw_min.strip() == "": | ||
| min_rate: float | None = None | ||
| else: | ||
| try: | ||
| min_rate = float(raw_min) | ||
| except ValueError: | ||
| logger.warning( | ||
| f"MX_RDMA_NIC_PIN_MIN_RATE_GBPS={raw_min!r} not a float; " | ||
| f"falling back to max-rate auto-detect" | ||
| ) | ||
| min_rate = None |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟡 Minor | ⚡ Quick win
Validate MX_RDMA_NIC_PIN_MIN_RATE_GBPS as finite and non-negative.
float() accepts values like nan, inf, and negatives, which can silently break stripe filtering semantics. Reject non-finite or < 0 values and fall back to auto-detect with a warning.
Suggested fix
raw_min = os.environ.get("MX_RDMA_NIC_PIN_MIN_RATE_GBPS")
if raw_min is None or raw_min.strip() == "":
min_rate: float | None = None
else:
try:
min_rate = float(raw_min)
+ if not math.isfinite(min_rate) or min_rate < 0:
+ raise ValueError("min rate must be finite and >= 0")
except ValueError:
logger.warning(
f"MX_RDMA_NIC_PIN_MIN_RATE_GBPS={raw_min!r} not a float; "
f"falling back to max-rate auto-detect"
)
min_rate = None import logging
+import math
import os📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| raw_min = os.environ.get("MX_RDMA_NIC_PIN_MIN_RATE_GBPS") | |
| if raw_min is None or raw_min.strip() == "": | |
| min_rate: float | None = None | |
| else: | |
| try: | |
| min_rate = float(raw_min) | |
| except ValueError: | |
| logger.warning( | |
| f"MX_RDMA_NIC_PIN_MIN_RATE_GBPS={raw_min!r} not a float; " | |
| f"falling back to max-rate auto-detect" | |
| ) | |
| min_rate = None | |
| raw_min = os.environ.get("MX_RDMA_NIC_PIN_MIN_RATE_GBPS") | |
| if raw_min is None or raw_min.strip() == "": | |
| min_rate: float | None = None | |
| else: | |
| try: | |
| min_rate = float(raw_min) | |
| if not math.isfinite(min_rate) or min_rate < 0: | |
| raise ValueError("min rate must be finite and >= 0") | |
| except ValueError: | |
| logger.warning( | |
| f"MX_RDMA_NIC_PIN_MIN_RATE_GBPS={raw_min!r} not a float; " | |
| f"falling back to max-rate auto-detect" | |
| ) | |
| min_rate = None |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@modelexpress_client/python/modelexpress/ucx_utils.py` around lines 409 - 420,
The code parses MX_RDMA_NIC_PIN_MIN_RATE_GBPS as a float but does not validate
that the resulting value is finite and non-negative; float() accepts nan, inf,
and negative numbers which would break stripe filtering semantics. After
successfully converting raw_min to a float in the try block, add validation to
check if min_rate is finite (using math.isfinite()) and greater than or equal to
zero. If either check fails, log a warning message indicating the invalid value
and set min_rate to None to fall back to auto-detect, similar to the existing
ValueError handling.
Cluster validation on 4-NIC GB200 — confirmedValidated the stripe mode end-to-end against the existing multi-cycle Qwen3-4B-Thinking refit benchmark, with the trainer (publisher) and receiver both opted into
Trainer UCX log on publisher start confirms 4-NIC config picked up: Bandwidth caveat — only ~2× scaling, not 4×Going from 1 NIC to 4 NICs got us 2× bandwidth (316 → 622 Gbps), not 4×. Bumping
#449's Layer 3 (expose Extrapolation to larger multi-receiver setupsFor a 16 GB model with ~16 concurrent receivers at the same ~50% bandwidth efficiency we saw here, paired with the buffer-caching fix in ai-dynamo/dynamo#10901:
Stripe + caching together substantially reduce wall time on multi-receiver dispatches. Layer 3 (NIXL Test plan checkbox update
|
Issue #449. Single-NIC pinning leaves 75-87.5% of allocated NIC bandwidth idle on multi-NIC pods (4-NIC GB200, 8-NIC GB300, 32-NIC AWS p5). Per-receiver RDMA pull is capped at single-NIC bandwidth as a result, which is the dominant factor in refit-cycle wall time on many-receiver setups. This PR adds `MX_RDMA_NIC_PIN=stripe` (alias `all`): - Lists ALL compute-rate IB NICs visible to the pod in UCX_NET_DEVICES (comma-joined, :1 port suffix on each). - Bumps UCX_MAX_RMA_RAILS to the NIC count (only if not already set) so UCX actually uses all of them — without this, both UCX's default and NIXL's hard-set MAX_RMA_RAILS=2 cap us at 2 NICs. - Degenerates gracefully to single-NIC behaviour on 1-NIC pods. - Forwards MX_RDMA_NIC_PIN_MIN_RATE_GBPS to filter side-fabric NICs. Existing modes (auto, explicit list, off) are unchanged. Stripe is opt-in via the env var. Cluster validation on a 4-NIC GB200 pod, single-receiver Qwen3-4B refit benchmark, both publisher + receiver opted into stripe: Today (1 NIC, MX_RDMA_NIC_PIN=auto): 0.215 s warm cycle, 316 Gbps Stripe (4 NICs, both sides): 0.114 s warm cycle, 622 Gbps (-47% wall, +97% bandwidth) The 2x scaling (vs 4x theoretical) is investigated in the PR description; primary suspect is NIXL's hardcoded MAX_RMA_RAILS=2 in src/plugins/ucx/ucx_utils.cpp:422 winning over the env-set value. Tracked as Layer 3 in #449. Doesn't address structural collective-broadcast wins (tree topology, collective init amortization). For full-tensor broadcast in tightly-coupled topologies those need separate work and may not be worth chasing in MX's design space. Tests: 9 new unit tests in tests/test_nic_pin_stripe.py covering the 4-NIC stripe path, MAX_RMA_RAILS bump logic, single-NIC degenerate case, env-var precedence, alias handling, and rate filtering. Existing nic-pin / ucx tests unaffected.
56e597c to
2c6fd8a
Compare
…timizations
Lands the v2 client surface and the surrounding RL workstream needed
to unblock the downstream Megatron-MX and perf PRs.
## What this contributes
### v2 client surface (this PR's primary deliverable)
* `MxV2TrainingPublisher` + `MxV2RefitReceiver` (modelexpress.nemo_rl_v2)
-- the fat-client surface for per-rank shard publish and
receiver-side multi-source assembly.
* `MxWeightTransferEngine` (modelexpress.vllm_weight_transfer) -- an
adapter implementing vLLM's upstream WeightTransferEngine ABC.
* `TensorDescriptorV2.extra_parameters` (map<string,string>) plus
`SourceIdentity.revision` (string) -- the two proto extensions that
carry all per-tensor + per-source RL metadata.
### Phase 3a -- compile_target + compile_metadata on TensorDescriptorV2
New per-tensor fields default to ``hf_raw`` / ``{}``; the wire encoder
omits them when default so existing payloads stay byte-identical.
New constants: ``COMPILE_TARGET_HF_RAW``, ``_VLLM_FUSED``,
``_DEEPGEMM_FP8``, ``_CUTLASS_FP8``, ``_TRTLLM``. New helper
``compile_target_matches(descriptor, *, allowed_targets,
required_metadata=None)`` for receiver-side filtering with whitelist
plus required-metadata-subset semantics.
### Phase 3b -- compile_target_filter on discover_v2_sources
New kwargs ``compile_target_filter`` (whitelist set) and
``required_compile_metadata`` (subset-of-every-tensor's
compile_metadata). Candidates with no v2 registry are rejected when
either filter is set; candidates with mixed compile targets are
rejected if any tensor falls outside the allowed set.
``V2SourceCandidate.compile_targets: frozenset[str]`` exposed for
caller introspection.
### Phase 4 -- Multi-source slice discovery for mixed trainer/inference TP
New types: ``TargetTPLayout``, ``SliceSource``, ``SliceCoveragePlan``.
New methods ``MxV2RefitReceiver.discover_v2_sources_for_slice`` and
``MxV2RefitReceiver.receive_via_plan`` -- planner walks v2 candidates
per tensor, intersects each publisher's local_shard_range against the
receiver's requested slice, emits the minimal candidate set covering
it; surfaces coverage gaps and shard_axis mismatches in plan.missing.
``receive_via_plan`` orchestrates per-candidate scratch RDMA pulls
and stitches via torch.cat along the shard axis.
### Proto extensions
* ``TensorDescriptorV2.extra_parameters: map<string, string>`` -- the
escape hatch downstream RL clients use for per-tensor metadata
(megatron_role, compile_target, expert_id, revision, training_step,
...). Heavily used by #429.
* ``SourceIdentity.revision: string`` -- content-addressed weight
version. Non-empty value guarantees two sources with identical
SourceIdentity have bit-identical weight bytes, enabling
decentralized modes (no central coordinator) to use mx_source_id
as a full content check.
* Redis backend ``SourceAttributesJson`` carries the new fields plus
the union of main's compatibility additions
(backend_framework_version, torch_version, cuda_version,
triton_version, gpu_arch, compile_config_digest).
## Tests
68/68 unit tests green on this branch's surface (test_v2_source_picker,
test_v2_shape_registry, test_source_id, test_types, test_vllm_adapter).
## Compatibility
Every new field has a backwards-compatible default and every new method
arg is optional. Existing demos and downstream consumers run unchanged.
## Downstream impact
This PR is the prerequisite for several PRs in flight:
* #421 -- publish_self_as_source tree fan-out fix
* #429 -- Megatron-Core MX clients
* #450 -- MX_RDMA_NIC_PIN=stripe multi-NIC mode
* ai-dynamo/dynamo#10900 -- first-time upstream port of
dynamo.vllm.mx_refit extension
* ai-dynamo/dynamo#10901 -- buffer-caching perf fix in Dynamo extension
* jthomson04/RL#2 (merged) + #7 -- NeMo RL Megatron-MX integration plus
perf fix
Once this PR lands, kavink/nemo_rl_moe (the integration branch holding
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
#421 and #429) can be retired by rebasing those PRs onto main.
…timizations
Lands the v2 client surface and the surrounding RL workstream needed
to unblock the downstream Megatron-MX and perf PRs.
## What this contributes
### v2 client surface (this PR's primary deliverable)
* `MxV2TrainingPublisher` + `MxV2RefitReceiver` (modelexpress.nemo_rl_v2)
-- the fat-client surface for per-rank shard publish and
receiver-side multi-source assembly.
* `MxWeightTransferEngine` (modelexpress.vllm_weight_transfer) -- an
adapter implementing vLLM's upstream WeightTransferEngine ABC.
* `TensorDescriptorV2.extra_parameters` (map<string,string>) plus
`SourceIdentity.revision` (string) -- the two proto extensions that
carry all per-tensor + per-source RL metadata.
### Phase 3a -- compile_target + compile_metadata on TensorDescriptorV2
New per-tensor fields default to ``hf_raw`` / ``{}``; the wire encoder
omits them when default so existing payloads stay byte-identical.
New constants: ``COMPILE_TARGET_HF_RAW``, ``_VLLM_FUSED``,
``_DEEPGEMM_FP8``, ``_CUTLASS_FP8``, ``_TRTLLM``. New helper
``compile_target_matches(descriptor, *, allowed_targets,
required_metadata=None)`` for receiver-side filtering with whitelist
plus required-metadata-subset semantics.
### Phase 3b -- compile_target_filter on discover_v2_sources
New kwargs ``compile_target_filter`` (whitelist set) and
``required_compile_metadata`` (subset-of-every-tensor's
compile_metadata). Candidates with no v2 registry are rejected when
either filter is set; candidates with mixed compile targets are
rejected if any tensor falls outside the allowed set.
``V2SourceCandidate.compile_targets: frozenset[str]`` exposed for
caller introspection.
### Phase 4 -- Multi-source slice discovery for mixed trainer/inference TP
New types: ``TargetTPLayout``, ``SliceSource``, ``SliceCoveragePlan``.
New methods ``MxV2RefitReceiver.discover_v2_sources_for_slice`` and
``MxV2RefitReceiver.receive_via_plan`` -- planner walks v2 candidates
per tensor, intersects each publisher's local_shard_range against the
receiver's requested slice, emits the minimal candidate set covering
it; surfaces coverage gaps and shard_axis mismatches in plan.missing.
``receive_via_plan`` orchestrates per-candidate scratch RDMA pulls
and stitches via torch.cat along the shard axis.
### Proto extensions
* ``TensorDescriptorV2.extra_parameters: map<string, string>`` -- the
escape hatch downstream RL clients use for per-tensor metadata
(megatron_role, compile_target, expert_id, revision, training_step,
...). Heavily used by #429.
* ``SourceIdentity.revision: string`` -- content-addressed weight
version. Non-empty value guarantees two sources with identical
SourceIdentity have bit-identical weight bytes, enabling
decentralized modes (no central coordinator) to use mx_source_id
as a full content check.
* Redis backend ``SourceAttributesJson`` carries the new fields plus
the union of main's compatibility additions
(backend_framework_version, torch_version, cuda_version,
triton_version, gpu_arch, compile_config_digest).
## Tests
68/68 unit tests green on this branch's surface (test_v2_source_picker,
test_v2_shape_registry, test_source_id, test_types, test_vllm_adapter).
## Compatibility
Every new field has a backwards-compatible default and every new method
arg is optional. Existing demos and downstream consumers run unchanged.
## Downstream impact
This PR is the prerequisite for several PRs in flight:
* #421 -- publish_self_as_source tree fan-out fix
* #429 -- Megatron-Core MX clients
* #450 -- MX_RDMA_NIC_PIN=stripe multi-NIC mode
* ai-dynamo/dynamo#10900 -- first-time upstream port of
dynamo.vllm.mx_refit extension
* ai-dynamo/dynamo#10901 -- buffer-caching perf fix in Dynamo extension
* jthomson04/RL#2 (merged) + #7 -- NeMo RL Megatron-MX integration plus
perf fix
Once this PR lands, kavink/nemo_rl_moe (the integration branch holding
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
#421 and #429) can be retired by rebasing those PRs onto main.
…timizations
Lands the v2 client surface and the surrounding RL workstream needed
to unblock the downstream Megatron-MX and perf PRs.
## What this contributes
### v2 client surface (this PR's primary deliverable)
* `MxV2TrainingPublisher` + `MxV2RefitReceiver` (modelexpress.nemo_rl_v2)
-- the fat-client surface for per-rank shard publish and
receiver-side multi-source assembly.
* `MxWeightTransferEngine` (modelexpress.vllm_weight_transfer) -- an
adapter implementing vLLM's upstream WeightTransferEngine ABC.
* `TensorDescriptorV2.extra_parameters` (map<string,string>) plus
`SourceIdentity.revision` (string) -- the two proto extensions that
carry all per-tensor + per-source RL metadata.
### Phase 3a -- compile_target + compile_metadata on TensorDescriptorV2
New per-tensor fields default to ``hf_raw`` / ``{}``; the wire encoder
omits them when default so existing payloads stay byte-identical.
New constants: ``COMPILE_TARGET_HF_RAW``, ``_VLLM_FUSED``,
``_DEEPGEMM_FP8``, ``_CUTLASS_FP8``, ``_TRTLLM``. New helper
``compile_target_matches(descriptor, *, allowed_targets,
required_metadata=None)`` for receiver-side filtering with whitelist
plus required-metadata-subset semantics.
### Phase 3b -- compile_target_filter on discover_v2_sources
New kwargs ``compile_target_filter`` (whitelist set) and
``required_compile_metadata`` (subset-of-every-tensor's
compile_metadata). Candidates with no v2 registry are rejected when
either filter is set; candidates with mixed compile targets are
rejected if any tensor falls outside the allowed set.
``V2SourceCandidate.compile_targets: frozenset[str]`` exposed for
caller introspection.
### Phase 4 -- Multi-source slice discovery for mixed trainer/inference TP
New types: ``TargetTPLayout``, ``SliceSource``, ``SliceCoveragePlan``.
New methods ``MxV2RefitReceiver.discover_v2_sources_for_slice`` and
``MxV2RefitReceiver.receive_via_plan`` -- planner walks v2 candidates
per tensor, intersects each publisher's local_shard_range against the
receiver's requested slice, emits the minimal candidate set covering
it; surfaces coverage gaps and shard_axis mismatches in plan.missing.
``receive_via_plan`` orchestrates per-candidate scratch RDMA pulls
and stitches via torch.cat along the shard axis.
### Proto extensions
* ``TensorDescriptorV2.extra_parameters: map<string, string>`` -- the
escape hatch downstream RL clients use for per-tensor metadata
(megatron_role, compile_target, expert_id, revision, training_step,
...). Heavily used by #429.
* ``SourceIdentity.revision: string`` -- content-addressed weight
version. Non-empty value guarantees two sources with identical
SourceIdentity have bit-identical weight bytes, enabling
decentralized modes (no central coordinator) to use mx_source_id
as a full content check.
* Redis backend ``SourceAttributesJson`` carries the new fields plus
the union of main's compatibility additions
(backend_framework_version, torch_version, cuda_version,
triton_version, gpu_arch, compile_config_digest).
## Tests
68/68 unit tests green on this branch's surface (test_v2_source_picker,
test_v2_shape_registry, test_source_id, test_types, test_vllm_adapter).
## Compatibility
Every new field has a backwards-compatible default and every new method
arg is optional. Existing demos and downstream consumers run unchanged.
## Downstream impact
This PR is the prerequisite for several PRs in flight:
* #421 -- publish_self_as_source tree fan-out fix
* #429 -- Megatron-Core MX clients
* #450 -- MX_RDMA_NIC_PIN=stripe multi-NIC mode
* ai-dynamo/dynamo#10900 -- first-time upstream port of
dynamo.vllm.mx_refit extension
* ai-dynamo/dynamo#10901 -- buffer-caching perf fix in Dynamo extension
* jthomson04/RL#2 (merged) + #7 -- NeMo RL Megatron-MX integration plus
perf fix
Once this PR lands, kavink/nemo_rl_moe (the integration branch holding
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
#421 and #429) can be retired by rebasing those PRs onto main.
…timizations
Lands the v2 client surface and the surrounding RL workstream needed
to unblock the downstream Megatron-MX and perf PRs.
## What this contributes
### v2 client surface (this PR's primary deliverable)
* `MxV2TrainingPublisher` + `MxV2RefitReceiver` (modelexpress.nemo_rl_v2)
-- the fat-client surface for per-rank shard publish and
receiver-side multi-source assembly.
* `MxWeightTransferEngine` (modelexpress.vllm_weight_transfer) -- an
adapter implementing vLLM's upstream WeightTransferEngine ABC.
* `TensorDescriptorV2.extra_parameters` (map<string,string>) plus
`SourceIdentity.revision` (string) -- the two proto extensions that
carry all per-tensor + per-source RL metadata.
### Phase 3a -- compile_target + compile_metadata on TensorDescriptorV2
New per-tensor fields default to ``hf_raw`` / ``{}``; the wire encoder
omits them when default so existing payloads stay byte-identical.
New constants: ``COMPILE_TARGET_HF_RAW``, ``_VLLM_FUSED``,
``_DEEPGEMM_FP8``, ``_CUTLASS_FP8``, ``_TRTLLM``. New helper
``compile_target_matches(descriptor, *, allowed_targets,
required_metadata=None)`` for receiver-side filtering with whitelist
plus required-metadata-subset semantics.
### Phase 3b -- compile_target_filter on discover_v2_sources
New kwargs ``compile_target_filter`` (whitelist set) and
``required_compile_metadata`` (subset-of-every-tensor's
compile_metadata). Candidates with no v2 registry are rejected when
either filter is set; candidates with mixed compile targets are
rejected if any tensor falls outside the allowed set.
``V2SourceCandidate.compile_targets: frozenset[str]`` exposed for
caller introspection.
### Phase 4 -- Multi-source slice discovery for mixed trainer/inference TP
New types: ``TargetTPLayout``, ``SliceSource``, ``SliceCoveragePlan``.
New methods ``MxV2RefitReceiver.discover_v2_sources_for_slice`` and
``MxV2RefitReceiver.receive_via_plan`` -- planner walks v2 candidates
per tensor, intersects each publisher's local_shard_range against the
receiver's requested slice, emits the minimal candidate set covering
it; surfaces coverage gaps and shard_axis mismatches in plan.missing.
``receive_via_plan`` orchestrates per-candidate scratch RDMA pulls
and stitches via torch.cat along the shard axis.
### Proto extensions
* ``TensorDescriptorV2.extra_parameters: map<string, string>`` -- the
escape hatch downstream RL clients use for per-tensor metadata
(megatron_role, compile_target, expert_id, revision, training_step,
...). Heavily used by #429.
* ``SourceIdentity.revision: string`` -- content-addressed weight
version. Non-empty value guarantees two sources with identical
SourceIdentity have bit-identical weight bytes, enabling
decentralized modes (no central coordinator) to use mx_source_id
as a full content check.
* Redis backend ``SourceAttributesJson`` carries the new fields plus
the union of main's compatibility additions
(backend_framework_version, torch_version, cuda_version,
triton_version, gpu_arch, compile_config_digest).
## Tests
68/68 unit tests green on this branch's surface (test_v2_source_picker,
test_v2_shape_registry, test_source_id, test_types, test_vllm_adapter).
## Compatibility
Every new field has a backwards-compatible default and every new method
arg is optional. Existing demos and downstream consumers run unchanged.
## Downstream impact
This PR is the prerequisite for several PRs in flight:
* #421 -- publish_self_as_source tree fan-out fix
* #429 -- Megatron-Core MX clients
* #450 -- MX_RDMA_NIC_PIN=stripe multi-NIC mode
* ai-dynamo/dynamo#10900 -- first-time upstream port of
dynamo.vllm.mx_refit extension
* ai-dynamo/dynamo#10901 -- buffer-caching perf fix in Dynamo extension
* jthomson04/RL#2 (merged) + #7 -- NeMo RL Megatron-MX integration plus
perf fix
Once this PR lands, kavink/nemo_rl_moe (the integration branch holding
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
#421 and #429) can be retired by rebasing those PRs onto main.
Summary
Adds
MX_RDMA_NIC_PIN=stripemode tomodelexpress.ucx_utils.apply_nic_pin_for_device. Lists all visible compute-rate IB NICs inUCX_NET_DEVICES(comma-joined) and bumpsUCX_MAX_RMA_RAILSto the NIC count, so UCX actually uses all of them for parallel RDMA pulls.Closes #449.
Motivation
MX/NIXL's per-receiver RDMA pull is single-NIC-pinned by default. On multi-NIC pods (4-NIC GB200, 8-NIC GB300, 32-NIC AWS p5) this leaves 75-87.5% of allocated NIC bandwidth idle. Single-receiver bulk-pull on a 4-NIC GB200 pod measures at ~316 Gbps — saturating one NIC, but well below the ~1.2 Tbps the four NICs are capable of.
Striping the pull across all allocated NICs lifts per-receiver bandwidth proportionally to the number of NICs UCX can fan out across, which is the dominant lever for refit-cycle wall time on many-receiver setups.
Change
One file changed in production code (
modelexpress/ucx_utils.py): adds_stripe_all_compute_nics, extends_resolve_nic_pinto handlestripe/allaliases, and extendsapply_nic_pin_for_deviceto bumpUCX_MAX_RMA_RAILSto the NIC count when striping.Existing modes (
auto, explicit list,off/unset) are unchanged. Stripe is opt-in.Note on
MAX_RMA_RAILS: UCX defaults this to 2, AND NIXL hardcodes it to 2 innixl/src/plugins/ucx/ucx_utils.cpp:422. The NIXL hardcode overrides our env-set value unless we also patch NIXL — tracked in #449 as Layer 3. For now this PR is the deployment-side fix; users opting intostripemode should also ensure they're on a NIXL build that respects the env-var override (or apply the future NIXL patch alongside).Tests
9 new unit tests in
tests/test_nic_pin_stripe.py:test_stripe_lists_all_compute_nicsmlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1test_stripe_bumps_max_rma_rails_to_nic_countUCX_MAX_RMA_RAILS=4set on 4-NIC podtest_stripe_respects_explicit_max_rma_railstest_stripe_single_nic_skips_rails_bumptest_stripe_no_nics_visibletest_stripe_alias_allMX_RDMA_NIC_PIN=allworks as aliastest_auto_mode_does_not_set_railsautomode unchangedtest_min_rate_filter_passed_to_stripetest_off_mode_unchangedoffmodeAll 9 pass.
Cluster validation
Validated on a 4-NIC GB200 pod with the existing multi-cycle Qwen3-4B-Thinking refit benchmark, with both the trainer (publisher) and receiver opted into
MX_RDMA_NIC_PIN=stripe. Real Megatron-Bridge-loaded model, 290 source tensors, 8.04 GB total, 3 back-to-back refit cycles per run.auto(current default — 1 NIC)stripe(this PR — 4 NICs)Cycle-1 wall regresses ~30% because there are 4× more NICs to register against; cycles 2+ recover that and net out ahead because the cached registration carries forward (paired with #421 / #429 / ai-dynamo/dynamo#10901 buffer caching).
Why only ~2× scaling on a 4-NIC pod
Bumping
UCX_MAX_RMA_RAILSfrom 4 to 8 didn't help — bandwidth plateaus at ~625 Gbps. Likely causes:MAX_RMA_RAILS=2innixl/src/plugins/ucx/ucx_utils.cpp:422viaconfig.modify(...). Per UCX semantics that variant is "modify if not already set via env var", so our env should win — but the empirical bandwidth ceiling is suspiciously close to 2-rail behavior. Worth instrumenting withUCX_LOG_LEVEL=debugto confirm which rail count UCX actually picked.#449's Layer 3 (expose
max_rma_railsas a NIXL backend param) targets cause 1.Test plan
examples_dgd/k8s_exemplars/V*/to use stripe mode (separate small PR after this lands)max_rma_railsas a backend paramRelated