Skip to content

Normalize functional score loss per variant within each condition#230

Merged
jaredgalloway merged 4 commits into
mainfrom
227-normalize-loss-per-variant
Apr 17, 2026
Merged

Normalize functional score loss per variant within each condition#230
jaredgalloway merged 4 commits into
mainfrom
227-normalize-loss-per-variant

Conversation

@jaredgalloway

Copy link
Copy Markdown
Member

Summary

  • Replace .sum() with .mean() in functional_score_loss so each condition contributes equally to the total loss regardless of variant count
  • This restores the V0.4.0/V1.0.0 design that was lost in the V2 (jaxmodels) rewrite, directly addressing the data-imbalance problem from Add sample-size-weighted fusion regularization #198 at the loss level
  • Add experiments/loss-normalization/ pipeline for validating rescaled hyperparameters against V0.4.0 anchors (fusionreg × l2reg 2D grid)

Changes

Core (multidms/jaxmodels.py):

  • functional_score_loss: .sum().mean() (one-line change)
  • count_loss: added TODO comment noting .sum() is retained by design
  • Convergence trajectory: loss_per_variant_trajectory now divides by n_conditions (not n_variants_total); loss_per_variant_{d} = loss_{d} since loss is already per-variant
  • Removed unused n_variants_total variable

Tests:

  • 4 new tests in TestMeanLossNormalization: known answer, duplicate invariance, gradient scaling, count_loss unchanged
  • Updated test_convergence_trajectory_per_variant_loss to check ratio equals n_conditions
  • Updated test_per_condition_loss_per_variant_normalization to check loss_per_variant_{d} == loss_{d}

Experimental pipeline (experiments/loss-normalization/):

  • Snakefile + config for sweeping fusionreg × l2reg (36 combinations at full grid)
  • fit_models.ipynb — fits the grid via fit_models()
  • evaluate.ipynb — sparsity & correlation diagnostics vs ground truth
  • Config grids anchored to V0.4.0 hyperparameter values

Hyperparameter impact: This is a breaking change for hyperparameter values. The loss scale drops by ~n_variants per condition, so regularization parameters (fusionreg, l2reg) need proportional rescaling. The experimental pipeline validates the V0.4.0-anchored ranges.

Test plan

  • All 182 existing tests pass (including updated trajectory tests)
  • 4 new mean-loss tests pass
  • Ruff lint clean
  • Black format clean
  • Run experiments/loss-normalization/ pipeline on simulation data (remote, post-merge or on branch)
  • Verify rescaled hyperparameter grids produce comparable or better β correlation vs ground truth

Closes #227

🤖 Generated with Claude Code

jaredgalloway and others added 3 commits April 9, 2026 16:35
Replace .sum() with .mean() in functional_score_loss so each condition
contributes equally to the total loss regardless of variant count. This
restores the V0.4.0/V1.0.0 design that was lost in the V2 rewrite.

- Core change: .sum() → .mean() in functional_score_loss (jaxmodels.py)
- Add TODO comment in count_loss noting it retains .sum() by design
- Update convergence trajectory: loss_per_variant_trajectory now divides
  by n_conditions; loss_per_variant_{d} = loss_{d} (already per-variant)
- Remove unused n_variants_total variable
- Add 4 new tests: known answer, duplicate invariance, gradient scaling,
  count_loss unchanged
- Update 2 existing tests for new semantics
- Add experiments/loss-normalization/ pipeline for fusionreg×l2reg grid
  validation against V0.4.0 hyperparameter anchors

Closes #227

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add papermill 'parameters' cell tags so config_path/output_dir are
  injected correctly
- Add subsample_frac support preserving wildtype rows for fast test runs
- Fix evaluate notebook: use Model.get_mutations_df() directly, handle
  column name suffixes from merge, make aggregation robust to missing cols

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- fusionreg: V0.4.0 grid [0, 5e-6, ..., 6.4e-4] (was [0, 0.4, ..., 12.8])
- l2reg: 0.0 (was 1e-4; V0.4.0 spike used 0.0)
- beta0_ridge: 1e-4 (was 10; scaled by ~93K avg variants/condition)
- Skip cross-validation for initial run
- Add 'mean_loss' profile to spike Snakefile

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@jaredgalloway

Copy link
Copy Markdown
Member Author

Experimental results

Both experiments ran against the V0.4.0-anchored hyperparameter grid described in the issue. Results confirm that .mean() loss normalization works correctly and that the V0.4.0 hyperparameter scale is appropriate.

1. Simulation experiment (experiments/loss-normalization/)

Setup: 216 models — 9 fusionreg × 4 l2reg × 6 datasets (2 libraries × 3 func_score_types). Production-scale data (50K variants/condition), maxiter=15. Ran on orca02 in ~36 min.

β correlation with ground truth:

l2reg Best fusionreg Pearson r
0 6.4e-4 0.9979
1e-8 6.4e-4 0.9979
1e-7 6.4e-4 0.9978
1e-6 6.4e-4 0.9960

All well above the 0.95 success criterion. l2reg has minimal effect — values 0 through 1e-7 are essentially equivalent, and even 1e-6 only drops by 0.002.

Sparsity recovery: Best match to true sparsity (0.80) at fusionreg=3.2e-4 → sparsity=0.811. The V0.4.0 chosen value of 4e-5 gives sparsity ~0.07 (low), while 8e-5 gives ~0.38. The optimal fusionreg for sparsity is higher than for correlation, as expected.

Conclusion: l2reg=0 is fine for the simulation; fusionreg in the 1e-4 to 3e-4 range balances correlation and sparsity well.


2. Spike experiment (experiments/scv2-spike/, profile mean_loss)

Setup: 18 models — 9 fusionreg × 2 replicates. Full spike data (~280K variants across 3 conditions), maxiter=50, l2reg=0, beta0_ridge=1e-4. Ran on orca01 in ~20 min.

Rescaled config (from V2 .sum().mean()):

Parameter Old (.sum()) New (.mean()) Scaling
fusionreg grid [0, 0.4 ... 12.8] [0, 5e-6 ... 6.4e-4] V0.4.0 grid
l2reg 1e-4 0.0 V0.4.0 spike value
beta0_ridge 10 1e-4 ÷ ~93K avg variants

Beta replicate correlation (higher = more reproducible):

fusionreg Delta BA.1 (ref) BA.2 Average
0 0.516 0.782 0.735 0.677
5e-6 0.690 0.788 0.831 0.770
1e-5 0.720 0.797 0.843 0.787
2e-5 0.750 0.800 0.837 0.796
4e-5 0.770 0.800 0.816 0.795
8e-5 0.780 0.795 0.796 0.790
1.6e-4 0.774 0.775 0.774 0.775
3.2e-4 0.771 0.771 0.770 0.771

Peak average beta correlation at fusionreg=2e-5 (0.796). BA.1 and BA.2 peak at lower fusionreg (1e-5 to 2e-5) while Delta — the data-poorest condition — peaks at 8e-5. This is the expected pattern: the .mean() normalization gives Delta equal fitting weight, so Delta benefits from moderate regularization rather than being dominated by the larger conditions.

Shift replicate correlation:

fusionreg Delta BA.2
4e-5 0.573 0.529
8e-5 0.570 0.621
3.2e-4 0.414 0.819

Delta shifts peak at 4e-5, BA.2 shifts peak at 3.2e-4. This condition-dependent optimal is expected — Delta has fewer true shifts relative to BA.2.

Nonsynonymous shift sparsity (avg across replicates):

fusionreg Delta BA.2
0 4% 6%
4e-5 50% 58%
8e-5 71% 75%
1.6e-4 99% 99%

V0.4.0's chosen value of 4e-5 gives ~50% sparsity — a reasonable operating point for initial analysis.


Summary

  1. .mean() normalization works as designed — conditions contribute equally regardless of variant count, and V0.4.0-scale hyperparameters are appropriate.
  2. l2reg is not critical for simulation data — 0.0 through 1e-7 give equivalent results. Using l2reg=0 (the V0.4.0 spike value) is a safe default.
  3. fusionreg in the 2e-5 to 8e-5 range is the sweet spot for the spike data, with 4e-5 (V0.4.0's chosen value) as a good default.
  4. Delta benefits most from the .mean() normalization — it achieves 0.780 beta replicate correlation (vs being gradient-dominated under .sum()).
  5. Results are in the worktree at .claude/worktrees/issue-227/experiments/ and on orca01/orca02. Production configs should be updated to the rescaled values after this PR merges.

Simulation:
- fusionreg: V0.4.0 grid [0, 5e-6, ..., 6.4e-4]
- l2reg: 0.0 (experiment showed no benefit from nonzero values)
- beta0_ridge: 0.0 (V0.4.0 sim used 0.0)
- lasso_choice: 8e-5 (best sparsity/correlation tradeoff from experiment)

Spike (all profiles):
- fusionreg: V0.4.0 grid [0, 5e-6, ..., 6.4e-4]
- l2reg: 0.0 (V0.4.0 spike value)
- beta0_ridge: 1e-4 (10 / ~93K avg variants)
- lasso_choice: 4e-5 (V0.4.0 chosen value)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@jaredgalloway jaredgalloway merged commit 4383648 into main Apr 17, 2026
6 checks passed
jaredgalloway added a commit that referenced this pull request Apr 17, 2026
Since PR #230, functional_score_loss is already a per-variant .mean(),
so total_loss_{training,validation} on ModelCollection.fit_models are
already per-variant averages. Both CV notebooks were dividing by
sample count again, producing doubly-normalized values that no
longer matched the "mean Huber loss per variant" axis label.

- Drop the n_samples dict and the `loss / n_samples` division in both
  experiments/simulation/notebooks/cross_validation.ipynb and
  experiments/scv2-spike/notebooks/cross_validation.ipynb
- Keep mean_loss as an alias for loss in the emitted CSVs so
  downstream readers (manuscript_figures.ipynb) pick up the fix
  without schema changes
- Delete experiments/scv2-spike/config/config_mean_loss.yaml (its grid
  was already absorbed into the default config.yaml) and drop the
  corresponding profile branch from the spike Snakefile
- Update CLAUDE.md to warn that total_loss_* columns are per-variant

Closes #231

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Normalize functional score loss per variant within each condition

1 participant