Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ This changelog is intended for _humans_ and follows many of the principles from
Changes for this project _do not_ currently follow the [Semantic Versioning rules](https://semver.org/spec/v2.0.0.html).
Instead, changes appear below grouped by the date they were added to the workflow.

# 13 January 2025
# 6 February 2026

- Use mean instead of median as point estimator for frequency and GA values. See [#38](https://github.com/nextstrain/forecasts-flu/pull/38) for details.

# 13 January 2026

- Collapse low-count haplotype counts into their parental clades instead of the "other" group. See [#30](https://github.com/nextstrain/forecasts-flu/pull/30) for details.
- Fix access to dated model results produced before we added support for amino acid haplotypes. See [#28](https://github.com/nextstrain/forecasts-flu/pull/28) for details.
Expand Down
1 change: 1 addition & 0 deletions config/mlr/h1n1pdm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ settings:
load: false # Load old model?
export_json: true # Export model results as json
ps: [0.5, 0.8, 0.95] # HPDI intervals to be exported
ps_point_estimator: "mean"

model:
forecast_L: 52
Expand Down
1 change: 1 addition & 0 deletions config/mlr/h3n2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ settings:
load: false # Load old model?
export_json: true # Export model results as json
ps: [0.5, 0.8, 0.95] # HPDI intervals to be exported
ps_point_estimator: "mean"

model:
forecast_L: 52
Expand Down
1 change: 1 addition & 0 deletions config/mlr/vic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ settings:
load: false # Load old model?
export_json: true # Export model results as json
ps: [0.5, 0.8, 0.95] # HPDI intervals to be exported
ps_point_estimator: "mean"

model:
forecast_L: 52
Expand Down
10 changes: 5 additions & 5 deletions scripts/parse-json.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def parse_json(input_file, output_ga, output_rf, output_freq, output_raw, output

# Parse freq from MLR or Latent model
for record in data["data"]:
if record["site"] == "freq" and record["ps"] in ["median", "HDI_95_upper", "HDI_95_lower"]:
if record["site"] == "freq" and record["ps"] in ["mean", "median", "HDI_95_upper", "HDI_95_lower"]:
key = (record["location"], record["variant"], record["date"])
if key not in grouped_freq:
grouped_freq[key] = {"location": record["location"], "date": record["date"], "variant": record["variant"]}
Expand All @@ -48,7 +48,7 @@ def parse_json(input_file, output_ga, output_rf, output_freq, output_raw, output
print("Parsing forecast freq from model results.")
grouped_freq_forecast = {}
for record in data["data"]:
if record["site"] == "freq_forecast" and record["ps"] in ["median", "HDI_95_upper", "HDI_95_lower"]:
if record["site"] == "freq_forecast" and record["ps"] in ["mean", "median", "HDI_95_upper", "HDI_95_lower"]:
key = (record["location"], record["variant"], record["date"])
if key not in grouped_freq_forecast:
grouped_freq_forecast[key] = {"location": record["location"], "date": record["date"], "variant": record["variant"]}
Expand All @@ -74,7 +74,7 @@ def parse_json(input_file, output_ga, output_rf, output_freq, output_raw, output
if model_version == "MLR":
print("Parsing ga (growth advantage) from MLR model results.")
for record in data["data"]:
if record["site"] == "ga" and record["ps"] in ["median", "HDI_95_upper", "HDI_95_lower"]:
if record["site"] == "ga" and record["ps"] in ["mean", "median", "HDI_95_upper", "HDI_95_lower"]:
key = (record["location"], record["variant"])
if key not in grouped_ga:
grouped_ga[key] = {"location": record["location"], "variant": record["variant"]}
Expand All @@ -88,12 +88,12 @@ def parse_json(input_file, output_ga, output_rf, output_freq, output_raw, output
print("Parsing delta (relative fitness) from Latent model results.")
print("Parsing ga (growth advantage) from Latent model results.")
for record in data["data"]:
if record["site"] == "delta" and record["ps"] in ["median", "HDI_95_upper", "HDI_95_lower"]:
if record["site"] == "delta" and record["ps"] in ["mean", "median", "HDI_95_upper", "HDI_95_lower"]:
key = (record["location"], record["variant"], record["date"])
if key not in grouped_rf:
grouped_rf[key] = {"location": record["location"], "date": record["date"], "variant": record["variant"]}
grouped_rf[key][record["ps"]] = record["value"]
if record["site"] == "ga" and record["ps"] in ["median", "HDI_95_upper", "HDI_95_lower"]:
if record["site"] == "ga" and record["ps"] in ["mean", "median", "HDI_95_upper", "HDI_95_lower"]:
key = (record["location"], record["variant"], record["date"])
if key not in grouped_ga:
grouped_ga[key] = {"location": record["location"], "date": record["date"], "variant": record["variant"]}
Expand Down
11 changes: 8 additions & 3 deletions scripts/run-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def make_raw_freq_tidy(data, location):
return {"metadata": metadata, "data": entries}

# export results MLR model (with GA)
def export_results_mlr(multi_posterior, ps, path, data_name, hier, ga_inclusion_threshold, variant_location_counts):
def export_results_mlr(multi_posterior, ps, path, data_name, hier, ga_inclusion_threshold, variant_location_counts, ps_point_estimator):
EXPORT_SITES = ["freq", "ga", "freq_forecast"]
EXPORT_DATED = [True, False, True]
EXPORT_FORECASTS = [False, False, True]
Expand Down Expand Up @@ -383,7 +383,8 @@ def get_group_samples(samples, sites, group):
[False],
[False],
ps,
location
location,
ps_point_estimator=ps_point_estimator,
)
)
else:
Expand All @@ -395,6 +396,7 @@ def get_group_samples(samples, sites, group):
EXPORT_FORECASTS,
ps,
location,
ps_point_estimator=ps_point_estimator,
)

# Apply filtering on ga values
Expand Down Expand Up @@ -637,8 +639,11 @@ def nonnegative_int(value):
ps = parse_with_default(
config.config["settings"], "ps", dflt=[0.5, 0.8, 0.95]
)
ps_point_estimator = parse_with_default(
config.config["settings"], "ps_point_estimator", dflt="median"
)
data_name = args.data_name or config.config["data"]["name"]
if config.config["model"]["version"] == "MLR":
export_results_mlr(multi_posterior, ps, export_path, data_name, hier, location_ga_inclusion_threshold, variant_location_counts)
export_results_mlr(multi_posterior, ps, export_path, data_name, hier, location_ga_inclusion_threshold, variant_location_counts, ps_point_estimator)
elif config.config["model"]["version"] == "Latent":
export_results_latent(multi_posterior, ps, export_path, data_name, hier)