diff --git a/CHANGES.md b/CHANGES.md index 262f3cf..1e3cf74 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -6,7 +6,11 @@ This changelog is intended for _humans_ and follows many of the principles from Changes for this project _do not_ currently follow the [Semantic Versioning rules](https://semver.org/spec/v2.0.0.html). Instead, changes appear below grouped by the date they were added to the workflow. -# 13 January 2025 +# 6 February 2026 + + - Use mean instead of median as point estimator for frequency and GA values. See [#38](https://github.com/nextstrain/forecasts-flu/pull/38) for details. + +# 13 January 2026 - Collapse low-count haplotype counts into their parental clades instead of the "other" group. See [#30](https://github.com/nextstrain/forecasts-flu/pull/30) for details. - Fix access to dated model results produced before we added support for amino acid haplotypes. See [#28](https://github.com/nextstrain/forecasts-flu/pull/28) for details. diff --git a/config/mlr/h1n1pdm.yaml b/config/mlr/h1n1pdm.yaml index 2a4d36b..f963753 100644 --- a/config/mlr/h1n1pdm.yaml +++ b/config/mlr/h1n1pdm.yaml @@ -11,6 +11,7 @@ settings: load: false # Load old model? export_json: true # Export model results as json ps: [0.5, 0.8, 0.95] # HPDI intervals to be exported + ps_point_estimator: "mean" model: forecast_L: 52 diff --git a/config/mlr/h3n2.yaml b/config/mlr/h3n2.yaml index 82fda17..eee4ca9 100644 --- a/config/mlr/h3n2.yaml +++ b/config/mlr/h3n2.yaml @@ -11,6 +11,7 @@ settings: load: false # Load old model? export_json: true # Export model results as json ps: [0.5, 0.8, 0.95] # HPDI intervals to be exported + ps_point_estimator: "mean" model: forecast_L: 52 diff --git a/config/mlr/vic.yaml b/config/mlr/vic.yaml index 0981cdf..5b61e5c 100644 --- a/config/mlr/vic.yaml +++ b/config/mlr/vic.yaml @@ -11,6 +11,7 @@ settings: load: false # Load old model? export_json: true # Export model results as json ps: [0.5, 0.8, 0.95] # HPDI intervals to be exported + ps_point_estimator: "mean" model: forecast_L: 52 diff --git a/scripts/parse-json.py b/scripts/parse-json.py index 00aa73b..883a030 100644 --- a/scripts/parse-json.py +++ b/scripts/parse-json.py @@ -33,7 +33,7 @@ def parse_json(input_file, output_ga, output_rf, output_freq, output_raw, output # Parse freq from MLR or Latent model for record in data["data"]: - if record["site"] == "freq" and record["ps"] in ["median", "HDI_95_upper", "HDI_95_lower"]: + if record["site"] == "freq" and record["ps"] in ["mean", "median", "HDI_95_upper", "HDI_95_lower"]: key = (record["location"], record["variant"], record["date"]) if key not in grouped_freq: grouped_freq[key] = {"location": record["location"], "date": record["date"], "variant": record["variant"]} @@ -48,7 +48,7 @@ def parse_json(input_file, output_ga, output_rf, output_freq, output_raw, output print("Parsing forecast freq from model results.") grouped_freq_forecast = {} for record in data["data"]: - if record["site"] == "freq_forecast" and record["ps"] in ["median", "HDI_95_upper", "HDI_95_lower"]: + if record["site"] == "freq_forecast" and record["ps"] in ["mean", "median", "HDI_95_upper", "HDI_95_lower"]: key = (record["location"], record["variant"], record["date"]) if key not in grouped_freq_forecast: grouped_freq_forecast[key] = {"location": record["location"], "date": record["date"], "variant": record["variant"]} @@ -74,7 +74,7 @@ def parse_json(input_file, output_ga, output_rf, output_freq, output_raw, output if model_version == "MLR": print("Parsing ga (growth advantage) from MLR model results.") for record in data["data"]: - if record["site"] == "ga" and record["ps"] in ["median", "HDI_95_upper", "HDI_95_lower"]: + if record["site"] == "ga" and record["ps"] in ["mean", "median", "HDI_95_upper", "HDI_95_lower"]: key = (record["location"], record["variant"]) if key not in grouped_ga: grouped_ga[key] = {"location": record["location"], "variant": record["variant"]} @@ -88,12 +88,12 @@ def parse_json(input_file, output_ga, output_rf, output_freq, output_raw, output print("Parsing delta (relative fitness) from Latent model results.") print("Parsing ga (growth advantage) from Latent model results.") for record in data["data"]: - if record["site"] == "delta" and record["ps"] in ["median", "HDI_95_upper", "HDI_95_lower"]: + if record["site"] == "delta" and record["ps"] in ["mean", "median", "HDI_95_upper", "HDI_95_lower"]: key = (record["location"], record["variant"], record["date"]) if key not in grouped_rf: grouped_rf[key] = {"location": record["location"], "date": record["date"], "variant": record["variant"]} grouped_rf[key][record["ps"]] = record["value"] - if record["site"] == "ga" and record["ps"] in ["median", "HDI_95_upper", "HDI_95_lower"]: + if record["site"] == "ga" and record["ps"] in ["mean", "median", "HDI_95_upper", "HDI_95_lower"]: key = (record["location"], record["variant"], record["date"]) if key not in grouped_ga: grouped_ga[key] = {"location": record["location"], "date": record["date"], "variant": record["variant"]} diff --git a/scripts/run-model.py b/scripts/run-model.py index 91c4b80..cd68381 100644 --- a/scripts/run-model.py +++ b/scripts/run-model.py @@ -333,7 +333,7 @@ def make_raw_freq_tidy(data, location): return {"metadata": metadata, "data": entries} # export results MLR model (with GA) -def export_results_mlr(multi_posterior, ps, path, data_name, hier, ga_inclusion_threshold, variant_location_counts): +def export_results_mlr(multi_posterior, ps, path, data_name, hier, ga_inclusion_threshold, variant_location_counts, ps_point_estimator): EXPORT_SITES = ["freq", "ga", "freq_forecast"] EXPORT_DATED = [True, False, True] EXPORT_FORECASTS = [False, False, True] @@ -383,7 +383,8 @@ def get_group_samples(samples, sites, group): [False], [False], ps, - location + location, + ps_point_estimator=ps_point_estimator, ) ) else: @@ -395,6 +396,7 @@ def get_group_samples(samples, sites, group): EXPORT_FORECASTS, ps, location, + ps_point_estimator=ps_point_estimator, ) # Apply filtering on ga values @@ -637,8 +639,11 @@ def nonnegative_int(value): ps = parse_with_default( config.config["settings"], "ps", dflt=[0.5, 0.8, 0.95] ) + ps_point_estimator = parse_with_default( + config.config["settings"], "ps_point_estimator", dflt="median" + ) data_name = args.data_name or config.config["data"]["name"] if config.config["model"]["version"] == "MLR": - export_results_mlr(multi_posterior, ps, export_path, data_name, hier, location_ga_inclusion_threshold, variant_location_counts) + export_results_mlr(multi_posterior, ps, export_path, data_name, hier, location_ga_inclusion_threshold, variant_location_counts, ps_point_estimator) elif config.config["model"]["version"] == "Latent": export_results_latent(multi_posterior, ps, export_path, data_name, hier)