Skip to content

feat: add tensor-parallel entropy loss benchmark #500

@jlamypoirier

Description

@jlamypoirier

Motivation

The Triton entropy loss kernel supports tensor-parallel training natively via its group parameter: each rank holds vocab / tp_size logits, all-reduces two scalars per row (max and sum-exp), then computes loss and grad locally. This avoids materializing the full logits on any rank.

The PyTorch alternative requires an all-gather of the full logit tensor first — O(tokens × vocab) communication. At realistic scale this is prohibitive:

  • Llama 3.1 405B: vocab=128K, tokens=8K, TP=8 → 16 GB all-gather per step

The Triton path is not just faster; it is the only feasible approach at large vocab × TP. The current single-GPU benchmark does not demonstrate this. This issue tracks building a multi-GPU benchmark that makes it concrete.

Variants to benchmark

Three qualitatively different approaches:

  1. triton_tp — existing triton_entropy_loss_forward_backward(..., group=group). Shards vocab across ranks, all-reduces two scalars per row. O(tokens) communication.

  2. pytorch_tp_manual — same algorithm in PyTorch without Triton: local_max = logits.max(-1)all_reduce(MAX)local_sum = (logits − max).exp().sum(-1)all_reduce(SUM) → loss. Tests whether Triton fusion still wins when both paths use the same O(tokens) communication pattern.

  3. pytorch_gather — all-gather logits to full vocab on each rank, then F.cross_entropy. O(tokens × vocab) communication. Included as a reference to show where the naive approach becomes infeasible; expected to OOM at large vocab × TP.

Shapes

Fix tokens=4096. Sweep (vocab, tp_size) pairs:

vocab tp_size shard / rank
32768 2 16384
32768 4 8192
65536 4 16384
131072 4 32768
131072 8 16384

Infrastructure changes

Multi-process runner

The current runner.py is single-process. Two options:

Option A — new tools/benchmark/run_tp.py entry point using torch.multiprocessing.spawn(worker, nprocs=tp_size). Each worker: initializes dist.init_process_group, creates a TP process group, runs benchmark variants with group=group, rank 0 collects and prints results.

Option B — extend __main__.py to detect a --tp N flag and re-launch itself via torchrun --nproc_per_node=N.

Option A is simpler to implement. Option B integrates more cleanly with the existing CLI.

Timing

  • dist.barrier() + torch.cuda.synchronize() before and after each timed region so all ranks agree on wall time.
  • Report max latency across ranks (rank 0 collects via dist.all_reduce(MAX)).
  • Communication time is included automatically since the all-reduce is inside the kernel call.

OOM guard for pytorch_gather

Wrap in try/except and report OOM in the table instead of a time. This makes the table show exactly where the naive approach becomes infeasible.

Expected outcome

At vocab=131072, TP=8:

variant result
triton_tp fast (~O(tokens) communication)
pytorch_tp_manual slightly slower (no kernel fusion, same communication)
pytorch_gather OOM or ~10–20× slower (16 GB all-gather dominates)

Files

  • New: tools/benchmark/run_tp.py (or extend __main__.py)
  • Modified: tools/benchmark/bench_entropy_loss.py — add TP variants alongside existing single-GPU variants

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions