From 4c29e5b37d3967fe663afa8090699eaabbd8e99a Mon Sep 17 00:00:00 2001 From: John Huddleston Date: Fri, 6 Feb 2026 14:18:33 -0800 Subject: [PATCH 1/2] Use mean instead of median for point estimates Adds logic to parse a `ps_point_estimator` function name from the model config YAML and pass that estimator name to the evofr function for generating tidy data frames. This interface allows users to choose one of the available point estimates generated by evofr as the representative value to plot in evofr-viz and other downstream tools. Updates the configs for all three subtypes to use the "mean" instead of the "median" for reporting frequencies and GAs. Updates the local JSON parsing script to export mean along with median in the data frame output. Closes #32 Requires https://github.com/blab/evofr/pull/65 Requires https://github.com/nextstrain/forecasts-viz/pull/33 --- config/mlr/h1n1pdm.yaml | 1 + config/mlr/h3n2.yaml | 1 + config/mlr/vic.yaml | 1 + scripts/parse-json.py | 10 +++++----- scripts/run-model.py | 11 ++++++++--- 5 files changed, 16 insertions(+), 8 deletions(-) diff --git a/config/mlr/h1n1pdm.yaml b/config/mlr/h1n1pdm.yaml index 2a4d36b..f963753 100644 --- a/config/mlr/h1n1pdm.yaml +++ b/config/mlr/h1n1pdm.yaml @@ -11,6 +11,7 @@ settings: load: false # Load old model? export_json: true # Export model results as json ps: [0.5, 0.8, 0.95] # HPDI intervals to be exported + ps_point_estimator: "mean" model: forecast_L: 52 diff --git a/config/mlr/h3n2.yaml b/config/mlr/h3n2.yaml index 82fda17..eee4ca9 100644 --- a/config/mlr/h3n2.yaml +++ b/config/mlr/h3n2.yaml @@ -11,6 +11,7 @@ settings: load: false # Load old model? export_json: true # Export model results as json ps: [0.5, 0.8, 0.95] # HPDI intervals to be exported + ps_point_estimator: "mean" model: forecast_L: 52 diff --git a/config/mlr/vic.yaml b/config/mlr/vic.yaml index 0981cdf..5b61e5c 100644 --- a/config/mlr/vic.yaml +++ b/config/mlr/vic.yaml @@ -11,6 +11,7 @@ settings: load: false # Load old model? export_json: true # Export model results as json ps: [0.5, 0.8, 0.95] # HPDI intervals to be exported + ps_point_estimator: "mean" model: forecast_L: 52 diff --git a/scripts/parse-json.py b/scripts/parse-json.py index 00aa73b..883a030 100644 --- a/scripts/parse-json.py +++ b/scripts/parse-json.py @@ -33,7 +33,7 @@ def parse_json(input_file, output_ga, output_rf, output_freq, output_raw, output # Parse freq from MLR or Latent model for record in data["data"]: - if record["site"] == "freq" and record["ps"] in ["median", "HDI_95_upper", "HDI_95_lower"]: + if record["site"] == "freq" and record["ps"] in ["mean", "median", "HDI_95_upper", "HDI_95_lower"]: key = (record["location"], record["variant"], record["date"]) if key not in grouped_freq: grouped_freq[key] = {"location": record["location"], "date": record["date"], "variant": record["variant"]} @@ -48,7 +48,7 @@ def parse_json(input_file, output_ga, output_rf, output_freq, output_raw, output print("Parsing forecast freq from model results.") grouped_freq_forecast = {} for record in data["data"]: - if record["site"] == "freq_forecast" and record["ps"] in ["median", "HDI_95_upper", "HDI_95_lower"]: + if record["site"] == "freq_forecast" and record["ps"] in ["mean", "median", "HDI_95_upper", "HDI_95_lower"]: key = (record["location"], record["variant"], record["date"]) if key not in grouped_freq_forecast: grouped_freq_forecast[key] = {"location": record["location"], "date": record["date"], "variant": record["variant"]} @@ -74,7 +74,7 @@ def parse_json(input_file, output_ga, output_rf, output_freq, output_raw, output if model_version == "MLR": print("Parsing ga (growth advantage) from MLR model results.") for record in data["data"]: - if record["site"] == "ga" and record["ps"] in ["median", "HDI_95_upper", "HDI_95_lower"]: + if record["site"] == "ga" and record["ps"] in ["mean", "median", "HDI_95_upper", "HDI_95_lower"]: key = (record["location"], record["variant"]) if key not in grouped_ga: grouped_ga[key] = {"location": record["location"], "variant": record["variant"]} @@ -88,12 +88,12 @@ def parse_json(input_file, output_ga, output_rf, output_freq, output_raw, output print("Parsing delta (relative fitness) from Latent model results.") print("Parsing ga (growth advantage) from Latent model results.") for record in data["data"]: - if record["site"] == "delta" and record["ps"] in ["median", "HDI_95_upper", "HDI_95_lower"]: + if record["site"] == "delta" and record["ps"] in ["mean", "median", "HDI_95_upper", "HDI_95_lower"]: key = (record["location"], record["variant"], record["date"]) if key not in grouped_rf: grouped_rf[key] = {"location": record["location"], "date": record["date"], "variant": record["variant"]} grouped_rf[key][record["ps"]] = record["value"] - if record["site"] == "ga" and record["ps"] in ["median", "HDI_95_upper", "HDI_95_lower"]: + if record["site"] == "ga" and record["ps"] in ["mean", "median", "HDI_95_upper", "HDI_95_lower"]: key = (record["location"], record["variant"], record["date"]) if key not in grouped_ga: grouped_ga[key] = {"location": record["location"], "date": record["date"], "variant": record["variant"]} diff --git a/scripts/run-model.py b/scripts/run-model.py index 91c4b80..cd68381 100644 --- a/scripts/run-model.py +++ b/scripts/run-model.py @@ -333,7 +333,7 @@ def make_raw_freq_tidy(data, location): return {"metadata": metadata, "data": entries} # export results MLR model (with GA) -def export_results_mlr(multi_posterior, ps, path, data_name, hier, ga_inclusion_threshold, variant_location_counts): +def export_results_mlr(multi_posterior, ps, path, data_name, hier, ga_inclusion_threshold, variant_location_counts, ps_point_estimator): EXPORT_SITES = ["freq", "ga", "freq_forecast"] EXPORT_DATED = [True, False, True] EXPORT_FORECASTS = [False, False, True] @@ -383,7 +383,8 @@ def get_group_samples(samples, sites, group): [False], [False], ps, - location + location, + ps_point_estimator=ps_point_estimator, ) ) else: @@ -395,6 +396,7 @@ def get_group_samples(samples, sites, group): EXPORT_FORECASTS, ps, location, + ps_point_estimator=ps_point_estimator, ) # Apply filtering on ga values @@ -637,8 +639,11 @@ def nonnegative_int(value): ps = parse_with_default( config.config["settings"], "ps", dflt=[0.5, 0.8, 0.95] ) + ps_point_estimator = parse_with_default( + config.config["settings"], "ps_point_estimator", dflt="median" + ) data_name = args.data_name or config.config["data"]["name"] if config.config["model"]["version"] == "MLR": - export_results_mlr(multi_posterior, ps, export_path, data_name, hier, location_ga_inclusion_threshold, variant_location_counts) + export_results_mlr(multi_posterior, ps, export_path, data_name, hier, location_ga_inclusion_threshold, variant_location_counts, ps_point_estimator) elif config.config["model"]["version"] == "Latent": export_results_latent(multi_posterior, ps, export_path, data_name, hier) From 4b5886334950525f5d3030c51da19af0e30d21b2 Mon Sep 17 00:00:00 2001 From: John Huddleston Date: Fri, 6 Feb 2026 14:26:48 -0800 Subject: [PATCH 2/2] Update change log --- CHANGES.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 262f3cf..1e3cf74 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -6,7 +6,11 @@ This changelog is intended for _humans_ and follows many of the principles from Changes for this project _do not_ currently follow the [Semantic Versioning rules](https://semver.org/spec/v2.0.0.html). Instead, changes appear below grouped by the date they were added to the workflow. -# 13 January 2025 +# 6 February 2026 + + - Use mean instead of median as point estimator for frequency and GA values. See [#38](https://github.com/nextstrain/forecasts-flu/pull/38) for details. + +# 13 January 2026 - Collapse low-count haplotype counts into their parental clades instead of the "other" group. See [#30](https://github.com/nextstrain/forecasts-flu/pull/30) for details. - Fix access to dated model results produced before we added support for amino acid haplotypes. See [#28](https://github.com/nextstrain/forecasts-flu/pull/28) for details.