Skip to content

MrAnayDongre/Inference-Kernels

Repository files navigation

Inference Kernels from Scratch

A hands-on LLM inference systems repo built from first principles: KV cache, paged KV allocation, Triton kernels, FlashAttention-style tiled attention, PagedAttention-style decode, continuous batching simulation, INT8 weight-only quantization, and speculative decoding.

This repo is intentionally educational. The goal is not to replace production kernels from FlashAttention, vLLM, PyTorch SDPA, Marlin, cuBLAS, TensorRT-LLM, or vendor libraries. The goal is to make the core mechanisms visible, runnable, benchmarkable, and explainable.

Built by Anay Dongre while studying LLM inference systems, GPU kernels, quantized inference, and serving internals.


Why this repo exists

Modern LLM inference speed is controlled by a few recurring ideas:

  • KV cache turns decode from recomputing the entire prompt into appending one K/V token per layer.
  • Kernel fusion removes avoidable global-memory traffic for memory-bound operations like RMSNorm, RoPE, SwiGLU activation, masked softmax, and dequantization.
  • FlashAttention avoids materializing the full [T, T] attention score/probability matrices by streaming K/V tiles and using online softmax.
  • PagedAttention avoids requiring contiguous KV cache allocation by reading K/V through block-table indirection.
  • Continuous batching improves serving efficiency by evicting finished requests immediately and admitting new requests while other requests continue decoding.
  • Weight-only quantization reduces model weight bandwidth by storing weights in INT8/INT4 while keeping activations in FP16/BF16.
  • Speculative decoding reduces expensive target-model forward passes by using a cheaper draft model to propose multiple tokens and the target model to verify them.
  • GEMM is different: matmul is usually best delegated to cuBLAS/cuBLASLt or specialized tensor-core kernels because vendor kernels exploit tensor cores, deep tiling, pipelining, and algorithm selection.

This repo implements each of those ideas directly.


What is included

Area File(s) What it demonstrates
KV cache 01_kv_cache/ TinyGPT, prefill vs decode, cached generation, TTFT/TPOT-style timing
Paged KV simulation 02_paged_kv/ Logical blocks, physical blocks, block tables, scattered KV reads
Triton fundamentals 03_triton_kernels/vector_add.py, matmul.py Program IDs, offsets, masks, strides, blocked matmul
Fused row/element kernels softmax.py, causal_softmax.py, rmsnorm.py, swiglu.py, rope.py Fusion, memory traffic, row-wise reductions, real LLM ops
FlashAttention 04_flash_attention/ Online softmax, tiled Q/K/V streaming, no [T,T] materialization
PagedAttention decode 05_paged_attention/paged_decode.py Block-table lookup, scattered physical KV reads, decode-time attention
Continuous batching 06_continuous_batching/ Static vs continuous batching, request admission, dynamic eviction, latency/work comparison
Quantization 07_quantization/ INT8 weight-only quantization, per-channel scales, fused dequant + matmul
Speculative decoding 08_speculative_decoding/ Draft/verify loop, acceptance rate, target-forward reduction, correctness vs target greedy
Benchmarks benchmarks/results.md Tesla T4, RTX 5070 Laptop GPU, and CPU simulator measurements collected during development

Benchmark highlights

Benchmarks were run during development across a Tesla T4, an RTX 5070 Laptop GPU, and CPU-only simulators. Exact timings vary by GPU, CUDA stack, PyTorch/Triton versions, dtype, shape, and kernel path. These numbers are meant as learning artifacts, not production claims.

Kernel / System Best observed result Key insight
RMSNorm ~3.3x to 3.5x vs naive PyTorch Fuse square, reduction, rsqrt, scale, weight multiply
Fused SwiGLU activation ~1.68x vs composed PyTorch Fuse silu(gate) * up; keep GEMMs in cuBLAS
RoPE ~5x to 6.5x vs PyTorch slicing baseline Pairwise rotation is layout-sensitive and highly fusible
Causal softmax Up to ~3.3x on larger T Fuse causal masking into softmax, avoid masked-score materialization
FlashAttention-style forward ~1.11x best FP32 config on T4 Online softmax avoids a 64 MB score matrix at B=1,H=16,T=1024,D=64
PagedAttention-style decode ~2.14x smaller decode case, ~36% KV memory savings Block-table indirection enables scattered physical KV reads
Continuous batching simulator 3.19x avg latency improvement, 1.97x total-time speedup, 40.86% work reduction Evict finished requests and admit new ones dynamically
INT8 weight-only linear 2.00x weight memory reduction, 1.34x vs PyTorch dequant + matmul Fused dequant avoids materializing the full W_deq matrix
Speculative decoding simulator 4.57x target-forward reduction with a good draft, 1.78x with a bad draft Speedup depends on draft-token acceptance rate

The most important result is conceptual: these implementations make the memory path, compute path, and scheduling path explicit.


INT8 weight-only quantization result

The INT8 quantization experiment was run on an NVIDIA GeForce RTX 5070 Laptop GPU using:

torch:  2.7.0+cu128
triton: 3.3.0
cuda:   12.8
compute capability: (12, 0)

Run command:

python 07_quantization/quant_linear_triton.py --M 32 --K 4096 --N 11008 --iters 10 --block-m 8 --block-n 32 --block-k 32

Shape:

x: [32, 4096]
W: [4096, 11008]

Results:

Metric Result
Max error vs PyTorch dequant reference 0.000017
Mean error 0.000001
Relative error 0.000005
FP16 weight memory 86.00 MB
INT8 + scales memory 43.04 MB
Memory reduction 2.00x
PyTorch pre-dequantized FP16 matmul 0.4869 ms
PyTorch dequant + matmul 2.9091 ms
Triton fused dequant + matmul 2.1667 ms
Triton vs PyTorch dequant + matmul 1.34x
Triton vs pre-dequantized FP16 GEMM 0.22x

Important note:

This kernel uses a SIMT fallback path that avoids tl.dot.

On the RTX 5070 Laptop GPU / compute capability 12.0, Triton 3.3.0 failed while lowering tl.dot through the matmul acceleration pass. The fallback version avoids that compiler path and still demonstrates fused dequantization correctly.

Because it avoids tl.dot and tensor-core lowering, it is not expected to beat cuBLAS-backed FP16 matmul.

The correct takeaway:

The Triton kernel proves that fused dequantization avoids W_deq materialization and beats the naive PyTorch dequantize-then-matmul baseline.

It does not yet prove production INT4/INT8 speedups over cuBLAS, because that requires tensor-core-friendly quantized kernels such as Marlin-style kernels.

Speculative decoding result

The speculative decoding demo is CPU-only and uses a toy greedy bigram target model plus a noisier/corrupted draft model. The target model remains the source of truth.

The demo verifies that speculative decoding produces the same output as target-only greedy decoding while reducing the number of expensive target-model forward passes.

Good draft model

Run command:

python 08_speculative_decoding/speculative_decode.py --draft-k 4 --draft-noise 0.10 --corrupt-frac 0.05

Results:

Metric Result
Generated tokens 64
Output matches target greedy True
Target-only forwards 64
Speculative target forwards 14
Draft forwards 55
Proposed draft tokens 55
Accepted draft tokens 51
Rejected draft tokens 1
Acceptance rate 92.73%
Tokens per target forward 4.57
Target forward reduction 4.57x

Medium draft model

Run command:

python 08_speculative_decoding/speculative_decode.py --draft-k 4 --draft-noise 0.70 --corrupt-frac 0.30

Results:

Metric Result
Generated tokens 64
Output matches target greedy True
Target-only forwards 64
Speculative target forwards 20
Draft forwards 77
Proposed draft tokens 77
Accepted draft tokens 45
Rejected draft tokens 12
Acceptance rate 58.44%
Tokens per target forward 3.20
Target forward reduction 3.20x

Bad draft model

Run command:

python 08_speculative_decoding/speculative_decode.py --draft-k 4 --draft-noise 1.5 --corrupt-frac 0.70

Results:

Metric Result
Generated tokens 64
Output matches target greedy True
Target-only forwards 64
Speculative target forwards 36
Draft forwards 141
Proposed draft tokens 141
Accepted draft tokens 29
Rejected draft tokens 33
Acceptance rate 20.57%
Tokens per target forward 1.78
Target forward reduction 1.78x

The main lesson:

Speculative decoding does not change the target model's output in this greedy demo. The target model remains the source of truth.

The draft model proposes several tokens cheaply. The target model verifies those tokens in one pass. If many draft tokens are accepted, each target forward produces multiple final tokens. If many draft tokens are rejected, speculative decoding collapses toward normal decoding plus wasted draft work.

Core conceptual distinctions

FlashAttention vs PagedAttention

FlashAttention is an attention-compute optimization. It avoids storing full [T, T] scores and probabilities by streaming K/V tiles and maintaining online softmax state:

m   = running max per query row
l   = running softmax denominator per query row
acc = running output numerator per query row
out = acc / l

PagedAttention is a serving-memory layout optimization. It avoids requiring each request's KV cache to be physically contiguous:

logical token position -> logical block -> block table -> physical block -> K/V read

PagedAttention vs continuous batching

PagedAttention answers:

Where is this request's KV cache stored?

It uses fixed-size physical blocks and block tables so KV memory does not need to be contiguous.

Continuous batching answers:

Which requests should run in this decode step?

It keeps the active batch full by removing completed requests and admitting new waiting requests dynamically.

Together, they explain a large part of modern LLM serving systems:

PagedAttention       = KV memory layout
Continuous batching = serving scheduler
FlashAttention      = attention compute

Speculative decoding vs continuous batching

Continuous batching improves serving throughput across many requests:

Which requests should be active in this decode step?

Speculative decoding improves per-request generation efficiency:

Can one target-model verification step produce multiple accepted tokens?

They solve different problems and can be used together.

In a real serving engine:

continuous batching schedules many active requests
speculative decoding changes how each request advances
paged KV / TransKV-style logic manages the extra KV-cache pressure from draft and target verification paths

Quantization method vs quantized kernel

Quantized inference has two different problems:

1. Accuracy problem:
   How do we compress weights without hurting model quality?

2. Kernel problem:
   How do we run the compressed weights faster than FP16?

Methods like AWQ and GPTQ mainly address the accuracy problem.

Specialized kernels like Marlin mainly address the kernel problem.

The mental model:

AWQ / GPTQ = decide what quantized numbers to store
Marlin     = multiply those quantized numbers fast
GGUF       = model storage/runtime format in llama.cpp ecosystem
bitsandbytes = convenient runtime/library path for quantized loading

When Triton helps

Triton helped most for operations where a naive PyTorch expression creates intermediate tensors or multiple memory passes:

  • RMSNorm
  • RoPE
  • SwiGLU activation
  • causal masked softmax
  • fused dequantization

Triton did not beat cuBLAS-style matmul in this repo. That is expected. GEMM is one of the most optimized operations in the CUDA ecosystem.


What weight-only quantization is

A normal LLM linear layer computes:

y = x @ W

Typical shapes:

x: [M, K]
W: [K, N]
y: [M, N]

For example:

x: [32, 4096]
W: [4096, 11008]
y: [32, 11008]

In FP16 inference, weights usually use 2 bytes per value:

4096 * 11008 * 2 bytes ≈ 86 MB

Weight-only quantization compresses only the weights:

x stays FP16/BF16
W becomes INT8 or INT4

The operation becomes conceptually:

y = x @ dequant(W_quantized)

For INT8:

W_fp16 -> W_int8 + scales

For INT4:

W_fp16 -> packed_W_int4 + scales

This is called weight-only quantization because only the model weights are quantized. Runtime activations stay floating point.

That matters because activations are dynamic. They depend on the prompt, layer, token position, batch, and generated sequence. Quantizing activations safely is harder because their range can change at runtime. Weights are fixed after training, so they can be quantized offline.

The main benefit is memory bandwidth reduction:

Format Bytes per weight Approx reduction vs FP16
FP16 2.0 1.0x
INT8 1.0 2.0x
INT4 0.5 4.0x

The main tradeoff:

Quantization reduces memory bandwidth.
Dequantization adds extra math.

A quantized kernel is useful only if the memory savings outweigh the dequantization overhead.


What per-channel scales are

INT8 can only represent a small integer range:

-128 to 127

Neural network weights are floating-point values:

0.018
-0.241
1.372

A scale maps floating-point values to integer values.

Quantization:

W_q = round(W_fp / scale)

Dequantization:

W_deq = W_q * scale

Without the scale, an integer value has no useful meaning.

Example:

scale = 0.01
W_q = 25
W_deq = 25 * 0.01 = 0.25

Per-tensor scale

The simplest scheme uses one scale for the whole matrix:

one scale for all values in W

This is simple but often too crude. Different output channels can have very different ranges.

One output column may contain small values:

[0.01, -0.03, 0.02, ...]

Another output column may contain larger values:

[0.8, -1.1, 0.5, ...]

A single global scale may preserve the large values but destroy precision in the small-value channel.

Per-channel scale

Per-channel quantization gives each output channel its own scale.

For a linear layer:

W: [K, N]

The output channels are the columns:

W[:, 0], W[:, 1], ..., W[:, N-1]

So per-output-channel scales have shape:

scales: [N]

Each column gets its own range:

scale[j] = max(abs(W[:, j])) / 127
W_q[:, j] = round(W[:, j] / scale[j])
W_deq[:, j] = W_q[:, j] * scale[j]

This repo implements per-output-channel INT8 scales.

For the benchmark shape:

W:      [4096, 11008]
scales: [11008]

The scale overhead is tiny compared with the number of weights:

scales: 11008 values
weights: 4096 * 11008 = 45,088,768 values

So per-channel scaling usually improves accuracy with negligible memory overhead.


Why fused dequant avoids materializing W_deq

The simple PyTorch reference is:

W_deq = W_q.float() * scales
y = x.float() @ W_deq

This is easy to understand, but it creates a full dequantized matrix in GPU memory.

For the benchmark shape:

W_q:   [4096, 11008] INT8
W_deq: [4096, 11008] FP16 or FP32

If W_deq is FP16:

4096 * 11008 * 2 bytes ≈ 86 MB

If W_deq is FP32:

4096 * 11008 * 4 bytes ≈ 172 MB

The naive path does unnecessary memory work:

1. Read compressed W_q.
2. Read scales.
3. Write full W_deq to global memory.
4. Read W_deq again for matmul.
5. Compute y.

That partly defeats the point of quantization.

The fused Triton kernel avoids this by working tile by tile:

1. Load a small tile of W_q.
2. Load the scale values for that tile's output columns.
3. Convert INT8 values to floating point inside the kernel.
4. Multiply by scales inside the kernel.
5. Immediately multiply with x.
6. Accumulate the output.
7. Store only y.

Conceptually:

for output tile:
    acc = 0

    for K tile:
        x_tile = load x
        w_q_tile = load W_q
        scales_tile = load scales

        w_deq_tile = w_q_tile * scales_tile

        acc += x_tile @ w_deq_tile

    store acc

The important point:

w_deq_tile exists only temporarily inside the kernel.
The full W_deq matrix is never written to global memory.

That is why the Triton kernel beats the PyTorch dequantize-then-matmul baseline:

PyTorch dequant + matmul:      2.9091 ms
Triton fused dequant + matmul: 2.1667 ms
Speedup:                       1.34x

But it is still slower than pre-dequantized FP16 cuBLAS matmul because the current compatibility kernel avoids tl.dot and tensor-core lowering:

PyTorch pre-dequantized FP16 matmul: 0.4869 ms
Triton fused dequant + matmul:       2.1667 ms

So the correct lesson is:

Fusing dequantization avoids W_deq materialization.

To beat FP16 GEMM, the quantized kernel also needs optimized tensor-core usage, packed low-bit weight layouts, and architecture-aware scheduling.

Why Marlin, AWQ, and GPTQ need specialized tensor-core kernels

AWQ and GPTQ are quantization methods. Marlin is an optimized inference kernel.

They solve different parts of the quantized inference problem.

AWQ

AWQ stands for Activation-aware Weight Quantization.

It is a low-bit weight-only quantization method for LLMs. Its key observation is that not all weights are equally important, and activation statistics help identify which channels matter most.

AWQ uses calibration activations to protect important channels through scaling.

In simple terms:

AWQ decides how to quantize the weights while preserving model quality.

It answers:

How should the weights be scaled and quantized so accuracy stays high?

AWQ is mainly about quantization quality.

It does not automatically make inference fast unless the runtime has an efficient kernel that can consume the AWQ weight format.

GPTQ

GPTQ is another post-training weight quantization method.

It uses approximate second-order information to quantize weights while compensating for quantization error.

In simple terms:

GPTQ decides how to round and compress weights while minimizing output error.

It answers:

How can we quantize a large transformer layer with low accuracy loss?

Like AWQ, GPTQ is mainly about quantization quality. It still needs an optimized runtime kernel for speed.

Marlin

Marlin is different.

Marlin is not mainly a quantization algorithm. It is an optimized mixed-precision inference kernel.

Its target operation is roughly:

FP16 activations × INT4 weights -> FP16/FP32 output

Marlin answers:

How do we multiply FP16 activations by INT4 weights very fast on GPU?

That requires more than loading INT4 weights and multiplying by scales.

A production INT4 kernel needs specialized layout and scheduling.

Why specialized kernels are needed

1. INT4 weights are packed

INT4 uses 4 bits per weight.

That means two weights fit in one byte:

byte = [high 4-bit weight][low 4-bit weight]

Before multiplying, the kernel must unpack:

packed byte -> two int4 values -> dequantized values

If unpacking is too expensive, the memory savings do not become speedups.

2. Scales and zero-points must be applied efficiently

Quantized weights usually need metadata:

scales
zero-points
group sizes
packing layout

A common INT4 format may use one scale per group, such as every 128 weights.

The kernel must apply:

w_deq = scale * (w_q - zero_point)

without adding too much overhead.

3. Tensor cores expect specific tile layouts

cuBLAS-level speed comes from tensor cores.

Tensor cores operate on specific tile shapes and data layouts. A fast INT4 kernel must arrange data so that the GPU can feed tensor cores efficiently.

That means:

weights are packed in a tensor-core-friendly layout
loads are coalesced
tiles are scheduled for reuse
shared memory/register usage is controlled
warps are assigned carefully

The layout that is convenient for storage is not always the layout that is fastest for tensor cores.

4. Dequantization must be fused into the matmul pipeline

A slow quantized path does this:

INT4 weights -> full FP16 W_deq matrix -> normal GEMM

That materializes the expanded weight matrix and loses much of the memory benefit.

A fast path does this:

load packed INT4 tile
unpack/dequantize inside the kernel
feed values into tensor-core MMA pipeline
accumulate output

The dequantized values should be temporary. They should live in registers or shared memory long enough to be consumed, not written as a full matrix to global memory.

5. Decode often has small or medium batch sizes

Autoregressive decode often has:

M = active batch size
K = hidden dimension
N = output dimension

For example:

x: [32, 4096]
W: [4096, 11008]

This is a hard performance regime because M is relatively small while K and N are large.

Specialized kernels like Marlin are designed to stay efficient in this LLM decode regime.

AWQ/GPTQ + Marlin

This is why combinations like these exist:

Marlin-AWQ
Marlin-GPTQ

That means:

AWQ or GPTQ determines the quantized weights.
Marlin provides the fast CUDA kernel to run those weights.

Full stack:

Quantization method:
  AWQ / GPTQ

Weight format:
  packed INT4 weights + scales/metadata

Kernel:
  Marlin-style FP16 x INT4 matmul

Runtime:
  vLLM / TensorRT-LLM / custom server

Main mental model:

AWQ/GPTQ decide what numbers to store.
Marlin decides how to multiply those numbers fast.

This repo currently implements the first kernel-level building block:

FP16 activations x INT8 weights
per-output-channel scales
fused dequantization
manual SIMT accumulation fallback

It does not yet implement:

INT4 packing
group-wise scales
zero-points
tensor-core MMA layout
Marlin-style scheduling
AWQ search
GPTQ reconstruction/error compensation

That is intentional. The learning progression is:

Step 1: INT8 per-channel quantization
Step 2: fused dequant + matmul
Step 3: INT4 packing/unpacking
Step 4: group-wise scales
Step 5: tensor-core-friendly layouts
Step 6: AWQ/GPTQ-compatible formats
Step 7: Marlin-style optimized kernel

Continuous batching result

The continuous batching simulator is CPU-only. It compares a simple static batching baseline against a dynamic continuous batching scheduler.

Static batching keeps a batch together until the slowest request finishes. This wastes decode work on requests that already completed and pads prompt work to the longest prompt in the batch.

Continuous batching evicts finished requests immediately and admits new waiting requests into freed batch slots.

Example output:

======================================================================
Continuous batching
======================================================================
total_time: 133
completed: 40
total_tokens_processed: 11563
generated_tokens: 880
tokens_per_step: 86.94
avg_latency: 39.75
p95_latency: 90
======================================================================
Static batching
======================================================================
total_time: 262
completed: 40
total_tokens_processed: 19552
generated_tokens: 880
tokens_per_step: 74.63
avg_latency: 126.65
p95_latency: 225
======================================================================
Comparison
======================================================================
Latency speedup: 3.19x
Total time speedup: 1.97x
Work reduction: 40.86%

The scheduler does not know true output length upfront. In the simulator, hidden_actual_output_len is used only to emulate when a request would naturally hit EOS. The scheduler only sees max_new_tokens, current generated token count, and request status.


How speculative decoding works

Normal autoregressive decoding uses the target model once per generated token:

target forward -> one token
target forward -> one token
target forward -> one token
...

Speculative decoding uses two models:

draft model  = smaller/cheaper model that proposes tokens
target model = larger/source-of-truth model that verifies tokens

The loop is:

1. Draft model proposes k tokens autoregressively.
2. Target model verifies those k tokens in one forward pass.
3. Matching draft tokens are accepted.
4. At the first mismatch, the target token is used instead.
5. If all draft tokens are accepted, the target can provide one additional bonus token.

The key property:

The target model remains the source of truth.

In this repo's greedy toy demo, the speculative output exactly matches target-only greedy decoding for good, medium, and bad draft settings.

Speculative decoding helps when:

draft model is much cheaper than target model
draft predictions are frequently accepted
target verification can check multiple positions in parallel
KV-cache overhead is managed carefully

Speculative decoding hurts or gives weak gains when:

draft model is too inaccurate
acceptance rate is low
draft overhead is too high
KV-cache management overhead dominates

Speculative decoding and KV cache pressure

Speculative decoding creates extra KV-cache complexity.

Normal decode appends one token at a time:

target accepts one token -> append one token's KV

Speculative decoding may temporarily evaluate several draft tokens:

draft proposes k tokens
target verifies k tokens
some prefix is accepted
remaining draft tokens are rejected

That means the system must handle temporary or rollbackable KV states. If a draft token is rejected, its KV entries should not permanently remain in the accepted target sequence.

This is why speculative decoding connects directly to KV-cache memory-management work:

Accepted draft tokens -> keep/commit KV
Rejected draft suffix -> discard/free/rollback KV
Target correction token -> append correct KV

The serving system must make this efficient, especially under continuous batching where many requests may speculate concurrently.


Repo structure

inference-kernels/
├── README.md
├── requirements.txt
├── 01_kv_cache/
│   ├── tiny_gpt.py
│   ├── cached_generation.py
│   └── benchmark.py
├── 02_paged_kv/
│   ├── block_manager.py
│   └── paged_cache_demo.py
├── 03_triton_kernels/
│   ├── vector_add.py
│   ├── matmul.py
│   ├── softmax.py
│   ├── causal_softmax.py
│   ├── rmsnorm.py
│   ├── swiglu.py
│   └── rope.py
├── 04_flash_attention/
│   ├── online_softmax.py
│   └── flash_attention.py
├── 05_paged_attention/
│   └── paged_decode.py
├── 06_continuous_batching/
│   ├── scheduler.py
│   └── simulate.py
├── 07_quantization/
│   ├── README.md
│   └── quant_linear_triton.py
├── 08_speculative_decoding/
│   └── speculative_decode.py
├── benchmarks/
│   └── results.md
└── scripts/
    └── check_syntax.py

Setup

Use a CUDA machine with PyTorch and Triton installed for the GPU kernels. Colab with a T4 works for most scripts. RTX 50-series users should use a recent CUDA/PyTorch stack.

pip install -r requirements.txt

For the RTX 5070 Laptop GPU test environment used in the quantization benchmark:

python -m pip install torch==2.7.0 --index-url https://download.pytorch.org/whl/cu128
python -m pip install numpy

Check syntax for all files:

python scripts/check_syntax.py

Run individual GPU/Triton scripts:

python 03_triton_kernels/rmsnorm.py
python 03_triton_kernels/swiglu.py
python 03_triton_kernels/rope.py
python 04_flash_attention/flash_attention.py
python 05_paged_attention/paged_decode.py
python 07_quantization/quant_linear_triton.py

Run CPU-safe scripts:

python 01_kv_cache/cached_generation.py
python 02_paged_kv/paged_cache_demo.py
python 04_flash_attention/online_softmax.py
python 06_continuous_batching/simulate.py
python 08_speculative_decoding/speculative_decode.py

Run speculative decoding quality sweep:

python 08_speculative_decoding/speculative_decode.py --draft-k 4 --draft-noise 0.10 --corrupt-frac 0.05
python 08_speculative_decoding/speculative_decode.py --draft-k 4 --draft-noise 0.70 --corrupt-frac 0.30
python 08_speculative_decoding/speculative_decode.py --draft-k 4 --draft-noise 1.5 --corrupt-frac 0.70

Learning path

The repo is ordered to build intuition:

  1. KV cache: prefill, decode, cached generation, KV memory formula.
  2. Paged KV simulation: physical blocks, logical blocks, request block tables.
  3. Triton basics: vector add, masks, program IDs, blocked matmul.
  4. Fused kernels: softmax, causal softmax, RMSNorm, SwiGLU activation, RoPE.
  5. FlashAttention: online softmax and tiled attention without [T,T] intermediate matrices.
  6. PagedAttention decode: block-table lookup over scattered KV blocks.
  7. Continuous batching: request scheduling, dynamic eviction, admission, latency/work comparison.
  8. Quantized inference: weight-only quantization, per-channel scales, fused dequantization, and why production INT4 needs specialized tensor-core kernels.
  9. Speculative decoding: draft/verify loop, acceptance rate, target-forward reduction, and KV-cache implications.

Interview-ready explanations

Why is decode memory-bound?

Decode usually processes one new token per request. The model must read weights and read all cached K/V for attention, but the batch/token dimension is small compared with prefill. This makes memory bandwidth and KV-cache layout central bottlenecks.

Why does kernel fusion help?

Fusion helps when separate PyTorch ops read/write large intermediate tensors. A fused Triton kernel keeps intermediate values in registers and writes only the final result. RMSNorm, RoPE, SwiGLU activation, causal masked softmax, and dequantization are good examples.

Why not write your own GEMM?

Because GEMM is compute-heavy and vendor libraries are heavily optimized for tensor cores, tiling, memory hierarchy, and scheduling. Custom kernels are more valuable around GEMMs: fusing activations, normalization, position encodings, quantization/dequantization, and custom attention layouts.

How does FlashAttention work?

It streams K/V tiles, computes score tiles in SRAM/registers, applies online softmax, accumulates the output, and never writes the full [T,T] score/probability matrices to global memory.

How does PagedAttention work?

Each request has a block table mapping logical KV blocks to physical KV blocks. Decode-time attention reads K/V through this block table, allowing non-contiguous KV allocation and reduced fragmentation/padding waste.

How does continuous batching work?

A scheduler keeps a waiting queue and an active batch. New requests are admitted when capacity is available. Prefill runs once per request. Decode then runs repeatedly, one token per active request per step. When a request hits EOS, a stop condition, or max_new_tokens, it is evicted immediately and a new waiting request can enter the active batch.

Why does continuous batching improve latency?

Requests finish at different, unpredictable times. Static batching waits for the slowest request in the batch, so completed requests waste slots. Continuous batching frees those slots immediately, improving average latency, p95 latency, and overall work efficiency.

What is weight-only quantization?

Weight-only quantization stores model weights in a smaller integer format, such as INT8 or INT4, while keeping runtime activations in FP16/BF16. During inference, weights are dequantized and multiplied with the floating-point activations. It reduces weight memory bandwidth, which is important because LLM inference repeatedly reads large weight matrices.

Why use per-channel scales?

Different output channels can have different weight ranges. A single global scale can overfit the largest values and lose precision in smaller channels. Per-channel scales give each output channel its own numeric range, reducing quantization error with very little metadata overhead.

Why fuse dequantization into matmul?

If the runtime first creates a full W_deq matrix and then runs matmul, it writes and rereads a large expanded weight matrix. Fused dequantization avoids this by dequantizing only small weight tiles inside the kernel and immediately consuming them for matmul.

Why do AWQ/GPTQ need kernels like Marlin?

AWQ and GPTQ decide how to quantize weights with low accuracy loss. They do not automatically make inference fast. To get speedups, the runtime needs specialized kernels that can unpack low-bit weights, apply scales/zero-points, feed tensor cores efficiently, and avoid materializing dequantized weights. Marlin-style kernels solve this kernel-side problem.

How does speculative decoding work?

A smaller draft model proposes multiple tokens. The larger target model verifies those proposed tokens in one forward pass. Matching draft tokens are accepted. At the first mismatch, the target token is used instead. If all proposed tokens are accepted, the target can provide an extra bonus token. The target model remains the source of truth, so output correctness is preserved.

When does speculative decoding help?

Speculative decoding helps when the draft model is much cheaper than the target model and has a high acceptance rate. A high acceptance rate means each expensive target forward can produce multiple final tokens. If the draft model is poor, the target rejects often and the speedup collapses.

How does speculative decoding affect KV cache?

Speculative decoding creates temporary KV-cache states. Accepted draft tokens can be committed, but rejected draft suffixes must be discarded or rolled back. The target correction token then needs to be appended. This makes speculative decoding closely tied to KV-cache memory management and paged allocation.


Current limitations

This is a learning repo, not a production inference engine.

  • Kernels are forward-only.
  • FlashAttention implementation is simplified and lacks production optimizations like causal block skipping, software pipelining, advanced scheduling, and tuned tensor-core paths.
  • PagedAttention decode kernel demonstrates block-table indirection but does not implement a full production allocator or serving runtime.
  • Continuous batching simulator is CPU-only and models scheduling behavior, not actual GPU execution.
  • INT8 quantization kernel demonstrates fused dequantization, but the RTX 50-series compatibility path avoids tl.dot and tensor-core lowering.
  • Speculative decoding demo is CPU-only and uses toy greedy bigram models to explain the draft/verify loop.
  • INT4 packing, group-wise scales, zero-points, AWQ/GPTQ formats, and Marlin-style scheduling are not implemented yet.
  • Benchmarks are development runs and should be re-run on the target GPU.

Next planned additions

  • 08_speculative_decoding/README.md with a deeper explanation of draft/verify, acceptance rate, and KV-cache rollback.
  • INT4 packing/unpacking demo.
  • Group-wise scales and zero-points.
  • Tensor-core-friendly quantized matmul path when Triton SM120 support is stable.
  • Mini inference server/capstone that ties KV cache, paged allocation, scheduling, quantization, speculative decoding, and decode together.
  • Nsight Compute profiling notes.

Author

Anay Dongre

About

LLM inference kernels from scratch in Triton: KV cache, FlashAttention, PagedAttention, RMSNorm, RoPE, SwiGLU, and benchmarks.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages