Skip to content

Adapt SolLatencyEstimator for Triton intra-node AllReduce backend#888

Open
mfrancepillois wants to merge 1 commit into
mainfrom
maxime_latency_estimation_triton_collective_ops
Open

Adapt SolLatencyEstimator for Triton intra-node AllReduce backend#888
mfrancepillois wants to merge 1 commit into
mainfrom
maxime_latency_estimation_triton_collective_ops

Conversation

@mfrancepillois
Copy link
Copy Markdown

🎯 Justification

Triton collective kernels (one-shot / two-shot) use GPU-side P2P memory access rather than the NCCL/RCCL ring algorithm, so the existing NIC-bandwidth-based cost model is wrong for those collectives.

Without this change, enabling the Triton AllReduce kernel (xla_gpu_unsupported_use_all_reduce_one_shot_kernel) causes two independent categories of incorrect scheduling decisions:

  1. Async Triton AllReduce: wrong GetLatencyBetween budget SolLatencyEstimator::GetLatencyBetween(AllReduceStart, AllReduceDone) returns the value of ComputeCollectiveTime(), which routes intra-node (SINGLE_PARTITION) collectives to CollectiveInterpolator. CollectiveInterpolator uses NCCL-calibrated empirical bandwidth tables. The Latency Hiding Scheduler uses this budget to decide how much compute to interleave between AllReduceStart and AllReduceDone. An over-estimated budget causes the scheduler to try to fill a large overlapping window that does not exist at runtime, potentially reordering compute operations in ways that hurt rather than help. An under-estimated budget would prevent profitable overlap from being exploited.

  2. Sync Triton AllReduce: NodeCost returns kLowCost instead of the true blocking latency The sync Triton AllReduce kernel is dispatched on the compute stream and fully blocks all further compute until it finishes. NodeCost() currently returns kLowCost=1 µs for any IsAsyncCollectiveStartOp() instruction regardless of is_sync. This causes the scheduler to treat a ~5 µs blocking op as if it were free on the critical path. For a Transformer training step with many small AllReduces (e.g. bias/norm gradient synchronisations), this error compounds: the scheduler builds incorrect priority orders and may produce a schedule that is far from optimal.

📝 Summary of Changes

This change introduces a three-layer fix:

  1. CollectiveKernelStrategy proto annotation (backend_configs.proto)

    • Add CollectiveKernelStrategy enum to CollectiveBackendConfig with KERNEL_STRATEGY_DEFAULT (NCCL), KERNEL_STRATEGY_TRITON_ONE_SHOT, KERNEL_STRATEGY_TRITON_TWO_SHOT.
  2. New HLO pass: CollectiveKernelStrategyAnnotator

    • Runs after CollectiveBackendAssigner, before scheduling.
    • Calls BuildAllReduceInfo() (same eligibility check as thunk_emitter) to determine whether Triton one-shot / two-shot will be used at runtime, and annotates the HLO instruction's backend_config.
    • Only registered when xla_gpu_unsupported_use_all_reduce_one_shot_kernel is enabled, keeping the NCCL path completely unchanged otherwise.
  3. SolGPUCostModel / SolLatencyEstimator updates

    • SolGPUCostModel::Config: add nvlink_bw_per_lane_gbps (from CudaBandwidthSettings / RocmBandwidthSettings via the new GpuPerformanceWithCollectiveModel::GetNvlinkBandwidthPerLaneGbps helper) and nvlink_barrier_latency.
    • SolGPUCostModel::TritonAllReduceLatency: NVLink-based formula for one-shot (<= 256 KB) and two-shot (<= 4 MB) strategies. Strategy is derived internally from size_bytes via GetAllReduceStrategy(), matching the runtime threshold.
    • SolLatencyEstimator::DispatchEstimation SINGLE_PARTITION path: when the kernel_strategy annotation is TRITON_ONE/TWO_SHOT, use TritonAllReduceLatency (NVLink formula + HBM time) instead of CollectiveInterpolator (NCCL empirical table).
    • SolLatencyEstimator::NodeCost: sync Triton AllReduce runs on the compute stream and fully blocks compute. When is_sync=true and kernel_strategy is TRITON_ONE/TWO_SHOT, return ComputeCollectiveTime rather than kLowCost=1.0, so the latency lands on the critical path.

ROCm support: GetNvlinkBandwidthPerLaneGbps dispatches to RocmBandwidthSettings::GetNvlinkBw() automatically, so the fix covers both CUDA and ROCm targets.

🧪 Execution Tests:

  • collective_kernel_strategy_annotator_test.cc: verifies small (ONE_SHOT), medium (TWO_SHOT) and large (DEFAULT/NCCL) AllReduces are annotated correctly.
  • sol_gpu_cost_model_test.cc: TritonAllReduceLatency formula validation.
  • sol_latency_estimator_test.cc: scheduler integration - async Triton ONE_SHOT uses NVLink formula and differs from NCCL; sync Triton ONE_SHOT ComputeCollectiveTime exceeds kLowCost; async NCCL NodeCost remains kLowCost=1 µs.

@mfrancepillois mfrancepillois added the claude-review Request a Claude AI code review for this PR label May 28, 2026
@mfrancepillois mfrancepillois force-pushed the maxime_latency_estimation_triton_collective_ops branch 2 times, most recently from 28e8280 to 2095543 Compare May 28, 2026 16:43
Comment thread xla/service/gpu/model/sol_gpu_cost_model.cc
Comment thread xla/service/gpu/model/sol_latency_estimator.cc
Comment thread xla/service/gpu/model/gpu_collective_performance_model.cc Outdated
Comment thread xla/service/gpu/model/sol_gpu_cost_model.cc
Comment thread xla/service/gpu/model/sol_latency_estimator.cc Outdated
Comment thread xla/service/gpu/model/sol_latency_estimator_test.cc
@claude
Copy link
Copy Markdown

claude Bot commented May 28, 2026

Review Summary

Reviewed the three-layer change introducing Triton intra-node AllReduce cost modelling: the CollectiveKernelStrategy proto enum, the CollectiveKernelStrategyAnnotator HLO pass, and the NVLink-bandwidth-based cost formula in SolGPUCostModel/SolLatencyEstimator.

Key findings (7 inline comments):

  • Potential double-counting of sync Triton AllReduce latency — NodeCost and GetLatencyBetween both call ComputeCollectiveTime for the same instruction.
  • Fragile strategy re-derivationTritonAllReduceLatency re-derives the strategy from size_bytes instead of accepting the already-known enum, bypassing the 4 MB size guard.
  • Null-pointer risk in GetNvlinkBandwidthPerLaneGbps when the device is neither CUDA nor ROCm.
  • Silent under-estimation for unmodelled strategies (kMultimem) in the default case.
  • Test gap — the sync Triton NodeCost dispatch path (lines 504-526 in sol_latency_estimator.cc) lacks direct end-to-end test coverage.
  • Minor nits: implicit float→int64 narrowing, unchecked StatusOr dereference in test.

Overall the design is solid — the annotation-before-scheduling approach is clean and the cost formulas match the documented kernel behavior. The double-counting issue is the most impactful finding to address before merge.

🤖 Generated with Claude Code

@github-actions github-actions Bot removed the claude-review Request a Claude AI code review for this PR label May 28, 2026
Triton collective kernels (one-shot / two-shot) use GPU-side P2P
memory access rather than the NCCL/RCCL ring algorithm, so the existing
NIC-bandwidth-based cost model is wrong for those collectives.

== Motivation: incorrect scheduling with the old model ==

Without this change, enabling the Triton AllReduce kernel
(xla_gpu_unsupported_use_all_reduce_one_shot_kernel) causes two
independent categories of incorrect scheduling decisions:

1. Async Triton AllReduce: wrong GetLatencyBetween budget
   SolLatencyEstimator::GetLatencyBetween(AllReduceStart, AllReduceDone)
   returns the value of ComputeCollectiveTime(), which routes intra-node
   (SINGLE_PARTITION) collectives to CollectiveInterpolator.
   CollectiveInterpolator uses NCCL-calibrated empirical bandwidth tables.
   The Latency Hiding Scheduler uses this budget to decide how much
   compute to interleave between AllReduceStart and AllReduceDone.
   An over-estimated budget causes the scheduler to try to fill a large
   overlapping window that does not exist at runtime, potentially
   reordering compute operations in ways that hurt rather than help.
   An under-estimated budget would prevent profitable overlap
   from being exploited.

2. Sync Triton AllReduce: NodeCost returns kLowCost instead of the
   true blocking latency
   The sync Triton AllReduce kernel is dispatched on the compute stream
   and fully blocks all further compute until it finishes.
   NodeCost() currently returns kLowCost=1 µs for any
   IsAsyncCollectiveStartOp() instruction regardless of is_sync. This
   causes the scheduler to treat a ~5 µs blocking op as if it were free
   on the critical path.  For a Transformer training step with many small
   AllReduces (e.g. bias/norm gradient synchronisations), this error
   compounds: the scheduler builds incorrect priority orders and may
   produce a schedule that is far from optimal.

This change introduces a three-layer fix:

1. CollectiveKernelStrategy proto annotation (backend_configs.proto)
   - Add CollectiveKernelStrategy enum to CollectiveBackendConfig with
     KERNEL_STRATEGY_DEFAULT (NCCL), KERNEL_STRATEGY_TRITON_ONE_SHOT,
     KERNEL_STRATEGY_TRITON_TWO_SHOT.

2. New HLO pass: CollectiveKernelStrategyAnnotator
   - Runs after CollectiveBackendAssigner, before scheduling.
   - Calls BuildAllReduceInfo() (same eligibility check as thunk_emitter)
     to determine whether Triton one-shot / two-shot will be used at
     runtime, and annotates the HLO instruction's backend_config.
   - Only registered when xla_gpu_unsupported_use_all_reduce_one_shot_kernel
     is enabled, keeping the NCCL path completely unchanged otherwise.

3. SolGPUCostModel / SolLatencyEstimator updates
   - SolGPUCostModel::Config: add nvlink_bw_per_lane_gbps (from
     CudaBandwidthSettings / RocmBandwidthSettings via the new
     GpuPerformanceWithCollectiveModel::GetNvlinkBandwidthPerLaneGbps
     helper) and nvlink_barrier_latency.
   - SolGPUCostModel::TritonAllReduceLatency: NVLink-based formula
     for one-shot (<= 256 KB) and two-shot (<= 4 MB) strategies.
     Strategy is derived internally from size_bytes via
     GetAllReduceStrategy(), matching the runtime threshold.
   - SolLatencyEstimator::DispatchEstimation SINGLE_PARTITION path:
     when the kernel_strategy annotation is TRITON_ONE/TWO_SHOT, use
     TritonAllReduceLatency (NVLink formula + HBM time) instead of
     CollectiveInterpolator (NCCL empirical table).
   - SolLatencyEstimator::NodeCost: sync Triton AllReduce runs on the
     compute stream and fully blocks compute. When is_sync=true and
     kernel_strategy is TRITON_ONE/TWO_SHOT, return ComputeCollectiveTime
     rather than kLowCost=1.0, so the latency lands on the critical path.

ROCm support: GetNvlinkBandwidthPerLaneGbps dispatches to
RocmBandwidthSettings::GetNvlinkBw() automatically,
so the fix covers both CUDA and ROCm targets.

Tests:
- collective_kernel_strategy_annotator_test.cc: verifies small (ONE_SHOT),
  medium (TWO_SHOT) and large (DEFAULT/NCCL) AllReduces are annotated
  correctly.
- sol_gpu_cost_model_test.cc: TritonAllReduceLatency formula validation.
- sol_latency_estimator_test.cc: scheduler integration - async Triton
  ONE_SHOT uses NVLink formula and differs from NCCL; sync Triton
  ONE_SHOT ComputeCollectiveTime exceeds kLowCost; async NCCL NodeCost
  remains kLowCost=1 µs.
@mfrancepillois mfrancepillois force-pushed the maxime_latency_estimation_triton_collective_ops branch from 2095543 to 02e513d Compare May 28, 2026 17:04
@i-chaochen i-chaochen requested review from Eetusjo, ScXfjiang and alekstheod and removed request for draganmladjenovic May 28, 2026 22:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant