diff --git a/Snakefile b/Snakefile index 1427850..41b673d 100644 --- a/Snakefile +++ b/Snakefile @@ -177,7 +177,7 @@ rule add_colors_to_mlr_model: model="results/{lineage}/{geo_resolution}/mlr/initial_MLR_results.json", auspice_config="results/{lineage}/auspice_config.json", output: - model="results/{lineage}/{geo_resolution}/mlr/MLR_results.json", + model="results/{lineage}/{geo_resolution}/mlr/MLR_results_with_colors.json", params: coloring_field=config["coloring_field"], shell: @@ -189,6 +189,34 @@ rule add_colors_to_mlr_model: --output {output.model:q} """ +rule add_global_to_regional_mlr: + input: + model="results/{lineage}/region/mlr/MLR_results_with_colors.json", + regional_weights="config/regional_population_weights.tsv" + output: + model="results/{lineage}/region/mlr/MLR_results.json" + params: + aggregate_regions=config.get("aggregate_regions", False) + run: + if params.aggregate_regions and wildcards.lineage in config["lineages"]: + shell(""" + python scripts/add_global_to_mlr_results.py \ + --input-json {input.model} \ + --regional-weights {input.regional_weights} \ + --output-json {output.model} + """) + else: + shell("cp {input.model} {output.model}") + +# For non-region geo_resolutions, just copy the file +rule finalize_country_mlr: + input: + model="results/{lineage}/country/mlr/MLR_results_with_colors.json", + output: + model="results/{lineage}/country/mlr/MLR_results.json", + shell: + "cp {input.model} {output.model}" + rule parse_mlr_json: input: model="results/{lineage}/{geo_resolution}/mlr/MLR_results.json", diff --git a/config/defaults.yaml b/config/defaults.yaml index 494b2c9..e85b137 100644 --- a/config/defaults.yaml +++ b/config/defaults.yaml @@ -33,3 +33,5 @@ prepare_data: haplotype_variant_column: "subclade" variant: "haplotype" coloring_field: "emerging_haplotype" + +aggregate_regions: true # Set to true to add Global region to regional MLR results diff --git a/config/regional_population_weights.tsv b/config/regional_population_weights.tsv new file mode 100644 index 0000000..be51d1a --- /dev/null +++ b/config/regional_population_weights.tsv @@ -0,0 +1,11 @@ +region weight +Africa 977096.0 +China 1445901.0 +Europe 515136.0 +JapanKorea 124370.0 +NorthAmerica 579891.0 +Oceania 27915.0 +SouthAmerica 431886.0 +SouthAsia 1640015.0 +SoutheastAsia 688728.0 +WestAsia 826100.0 diff --git a/scripts/add_global_to_mlr_results.py b/scripts/add_global_to_mlr_results.py new file mode 100644 index 0000000..b91b2ca --- /dev/null +++ b/scripts/add_global_to_mlr_results.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +# script by Claude Code +""" +Add Global region to MLR results by aggregating regional frequencies using population weights +and copying hierarchical GA values as Global GA values. +""" + +import argparse +import json +import pandas as pd +import numpy as np +from collections import defaultdict + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument('--input-json', required=True, + help='Path to input MLR results JSON') + parser.add_argument('--regional-weights', required=True, + help='Path to regional population weights TSV') + parser.add_argument('--output-json', required=True, + help='Path to output MLR results JSON with Global added') + args = parser.parse_args() + + # Read the MLR results JSON + with open(args.input_json, 'r') as f: + mlr_data = json.load(f) + + # Read regional population weights + weights_df = pd.read_csv(args.regional_weights, sep='\t') + + # Map regional names to match MLR data + region_name_map = { + 'Africa': 'Africa', + 'Europe': 'Europe', + 'NorthAmerica': 'North America', + 'SouthAmerica': 'South America', + 'SoutheastAsia': 'Southeast Asia', + 'WestAsia': 'West Asia', + 'Oceania': 'Oceania', + 'China': 'China', + 'JapanKorea': 'Japan Korea', + 'SouthAsia': 'South Asia' + } + + # Apply mapping to weights + weights_df['mapped_region'] = weights_df['region'].map(region_name_map) + + # Filter to only regions present in the MLR data + available_regions = [loc for loc in mlr_data['metadata']['location'] + if loc != 'hierarchical'] + weights_df = weights_df[weights_df['mapped_region'].isin(available_regions)] + + # Normalize weights + total_weight = weights_df['weight'].sum() + weights_df['normalized_weight'] = weights_df['weight'] / total_weight + + # Create weight lookup + weight_lookup = dict(zip(weights_df['mapped_region'], weights_df['normalized_weight'])) + + print(f"Using population weights for {len(weight_lookup)} regions:") + for region, weight in weight_lookup.items(): + print(f" {region:<20} weight: {weight:.4f}") + + # Add "Global" to metadata locations if not present + if "Global" not in mlr_data['metadata']['location']: + mlr_data['metadata']['location'].append("Global") + + # Collect data by site, date, variant, and ps for aggregation + freq_data = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) + raw_freq_data = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) + freq_forecast_data = defaultdict(lambda: defaultdict(lambda: defaultdict(dict))) + + # Process existing data to collect regional values + for record in mlr_data['data']: + if record['location'] in weight_lookup: + if record['site'] == 'freq': + key = (record['date'], record['variant'], record['ps']) + freq_data[key[0]][key[1]][key[2]][record['location']] = record['value'] + elif record['site'] == 'raw_freq': + key = (record['date'], record['variant']) + raw_freq_data[key[0]][key[1]][record['location']] = record['value'] + elif record['site'] == 'freq_forecast': + key = (record['date'], record['variant'], record['ps']) + freq_forecast_data[key[0]][key[1]][key[2]][record['location']] = record['value'] + + # Create new Global records + new_global_records = [] + + # Aggregate freq data + for date in freq_data: + for variant in freq_data[date]: + for ps in freq_data[date][variant]: + regional_values = freq_data[date][variant][ps] + if regional_values: + # Calculate weighted average, skipping None values + valid_regions = [region for region in regional_values + if region in weight_lookup and regional_values[region] is not None] + if valid_regions: + weighted_sum = sum(regional_values[region] * weight_lookup[region] + for region in valid_regions) + weight_sum = sum(weight_lookup[region] for region in valid_regions) + if weight_sum > 0: + global_value = weighted_sum / weight_sum + new_global_records.append({ + 'location': 'Global', + 'date': date, + 'variant': variant, + 'site': 'freq', + 'ps': ps, + 'value': global_value + }) + + # Aggregate raw_freq data + for date in raw_freq_data: + for variant in raw_freq_data[date]: + regional_values = raw_freq_data[date][variant] + if regional_values: + # Calculate weighted average, skipping None values + valid_regions = [region for region in regional_values + if region in weight_lookup and regional_values[region] is not None] + if valid_regions: + weighted_sum = sum(regional_values[region] * weight_lookup[region] + for region in valid_regions) + weight_sum = sum(weight_lookup[region] for region in valid_regions) + if weight_sum > 0: + global_value = weighted_sum / weight_sum + new_global_records.append({ + 'location': 'Global', + 'date': date, + 'variant': variant, + 'site': 'raw_freq', + 'value': global_value + }) + + # Aggregate freq_forecast data + for date in freq_forecast_data: + for variant in freq_forecast_data[date]: + for ps in freq_forecast_data[date][variant]: + regional_values = freq_forecast_data[date][variant][ps] + if regional_values: + # Calculate weighted average, skipping None values + valid_regions = [region for region in regional_values + if region in weight_lookup and regional_values[region] is not None] + if valid_regions: + weighted_sum = sum(regional_values[region] * weight_lookup[region] + for region in valid_regions) + weight_sum = sum(weight_lookup[region] for region in valid_regions) + if weight_sum > 0: + global_value = weighted_sum / weight_sum + new_global_records.append({ + 'location': 'Global', + 'date': date, + 'variant': variant, + 'site': 'freq_forecast', + 'ps': ps, + 'value': global_value + }) + + # Copy hierarchical GA values as Global GA + hierarchical_ga_records = [r for r in mlr_data['data'] + if r['location'] == 'hierarchical' and r['site'] == 'ga'] + for record in hierarchical_ga_records: + new_record = record.copy() + new_record['location'] = 'Global' + new_global_records.append(new_record) + + # For smoothed_raw_freq and agg_counts, we'll create empty/zero records for Global + # Get unique dates and variants + dates = mlr_data['metadata']['dates'] + variants = mlr_data['metadata']['variants'] + + # Add smoothed_raw_freq records (set to 0 or could aggregate if needed) + for date in dates: + for variant in variants: + new_global_records.append({ + 'location': 'Global', + 'date': date, + 'variant': variant, + 'site': 'smoothed_raw_freq', + 'value': 0.0 # Could aggregate these too if needed + }) + + # Add agg_counts records (set to 0 or could aggregate if needed) + for date in dates: + for variant in variants: + new_global_records.append({ + 'location': 'Global', + 'date': date, + 'variant': variant, + 'site': 'agg_counts', + 'value': 0 # Could aggregate these too if needed + }) + + # Add all new Global records to the data + mlr_data['data'].extend(new_global_records) + + # Write output JSON + with open(args.output_json, 'w') as f: + json.dump(mlr_data, f, indent=2) + + print(f"\nAdded {len(new_global_records)} Global records to MLR results") + print(f"Output saved to {args.output_json}") + + # Summary of what was added + sites_added = defaultdict(int) + for record in new_global_records: + sites_added[record['site']] += 1 + + print("\nRecords added by site:") + for site, count in sorted(sites_added.items()): + print(f" {site:<20} {count:>6} records") + +if __name__ == "__main__": + main() diff --git a/scripts/generate_regional_population_weights.py b/scripts/generate_regional_population_weights.py new file mode 100644 index 0000000..222d181 --- /dev/null +++ b/scripts/generate_regional_population_weights.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +# script by Claude Code +""" +Generate regional population weights by aggregating country populations +according to the geo_regions mapping from nextstrain/seasonal-flu. +""" + +import pandas as pd +import sys + +def main(): + # Download and read the geo_regions mapping + geo_regions_url = "https://raw.githubusercontent.com/nextstrain/seasonal-flu/master/config/geo_regions.tsv" + geo_regions = pd.read_csv(geo_regions_url, sep='\t') + + # Download and read the population weights + population_url = "https://raw.githubusercontent.com/nextstrain/ncov/master/defaults/population_weights.tsv" + population_weights = pd.read_csv(population_url, sep='\t', comment='#') + + # Standardize column names for merging + geo_regions.columns = ['country', 'region'] + population_weights.columns = ['country', 'weight'] + + # Merge the dataframes + merged = pd.merge(geo_regions, population_weights, on='country', how='left') + + # Check for countries without population data + missing_pop = merged[merged['weight'].isna()] + if not missing_pop.empty: + print("Warning: Countries without population data:", file=sys.stderr) + for _, row in missing_pop.iterrows(): + print(f" {row['country']} ({row['region']})", file=sys.stderr) + print(file=sys.stderr) + + # Remove rows with missing population data + merged = merged.dropna(subset=['weight']) + + # Aggregate by region + regional_weights = merged.groupby('region')['weight'].sum().reset_index() + regional_weights.columns = ['region', 'weight'] + + # Sort by region name + regional_weights = regional_weights.sort_values('region') + + # Save to TSV + output_path = "config/regional_population_weights.tsv" + regional_weights.to_csv(output_path, sep='\t', index=False) + print(f"Regional population weights saved to {output_path}") + + # Print summary statistics + print("\nRegional Population Weights Summary:") + print("-" * 50) + total_pop = regional_weights['weight'].sum() + for _, row in regional_weights.iterrows(): + percentage = (row['weight'] / total_pop) * 100 + # Convert weight to millions for readability + pop_millions = row['weight'] / 1000 + print(f"{row['region']:<20} {pop_millions:>10.1f}M ({percentage:>5.1f}%)") + print("-" * 50) + print(f"{'Total':<20} {total_pop/1000:>10.1f}M (100.0%)") + + # Also check which regions from the results are present + expected_regions = ['Africa', 'Europe', 'North America', 'Oceania', + 'South America', 'Southeast Asia', 'West Asia'] + + print("\nRegions in current workflow:") + for region in expected_regions: + # Map display names to actual region names in geo_regions.tsv + region_map = { + 'North America': 'NorthAmerica', + 'South America': 'SouthAmerica', + 'Southeast Asia': 'SoutheastAsia', + 'West Asia': 'WestAsia' + } + actual_region = region_map.get(region, region) + + if actual_region in regional_weights['region'].values: + weight = regional_weights[regional_weights['region'] == actual_region]['weight'].iloc[0] + print(f" ✓ {region:<20} (population: {weight/1000:.1f}M)") + else: + print(f" ✗ {region:<20} (NOT FOUND in geo_regions)") + +if __name__ == "__main__": + main() diff --git a/scripts/plot-ga.py b/scripts/plot-ga.py index c8872dc..f4e46ab 100644 --- a/scripts/plot-ga.py +++ b/scripts/plot-ga.py @@ -26,6 +26,10 @@ def plot_ga(input_file, virus, color_file, out_var, out_loc, loc_lst, var_lst, p df = df[df["location"].isin(loc_filter)] df = df[df["variant"].isin(var_filter)] + + # Exclude "Global" from GA plots + df = df[df["location"] != "Global"] + base_chart = alt.Chart(df) # Parse pivot from file