Skip to content

docs: imp-min parity + 1→9-layer / batch scaling analysis#869

Open
ocg-goodfire wants to merge 1 commit into
feature/jaxfrom
analysis/impmin-scaling
Open

docs: imp-min parity + 1→9-layer / batch scaling analysis#869
ocg-goodfire wants to merge 1 commit into
feature/jaxfrom
analysis/impmin-scaling

Conversation

@ocg-goodfire

Copy link
Copy Markdown
Collaborator

What

Analysis-only PR (one markdown file, no code changes) ahead of the from-scratch
9-layer (Llama-8B L18–26, 27-site, C=49,152) decomposition run. Answers three
questions about the importance-minimality loss vs the 1-layer (3-site) reference run
(coeff 5e-6, beta 0.2, pnorm 2.0→0.4, eps 1e-12, batch 512×2048).

Findings

  1. Parity (confirmed, bit-exact). JAX importance_minimality_terms
    (feature/jax) is a line-for-line semantic match to the torch oracle
    (torch-oracle:param_decomp/metrics/importance_minimality.py): per-component sum
    over positions, eps inside (ci+eps)^p, mean over positions, lp + beta·entropy,
    global-sum inside log2(1+·), per-site grouping. The equivalence test
    test_jax_matches_torch_reference[imp] passes with rel err 0.00e+00. The only
    torch surface absent is the _no_beta eval-only diagnostic — never optimized.

  2. Site scaling (the key question): keep coeff = 5e-6. The loss is
    coeff·Σ_s Σ_c(independent per-component block). The per-component gradient
    ∂L/∂ci[s,c] depends only on coeff, p, eps, N and that one component's own
    S_{s,c}, M_{s,c}no factor of the site count or total component count, for
    BOTH the lp and beta·entropy terms (entropy couples a component only to its own
    S_{s,c}). So per-component sparsity pressure is invariant to 3→27 sites. Dividing
    coeff by the component count would be a 9× under-penalty bug.

  3. Batch scaling: no compensation needed. lp is a mean over positions →
    batch-invariant. entropy carries a +log2(N) term via the global sum
    S = N·M: 512→256 shifts the entropy term ~−7–15% per component (512→128
    ~−15–30%), total imp-min only a few %. Immaterial; if ever needed, nudge beta,
    not coeff.

Recommendation table

knob reference 27-site run
coeff 5e-6 5e-6 UNCHANGED
beta 0.2 0.2
pnorm 2.0→0.4 unchanged
eps 1e-12 unchanged
batch 512×2048 smaller is fine, no coeff/beta change

See IMPMIN_SCALING_ANALYSIS.md for the full side-by-side, equivalence-test output, and
gradient derivation.

🤖 Generated with Claude Code

Confirms JAX importance_minimality_terms is a bit-exact semantic match to the
torch oracle (equivalence test passes, rel err 0.00e+00) and derives the
per-component gradient: imp-min is coeff·Σ_s Σ_c(independent block), so per-
component sparsity pressure is invariant to site count. Recommendation: keep
coeff=5e-6 unchanged for the 27-site run (do NOT divide by component count);
lp is batch-invariant, entropy carries only a sub-log2(N) shift — no batch
compensation needed.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant