Skip to content

Mr-Skeleton-Max/PSE_Transformer

Repository files navigation

PSE Transformer

Per-Slot-EMA Transformer: a sequence modeling architecture that achieves lossless representation and $O(1)$ inference simultaneously, by relaxing the implicit "information retention $=$ bijective invertibility" assumption to injection without bijection.

From Bijection to Injection: Lossless and $O(1)$ Inference Are No Longer in Conflict Yuyue Li. arXiv preprint, 2026.

中文 README → README_zh.md

Status

This repository accompanies the arXiv paper above. The code is the research prototype used to produce all experiments in the paper. It is not a production implementation; it has not been tuned for wall-clock performance, and a kernel-level rewrite would be needed for deployment-grade throughput. See §7.6 of the paper for the asymptotic comparison and the rationale for this disclaimer.

Repository layout

.
├── README.md / README_zh.md   This file (English / Chinese)
├── LICENSE                    MIT
├── ARCHITECTURE.md            Internal architecture notes
├── record.md                  Experimental record (source of truth for paper numbers)
├── training_log.txt           Sample training log (350M run on SlimPajama)
├── train.py                   Main training entry point
├── cal_mean.py                Utility: mean PPL from training log
│
├── src/                       Core implementation
│   ├── model.py               EMA / RoPE model definitions
│   ├── ema_kernel.py          Multi-scale EMA scan (Triton)
│   ├── score_kernel.py        Three-factor attention score kernel (Triton)
│   ├── flash_attn_kernel.py   FlashAttention-style fused kernel
│   ├── linear_attn_kernel.py  Linear-attention reference kernel
│   ├── rope_model.py          RoPE Transformer baseline
│   ├── slot_producer.py       Per-token-id slot management
│   ├── token_state.py         Active-slot state container
│   ├── state.py               State serialization (warm-start)
│   ├── batch_planner.py       Heaps-aware batch planner
│   ├── bucket_autogen.py      K_max bucket auto-tuning
│   ├── doc_index.py           Document-level index for packing
│   ├── doc_similarity.py      MinHash similarity for packing
│   ├── datapipe.py            v1 data pipeline
│   ├── datapipe_v2.py         v2 data pipeline (chunk-aware, with State carry)
│   ├── collate.py             Batch collation
│   ├── ingest.py / prepare.py Tokenization + bin file ingestion
│   └── partition.py           Train/val/test partitioning
│
├── tests/                     Causality + sanity tests
│   ├── test_no_future_leak.py        Strict causality check (most important)
│   ├── test_cross_chunk_causality.py
│   ├── test_v2_causality.py
│   └── test_model.py                 Model forward/backward sanity
│
├── tools/                     Reproduction tools (paper tables)
│   ├── eval_extrapolation.py         Length extrapolation (Table 4)
│   ├── eval_extrapolation.sh
│   ├── profile_k_distribution.py     K_max distribution (Table 2 right)
│   ├── doc_length_stats.py           Heaps' law fit (Table 2 left)
│   └── per_position_loss.py          Per-position loss diagnostic
│
└── paper/
    ├── main.tex / main.pdf           Chinese version
    └── main_en.tex / main_en.pdf     English (arXiv) version

Requirements

  • Python 3.10+
  • PyTorch 2.1+
  • Triton 2.1+
  • A CUDA-capable GPU (tested on RTX 5060 Ti and A100; bf16 default)
  • For the paper PDFs: TeX Live 2023+ (English: pdflatex; Chinese: xelatex with ctex)

Data preparation

src/prepare.py ingests raw text and produces tokenized binary files used for training:

python -c "from src.prepare import prepare; \
           prepare(train_file='/path/to/wt103/train.txt', \
                   data_root='/tmp/ema_data/wt103', \
                   tokenizer='gpt2', \
                   text_field='text')"

This is also called automatically the first time you run train.py --train_file <raw> --data_root <bin_dir>. The output is a directory containing:

  • train.bin / val.binint32 token streams, contiguous, with document boundaries
  • doc_offsets.bin — document-start indices
  • meta.json — vocab size, tokenizer, doc count

Supported sources: plain .txt (one doc per line), .jsonl (with --text_field), .parquet, and Hugging Face datasets.

Quick start

Train PSE on WikiText-103

python train.py \
    --train_file /path/to/wt103/train.txt \
    --val_file   /path/to/wt103/val.txt \
    --data_root  /tmp/ema_data/wt103 \
    --tokenizer  gpt2 \
    --model      ema \
    --d_model    512 --n_layers 6 --n_heads 8 \
    --n_scales   32  --rho_min_halflife 1 --rho_max_halflife 8192 \
    --pipe_version v2 \
    --chunk_min  2048 --chunk_max 2048 \
    --steps      200000 \
    --lr 3e-4 --warmup_steps 2000 \
    --device cuda

Train RoPE baseline (matched parameter count)

Same command, but --model rope and drop the EMA-specific flags (--n_scales, --rho_*).

Length-extrapolation evaluation (reproduces Table 4)

python tools/eval_extrapolation.py \
    --ema_ckpt   experiments/<run_name>/model_best.pt \
    --rope_ckpt  experiments/<rope_run>/model_best.pt \
    --val_file   /path/to/wt103/val.txt \
    --data_root  /tmp/ema_data/wt103 \
    --tokenizer  gpt2 \
    --train_len  512 \
    --eval_lens  256,512,1024,2048,4096,8192 \
    --device     cuda

Heaps' law fit (reproduces Table 2 left)

python tools/doc_length_stats.py --bin_dir /tmp/ema_data/wt103
python tools/profile_k_distribution.py --bin_dir /tmp/ema_data/wt103

Reproducibility notes

  • All numbers in the paper come from record.md; each row in the paper tables can be traced back to a section there. If you reproduce a number and it differs by more than 0.5 PPL, please open an issue.
  • Random seed: --seed 0 everywhere unless noted. bf16 is the default precision.
  • The 350M SlimPajama run (15B tokens) is the most expensive experiment (~3 days on 4×A100). Smaller-scale results (Tables 3, 5, 6) reproduce in hours on a single GPU.
  • The first train.py invocation builds the bin files and a MinHash index; subsequent runs reuse them (use --force_prepare / --rebuild_buckets to force rebuild).

Citation

@misc{li2026pse,
  title  = {From Bijection to Injection: Lossless and $O(1)$ Inference Are No Longer in Conflict},
  author = {Yuyue Li},
  year   = {2026},
  eprint = {arXiv:XXXX.XXXXX},
  archivePrefix = {arXiv},
  primaryClass  = {cs.LG}
}

(arXiv ID will be filled in once the paper is publicly available.)

License

MIT. See LICENSE.

Acknowledgments

Thanks to Yuwei Feng for helpful discussions throughout the development of this work.

About

Per-Slot-EMA Transformer: lossless and O(1) inference via injection without bijection (paper code, to be released upon arXiv publication)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors