You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Two corrector training features, unified behind a single StepOutputLoss object (per Jeremy's suggestion): a Loss-like object that takes a StepOutput and folds in both the pre-corrector loss target and the correction-magnitude penalty, jointly configurable. Landed as two commits in one PR. This is PR 3 of a 3-PR split of #1218 and #1222, depending only on #1271.
Both features operate in the stepper's loss-normalized space (get_loss_normalizer), consistent with how StepLoss already normalizes, so there is no normalizer mismatch between the main loss and the penalty.
StepOutputLoss (in fme/core/loss.py, alongside StepLoss)
Wraps the existing StepLoss (the main loss, unchanged) plus an optional regularization WeightedMappingLoss, a reg weight, and an optional pre-corrector exclusion matcher.
Call signature: __call__(step_output, target_step, step, data_mask) -> StepOutputLossResult, unfolding the ensemble dim internally (n_ensemble fixed at build).
StepOutputLossResult exposes .main (a LossOutput), .regularization (a Tensor | None), and .total() == main.total() + regularization.
Exposes needs_uncorrected_grad: bool (true iff either feature is configured), used by the stepper to decide detaching.
Realized inside StepOutputLoss.__call__: the main-loss target overrides corrected values with step_output.uncorrected for every corrector-modified variable not matched by the exclusion matcher, then defers to the wrapped StepLoss.
exclude_names_and_prefixes uses the shared name-and-prefix matcher consolidated with the spatial-masking convention (bare name matches 2D + all 3D levels, trailing-underscore prefix matches all levels, explicit name_<level> matches exactly).
The exclusion matcher and its warn-once validation of unmatched entries live on the loss object, since it is what sees step_output.uncorrected.keys() each step.
Rollout state and returned predictions always use fully corrected outputs; only the loss target changes.
Adds a new corrector_enabled: bool abstract property on StepABC, implemented by every step type. The TrainStepper raises at build time if corrector_loss is configured but corrector_enabled is False (a config mistake that would otherwise train standard with no signal). A first-step warn-once covers the case where a corrector runs but uncorrected comes back empty.
Derives corrections as output − uncorrected over the sparse corrected keys (subsuming Add corrector regularization to training #1218's separate dense corrections payload — no extra field on StepOutput is needed) and penalizes them against a zero baseline.
Penalty is the mean over corrected channels only (consistent with PR2's channel_mean), realized via the existing LossOutput active-channel masking. ⚠️ This differs from Add corrector regularization to training #1218's all-channel dilution by a factor of n_corrected / n_all, so weights tuned in exp/2026-06-03-corrector-regularization must be rescaled — call this out in the PR body.
The reg term carries no per-step decay (only weight scales it, matching Add corrector regularization to training #1218); document the asymmetry vs the main loss's sqrt_loss_step_decay in the config docstring.
Per-step (corrector_regularization_step_{i}) and epoch-aggregated (corrector_regularization) metrics; the trainer's standard batch_ prefix yields batch_corrector_regularization with no extra work.
Credits Co-authored-by: Jeremy McGibbon <jeremy.mcgibbon@gmail.com>.
Config (single CorrectorLossConfig, in new fme/core/corrector/loss.py)
Holds precorrector_optimization: PreCorrectorOptimizationConfig | None and regularization: CorrectorRegularizationConfig | None; its __post_init__ errors if both are None.
One corrector_loss: CorrectorLossConfig | None field on the train stepper config replaces the two sibling fields from the source PRs.
CorrectorRegularizationConfig.__post_init__ rejects ensemble losses, NaN losses, and global-mean loss types.
Combining the two features is first-class (one config carries both); no validation forbids it, and the interaction is documented in the config docstrings. Validating the combined regime is out of scope; the combination is merely permitted.
Import is one-way (corrector/loss.py → loss.py): the config's build() assembles the primitives StepOutputLoss takes, so StepOutputLoss itself imports no config.
TrainStepper integration
Always builds one StepOutputLoss (a pass-through to StepLoss when corrector_loss is None), so _accumulate_loss has a single uniform call site with no per-feature branching.
Accumulates result.total() once per optimized step, which structurally eliminates the double-backward-under-gradient-accumulation error (no separate accumulate_loss call for the penalty to backward a freed graph).
Passes detach_uncorrected = not loss_obj.needs_uncorrected_grad into the predict generator. Detaching the pre-correction values would zero the (C′ − I) term of the correction gradient and silently corrupt both features.
Metrics preserved: loss_step_{i} from .main, corrector_regularization_step_{i} from .regularization, epoch-aggregated corrector_regularization, per-channel losses from .main.
Coupled support (per-realm metrics, coupled pre-corrector optimization, and the coupled regularization drafted in #1218) is out of scope, deferred to a later PR.
Acceptance criteria
StepOutputLoss unit-tested directly with a deterministic StepOutput: pre-corrector target swap (incl. exclusions), the analytically expected penalty, the corrected-channels-only reg scale, the .main/.regularization/.total() decomposition, and needs_uncorrected_grad for each config combination.
With a deterministic corrector, the training loss numerically equals the loss computed against pre-correction values, including exclusion behavior and warn-once validation. (Prior art: train-on-batch tests in both source PRs.)
Regularization adds the analytically expected penalty, with per-step and epoch-aggregated metrics logged.
Gradients flow through the correction (e.g. corrector-dependent parameters receive grad) when either feature is configured; plain prediction still detaches.
The two features compose in one run.
Both features work under use_gradient_accumulation=True: the accumulation path detaches rollout state between steps, and tests show per-step losses/penalties and gradients through corrections survive it (the single result.total() accumulation must not double-backward a freed graph).
Config / build validation: __post_init__ tests for CorrectorRegularizationConfig's rejected loss types and CorrectorLossConfig's both-None error; a build-time test that a corrector_loss configured against a stepper with corrector_enabled is False raises.
All new configs are optional with defaults preserving current behavior; existing checkpoints load unchanged.
Original two-config design (superseded — kept for posterity)
The original breakdown kept #1222's pre-corrector optimization and #1218's regularization as two independent configs on the train stepper, unified only implicitly. It was superseded by the StepOutputLoss design above (Jeremy's suggestion) to make both features flow through one swappable, jointly-configurable Loss-like object. Substantive differences in the new version: a single CorrectorLossConfig replaces the two sibling fields; the regularization penalty averages over corrected channels only (a n_corrected / n_all rescaling vs the original) rather than diluting by all channels; a new StepABC.corrector_enabled accessor enables a build-time error for a misconfigured corrector_loss; and the single result.total() accumulation structurally avoids the gradient-accumulation double-backward rather than patching it.
What to build
Two training features that use the pre-correction values carried by StepOutput, landed as two commits in one PR. This is PR 3 of a 3-PR split of #1218 and #1222, depending only on #1271.
Commit 1 — pre-corrector optimization (adopt from #1222 as-is):
An optional pre-corrector optimization config on the train stepper config; presence of the object is the on/off switch.
Loss for corrector-modified variables is computed against pre-correction values; exclude_names_and_prefixes opts variables back to post-corrector targets, using a shared name-and-prefix matcher consolidated with the spatial-masking matching convention (bare name matches 2D + all 3D levels, trailing-underscore prefix matches all levels, explicit name_<level> matches exactly).
Warn-once runtime validation of exclusion entries against the actual corrector-modified variable set on the first training step, so config typos surface instead of silently doing nothing.
Rollout state and returned predictions always use fully corrected outputs; only the loss target changes.
The stepper keeps uncorrected tensors attached to the autograd graph when either feature here is configured, detached otherwise. (Detaching the pre-correction values would zero the (C′ − I) term of the correction gradient and silently corrupt both features.)
Commit 2 — corrector regularization (adopt from #1218, single-module only):
A regularization config wrapping a loss config plus weight; __post_init__ validation rejects ensemble losses, NaN losses, and global-mean loss types.
Penalty computed between per-variable corrections (derived as output − uncorrected, undetached) and a zero baseline in loss-normalized space, accumulated into the training loss per optimized step.
Per-step (corrector_regularization_step_{i}) and epoch-aggregated (corrector_regularization) metrics; the trainer's standard batch_ prefix yields batch_corrector_regularization with no extra work.
Combining regularization with pre-corrector optimization is allowed (no validation forbids it); document the interaction in both config docstrings. Validating the combined training regime is out of scope; the combination is merely permitted.
This commit credits Co-authored-by: Jeremy McGibbon <jeremy.mcgibbon@gmail.com>.
Coupled support (per-realm metrics, coupled pre-corrector optimization, and the coupled regularization drafted in #1218) is out of scope, deferred to a later PR.
Acceptance criteria
With a deterministic corrector, the training loss numerically equals the loss computed against pre-correction values, including exclusion behavior and warn-once validation. (Prior art: train-on-batch tests in both source PRs.)
Regularization adds the analytically expected penalty, with per-step and epoch-aggregated metrics logged.
Gradients flow through the correction (e.g. corrector-dependent parameters receive grad) when either feature is configured; plain prediction still detaches.
The two features compose in one run.
Both features work under use_gradient_accumulation=True: the accumulation path detaches rollout state between steps, and tests show per-step losses/penalties and gradients through corrections survive it.
__post_init__ tests for the regularization config's rejected loss types.
All new configs are optional with defaults preserving current behavior; existing checkpoints load unchanged.
What to build
Two corrector training features, unified behind a single
StepOutputLossobject (per Jeremy's suggestion): a Loss-like object that takes aStepOutputand folds in both the pre-corrector loss target and the correction-magnitude penalty, jointly configurable. Landed as two commits in one PR. This is PR 3 of a 3-PR split of #1218 and #1222, depending only on #1271.Both features operate in the stepper's loss-normalized space (
get_loss_normalizer), consistent with howStepLossalready normalizes, so there is no normalizer mismatch between the main loss and the penalty.StepOutputLoss(infme/core/loss.py, alongsideStepLoss)Wraps the existing
StepLoss(the main loss, unchanged) plus an optional regularizationWeightedMappingLoss, a reg weight, and an optional pre-corrector exclusion matcher.__call__(step_output, target_step, step, data_mask) -> StepOutputLossResult, unfolding the ensemble dim internally (n_ensemble fixed at build).StepOutputLossResultexposes.main(aLossOutput),.regularization(aTensor | None), and.total() == main.total() + regularization.needs_uncorrected_grad: bool(true iff either feature is configured), used by the stepper to decide detaching.Commit 1 — pre-corrector optimization (from #1222)
StepOutputLoss.__call__: the main-loss target overrides corrected values withstep_output.uncorrectedfor every corrector-modified variable not matched by the exclusion matcher, then defers to the wrappedStepLoss.exclude_names_and_prefixesuses the shared name-and-prefix matcher consolidated with the spatial-masking convention (bare name matches 2D + all 3D levels, trailing-underscore prefix matches all levels, explicitname_<level>matches exactly).step_output.uncorrected.keys()each step.corrector_enabled: boolabstract property onStepABC, implemented by every step type. TheTrainStepperraises at build time ifcorrector_lossis configured butcorrector_enabledisFalse(a config mistake that would otherwise train standard with no signal). A first-step warn-once covers the case where a corrector runs butuncorrectedcomes back empty.Commit 2 — corrector regularization (from #1218, single-module only)
output − uncorrectedover the sparse corrected keys (subsuming Add corrector regularization to training #1218's separate dense corrections payload — no extra field onStepOutputis needed) and penalizes them against a zero baseline.channel_mean), realized via the existingLossOutputactive-channel masking.n_corrected / n_all, so weights tuned inexp/2026-06-03-corrector-regularizationmust be rescaled — call this out in the PR body.weightscales it, matching Add corrector regularization to training #1218); document the asymmetry vs the main loss'ssqrt_loss_step_decayin the config docstring.corrector_regularization_step_{i}) and epoch-aggregated (corrector_regularization) metrics; the trainer's standardbatch_prefix yieldsbatch_corrector_regularizationwith no extra work.Co-authored-by: Jeremy McGibbon <jeremy.mcgibbon@gmail.com>.Config (single
CorrectorLossConfig, in newfme/core/corrector/loss.py)precorrector_optimization: PreCorrectorOptimizationConfig | Noneandregularization: CorrectorRegularizationConfig | None; its__post_init__errors if both areNone.corrector_loss: CorrectorLossConfig | Nonefield on the train stepper config replaces the two sibling fields from the source PRs.CorrectorRegularizationConfig.__post_init__rejects ensemble losses, NaN losses, and global-mean loss types.corrector/loss.py → loss.py): the config'sbuild()assembles the primitivesStepOutputLosstakes, soStepOutputLossitself imports no config.TrainStepper integration
StepOutputLoss(a pass-through toStepLosswhencorrector_loss is None), so_accumulate_losshas a single uniform call site with no per-feature branching.result.total()once per optimized step, which structurally eliminates the double-backward-under-gradient-accumulation error (no separateaccumulate_losscall for the penalty to backward a freed graph).detach_uncorrected = not loss_obj.needs_uncorrected_gradinto the predict generator. Detaching the pre-correction values would zero the(C′ − I)term of the correction gradient and silently corrupt both features.loss_step_{i}from.main,corrector_regularization_step_{i}from.regularization, epoch-aggregatedcorrector_regularization, per-channel losses from.main.Coupled support (per-realm metrics, coupled pre-corrector optimization, and the coupled regularization drafted in #1218) is out of scope, deferred to a later PR.
Acceptance criteria
StepOutputLossunit-tested directly with a deterministicStepOutput: pre-corrector target swap (incl. exclusions), the analytically expected penalty, the corrected-channels-only reg scale, the.main/.regularization/.total()decomposition, andneeds_uncorrected_gradfor each config combination.use_gradient_accumulation=True: the accumulation path detaches rollout state between steps, and tests show per-step losses/penalties and gradients through corrections survive it (the singleresult.total()accumulation must not double-backward a freed graph).__post_init__tests forCorrectorRegularizationConfig's rejected loss types andCorrectorLossConfig's both-Noneerror; a build-time test that acorrector_lossconfigured against a stepper withcorrector_enabled is Falseraises.Blocked by
Original two-config design (superseded — kept for posterity)
The original breakdown kept #1222's pre-corrector optimization and #1218's regularization as two independent configs on the train stepper, unified only implicitly. It was superseded by the
StepOutputLossdesign above (Jeremy's suggestion) to make both features flow through one swappable, jointly-configurable Loss-like object. Substantive differences in the new version: a singleCorrectorLossConfigreplaces the two sibling fields; the regularization penalty averages over corrected channels only (an_corrected / n_allrescaling vs the original) rather than diluting by all channels; a newStepABC.corrector_enabledaccessor enables a build-time error for a misconfigured corrector_loss; and the singleresult.total()accumulation structurally avoids the gradient-accumulation double-backward rather than patching it.What to build
Two training features that use the pre-correction values carried by
StepOutput, landed as two commits in one PR. This is PR 3 of a 3-PR split of #1218 and #1222, depending only on #1271.Commit 1 — pre-corrector optimization (adopt from #1222 as-is):
exclude_names_and_prefixesopts variables back to post-corrector targets, using a shared name-and-prefix matcher consolidated with the spatial-masking matching convention (bare name matches 2D + all 3D levels, trailing-underscore prefix matches all levels, explicitname_<level>matches exactly).(C′ − I)term of the correction gradient and silently corrupt both features.)Commit 2 — corrector regularization (adopt from #1218, single-module only):
__post_init__validation rejects ensemble losses, NaN losses, and global-mean loss types.output − uncorrected, undetached) and a zero baseline in loss-normalized space, accumulated into the training loss per optimized step.corrector_regularization_step_{i}) and epoch-aggregated (corrector_regularization) metrics; the trainer's standardbatch_prefix yieldsbatch_corrector_regularizationwith no extra work.Co-authored-by: Jeremy McGibbon <jeremy.mcgibbon@gmail.com>.Coupled support (per-realm metrics, coupled pre-corrector optimization, and the coupled regularization drafted in #1218) is out of scope, deferred to a later PR.
Acceptance criteria
use_gradient_accumulation=True: the accumulation path detaches rollout state between steps, and tests show per-step losses/penalties and gradients through corrections survive it.__post_init__tests for the regularization config's rejected loss types.Blocked by