metrics: split frequency-minimality out of importance-minimality (batch-invariant)#543
Draft
ocg-goodfire wants to merge 1 commit into
Draft
metrics: split frequency-minimality out of importance-minimality (batch-invariant)#543ocg-goodfire wants to merge 1 commit into
ocg-goodfire wants to merge 1 commit into
Conversation
…ch-invariant) Split the frequency-weighted log2 term out of ImportanceMinimalityLoss into a standalone FrequencyMinimalityLoss, normalized to be batch-invariant. Math (per component c, f_c = per-token firing frequency over the global batch): imp = Σ_c f_c (bare L_p mean term) freq = Σ_c f_c · log2(1 + a' · f_c) (a' = reference_token_count) Setting a' = B·T reproduces the old rolled `imp + beta·log2(1 + B·T·f_c)`, so coefficients transfer as freq.coeff = old imp.coeff · old beta. Core: - importance_minimality.py: finalize_imp_min → bare Σ_c f_c; new finalize_freq_min + FrequencyMinimalityLoss/Config (reference_token_count knob); drop beta from ImportanceMinimalityLoss; shared autograd-aware global-sum reduce helper. - Register FrequencyMinimalityLoss in dispatch + AnyLossMetricConfig. 3-pool: - ThreePoolLosses gains a `freq` field (+ validator requiring freq/imp to share pnorm/eps/anneal, since step_ci shares one set of (ci+eps)^p sums). - step_ci computes both terms from the same all-reduced sums (one extra finalize); reductions + total-loss + logging carry loss/freq. Migrations: - App circuit-optimization keeps the old rolled form via a local AppImpMinConfig + app_importance_minimality_loss (its beta is persisted in the app DB), decoupling it from the core signature change. - In-repo 3-pool/LM/TMS/resid YAMLs: drop imp.beta, add freq block where beta>0 with reference_token_count = that config's B·T and coeff = imp.coeff·beta. Tests: batch-invariance, a'=B·T reproduces the old rolled value, f=0→0, and the 3-pool _sparsity_losses matches a single-process reference. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Splits the frequency-weighted
log2term out ofImportanceMinimalityLossinto a standalone, batch-invariantFrequencyMinimalityLoss. (Lucius signed off on this exact conceptual shape.)Per component
c, withf_c= per-token firing frequency over the whole global batch (f_c = Σ_{all B·T tokens} (g_c + ε)^p / (B·T)):L_imp = Σ_c f_cL_freq = Σ_c f_c · log2(1 + a'·f_c), wherea' = reference_token_countThe
f=0 → 0cutoff is inherent to the form. Settinga' = B·Treproduces the old implicit behavior (the old code'slog2(1 + B·T·f_c)), so existing coefficients transfer cleanly.The
reference_token_countknobThe old rolled term baked
B·Timplicitly inside thelog2(via the un-normalized per-component sum). That made the penalty's curvature batch-size-dependent.L_freqmakes the normalizer explicit asreference_token_count(a'), so the loss is invariant to batch size at fixed firing rate — a config can changeBwithout silently rescaling the frequency penalty.Coefficient translation
For migrated configs:
new freq.coeff = old imp.coeff · old beta,imp.coeffunchanged,reference_token_count = B·T.beta: 0configs just dropbeta(no freq block needed, since single-poolloss_metricsis a free list and the term was zero).Motivation and Context
The implicit-
B·Tcoupling meant frequency-penalty strength drifted with batch size. Splitting + normalizing makes the two sparsity pressures independently tunable and batch-invariant.How Has This Been Tested?
make check(basedpyright + ruff) clean — zero new type errors vs baseline (122 pre-existing, all wandb/fire/fastapi import noise).make test(excluding slow): 467 passed, 8 skipped. New tests inparam_decomp/tests/metrics/test_importance_minimality_loss.py+param_decomp_lab/tests/test_three_pool_sparsity_losses.py:fat differentB→ sameL_freqa' = B·Treproduces the old rolledimp + beta·logvalue exactlyf=0 → 0contribution_sparsity_lossesmatches a single-processfinalize_imp_min/finalize_freq_minreferenceAll 21 migrated YAMLs verified to parse through their config classes.
Code changes
param_decomp/metrics/importance_minimality.py):finalize_imp_min→ bareΣ_c f_c; newfinalize_freq_min(...); newFrequencyMinimalityLoss+FrequencyMinimalityLossConfig;betaremoved fromImportanceMinimalityLoss. Registered indispatch.py+AnyLossMetricConfig.ThreePoolLossesgains afreqfield.step_cicomputes both terms from the same all-reduced per-component sums (one extra finalize — no recompute), so a validator requiresfreq/impto sharepnorm/eps/anneal (onlyreference_token_countdiffers).reductions.py, total-loss, and logging now carryloss/freq/_raw/freq_num._xl_production,_resumption_validation), single-pool LM, TMS, resid-MLP YAMLs.Notes / decisions made under ambiguity
app/backend/optim_cis.py,editing/) genuinely depended on the combinedimp + beta·logform, and the app persistsbetain its SQLite graph-optimization params. Retrofitting the split there would require a DB migration + behavior change — out of scope. Instead I introduced an app-localAppImpMinConfig+app_importance_minimality_lossthat preserve the old rolled math exactly, decoupling the app from the core signature change. App behavior is unchanged.p-ecfda851) is unaffected — it snapshotted its config at launch; editing the in-repo YAML doesn't touch it.scripts/three_pool_grad_check/base.yamland thescripts/repro_3pool_eval_deadlock/*configs carry a top-levelthree_pool: nullkey thatLMExperimentConfigalready rejects (pre-existing, from the config-fork PR) — they were stale before this change. I migrated their loss blocks for consistency but did not fix the unrelatedthree_poolstaleness.Does this PR introduce a breaking change?
Yes —
ImportanceMinimalityLossno longer acceptsbeta. No legacy fallback (house style): existing YAMLs need manual migration (dropimp.beta, add aFrequencyMinimalityLoss/freqblock). All in-repo configs are migrated here.🤖 Generated with Claude Code