State-Centered Temporal Processes#828
Conversation
…v/PyRenew into mem_810_centered_parameterization
|
ran the benchmarks on my machine - here are the results: |
…e time 0) (#827) * bug fix and unit tests * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * update unit test to match code * revert changes, apply simpler fix --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
for more information, see https://pre-commit.ci
…v/PyRenew into mem_810_centered_parameterization
for more information, see https://pre-commit.ci
…v/PyRenew into mem_810_centered_parameterization
for more information, see https://pre-commit.ci
|
@dylanhmorris @sbidari ready for code review |
…v/PyRenew into mem_810_centered_parameterization
|
Linking epiforecasts/EpiNow2#1396 because it seems to cover a similar reparametrisation |
thanks! good to know there's good precedent. |
|
To simplify this PR, I propose splitting out the The code added/ modified by this PR, exclusive of benchmarks code is:
API-change subtotal: ~1,673 added / ~182 removed. Other changes also picked up on this branch (not in the PR description)
Other-changes subtotal: ~40 added / ~104 removed. Benchmarking component adds roughly 3K lines of code. (oof!)
Totals: 13 source files (~2,346 lines) + 718 lines of tests. The benchmark framework is self-contained under |
Added state-centered parameterizations for all three temporal-process
classes in
pyrenew.latent:AR1— stationary AR(1) on log-Rt levelsDifferencedAR1— AR(1) on first differences of log-Rt (the productionprocess)
RandomWalk— unconstrained drift on log-RtEach class now takes a constructor argument
parameterization: Literal["innovation", "state"], defaulting to"innovation"to preserve current behavior. Setting"state"switchesthe internal sampling from standardized increments to the latent state
path directly.
The state-centered variants are implemented via:
RandomWalk: NumPyro's built-indist.GaussianRandomWalk, shiftedby the initial value.
AR1andDifferencedAR1: two new custom NumPyroDistributionsubclasses (
StateAR1,StateDifferencedAR1) inpyrenew/latent/state_centered_distributions.py. Both have vectorizedlog_probusing slice arithmetic (no scan during MCMC) andlax.scan-basedsample(only called for prior/posterior predictive,not on the MCMC gradient path).
Both parameterizations encode the same prior distribution over the
state path. They differ only in sampler geometry — which latent
variables HMC sees and operates on.
Code added
pyrenew/latent/state_centered_distributions.pyStateAR1,StateDifferencedAR1pyrenew/latent/temporal_processes.pyparameterizationflag on all three classes;_prepare_initial_valuehelpertest/test_temporal_processes.pytest/test_helpers.pyfixed_ar1_state,fixed_differenced_ar1_statefactoriestest/integration/conftest.pyhe_model_state_centered,he_weekly_rt_model_state_centered,he_weekly_model_state_centeredfixturestest/integration/test_population_infections_he_state_centered.pytest/integration/test_population_infections_he_weekly_rt_state_centered.pyWeeklyTemporalProcess_typos.tomlreparametrized_params(NumPyro upstream attribute name)