Per-Slot-EMA Transformer: a sequence modeling architecture that achieves lossless representation and
From Bijection to Injection: Lossless and
$O(1)$ Inference Are No Longer in Conflict Yuyue Li. arXiv preprint, 2026.
This repository accompanies the arXiv paper above. The code is the research prototype used to produce all experiments in the paper. It is not a production implementation; it has not been tuned for wall-clock performance, and a kernel-level rewrite would be needed for deployment-grade throughput. See §7.6 of the paper for the asymptotic comparison and the rationale for this disclaimer.
.
├── README.md / README_zh.md This file (English / Chinese)
├── LICENSE MIT
├── ARCHITECTURE.md Internal architecture notes
├── record.md Experimental record (source of truth for paper numbers)
├── training_log.txt Sample training log (350M run on SlimPajama)
├── train.py Main training entry point
├── cal_mean.py Utility: mean PPL from training log
│
├── src/ Core implementation
│ ├── model.py EMA / RoPE model definitions
│ ├── ema_kernel.py Multi-scale EMA scan (Triton)
│ ├── score_kernel.py Three-factor attention score kernel (Triton)
│ ├── flash_attn_kernel.py FlashAttention-style fused kernel
│ ├── linear_attn_kernel.py Linear-attention reference kernel
│ ├── rope_model.py RoPE Transformer baseline
│ ├── slot_producer.py Per-token-id slot management
│ ├── token_state.py Active-slot state container
│ ├── state.py State serialization (warm-start)
│ ├── batch_planner.py Heaps-aware batch planner
│ ├── bucket_autogen.py K_max bucket auto-tuning
│ ├── doc_index.py Document-level index for packing
│ ├── doc_similarity.py MinHash similarity for packing
│ ├── datapipe.py v1 data pipeline
│ ├── datapipe_v2.py v2 data pipeline (chunk-aware, with State carry)
│ ├── collate.py Batch collation
│ ├── ingest.py / prepare.py Tokenization + bin file ingestion
│ └── partition.py Train/val/test partitioning
│
├── tests/ Causality + sanity tests
│ ├── test_no_future_leak.py Strict causality check (most important)
│ ├── test_cross_chunk_causality.py
│ ├── test_v2_causality.py
│ └── test_model.py Model forward/backward sanity
│
├── tools/ Reproduction tools (paper tables)
│ ├── eval_extrapolation.py Length extrapolation (Table 4)
│ ├── eval_extrapolation.sh
│ ├── profile_k_distribution.py K_max distribution (Table 2 right)
│ ├── doc_length_stats.py Heaps' law fit (Table 2 left)
│ └── per_position_loss.py Per-position loss diagnostic
│
└── paper/
├── main.tex / main.pdf Chinese version
└── main_en.tex / main_en.pdf English (arXiv) version
- Python 3.10+
- PyTorch 2.1+
- Triton 2.1+
- A CUDA-capable GPU (tested on RTX 5060 Ti and A100; bf16 default)
- For the paper PDFs: TeX Live 2023+ (English:
pdflatex; Chinese:xelatexwithctex)
src/prepare.py ingests raw text and produces tokenized binary files used for training:
python -c "from src.prepare import prepare; \
prepare(train_file='/path/to/wt103/train.txt', \
data_root='/tmp/ema_data/wt103', \
tokenizer='gpt2', \
text_field='text')"This is also called automatically the first time you run train.py --train_file <raw> --data_root <bin_dir>. The output is a directory containing:
train.bin/val.bin—int32token streams, contiguous, with document boundariesdoc_offsets.bin— document-start indicesmeta.json— vocab size, tokenizer, doc count
Supported sources: plain .txt (one doc per line), .jsonl (with --text_field), .parquet, and Hugging Face datasets.
python train.py \
--train_file /path/to/wt103/train.txt \
--val_file /path/to/wt103/val.txt \
--data_root /tmp/ema_data/wt103 \
--tokenizer gpt2 \
--model ema \
--d_model 512 --n_layers 6 --n_heads 8 \
--n_scales 32 --rho_min_halflife 1 --rho_max_halflife 8192 \
--pipe_version v2 \
--chunk_min 2048 --chunk_max 2048 \
--steps 200000 \
--lr 3e-4 --warmup_steps 2000 \
--device cudaSame command, but --model rope and drop the EMA-specific flags (--n_scales, --rho_*).
python tools/eval_extrapolation.py \
--ema_ckpt experiments/<run_name>/model_best.pt \
--rope_ckpt experiments/<rope_run>/model_best.pt \
--val_file /path/to/wt103/val.txt \
--data_root /tmp/ema_data/wt103 \
--tokenizer gpt2 \
--train_len 512 \
--eval_lens 256,512,1024,2048,4096,8192 \
--device cudapython tools/doc_length_stats.py --bin_dir /tmp/ema_data/wt103
python tools/profile_k_distribution.py --bin_dir /tmp/ema_data/wt103- All numbers in the paper come from
record.md; each row in the paper tables can be traced back to a section there. If you reproduce a number and it differs by more than 0.5 PPL, please open an issue. - Random seed:
--seed 0everywhere unless noted.bf16is the default precision. - The 350M SlimPajama run (15B tokens) is the most expensive experiment (~3 days on 4×A100). Smaller-scale results (Tables 3, 5, 6) reproduce in hours on a single GPU.
- The first
train.pyinvocation builds the bin files and a MinHash index; subsequent runs reuse them (use--force_prepare/--rebuild_bucketsto force rebuild).
@misc{li2026pse,
title = {From Bijection to Injection: Lossless and $O(1)$ Inference Are No Longer in Conflict},
author = {Yuyue Li},
year = {2026},
eprint = {arXiv:XXXX.XXXXX},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}(arXiv ID will be filled in once the paper is publicly available.)
MIT. See LICENSE.
Thanks to Yuwei Feng for helpful discussions throughout the development of this work.