A hands-on LLM inference systems repo built from first principles: KV cache, paged KV allocation, Triton kernels, FlashAttention-style tiled attention, PagedAttention-style decode, continuous batching simulation, INT8 weight-only quantization, and speculative decoding.
This repo is intentionally educational. The goal is not to replace production kernels from FlashAttention, vLLM, PyTorch SDPA, Marlin, cuBLAS, TensorRT-LLM, or vendor libraries. The goal is to make the core mechanisms visible, runnable, benchmarkable, and explainable.
Built by Anay Dongre while studying LLM inference systems, GPU kernels, quantized inference, and serving internals.
Modern LLM inference speed is controlled by a few recurring ideas:
- KV cache turns decode from recomputing the entire prompt into appending one K/V token per layer.
- Kernel fusion removes avoidable global-memory traffic for memory-bound operations like RMSNorm, RoPE, SwiGLU activation, masked softmax, and dequantization.
- FlashAttention avoids materializing the full
[T, T]attention score/probability matrices by streaming K/V tiles and using online softmax. - PagedAttention avoids requiring contiguous KV cache allocation by reading K/V through block-table indirection.
- Continuous batching improves serving efficiency by evicting finished requests immediately and admitting new requests while other requests continue decoding.
- Weight-only quantization reduces model weight bandwidth by storing weights in INT8/INT4 while keeping activations in FP16/BF16.
- Speculative decoding reduces expensive target-model forward passes by using a cheaper draft model to propose multiple tokens and the target model to verify them.
- GEMM is different: matmul is usually best delegated to cuBLAS/cuBLASLt or specialized tensor-core kernels because vendor kernels exploit tensor cores, deep tiling, pipelining, and algorithm selection.
This repo implements each of those ideas directly.
| Area | File(s) | What it demonstrates |
|---|---|---|
| KV cache | 01_kv_cache/ |
TinyGPT, prefill vs decode, cached generation, TTFT/TPOT-style timing |
| Paged KV simulation | 02_paged_kv/ |
Logical blocks, physical blocks, block tables, scattered KV reads |
| Triton fundamentals | 03_triton_kernels/vector_add.py, matmul.py |
Program IDs, offsets, masks, strides, blocked matmul |
| Fused row/element kernels | softmax.py, causal_softmax.py, rmsnorm.py, swiglu.py, rope.py |
Fusion, memory traffic, row-wise reductions, real LLM ops |
| FlashAttention | 04_flash_attention/ |
Online softmax, tiled Q/K/V streaming, no [T,T] materialization |
| PagedAttention decode | 05_paged_attention/paged_decode.py |
Block-table lookup, scattered physical KV reads, decode-time attention |
| Continuous batching | 06_continuous_batching/ |
Static vs continuous batching, request admission, dynamic eviction, latency/work comparison |
| Quantization | 07_quantization/ |
INT8 weight-only quantization, per-channel scales, fused dequant + matmul |
| Speculative decoding | 08_speculative_decoding/ |
Draft/verify loop, acceptance rate, target-forward reduction, correctness vs target greedy |
| Benchmarks | benchmarks/results.md |
Tesla T4, RTX 5070 Laptop GPU, and CPU simulator measurements collected during development |
Benchmarks were run during development across a Tesla T4, an RTX 5070 Laptop GPU, and CPU-only simulators. Exact timings vary by GPU, CUDA stack, PyTorch/Triton versions, dtype, shape, and kernel path. These numbers are meant as learning artifacts, not production claims.
| Kernel / System | Best observed result | Key insight |
|---|---|---|
| RMSNorm | ~3.3x to 3.5x vs naive PyTorch | Fuse square, reduction, rsqrt, scale, weight multiply |
| Fused SwiGLU activation | ~1.68x vs composed PyTorch | Fuse silu(gate) * up; keep GEMMs in cuBLAS |
| RoPE | ~5x to 6.5x vs PyTorch slicing baseline | Pairwise rotation is layout-sensitive and highly fusible |
| Causal softmax | Up to ~3.3x on larger T | Fuse causal masking into softmax, avoid masked-score materialization |
| FlashAttention-style forward | ~1.11x best FP32 config on T4 | Online softmax avoids a 64 MB score matrix at B=1,H=16,T=1024,D=64 |
| PagedAttention-style decode | ~2.14x smaller decode case, ~36% KV memory savings | Block-table indirection enables scattered physical KV reads |
| Continuous batching simulator | 3.19x avg latency improvement, 1.97x total-time speedup, 40.86% work reduction | Evict finished requests and admit new ones dynamically |
| INT8 weight-only linear | 2.00x weight memory reduction, 1.34x vs PyTorch dequant + matmul | Fused dequant avoids materializing the full W_deq matrix |
| Speculative decoding simulator | 4.57x target-forward reduction with a good draft, 1.78x with a bad draft | Speedup depends on draft-token acceptance rate |
The most important result is conceptual: these implementations make the memory path, compute path, and scheduling path explicit.
The INT8 quantization experiment was run on an NVIDIA GeForce RTX 5070 Laptop GPU using:
torch: 2.7.0+cu128
triton: 3.3.0
cuda: 12.8
compute capability: (12, 0)
Run command:
python 07_quantization/quant_linear_triton.py --M 32 --K 4096 --N 11008 --iters 10 --block-m 8 --block-n 32 --block-k 32Shape:
x: [32, 4096]
W: [4096, 11008]
Results:
| Metric | Result |
|---|---|
| Max error vs PyTorch dequant reference | 0.000017 |
| Mean error | 0.000001 |
| Relative error | 0.000005 |
| FP16 weight memory | 86.00 MB |
| INT8 + scales memory | 43.04 MB |
| Memory reduction | 2.00x |
| PyTorch pre-dequantized FP16 matmul | 0.4869 ms |
| PyTorch dequant + matmul | 2.9091 ms |
| Triton fused dequant + matmul | 2.1667 ms |
| Triton vs PyTorch dequant + matmul | 1.34x |
| Triton vs pre-dequantized FP16 GEMM | 0.22x |
Important note:
This kernel uses a SIMT fallback path that avoids tl.dot.
On the RTX 5070 Laptop GPU / compute capability 12.0, Triton 3.3.0 failed while lowering tl.dot through the matmul acceleration pass. The fallback version avoids that compiler path and still demonstrates fused dequantization correctly.
Because it avoids tl.dot and tensor-core lowering, it is not expected to beat cuBLAS-backed FP16 matmul.
The correct takeaway:
The Triton kernel proves that fused dequantization avoids W_deq materialization and beats the naive PyTorch dequantize-then-matmul baseline.
It does not yet prove production INT4/INT8 speedups over cuBLAS, because that requires tensor-core-friendly quantized kernels such as Marlin-style kernels.
The speculative decoding demo is CPU-only and uses a toy greedy bigram target model plus a noisier/corrupted draft model. The target model remains the source of truth.
The demo verifies that speculative decoding produces the same output as target-only greedy decoding while reducing the number of expensive target-model forward passes.
Run command:
python 08_speculative_decoding/speculative_decode.py --draft-k 4 --draft-noise 0.10 --corrupt-frac 0.05Results:
| Metric | Result |
|---|---|
| Generated tokens | 64 |
| Output matches target greedy | True |
| Target-only forwards | 64 |
| Speculative target forwards | 14 |
| Draft forwards | 55 |
| Proposed draft tokens | 55 |
| Accepted draft tokens | 51 |
| Rejected draft tokens | 1 |
| Acceptance rate | 92.73% |
| Tokens per target forward | 4.57 |
| Target forward reduction | 4.57x |
Run command:
python 08_speculative_decoding/speculative_decode.py --draft-k 4 --draft-noise 0.70 --corrupt-frac 0.30Results:
| Metric | Result |
|---|---|
| Generated tokens | 64 |
| Output matches target greedy | True |
| Target-only forwards | 64 |
| Speculative target forwards | 20 |
| Draft forwards | 77 |
| Proposed draft tokens | 77 |
| Accepted draft tokens | 45 |
| Rejected draft tokens | 12 |
| Acceptance rate | 58.44% |
| Tokens per target forward | 3.20 |
| Target forward reduction | 3.20x |
Run command:
python 08_speculative_decoding/speculative_decode.py --draft-k 4 --draft-noise 1.5 --corrupt-frac 0.70Results:
| Metric | Result |
|---|---|
| Generated tokens | 64 |
| Output matches target greedy | True |
| Target-only forwards | 64 |
| Speculative target forwards | 36 |
| Draft forwards | 141 |
| Proposed draft tokens | 141 |
| Accepted draft tokens | 29 |
| Rejected draft tokens | 33 |
| Acceptance rate | 20.57% |
| Tokens per target forward | 1.78 |
| Target forward reduction | 1.78x |
The main lesson:
Speculative decoding does not change the target model's output in this greedy demo. The target model remains the source of truth.
The draft model proposes several tokens cheaply. The target model verifies those tokens in one pass. If many draft tokens are accepted, each target forward produces multiple final tokens. If many draft tokens are rejected, speculative decoding collapses toward normal decoding plus wasted draft work.
FlashAttention is an attention-compute optimization. It avoids storing full [T, T] scores and probabilities by streaming K/V tiles and maintaining online softmax state:
m = running max per query row
l = running softmax denominator per query row
acc = running output numerator per query row
out = acc / l
PagedAttention is a serving-memory layout optimization. It avoids requiring each request's KV cache to be physically contiguous:
logical token position -> logical block -> block table -> physical block -> K/V read
PagedAttention answers:
Where is this request's KV cache stored?
It uses fixed-size physical blocks and block tables so KV memory does not need to be contiguous.
Continuous batching answers:
Which requests should run in this decode step?
It keeps the active batch full by removing completed requests and admitting new waiting requests dynamically.
Together, they explain a large part of modern LLM serving systems:
PagedAttention = KV memory layout
Continuous batching = serving scheduler
FlashAttention = attention compute
Continuous batching improves serving throughput across many requests:
Which requests should be active in this decode step?
Speculative decoding improves per-request generation efficiency:
Can one target-model verification step produce multiple accepted tokens?
They solve different problems and can be used together.
In a real serving engine:
continuous batching schedules many active requests
speculative decoding changes how each request advances
paged KV / TransKV-style logic manages the extra KV-cache pressure from draft and target verification paths
Quantized inference has two different problems:
1. Accuracy problem:
How do we compress weights without hurting model quality?
2. Kernel problem:
How do we run the compressed weights faster than FP16?
Methods like AWQ and GPTQ mainly address the accuracy problem.
Specialized kernels like Marlin mainly address the kernel problem.
The mental model:
AWQ / GPTQ = decide what quantized numbers to store
Marlin = multiply those quantized numbers fast
GGUF = model storage/runtime format in llama.cpp ecosystem
bitsandbytes = convenient runtime/library path for quantized loading
Triton helped most for operations where a naive PyTorch expression creates intermediate tensors or multiple memory passes:
- RMSNorm
- RoPE
- SwiGLU activation
- causal masked softmax
- fused dequantization
Triton did not beat cuBLAS-style matmul in this repo. That is expected. GEMM is one of the most optimized operations in the CUDA ecosystem.
A normal LLM linear layer computes:
y = x @ W
Typical shapes:
x: [M, K]
W: [K, N]
y: [M, N]
For example:
x: [32, 4096]
W: [4096, 11008]
y: [32, 11008]
In FP16 inference, weights usually use 2 bytes per value:
4096 * 11008 * 2 bytes ≈ 86 MB
Weight-only quantization compresses only the weights:
x stays FP16/BF16
W becomes INT8 or INT4
The operation becomes conceptually:
y = x @ dequant(W_quantized)
For INT8:
W_fp16 -> W_int8 + scales
For INT4:
W_fp16 -> packed_W_int4 + scales
This is called weight-only quantization because only the model weights are quantized. Runtime activations stay floating point.
That matters because activations are dynamic. They depend on the prompt, layer, token position, batch, and generated sequence. Quantizing activations safely is harder because their range can change at runtime. Weights are fixed after training, so they can be quantized offline.
The main benefit is memory bandwidth reduction:
| Format | Bytes per weight | Approx reduction vs FP16 |
|---|---|---|
| FP16 | 2.0 | 1.0x |
| INT8 | 1.0 | 2.0x |
| INT4 | 0.5 | 4.0x |
The main tradeoff:
Quantization reduces memory bandwidth.
Dequantization adds extra math.
A quantized kernel is useful only if the memory savings outweigh the dequantization overhead.
INT8 can only represent a small integer range:
-128 to 127
Neural network weights are floating-point values:
0.018
-0.241
1.372
A scale maps floating-point values to integer values.
Quantization:
W_q = round(W_fp / scale)
Dequantization:
W_deq = W_q * scale
Without the scale, an integer value has no useful meaning.
Example:
scale = 0.01
W_q = 25
W_deq = 25 * 0.01 = 0.25
The simplest scheme uses one scale for the whole matrix:
one scale for all values in W
This is simple but often too crude. Different output channels can have very different ranges.
One output column may contain small values:
[0.01, -0.03, 0.02, ...]
Another output column may contain larger values:
[0.8, -1.1, 0.5, ...]
A single global scale may preserve the large values but destroy precision in the small-value channel.
Per-channel quantization gives each output channel its own scale.
For a linear layer:
W: [K, N]
The output channels are the columns:
W[:, 0], W[:, 1], ..., W[:, N-1]
So per-output-channel scales have shape:
scales: [N]
Each column gets its own range:
scale[j] = max(abs(W[:, j])) / 127
W_q[:, j] = round(W[:, j] / scale[j])
W_deq[:, j] = W_q[:, j] * scale[j]
This repo implements per-output-channel INT8 scales.
For the benchmark shape:
W: [4096, 11008]
scales: [11008]
The scale overhead is tiny compared with the number of weights:
scales: 11008 values
weights: 4096 * 11008 = 45,088,768 values
So per-channel scaling usually improves accuracy with negligible memory overhead.
The simple PyTorch reference is:
W_deq = W_q.float() * scales
y = x.float() @ W_deqThis is easy to understand, but it creates a full dequantized matrix in GPU memory.
For the benchmark shape:
W_q: [4096, 11008] INT8
W_deq: [4096, 11008] FP16 or FP32
If W_deq is FP16:
4096 * 11008 * 2 bytes ≈ 86 MB
If W_deq is FP32:
4096 * 11008 * 4 bytes ≈ 172 MB
The naive path does unnecessary memory work:
1. Read compressed W_q.
2. Read scales.
3. Write full W_deq to global memory.
4. Read W_deq again for matmul.
5. Compute y.
That partly defeats the point of quantization.
The fused Triton kernel avoids this by working tile by tile:
1. Load a small tile of W_q.
2. Load the scale values for that tile's output columns.
3. Convert INT8 values to floating point inside the kernel.
4. Multiply by scales inside the kernel.
5. Immediately multiply with x.
6. Accumulate the output.
7. Store only y.
Conceptually:
for output tile:
acc = 0
for K tile:
x_tile = load x
w_q_tile = load W_q
scales_tile = load scales
w_deq_tile = w_q_tile * scales_tile
acc += x_tile @ w_deq_tile
store acc
The important point:
w_deq_tile exists only temporarily inside the kernel.
The full W_deq matrix is never written to global memory.
That is why the Triton kernel beats the PyTorch dequantize-then-matmul baseline:
PyTorch dequant + matmul: 2.9091 ms
Triton fused dequant + matmul: 2.1667 ms
Speedup: 1.34x
But it is still slower than pre-dequantized FP16 cuBLAS matmul because the current compatibility kernel avoids tl.dot and tensor-core lowering:
PyTorch pre-dequantized FP16 matmul: 0.4869 ms
Triton fused dequant + matmul: 2.1667 ms
So the correct lesson is:
Fusing dequantization avoids W_deq materialization.
To beat FP16 GEMM, the quantized kernel also needs optimized tensor-core usage, packed low-bit weight layouts, and architecture-aware scheduling.
AWQ and GPTQ are quantization methods. Marlin is an optimized inference kernel.
They solve different parts of the quantized inference problem.
AWQ stands for Activation-aware Weight Quantization.
It is a low-bit weight-only quantization method for LLMs. Its key observation is that not all weights are equally important, and activation statistics help identify which channels matter most.
AWQ uses calibration activations to protect important channels through scaling.
In simple terms:
AWQ decides how to quantize the weights while preserving model quality.
It answers:
How should the weights be scaled and quantized so accuracy stays high?
AWQ is mainly about quantization quality.
It does not automatically make inference fast unless the runtime has an efficient kernel that can consume the AWQ weight format.
GPTQ is another post-training weight quantization method.
It uses approximate second-order information to quantize weights while compensating for quantization error.
In simple terms:
GPTQ decides how to round and compress weights while minimizing output error.
It answers:
How can we quantize a large transformer layer with low accuracy loss?
Like AWQ, GPTQ is mainly about quantization quality. It still needs an optimized runtime kernel for speed.
Marlin is different.
Marlin is not mainly a quantization algorithm. It is an optimized mixed-precision inference kernel.
Its target operation is roughly:
FP16 activations × INT4 weights -> FP16/FP32 output
Marlin answers:
How do we multiply FP16 activations by INT4 weights very fast on GPU?
That requires more than loading INT4 weights and multiplying by scales.
A production INT4 kernel needs specialized layout and scheduling.
INT4 uses 4 bits per weight.
That means two weights fit in one byte:
byte = [high 4-bit weight][low 4-bit weight]
Before multiplying, the kernel must unpack:
packed byte -> two int4 values -> dequantized values
If unpacking is too expensive, the memory savings do not become speedups.
Quantized weights usually need metadata:
scales
zero-points
group sizes
packing layout
A common INT4 format may use one scale per group, such as every 128 weights.
The kernel must apply:
w_deq = scale * (w_q - zero_point)
without adding too much overhead.
cuBLAS-level speed comes from tensor cores.
Tensor cores operate on specific tile shapes and data layouts. A fast INT4 kernel must arrange data so that the GPU can feed tensor cores efficiently.
That means:
weights are packed in a tensor-core-friendly layout
loads are coalesced
tiles are scheduled for reuse
shared memory/register usage is controlled
warps are assigned carefully
The layout that is convenient for storage is not always the layout that is fastest for tensor cores.
A slow quantized path does this:
INT4 weights -> full FP16 W_deq matrix -> normal GEMM
That materializes the expanded weight matrix and loses much of the memory benefit.
A fast path does this:
load packed INT4 tile
unpack/dequantize inside the kernel
feed values into tensor-core MMA pipeline
accumulate output
The dequantized values should be temporary. They should live in registers or shared memory long enough to be consumed, not written as a full matrix to global memory.
Autoregressive decode often has:
M = active batch size
K = hidden dimension
N = output dimension
For example:
x: [32, 4096]
W: [4096, 11008]
This is a hard performance regime because M is relatively small while K and N are large.
Specialized kernels like Marlin are designed to stay efficient in this LLM decode regime.
This is why combinations like these exist:
Marlin-AWQ
Marlin-GPTQ
That means:
AWQ or GPTQ determines the quantized weights.
Marlin provides the fast CUDA kernel to run those weights.
Full stack:
Quantization method:
AWQ / GPTQ
Weight format:
packed INT4 weights + scales/metadata
Kernel:
Marlin-style FP16 x INT4 matmul
Runtime:
vLLM / TensorRT-LLM / custom server
Main mental model:
AWQ/GPTQ decide what numbers to store.
Marlin decides how to multiply those numbers fast.
This repo currently implements the first kernel-level building block:
FP16 activations x INT8 weights
per-output-channel scales
fused dequantization
manual SIMT accumulation fallback
It does not yet implement:
INT4 packing
group-wise scales
zero-points
tensor-core MMA layout
Marlin-style scheduling
AWQ search
GPTQ reconstruction/error compensation
That is intentional. The learning progression is:
Step 1: INT8 per-channel quantization
Step 2: fused dequant + matmul
Step 3: INT4 packing/unpacking
Step 4: group-wise scales
Step 5: tensor-core-friendly layouts
Step 6: AWQ/GPTQ-compatible formats
Step 7: Marlin-style optimized kernel
The continuous batching simulator is CPU-only. It compares a simple static batching baseline against a dynamic continuous batching scheduler.
Static batching keeps a batch together until the slowest request finishes. This wastes decode work on requests that already completed and pads prompt work to the longest prompt in the batch.
Continuous batching evicts finished requests immediately and admits new waiting requests into freed batch slots.
Example output:
======================================================================
Continuous batching
======================================================================
total_time: 133
completed: 40
total_tokens_processed: 11563
generated_tokens: 880
tokens_per_step: 86.94
avg_latency: 39.75
p95_latency: 90
======================================================================
Static batching
======================================================================
total_time: 262
completed: 40
total_tokens_processed: 19552
generated_tokens: 880
tokens_per_step: 74.63
avg_latency: 126.65
p95_latency: 225
======================================================================
Comparison
======================================================================
Latency speedup: 3.19x
Total time speedup: 1.97x
Work reduction: 40.86%
The scheduler does not know true output length upfront. In the simulator, hidden_actual_output_len is used only to emulate when a request would naturally hit EOS. The scheduler only sees max_new_tokens, current generated token count, and request status.
Normal autoregressive decoding uses the target model once per generated token:
target forward -> one token
target forward -> one token
target forward -> one token
...
Speculative decoding uses two models:
draft model = smaller/cheaper model that proposes tokens
target model = larger/source-of-truth model that verifies tokens
The loop is:
1. Draft model proposes k tokens autoregressively.
2. Target model verifies those k tokens in one forward pass.
3. Matching draft tokens are accepted.
4. At the first mismatch, the target token is used instead.
5. If all draft tokens are accepted, the target can provide one additional bonus token.
The key property:
The target model remains the source of truth.
In this repo's greedy toy demo, the speculative output exactly matches target-only greedy decoding for good, medium, and bad draft settings.
Speculative decoding helps when:
draft model is much cheaper than target model
draft predictions are frequently accepted
target verification can check multiple positions in parallel
KV-cache overhead is managed carefully
Speculative decoding hurts or gives weak gains when:
draft model is too inaccurate
acceptance rate is low
draft overhead is too high
KV-cache management overhead dominates
Speculative decoding creates extra KV-cache complexity.
Normal decode appends one token at a time:
target accepts one token -> append one token's KV
Speculative decoding may temporarily evaluate several draft tokens:
draft proposes k tokens
target verifies k tokens
some prefix is accepted
remaining draft tokens are rejected
That means the system must handle temporary or rollbackable KV states. If a draft token is rejected, its KV entries should not permanently remain in the accepted target sequence.
This is why speculative decoding connects directly to KV-cache memory-management work:
Accepted draft tokens -> keep/commit KV
Rejected draft suffix -> discard/free/rollback KV
Target correction token -> append correct KV
The serving system must make this efficient, especially under continuous batching where many requests may speculate concurrently.
inference-kernels/
├── README.md
├── requirements.txt
├── 01_kv_cache/
│ ├── tiny_gpt.py
│ ├── cached_generation.py
│ └── benchmark.py
├── 02_paged_kv/
│ ├── block_manager.py
│ └── paged_cache_demo.py
├── 03_triton_kernels/
│ ├── vector_add.py
│ ├── matmul.py
│ ├── softmax.py
│ ├── causal_softmax.py
│ ├── rmsnorm.py
│ ├── swiglu.py
│ └── rope.py
├── 04_flash_attention/
│ ├── online_softmax.py
│ └── flash_attention.py
├── 05_paged_attention/
│ └── paged_decode.py
├── 06_continuous_batching/
│ ├── scheduler.py
│ └── simulate.py
├── 07_quantization/
│ ├── README.md
│ └── quant_linear_triton.py
├── 08_speculative_decoding/
│ └── speculative_decode.py
├── benchmarks/
│ └── results.md
└── scripts/
└── check_syntax.py
Use a CUDA machine with PyTorch and Triton installed for the GPU kernels. Colab with a T4 works for most scripts. RTX 50-series users should use a recent CUDA/PyTorch stack.
pip install -r requirements.txtFor the RTX 5070 Laptop GPU test environment used in the quantization benchmark:
python -m pip install torch==2.7.0 --index-url https://download.pytorch.org/whl/cu128
python -m pip install numpyCheck syntax for all files:
python scripts/check_syntax.pyRun individual GPU/Triton scripts:
python 03_triton_kernels/rmsnorm.py
python 03_triton_kernels/swiglu.py
python 03_triton_kernels/rope.py
python 04_flash_attention/flash_attention.py
python 05_paged_attention/paged_decode.py
python 07_quantization/quant_linear_triton.pyRun CPU-safe scripts:
python 01_kv_cache/cached_generation.py
python 02_paged_kv/paged_cache_demo.py
python 04_flash_attention/online_softmax.py
python 06_continuous_batching/simulate.py
python 08_speculative_decoding/speculative_decode.pyRun speculative decoding quality sweep:
python 08_speculative_decoding/speculative_decode.py --draft-k 4 --draft-noise 0.10 --corrupt-frac 0.05
python 08_speculative_decoding/speculative_decode.py --draft-k 4 --draft-noise 0.70 --corrupt-frac 0.30
python 08_speculative_decoding/speculative_decode.py --draft-k 4 --draft-noise 1.5 --corrupt-frac 0.70The repo is ordered to build intuition:
- KV cache: prefill, decode, cached generation, KV memory formula.
- Paged KV simulation: physical blocks, logical blocks, request block tables.
- Triton basics: vector add, masks, program IDs, blocked matmul.
- Fused kernels: softmax, causal softmax, RMSNorm, SwiGLU activation, RoPE.
- FlashAttention: online softmax and tiled attention without
[T,T]intermediate matrices. - PagedAttention decode: block-table lookup over scattered KV blocks.
- Continuous batching: request scheduling, dynamic eviction, admission, latency/work comparison.
- Quantized inference: weight-only quantization, per-channel scales, fused dequantization, and why production INT4 needs specialized tensor-core kernels.
- Speculative decoding: draft/verify loop, acceptance rate, target-forward reduction, and KV-cache implications.
Decode usually processes one new token per request. The model must read weights and read all cached K/V for attention, but the batch/token dimension is small compared with prefill. This makes memory bandwidth and KV-cache layout central bottlenecks.
Fusion helps when separate PyTorch ops read/write large intermediate tensors. A fused Triton kernel keeps intermediate values in registers and writes only the final result. RMSNorm, RoPE, SwiGLU activation, causal masked softmax, and dequantization are good examples.
Because GEMM is compute-heavy and vendor libraries are heavily optimized for tensor cores, tiling, memory hierarchy, and scheduling. Custom kernels are more valuable around GEMMs: fusing activations, normalization, position encodings, quantization/dequantization, and custom attention layouts.
It streams K/V tiles, computes score tiles in SRAM/registers, applies online softmax, accumulates the output, and never writes the full [T,T] score/probability matrices to global memory.
Each request has a block table mapping logical KV blocks to physical KV blocks. Decode-time attention reads K/V through this block table, allowing non-contiguous KV allocation and reduced fragmentation/padding waste.
A scheduler keeps a waiting queue and an active batch. New requests are admitted when capacity is available. Prefill runs once per request. Decode then runs repeatedly, one token per active request per step. When a request hits EOS, a stop condition, or max_new_tokens, it is evicted immediately and a new waiting request can enter the active batch.
Requests finish at different, unpredictable times. Static batching waits for the slowest request in the batch, so completed requests waste slots. Continuous batching frees those slots immediately, improving average latency, p95 latency, and overall work efficiency.
Weight-only quantization stores model weights in a smaller integer format, such as INT8 or INT4, while keeping runtime activations in FP16/BF16. During inference, weights are dequantized and multiplied with the floating-point activations. It reduces weight memory bandwidth, which is important because LLM inference repeatedly reads large weight matrices.
Different output channels can have different weight ranges. A single global scale can overfit the largest values and lose precision in smaller channels. Per-channel scales give each output channel its own numeric range, reducing quantization error with very little metadata overhead.
If the runtime first creates a full W_deq matrix and then runs matmul, it writes and rereads a large expanded weight matrix. Fused dequantization avoids this by dequantizing only small weight tiles inside the kernel and immediately consuming them for matmul.
AWQ and GPTQ decide how to quantize weights with low accuracy loss. They do not automatically make inference fast. To get speedups, the runtime needs specialized kernels that can unpack low-bit weights, apply scales/zero-points, feed tensor cores efficiently, and avoid materializing dequantized weights. Marlin-style kernels solve this kernel-side problem.
A smaller draft model proposes multiple tokens. The larger target model verifies those proposed tokens in one forward pass. Matching draft tokens are accepted. At the first mismatch, the target token is used instead. If all proposed tokens are accepted, the target can provide an extra bonus token. The target model remains the source of truth, so output correctness is preserved.
Speculative decoding helps when the draft model is much cheaper than the target model and has a high acceptance rate. A high acceptance rate means each expensive target forward can produce multiple final tokens. If the draft model is poor, the target rejects often and the speedup collapses.
Speculative decoding creates temporary KV-cache states. Accepted draft tokens can be committed, but rejected draft suffixes must be discarded or rolled back. The target correction token then needs to be appended. This makes speculative decoding closely tied to KV-cache memory management and paged allocation.
This is a learning repo, not a production inference engine.
- Kernels are forward-only.
- FlashAttention implementation is simplified and lacks production optimizations like causal block skipping, software pipelining, advanced scheduling, and tuned tensor-core paths.
- PagedAttention decode kernel demonstrates block-table indirection but does not implement a full production allocator or serving runtime.
- Continuous batching simulator is CPU-only and models scheduling behavior, not actual GPU execution.
- INT8 quantization kernel demonstrates fused dequantization, but the RTX 50-series compatibility path avoids
tl.dotand tensor-core lowering. - Speculative decoding demo is CPU-only and uses toy greedy bigram models to explain the draft/verify loop.
- INT4 packing, group-wise scales, zero-points, AWQ/GPTQ formats, and Marlin-style scheduling are not implemented yet.
- Benchmarks are development runs and should be re-run on the target GPU.
08_speculative_decoding/README.mdwith a deeper explanation of draft/verify, acceptance rate, and KV-cache rollback.- INT4 packing/unpacking demo.
- Group-wise scales and zero-points.
- Tensor-core-friendly quantized matmul path when Triton SM120 support is stable.
- Mini inference server/capstone that ties KV cache, paged allocation, scheduling, quantization, speculative decoding, and decode together.
- Nsight Compute profiling notes.
Anay Dongre