Skip to content

Corrector training features: pre-corrector loss optimization and corrector regularization #1273

Description

@jpdunc23

What to build

Two corrector training features, unified behind a single StepOutputLoss object (per Jeremy's suggestion): a Loss-like object that takes a StepOutput and folds in both the pre-corrector loss target and the correction-magnitude penalty, jointly configurable. Landed as two commits in one PR. This is PR 3 of a 3-PR split of #1218 and #1222, depending only on #1271.

Both features operate in the stepper's loss-normalized space (get_loss_normalizer), consistent with how StepLoss already normalizes, so there is no normalizer mismatch between the main loss and the penalty.

StepOutputLoss (in fme/core/loss.py, alongside StepLoss)

Wraps the existing StepLoss (the main loss, unchanged) plus an optional regularization WeightedMappingLoss, a reg weight, and an optional pre-corrector exclusion matcher.

  • Call signature: __call__(step_output, target_step, step, data_mask) -> StepOutputLossResult, unfolding the ensemble dim internally (n_ensemble fixed at build).
  • StepOutputLossResult exposes .main (a LossOutput), .regularization (a Tensor | None), and .total() == main.total() + regularization.
  • Exposes needs_uncorrected_grad: bool (true iff either feature is configured), used by the stepper to decide detaching.

Commit 1 — pre-corrector optimization (from #1222)

  • Realized inside StepOutputLoss.__call__: the main-loss target overrides corrected values with step_output.uncorrected for every corrector-modified variable not matched by the exclusion matcher, then defers to the wrapped StepLoss.
  • exclude_names_and_prefixes uses the shared name-and-prefix matcher consolidated with the spatial-masking convention (bare name matches 2D + all 3D levels, trailing-underscore prefix matches all levels, explicit name_<level> matches exactly).
  • The exclusion matcher and its warn-once validation of unmatched entries live on the loss object, since it is what sees step_output.uncorrected.keys() each step.
  • Rollout state and returned predictions always use fully corrected outputs; only the loss target changes.
  • Adds a new corrector_enabled: bool abstract property on StepABC, implemented by every step type. The TrainStepper raises at build time if corrector_loss is configured but corrector_enabled is False (a config mistake that would otherwise train standard with no signal). A first-step warn-once covers the case where a corrector runs but uncorrected comes back empty.

Commit 2 — corrector regularization (from #1218, single-module only)

  • Derives corrections as output − uncorrected over the sparse corrected keys (subsuming Add corrector regularization to training #1218's separate dense corrections payload — no extra field on StepOutput is needed) and penalizes them against a zero baseline.
  • Penalty is the mean over corrected channels only (consistent with PR2's channel_mean), realized via the existing LossOutput active-channel masking. ⚠️ This differs from Add corrector regularization to training #1218's all-channel dilution by a factor of n_corrected / n_all, so weights tuned in exp/2026-06-03-corrector-regularization must be rescaled — call this out in the PR body.
  • The reg term carries no per-step decay (only weight scales it, matching Add corrector regularization to training #1218); document the asymmetry vs the main loss's sqrt_loss_step_decay in the config docstring.
  • Per-step (corrector_regularization_step_{i}) and epoch-aggregated (corrector_regularization) metrics; the trainer's standard batch_ prefix yields batch_corrector_regularization with no extra work.
  • Credits Co-authored-by: Jeremy McGibbon <jeremy.mcgibbon@gmail.com>.

Config (single CorrectorLossConfig, in new fme/core/corrector/loss.py)

  • Holds precorrector_optimization: PreCorrectorOptimizationConfig | None and regularization: CorrectorRegularizationConfig | None; its __post_init__ errors if both are None.
  • One corrector_loss: CorrectorLossConfig | None field on the train stepper config replaces the two sibling fields from the source PRs.
  • CorrectorRegularizationConfig.__post_init__ rejects ensemble losses, NaN losses, and global-mean loss types.
  • Combining the two features is first-class (one config carries both); no validation forbids it, and the interaction is documented in the config docstrings. Validating the combined regime is out of scope; the combination is merely permitted.
  • Import is one-way (corrector/loss.py → loss.py): the config's build() assembles the primitives StepOutputLoss takes, so StepOutputLoss itself imports no config.

TrainStepper integration

  • Always builds one StepOutputLoss (a pass-through to StepLoss when corrector_loss is None), so _accumulate_loss has a single uniform call site with no per-feature branching.
  • Accumulates result.total() once per optimized step, which structurally eliminates the double-backward-under-gradient-accumulation error (no separate accumulate_loss call for the penalty to backward a freed graph).
  • Passes detach_uncorrected = not loss_obj.needs_uncorrected_grad into the predict generator. Detaching the pre-correction values would zero the (C′ − I) term of the correction gradient and silently corrupt both features.
  • Metrics preserved: loss_step_{i} from .main, corrector_regularization_step_{i} from .regularization, epoch-aggregated corrector_regularization, per-channel losses from .main.

Coupled support (per-realm metrics, coupled pre-corrector optimization, and the coupled regularization drafted in #1218) is out of scope, deferred to a later PR.

Acceptance criteria

  • StepOutputLoss unit-tested directly with a deterministic StepOutput: pre-corrector target swap (incl. exclusions), the analytically expected penalty, the corrected-channels-only reg scale, the .main/.regularization/.total() decomposition, and needs_uncorrected_grad for each config combination.
  • With a deterministic corrector, the training loss numerically equals the loss computed against pre-correction values, including exclusion behavior and warn-once validation. (Prior art: train-on-batch tests in both source PRs.)
  • Regularization adds the analytically expected penalty, with per-step and epoch-aggregated metrics logged.
  • Gradients flow through the correction (e.g. corrector-dependent parameters receive grad) when either feature is configured; plain prediction still detaches.
  • The two features compose in one run.
  • Both features work under use_gradient_accumulation=True: the accumulation path detaches rollout state between steps, and tests show per-step losses/penalties and gradients through corrections survive it (the single result.total() accumulation must not double-backward a freed graph).
  • Config / build validation: __post_init__ tests for CorrectorRegularizationConfig's rejected loss types and CorrectorLossConfig's both-None error; a build-time test that a corrector_loss configured against a stepper with corrector_enabled is False raises.
  • All new configs are optional with defaults preserving current behavior; existing checkpoints load unchanged.

Blocked by


Original two-config design (superseded — kept for posterity)

The original breakdown kept #1222's pre-corrector optimization and #1218's regularization as two independent configs on the train stepper, unified only implicitly. It was superseded by the StepOutputLoss design above (Jeremy's suggestion) to make both features flow through one swappable, jointly-configurable Loss-like object. Substantive differences in the new version: a single CorrectorLossConfig replaces the two sibling fields; the regularization penalty averages over corrected channels only (a n_corrected / n_all rescaling vs the original) rather than diluting by all channels; a new StepABC.corrector_enabled accessor enables a build-time error for a misconfigured corrector_loss; and the single result.total() accumulation structurally avoids the gradient-accumulation double-backward rather than patching it.

What to build

Two training features that use the pre-correction values carried by StepOutput, landed as two commits in one PR. This is PR 3 of a 3-PR split of #1218 and #1222, depending only on #1271.

Commit 1 — pre-corrector optimization (adopt from #1222 as-is):

  • An optional pre-corrector optimization config on the train stepper config; presence of the object is the on/off switch.
  • Loss for corrector-modified variables is computed against pre-correction values; exclude_names_and_prefixes opts variables back to post-corrector targets, using a shared name-and-prefix matcher consolidated with the spatial-masking matching convention (bare name matches 2D + all 3D levels, trailing-underscore prefix matches all levels, explicit name_<level> matches exactly).
  • Warn-once runtime validation of exclusion entries against the actual corrector-modified variable set on the first training step, so config typos surface instead of silently doing nothing.
  • Rollout state and returned predictions always use fully corrected outputs; only the loss target changes.
  • The stepper keeps uncorrected tensors attached to the autograd graph when either feature here is configured, detached otherwise. (Detaching the pre-correction values would zero the (C′ − I) term of the correction gradient and silently corrupt both features.)

Commit 2 — corrector regularization (adopt from #1218, single-module only):

  • A regularization config wrapping a loss config plus weight; __post_init__ validation rejects ensemble losses, NaN losses, and global-mean loss types.
  • Penalty computed between per-variable corrections (derived as output − uncorrected, undetached) and a zero baseline in loss-normalized space, accumulated into the training loss per optimized step.
  • Per-step (corrector_regularization_step_{i}) and epoch-aggregated (corrector_regularization) metrics; the trainer's standard batch_ prefix yields batch_corrector_regularization with no extra work.
  • Combining regularization with pre-corrector optimization is allowed (no validation forbids it); document the interaction in both config docstrings. Validating the combined training regime is out of scope; the combination is merely permitted.
  • This commit credits Co-authored-by: Jeremy McGibbon <jeremy.mcgibbon@gmail.com>.

Coupled support (per-realm metrics, coupled pre-corrector optimization, and the coupled regularization drafted in #1218) is out of scope, deferred to a later PR.

Acceptance criteria

  • With a deterministic corrector, the training loss numerically equals the loss computed against pre-correction values, including exclusion behavior and warn-once validation. (Prior art: train-on-batch tests in both source PRs.)
  • Regularization adds the analytically expected penalty, with per-step and epoch-aggregated metrics logged.
  • Gradients flow through the correction (e.g. corrector-dependent parameters receive grad) when either feature is configured; plain prediction still detaches.
  • The two features compose in one run.
  • Both features work under use_gradient_accumulation=True: the accumulation path detaches rollout state between steps, and tests show per-step losses/penalties and gradients through corrections survive it.
  • __post_init__ tests for the regularization config's rejected loss types.
  • All new configs are optional with defaults preserving current behavior; existing checkpoints load unchanged.

Blocked by

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions