Skip to content

Replace step return tuple with a StepOutput dataclass carrying pre-correction values #1271

Description

@jpdunc23

What to build

Replace the tuple[TensorDict, StepperState | None] return of the step abstraction with a StepOutput dataclass, and plumb pre-correction values up through the stepper so prediction BatchData optionally carries them. Pure plumbing: no user-visible behavior change.

This is PR 1 of a 3-PR split of #1218 and #1222, both of which predate the StepperState refactor on main. Branch from current main HEAD. Follow-up issues cover correction inference metrics and corrector training features.

Design decisions (settled in design review):

  • Named StepOutput, following the repo's *Output convention (TrainOutput, LossOutput); *Result is reserved for benchmarks.
  • Fields: output (denormalized post-adjustment data), stepper_state (unchanged passthrough semantics), and uncorrected.
  • uncorrected carries pre-correction values (not deltas) of exactly the corrector-modified variables; empty dict when no corrector ran, so consumers need no None checks. Deltas are always derivable as output − uncorrected; values are not recoverable from deltas, and the training-features follow-up needs the values.
  • StepperState is not used to carry uncorrected data: its contract is recurrent per-sample state threaded across predict windows, while uncorrected values are a per-window, per-timestep diagnostic series.
  • Scope: the step package and corrector tap, the stepper's step/predict generator plumbing, and an optional uncorrected field on BatchData (following the data_mask/stepper_state optional-field precedent), populated only on prediction BatchData returned by predict. No return_uncorrected argument or overloads. PairedData exposes the uncorrected prediction for the metrics follow-up. If the BatchData layer bloats this PR, it may slip to the metrics PR.
  • Uncorrected tensors are detached unconditionally here; attach/detach control arrives in the training-features follow-up with its first consumer.
  • The output-process function (e.g. spatial masking) applies per-variable and name-preserving, so it is applied to the uncorrected subset identically.
  • The rollout always feeds corrected output forward as state.
  • fme.coupled touches only what is needed to unwrap StepOutput.output; all coupled feature support is deferred to a later PR.

Acceptance criteria

  • Every step implementation returns StepOutput, with uncorrected populated exactly when a corrector modifies variables and an empty dict otherwise; stepper_state semantics unchanged. (Prior art: the step package's existing unit tests; updated versions exist in both source PRs.)
  • predict returns prediction BatchData with uncorrected set; rollout state is the corrected output; uncorrected tensors are detached. (Prior art: existing stepper unit tests.)
  • The existing distributed/spatial-parallel step test matrix passes (baselines generated single-rank, verified under torchrun, per repo convention).
  • No user-visible behavior change; existing checkpoints load unchanged.

Blocked by

None - can start immediately.

Notes

Source PRs serve as reference implementations only; neither merges as-is (#1222 is explicitly marked don't-merge): #1222 for the corrector tap and step-return change, #1218 for its alternative StepResult framing. Where they disagree, the decisions above win.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions