Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
680bb1e
Merge branch 'main' of https://github.com/CDCgov/PyRenew
cdc-mitzimorris Sep 15, 2025
2cb876b
update
cdc-mitzimorris Sep 18, 2025
60db8df
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Sep 22, 2025
32a5314
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Oct 5, 2025
d6213f2
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Oct 8, 2025
96f27c9
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Nov 17, 2025
1cb6fa2
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Nov 24, 2025
f62e1e4
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Dec 4, 2025
0c6785d
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Dec 22, 2025
1ee62b9
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Jan 29, 2026
0629461
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 4, 2026
efeadee
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 5, 2026
371ba98
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 5, 2026
0304bed
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 6, 2026
ffeea65
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 9, 2026
50e7261
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 9, 2026
dae6af8
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 10, 2026
5cb3097
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 11, 2026
1d80ccc
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 11, 2026
e73b401
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 12, 2026
b1473b5
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 18, 2026
0b929b5
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 18, 2026
3ee00a7
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 24, 2026
307982a
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 24, 2026
b862bc6
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Feb 26, 2026
2c665a5
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Mar 11, 2026
60d6458
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Mar 12, 2026
ec8c464
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Mar 19, 2026
c018bf7
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Mar 24, 2026
d0207dd
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 4, 2026
f3c706a
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 9, 2026
684c6c5
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 10, 2026
ca2454f
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 13, 2026
0f38afc
merge
cdc-mitzimorris Apr 14, 2026
d8e7a57
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 16, 2026
7e9b5fe
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 24, 2026
e1d8014
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 24, 2026
83ddbf0
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 24, 2026
69ea4ea
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 24, 2026
555e87b
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris Apr 28, 2026
fa5a7cb
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris May 4, 2026
69cdab0
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris May 6, 2026
c28a89f
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris May 6, 2026
fd091ca
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris May 7, 2026
8cee471
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris May 7, 2026
b2a1e1a
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris May 11, 2026
2006afd
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris May 12, 2026
a31ec85
Merge branch 'main' of github-bf06:CDCgov/PyRenew
cdc-mitzimorris May 13, 2026
0db35a9
implementation and unit tests for centered versions of temporal proce…
cdc-mitzimorris May 19, 2026
a92d58b
checkpointing - test cleanup
cdc-mitzimorris May 19, 2026
9911f4a
checkpointing
cdc-mitzimorris May 19, 2026
7795672
fix unit test
cdc-mitzimorris May 19, 2026
b7030a3
benchmark test suite
cdc-mitzimorris May 19, 2026
c0d4684
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 19, 2026
fe81470
Merge branch 'mem_810_centered_parameterization' of github-bf06:CDCgo…
cdc-mitzimorris May 19, 2026
ee0c276
lint fix
cdc-mitzimorris May 19, 2026
1e99920
more unit tests
cdc-mitzimorris May 19, 2026
c8a8764
checkpointing
cdc-mitzimorris May 26, 2026
0c852ab
refactoring benchmarks
cdc-mitzimorris May 26, 2026
c67fe92
Day-of-week effects applied on observation time axis (i.e., not befor…
cdc-mitzimorris May 19, 2026
197d9da
simplify benchmarks
cdc-mitzimorris May 27, 2026
3fe2f2b
remove dependency on R forecasttools package by substituting local po…
cdc-mitzimorris May 27, 2026
fec5f4c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2026
f910b2b
remove dependency on R forecasttools package by substituting local po…
cdc-mitzimorris May 27, 2026
a3e34ab
Merge branch 'mem_810_centered_parameterization' of github-bf06:CDCgo…
cdc-mitzimorris May 27, 2026
7d9619a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2026
360e8f0
more informative benchmark outputs
cdc-mitzimorris May 27, 2026
ab623e6
Merge branch 'mem_810_centered_parameterization' of github-bf06:CDCgo…
cdc-mitzimorris May 27, 2026
07f5bb6
checkpointing
cdc-mitzimorris May 27, 2026
72a4f19
fixing real data loading
cdc-mitzimorris May 27, 2026
670ee27
fix typo
cdc-mitzimorris May 28, 2026
b62cddf
fix typo
cdc-mitzimorris May 28, 2026
1f0b68f
deptry fix
cdc-mitzimorris May 28, 2026
7a9031e
changes per bot review
cdc-mitzimorris May 29, 2026
ed614e6
cleanup
cdc-mitzimorris May 29, 2026
70e116c
cleanup
cdc-mitzimorris May 29, 2026
a16ce91
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 29, 2026
26616bf
tweak benchmarks report
cdc-mitzimorris May 29, 2026
543135b
Merge branch 'mem_810_centered_parameterization' of github-bf06:CDCgo…
cdc-mitzimorris May 29, 2026
b1367ac
simplify PR; remove all benchmarks code and tests
cdc-mitzimorris Jun 1, 2026
29eff16
more benchmarks cleanup
cdc-mitzimorris Jun 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions _typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@
arange = "arange"
lod = "lod"
dows = "dows"

[default.extend-identifiers]
# NumPyro's Distribution base class spells this with a typo; we must
# match the upstream attribute name for `has_rsample` to work correctly.
reparametrized_params = "reparametrized_params"
387 changes: 387 additions & 0 deletions pyrenew/latent/state_centered_distributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,387 @@
"""NumPyro distributions for state-centered temporal-process priors."""

from __future__ import annotations

import jax
import jax.numpy as jnp
from jax import lax, random
from jax.typing import ArrayLike
from numpyro.distributions import constraints
from numpyro.distributions.continuous import Normal
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import validate_sample
from numpyro.util import is_prng_key


class StateRandomWalk(Distribution):
r"""
State-centered random-walk prior on a post-initial state path.

Given a deterministic initial state $x_0$ = ``initial_loc``:

$$
x_t \sim \mathrm{Normal}(x_{t-1}, \sigma), \quad t = 1, \dots, T
$$

The sampled value is the post-initial path
$[x_1, x_2, \ldots, x_{\mathrm{num\_steps}}]$ of length ``num_steps``.
"""
Comment on lines +16 to +28

arg_constraints = {
"scale": constraints.positive,
"initial_loc": constraints.real,
}
support = constraints.real_vector
reparametrized_params = ["scale", "initial_loc"]
pytree_aux_fields = ("num_steps",)

def __init__(
self,
scale: ArrayLike,
initial_loc: ArrayLike = 0.0,
num_steps: int = 1,
*,
validate_args: bool | None = None,
) -> None:
"""Construct a state-centered random-walk distribution."""
if not isinstance(num_steps, int) or num_steps <= 0:
raise ValueError(f"num_steps must be a positive integer; got {num_steps!r}")
self.scale = scale
self.initial_loc = initial_loc
self.num_steps = num_steps

batch_shape = lax.broadcast_shapes(
jnp.shape(scale),
jnp.shape(initial_loc),
)
super().__init__(batch_shape, (num_steps,), validate_args=validate_args)

def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
"""
Forward-sample a post-initial random-walk state path.

Returns
-------
ArrayLike
Array of shape ``sample_shape + batch_shape + (num_steps,)``.
"""
assert is_prng_key(key)

per_step_shape = sample_shape + self.batch_shape
scale = jnp.broadcast_to(jnp.asarray(self.scale), per_step_shape)
initial_loc = jnp.broadcast_to(jnp.asarray(self.initial_loc), per_step_shape)
noise = random.normal(key, shape=per_step_shape + (self.num_steps,))
increments = scale[..., jnp.newaxis] * noise
return initial_loc[..., jnp.newaxis] + jnp.cumsum(increments, axis=-1)

@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
"""
Compute the log-density of an observed post-initial state path.

Parameters
----------
value
Post-initial path of shape
``sample_shape + batch_shape + (num_steps,)``.

Returns
-------
ArrayLike
Log-density of shape ``sample_shape + batch_shape``.
"""
scale = jnp.asarray(self.scale)
initial_loc = jnp.asarray(self.initial_loc)
init_with_event = jnp.expand_dims(initial_loc, -1)
init_bcast = jnp.broadcast_to(init_with_event, value.shape[:-1] + (1,))
v = jnp.concatenate([init_bcast, value], axis=-1)
step_probs = Normal(v[..., :-1], jnp.expand_dims(scale, -1)).log_prob(
v[..., 1:]
)
return jnp.sum(step_probs, axis=-1)


class StateAR1(Distribution):
r"""
State-centered AR(1) prior on a length-``num_steps`` state path.

Generative form:

$$
x_0 \sim \mathrm{Normal}(\mu_0, \sigma_{\text{stat}})
$$
$$
x_t \sim \mathrm{Normal}(\phi \, x_{t-1}, \sigma), \quad t = 1, \dots, T-1
$$

where $\sigma_{\text{stat}} = \sigma / \sqrt{1 - \phi^2}$ is the
stationary standard deviation, $\mu_0$ is ``initial_loc``, $\phi$ is
``autoreg``, and $\sigma$ is ``scale``.

The sampled value is the full path $[x_0, x_1, \ldots, x_{T-1}]$.

Parameters
----------
autoreg
AR(1) coefficient $\phi$. For stationarity, $|\phi| < 1$; this is
not enforced.
scale
Innovation standard deviation $\sigma$. Must be positive.
initial_loc
Prior mean $\mu_0$ of the initial state $x_0$. Defaults to ``0.0``.
num_steps
Length of the state path. Must be a positive integer.
validate_args
Forwarded to the base [`numpyro.distributions.Distribution`][].
"""

arg_constraints = {
"autoreg": constraints.real,
"scale": constraints.positive,
"initial_loc": constraints.real,
}
support = constraints.real_vector
reparametrized_params = ["autoreg", "scale", "initial_loc"]
pytree_aux_fields = ("num_steps",)

def __init__(
self,
autoreg: ArrayLike,
scale: ArrayLike,
initial_loc: ArrayLike = 0.0,
num_steps: int = 1,
*,
validate_args: bool | None = None,
) -> None:
"""
Construct a state-centered AR(1) distribution.

Raises
------
ValueError
If ``num_steps`` is not a positive integer.
"""
if not isinstance(num_steps, int) or num_steps <= 0:
raise ValueError(f"num_steps must be a positive integer; got {num_steps!r}")
self.autoreg = autoreg
self.scale = scale
self.initial_loc = initial_loc
self.num_steps = num_steps

batch_shape = lax.broadcast_shapes(
jnp.shape(autoreg),
jnp.shape(scale),
jnp.shape(initial_loc),
)
event_shape = (num_steps,)
super().__init__(batch_shape, event_shape, validate_args=validate_args)

def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
"""
Forward-sample a state path.

Returns
-------
ArrayLike
Array of shape ``sample_shape + batch_shape + (num_steps,)``.
"""
assert is_prng_key(key)

per_step_shape = sample_shape + self.batch_shape
autoreg = jnp.broadcast_to(jnp.asarray(self.autoreg), per_step_shape)
scale = jnp.broadcast_to(jnp.asarray(self.scale), per_step_shape)
initial_loc = jnp.broadcast_to(jnp.asarray(self.initial_loc), per_step_shape)
stationary_sd = scale / jnp.sqrt(1 - autoreg**2)

noise = random.normal(key, shape=(self.num_steps,) + per_step_shape)
z0 = noise[0]
x0 = initial_loc + stationary_sd * z0

if self.num_steps == 1:
return x0[..., jnp.newaxis]

def step(
prev: ArrayLike, z_t: ArrayLike
) -> tuple[ArrayLike, ArrayLike]: # numpydoc ignore=GL08
new = autoreg * prev + scale * z_t
return new, new

_, xs = lax.scan(step, x0, noise[1:])
path_time_first = jnp.concatenate([x0[jnp.newaxis], xs], axis=0)
return jnp.moveaxis(path_time_first, 0, -1)

@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
"""
Compute the log-density of an observed state path.

Parameters
----------
value
State path of shape ``sample_shape + batch_shape + (num_steps,)``.

Returns
-------
ArrayLike
Log-density of shape ``sample_shape + batch_shape``.
"""
scale = jnp.asarray(self.scale)
autoreg = jnp.asarray(self.autoreg)
stationary_sd = scale / jnp.sqrt(1 - autoreg**2)

init_prob = Normal(self.initial_loc, stationary_sd).log_prob(value[..., 0])

scale_t = jnp.expand_dims(scale, -1)
autoreg_t = jnp.expand_dims(autoreg, -1)
step_locs = autoreg_t * value[..., :-1]
step_probs = Normal(step_locs, scale_t).log_prob(value[..., 1:])
return init_prob + jnp.sum(step_probs, axis=-1)


class StateDifferencedAR1(Distribution):
r"""
State-centered differenced AR(1) prior on a length-``num_steps`` post-initial path.

Generative form, given a deterministic initial state $x_0$ = ``initial_loc``:

$$
x_1 \sim \mathrm{Normal}(x_0, \sigma_{\text{stat}})
$$
$$
x_t \sim \mathrm{Normal}(x_{t-1} + \phi \, (x_{t-1} - x_{t-2}), \sigma),
\quad t \geq 2
$$

where $\sigma_{\text{stat}} = \sigma / \sqrt{1 - \phi^2}$, $\phi$ is
``autoreg``, and $\sigma$ is ``scale``.

The sampled value is the post-initial path
$[x_1, x_2, \ldots, x_{\mathrm{num\_steps}}]$ of length ``num_steps``.
The initial state $x_0$ is not part of the sample; it is supplied as
``initial_loc`` and used to score the first transition.

Parameters
----------
autoreg
AR(1) coefficient $\phi$ on first differences. For stationarity,
$|\phi| < 1$; this is not enforced.
scale
Innovation standard deviation $\sigma$. Must be positive.
initial_loc
Deterministic initial state $x_0$. Used to score the first
transition; not itself sampled.
num_steps
Length of the post-initial path. Must be a positive integer.
validate_args
Forwarded to the base [`numpyro.distributions.Distribution`][].
"""

arg_constraints = {
"autoreg": constraints.real,
"scale": constraints.positive,
"initial_loc": constraints.real,
}
support = constraints.real_vector
reparametrized_params = ["autoreg", "scale", "initial_loc"]
pytree_aux_fields = ("num_steps",)

def __init__(
self,
autoreg: ArrayLike,
scale: ArrayLike,
initial_loc: ArrayLike = 0.0,
num_steps: int = 1,
*,
validate_args: bool | None = None,
) -> None:
"""
Construct a state-centered differenced AR(1) distribution.

Raises
------
ValueError
If ``num_steps`` is not a positive integer.
"""
if not isinstance(num_steps, int) or num_steps <= 0:
raise ValueError(f"num_steps must be a positive integer; got {num_steps!r}")
self.autoreg = autoreg
self.scale = scale
self.initial_loc = initial_loc
self.num_steps = num_steps

batch_shape = lax.broadcast_shapes(
jnp.shape(autoreg),
jnp.shape(scale),
jnp.shape(initial_loc),
)
event_shape = (num_steps,)
super().__init__(batch_shape, event_shape, validate_args=validate_args)

def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
"""
Forward-sample a post-initial path.

Returns
-------
ArrayLike
Array of shape ``sample_shape + batch_shape + (num_steps,)``.
"""
assert is_prng_key(key)

per_step_shape = sample_shape + self.batch_shape
autoreg = jnp.broadcast_to(jnp.asarray(self.autoreg), per_step_shape)
scale = jnp.broadcast_to(jnp.asarray(self.scale), per_step_shape)
initial_loc = jnp.broadcast_to(jnp.asarray(self.initial_loc), per_step_shape)
stationary_sd = scale / jnp.sqrt(1 - autoreg**2)

noise = random.normal(key, shape=(self.num_steps,) + per_step_shape)
z1 = noise[0]
x1 = initial_loc + stationary_sd * z1

if self.num_steps == 1:
return x1[..., jnp.newaxis]

def step(
carry: tuple[ArrayLike, ArrayLike], z_t: ArrayLike
) -> tuple[tuple[ArrayLike, ArrayLike], ArrayLike]: # numpydoc ignore=GL08
prev_2, prev_1 = carry
new = prev_1 + autoreg * (prev_1 - prev_2) + scale * z_t
return (prev_1, new), new

_, xs = lax.scan(step, (initial_loc, x1), noise[1:])
path_time_first = jnp.concatenate([x1[jnp.newaxis], xs], axis=0)
return jnp.moveaxis(path_time_first, 0, -1)

@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
"""
Compute the log-density of an observed post-initial path.

Parameters
----------
value
Post-initial path of shape
``sample_shape + batch_shape + (num_steps,)``.

Returns
-------
ArrayLike
Log-density of shape ``sample_shape + batch_shape``.
"""
scale = jnp.asarray(self.scale)
autoreg = jnp.asarray(self.autoreg)
initial_loc = jnp.asarray(self.initial_loc)
stationary_sd = scale / jnp.sqrt(1 - autoreg**2)

init_prob = Normal(initial_loc, stationary_sd).log_prob(value[..., 0])

init_with_event = jnp.expand_dims(initial_loc, -1)
init_bcast = jnp.broadcast_to(init_with_event, value.shape[:-1] + (1,))
v = jnp.concatenate([init_bcast, value], axis=-1)

prev_delta = v[..., 1:-1] - v[..., :-2]
scale_t = jnp.expand_dims(scale, -1)
autoreg_t = jnp.expand_dims(autoreg, -1)
means = v[..., 1:-1] + autoreg_t * prev_delta
step_probs = Normal(means, scale_t).log_prob(v[..., 2:])
return init_prob + jnp.sum(step_probs, axis=-1)
Loading
Loading