Skip to content

Delete all torch from the repo (JAX-only)#876

Merged
ocg-goodfire merged 1 commit into
feature/jaxfrom
worktree-agent-ab39118894fbeca2e
Jun 17, 2026
Merged

Delete all torch from the repo (JAX-only)#876
ocg-goodfire merged 1 commit into
feature/jaxfrom
worktree-agent-ab39118894fbeca2e

Conversation

@ocg-goodfire

Copy link
Copy Markdown
Collaborator

Summary

The repo is now torch-free. grep -rl "import torch\|from torch" param_decomp param_decomp_lab param_decomp_jax --include=*.py | grep -v /tests/ returns nothing. Training has been the JAX single-pool trainer (jsp-train via pd-jax-lm) since the torch-trainer retirement; this PR sheds the remaining torch consumer/infra layer.

The only torch left anywhere is nano_param_decomp/run.py — the standalone single-file VPD reference impl for paper readers (self-contained, no lab deps, excluded from make type, outside the goal-grep scope).

Deleted entirely

  • param_decomp_lab/experiments/lm/pretrain/ — torch model defs, train loop, pd-pretrain CLI, run_info, configs. Pretraining will be reimplemented in JAX when next needed. The trainer loads target weights from the on-disk cache via its own torch-free loaders (llama_simple_mlp.load_target_from_pretrain_cache / load_prefix_from_pretrain_cache), never through pretrain code — so this does not touch the weight cache or break target loading.
  • param_decomp/{distributed,decomposition_targets}.py, param_decomp_lab/{distributed,seed}.py, param_decomp_lab/infra/{ddp_launch,wandb_tensor_info}.py, param_decomp_lab/toy_models/, param_decomp_lab/topology/topology.py (TransformerTopology), param_decomp_lab/experiments/lm/run.py (the torch build_target bridge) — all dead after de-torching their sole consumers.
  • jax_single_pool/tools/convert_llama_simple_mlp_checkpoint.py (one-off torch checkpoint→safetensors converter; existing caches already hold their safetensors).
  • nano_param_decomp/{pile_4L,simplestories_2L}.py — model-wiring entry points that imported the deleted torch pretrain archs (now broken). run.py (the method itself) stays.
  • Four tests covering deleted modules.

Relocated / de-torched (no torch-model construction)

  • JaxPDAdapter derives target topology (n_blocks / vocab / per-site (name, C) / layer descriptions) from a new torch-free jax_single_pool.load_run.run_metadata (reads config + the pretrain-cache model_config.yaml, no orbax restore) + the already-torch-free path_schema_for_model_type.
  • experiments/lm/data.py keeps only tokenize_and_concatenate (numpy) for the offline prestage_tokenized tool; the torch DataLoader machinery is gone.
  • infra/run_files.py drops the dead save_file / torch.save path.

Tooling

  • Dropped torch + the pytorch-cu128 index/source from the root pyproject.toml + regenerated uv.lock. The two-stack conflict is gone; make install-dev bridges CPU jax + beartype + editable param_decomp_jax into the main venv. CI install drops the pytorch cpu extra-index. Docs updated (root / jax / experiments CLAUDE.md, MIGRATION_HOLES, READMEs).

Zero-torch proof

$ grep -rl "import torch\|from torch\|import torch\b" param_decomp param_decomp_lab param_decomp_jax --include=*.py | grep -v /tests/
(empty)

Trainer-still-loads-targets proof (CPU)

1. build_from_schema(pile_pgd1.yaml) OK -> target: LlamaSimpleMLPTargetConfig, n_sites: 24
2. build_target OK -> lm sites=24 vocab=50277
   loaded the pile LlamaSimpleMLP weights from the on-disk pretrain cache (torch-free)

(run_metadata on a real run + the full target+prefix build from cache both succeed with no torch installed.)

Validation

  • make type: 0 errors
  • lab+core tests: 209 passed (-m "not slow")
  • make check-jax: 0 errors
  • JAX suite: 193 passed, 2 skipped
  • equivalence goldens: 12 passed (bit-identical — training semantics unchanged)
  • import-smoke: all kept consumers (harvest/clustering run_worker_jax, autointerp, intruder, postprocess, jax_launch, jsp-train, JaxPDAdapter) import with torch absent.

Note for review

nano_param_decomp/run.py (torch reference impl, paper-linked from README) is left in place but is the lone torch file. Flag if you'd prefer it deleted too — it's outside the goal-grep scope and self-contained, so it was kept as a pedagogical artifact.

🤖 Generated with Claude Code

The repo is now torch-free except `nano_param_decomp/run.py` — the standalone
single-file VPD reference impl for paper readers (self-contained, no lab/pretrain
deps, excluded from `make type`). Training has been JAX (`jsp-train` via `pd-jax-lm`)
since the torch trainer retirement; this sheds the remaining torch consumer/infra
layer. Zero `import torch` in `param_decomp` / `param_decomp_lab` / `param_decomp_jax`.

Deleted entirely:
- `param_decomp_lab/experiments/lm/pretrain/` (torch model defs, train loop,
  `pd-pretrain` CLI, run_info, configs) — pretraining will be reimplemented in JAX
  when next needed; the trainer loads target weights from the on-disk cache via its
  own torch-free loaders, never through pretrain code.
- `param_decomp/{distributed,decomposition_targets}.py`,
  `param_decomp_lab/{distributed,seed}.py`,
  `param_decomp_lab/infra/{ddp_launch,wandb_tensor_info}.py`,
  `param_decomp_lab/toy_models/`, `param_decomp_lab/topology/topology.py`,
  `param_decomp_lab/experiments/lm/run.py` (the torch `build_target` bridge) —
  all dead after de-torching their sole consumers.
- `jax_single_pool/tools/convert_llama_simple_mlp_checkpoint.py` (one-off torch
  checkpoint->safetensors converter; existing caches already converted).
- `nano_param_decomp/{pile_4L,simplestories_2L}.py` — model-wiring entry points that
  imported the deleted torch pretrain archs (now broken); `run.py` (the method) stays.
- Four tests covering deleted modules.

De-torched (relocated metadata path):
- `JaxPDAdapter` now derives target topology (n_blocks / vocab / per-site (name, C))
  from the new torch-free `jax_single_pool.load_run.run_metadata` (config + pretrain
  cache, no orbax restore) + the torch-free `path_schema_for_model_type`. No torch
  model construction.
- `experiments/lm/data.py` keeps only `tokenize_and_concatenate` (numpy, for the
  offline prestage tool); the torch DataLoader machinery is gone.
- `infra/run_files.py` drops the dead `save_file`/torch.save path.

Tooling: dropped torch + the pytorch-cu128 index from the root pyproject/lock; the
two-stack conflict is gone; the main venv bridges CPU jax + beartype + editable
`param_decomp_jax` (`make install-dev`). Docs updated (root/jax/experiments CLAUDE.md,
MIGRATION_HOLES, READMEs).

Validation: zero-torch grep empty; trainer still builds+loads the pile LlamaSimpleMLP
target from cache (CPU); `make type` 0 errors; lab+core 209 passed; `make check-jax`
0 errors; JAX suite 193 passed/2 skipped; equivalence goldens 12 passed (bit-identical
— training semantics unchanged).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@ocg-goodfire ocg-goodfire merged commit bdda5ce into feature/jax Jun 17, 2026
1 check failed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant