Skip to content

Add continuation-path fitting for fusionreg sweeps#234

Merged
jaredgalloway merged 4 commits into
mainfrom
232-fit-models-path
May 6, 2026
Merged

Add continuation-path fitting for fusionreg sweeps#234
jaredgalloway merged 4 commits into
mainfrom
232-fit-models-path

Conversation

@jaredgalloway

@jaredgalloway jaredgalloway commented Apr 17, 2026

Copy link
Copy Markdown
Member

Summary

  • Adds fit_models_path(), a sibling to fit_models() that fits a single multidms.Model sequentially along an ascending fusionreg grid, warm-starting each step from the previous fit's (β, β0, α). Output schema is identical, so it is a drop-in replacement anywhere a fit_collection.pkl is consumed.
  • Wires a spike.fitting.strategy: "independent" | "continuation" config knob into the spike Snakefile, selecting between fit_models.ipynb and the new fit_models_path.ipynb. Default stays independent.
  • Adds concat_path_trajectories() for stitching per-step convergence trajectories into a single long DataFrame.

The motivation is the Delta epistasis pathology at high fusionreg in the spike pipeline: under independent zero-initialized fits, a strong shift lasso pulls Delta's latent phenotype below the sigmoid asymptote and never recovers the β0 / α calibration. A continuation path, starting at the unregularized solution and tightening the lasso step by step, keeps each fit in the basin of its predecessor so the wildtype positioning follows the shifts as they shrink rather than being jointly distorted.

Validation

  • pixi run lint, pixi run fmt-check, and the full pixi run test suite (192 passed, up from 182 on main; 10 new tests added) are green on this branch.
  • Integration test: ran the spike test profile twice locally from a clean results_test/:
    • strategy: "independent" → pipeline completes; fit_collection.pkl produced by fit_models.ipynb.
    • strategy: "continuation" → pipeline completes; fit_collection.pkl produced by fit_models_path.ipynb (verified by inspecting the executed notebook's title and imports). Downstream evaluate and cross_validation rules run unchanged.
  • Spot-check monotonicity: Delta sparsity at the final trajectory step, per replicate, is non-decreasing along the path — rep_1: [0.129, 0.149], rep_2: [0.140, 0.170] at fusionreg ∈ [0.0, 4e-5].

What's not in this PR (deferred)

  • Human visual-validation on the dashboard that the Delta epistasis pathology actually disappears at production fusionreg=6.4e-4 with strategy: "continuation" — that is the criterion from the issue for flipping the production default, and requires a remote-pipeline run + dashboard inspection. File as follow-up once reviewer is happy with the code.
  • Flipping strategy to continuation in the production config.yaml — keeping the default conservative until the visual check passes.
  • Cross-validation loss comparison (continuation vs independent at best fusionreg) — same rationale; wait for the remote run.

Follow-ups

  • Parallelize across non-fusionreg grid axes (e.g. per replicate) while keeping each individual path sequential.
  • Apply the same strategy in experiments/simulation/ and experiments/loss-normalization/ if spike results are compelling.

Test plan

  • pixi run lint
  • pixi run fmt-check
  • pixi run test (192 passed)
  • pixi run spike-test with strategy: "independent"
  • pixi run spike-test with strategy: "continuation"
  • CI rerun across Python 3.9 / 3.10 / 3.11, macos + ubuntu
  • Manual dashboard check on remote-pipeline continuation fit (follow-up)

Closes #232

🤖 Generated with Claude Code

jaredgalloway and others added 3 commits April 21, 2026 11:09
Adds fit_models_path(), a sibling to fit_models() that fits a single
multidms.Model sequentially along an ascending fusionreg grid, warm-
starting each step from the previous fit's (β, β0, α). This replaces
the current practice of independent zero-initialized fits, which
produces pathological latent-phenotype ranges for data-poor
conditions (notably Delta in the spike pipeline) at high lasso
strengths.

Additions:
- multidms/model_collection.py:
  - fit_models_path(params, verbose): sequential path fit along
    ascending fusionreg. Returns the same (n_fit, n_failed, df)
    schema as fit_models, so the output is a drop-in to
    ModelCollection.
  - _fit_one_path_step(): helper seeding beta_init/beta0_init/
    alpha_init from the previous fit and forcing warmstart=False.
  - _assert_no_nan(): guard that raises ModelCollectionFitError
    if a fit produced NaNs in β, β0, or α — prevents NaN
    propagation into the next step's seed.
  - concat_path_trajectories(): stitches per-step convergence
    trajectories into a single long DataFrame for visualization.
- experiments/scv2-spike:
  - New notebook fit_models_path.ipynb mirroring fit_models.ipynb
    but calling fit_models_path().
  - Snakefile selects notebook based on spike.fitting.strategy
    ("independent" | "continuation"); output filename unchanged.
  - config.yaml / config_test.yaml / config_experimental.yaml
    gain strategy: "independent" (default; flip to "continuation"
    to exercise the path fitter).
- Tests (tests/test_model_collection.py, TestFitModelsPath +
  TestConcatPathTrajectories):
  - Schema parity with fit_models.
  - Seeding round-trip: beta_init/beta0_init/alpha_init at step
    k+1 equal step k's fitted params.
  - Constant-fusionreg identity (tight tolerance after convergence).
  - Single-step degeneracy (path ≡ fit_models at len-1).
  - Order invariance across datasets.
  - Monotone non-decreasing sparsity along the path.
  - NaN guard raises on bad β, β0, or α (scalar and dict α).
  - Trajectory concatenation: iteration_global monotonicity,
    fusionreg constancy per step, row-count parity.
  - Empty-input handling and non-zero-start warning.
- Docs: CLAUDE.md, experiments/scv2-spike/README.md.

Closes #232

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Three follow-ups from the review of the original commit:

1. Flip spike.fitting.strategy in experiments/scv2-spike/config/config.yaml
   (and config_test.yaml) to "continuation". The whole point of this PR
   is to replace the default fitter that produces the Delta epistasis
   pathology at production fusionreg — leaving the default on the
   broken strategy after merging the fix would be backwards.

2. fit_models_path output is now schema-identical to fit_models output:
   _fit_one_path_step clears beta_init / beta0_init / alpha_init to
   None in the returned row after fitting. Before: those cells held
   jax arrays (or dicts of them) pulled from the previous fit, which
   broke .apply(str), groupby, and any pandas operation that assumed
   scalar cells. The seeds are still fully recoverable from the
   previous row's model column, so nothing is lost.

3. The notebook's dict-stringification loop now only stringifies
   actual dicts (leaves None and other scalars alone), so a mixed
   column of dicts + None no longer coerces everything to strings.

Tests added:
- test_extract_seed_round_trip — direct unit test on _extract_seed
  replacing the now-redundant column-read round-trip test.
- test_path_step_rows_have_no_jax_seed_leakage — asserts the three
  seed columns hold only None values in path output.
- test_schema_matches_fit_models_exactly — asserts no column of the
  path DataFrame holds a dict or jax.Array (the contract that makes
  path output a true drop-in to ModelCollection).
- test_verbose_in_params_does_not_collide — regression test for the
  setdefault("verbose", verbose) fix in fit_models_path.

195 tests pass (10 → 13 in the TestFitModelsPath class).

Integration: spike-test profile still passes end-to-end with
strategy="continuation" (prepare_data → fit_models_path → evaluate
+ cross_validation). The produced fit_collection.pkl has no jax or
dict leakage in any non-model column.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The test ran two back-to-back fit_models_path calls with a 3-step
fusionreg path, which compiles enough JAX kernels on constrained
runners (~7GB CI images) to hit "LLVM compilation error: Cannot
allocate memory" on the second call. When a step failed from OOM,
fit_models_path correctly aborted the remainder of that path, but
the test then asserted len(sub_ab) == len(sub_ba) > 0 and got
"3 == 1" — a misleading failure message.

Two changes:
- Shorten the path to 2 fusionreg values (the invariance claim does
  not require more length).
- jax.clear_caches() between the two calls to release compiled
  kernels before the second fit starts.
- Assert n_failed == 0 explicitly on both calls so a future OOM
  shows up as "forward path had N step failure(s)" rather than
  length mismatch.

Local pytest still passes (13/13 in TestFitModelsPath +
TestConcatPathTrajectories).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Earlier in this branch, spike/config.yaml was flipped to
strategy: "continuation" as the production default. Reverting that flip:
production stays on the established independent fitter, and continuation
remains opt-in via the strategy key (defaulted to "independent" in the
Snakefile when absent from YAML).

The continuation code path itself, schema cleanup, and tests stay; only
the prod-default flip is reverted. config_test.yaml still exercises
continuation so CI covers the new code path end-to-end.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@jaredgalloway jaredgalloway merged commit d82657e into main May 6, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add continuation-path fitting for fusionreg sweeps

1 participant