Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ rule add_colors_to_mlr_model:
model="results/{lineage}/{geo_resolution}/mlr/initial_MLR_results.json",
auspice_config="results/{lineage}/auspice_config.json",
output:
model="results/{lineage}/{geo_resolution}/mlr/MLR_results.json",
model="results/{lineage}/{geo_resolution}/mlr/MLR_results_with_colors.json",
params:
coloring_field=config["coloring_field"],
shell:
Expand All @@ -189,6 +189,34 @@ rule add_colors_to_mlr_model:
--output {output.model:q}
"""

rule add_global_to_regional_mlr:
input:
model="results/{lineage}/region/mlr/MLR_results_with_colors.json",
regional_weights="config/regional_population_weights.tsv"
output:
model="results/{lineage}/region/mlr/MLR_results.json"
params:
aggregate_regions=config.get("aggregate_regions", False)
run:
if params.aggregate_regions and wildcards.lineage in config["lineages"]:
shell("""
python scripts/add_global_to_mlr_results.py \
--input-json {input.model} \
--regional-weights {input.regional_weights} \
--output-json {output.model}
""")
else:
shell("cp {input.model} {output.model}")

# For non-region geo_resolutions, just copy the file
rule finalize_country_mlr:
input:
model="results/{lineage}/country/mlr/MLR_results_with_colors.json",
output:
model="results/{lineage}/country/mlr/MLR_results.json",
shell:
"cp {input.model} {output.model}"

rule parse_mlr_json:
input:
model="results/{lineage}/{geo_resolution}/mlr/MLR_results.json",
Expand Down
2 changes: 2 additions & 0 deletions config/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@ prepare_data:
haplotype_variant_column: "subclade"
variant: "haplotype"
coloring_field: "emerging_haplotype"

aggregate_regions: true # Set to true to add Global region to regional MLR results
11 changes: 11 additions & 0 deletions config/regional_population_weights.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
region weight
Africa 977096.0
China 1445901.0
Europe 515136.0
JapanKorea 124370.0
NorthAmerica 579891.0
Oceania 27915.0
SouthAmerica 431886.0
SouthAsia 1640015.0
SoutheastAsia 688728.0
WestAsia 826100.0
214 changes: 214 additions & 0 deletions scripts/add_global_to_mlr_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
#!/usr/bin/env python3
# script by Claude Code
"""
Add Global region to MLR results by aggregating regional frequencies using population weights
and copying hierarchical GA values as Global GA values.
"""

import argparse
import json
import pandas as pd
import numpy as np
from collections import defaultdict

def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--input-json', required=True,
help='Path to input MLR results JSON')
parser.add_argument('--regional-weights', required=True,
help='Path to regional population weights TSV')
parser.add_argument('--output-json', required=True,
help='Path to output MLR results JSON with Global added')
args = parser.parse_args()

# Read the MLR results JSON
with open(args.input_json, 'r') as f:
mlr_data = json.load(f)

# Read regional population weights
weights_df = pd.read_csv(args.regional_weights, sep='\t')

# Map regional names to match MLR data
region_name_map = {
'Africa': 'Africa',
'Europe': 'Europe',
'NorthAmerica': 'North America',
'SouthAmerica': 'South America',
'SoutheastAsia': 'Southeast Asia',
'WestAsia': 'West Asia',
'Oceania': 'Oceania',
'China': 'China',
'JapanKorea': 'Japan Korea',
'SouthAsia': 'South Asia'
}

# Apply mapping to weights
weights_df['mapped_region'] = weights_df['region'].map(region_name_map)

# Filter to only regions present in the MLR data
available_regions = [loc for loc in mlr_data['metadata']['location']
if loc != 'hierarchical']
weights_df = weights_df[weights_df['mapped_region'].isin(available_regions)]

# Normalize weights
total_weight = weights_df['weight'].sum()
weights_df['normalized_weight'] = weights_df['weight'] / total_weight

# Create weight lookup
weight_lookup = dict(zip(weights_df['mapped_region'], weights_df['normalized_weight']))

print(f"Using population weights for {len(weight_lookup)} regions:")
for region, weight in weight_lookup.items():
print(f" {region:<20} weight: {weight:.4f}")

# Add "Global" to metadata locations if not present
if "Global" not in mlr_data['metadata']['location']:
mlr_data['metadata']['location'].append("Global")

# Collect data by site, date, variant, and ps for aggregation
freq_data = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
raw_freq_data = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
freq_forecast_data = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))

# Process existing data to collect regional values
for record in mlr_data['data']:
if record['location'] in weight_lookup:
if record['site'] == 'freq':
key = (record['date'], record['variant'], record['ps'])
freq_data[key[0]][key[1]][key[2]][record['location']] = record['value']
elif record['site'] == 'raw_freq':
key = (record['date'], record['variant'])
raw_freq_data[key[0]][key[1]][record['location']] = record['value']
elif record['site'] == 'freq_forecast':
key = (record['date'], record['variant'], record['ps'])
freq_forecast_data[key[0]][key[1]][key[2]][record['location']] = record['value']

# Create new Global records
new_global_records = []

# Aggregate freq data
for date in freq_data:
for variant in freq_data[date]:
for ps in freq_data[date][variant]:
regional_values = freq_data[date][variant][ps]
if regional_values:
# Calculate weighted average, skipping None values
valid_regions = [region for region in regional_values
if region in weight_lookup and regional_values[region] is not None]
if valid_regions:
weighted_sum = sum(regional_values[region] * weight_lookup[region]
for region in valid_regions)
weight_sum = sum(weight_lookup[region] for region in valid_regions)
if weight_sum > 0:
global_value = weighted_sum / weight_sum
new_global_records.append({
'location': 'Global',
'date': date,
'variant': variant,
'site': 'freq',
'ps': ps,
'value': global_value
})

# Aggregate raw_freq data
for date in raw_freq_data:
for variant in raw_freq_data[date]:
regional_values = raw_freq_data[date][variant]
if regional_values:
# Calculate weighted average, skipping None values
valid_regions = [region for region in regional_values
if region in weight_lookup and regional_values[region] is not None]
if valid_regions:
weighted_sum = sum(regional_values[region] * weight_lookup[region]
for region in valid_regions)
weight_sum = sum(weight_lookup[region] for region in valid_regions)
if weight_sum > 0:
global_value = weighted_sum / weight_sum
new_global_records.append({
'location': 'Global',
'date': date,
'variant': variant,
'site': 'raw_freq',
'value': global_value
})

# Aggregate freq_forecast data
for date in freq_forecast_data:
for variant in freq_forecast_data[date]:
for ps in freq_forecast_data[date][variant]:
regional_values = freq_forecast_data[date][variant][ps]
if regional_values:
# Calculate weighted average, skipping None values
valid_regions = [region for region in regional_values
if region in weight_lookup and regional_values[region] is not None]
if valid_regions:
weighted_sum = sum(regional_values[region] * weight_lookup[region]
for region in valid_regions)
weight_sum = sum(weight_lookup[region] for region in valid_regions)
if weight_sum > 0:
global_value = weighted_sum / weight_sum
new_global_records.append({
'location': 'Global',
'date': date,
'variant': variant,
'site': 'freq_forecast',
'ps': ps,
'value': global_value
})

# Copy hierarchical GA values as Global GA
hierarchical_ga_records = [r for r in mlr_data['data']
if r['location'] == 'hierarchical' and r['site'] == 'ga']
for record in hierarchical_ga_records:
new_record = record.copy()
new_record['location'] = 'Global'
new_global_records.append(new_record)

# For smoothed_raw_freq and agg_counts, we'll create empty/zero records for Global
# Get unique dates and variants
dates = mlr_data['metadata']['dates']
variants = mlr_data['metadata']['variants']

# Add smoothed_raw_freq records (set to 0 or could aggregate if needed)
for date in dates:
for variant in variants:
new_global_records.append({
'location': 'Global',
'date': date,
'variant': variant,
'site': 'smoothed_raw_freq',
'value': 0.0 # Could aggregate these too if needed
})

# Add agg_counts records (set to 0 or could aggregate if needed)
for date in dates:
for variant in variants:
new_global_records.append({
'location': 'Global',
'date': date,
'variant': variant,
'site': 'agg_counts',
'value': 0 # Could aggregate these too if needed
})

# Add all new Global records to the data
mlr_data['data'].extend(new_global_records)

# Write output JSON
with open(args.output_json, 'w') as f:
json.dump(mlr_data, f, indent=2)

print(f"\nAdded {len(new_global_records)} Global records to MLR results")
print(f"Output saved to {args.output_json}")

# Summary of what was added
sites_added = defaultdict(int)
for record in new_global_records:
sites_added[record['site']] += 1

print("\nRecords added by site:")
for site, count in sorted(sites_added.items()):
print(f" {site:<20} {count:>6} records")

if __name__ == "__main__":
main()
84 changes: 84 additions & 0 deletions scripts/generate_regional_population_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#!/usr/bin/env python3
# script by Claude Code
"""
Generate regional population weights by aggregating country populations
according to the geo_regions mapping from nextstrain/seasonal-flu.
"""

import pandas as pd
import sys

def main():
# Download and read the geo_regions mapping
geo_regions_url = "https://raw.githubusercontent.com/nextstrain/seasonal-flu/master/config/geo_regions.tsv"
geo_regions = pd.read_csv(geo_regions_url, sep='\t')

# Download and read the population weights
population_url = "https://raw.githubusercontent.com/nextstrain/ncov/master/defaults/population_weights.tsv"
population_weights = pd.read_csv(population_url, sep='\t', comment='#')

# Standardize column names for merging
geo_regions.columns = ['country', 'region']
population_weights.columns = ['country', 'weight']

# Merge the dataframes
merged = pd.merge(geo_regions, population_weights, on='country', how='left')

# Check for countries without population data
missing_pop = merged[merged['weight'].isna()]
if not missing_pop.empty:
print("Warning: Countries without population data:", file=sys.stderr)
for _, row in missing_pop.iterrows():
print(f" {row['country']} ({row['region']})", file=sys.stderr)
print(file=sys.stderr)

# Remove rows with missing population data
merged = merged.dropna(subset=['weight'])

# Aggregate by region
regional_weights = merged.groupby('region')['weight'].sum().reset_index()
regional_weights.columns = ['region', 'weight']

# Sort by region name
regional_weights = regional_weights.sort_values('region')

# Save to TSV
output_path = "config/regional_population_weights.tsv"
regional_weights.to_csv(output_path, sep='\t', index=False)
print(f"Regional population weights saved to {output_path}")

# Print summary statistics
print("\nRegional Population Weights Summary:")
print("-" * 50)
total_pop = regional_weights['weight'].sum()
for _, row in regional_weights.iterrows():
percentage = (row['weight'] / total_pop) * 100
# Convert weight to millions for readability
pop_millions = row['weight'] / 1000
print(f"{row['region']:<20} {pop_millions:>10.1f}M ({percentage:>5.1f}%)")
print("-" * 50)
print(f"{'Total':<20} {total_pop/1000:>10.1f}M (100.0%)")

# Also check which regions from the results are present
expected_regions = ['Africa', 'Europe', 'North America', 'Oceania',
'South America', 'Southeast Asia', 'West Asia']

print("\nRegions in current workflow:")
for region in expected_regions:
# Map display names to actual region names in geo_regions.tsv
region_map = {
'North America': 'NorthAmerica',
'South America': 'SouthAmerica',
'Southeast Asia': 'SoutheastAsia',
'West Asia': 'WestAsia'
}
actual_region = region_map.get(region, region)

if actual_region in regional_weights['region'].values:
weight = regional_weights[regional_weights['region'] == actual_region]['weight'].iloc[0]
print(f" ✓ {region:<20} (population: {weight/1000:.1f}M)")
else:
print(f" ✗ {region:<20} (NOT FOUND in geo_regions)")

if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions scripts/plot-ga.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def plot_ga(input_file, virus, color_file, out_var, out_loc, loc_lst, var_lst, p

df = df[df["location"].isin(loc_filter)]
df = df[df["variant"].isin(var_filter)]

# Exclude "Global" from GA plots
df = df[df["location"] != "Global"]

base_chart = alt.Chart(df)

# Parse pivot from file
Expand Down