Static verification framework for distributed LLM training. Catches placement bugs and communication races at compile time — no GPUs needed.
Distributed training bugs are silent, intermittent, and catastrophic at scale:
| Bug Type | Symptom | Traditional Detection |
|---|---|---|
| Missing AllReduce after row-parallel MatMul | Silently diverging weights | Compare against single-GPU baseline |
| LayerNorm on sharded hidden dim | Numerically wrong normalization | Check mathematical equivalence |
| Async AllReduce without Wait | Race condition — works on some runs | Flaky, CI-unfriendly |
| Mismatched Send/Recv in pipeline | Deadlock under specific schedules | Only surfaces at scale |
This verifier catches them all statically — in milliseconds, with zero GPUs.
Input Source (PyTorch / Megatron / TileLang TIR)
│
▼
┌──────────────────┐
│ Verification IR │ 48 ops, symbolic metadata
│ verifier/ir/ │ not numeric values
└────────┬─────────┘
│
┌────────────┴────────────┐
▼ ▼
┌──────────────┐ ┌──────────────┐
│ SPATIAL │ │ TEMPORAL │
│ "Where" │ │ "When" │
│ │ │ │
│ Z3 SMT │ │ HB graph │
│ 6 checks │ │ 5 checks │
└──────────────┘ └──────────────┘
| Check | Property | Method |
|---|---|---|
| Postcondition | Output not Partial at boundaries | Z3 UNSAT proof |
| Communication legality | Collectives on valid inputs only | Structural + Z3 |
| Gradient duality | fwd collective ↔ bwd dual | Type duality table |
| Placement consistency | Output placement follows inputs | Symbolic propagation |
| Shape consistency | Shapes valid through collectives | Divisibility constraints |
| PP deadlock freedom | No circular Send/Recv waits | DFS cycle detection |
Multi-dim mesh support: N-dimensional meshes (e.g. TP×DP) with per-dim Z3 variables. Collectives with mesh_dim target only the specified dimension, preserving placements on other dims.
| Check | Property | Method |
|---|---|---|
| Data race | Unordered concurrent R/W on same buffer | HB interval overlap |
| Missing Wait | Async output read before sync | Handle-waited analysis |
| Buffer aliasing | Two async ops write same buffer | WAW detection |
| Dependency violation | Recv before matching Send | HB ordering |
| Orphaned handle | Async handle never waited on | Handle lifecycle tracking |
| Category | Ops | Sharding Semantics |
|---|---|---|
| Linear algebra | MatMul | S(1)×S(0)→P, R×S(1)→S(1), S(0)×R→S(0) |
| Element-wise binary | Add, Multiply | Merge placements; P×P forbidden |
| Activations | SiLU, GELU, ReLU | Passthrough |
| Regularization | Dropout | Passthrough |
| Normalization | LayerNorm, RMSNorm | Error if Shard(norm_dim) |
| Reduction | Softmax | Error if Shard(reduction_dim) |
| Attention | FlashAttention | Follows Q placement |
| Vocab-parallel | Embedding | weight Shard(vocab)→Partial |
| Loss | CrossEntropyLoss | logits Shard(vocab)→Partial |
All collectives support optional mesh_dim for multi-dim meshes — targets only the specified dimension.
| Op | Forward | Backward Dual |
|---|---|---|
| AllReduce | Partial→Replicate | AllReduce (self) |
| AllGather | Shard(d)→Replicate | ReduceScatter |
| ReduceScatter | R/P→Shard(d) | AllGather |
| Broadcast | any→Replicate | Reduce |
| Reduce | Partial→Replicate(root) | Broadcast |
| AllToAll | Shard(split)→Shard(concat) | AllToAll (dim-swap) |
| Scatter | Replicate→Shard(d) | Gather |
| Gather | Shard(d)→Replicate | Scatter |
| Send/Recv | P2P data movement | Recv/Send (reversed) |
| SendAsync/RecvAsync | Non-blocking P2P | — |
AllReduceAsync, Wait, WaitAll, OverlapRegion
Cast, LossScale, FP8Quantize, FP8Dequantize, AmaxUpdate
| Domain | Ops |
|---|---|
| ZeRO | ZeROGatherParam, ZeROScatterGrad, ZeROPartitionOptState |
| Context Parallelism | RingRotate, RingAttentionStep, RingAttention |
| Mixture of Experts | TopKGate, MoEDispatch, MoECombine, ExpertCompute |
Reshape, Transpose, Reinterpret, Convert
verifier/
├── state/ # TensorState, ShardingSpec, DeviceMesh, placements
├── ir/ # 48 IR ops across 10 modules
│ ├── compute.py # MatMul, activations, norm, attention, embedding, loss
│ ├── collective.py # 8 NCCL collectives
│ ├── p2p.py # Send/Recv (sync + async)
│ ├── async_ops.py # AllReduceAsync, Wait, OverlapRegion
│ ├── precision.py # Cast, FP8, LossScale
│ ├── shape.py # Reshape, Transpose
│ ├── spmd.py # Reinterpret, Convert, SPMDGuard
│ ├── zero.py # ZeRO-1/2/3 ops
│ ├── cp.py # Ring attention ops
│ └── moe.py # MoE dispatch/combine
├── executor.py # Multi-device symbolic executor
├── solver.py # Z3 spatial verifier (6 checks, multi-dim mesh)
├── temporal.py # HB graph + race detection (5 checks)
├── autograd.py # VJP engine + gradient duality
├── rewrite.py # Pattern matching + cost model
├── synthesis.py # Beam-search tactic synthesis
├── schedules.py # PP 1F1B schedule + deadlock checker
├── llm_frontend.py # PyTorch → IR via LLM + feedback loop
└── tir_lifter.py # TileLang TIR → IR lifter
tests/ # 660 tests across 13 files
examples/ # 9 runnable demos
benchmarks/ # 3 suites
docs/ # Architecture & API docs
pip install -r requirements.txt # z3-solver, pytest
python -m pytest tests/ -v # run all 660 testsfrom verifier import *
mesh = DeviceMesh(shape=(2,), dim_names=("tp",))
x_spec = ShardingSpec(placements=(Shard(dim=1),), mesh=mesh)
w_spec = ShardingSpec(placements=(Shard(dim=0),), mesh=mesh)
x = TensorState("x", (8, 16), compute_local_shape((8, 16), x_spec), x_spec, expr="x")
w = TensorState("w", (16, 32), compute_local_shape((16, 32), w_spec), w_spec, expr="w")
program = Program("tp_linear", ops=[
MatMul(a="x", b="w", output="y_partial"),
AllReduce(x="y_partial", output="y"),
])
executor = MultiDeviceExecutor(mesh=mesh)
executor.register_tensor(x)
executor.register_tensor(w)
final = executor.run_program(program)
verifier = DistributedVerifier()
results = verifier.verify_all(program, final)
print(verifier.summary())
# Verification Summary: 4 passed, 0 failedfrom verifier import *
from verifier.temporal import verify_temporal
prog = Program("race", ops=[
MatMul(a="x", b="w", output="y_p"),
AllReduceAsync(x="y_p", output="y", handle="h1"),
MatMul(a="y", b="w2", output="z"), # BUG: reads y before Wait!
])
result = verify_temporal(prog)
# Detected: MISSING_WAIT on tensor 'y'from verifier import *
mesh = DeviceMesh(shape=(2,), dim_names=("tp",))
w_spec = ShardingSpec(placements=(Shard(dim=0),), mesh=mesh)
ids_spec = ShardingSpec(placements=(Replicate(),), mesh=mesh)
ids = TensorState("ids", (32,), (32,), ids_spec, expr="ids")
W = TensorState("W", (50000, 128), compute_local_shape((50000, 128), w_spec), w_spec, expr="W")
program = Program("vocab_parallel", ops=[
Embedding(indices="ids", weight="W", output="emb_partial"),
AllReduce(x="emb_partial", output="emb"),
])
executor = MultiDeviceExecutor(mesh=mesh)
executor.register_tensor(ids)
executor.register_tensor(W)
final = executor.run_program(program)
assert not final["emb"].partial # Output is Replicate after AllReduce# Tensor Parallelism
python examples/tp_linear.py # Row-parallel: correct vs bug detection
python examples/tp_mlp.py # Megatron MLP: Column+Row parallel
# Pipeline Parallelism
python examples/pp_2stage.py # 2-stage 1F1B schedule verification
# Context Parallelism
python examples/cp_ring_attn.py # Ring attention with FlashAttention
# Temporal / Async
python examples/overlap_demo.py # Race, missing wait, buffer aliasing
# SPMD Type System
python examples/spmd_demo.py # R/I/V/P types + gradient duality
# Synthesis
python examples/synthesis_demo.py # Auto-synthesis from unannotated compute
# End-to-End
python examples/megatron_gpt2_verify.py # Full Megatron GPT-2 verification
python examples/consistency_demo.py # Single-GPU equivalence proofFour local types with gradient duality:
| Type | Meaning | Gradient Dual |
|---|---|---|
| R (Replicate) | Same data on all ranks | P (Partial) |
| I (Invariant) | Same data, no gradient comm needed | I (Invariant) |
| V (Varying) | Different data per rank | V (Varying) |
| P (Partial) | Pending sum across ranks | R (Replicate) |
Key rules enforced by SPMDGuard:
Partial × Partialis forbidden (doesn't distribute over pending sum)AllReduce(Replicate)is an error (no pending sum exists)AllReduce(Invariant)is an error (gradient already identical)
| Parallelism Strategy | Spatial Checks | Temporal Checks |
|---|---|---|
| Tensor Parallelism | Placement, postcondition, shape | Async AllReduce races |
| Pipeline Parallelism | Send/Recv matching, deadlock | 1F1B schedule ordering |
| Context Parallelism | Ring placement propagation | Async ring communication |
| Data Parallelism | Gradient duality | Async gradient AllReduce |
| TP + DP (multi-dim mesh) | Per-dim placement, mesh_dim targeting | Per-dim async races |
| ZeRO-1/2/3 | Shard/gather consistency | ReduceScatter ordering |
| MoE (Expert Parallel) | AllToAll dispatch/combine | Token routing races |
| Mixed Precision (FP8) | Scale freshness, format usage | Delayed scaling ordering |
- Symbolic over numeric — verify properties for ALL possible inputs, not samples
- LLM proposes, Verifier decides — Z3 is the final authority, not model confidence
- Spatial × Temporal — independent dimensions, composed results
- Every op has VJP — forward and backward are equally verifiable
- Zero dependencies beyond Z3 — no PyTorch, no CUDA, no cluster needed
Bugs sourced from real distributed training issues:
| Repository | Issues |
|---|---|
| pytorch/pytorch | #144359, #173041, #175690, #139681, #140227 |
| NVIDIA/Megatron-LM | #4092, #3952, #1525, #4382 |
| tile-ai/tilelang | #2035, #2042, #2054, #2158, #2172 |
| triton-lang/triton | #9991, #9963, #10106, #10176 |
| deepseek-ai/TileKernels | #2 |