Skip to content

Replace step return tuple with StepOutput dataclass carrying pre-correction values#1283

Open
jpdunc23 wants to merge 4 commits into
mainfrom
feature/step-output
Open

Replace step return tuple with StepOutput dataclass carrying pre-correction values#1283
jpdunc23 wants to merge 4 commits into
mainfrom
feature/step-output

Conversation

@jpdunc23

@jpdunc23 jpdunc23 commented Jun 16, 2026

Copy link
Copy Markdown
Member

ACE steppers can apply correctors that adjust the network's raw output each step, but today the corrected and uncorrected outputs are conflated the moment step returns, so nothing downstream can interrogate the correction itself. This PR replaces the step abstraction's tuple[TensorDict, StepperState | None] return with a StepOutput dataclass that additionally carries the pre-correction values of the corrector-modified variables. It is pure plumbing within the step/stepper layer: no user-visible behavior change, and existing checkpoints load unchanged.

StepOutput.uncorrected holds the pre-correction values (not deltas) of exactly the variables a corrector modified, captured at the corrector boundary by tensor-identity detection of out-of-place edits and detached unconditionally; it is an empty dict when no corrector ran, so consumers need no None checks. The correction is derivable as output - uncorrected. Ocean and prescribed-prognostic adjustments run after the corrector and are excluded; the rollout always feeds the corrected output forward as state; fme.coupled only unwraps StepOutput.output. The corrector ABC is intentionally left unchanged (its existing corrector_state tuple is untouched).

This PR is now scoped to the step/stepper layer only. Per review feedback, the consumer-facing plumbing that surfaces uncorrected on prediction BatchData/PairedData has been moved to the follow-on aggregator PR (#1284), where it can be reviewed together with the metrics that consume it.

Changes:

  • fme.core.step.step.StepOutput: new dataclass (output, stepper_state, uncorrected) replacing the step return tuple; StepABC.step and all implementations (SingleModuleStep, SecondaryModuleStep, SeparateRadiationStep, FCN3Step) return it.

  • fme.core.corrector.utils.captured_before: new helper returning the pre-correction values of variables modified out-of-place.

  • fme.core.step.single_module.step_with_adjustments: taps the corrector boundary to build the detached uncorrected shadow and returns StepOutput.

  • fme.core.step.multi_call / _multi_call: MultiCallStep composes the wrapped step's shadow; the MultiCall helper returns an empty shadow.

  • fme.ace.stepper.single_module.Stepper.step: applies the name-preserving output process func to the shadow and returns StepOutput; predict_generator yields StepOutput and feeds the corrected output forward; _stack_step_outputs/process_prediction_generator_list consume StepOutput (stacking only output and threading the terminal stepper_state).

  • fme.coupled.stepper.CoupledStepper: unwraps StepOutput.output.

  • Tests added

  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

Resolves #1271

🤖 Generated with Claude Code

@jpdunc23 jpdunc23 changed the title Replace step return tuple with StepOutput dataclass carrying pre-correction values Replace step return tuple with StepOutput dataclass carrying pre-correction values Jun 16, 2026

@mcgibbon mcgibbon left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like many BatchData methods like comptue_derived_variables drop these uncorrected values. If it's intentional it should get tested, if not it should get tested and fixed.

I would suggest taking this PR and moving all changes to BatchData/PairedData into the follow-on aggregator PR. The StepOutput changes look pretty mergeable as-is, though I'll do a full agent review and check of that review once you've done the move.

Comment thread fme/ace/data_loading/batch_data.py Outdated
n_ensemble: int = 1
data_mask: TensorMapping | None = None
stepper_state: StepperState | None = None
uncorrected: TensorMapping | None = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: This attribute is very public, but it pertains to the low-level details of a specific Step implementation. Can we avoid tightly coupling BatchData to that Step using an encapsulation/hiding strategy like we did for StepperState (which also contains information specific to certain Step implementations, but only makes that information available to those implementations)?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: BatchData.broadcast_ensemble silently drops uncorrected

If we add this as a container instead of this specific low-level instance, then future private data we add could be covered by the same code/tests that we add at this stage to cover a e.g. a StepperMetrics container being added.

Comment thread fme/ace/data_loading/batch_data.py Outdated
n_ensemble: int = 1,
data_mask: TensorMapping | None = None,
stepper_state: StepperState | None = None,
uncorrected: TensorMapping | None = None,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: uncorrected got added to new methods for BatchData but not for PairedData, would be nice to keep them consistent.

…ection values

StepABC.step now returns a StepOutput(output, stepper_state, uncorrected)
dataclass instead of a tuple[TensorDict, StepperState | None]. The new
uncorrected field is a sparse, detached snapshot of the pre-correction values
of exactly the variables a corrector modified, so downstream features can
derive the correction (output - uncorrected) or use the raw pre-correction
values without re-running the model. It is an empty dict when no corrector ran,
so consumers need no None checks. stepper_state keeps its existing
passthrough semantics.

step_with_adjustments captures the shadow at the corrector boundary via a new
captured_before helper (tensor-identity detection of out-of-place edits),
detaching unconditionally; ocean and prescribed-prognostic adjustments run
after the corrector and are intentionally excluded. The corrector ABC is left
unchanged. All step implementations (single/secondary/radiation/fcn3) inherit
the new return type; MultiCallStep composes its wrapped step's shadow and the
MultiCall helper returns an empty shadow. The rollout in predict_generator
always feeds the corrected output forward as state; Stepper.step applies the
name-preserving output process func to the shadow too.

This PR is pure StepOutput-through-step plumbing: the per-step
StepOutput.uncorrected is computed at the corrector boundary but discarded at
the Stepper.predict boundary, so predict returns its existing corrected-only
BatchData and no BatchData/PairedData surface changes. Carrying the
uncorrected series on the prediction is deferred to the correction-metrics PR
(#1284), which introduces an encapsulated, time-aware container for it.

Pure plumbing: no user-visible behavior change, and existing checkpoints load
unchanged. Adds step- and stepper-seam tests plus captured_before unit tests;
the spatial-parallel step regression matrix passes unchanged under torchrun.

Part of #1271 (PR 1 of the #1218/#1222 split).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@jpdunc23 jpdunc23 force-pushed the feature/step-output branch from d6e1d56 to cb8029d Compare June 17, 2026 23:39
jpdunc23 added a commit that referenced this pull request Jun 17, 2026
Adds normalized-space metrics of the corrector's correction
(output - uncorrected) to the inference aggregators, plus an optional
denormalized correction netCDF, on by default behind aggregator config flags.
This is PR 2 of the 3-PR split of #1218/#1222 and builds on the StepOutput
plumbing from #1271.

Carriage: the pre-correction ``uncorrected`` series is carried from
Stepper.predict to the consumers through a new opaque, time-aware container,
StepDiagnostics (fme/ace/data_loading/step_diagnostics.py), instead of a raw
public field on BatchData. It follows the StepperState encapsulation pattern:
BatchData and PairedData each hold a single opaque step_diagnostics field
(default None) and never inspect its contents. Unlike StepperState (terminal
per-sample state), the payload is a per-timestep diagnostic series, so the
container is time-aware: it is forwarded by reference through every
structure-preserving method and time-sliced/padded alongside data by the
time-touching ones (select_time_slice, remove_initial_condition,
get_start/get_end, prepend), scattered by scatter_spatial, broadcast by
broadcast_ensemble, and moved by to_device/to_cpu/pin_memory. __post_init__
validates its leading sample dim like stepper_state. This fixes the
silently-dropped path the reviewer flagged on #1283: compute_derived_variables
and PairedData.from_batch_data now preserve the series, so it survives the real
inference loop. Stepper.predict builds the container from the stacked per-step
StepOutput shadows and attaches it to the prediction. The correction aggregator
and netCDF writer reach into it via the single get_uncorrected accessor.

Metrics (computed as normalize(output) - normalize(uncorrected) per corrected
key, using the network normalizer the existing *_norm metrics use):

- inference/time_mean_norm/correction_magnitude/{var}: area-weighted global mean
  of the time-mean of |normalized correction|, plus a channel_mean over the
  corrected variables only.
- inference/time_mean_norm/correction_map/{var}: signed time-mean map, logged as
  an image and flushed to time_mean_norm_correction_diagnostics.nc.
- inference/mean_norm/weighted_correction_magnitude/{var}: per-step area-weighted
  global mean of |normalized correction|.
- inference/mean_norm/weighted_correction_std/{var}: per-step area-weighted
  spatial std of the signed normalized correction (mirrors weighted_std_gen).

These live in a new fme/ace/aggregator/inference/correction.py with dedicated
CorrectionTimeMeanAggregator / CorrectionMeanAggregator and a CorrectionRecorder
shared by both inference aggregators. They are kept in a separate group merged
into the existing time_mean_norm / mean_norm label groups, so the time-series
table uses a distinct "correction_series" key that to_inference_logs resolves to
the same prefix without colliding with the main series table.

Availability and gating:

- Time-mean metrics in all inference types; time-series metrics only in
  standalone evaluator and no-target inference (inline training drops them via
  the existing enable_time_series path).
- The no-target inference aggregator now receives the stepper's network
  normalizer (plumbed through InferenceAggregatorConfig.build and the inference
  job), introducing mean_norm / time_mean_norm groups there containing only
  correction metrics. Correction metrics are skipped when no normalizer is
  available, preserving backward compatibility for callers that omit it.
- log_correction_metrics: bool = True on both the evaluator and no-target
  aggregator configs. No effect when the stepper has no corrector: the
  container's uncorrected mapping is empty and the correction aggregators stay
  silent.

Disk output:

- save_correction_files: bool = False on DataWriterConfig writes
  autoregressive_corrections.nc with the denormalized correction time series
  (output - uncorrected, physical units, with variable metadata) for the sparse
  corrected variables, respecting the save-names subset and time-coarsening, via
  a single-source RawDataWriter in PairedDataWriter.

The uncorrected/-prefixed error metrics from #1222 are intentionally dropped.

Adds a shared parametrized round-trip test asserting the container survives (and
stays time-aligned through) every structure-preserving method on BatchData and
PairedData, so a future method that forgets to thread it fails CI; aggregator
unit tests asserting exact magnitude/std/map/channel_mean values for a
constant-offset correction and the flag-off/no-corrector silence paths; writer
tests for the sparse denormalized file (incl. time-coarsening); config
validation/defaults tests; and an end-to-end train+inference test asserting
time-mean correction metrics on the inline inference-loop path (series dropped),
per-step series in standalone inference, and the corrections netCDF.

Part of #1272 (PR 2 of the #1218/#1222 split).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
jpdunc23 and others added 3 commits June 17, 2026 16:54
The main merge brought in test_step_masked_nan_input_does_not_raise
(from #1297), which unpacked stepper.step(...) as a tuple. Use
StepOutput.output instead.
@jpdunc23

Copy link
Copy Markdown
Member Author

🤖 Posted by Claude Code on James's behalf.

Thanks @mcgibbon. Done — I've moved all BatchData/PairedData changes out of this PR. It's now scoped to the step/stepper layer only:

  • BatchData.uncorrected / PairedData.uncorrected and their threading through the movement methods are gone from this branch (no diff to fme/ace/data_loading/batch_data.py vs main).
  • Stepper.predict no longer stacks the per-step shadows onto the returned BatchData; _stack_step_outputs stacks only output and threads the terminal stepper_state. The uncorrected shadow still flows through StepOutput at the step/stepper boundary, but stops there.
  • The consumer-facing plumbing (and therefore the compute_derived_variables-drops-uncorrected question) now lives in the follow-on aggregator PR Add correction inference metrics and optional correction netCDF output #1284, where the metrics that consume it can be reviewed and tested together.

The description has been updated to reflect the narrowed scope. Ready for the full agent review whenever you are.

return out


def captured_before(original: TensorMapping, corrected: TensorMapping) -> TensorDict:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: I'm wary of relying on the memory mutation behavior of the step code (which isn't constrained by a test to have this property) to determine...

Issue: it's not very clear what this function does based on its name.



def captured_before(original: TensorMapping, corrected: TensorMapping) -> TensorDict:
"""Return the pre-correction values of variables a corrector modified.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: the relationship between "a corrector" and these inputs should be concretely defined, if we are going to define the output based on that relationship.

# adjustments run after and are intentionally excluded). Detached
# unconditionally here: unused on the train path, so this avoids
# retaining the autograd graph. The correction is output - uncorrected.
uncorrected = {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clarifying: This dict contains all output variables, correct?

Comment thread fme/core/step/step.py

output: TensorDict
stepper_state: StepperState | None
uncorrected: TensorDict = dataclasses.field(default_factory=dict)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: The corrector doesn't strictly apply corrections (despite the name), it also does things like completely derive the tendency of water vapor due to advection from-scratch, ignoring the NN output. We likely don't want to optimize these unused weights. I'm actually not sure whether the current code would or wouldn't optimize them, because of the memory-update-based logic for selection.

If we supplied values of corrections as direct metrics, we'd have a direct pathway to control whether a corrector action is classified as a "correction" or not, by not including non-correction effects (like derived overwrites, or residual relationships) in the correction metric.

Claude was particularly concerned about this correction, which would allegedly conflict with the residual prediction of hfds and should probably not be included in pre-corrector loss or corrector regularization:

fme/core/corrector/ocean.py:260-261
if method == "residual_prediction":
    out[hfds_name] = net_flux * ocean_fraction + gen_hfds

@mcgibbon

Copy link
Copy Markdown
Contributor

I would suggest re-framing from uncorrected values to the value of the correction being applied, and only storing corrections for variables that get corrected. This change would resolve several issues in this and the next PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Replace step return tuple with a StepOutput dataclass carrying pre-correction values

2 participants