Skip to content

Add per-sample spectral whitening option to the energy-score loss#1303

Draft
mcgibbon wants to merge 2 commits into
mainfrom
feature/energy-score-spectral-whitening-pr
Draft

Add per-sample spectral whitening option to the energy-score loss#1303
mcgibbon wants to merge 2 commits into
mainfrom
feature/energy-score-spectral-whitening-pr

Conversation

@mcgibbon

Copy link
Copy Markdown
Contributor

Small-scale (high-l) spectral power converges very slowly in SFNO training because the loss lives in a red (grid) spectrum: high-l modes carry tiny amplitude, so they contribute little to the energy score and their dedicated dhconv filter weights see proportionally small, low-SNR gradients. A per-l gradient probe on a trained checkpoint confirmed an ~8× monotone decline of high-l gradient magnitude (and the energy score, though computed per (l,m) mode, uses uniform mode_weights, so it does not correct this).

This PR adds an opt-in spectral_whitening='per_sample' mode to EnergyScoreLoss that reweights the per-(l,m) energy score by the inverse per-degree RMS amplitude of each target sample. This flattens each target's angular power spectrum so high-l modes are no longer starved, raising their gradient SNR. The factor is computed from the detached target coefficients (no new gradient path) and is magnitude-preserving (rescaled per (sample, channel) so the overall energy-score magnitude — and the meaning of energy_score_weight — is unchanged; only the per-scale balance shifts). A white-spectrum target yields a uniform factor (no-op). Default 'none' is bit-for-bit backward compatible.

An A/B validation run (a perturbation of a fg16 sr0.125 residual SFNO baseline, identical except whitening enabled) showed the targeted effect with neutral skill:

  • Small-scale spectral fidelity (the target): large win. Train smallest_scale_norm_bias reaches the baseline's final (epoch-120) value by epoch ~11 and finishes ~2× lower (0.031 vs 0.061). Inference-time spectral bias also improves substantially (e.g. 10year h500 1.23 → 0.40, TMP850 0.13 → −0.006).
  • Probabilistic skill (CRPS): roughly neutral. Validation 1-step CRPS is ~1% better on every field (geomean 0.989); day-5 CRPS is tied on one weather period (geomean 1.001) and ~2% worse on another (1.019), within single-seed noise. (The training loss value is not comparable across arms because whitening redefines the energy-score term; CRPS is the proper definition-stable comparison.)

Changes:

  • fme.core.loss.EnergyScoreLoss — add spectral_whitening / whitening_eps_frac args and _spectral_whitening_factor (per-(sample, channel, l), detached, magnitude-preserving, floored at whitening_eps_frac of the per-sample mean degree amplitude); applied as a multiplicative reweight after mode_weights.

  • fme.core.loss.EnsembleLoss — thread energy_score_whitening / energy_score_whitening_eps_frac kwargs through to EnergyScoreLoss.

  • fme.core.test_loss — tests for the no-op default, white-target invariance, small-scale boost, magnitude preservation, and config wiring.

  • Tests added

  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

Resolves # (delete if none)

… loss

EnergyScoreLoss gains a spectral_whitening='per_sample' mode that reweights
the per-(l,m) energy score by the inverse per-degree RMS amplitude of each
target sample (computed over valid orders m<=l, detached, magnitude-preserving
so energy_score_weight keeps its meaning). This flattens each target sample's
spectrum so high-l (small-scale) modes are no longer starved by the red
spectrum, raising their gradient SNR. Default 'none' is a no-op (backward
compatible). Threaded through EnsembleLoss via energy_score_whitening /
energy_score_whitening_eps_frac kwargs. Tests cover the no-op, white-target
invariance, small-scale boost, magnitude preservation, and config wiring.
Comment thread fme/core/loss.py Outdated
def __init__(
self,
sht: Callable[[torch.Tensor], torch.Tensor],
spectral_whitening: str = "none",

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this a Literal[(options)] instead of str

Comment thread fme/core/loss.py Outdated
finite_difference_crps_weight: float = 0.0,
finite_difference_crps_levels: int = 1,
almost_fair_crps_alpha: float = 1.0,
energy_score_whitening: str = "none",

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar type comment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant