A tiny decoder-only transformer in JAX where every component, from scaled dot-product attention to autoregressive generation, is implemented from scratch. Built to trace every tensor shape through every layer, not to produce coherent text.
| Tag | Description |
|---|---|
tinformer-from-scratch |
Base transformer implementation |
kv-caching |
KV caching with benchmarks |
.gitignore
LICENSE
README.md
generate.py # Naive + cached autoregressive generation
requirements.txt
src/
config.py # TinformerConfig dataclass
attention.py # Scaled dot-product attention + multi-head attention (with KV cache)
layernorm.py # Layer normalization
ffn.py # Feed-forward network
decoder.py # Decoder block (LN → MHA → residual → LN → FFN → residual)
tinformer.py # Full model with KV cache support
benchmarks/
benchmark_kv_caching.py # Naive vs cached generation benchmarks
tests/
test_shapes.py # Tensor shape verification
test_causal_masking.py # Causal mask validation
test_attention_stability.py # Numerical stability tests
test_cached_generate.py # KV cache correctness tests
git clone https://github.com/othakkar/tinformer.git
cd tinformer
pip install -r requirements.txt
python -m generatepython -m benchmarks.benchmark_kv_cachingpytest tests/Apache 2.0