Skip to content

handdl/model-parallelisms

Repository files navigation

model-parallelisms

Minimal single-file implementations of TP, SP, and PP for LLaMA on torch.distributed. Inspired by efficient-dl-systems; for production use, I'd recommend Picotron.

Use

Each strategy is a standalone script. Run with torchrun:

torchrun --nproc_per_node=2 tp.py --length 20480 --mode benchmark --dtype bfloat16
torchrun --nproc_per_node=2 tp.py --length 20480 --mode benchmark --dtype bfloat16 --dtensor
torchrun --nproc_per_node=2 sp_megatron.py --length 20480 --mode correctness --dtype bfloat16
torchrun --nproc_per_node=2 pp.py --length 20480 --mode all --dtype bfloat16

Options: --mode {correctness|learning|benchmark|all}, --length N, --dtype {float32|bfloat16}, --model MODEL, --bench-iters N.

And to run oom_search.py binary-searches for the longest sequence that fits in memory call:

python oom_search.py --strategy sp --nproc 2 --lo 14000 --hi 40000 --step 512
python oom_search.py --strategy all --nproc 1

Collective primitives

Every distributed strategy boils down to four operations.

reduce_scatter - each node ends up with one aggregated chunk. Cost: O(M).

all_gather - each node collects all chunks, reconstructing the full tensor. Cost: O(M).

all_reduce = reduce_scatter + all_gather. Cost: 2 x O(M).

all_to_all - redistributes chunks across nodes (a permutation, no aggregation). Cost: O(M) in bandwidth, but on ring topology it's slower than reduce_scatter/all_gather - each node must send its own P−1 chunks and help relay neighbors' chunks too.

So: all_reduce = 2 x reduce_scatter = 2 x all_gather. And all_to_all has the same bandwidth cost O(M) but higher latency on ring interconnects.

One thing to watch: the result of distributed sync operations is AsyncCollectiveTensor. PyTorch implicitly calls .wait() when you do matmuls, but .item() or custom kernels can silently read stale data.

Strategies

Tensor Parallelism (TP)

Each device holds a shard of every layer's weights. For MLP: weights are split along the intermediate dimension (h -> h/P -> h). For attention: Q, K, V projections are split by heads, O projection split along input dim.

Forward pass produces partial results that must be summed - all_reduce after each attention and MLP block. Communication cost per layer: O(N x h), where N is sequence length and h is hidden size.

Backward: gradients on weights accumulate locally (each device owns its shard). Input gradients need the same all_reduce.

Two implementations here: manual sharding with custom autograd.Function, and PyTorch's DTensor via parallelize_module.

Memory: TP shards weights (each device holds 1/P of parameters) but every device still holds full-sequence activations. At long context lengths, activations dominate and TP alone doesn't help - this is why SP exists.

Sequence Parallelism (SP)

Extension of TP. Each device holds N/P tokens instead of the full sequence, directly cutting activation memory.

Two variants exist with different trade-offs:

Megatron SP (implemented here) - weights are sharded same as TP. Before attention, all_gather reconstructs the full sequence: O(N x h). Each device computes attention on its local heads over the full sequence. After attention, reduce_scatter back to chunks: O(N x h). Total communication: 2 x O(N x h) - same bandwidth as all_reduce in vanilla TP. The advantage is purely memory: between layers, only the local chunk (seq/P x hidden) is stored.

Ulysses SP - each device holds full QKV/O weights. Instead of gathering the full sequence, it uses all_to_all to swap from (my tokens, all heads) -> (all tokens, my heads). Each device computes attention for its subset of heads on the full sequence. Reverse all_to_all in backward. Communication: 2 x O(N/P x h), which is P times cheaper than Megatron SP. But it requires fast all_to_all interconnect (NVLink mesh, not ring) and keeps full weights on every device - pair with PP or FSDP to manage parameter memory. Pays off at large batch sizes.

Implementation detail: our autograd.Function recomputes the forward pass during backward (saves only the local input shard via ctx.save_for_backward, then re-gathers and re-forwards in backward). This is effectively activation checkpointing per block - but it applies at any world_size including 1, so it doesn't explain SP-vs-TP differences.

Implementation detail: when the sequence is split across devices, each chunk's positional indices must reflect its position in the full sequence, not start from 0. You manually pass position_ids into the model. The causal attention mask also needs rebuilding after all_gather since the gathered sequence has a different length than what the model expects.

Pipeline Parallelism (PP)

Each device holds a contiguous subset of layers. Activations flow forward via point-to-point send/recv, gradients flow backward the same way.

PP cuts both parameter memory and activation memory - each stage stores only its layers and their activations.

The autograd graph is split across devices, so you can't just call loss.backward() end-to-end. Each stage manually calls torch.autograd.backward(output, grad_tensors=received_grad).

Convenient API tricks for LLaMA:

  • Replace skipped components (embed_tokens, norm, lm_head) with nn.Identity() - the model still runs, just passes tensors through (it is waaaay easier than implement per-block synchronizations!)
  • Use the inputs_embeds kwarg to bypass the embedding layer on non-first stages and to inject gradient-tracked tensors for verification
  • retain_grad() instead of requires_grad_() on intermediate tensors - after receiving and potentially casting, the tensor may not be a leaf anymore

This repo implements AFAB (All-Forward-All-Backward) schedule. 1F1B and interleaved schedules exist but are not implemented.

PP OOM note: logits must be materialized for inter-stage communication (~4.3GB for a 1B model). This extra allocation is why PP hits OOM earlier than TP/SP at the same sequence length on a single GPU.

Communication costs

Per-layer communication for a sequence of N tokens, hidden size h, P devices:

Strategy Attention MLP Total per layer
TP O(Nh) all_reduce O(Nh) all_reduce 4 O(Nh)
Megatron SP 2 O(Nh) gather+scatter O(Nh) all_reduce 4 O(Nh)
Ulysses SP 2 O(Nh/P) all_to_all O(Nh/P) all_reduce 4 O(Nh/P)
PP - - O(Nh) per boundary

TP and Megatron SP have the same total bandwidth cost. SP's advantage is purely memory - activations are distributed. Ulysses is Px cheaper in bandwidth but needs full weights on each device and fast all_to_all.

PP communication is minimal (only at stage boundaries, not every layer) but introduces pipeline bubbles where some stages idle.

Benchmarks

LLaMA 1B, bf16, Flash Attention 2, batch=1.

Max sequence length before OOM

GPUs TP SP PP
1 19k 19k 15k
2 21k 41k 26k

TP 1->2 GPU: modest gain. Weight sharding on a 1B model frees ~1GB - activations still dominate and aren't sharded.

SP 1->2 GPU: ~2.15x, consistent with linear scaling. With Flash Attention all allocations are linear in sequence length. Between layers SP stores seq/P x hidden instead of seq x hidden. The slight bonus over exactly 2x comes from weight sharding freeing additional memory.

PP single GPU: lower than TP/SP because logits must be materialized for inter-stage send (~4.3GB). PP at 2 GPU still limited by last-stage logit allocation.

Timing (2 GPUs)

Strategy Forward Forward + Backward
TP (manual) 1.3s 4.0s
TP (DTensor) 1.3s 4.9s
SP 1.2s 4.0s
PP 0.7s* 4.7s*

*PP forward is rank-0 time only (prints before rank-1 finishes). True wall-clock is likely ~1.4s.

DTensor backward is ~20% slower than manual TP - probably the cost of the abstraction.

What to use when

If activations are the bottleneck (long sequences) - SP. It's the only strategy that directly reduces per-device activation memory.

If the model doesn't fit on one GPU - PP to split layers, or TP to split weights. PP has less communication overhead but introduces bubbles. TP has no bubbles but communicates every layer.

In practice these combine (with something DDP-like).

Fine-tuning

TP/SP with LoRA is awkward - adapter weights aren't sharded, so their gradients need separate handling. Prompt tuning (prepending learnable tokens) is simpler since virtual tokens go through the same parallel path as real tokens, but in PP prompt tuning makes gradient tracking harder - LoRA is the better choice there. For reproducibility I ended up with simple hack - I just optimized randomly sampled embeddings.

Lessons / Notes

Weights are stored transposed. PyTorch linear layers store weight as (d_out, d_in), not (d_in, d_out). So x @ weight.T means sharding dim=0 splits the output dimension and dim=1 splits the input dimension. This bites you when manually chunking weights for TP.

Attention returns tuples. LLaMA attention forward returns (hidden_states, self_attn_weights, past_key_value), not a single tensor. If you all_reduce output instead of output[0], you get cryptic errors. And what is more dangerous - in case of MLP you may silently aggregate only one element of the batch.

kwargs in autograd. torch.autograd.Function.forward doesn't support kwargs. HuggingFace layers are called almost exclusively with kwargs. You need to pop them manually and thread them through via ctx.

Position embeddings in SP. When the sequence is split, rotary embeddings need the global position of each chunk, not local indices starting from 0. You must manually construct and pass position_ids. Miss this and the model computes wrong attention patterns - numerically wrong but won't crash.

copy_ needs no_grad. Inplace copy_ on parameters only works inside torch.no_grad().

retain_grad vs requires_grad_. After sending/receiving and potential dtype casting, the tensor might not be a leaf. requires_grad_() only works on leaves. retain_grad() works on any tensor in the graph. Alternative: .detach().requires_grad_(True) to force anything into a leaf.

adaptive_autorange can deadlock. It decides iteration count based on timing, which differs per rank (especially in PP). One rank finishes while another is still waiting to sync. Use fixed iteration counts for distributed benchmarks.

parallelize_module is synchronized. DTensor's parallelize_module is a synchronized call. Model must be on-device before you call it, or you deadlock.

SP loss computation. After forward, each rank has logits for its chunk only. Computing global loss requires all_gather, then replacing your rank's chunk with the original tensor to preserve the autograd graph - otherwise gradients don't flow back.

Local shifts lose an extra token. In SP with causal LM, the shift-by-one for next-token prediction loses one token at each chunk boundary, not just one total.

About

Minimal tensor, sequence, and pipeline parallelism for LLaMA.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages