Replace step return tuple with StepOutput dataclass carrying pre-correction values#1283
Replace step return tuple with StepOutput dataclass carrying pre-correction values#1283jpdunc23 wants to merge 4 commits into
StepOutput dataclass carrying pre-correction values#1283Conversation
StepOutput dataclass carrying pre-correction values
mcgibbon
left a comment
There was a problem hiding this comment.
It looks like many BatchData methods like comptue_derived_variables drop these uncorrected values. If it's intentional it should get tested, if not it should get tested and fixed.
I would suggest taking this PR and moving all changes to BatchData/PairedData into the follow-on aggregator PR. The StepOutput changes look pretty mergeable as-is, though I'll do a full agent review and check of that review once you've done the move.
| n_ensemble: int = 1 | ||
| data_mask: TensorMapping | None = None | ||
| stepper_state: StepperState | None = None | ||
| uncorrected: TensorMapping | None = None |
There was a problem hiding this comment.
Issue: This attribute is very public, but it pertains to the low-level details of a specific Step implementation. Can we avoid tightly coupling BatchData to that Step using an encapsulation/hiding strategy like we did for StepperState (which also contains information specific to certain Step implementations, but only makes that information available to those implementations)?
There was a problem hiding this comment.
Issue: BatchData.broadcast_ensemble silently drops uncorrected
If we add this as a container instead of this specific low-level instance, then future private data we add could be covered by the same code/tests that we add at this stage to cover a e.g. a StepperMetrics container being added.
| n_ensemble: int = 1, | ||
| data_mask: TensorMapping | None = None, | ||
| stepper_state: StepperState | None = None, | ||
| uncorrected: TensorMapping | None = None, |
There was a problem hiding this comment.
Issue: uncorrected got added to new methods for BatchData but not for PairedData, would be nice to keep them consistent.
…ection values StepABC.step now returns a StepOutput(output, stepper_state, uncorrected) dataclass instead of a tuple[TensorDict, StepperState | None]. The new uncorrected field is a sparse, detached snapshot of the pre-correction values of exactly the variables a corrector modified, so downstream features can derive the correction (output - uncorrected) or use the raw pre-correction values without re-running the model. It is an empty dict when no corrector ran, so consumers need no None checks. stepper_state keeps its existing passthrough semantics. step_with_adjustments captures the shadow at the corrector boundary via a new captured_before helper (tensor-identity detection of out-of-place edits), detaching unconditionally; ocean and prescribed-prognostic adjustments run after the corrector and are intentionally excluded. The corrector ABC is left unchanged. All step implementations (single/secondary/radiation/fcn3) inherit the new return type; MultiCallStep composes its wrapped step's shadow and the MultiCall helper returns an empty shadow. The rollout in predict_generator always feeds the corrected output forward as state; Stepper.step applies the name-preserving output process func to the shadow too. This PR is pure StepOutput-through-step plumbing: the per-step StepOutput.uncorrected is computed at the corrector boundary but discarded at the Stepper.predict boundary, so predict returns its existing corrected-only BatchData and no BatchData/PairedData surface changes. Carrying the uncorrected series on the prediction is deferred to the correction-metrics PR (#1284), which introduces an encapsulated, time-aware container for it. Pure plumbing: no user-visible behavior change, and existing checkpoints load unchanged. Adds step- and stepper-seam tests plus captured_before unit tests; the spatial-parallel step regression matrix passes unchanged under torchrun. Part of #1271 (PR 1 of the #1218/#1222 split). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
d6e1d56 to
cb8029d
Compare
Adds normalized-space metrics of the corrector's correction (output - uncorrected) to the inference aggregators, plus an optional denormalized correction netCDF, on by default behind aggregator config flags. This is PR 2 of the 3-PR split of #1218/#1222 and builds on the StepOutput plumbing from #1271. Carriage: the pre-correction ``uncorrected`` series is carried from Stepper.predict to the consumers through a new opaque, time-aware container, StepDiagnostics (fme/ace/data_loading/step_diagnostics.py), instead of a raw public field on BatchData. It follows the StepperState encapsulation pattern: BatchData and PairedData each hold a single opaque step_diagnostics field (default None) and never inspect its contents. Unlike StepperState (terminal per-sample state), the payload is a per-timestep diagnostic series, so the container is time-aware: it is forwarded by reference through every structure-preserving method and time-sliced/padded alongside data by the time-touching ones (select_time_slice, remove_initial_condition, get_start/get_end, prepend), scattered by scatter_spatial, broadcast by broadcast_ensemble, and moved by to_device/to_cpu/pin_memory. __post_init__ validates its leading sample dim like stepper_state. This fixes the silently-dropped path the reviewer flagged on #1283: compute_derived_variables and PairedData.from_batch_data now preserve the series, so it survives the real inference loop. Stepper.predict builds the container from the stacked per-step StepOutput shadows and attaches it to the prediction. The correction aggregator and netCDF writer reach into it via the single get_uncorrected accessor. Metrics (computed as normalize(output) - normalize(uncorrected) per corrected key, using the network normalizer the existing *_norm metrics use): - inference/time_mean_norm/correction_magnitude/{var}: area-weighted global mean of the time-mean of |normalized correction|, plus a channel_mean over the corrected variables only. - inference/time_mean_norm/correction_map/{var}: signed time-mean map, logged as an image and flushed to time_mean_norm_correction_diagnostics.nc. - inference/mean_norm/weighted_correction_magnitude/{var}: per-step area-weighted global mean of |normalized correction|. - inference/mean_norm/weighted_correction_std/{var}: per-step area-weighted spatial std of the signed normalized correction (mirrors weighted_std_gen). These live in a new fme/ace/aggregator/inference/correction.py with dedicated CorrectionTimeMeanAggregator / CorrectionMeanAggregator and a CorrectionRecorder shared by both inference aggregators. They are kept in a separate group merged into the existing time_mean_norm / mean_norm label groups, so the time-series table uses a distinct "correction_series" key that to_inference_logs resolves to the same prefix without colliding with the main series table. Availability and gating: - Time-mean metrics in all inference types; time-series metrics only in standalone evaluator and no-target inference (inline training drops them via the existing enable_time_series path). - The no-target inference aggregator now receives the stepper's network normalizer (plumbed through InferenceAggregatorConfig.build and the inference job), introducing mean_norm / time_mean_norm groups there containing only correction metrics. Correction metrics are skipped when no normalizer is available, preserving backward compatibility for callers that omit it. - log_correction_metrics: bool = True on both the evaluator and no-target aggregator configs. No effect when the stepper has no corrector: the container's uncorrected mapping is empty and the correction aggregators stay silent. Disk output: - save_correction_files: bool = False on DataWriterConfig writes autoregressive_corrections.nc with the denormalized correction time series (output - uncorrected, physical units, with variable metadata) for the sparse corrected variables, respecting the save-names subset and time-coarsening, via a single-source RawDataWriter in PairedDataWriter. The uncorrected/-prefixed error metrics from #1222 are intentionally dropped. Adds a shared parametrized round-trip test asserting the container survives (and stays time-aligned through) every structure-preserving method on BatchData and PairedData, so a future method that forgets to thread it fails CI; aggregator unit tests asserting exact magnitude/std/map/channel_mean values for a constant-offset correction and the flag-off/no-corrector silence paths; writer tests for the sparse denormalized file (incl. time-coarsening); config validation/defaults tests; and an end-to-end train+inference test asserting time-mean correction metrics on the inline inference-loop path (series dropped), per-step series in standalone inference, and the corrections netCDF. Part of #1272 (PR 2 of the #1218/#1222 split). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The main merge brought in test_step_masked_nan_input_does_not_raise (from #1297), which unpacked stepper.step(...) as a tuple. Use StepOutput.output instead.
Thanks @mcgibbon. Done — I've moved all
The description has been updated to reflect the narrowed scope. Ready for the full agent review whenever you are. |
| return out | ||
|
|
||
|
|
||
| def captured_before(original: TensorMapping, corrected: TensorMapping) -> TensorDict: |
There was a problem hiding this comment.
Issue: I'm wary of relying on the memory mutation behavior of the step code (which isn't constrained by a test to have this property) to determine...
Issue: it's not very clear what this function does based on its name.
|
|
||
|
|
||
| def captured_before(original: TensorMapping, corrected: TensorMapping) -> TensorDict: | ||
| """Return the pre-correction values of variables a corrector modified. |
There was a problem hiding this comment.
Issue: the relationship between "a corrector" and these inputs should be concretely defined, if we are going to define the output based on that relationship.
| # adjustments run after and are intentionally excluded). Detached | ||
| # unconditionally here: unused on the train path, so this avoids | ||
| # retaining the autograd graph. The correction is output - uncorrected. | ||
| uncorrected = { |
There was a problem hiding this comment.
Clarifying: This dict contains all output variables, correct?
|
|
||
| output: TensorDict | ||
| stepper_state: StepperState | None | ||
| uncorrected: TensorDict = dataclasses.field(default_factory=dict) |
There was a problem hiding this comment.
Issue: The corrector doesn't strictly apply corrections (despite the name), it also does things like completely derive the tendency of water vapor due to advection from-scratch, ignoring the NN output. We likely don't want to optimize these unused weights. I'm actually not sure whether the current code would or wouldn't optimize them, because of the memory-update-based logic for selection.
If we supplied values of corrections as direct metrics, we'd have a direct pathway to control whether a corrector action is classified as a "correction" or not, by not including non-correction effects (like derived overwrites, or residual relationships) in the correction metric.
Claude was particularly concerned about this correction, which would allegedly conflict with the residual prediction of hfds and should probably not be included in pre-corrector loss or corrector regularization:
fme/core/corrector/ocean.py:260-261
if method == "residual_prediction":
out[hfds_name] = net_flux * ocean_fraction + gen_hfds
|
I would suggest re-framing from uncorrected values to the value of the correction being applied, and only storing corrections for variables that get corrected. This change would resolve several issues in this and the next PR. |
ACE steppers can apply correctors that adjust the network's raw output each step, but today the corrected and uncorrected outputs are conflated the moment
stepreturns, so nothing downstream can interrogate the correction itself. This PR replaces the step abstraction'stuple[TensorDict, StepperState | None]return with aStepOutputdataclass that additionally carries the pre-correction values of the corrector-modified variables. It is pure plumbing within the step/stepper layer: no user-visible behavior change, and existing checkpoints load unchanged.StepOutput.uncorrectedholds the pre-correction values (not deltas) of exactly the variables a corrector modified, captured at the corrector boundary by tensor-identity detection of out-of-place edits and detached unconditionally; it is an empty dict when no corrector ran, so consumers need no None checks. The correction is derivable asoutput - uncorrected. Ocean and prescribed-prognostic adjustments run after the corrector and are excluded; the rollout always feeds the corrected output forward as state;fme.coupledonly unwrapsStepOutput.output. The corrector ABC is intentionally left unchanged (its existingcorrector_statetuple is untouched).This PR is now scoped to the step/stepper layer only. Per review feedback, the consumer-facing plumbing that surfaces
uncorrectedon predictionBatchData/PairedDatahas been moved to the follow-on aggregator PR (#1284), where it can be reviewed together with the metrics that consume it.Changes:
fme.core.step.step.StepOutput: new dataclass (output,stepper_state,uncorrected) replacing the step return tuple;StepABC.stepand all implementations (SingleModuleStep,SecondaryModuleStep,SeparateRadiationStep,FCN3Step) return it.fme.core.corrector.utils.captured_before: new helper returning the pre-correction values of variables modified out-of-place.fme.core.step.single_module.step_with_adjustments: taps the corrector boundary to build the detacheduncorrectedshadow and returnsStepOutput.fme.core.step.multi_call/_multi_call:MultiCallStepcomposes the wrapped step's shadow; theMultiCallhelper returns an empty shadow.fme.ace.stepper.single_module.Stepper.step: applies the name-preserving output process func to the shadow and returnsStepOutput;predict_generatoryieldsStepOutputand feeds the corrected output forward;_stack_step_outputs/process_prediction_generator_listconsumeStepOutput(stacking onlyoutputand threading the terminalstepper_state).fme.coupled.stepper.CoupledStepper: unwrapsStepOutput.output.Tests added
If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated
Resolves #1271
🤖 Generated with Claude Code