What to build
Replace the tuple[TensorDict, StepperState | None] return of the step abstraction with a StepOutput dataclass, and plumb pre-correction values up through the stepper so prediction BatchData optionally carries them. Pure plumbing: no user-visible behavior change.
This is PR 1 of a 3-PR split of #1218 and #1222, both of which predate the StepperState refactor on main. Branch from current main HEAD. Follow-up issues cover correction inference metrics and corrector training features.
Design decisions (settled in design review):
- Named
StepOutput, following the repo's *Output convention (TrainOutput, LossOutput); *Result is reserved for benchmarks.
- Fields:
output (denormalized post-adjustment data), stepper_state (unchanged passthrough semantics), and uncorrected.
uncorrected carries pre-correction values (not deltas) of exactly the corrector-modified variables; empty dict when no corrector ran, so consumers need no None checks. Deltas are always derivable as output − uncorrected; values are not recoverable from deltas, and the training-features follow-up needs the values.
StepperState is not used to carry uncorrected data: its contract is recurrent per-sample state threaded across predict windows, while uncorrected values are a per-window, per-timestep diagnostic series.
- Scope: the step package and corrector tap, the stepper's step/predict generator plumbing, and an optional
uncorrected field on BatchData (following the data_mask/stepper_state optional-field precedent), populated only on prediction BatchData returned by predict. No return_uncorrected argument or overloads. PairedData exposes the uncorrected prediction for the metrics follow-up. If the BatchData layer bloats this PR, it may slip to the metrics PR.
- Uncorrected tensors are detached unconditionally here; attach/detach control arrives in the training-features follow-up with its first consumer.
- The output-process function (e.g. spatial masking) applies per-variable and name-preserving, so it is applied to the uncorrected subset identically.
- The rollout always feeds corrected output forward as state.
fme.coupled touches only what is needed to unwrap StepOutput.output; all coupled feature support is deferred to a later PR.
Acceptance criteria
Blocked by
None - can start immediately.
Notes
Source PRs serve as reference implementations only; neither merges as-is (#1222 is explicitly marked don't-merge): #1222 for the corrector tap and step-return change, #1218 for its alternative StepResult framing. Where they disagree, the decisions above win.
What to build
Replace the
tuple[TensorDict, StepperState | None]return of the step abstraction with aStepOutputdataclass, and plumb pre-correction values up through the stepper so predictionBatchDataoptionally carries them. Pure plumbing: no user-visible behavior change.This is PR 1 of a 3-PR split of #1218 and #1222, both of which predate the
StepperStaterefactor on main. Branch from current main HEAD. Follow-up issues cover correction inference metrics and corrector training features.Design decisions (settled in design review):
StepOutput, following the repo's*Outputconvention (TrainOutput,LossOutput);*Resultis reserved for benchmarks.output(denormalized post-adjustment data),stepper_state(unchanged passthrough semantics), anduncorrected.uncorrectedcarries pre-correction values (not deltas) of exactly the corrector-modified variables; empty dict when no corrector ran, so consumers need no None checks. Deltas are always derivable asoutput − uncorrected; values are not recoverable from deltas, and the training-features follow-up needs the values.StepperStateis not used to carry uncorrected data: its contract is recurrent per-sample state threaded across predict windows, while uncorrected values are a per-window, per-timestep diagnostic series.uncorrectedfield onBatchData(following thedata_mask/stepper_stateoptional-field precedent), populated only on predictionBatchDatareturned bypredict. Noreturn_uncorrectedargument or overloads.PairedDataexposes the uncorrected prediction for the metrics follow-up. If theBatchDatalayer bloats this PR, it may slip to the metrics PR.fme.coupledtouches only what is needed to unwrapStepOutput.output; all coupled feature support is deferred to a later PR.Acceptance criteria
StepOutput, withuncorrectedpopulated exactly when a corrector modifies variables and an empty dict otherwise;stepper_statesemantics unchanged. (Prior art: the step package's existing unit tests; updated versions exist in both source PRs.)predictreturns predictionBatchDatawithuncorrectedset; rollout state is the corrected output; uncorrected tensors are detached. (Prior art: existing stepper unit tests.)Blocked by
None - can start immediately.
Notes
Source PRs serve as reference implementations only; neither merges as-is (#1222 is explicitly marked don't-merge): #1222 for the corrector tap and step-return change, #1218 for its alternative
StepResultframing. Where they disagree, the decisions above win.