Minimal single-file implementations of TP, SP, and PP for LLaMA on torch.distributed. Inspired by efficient-dl-systems; for production use, I'd recommend Picotron.
Each strategy is a standalone script. Run with torchrun:
torchrun --nproc_per_node=2 tp.py --length 20480 --mode benchmark --dtype bfloat16
torchrun --nproc_per_node=2 tp.py --length 20480 --mode benchmark --dtype bfloat16 --dtensor
torchrun --nproc_per_node=2 sp_megatron.py --length 20480 --mode correctness --dtype bfloat16
torchrun --nproc_per_node=2 pp.py --length 20480 --mode all --dtype bfloat16Options: --mode {correctness|learning|benchmark|all}, --length N, --dtype {float32|bfloat16}, --model MODEL, --bench-iters N.
And to run oom_search.py binary-searches for the longest sequence that fits in memory call:
python oom_search.py --strategy sp --nproc 2 --lo 14000 --hi 40000 --step 512
python oom_search.py --strategy all --nproc 1Every distributed strategy boils down to four operations.
reduce_scatter - each node ends up with one aggregated chunk. Cost: O(M).
all_gather - each node collects all chunks, reconstructing the full tensor. Cost: O(M).
all_reduce = reduce_scatter + all_gather. Cost: 2 x O(M).
all_to_all - redistributes chunks across nodes (a permutation, no aggregation). Cost: O(M) in bandwidth, but on ring topology it's slower than reduce_scatter/all_gather - each node must send its own P−1 chunks and help relay neighbors' chunks too.
So: all_reduce = 2 x reduce_scatter = 2 x all_gather. And all_to_all has the same bandwidth cost O(M) but higher latency on ring interconnects.
One thing to watch: the result of distributed sync operations is
AsyncCollectiveTensor. PyTorch implicitly calls.wait()when you do matmuls, but.item()or custom kernels can silently read stale data.
Each device holds a shard of every layer's weights. For MLP: weights are split along the intermediate dimension (h -> h/P -> h). For attention: Q, K, V projections are split by heads, O projection split along input dim.
Forward pass produces partial results that must be summed - all_reduce after each attention and MLP block. Communication cost per layer: O(N x h), where N is sequence length and h is hidden size.
Backward: gradients on weights accumulate locally (each device owns its shard). Input gradients need the same all_reduce.
Two implementations here: manual sharding with custom autograd.Function, and PyTorch's DTensor via parallelize_module.
Memory: TP shards weights (each device holds 1/P of parameters) but every device still holds full-sequence activations. At long context lengths, activations dominate and TP alone doesn't help - this is why SP exists.
Extension of TP. Each device holds N/P tokens instead of the full sequence, directly cutting activation memory.
Two variants exist with different trade-offs:
Megatron SP (implemented here) - weights are sharded same as TP. Before attention, all_gather reconstructs the full sequence: O(N x h). Each device computes attention on its local heads over the full sequence. After attention, reduce_scatter back to chunks: O(N x h). Total communication: 2 x O(N x h) - same bandwidth as all_reduce in vanilla TP. The advantage is purely memory: between layers, only the local chunk (seq/P x hidden) is stored.
Ulysses SP - each device holds full QKV/O weights. Instead of gathering the full sequence, it uses all_to_all to swap from (my tokens, all heads) -> (all tokens, my heads). Each device computes attention for its subset of heads on the full sequence. Reverse all_to_all in backward. Communication: 2 x O(N/P x h), which is P times cheaper than Megatron SP. But it requires fast all_to_all interconnect (NVLink mesh, not ring) and keeps full weights on every device - pair with PP or FSDP to manage parameter memory. Pays off at large batch sizes.
Implementation detail: our autograd.Function recomputes the forward pass during backward (saves only the local input shard via ctx.save_for_backward, then re-gathers and re-forwards in backward). This is effectively activation checkpointing per block - but it applies at any world_size including 1, so it doesn't explain SP-vs-TP differences.
Implementation detail: when the sequence is split across devices, each chunk's positional indices must reflect its position in the full sequence, not start from 0. You manually pass position_ids into the model. The causal attention mask also needs rebuilding after all_gather since the gathered sequence has a different length than what the model expects.
Each device holds a contiguous subset of layers. Activations flow forward via point-to-point send/recv, gradients flow backward the same way.
PP cuts both parameter memory and activation memory - each stage stores only its layers and their activations.
The autograd graph is split across devices, so you can't just call loss.backward() end-to-end. Each stage manually calls torch.autograd.backward(output, grad_tensors=received_grad).
Convenient API tricks for LLaMA:
- Replace skipped components (
embed_tokens,norm,lm_head) withnn.Identity()- the model still runs, just passes tensors through (it is waaaay easier than implement per-block synchronizations!) - Use the
inputs_embedskwarg to bypass the embedding layer on non-first stages and to inject gradient-tracked tensors for verification retain_grad()instead ofrequires_grad_()on intermediate tensors - after receiving and potentially casting, the tensor may not be a leaf anymore
This repo implements AFAB (All-Forward-All-Backward) schedule. 1F1B and interleaved schedules exist but are not implemented.
PP OOM note: logits must be materialized for inter-stage communication (~4.3GB for a 1B model). This extra allocation is why PP hits OOM earlier than TP/SP at the same sequence length on a single GPU.
Per-layer communication for a sequence of N tokens, hidden size h, P devices:
| Strategy | Attention | MLP | Total per layer |
|---|---|---|---|
| TP | O(Nh) all_reduce | O(Nh) all_reduce | 4 O(Nh) |
| Megatron SP | 2 O(Nh) gather+scatter | O(Nh) all_reduce | 4 O(Nh) |
| Ulysses SP | 2 O(Nh/P) all_to_all | O(Nh/P) all_reduce | 4 O(Nh/P) |
| PP | - | - | O(Nh) per boundary |
TP and Megatron SP have the same total bandwidth cost. SP's advantage is purely memory - activations are distributed. Ulysses is Px cheaper in bandwidth but needs full weights on each device and fast all_to_all.
PP communication is minimal (only at stage boundaries, not every layer) but introduces pipeline bubbles where some stages idle.
LLaMA 1B, bf16, Flash Attention 2, batch=1.
| GPUs | TP | SP | PP |
|---|---|---|---|
| 1 | 19k | 19k | 15k |
| 2 | 21k | 41k | 26k |
TP 1->2 GPU: modest gain. Weight sharding on a 1B model frees ~1GB - activations still dominate and aren't sharded.
SP 1->2 GPU: ~2.15x, consistent with linear scaling. With Flash Attention all allocations are linear in sequence length. Between layers SP stores seq/P x hidden instead of seq x hidden. The slight bonus over exactly 2x comes from weight sharding freeing additional memory.
PP single GPU: lower than TP/SP because logits must be materialized for inter-stage send (~4.3GB). PP at 2 GPU still limited by last-stage logit allocation.
| Strategy | Forward | Forward + Backward |
|---|---|---|
| TP (manual) | 1.3s | 4.0s |
| TP (DTensor) | 1.3s | 4.9s |
| SP | 1.2s | 4.0s |
| PP | 0.7s* | 4.7s* |
*PP forward is rank-0 time only (prints before rank-1 finishes). True wall-clock is likely ~1.4s.
DTensor backward is ~20% slower than manual TP - probably the cost of the abstraction.
If activations are the bottleneck (long sequences) - SP. It's the only strategy that directly reduces per-device activation memory.
If the model doesn't fit on one GPU - PP to split layers, or TP to split weights. PP has less communication overhead but introduces bubbles. TP has no bubbles but communicates every layer.
In practice these combine (with something DDP-like).
TP/SP with LoRA is awkward - adapter weights aren't sharded, so their gradients need separate handling. Prompt tuning (prepending learnable tokens) is simpler since virtual tokens go through the same parallel path as real tokens, but in PP prompt tuning makes gradient tracking harder - LoRA is the better choice there. For reproducibility I ended up with simple hack - I just optimized randomly sampled embeddings.
Weights are stored transposed. PyTorch linear layers store weight as (d_out, d_in), not (d_in, d_out). So x @ weight.T means sharding dim=0 splits the output dimension and dim=1 splits the input dimension. This bites you when manually chunking weights for TP.
Attention returns tuples. LLaMA attention forward returns (hidden_states, self_attn_weights, past_key_value), not a single tensor. If you all_reduce output instead of output[0], you get cryptic errors. And what is more dangerous - in case of MLP you may silently aggregate only one element of the batch.
kwargs in autograd. torch.autograd.Function.forward doesn't support kwargs. HuggingFace layers are called almost exclusively with kwargs. You need to pop them manually and thread them through via ctx.
Position embeddings in SP. When the sequence is split, rotary embeddings need the global position of each chunk, not local indices starting from 0. You must manually construct and pass position_ids. Miss this and the model computes wrong attention patterns - numerically wrong but won't crash.
copy_ needs no_grad. Inplace copy_ on parameters only works inside torch.no_grad().
retain_grad vs requires_grad_. After sending/receiving and potential dtype casting, the tensor might not be a leaf. requires_grad_() only works on leaves. retain_grad() works on any tensor in the graph. Alternative: .detach().requires_grad_(True) to force anything into a leaf.
adaptive_autorange can deadlock. It decides iteration count based on timing, which differs per rank (especially in PP). One rank finishes while another is still waiting to sync. Use fixed iteration counts for distributed benchmarks.
parallelize_module is synchronized. DTensor's parallelize_module is a synchronized call. Model must be on-device before you call it, or you deadlock.
SP loss computation. After forward, each rank has logits for its chunk only. Computing global loss requires all_gather, then replacing your rank's chunk with the original tensor to preserve the autograd graph - otherwise gradients don't flow back.
Local shifts lose an extra token. In SP with causal LM, the shift-by-one for next-token prediction loses one token at each chunk boundary, not just one total.