diff --git a/CHANGES.md b/CHANGES.md new file mode 100644 index 0000000..a92dde6 --- /dev/null +++ b/CHANGES.md @@ -0,0 +1,16 @@ +This is the changelog for evofr. +All notable changes in a release will be documented in this file. + +This changelog is intended for _humans_ and follows many of the principles from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). + +Versions for this project follow the [Semantic Versioning rules](https://semver.org/spec/v2.0.0.html). +Each heading below is a version released to [PyPI](https://pypi.org/project/evofr/) and the date it was released. +The "__NEXT__" heading below describes changes in the unreleased development source code and as such may not be routinely kept up to date. + +# __NEXT__ + +# 0.2.0 (February 9, 2026) + +## Features + + - Support alternate point estimators for posterior distributions of frequencies and growth advantages (e.g., mean or median). See [#65](https://github.com/blab/evofr/pull/65) for more. diff --git a/evofr/commands/run_model.py b/evofr/commands/run_model.py index e9e01d3..6f3a0dc 100644 --- a/evofr/commands/run_model.py +++ b/evofr/commands/run_model.py @@ -156,6 +156,7 @@ def export_results(posterior, export_config): forecasts=forecasts, ps=[0.5, 0.8, 0.95], # Default percentiles name=posterior.name, + ps_point_estimator=export_config.get("ps_point_estimator", "median"), ) results["metadata"]["updated"] = pd.to_datetime(date.today()).isoformat() diff --git a/evofr/posterior/posterior_helpers.py b/evofr/posterior/posterior_helpers.py index 0df2aab..24856b2 100644 --- a/evofr/posterior/posterior_helpers.py +++ b/evofr/posterior/posterior_helpers.py @@ -34,6 +34,11 @@ def get_quantile(samples: Dict, p, site): return jnp.quantile(samples[site], q=q, axis=0) +def get_mean(samples: Dict, site): + """Returns mean value across all samples for a site""" + return jnp.mean(samples[site], axis=0) + + def get_median(samples: Dict, site): """Returns median value across all samples for a site""" return jnp.median(samples[site], axis=0) @@ -288,17 +293,24 @@ def get_sites_variants_tidy( forecasts: List[bool], ps, name: Optional[str] = None, + ps_point_estimator: Optional[str] = "median", ): # Save metadata metadata = dict() # Make keys for probability levels - ps_keys = ["median"] + ps_keys = [ + "median", + "mean", + ] for p in ps: ps_keys.append(f"HDI_{round(p * 100)}_upper") ps_keys.append(f"HDI_{round(p * 100)}_lower") metadata["ps"] = ps_keys + # Save the requested point estimator function. + metadata["ps_point_estimator"] = ps_point_estimator + metadata["sites"] = sites if name: metadata["location"] = [name] @@ -332,6 +344,7 @@ def tidy_site_date(site, forecast): # Loop over entries of median and med, quants = get_quantiles(samples, ps, site) med, quants = np.array(med), np.array(quants) + means = np.array(get_mean(samples, site)) entries = [] T, N_variants = med.shape @@ -363,6 +376,14 @@ def tidy_site_date(site, forecast): # Add median entry entries.append(entry_med) + # Create mean entry + entry_mean = entry.copy() + entry_mean["value"] = np.around(means[index, v], decimals=3) + entry_mean["ps"] = "mean" + + # Add mean entry + entries.append(entry_mean) + # Loop over intervals of interest for i, p in enumerate(ps): entry_lower = entry.copy() @@ -384,6 +405,7 @@ def tidy_site_flat(site): # Loop over entries of median and med, quants = get_quantiles(samples, ps, site) med, quants = np.array(med), np.array(quants) + means = np.array(get_mean(samples, site)) entries = [] N_variants = med.shape[0] @@ -407,6 +429,14 @@ def tidy_site_flat(site): # Add median entry entries.append(entry_med) + # Create mean entry + entry_mean = entry.copy() + entry_mean["value"] = np.around(means[v], decimals=3) + entry_mean["ps"] = "mean" + + # Add mean entry + entries.append(entry_mean) + # Loop over intervals of interest for i, p in enumerate(ps): entry_lower = entry.copy() @@ -439,7 +469,10 @@ def combine_sites_tidy(tidy_dicts): for tidy_dict in tidy_dicts: for key, value in tidy_dict["metadata"].items(): - metadata[key].extend([v for v in value if v not in metadata[key]]) + if isinstance(value, list): + metadata[key].extend([v for v in value if v not in metadata[key]]) + else: + metadata[key] = value # Loop over data entries = [] diff --git a/pyproject.toml b/pyproject.toml index 9516da3..e0b0ce8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "evofr" -version = "0.1.27" +version = "0.2.0" description = "Tools for evolutionary forecasting." authors = ["marlinfiggins "] license = "AGPL-3.0" diff --git a/test/configs/run_model.yaml b/test/configs/run_model.yaml index a81ee0a..3218315 100644 --- a/test/configs/run_model.yaml +++ b/test/configs/run_model.yaml @@ -14,3 +14,4 @@ export: sites: ["freq", "ga"] dated: [True, False] forecasts: [False, False] + ps_point_estimator: "mean" diff --git a/test/posterior/__init__.py b/test/posterior/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/posterior/test_posterior_helpers.py b/test/posterior/test_posterior_helpers.py new file mode 100644 index 0000000..58f8ca5 --- /dev/null +++ b/test/posterior/test_posterior_helpers.py @@ -0,0 +1,32 @@ +from evofr.posterior.posterior_helpers import combine_sites_tidy + + +def test_combine_sites_tidy(): + tidy_dicts = [ + { + "metadata": { + "location": ["Africa"], + "ps_point_estimator": "mean", + }, + "data": [ + { + "record": 1, + } + ], + }, + { + "metadata": { + "location": ["Europe"], + "ps_point_estimator": "mean", + }, + "data": [ + { + "record": 2, + } + ], + } + ] + combined_dict = combine_sites_tidy(tidy_dicts) + assert sorted(combined_dict["metadata"]["location"]) == ["Africa", "Europe"] + assert combined_dict["metadata"]["ps_point_estimator"] == "mean" + assert len(combined_dict["data"]) == 2 diff --git a/test/test_cli.py b/test/test_cli.py index 37660b8..ddb56b5 100644 --- a/test/test_cli.py +++ b/test/test_cli.py @@ -1,3 +1,4 @@ +import json import shutil import subprocess from pathlib import Path @@ -58,3 +59,10 @@ def test_run_model_creates_results(): run_cli(command) result_file = export_dir / "results.json" assert result_file.exists(), "Results JSON file not created." + + with open(result_file, "r", encoding="utf-8") as fh: + model = json.load(fh) + + assert model["metadata"].get("ps_point_estimator") == "mean" + assert any((record["ps"] == "mean") & (record["site"] == "freq") for record in model["data"]) + assert any((record["ps"] == "mean") & (record["site"] == "ga") for record in model["data"])