Adapt SolLatencyEstimator for Triton intra-node AllReduce backend#888
Open
mfrancepillois wants to merge 1 commit into
Open
Adapt SolLatencyEstimator for Triton intra-node AllReduce backend#888mfrancepillois wants to merge 1 commit into
mfrancepillois wants to merge 1 commit into
Conversation
28e8280 to
2095543
Compare
Review SummaryReviewed the three-layer change introducing Triton intra-node AllReduce cost modelling: the Key findings (7 inline comments):
Overall the design is solid — the annotation-before-scheduling approach is clean and the cost formulas match the documented kernel behavior. The double-counting issue is the most impactful finding to address before merge. 🤖 Generated with Claude Code |
Triton collective kernels (one-shot / two-shot) use GPU-side P2P
memory access rather than the NCCL/RCCL ring algorithm, so the existing
NIC-bandwidth-based cost model is wrong for those collectives.
== Motivation: incorrect scheduling with the old model ==
Without this change, enabling the Triton AllReduce kernel
(xla_gpu_unsupported_use_all_reduce_one_shot_kernel) causes two
independent categories of incorrect scheduling decisions:
1. Async Triton AllReduce: wrong GetLatencyBetween budget
SolLatencyEstimator::GetLatencyBetween(AllReduceStart, AllReduceDone)
returns the value of ComputeCollectiveTime(), which routes intra-node
(SINGLE_PARTITION) collectives to CollectiveInterpolator.
CollectiveInterpolator uses NCCL-calibrated empirical bandwidth tables.
The Latency Hiding Scheduler uses this budget to decide how much
compute to interleave between AllReduceStart and AllReduceDone.
An over-estimated budget causes the scheduler to try to fill a large
overlapping window that does not exist at runtime, potentially
reordering compute operations in ways that hurt rather than help.
An under-estimated budget would prevent profitable overlap
from being exploited.
2. Sync Triton AllReduce: NodeCost returns kLowCost instead of the
true blocking latency
The sync Triton AllReduce kernel is dispatched on the compute stream
and fully blocks all further compute until it finishes.
NodeCost() currently returns kLowCost=1 µs for any
IsAsyncCollectiveStartOp() instruction regardless of is_sync. This
causes the scheduler to treat a ~5 µs blocking op as if it were free
on the critical path. For a Transformer training step with many small
AllReduces (e.g. bias/norm gradient synchronisations), this error
compounds: the scheduler builds incorrect priority orders and may
produce a schedule that is far from optimal.
This change introduces a three-layer fix:
1. CollectiveKernelStrategy proto annotation (backend_configs.proto)
- Add CollectiveKernelStrategy enum to CollectiveBackendConfig with
KERNEL_STRATEGY_DEFAULT (NCCL), KERNEL_STRATEGY_TRITON_ONE_SHOT,
KERNEL_STRATEGY_TRITON_TWO_SHOT.
2. New HLO pass: CollectiveKernelStrategyAnnotator
- Runs after CollectiveBackendAssigner, before scheduling.
- Calls BuildAllReduceInfo() (same eligibility check as thunk_emitter)
to determine whether Triton one-shot / two-shot will be used at
runtime, and annotates the HLO instruction's backend_config.
- Only registered when xla_gpu_unsupported_use_all_reduce_one_shot_kernel
is enabled, keeping the NCCL path completely unchanged otherwise.
3. SolGPUCostModel / SolLatencyEstimator updates
- SolGPUCostModel::Config: add nvlink_bw_per_lane_gbps (from
CudaBandwidthSettings / RocmBandwidthSettings via the new
GpuPerformanceWithCollectiveModel::GetNvlinkBandwidthPerLaneGbps
helper) and nvlink_barrier_latency.
- SolGPUCostModel::TritonAllReduceLatency: NVLink-based formula
for one-shot (<= 256 KB) and two-shot (<= 4 MB) strategies.
Strategy is derived internally from size_bytes via
GetAllReduceStrategy(), matching the runtime threshold.
- SolLatencyEstimator::DispatchEstimation SINGLE_PARTITION path:
when the kernel_strategy annotation is TRITON_ONE/TWO_SHOT, use
TritonAllReduceLatency (NVLink formula + HBM time) instead of
CollectiveInterpolator (NCCL empirical table).
- SolLatencyEstimator::NodeCost: sync Triton AllReduce runs on the
compute stream and fully blocks compute. When is_sync=true and
kernel_strategy is TRITON_ONE/TWO_SHOT, return ComputeCollectiveTime
rather than kLowCost=1.0, so the latency lands on the critical path.
ROCm support: GetNvlinkBandwidthPerLaneGbps dispatches to
RocmBandwidthSettings::GetNvlinkBw() automatically,
so the fix covers both CUDA and ROCm targets.
Tests:
- collective_kernel_strategy_annotator_test.cc: verifies small (ONE_SHOT),
medium (TWO_SHOT) and large (DEFAULT/NCCL) AllReduces are annotated
correctly.
- sol_gpu_cost_model_test.cc: TritonAllReduceLatency formula validation.
- sol_latency_estimator_test.cc: scheduler integration - async Triton
ONE_SHOT uses NVLink formula and differs from NCCL; sync Triton
ONE_SHOT ComputeCollectiveTime exceeds kLowCost; async NCCL NodeCost
remains kLowCost=1 µs.
2095543 to
02e513d
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
🎯 Justification
Triton collective kernels (one-shot / two-shot) use GPU-side P2P memory access rather than the NCCL/RCCL ring algorithm, so the existing NIC-bandwidth-based cost model is wrong for those collectives.
Without this change, enabling the Triton AllReduce kernel (xla_gpu_unsupported_use_all_reduce_one_shot_kernel) causes two independent categories of incorrect scheduling decisions:
Async Triton AllReduce: wrong GetLatencyBetween budget SolLatencyEstimator::GetLatencyBetween(AllReduceStart, AllReduceDone) returns the value of ComputeCollectiveTime(), which routes intra-node (SINGLE_PARTITION) collectives to CollectiveInterpolator. CollectiveInterpolator uses NCCL-calibrated empirical bandwidth tables. The Latency Hiding Scheduler uses this budget to decide how much compute to interleave between AllReduceStart and AllReduceDone. An over-estimated budget causes the scheduler to try to fill a large overlapping window that does not exist at runtime, potentially reordering compute operations in ways that hurt rather than help. An under-estimated budget would prevent profitable overlap from being exploited.
Sync Triton AllReduce: NodeCost returns kLowCost instead of the true blocking latency The sync Triton AllReduce kernel is dispatched on the compute stream and fully blocks all further compute until it finishes. NodeCost() currently returns kLowCost=1 µs for any IsAsyncCollectiveStartOp() instruction regardless of is_sync. This causes the scheduler to treat a ~5 µs blocking op as if it were free on the critical path. For a Transformer training step with many small AllReduces (e.g. bias/norm gradient synchronisations), this error compounds: the scheduler builds incorrect priority orders and may produce a schedule that is far from optimal.
📝 Summary of Changes
This change introduces a three-layer fix:
CollectiveKernelStrategy proto annotation (backend_configs.proto)
New HLO pass: CollectiveKernelStrategyAnnotator
SolGPUCostModel / SolLatencyEstimator updates
ROCm support: GetNvlinkBandwidthPerLaneGbps dispatches to RocmBandwidthSettings::GetNvlinkBw() automatically, so the fix covers both CUDA and ROCm targets.
🧪 Execution Tests: