Skip to content

othakkar/tinformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

53 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Tinformer

A tiny decoder-only transformer in JAX where every component, from scaled dot-product attention to autoregressive generation, is implemented from scratch. Built to trace every tensor shape through every layer, not to produce coherent text.

Tags

Tag Description
tinformer-from-scratch Base transformer implementation
kv-caching KV caching with benchmarks

Repo structure

.gitignore
LICENSE
README.md
generate.py                          # Naive + cached autoregressive generation
requirements.txt
src/
  config.py                          # TinformerConfig dataclass
  attention.py                       # Scaled dot-product attention + multi-head attention (with KV cache)
  layernorm.py                       # Layer normalization
  ffn.py                             # Feed-forward network
  decoder.py                         # Decoder block (LN → MHA → residual → LN → FFN → residual)
  tinformer.py                       # Full model with KV cache support
benchmarks/
  benchmark_kv_caching.py            # Naive vs cached generation benchmarks
tests/
  test_shapes.py                     # Tensor shape verification
  test_causal_masking.py             # Causal mask validation
  test_attention_stability.py        # Numerical stability tests
  test_cached_generate.py            # KV cache correctness tests

Quickstart

git clone https://github.com/othakkar/tinformer.git
cd tinformer
pip install -r requirements.txt
python -m generate

Run benchmarks

python -m benchmarks.benchmark_kv_caching

Run tests

pytest tests/

License

Apache 2.0

About

Tiny decoder-only transformer in JAX. Every component built from scratch.

Resources

License

Stars

Watchers

Forks

Contributors

Languages