Skip to content

metrics: split frequency-minimality out of importance-minimality (batch-invariant)#543

Draft
ocg-goodfire wants to merge 1 commit into
feature/multipoolfrom
feature/frequency-minimality-loss
Draft

metrics: split frequency-minimality out of importance-minimality (batch-invariant)#543
ocg-goodfire wants to merge 1 commit into
feature/multipoolfrom
feature/frequency-minimality-loss

Conversation

@ocg-goodfire

Copy link
Copy Markdown
Collaborator

Description

Splits the frequency-weighted log2 term out of ImportanceMinimalityLoss into a standalone, batch-invariant FrequencyMinimalityLoss. (Lucius signed off on this exact conceptual shape.)

Per component c, with f_c = per-token firing frequency over the whole global batch (f_c = Σ_{all B·T tokens} (g_c + ε)^p / (B·T)):

  • Importance-minimality becomes the bare mean term: L_imp = Σ_c f_c
  • Frequency-minimality (new, separate loss): L_freq = Σ_c f_c · log2(1 + a'·f_c), where a' = reference_token_count

The f=0 → 0 cutoff is inherent to the form. Setting a' = B·T reproduces the old implicit behavior (the old code's log2(1 + B·T·f_c)), so existing coefficients transfer cleanly.

The reference_token_count knob

The old rolled term baked B·T implicitly inside the log2 (via the un-normalized per-component sum). That made the penalty's curvature batch-size-dependent. L_freq makes the normalizer explicit as reference_token_count (a'), so the loss is invariant to batch size at fixed firing rate — a config can change B without silently rescaling the frequency penalty.

Coefficient translation

For migrated configs: new freq.coeff = old imp.coeff · old beta, imp.coeff unchanged, reference_token_count = B·T. beta: 0 configs just drop beta (no freq block needed, since single-pool loss_metrics is a free list and the term was zero).

Motivation and Context

The implicit-B·T coupling meant frequency-penalty strength drifted with batch size. Splitting + normalizing makes the two sparsity pressures independently tunable and batch-invariant.

How Has This Been Tested?

make check (basedpyright + ruff) clean — zero new type errors vs baseline (122 pre-existing, all wandb/fire/fastapi import noise). make test (excluding slow): 467 passed, 8 skipped. New tests in param_decomp/tests/metrics/test_importance_minimality_loss.py + param_decomp_lab/tests/test_three_pool_sparsity_losses.py:

  • batch-invariance: same per-token f at different B → same L_freq
  • a' = B·T reproduces the old rolled imp + beta·log value exactly
  • f=0 → 0 contribution
  • 3-pool _sparsity_losses matches a single-process finalize_imp_min/finalize_freq_min reference

All 21 migrated YAMLs verified to parse through their config classes.

Code changes

  • Core (param_decomp/metrics/importance_minimality.py): finalize_imp_min → bare Σ_c f_c; new finalize_freq_min(...); new FrequencyMinimalityLoss + FrequencyMinimalityLossConfig; beta removed from ImportanceMinimalityLoss. Registered in dispatch.py + AnyLossMetricConfig.
  • 3-pool: ThreePoolLosses gains a freq field. step_ci computes both terms from the same all-reduced per-component sums (one extra finalize — no recompute), so a validator requires freq/imp to share pnorm/eps/anneal (only reference_token_count differs). reductions.py, total-loss, and logging now carry loss/freq / _raw/freq_num.
  • Migrations: 3-pool (_xl_production, _resumption_validation), single-pool LM, TMS, resid-MLP YAMLs.

Notes / decisions made under ambiguity

  • App circuit-optimization (app/backend/optim_cis.py, editing/) genuinely depended on the combined imp + beta·log form, and the app persists beta in its SQLite graph-optimization params. Retrofitting the split there would require a DB migration + behavior change — out of scope. Instead I introduced an app-local AppImpMinConfig + app_importance_minimality_loss that preserve the old rolled math exactly, decoupling the app from the core signature change. App behavior is unchanged.
  • The running production run (p-ecfda851) is unaffected — it snapshotted its config at launch; editing the in-repo YAML doesn't touch it.
  • scripts/three_pool_grad_check/base.yaml and the scripts/repro_3pool_eval_deadlock/* configs carry a top-level three_pool: null key that LMExperimentConfig already rejects (pre-existing, from the config-fork PR) — they were stale before this change. I migrated their loss blocks for consistency but did not fix the unrelated three_pool staleness.

Does this PR introduce a breaking change?

Yes — ImportanceMinimalityLoss no longer accepts beta. No legacy fallback (house style): existing YAMLs need manual migration (drop imp.beta, add a FrequencyMinimalityLoss / freq block). All in-repo configs are migrated here.

🤖 Generated with Claude Code

…ch-invariant)

Split the frequency-weighted log2 term out of ImportanceMinimalityLoss into a
standalone FrequencyMinimalityLoss, normalized to be batch-invariant.

Math (per component c, f_c = per-token firing frequency over the global batch):
  imp  = Σ_c f_c                          (bare L_p mean term)
  freq = Σ_c f_c · log2(1 + a' · f_c)      (a' = reference_token_count)

Setting a' = B·T reproduces the old rolled `imp + beta·log2(1 + B·T·f_c)`, so
coefficients transfer as freq.coeff = old imp.coeff · old beta.

Core:
- importance_minimality.py: finalize_imp_min → bare Σ_c f_c; new finalize_freq_min
  + FrequencyMinimalityLoss/Config (reference_token_count knob); drop beta from
  ImportanceMinimalityLoss; shared autograd-aware global-sum reduce helper.
- Register FrequencyMinimalityLoss in dispatch + AnyLossMetricConfig.

3-pool:
- ThreePoolLosses gains a `freq` field (+ validator requiring freq/imp to share
  pnorm/eps/anneal, since step_ci shares one set of (ci+eps)^p sums).
- step_ci computes both terms from the same all-reduced sums (one extra finalize);
  reductions + total-loss + logging carry loss/freq.

Migrations:
- App circuit-optimization keeps the old rolled form via a local AppImpMinConfig +
  app_importance_minimality_loss (its beta is persisted in the app DB), decoupling
  it from the core signature change.
- In-repo 3-pool/LM/TMS/resid YAMLs: drop imp.beta, add freq block where beta>0
  with reference_token_count = that config's B·T and coeff = imp.coeff·beta.

Tests: batch-invariance, a'=B·T reproduces the old rolled value, f=0→0, and the
3-pool _sparsity_losses matches a single-process reference.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant