From efc391d7374e940f463a0c78f3b1952a801b1330 Mon Sep 17 00:00:00 2001 From: Trevor Bedford Date: Wed, 17 Sep 2025 19:34:33 -0700 Subject: [PATCH] Aggregate regional frequencies into global frequency This inserts a new rule toadd_global_to_regional_mlr that takes the MLR_results.json from a regional run and aggregates frequencies across regions into a new "Global" region which is inserted into the JSON. Growth advantages for Global are copied from the hierarchical estimates. Aggregation is weighted sum based on regional population size. There's some additional tweaks and testing to be done here. --- Snakefile | 30 ++- config/defaults.yaml | 2 + config/regional_population_weights.tsv | 11 + scripts/add_global_to_mlr_results.py | 214 ++++++++++++++++++ .../generate_regional_population_weights.py | 84 +++++++ scripts/plot-ga.py | 4 + 6 files changed, 344 insertions(+), 1 deletion(-) create mode 100644 config/regional_population_weights.tsv create mode 100644 scripts/add_global_to_mlr_results.py create mode 100644 scripts/generate_regional_population_weights.py diff --git a/Snakefile b/Snakefile index 1427850..41b673d 100644 --- a/Snakefile +++ b/Snakefile @@ -177,7 +177,7 @@ rule add_colors_to_mlr_model: model="results/{lineage}/{geo_resolution}/mlr/initial_MLR_results.json", auspice_config="results/{lineage}/auspice_config.json", output: - model="results/{lineage}/{geo_resolution}/mlr/MLR_results.json", + model="results/{lineage}/{geo_resolution}/mlr/MLR_results_with_colors.json", params: coloring_field=config["coloring_field"], shell: @@ -189,6 +189,34 @@ rule add_colors_to_mlr_model: --output {output.model:q} """ +rule add_global_to_regional_mlr: + input: + model="results/{lineage}/region/mlr/MLR_results_with_colors.json", + regional_weights="config/regional_population_weights.tsv" + output: + model="results/{lineage}/region/mlr/MLR_results.json" + params: + aggregate_regions=config.get("aggregate_regions", False) + run: + if params.aggregate_regions and wildcards.lineage in config["lineages"]: + shell(""" + python scripts/add_global_to_mlr_results.py \ + --input-json {input.model} \ + --regional-weights {input.regional_weights} \ + --output-json {output.model} + """) + else: + shell("cp {input.model} {output.model}") + +# For non-region geo_resolutions, just copy the file +rule finalize_country_mlr: + input: + model="results/{lineage}/country/mlr/MLR_results_with_colors.json", + output: + model="results/{lineage}/country/mlr/MLR_results.json", + shell: + "cp {input.model} {output.model}" + rule parse_mlr_json: input: model="results/{lineage}/{geo_resolution}/mlr/MLR_results.json", diff --git a/config/defaults.yaml b/config/defaults.yaml index 494b2c9..e85b137 100644 --- a/config/defaults.yaml +++ b/config/defaults.yaml @@ -33,3 +33,5 @@ prepare_data: haplotype_variant_column: "subclade" variant: "haplotype" coloring_field: "emerging_haplotype" + +aggregate_regions: true # Set to true to add Global region to regional MLR results diff --git a/config/regional_population_weights.tsv b/config/regional_population_weights.tsv new file mode 100644 index 0000000..be51d1a --- /dev/null +++ b/config/regional_population_weights.tsv @@ -0,0 +1,11 @@ +region weight +Africa 977096.0 +China 1445901.0 +Europe 515136.0 +JapanKorea 124370.0 +NorthAmerica 579891.0 +Oceania 27915.0 +SouthAmerica 431886.0 +SouthAsia 1640015.0 +SoutheastAsia 688728.0 +WestAsia 826100.0 diff --git a/scripts/add_global_to_mlr_results.py b/scripts/add_global_to_mlr_results.py new file mode 100644 index 0000000..b91b2ca --- /dev/null +++ b/scripts/add_global_to_mlr_results.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +# script by Claude Code +""" +Add Global region to MLR results by aggregating regional frequencies using population weights +and copying hierarchical GA values as Global GA values. +""" + +import argparse +import json +import pandas as pd +import numpy as np +from collections import defaultdict + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument('--input-json', required=True, + help='Path to input MLR results JSON') + parser.add_argument('--regional-weights', required=True, + help='Path to regional population weights TSV') + parser.add_argument('--output-json', required=True, + help='Path to output MLR results JSON with Global added') + args = parser.parse_args() + + # Read the MLR results JSON + with open(args.input_json, 'r') as f: + mlr_data = json.load(f) + + # Read regional population weights + weights_df = pd.read_csv(args.regional_weights, sep='\t') + + # Map regional names to match MLR data + region_name_map = { + 'Africa': 'Africa', + 'Europe': 'Europe', + 'NorthAmerica': 'North America', + 'SouthAmerica': 'South America', + 'SoutheastAsia': 'Southeast Asia', + 'WestAsia': 'West Asia', + 'Oceania': 'Oceania', + 'China': 'China', + 'JapanKorea': 'Japan Korea', + 'SouthAsia': 'South Asia' + } + + # Apply mapping to weights + weights_df['mapped_region'] = weights_df['region'].map(region_name_map) + + # Filter to only regions present in the MLR data + available_regions = [loc for loc in mlr_data['metadata']['location'] + if loc != 'hierarchical'] + weights_df = weights_df[weights_df['mapped_region'].isin(available_regions)] + + # Normalize weights + total_weight = weights_df['weight'].sum() + weights_df['normalized_weight'] = weights_df['weight'] / total_weight + + # Create weight lookup + weight_lookup = dict(zip(weights_df['mapped_region'], weights_df['normalized_weight'])) + + print(f"Using population weights for {len(weight_lookup)} regions:") + for region, weight in weight_lookup.items(): + print(f" {region:<20} weight: {weight:.4f}") + + # Add "Global" to metadata locations if not present + if "Global" not in mlr_data['metadata']['location']: + mlr_data['metadata']['location'].append("Global") + + # Collect data by site, date, variant, and ps for aggregation + freq_data = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) + raw_freq_data = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) + freq_forecast_data = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) + + # Process existing data to collect regional values + for record in mlr_data['data']: + if record['location'] in weight_lookup: + if record['site'] == 'freq': + key = (record['date'], record['variant'], record['ps']) + freq_data[key[0]][key[1]][key[2]][record['location']] = record['value'] + elif record['site'] == 'raw_freq': + key = (record['date'], record['variant']) + raw_freq_data[key[0]][key[1]][record['location']] = record['value'] + elif record['site'] == 'freq_forecast': + key = (record['date'], record['variant'], record['ps']) + freq_forecast_data[key[0]][key[1]][key[2]][record['location']] = record['value'] + + # Create new Global records + new_global_records = [] + + # Aggregate freq data + for date in freq_data: + for variant in freq_data[date]: + for ps in freq_data[date][variant]: + regional_values = freq_data[date][variant][ps] + if regional_values: + # Calculate weighted average, skipping None values + valid_regions = [region for region in regional_values + if region in weight_lookup and regional_values[region] is not None] + if valid_regions: + weighted_sum = sum(regional_values[region] * weight_lookup[region] + for region in valid_regions) + weight_sum = sum(weight_lookup[region] for region in valid_regions) + if weight_sum > 0: + global_value = weighted_sum / weight_sum + new_global_records.append({ + 'location': 'Global', + 'date': date, + 'variant': variant, + 'site': 'freq', + 'ps': ps, + 'value': global_value + }) + + # Aggregate raw_freq data + for date in raw_freq_data: + for variant in raw_freq_data[date]: + regional_values = raw_freq_data[date][variant] + if regional_values: + # Calculate weighted average, skipping None values + valid_regions = [region for region in regional_values + if region in weight_lookup and regional_values[region] is not None] + if valid_regions: + weighted_sum = sum(regional_values[region] * weight_lookup[region] + for region in valid_regions) + weight_sum = sum(weight_lookup[region] for region in valid_regions) + if weight_sum > 0: + global_value = weighted_sum / weight_sum + new_global_records.append({ + 'location': 'Global', + 'date': date, + 'variant': variant, + 'site': 'raw_freq', + 'value': global_value + }) + + # Aggregate freq_forecast data + for date in freq_forecast_data: + for variant in freq_forecast_data[date]: + for ps in freq_forecast_data[date][variant]: + regional_values = freq_forecast_data[date][variant][ps] + if regional_values: + # Calculate weighted average, skipping None values + valid_regions = [region for region in regional_values + if region in weight_lookup and regional_values[region] is not None] + if valid_regions: + weighted_sum = sum(regional_values[region] * weight_lookup[region] + for region in valid_regions) + weight_sum = sum(weight_lookup[region] for region in valid_regions) + if weight_sum > 0: + global_value = weighted_sum / weight_sum + new_global_records.append({ + 'location': 'Global', + 'date': date, + 'variant': variant, + 'site': 'freq_forecast', + 'ps': ps, + 'value': global_value + }) + + # Copy hierarchical GA values as Global GA + hierarchical_ga_records = [r for r in mlr_data['data'] + if r['location'] == 'hierarchical' and r['site'] == 'ga'] + for record in hierarchical_ga_records: + new_record = record.copy() + new_record['location'] = 'Global' + new_global_records.append(new_record) + + # For smoothed_raw_freq and agg_counts, we'll create empty/zero records for Global + # Get unique dates and variants + dates = mlr_data['metadata']['dates'] + variants = mlr_data['metadata']['variants'] + + # Add smoothed_raw_freq records (set to 0 or could aggregate if needed) + for date in dates: + for variant in variants: + new_global_records.append({ + 'location': 'Global', + 'date': date, + 'variant': variant, + 'site': 'smoothed_raw_freq', + 'value': 0.0 # Could aggregate these too if needed + }) + + # Add agg_counts records (set to 0 or could aggregate if needed) + for date in dates: + for variant in variants: + new_global_records.append({ + 'location': 'Global', + 'date': date, + 'variant': variant, + 'site': 'agg_counts', + 'value': 0 # Could aggregate these too if needed + }) + + # Add all new Global records to the data + mlr_data['data'].extend(new_global_records) + + # Write output JSON + with open(args.output_json, 'w') as f: + json.dump(mlr_data, f, indent=2) + + print(f"\nAdded {len(new_global_records)} Global records to MLR results") + print(f"Output saved to {args.output_json}") + + # Summary of what was added + sites_added = defaultdict(int) + for record in new_global_records: + sites_added[record['site']] += 1 + + print("\nRecords added by site:") + for site, count in sorted(sites_added.items()): + print(f" {site:<20} {count:>6} records") + +if __name__ == "__main__": + main() diff --git a/scripts/generate_regional_population_weights.py b/scripts/generate_regional_population_weights.py new file mode 100644 index 0000000..222d181 --- /dev/null +++ b/scripts/generate_regional_population_weights.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +# script by Claude Code +""" +Generate regional population weights by aggregating country populations +according to the geo_regions mapping from nextstrain/seasonal-flu. +""" + +import pandas as pd +import sys + +def main(): + # Download and read the geo_regions mapping + geo_regions_url = "https://raw.githubusercontent.com/nextstrain/seasonal-flu/master/config/geo_regions.tsv" + geo_regions = pd.read_csv(geo_regions_url, sep='\t') + + # Download and read the population weights + population_url = "https://raw.githubusercontent.com/nextstrain/ncov/master/defaults/population_weights.tsv" + population_weights = pd.read_csv(population_url, sep='\t', comment='#') + + # Standardize column names for merging + geo_regions.columns = ['country', 'region'] + population_weights.columns = ['country', 'weight'] + + # Merge the dataframes + merged = pd.merge(geo_regions, population_weights, on='country', how='left') + + # Check for countries without population data + missing_pop = merged[merged['weight'].isna()] + if not missing_pop.empty: + print("Warning: Countries without population data:", file=sys.stderr) + for _, row in missing_pop.iterrows(): + print(f" {row['country']} ({row['region']})", file=sys.stderr) + print(file=sys.stderr) + + # Remove rows with missing population data + merged = merged.dropna(subset=['weight']) + + # Aggregate by region + regional_weights = merged.groupby('region')['weight'].sum().reset_index() + regional_weights.columns = ['region', 'weight'] + + # Sort by region name + regional_weights = regional_weights.sort_values('region') + + # Save to TSV + output_path = "config/regional_population_weights.tsv" + regional_weights.to_csv(output_path, sep='\t', index=False) + print(f"Regional population weights saved to {output_path}") + + # Print summary statistics + print("\nRegional Population Weights Summary:") + print("-" * 50) + total_pop = regional_weights['weight'].sum() + for _, row in regional_weights.iterrows(): + percentage = (row['weight'] / total_pop) * 100 + # Convert weight to millions for readability + pop_millions = row['weight'] / 1000 + print(f"{row['region']:<20} {pop_millions:>10.1f}M ({percentage:>5.1f}%)") + print("-" * 50) + print(f"{'Total':<20} {total_pop/1000:>10.1f}M (100.0%)") + + # Also check which regions from the results are present + expected_regions = ['Africa', 'Europe', 'North America', 'Oceania', + 'South America', 'Southeast Asia', 'West Asia'] + + print("\nRegions in current workflow:") + for region in expected_regions: + # Map display names to actual region names in geo_regions.tsv + region_map = { + 'North America': 'NorthAmerica', + 'South America': 'SouthAmerica', + 'Southeast Asia': 'SoutheastAsia', + 'West Asia': 'WestAsia' + } + actual_region = region_map.get(region, region) + + if actual_region in regional_weights['region'].values: + weight = regional_weights[regional_weights['region'] == actual_region]['weight'].iloc[0] + print(f" ✓ {region:<20} (population: {weight/1000:.1f}M)") + else: + print(f" ✗ {region:<20} (NOT FOUND in geo_regions)") + +if __name__ == "__main__": + main() diff --git a/scripts/plot-ga.py b/scripts/plot-ga.py index c8872dc..f4e46ab 100644 --- a/scripts/plot-ga.py +++ b/scripts/plot-ga.py @@ -26,6 +26,10 @@ def plot_ga(input_file, virus, color_file, out_var, out_loc, loc_lst, var_lst, p df = df[df["location"].isin(loc_filter)] df = df[df["variant"].isin(var_filter)] + + # Exclude "Global" from GA plots + df = df[df["location"] != "Global"] + base_chart = alt.Chart(df) # Parse pivot from file