Normalize functional score loss per variant within each condition#230
Conversation
Replace .sum() with .mean() in functional_score_loss so each condition
contributes equally to the total loss regardless of variant count. This
restores the V0.4.0/V1.0.0 design that was lost in the V2 rewrite.
- Core change: .sum() → .mean() in functional_score_loss (jaxmodels.py)
- Add TODO comment in count_loss noting it retains .sum() by design
- Update convergence trajectory: loss_per_variant_trajectory now divides
by n_conditions; loss_per_variant_{d} = loss_{d} (already per-variant)
- Remove unused n_variants_total variable
- Add 4 new tests: known answer, duplicate invariance, gradient scaling,
count_loss unchanged
- Update 2 existing tests for new semantics
- Add experiments/loss-normalization/ pipeline for fusionreg×l2reg grid
validation against V0.4.0 hyperparameter anchors
Closes #227
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Add papermill 'parameters' cell tags so config_path/output_dir are injected correctly - Add subsample_frac support preserving wildtype rows for fast test runs - Fix evaluate notebook: use Model.get_mutations_df() directly, handle column name suffixes from merge, make aggregation robust to missing cols Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- fusionreg: V0.4.0 grid [0, 5e-6, ..., 6.4e-4] (was [0, 0.4, ..., 12.8]) - l2reg: 0.0 (was 1e-4; V0.4.0 spike used 0.0) - beta0_ridge: 1e-4 (was 10; scaled by ~93K avg variants/condition) - Skip cross-validation for initial run - Add 'mean_loss' profile to spike Snakefile Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Experimental resultsBoth experiments ran against the V0.4.0-anchored hyperparameter grid described in the issue. Results confirm that 1. Simulation experiment (
|
l2reg |
Best fusionreg |
Pearson r |
|---|---|---|
| 0 | 6.4e-4 | 0.9979 |
| 1e-8 | 6.4e-4 | 0.9979 |
| 1e-7 | 6.4e-4 | 0.9978 |
| 1e-6 | 6.4e-4 | 0.9960 |
All well above the 0.95 success criterion. l2reg has minimal effect — values 0 through 1e-7 are essentially equivalent, and even 1e-6 only drops by 0.002.
Sparsity recovery: Best match to true sparsity (0.80) at fusionreg=3.2e-4 → sparsity=0.811. The V0.4.0 chosen value of 4e-5 gives sparsity ~0.07 (low), while 8e-5 gives ~0.38. The optimal fusionreg for sparsity is higher than for correlation, as expected.
Conclusion: l2reg=0 is fine for the simulation; fusionreg in the 1e-4 to 3e-4 range balances correlation and sparsity well.
2. Spike experiment (experiments/scv2-spike/, profile mean_loss)
Setup: 18 models — 9 fusionreg × 2 replicates. Full spike data (~280K variants across 3 conditions), maxiter=50, l2reg=0, beta0_ridge=1e-4. Ran on orca01 in ~20 min.
Rescaled config (from V2 .sum() → .mean()):
| Parameter | Old (.sum()) |
New (.mean()) |
Scaling |
|---|---|---|---|
fusionreg grid |
[0, 0.4 ... 12.8] | [0, 5e-6 ... 6.4e-4] | V0.4.0 grid |
l2reg |
1e-4 | 0.0 | V0.4.0 spike value |
beta0_ridge |
10 | 1e-4 | ÷ ~93K avg variants |
Beta replicate correlation (higher = more reproducible):
fusionreg |
Delta | BA.1 (ref) | BA.2 | Average |
|---|---|---|---|---|
| 0 | 0.516 | 0.782 | 0.735 | 0.677 |
| 5e-6 | 0.690 | 0.788 | 0.831 | 0.770 |
| 1e-5 | 0.720 | 0.797 | 0.843 | 0.787 |
| 2e-5 | 0.750 | 0.800 | 0.837 | 0.796 |
| 4e-5 | 0.770 | 0.800 | 0.816 | 0.795 |
| 8e-5 | 0.780 | 0.795 | 0.796 | 0.790 |
| 1.6e-4 | 0.774 | 0.775 | 0.774 | 0.775 |
| 3.2e-4 | 0.771 | 0.771 | 0.770 | 0.771 |
Peak average beta correlation at fusionreg=2e-5 (0.796). BA.1 and BA.2 peak at lower fusionreg (1e-5 to 2e-5) while Delta — the data-poorest condition — peaks at 8e-5. This is the expected pattern: the .mean() normalization gives Delta equal fitting weight, so Delta benefits from moderate regularization rather than being dominated by the larger conditions.
Shift replicate correlation:
fusionreg |
Delta | BA.2 |
|---|---|---|
| 4e-5 | 0.573 | 0.529 |
| 8e-5 | 0.570 | 0.621 |
| 3.2e-4 | 0.414 | 0.819 |
Delta shifts peak at 4e-5, BA.2 shifts peak at 3.2e-4. This condition-dependent optimal is expected — Delta has fewer true shifts relative to BA.2.
Nonsynonymous shift sparsity (avg across replicates):
fusionreg |
Delta | BA.2 |
|---|---|---|
| 0 | 4% | 6% |
| 4e-5 | 50% | 58% |
| 8e-5 | 71% | 75% |
| 1.6e-4 | 99% | 99% |
V0.4.0's chosen value of 4e-5 gives ~50% sparsity — a reasonable operating point for initial analysis.
Summary
.mean()normalization works as designed — conditions contribute equally regardless of variant count, and V0.4.0-scale hyperparameters are appropriate.l2regis not critical for simulation data — 0.0 through 1e-7 give equivalent results. Usingl2reg=0(the V0.4.0 spike value) is a safe default.fusionregin the 2e-5 to 8e-5 range is the sweet spot for the spike data, with4e-5(V0.4.0's chosen value) as a good default.- Delta benefits most from the
.mean()normalization — it achieves 0.780 beta replicate correlation (vs being gradient-dominated under.sum()). - Results are in the worktree at
.claude/worktrees/issue-227/experiments/and on orca01/orca02. Production configs should be updated to the rescaled values after this PR merges.
Simulation: - fusionreg: V0.4.0 grid [0, 5e-6, ..., 6.4e-4] - l2reg: 0.0 (experiment showed no benefit from nonzero values) - beta0_ridge: 0.0 (V0.4.0 sim used 0.0) - lasso_choice: 8e-5 (best sparsity/correlation tradeoff from experiment) Spike (all profiles): - fusionreg: V0.4.0 grid [0, 5e-6, ..., 6.4e-4] - l2reg: 0.0 (V0.4.0 spike value) - beta0_ridge: 1e-4 (10 / ~93K avg variants) - lasso_choice: 4e-5 (V0.4.0 chosen value) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Since PR #230, functional_score_loss is already a per-variant .mean(), so total_loss_{training,validation} on ModelCollection.fit_models are already per-variant averages. Both CV notebooks were dividing by sample count again, producing doubly-normalized values that no longer matched the "mean Huber loss per variant" axis label. - Drop the n_samples dict and the `loss / n_samples` division in both experiments/simulation/notebooks/cross_validation.ipynb and experiments/scv2-spike/notebooks/cross_validation.ipynb - Keep mean_loss as an alias for loss in the emitted CSVs so downstream readers (manuscript_figures.ipynb) pick up the fix without schema changes - Delete experiments/scv2-spike/config/config_mean_loss.yaml (its grid was already absorbed into the default config.yaml) and drop the corresponding profile branch from the spike Snakefile - Update CLAUDE.md to warn that total_loss_* columns are per-variant Closes #231 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Summary
.sum()with.mean()infunctional_score_lossso each condition contributes equally to the total loss regardless of variant countjaxmodels) rewrite, directly addressing the data-imbalance problem from Add sample-size-weighted fusion regularization #198 at the loss levelexperiments/loss-normalization/pipeline for validating rescaled hyperparameters against V0.4.0 anchors (fusionreg × l2reg 2D grid)Changes
Core (
multidms/jaxmodels.py):functional_score_loss:.sum()→.mean()(one-line change)count_loss: added TODO comment noting.sum()is retained by designloss_per_variant_trajectorynow divides byn_conditions(notn_variants_total);loss_per_variant_{d}=loss_{d}since loss is already per-variantn_variants_totalvariableTests:
TestMeanLossNormalization: known answer, duplicate invariance, gradient scaling, count_loss unchangedtest_convergence_trajectory_per_variant_lossto check ratio equalsn_conditionstest_per_condition_loss_per_variant_normalizationto checkloss_per_variant_{d}==loss_{d}Experimental pipeline (
experiments/loss-normalization/):fit_models.ipynb— fits the grid viafit_models()evaluate.ipynb— sparsity & correlation diagnostics vs ground truthHyperparameter impact: This is a breaking change for hyperparameter values. The loss scale drops by ~n_variants per condition, so regularization parameters (fusionreg, l2reg) need proportional rescaling. The experimental pipeline validates the V0.4.0-anchored ranges.
Test plan
experiments/loss-normalization/pipeline on simulation data (remote, post-merge or on branch)Closes #227
🤖 Generated with Claude Code