Skip to content

State-Centered Temporal Processes#828

Open
cdc-mitzimorris wants to merge 79 commits into
mainfrom
mem_810_centered_parameterization
Open

State-Centered Temporal Processes#828
cdc-mitzimorris wants to merge 79 commits into
mainfrom
mem_810_centered_parameterization

Conversation

@cdc-mitzimorris
Copy link
Copy Markdown
Collaborator

Added state-centered parameterizations for all three temporal-process
classes in pyrenew.latent:

  • AR1 — stationary AR(1) on log-Rt levels
  • DifferencedAR1 — AR(1) on first differences of log-Rt (the production
    process)
  • RandomWalk — unconstrained drift on log-Rt

Each class now takes a constructor argument
parameterization: Literal["innovation", "state"], defaulting to
"innovation" to preserve current behavior. Setting "state" switches
the internal sampling from standardized increments to the latent state
path directly.

The state-centered variants are implemented via:

  • For RandomWalk: NumPyro's built-in dist.GaussianRandomWalk, shifted
    by the initial value.
  • For AR1 and DifferencedAR1: two new custom NumPyro Distribution
    subclasses (StateAR1, StateDifferencedAR1) in
    pyrenew/latent/state_centered_distributions.py. Both have vectorized
    log_prob using slice arithmetic (no scan during MCMC) and
    lax.scan-based sample (only called for prior/posterior predictive,
    not on the MCMC gradient path).

Both parameterizations encode the same prior distribution over the
state path. They differ only in sampler geometry — which latent
variables HMC sees and operates on.

Code added

File Type Purpose
pyrenew/latent/state_centered_distributions.py new StateAR1, StateDifferencedAR1
pyrenew/latent/temporal_processes.py modified parameterization flag on all three classes; _prepare_initial_value helper
test/test_temporal_processes.py modified +31 unit tests (parameterization flag, state-centered shape/site/prior-equivalence)
test/test_helpers.py modified fixed_ar1_state, fixed_differenced_ar1_state factories
test/integration/conftest.py modified he_model_state_centered, he_weekly_rt_model_state_centered, he_weekly_model_state_centered fixtures
test/integration/test_population_infections_he_state_centered.py new 5 end-to-end tests, daily Rt
test/integration/test_population_infections_he_weekly_rt_state_centered.py new 5 end-to-end tests, weekly Rt via WeeklyTemporalProcess
_typos.toml modified Whitelist reparametrized_params (NumPyro upstream attribute name)

@cdc-mitzimorris
Copy link
Copy Markdown
Collaborator Author

ran the benchmarks on my machine - here are the results:

time python -m benchmarks.suites.rt_params --candidate he --prior both --repeats 3
rt_params suite: 4 candidate(s) x 2 prior(s) x 3 repeat(s) = 24 fits
>> fitting he_daily_innovation@sd=0.01,ar=0.9 (repeat 1/3) ...
   done he_daily_innovation@sd=0.01,ar=0.9 (repeat 1/3): 62.9s, divergences=0, min ESS/s=0.15
>> fitting he_daily_innovation@sd=0.01,ar=0.9 (repeat 2/3) ...
   done he_daily_innovation@sd=0.01,ar=0.9 (repeat 2/3): 66.4s, divergences=0, min ESS/s=0.25
>> fitting he_daily_innovation@sd=0.01,ar=0.9 (repeat 3/3) ...
   done he_daily_innovation@sd=0.01,ar=0.9 (repeat 3/3): 68.4s, divergences=0, min ESS/s=0.14
>> fitting he_daily_state@sd=0.01,ar=0.9 (repeat 1/3) ...
   done he_daily_state@sd=0.01,ar=0.9 (repeat 1/3): 63.5s, divergences=0, min ESS/s=5.79
>> fitting he_daily_state@sd=0.01,ar=0.9 (repeat 2/3) ...
   done he_daily_state@sd=0.01,ar=0.9 (repeat 2/3): 62.7s, divergences=0, min ESS/s=5.59
>> fitting he_daily_state@sd=0.01,ar=0.9 (repeat 3/3) ...
   done he_daily_state@sd=0.01,ar=0.9 (repeat 3/3): 63.4s, divergences=0, min ESS/s=6.61
>> fitting he_weekly_innovation@sd=0.01,ar=0.9 (repeat 1/3) ...
   done he_weekly_innovation@sd=0.01,ar=0.9 (repeat 1/3): 68.9s, divergences=0, min ESS/s=1.05
>> fitting he_weekly_innovation@sd=0.01,ar=0.9 (repeat 2/3) ...
   done he_weekly_innovation@sd=0.01,ar=0.9 (repeat 2/3): 69.3s, divergences=0, min ESS/s=0.12
>> fitting he_weekly_innovation@sd=0.01,ar=0.9 (repeat 3/3) ...
   done he_weekly_innovation@sd=0.01,ar=0.9 (repeat 3/3): 70.7s, divergences=0, min ESS/s=0.47
>> fitting he_weekly_state@sd=0.01,ar=0.9 (repeat 1/3) ...
   done he_weekly_state@sd=0.01,ar=0.9 (repeat 1/3): 17.9s, divergences=0, min ESS/s=28.92
>> fitting he_weekly_state@sd=0.01,ar=0.9 (repeat 2/3) ...
   done he_weekly_state@sd=0.01,ar=0.9 (repeat 2/3): 16.6s, divergences=0, min ESS/s=30.93
>> fitting he_weekly_state@sd=0.01,ar=0.9 (repeat 3/3) ...
   done he_weekly_state@sd=0.01,ar=0.9 (repeat 3/3): 16.8s, divergences=0, min ESS/s=32.74
>> fitting he_daily_innovation@sd=0.1,ar=0.5 (repeat 1/3) ...
   done he_daily_innovation@sd=0.1,ar=0.5 (repeat 1/3): 79.7s, divergences=0, min ESS/s=0.03
>> fitting he_daily_innovation@sd=0.1,ar=0.5 (repeat 2/3) ...
   done he_daily_innovation@sd=0.1,ar=0.5 (repeat 2/3): 79.4s, divergences=0, min ESS/s=0.03
>> fitting he_daily_innovation@sd=0.1,ar=0.5 (repeat 3/3) ...
   done he_daily_innovation@sd=0.1,ar=0.5 (repeat 3/3): 80.2s, divergences=0, min ESS/s=0.03
>> fitting he_daily_state@sd=0.1,ar=0.5 (repeat 1/3) ...
   done he_daily_state@sd=0.1,ar=0.5 (repeat 1/3): 30.0s, divergences=0, min ESS/s=10.49
>> fitting he_daily_state@sd=0.1,ar=0.5 (repeat 2/3) ...
   done he_daily_state@sd=0.1,ar=0.5 (repeat 2/3): 31.4s, divergences=0, min ESS/s=9.56
>> fitting he_daily_state@sd=0.1,ar=0.5 (repeat 3/3) ...
   done he_daily_state@sd=0.1,ar=0.5 (repeat 3/3): 29.4s, divergences=0, min ESS/s=10.88
>> fitting he_weekly_innovation@sd=0.1,ar=0.5 (repeat 1/3) ...
   done he_weekly_innovation@sd=0.1,ar=0.5 (repeat 1/3): 72.2s, divergences=0, min ESS/s=0.03
>> fitting he_weekly_innovation@sd=0.1,ar=0.5 (repeat 2/3) ...
   done he_weekly_innovation@sd=0.1,ar=0.5 (repeat 2/3): 72.8s, divergences=0, min ESS/s=0.04
>> fitting he_weekly_innovation@sd=0.1,ar=0.5 (repeat 3/3) ...
   done he_weekly_innovation@sd=0.1,ar=0.5 (repeat 3/3): 73.8s, divergences=0, min ESS/s=0.04
>> fitting he_weekly_state@sd=0.1,ar=0.5 (repeat 1/3) ...
   done he_weekly_state@sd=0.1,ar=0.5 (repeat 1/3): 22.6s, divergences=0, min ESS/s=42.72
>> fitting he_weekly_state@sd=0.1,ar=0.5 (repeat 2/3) ...
   done he_weekly_state@sd=0.1,ar=0.5 (repeat 2/3): 22.3s, divergences=0, min ESS/s=44.03
>> fitting he_weekly_state@sd=0.1,ar=0.5 (repeat 3/3) ...
   done he_weekly_state@sd=0.1,ar=0.5 (repeat 3/3): 22.3s, divergences=0, min ESS/s=52.17

--- synthetic_he_weekly_hospital | cadence=daily | innovation_sd=0.01 ---
metric                   innovation        state  state/innov
--------------------------------------------------------------
Wall time (s)                  65.9         63.2        0.96x
ESS/s Rt (median)             0.748       27.329     36.53x *
ESS/s Rt (min)                0.183        5.997     32.78x *
Divergences                       0            0          n/a
Tree depth (mean)             10.00         9.91        0.99x
Tree depth (max)                 10           10        1.00x
E-BFMI (min)                  0.888        0.943      1.06x *
R-hat Rt (max)                1.275        1.006      0.79x *

--- synthetic_he_weekly_hospital | cadence=weekly | innovation_sd=0.01 ---
metric                   innovation        state  state/innov
--------------------------------------------------------------
Wall time (s)                  69.7         17.1      0.25x *
ESS/s Rt (median)             1.856       98.404     53.02x *
ESS/s Rt (min)                0.546       30.864     56.49x *
Divergences                       0            0          n/a
Tree depth (mean)             10.00         7.17      0.72x *
Tree depth (max)                 10            9      0.90x *
E-BFMI (min)                  0.896        0.925        1.03x
R-hat Rt (max)                1.150        1.005      0.87x *

--- synthetic_he_weekly_hospital | cadence=daily | innovation_sd=0.1 ---
metric                   innovation        state  state/innov
--------------------------------------------------------------
Wall time (s)                  79.8         30.2      0.38x *
ESS/s Rt (median)             0.078       72.302    928.36x *
ESS/s Rt (min)                0.032       10.311    322.37x *
Divergences                       0            0          n/a
Tree depth (mean)             10.00         8.06      0.81x *
Tree depth (max)                 10           10        1.00x
E-BFMI (min)                  0.901        0.920        1.02x
R-hat Rt (max)                2.350        1.014      0.43x *

--- synthetic_he_weekly_hospital | cadence=weekly | innovation_sd=0.1 ---
metric                   innovation        state  state/innov
--------------------------------------------------------------
Wall time (s)                  73.0         22.4      0.31x *
ESS/s Rt (median)             0.098       75.878    772.58x *
ESS/s Rt (min)                0.038       46.309   1226.16x *
Divergences                       0            0          n/a
Tree depth (mean)             10.00         7.58      0.76x *
Tree depth (max)                 10            9      0.90x *
E-BFMI (min)                  0.980        0.941        0.96x
R-hat Rt (max)                2.165        1.004      0.46x *

(* marks an improvement over innovation; ratios are state / innovation)

Wrote results to benchmarks/results

real	21m15.997s
user	80m23.152s
sys	0m7.630s

cdc-mitzimorris and others added 21 commits May 19, 2026 17:47
…e time 0) (#827)

* bug fix and unit tests

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* update unit test to match code

* revert changes, apply simpler fix

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
…v/PyRenew into mem_810_centered_parameterization
…v/PyRenew into mem_810_centered_parameterization
@cdc-mitzimorris
Copy link
Copy Markdown
Collaborator Author

@dylanhmorris @sbidari ready for code review

@SamuelBrand1
Copy link
Copy Markdown
Collaborator

Linking epiforecasts/EpiNow2#1396 because it seems to cover a similar reparametrisation

@cdc-mitzimorris
Copy link
Copy Markdown
Collaborator Author

Linking epiforecasts/EpiNow2#1396 because it seems to cover a similar reparametrisation

thanks! good to know there's good precedent.

@cdc-mitzimorris
Copy link
Copy Markdown
Collaborator Author

To simplify this PR, I propose splitting out the benchmarks code, which adds ~2,346 lines of source plus a 718-line test file, separately reviewable from the API change.

The code added/ modified by this PR, exclusive of benchmarks code is:

File Type +/− Purpose
pyrenew/latent/state_centered_distributions.py new +387 StateAR1, StateDifferencedAR1 custom NumPyro distributions
pyrenew/latent/temporal_processes.py modified +207 / −71 parameterization flag on AR1, DifferencedAR1, RandomWalk; _prepare_initial_value helper
test/test_temporal_processes.py modified +493 Unit tests for parameterization flag, state-centered shape/site/prior-equivalence
test/test_helpers.py modified +46 fixed_ar1_state, fixed_differenced_ar1_state factories
test/integration/conftest.py modified +122 / −111 he_model_state_centered, he_weekly_rt_model_state_centered, he_weekly_model_state_centered fixtures
test/integration/test_population_infections_he_state_centered.py new +185 5 end-to-end tests, daily Rt
test/integration/test_population_infections_he_weekly_rt_state_centered.py new +227 5 end-to-end tests, weekly Rt via WeeklyTemporalProcess
_typos.toml modified +6 Whitelist reparametrized_params (NumPyro upstream attribute)

API-change subtotal: ~1,673 added / ~182 removed.

Other changes also picked up on this branch (not in the PR description)

File Type +/− Notes
.gitignore modified +3
pyproject.toml modified +2
docs/tutorials/_quarto.yml modified +1 / −1
docs_scripts/add_markdown_to_divs.py new +32 Replacement for the removed postprocessing script
docs_scripts/postprocess_generated_markdown.py removed −62 Replaced by add_markdown_to_divs.py
test/test_docs_postprocessing.py removed −39 Tests for the removed script
test/test_distributional_rv.py modified +2 / −2 Minor tweak

Other-changes subtotal: ~40 added / ~104 removed.

Benchmarking component adds roughly 3K lines of code. (oof!)

File Type Lines Purpose
benchmarks/__init__.py new 11 Package docstring; entry-point conventions
benchmarks/README.md new 177 How to run suites, layout, extension points
benchmarks/core/__init__.py new 1 Package marker
benchmarks/core/signals.py new 104 DatasetProvider protocol — seam between suites and data source
benchmarks/core/datasets.py new 101 Synthetic DatasetProvider wrapping pyrenew/datasets/
benchmarks/core/real_data.py new 245 Real-data DatasetProvider (CDC NHSN + NSSP feeds)
benchmarks/core/reference_data.py new 85 Static US location/population table (replaces R forecasttools dep)
benchmarks/core/priors.py new 42 Benchmark-local priors mirroring HEW production subset
benchmarks/core/models.py new 295 build_* model builders pairing a dataset with a model family
benchmarks/core/runner.py new 275 Single-fit MCMC wrapper; collects timing + diagnostic metrics
benchmarks/core/reporting.py new 610 CSV/JSON/Markdown reporters; pair comparisons, candidate summaries
benchmarks/suites/__init__.py new 1 Package marker
benchmarks/suites/rt_params.py new 399 The rt_params suite — compares innovation vs state on weekly Rt
test/test_benchmarks_rt_params.py new 718 Unit tests for the suite, runner, reporting, and providers

Totals: 13 source files (~2,346 lines) + 718 lines of tests. The benchmark framework is self-contained under benchmarks/; the only PR change outside that tree is the test file above.

@dylanhmorris, @sbidari

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants