Add continuation-path fitting for fusionreg sweeps#234
Merged
Conversation
Adds fit_models_path(), a sibling to fit_models() that fits a single
multidms.Model sequentially along an ascending fusionreg grid, warm-
starting each step from the previous fit's (β, β0, α). This replaces
the current practice of independent zero-initialized fits, which
produces pathological latent-phenotype ranges for data-poor
conditions (notably Delta in the spike pipeline) at high lasso
strengths.
Additions:
- multidms/model_collection.py:
- fit_models_path(params, verbose): sequential path fit along
ascending fusionreg. Returns the same (n_fit, n_failed, df)
schema as fit_models, so the output is a drop-in to
ModelCollection.
- _fit_one_path_step(): helper seeding beta_init/beta0_init/
alpha_init from the previous fit and forcing warmstart=False.
- _assert_no_nan(): guard that raises ModelCollectionFitError
if a fit produced NaNs in β, β0, or α — prevents NaN
propagation into the next step's seed.
- concat_path_trajectories(): stitches per-step convergence
trajectories into a single long DataFrame for visualization.
- experiments/scv2-spike:
- New notebook fit_models_path.ipynb mirroring fit_models.ipynb
but calling fit_models_path().
- Snakefile selects notebook based on spike.fitting.strategy
("independent" | "continuation"); output filename unchanged.
- config.yaml / config_test.yaml / config_experimental.yaml
gain strategy: "independent" (default; flip to "continuation"
to exercise the path fitter).
- Tests (tests/test_model_collection.py, TestFitModelsPath +
TestConcatPathTrajectories):
- Schema parity with fit_models.
- Seeding round-trip: beta_init/beta0_init/alpha_init at step
k+1 equal step k's fitted params.
- Constant-fusionreg identity (tight tolerance after convergence).
- Single-step degeneracy (path ≡ fit_models at len-1).
- Order invariance across datasets.
- Monotone non-decreasing sparsity along the path.
- NaN guard raises on bad β, β0, or α (scalar and dict α).
- Trajectory concatenation: iteration_global monotonicity,
fusionreg constancy per step, row-count parity.
- Empty-input handling and non-zero-start warning.
- Docs: CLAUDE.md, experiments/scv2-spike/README.md.
Closes #232
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Three follow-ups from the review of the original commit:
1. Flip spike.fitting.strategy in experiments/scv2-spike/config/config.yaml
(and config_test.yaml) to "continuation". The whole point of this PR
is to replace the default fitter that produces the Delta epistasis
pathology at production fusionreg — leaving the default on the
broken strategy after merging the fix would be backwards.
2. fit_models_path output is now schema-identical to fit_models output:
_fit_one_path_step clears beta_init / beta0_init / alpha_init to
None in the returned row after fitting. Before: those cells held
jax arrays (or dicts of them) pulled from the previous fit, which
broke .apply(str), groupby, and any pandas operation that assumed
scalar cells. The seeds are still fully recoverable from the
previous row's model column, so nothing is lost.
3. The notebook's dict-stringification loop now only stringifies
actual dicts (leaves None and other scalars alone), so a mixed
column of dicts + None no longer coerces everything to strings.
Tests added:
- test_extract_seed_round_trip — direct unit test on _extract_seed
replacing the now-redundant column-read round-trip test.
- test_path_step_rows_have_no_jax_seed_leakage — asserts the three
seed columns hold only None values in path output.
- test_schema_matches_fit_models_exactly — asserts no column of the
path DataFrame holds a dict or jax.Array (the contract that makes
path output a true drop-in to ModelCollection).
- test_verbose_in_params_does_not_collide — regression test for the
setdefault("verbose", verbose) fix in fit_models_path.
195 tests pass (10 → 13 in the TestFitModelsPath class).
Integration: spike-test profile still passes end-to-end with
strategy="continuation" (prepare_data → fit_models_path → evaluate
+ cross_validation). The produced fit_collection.pkl has no jax or
dict leakage in any non-model column.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The test ran two back-to-back fit_models_path calls with a 3-step fusionreg path, which compiles enough JAX kernels on constrained runners (~7GB CI images) to hit "LLVM compilation error: Cannot allocate memory" on the second call. When a step failed from OOM, fit_models_path correctly aborted the remainder of that path, but the test then asserted len(sub_ab) == len(sub_ba) > 0 and got "3 == 1" — a misleading failure message. Two changes: - Shorten the path to 2 fusionreg values (the invariance claim does not require more length). - jax.clear_caches() between the two calls to release compiled kernels before the second fit starts. - Assert n_failed == 0 explicitly on both calls so a future OOM shows up as "forward path had N step failure(s)" rather than length mismatch. Local pytest still passes (13/13 in TestFitModelsPath + TestConcatPathTrajectories). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
a70c343 to
e54e7f5
Compare
Earlier in this branch, spike/config.yaml was flipped to strategy: "continuation" as the production default. Reverting that flip: production stays on the established independent fitter, and continuation remains opt-in via the strategy key (defaulted to "independent" in the Snakefile when absent from YAML). The continuation code path itself, schema cleanup, and tests stay; only the prod-default flip is reverted. config_test.yaml still exercises continuation so CI covers the new code path end-to-end. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
fit_models_path(), a sibling tofit_models()that fits a singlemultidms.Modelsequentially along an ascendingfusionreggrid, warm-starting each step from the previous fit's(β, β0, α). Output schema is identical, so it is a drop-in replacement anywhere afit_collection.pklis consumed.spike.fitting.strategy: "independent" | "continuation"config knob into the spike Snakefile, selecting betweenfit_models.ipynband the newfit_models_path.ipynb. Default staysindependent.concat_path_trajectories()for stitching per-step convergence trajectories into a single long DataFrame.The motivation is the Delta epistasis pathology at high
fusionregin the spike pipeline: under independent zero-initialized fits, a strong shift lasso pulls Delta's latent phenotype below the sigmoid asymptote and never recovers theβ0/αcalibration. A continuation path, starting at the unregularized solution and tightening the lasso step by step, keeps each fit in the basin of its predecessor so the wildtype positioning follows the shifts as they shrink rather than being jointly distorted.Validation
pixi run lint,pixi run fmt-check, and the fullpixi run testsuite (192 passed, up from 182 on main; 10 new tests added) are green on this branch.testprofile twice locally from a cleanresults_test/:strategy: "independent"→ pipeline completes; fit_collection.pkl produced byfit_models.ipynb.strategy: "continuation"→ pipeline completes; fit_collection.pkl produced byfit_models_path.ipynb(verified by inspecting the executed notebook's title and imports). Downstreamevaluateandcross_validationrules run unchanged.fusionreg ∈ [0.0, 4e-5].What's not in this PR (deferred)
fusionreg=6.4e-4withstrategy: "continuation"— that is the criterion from the issue for flipping the production default, and requires a remote-pipeline run + dashboard inspection. File as follow-up once reviewer is happy with the code.strategytocontinuationin the productionconfig.yaml— keeping the default conservative until the visual check passes.Follow-ups
fusionreggrid axes (e.g. per replicate) while keeping each individual path sequential.experiments/simulation/andexperiments/loss-normalization/if spike results are compelling.Test plan
pixi run lintpixi run fmt-checkpixi run test(192 passed)pixi run spike-testwithstrategy: "independent"pixi run spike-testwithstrategy: "continuation"Closes #232
🤖 Generated with Claude Code