Delete all torch from the repo (JAX-only)#876
Merged
Conversation
The repo is now torch-free except `nano_param_decomp/run.py` — the standalone
single-file VPD reference impl for paper readers (self-contained, no lab/pretrain
deps, excluded from `make type`). Training has been JAX (`jsp-train` via `pd-jax-lm`)
since the torch trainer retirement; this sheds the remaining torch consumer/infra
layer. Zero `import torch` in `param_decomp` / `param_decomp_lab` / `param_decomp_jax`.
Deleted entirely:
- `param_decomp_lab/experiments/lm/pretrain/` (torch model defs, train loop,
`pd-pretrain` CLI, run_info, configs) — pretraining will be reimplemented in JAX
when next needed; the trainer loads target weights from the on-disk cache via its
own torch-free loaders, never through pretrain code.
- `param_decomp/{distributed,decomposition_targets}.py`,
`param_decomp_lab/{distributed,seed}.py`,
`param_decomp_lab/infra/{ddp_launch,wandb_tensor_info}.py`,
`param_decomp_lab/toy_models/`, `param_decomp_lab/topology/topology.py`,
`param_decomp_lab/experiments/lm/run.py` (the torch `build_target` bridge) —
all dead after de-torching their sole consumers.
- `jax_single_pool/tools/convert_llama_simple_mlp_checkpoint.py` (one-off torch
checkpoint->safetensors converter; existing caches already converted).
- `nano_param_decomp/{pile_4L,simplestories_2L}.py` — model-wiring entry points that
imported the deleted torch pretrain archs (now broken); `run.py` (the method) stays.
- Four tests covering deleted modules.
De-torched (relocated metadata path):
- `JaxPDAdapter` now derives target topology (n_blocks / vocab / per-site (name, C))
from the new torch-free `jax_single_pool.load_run.run_metadata` (config + pretrain
cache, no orbax restore) + the torch-free `path_schema_for_model_type`. No torch
model construction.
- `experiments/lm/data.py` keeps only `tokenize_and_concatenate` (numpy, for the
offline prestage tool); the torch DataLoader machinery is gone.
- `infra/run_files.py` drops the dead `save_file`/torch.save path.
Tooling: dropped torch + the pytorch-cu128 index from the root pyproject/lock; the
two-stack conflict is gone; the main venv bridges CPU jax + beartype + editable
`param_decomp_jax` (`make install-dev`). Docs updated (root/jax/experiments CLAUDE.md,
MIGRATION_HOLES, READMEs).
Validation: zero-torch grep empty; trainer still builds+loads the pile LlamaSimpleMLP
target from cache (CPU); `make type` 0 errors; lab+core 209 passed; `make check-jax`
0 errors; JAX suite 193 passed/2 skipped; equivalence goldens 12 passed (bit-identical
— training semantics unchanged).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
The repo is now torch-free.
grep -rl "import torch\|from torch" param_decomp param_decomp_lab param_decomp_jax --include=*.py | grep -v /tests/returns nothing. Training has been the JAX single-pool trainer (jsp-trainviapd-jax-lm) since the torch-trainer retirement; this PR sheds the remaining torch consumer/infra layer.The only torch left anywhere is
nano_param_decomp/run.py— the standalone single-file VPD reference impl for paper readers (self-contained, no lab deps, excluded frommake type, outside the goal-grep scope).Deleted entirely
param_decomp_lab/experiments/lm/pretrain/— torch model defs, train loop,pd-pretrainCLI, run_info, configs. Pretraining will be reimplemented in JAX when next needed. The trainer loads target weights from the on-disk cache via its own torch-free loaders (llama_simple_mlp.load_target_from_pretrain_cache/load_prefix_from_pretrain_cache), never through pretrain code — so this does not touch the weight cache or break target loading.param_decomp/{distributed,decomposition_targets}.py,param_decomp_lab/{distributed,seed}.py,param_decomp_lab/infra/{ddp_launch,wandb_tensor_info}.py,param_decomp_lab/toy_models/,param_decomp_lab/topology/topology.py(TransformerTopology),param_decomp_lab/experiments/lm/run.py(the torchbuild_targetbridge) — all dead after de-torching their sole consumers.jax_single_pool/tools/convert_llama_simple_mlp_checkpoint.py(one-off torch checkpoint→safetensors converter; existing caches already hold their safetensors).nano_param_decomp/{pile_4L,simplestories_2L}.py— model-wiring entry points that imported the deleted torch pretrain archs (now broken).run.py(the method itself) stays.Relocated / de-torched (no torch-model construction)
JaxPDAdapterderives target topology (n_blocks/ vocab / per-site(name, C)/ layer descriptions) from a new torch-freejax_single_pool.load_run.run_metadata(reads config + the pretrain-cachemodel_config.yaml, no orbax restore) + the already-torch-freepath_schema_for_model_type.experiments/lm/data.pykeeps onlytokenize_and_concatenate(numpy) for the offlineprestage_tokenizedtool; the torchDataLoadermachinery is gone.infra/run_files.pydrops the deadsave_file/torch.savepath.Tooling
torch+ the pytorch-cu128 index/source from the rootpyproject.toml+ regenerateduv.lock. The two-stack conflict is gone;make install-devbridges CPU jax + beartype + editableparam_decomp_jaxinto the main venv. CI install drops the pytorch cpu extra-index. Docs updated (root / jax / experiments CLAUDE.md, MIGRATION_HOLES, READMEs).Zero-torch proof
Trainer-still-loads-targets proof (CPU)
(
run_metadataon a real run + the full target+prefix build from cache both succeed with no torch installed.)Validation
make type: 0 errors-m "not slow")make check-jax: 0 errorsrun_worker_jax, autointerp, intruder, postprocess, jax_launch,jsp-train,JaxPDAdapter) import with torch absent.Note for review
nano_param_decomp/run.py(torch reference impl, paper-linked from README) is left in place but is the lone torch file. Flag if you'd prefer it deleted too — it's outside the goal-grep scope and self-contained, so it was kept as a pedagogical artifact.🤖 Generated with Claude Code