Add keep_gradient_through_clamps STE for corrector clamps#1322
Open
mcgibbon wants to merge 1 commit into
Open
Conversation
Add replace_value_keep_gradient helper to corrector/utils.py: a straight-through estimator that keeps the exact projected value in the forward pass while letting gradient flow as identity in the backward pass. Gate its use behind a new default-off boolean config field keep_gradient_through_clamps, global to each corrector: - AtmosphereCorrectorConfig: wraps force_positive (precip clamp, site 1). - OceanCorrectorConfig: wraps force_positive and the SeaIceFractionConfig 0-1 clamp (site 2) plus its negative-ocean-fraction rebalance (site 3). Forward values are unchanged when the flag is on, so it is cleanly A/B-able against a baseline; only the gradient path differs. Default off preserves existing behavior.
2 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The hard
torch.clampoperations in the correctors run inside the training autograd path, andclamphas zero gradient in its saturated tail — so a cell the network pushes out of range gets no learning signal to come back; the projection silently absorbs the error. This adds a straight-through estimator (STE) that keeps the exact projected (clamped) value in the forward pass while letting gradient flow as identity in the backward pass, gated behind a new default-offkeep_gradient_through_clampsflag on each corrector. Because forward values are unchanged when the flag is on, it is cleanly A/B-able against a baseline; only the gradient path differs, and default-off preserves existing behavior.Validated by a same-seed A/B on the selected 4° residual baseline: with the flag on, training is stable to convergence (best_val_loss neutral) and the predicted near-zero precipitation distribution moves toward target (the dry-cell over-prediction is roughly halved across all four climate evals).
Changes:
fme.core.corrector.utils.replace_value_keep_gradient: new STE helper (x + (new_value - x).detach()— forward isnew_value, backward is identity).fme.core.corrector.utils.force_positive/ForcePositive: optionalkeep_gradientthat applies the precip clamp via the STE.fme.core.corrector.ocean.SeaIceFractionConfig/SeaIceFractionCorrection: optionalkeep_gradientapplying the STE to the 0-1 sea-ice-fraction clamp and the negative-ocean-fraction rebalance.fme.core.corrector.atmosphere.AtmosphereCorrectorConfig/ocean.OceanCorrectorConfig: new default-offkeep_gradient_through_clampsfield wiring the above.Tests added
If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated