Skip to content

Add keep_gradient_through_clamps STE for corrector clamps#1322

Open
mcgibbon wants to merge 1 commit into
mainfrom
feature/corrector-clamp-ste
Open

Add keep_gradient_through_clamps STE for corrector clamps#1322
mcgibbon wants to merge 1 commit into
mainfrom
feature/corrector-clamp-ste

Conversation

@mcgibbon

Copy link
Copy Markdown
Contributor

The hard torch.clamp operations in the correctors run inside the training autograd path, and clamp has zero gradient in its saturated tail — so a cell the network pushes out of range gets no learning signal to come back; the projection silently absorbs the error. This adds a straight-through estimator (STE) that keeps the exact projected (clamped) value in the forward pass while letting gradient flow as identity in the backward pass, gated behind a new default-off keep_gradient_through_clamps flag on each corrector. Because forward values are unchanged when the flag is on, it is cleanly A/B-able against a baseline; only the gradient path differs, and default-off preserves existing behavior.

Validated by a same-seed A/B on the selected 4° residual baseline: with the flag on, training is stable to convergence (best_val_loss neutral) and the predicted near-zero precipitation distribution moves toward target (the dry-cell over-prediction is roughly halved across all four climate evals).

Changes:

  • fme.core.corrector.utils.replace_value_keep_gradient: new STE helper (x + (new_value - x).detach() — forward is new_value, backward is identity).

  • fme.core.corrector.utils.force_positive / ForcePositive: optional keep_gradient that applies the precip clamp via the STE.

  • fme.core.corrector.ocean.SeaIceFractionConfig / SeaIceFractionCorrection: optional keep_gradient applying the STE to the 0-1 sea-ice-fraction clamp and the negative-ocean-fraction rebalance.

  • fme.core.corrector.atmosphere.AtmosphereCorrectorConfig / ocean.OceanCorrectorConfig: new default-off keep_gradient_through_clamps field wiring the above.

  • Tests added

  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

Add replace_value_keep_gradient helper to corrector/utils.py: a
straight-through estimator that keeps the exact projected value in the
forward pass while letting gradient flow as identity in the backward pass.

Gate its use behind a new default-off boolean config field
keep_gradient_through_clamps, global to each corrector:
- AtmosphereCorrectorConfig: wraps force_positive (precip clamp, site 1).
- OceanCorrectorConfig: wraps force_positive and the SeaIceFractionConfig
  0-1 clamp (site 2) plus its negative-ocean-fraction rebalance (site 3).

Forward values are unchanged when the flag is on, so it is cleanly A/B-able
against a baseline; only the gradient path differs. Default off preserves
existing behavior.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant