docs: imp-min parity + 1→9-layer / batch scaling analysis#869
Open
ocg-goodfire wants to merge 1 commit into
Open
docs: imp-min parity + 1→9-layer / batch scaling analysis#869ocg-goodfire wants to merge 1 commit into
ocg-goodfire wants to merge 1 commit into
Conversation
Confirms JAX importance_minimality_terms is a bit-exact semantic match to the torch oracle (equivalence test passes, rel err 0.00e+00) and derives the per-component gradient: imp-min is coeff·Σ_s Σ_c(independent block), so per- component sparsity pressure is invariant to site count. Recommendation: keep coeff=5e-6 unchanged for the 27-site run (do NOT divide by component count); lp is batch-invariant, entropy carries only a sub-log2(N) shift — no batch compensation needed. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
Analysis-only PR (one markdown file, no code changes) ahead of the from-scratch
9-layer (Llama-8B L18–26, 27-site, C=49,152) decomposition run. Answers three
questions about the importance-minimality loss vs the 1-layer (3-site) reference run
(
coeff 5e-6,beta 0.2,pnorm 2.0→0.4,eps 1e-12, batch 512×2048).Findings
Parity (confirmed, bit-exact). JAX
importance_minimality_terms(
feature/jax) is a line-for-line semantic match to the torch oracle(
torch-oracle:param_decomp/metrics/importance_minimality.py): per-component sumover positions,
epsinside(ci+eps)^p, mean over positions,lp + beta·entropy,global-sum inside
log2(1+·), per-site grouping. The equivalence testtest_jax_matches_torch_reference[imp]passes with rel err 0.00e+00. The onlytorch surface absent is the
_no_betaeval-only diagnostic — never optimized.Site scaling (the key question): keep
coeff = 5e-6. The loss iscoeff·Σ_s Σ_c(independent per-component block). The per-component gradient∂L/∂ci[s,c]depends only oncoeff, p, eps, Nand that one component's ownS_{s,c}, M_{s,c}— no factor of the site count or total component count, forBOTH the
lpandbeta·entropyterms (entropy couples a component only to its ownS_{s,c}). So per-component sparsity pressure is invariant to 3→27 sites. Dividingcoeffby the component count would be a 9× under-penalty bug.Batch scaling: no compensation needed.
lpis a mean over positions →batch-invariant.
entropycarries a+log2(N)term via the global sumS = N·M:512→256shifts the entropy term ~−7–15% per component (512→128~−15–30%), total imp-min only a few %. Immaterial; if ever needed, nudge
beta,not
coeff.Recommendation table
See
IMPMIN_SCALING_ANALYSIS.mdfor the full side-by-side, equivalence-test output, andgradient derivation.
🤖 Generated with Claude Code