Skip to content
16 changes: 16 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
This is the changelog for evofr.
All notable changes in a release will be documented in this file.

This changelog is intended for _humans_ and follows many of the principles from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

Versions for this project follow the [Semantic Versioning rules](https://semver.org/spec/v2.0.0.html).
Each heading below is a version released to [PyPI](https://pypi.org/project/evofr/) and the date it was released.
The "__NEXT__" heading below describes changes in the unreleased development source code and as such may not be routinely kept up to date.

# __NEXT__

# 0.2.0 (February 9, 2026)

## Features

- Support alternate point estimators for posterior distributions of frequencies and growth advantages (e.g., mean or median). See [#65](https://github.com/blab/evofr/pull/65) for more.
1 change: 1 addition & 0 deletions evofr/commands/run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def export_results(posterior, export_config):
forecasts=forecasts,
ps=[0.5, 0.8, 0.95], # Default percentiles
name=posterior.name,
ps_point_estimator=export_config.get("ps_point_estimator", "median"),
)

results["metadata"]["updated"] = pd.to_datetime(date.today()).isoformat()
Expand Down
37 changes: 35 additions & 2 deletions evofr/posterior/posterior_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def get_quantile(samples: Dict, p, site):
return jnp.quantile(samples[site], q=q, axis=0)


def get_mean(samples: Dict, site):
"""Returns mean value across all samples for a site"""
return jnp.mean(samples[site], axis=0)


def get_median(samples: Dict, site):
"""Returns median value across all samples for a site"""
return jnp.median(samples[site], axis=0)
Expand Down Expand Up @@ -288,17 +293,24 @@ def get_sites_variants_tidy(
forecasts: List[bool],
ps,
name: Optional[str] = None,
ps_point_estimator: Optional[str] = "median",
):
# Save metadata
metadata = dict()

# Make keys for probability levels
ps_keys = ["median"]
ps_keys = [
"median",
"mean",
]
for p in ps:
ps_keys.append(f"HDI_{round(p * 100)}_upper")
ps_keys.append(f"HDI_{round(p * 100)}_lower")
metadata["ps"] = ps_keys

# Save the requested point estimator function.
metadata["ps_point_estimator"] = ps_point_estimator

metadata["sites"] = sites
if name:
metadata["location"] = [name]
Expand Down Expand Up @@ -332,6 +344,7 @@ def tidy_site_date(site, forecast):
# Loop over entries of median and
med, quants = get_quantiles(samples, ps, site)
med, quants = np.array(med), np.array(quants)
means = np.array(get_mean(samples, site))

entries = []
T, N_variants = med.shape
Expand Down Expand Up @@ -363,6 +376,14 @@ def tidy_site_date(site, forecast):
# Add median entry
entries.append(entry_med)

# Create mean entry
entry_mean = entry.copy()
entry_mean["value"] = np.around(means[index, v], decimals=3)
entry_mean["ps"] = "mean"

# Add mean entry
entries.append(entry_mean)

# Loop over intervals of interest
for i, p in enumerate(ps):
entry_lower = entry.copy()
Expand All @@ -384,6 +405,7 @@ def tidy_site_flat(site):
# Loop over entries of median and
med, quants = get_quantiles(samples, ps, site)
med, quants = np.array(med), np.array(quants)
means = np.array(get_mean(samples, site))

entries = []
N_variants = med.shape[0]
Expand All @@ -407,6 +429,14 @@ def tidy_site_flat(site):
# Add median entry
entries.append(entry_med)

# Create mean entry
entry_mean = entry.copy()
entry_mean["value"] = np.around(means[v], decimals=3)
entry_mean["ps"] = "mean"

# Add mean entry
entries.append(entry_mean)

# Loop over intervals of interest
for i, p in enumerate(ps):
entry_lower = entry.copy()
Expand Down Expand Up @@ -439,7 +469,10 @@ def combine_sites_tidy(tidy_dicts):

for tidy_dict in tidy_dicts:
for key, value in tidy_dict["metadata"].items():
metadata[key].extend([v for v in value if v not in metadata[key]])
if isinstance(value, list):
metadata[key].extend([v for v in value if v not in metadata[key]])
else:
metadata[key] = value

# Loop over data
entries = []
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "evofr"
version = "0.1.27"
version = "0.2.0"
description = "Tools for evolutionary forecasting."
authors = ["marlinfiggins <marlinfiggins@gmail.com>"]
license = "AGPL-3.0"
Expand Down
1 change: 1 addition & 0 deletions test/configs/run_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ export:
sites: ["freq", "ga"]
dated: [True, False]
forecasts: [False, False]
ps_point_estimator: "mean"
Empty file added test/posterior/__init__.py
Empty file.
32 changes: 32 additions & 0 deletions test/posterior/test_posterior_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from evofr.posterior.posterior_helpers import combine_sites_tidy


def test_combine_sites_tidy():
tidy_dicts = [
{
"metadata": {
"location": ["Africa"],
"ps_point_estimator": "mean",
},
"data": [
{
"record": 1,
}
],
},
{
"metadata": {
"location": ["Europe"],
"ps_point_estimator": "mean",
},
"data": [
{
"record": 2,
}
],
}
]
combined_dict = combine_sites_tidy(tidy_dicts)
assert sorted(combined_dict["metadata"]["location"]) == ["Africa", "Europe"]
assert combined_dict["metadata"]["ps_point_estimator"] == "mean"
assert len(combined_dict["data"]) == 2
8 changes: 8 additions & 0 deletions test/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import shutil
import subprocess
from pathlib import Path
Expand Down Expand Up @@ -58,3 +59,10 @@ def test_run_model_creates_results():
run_cli(command)
result_file = export_dir / "results.json"
assert result_file.exists(), "Results JSON file not created."

with open(result_file, "r", encoding="utf-8") as fh:
model = json.load(fh)

assert model["metadata"].get("ps_point_estimator") == "mean"
assert any((record["ps"] == "mean") & (record["site"] == "freq") for record in model["data"])
assert any((record["ps"] == "mean") & (record["site"] == "ga") for record in model["data"])