Skip to content

Refactor step return to StepOutput carrying opaque StepMetrics#1286

Closed
mcgibbon wants to merge 1 commit into
mainfrom
refactor/step-metrics-abstraction
Closed

Refactor step return to StepOutput carrying opaque StepMetrics#1286
mcgibbon wants to merge 1 commit into
mainfrom
refactor/step-metrics-abstraction

Conversation

@mcgibbon

Copy link
Copy Markdown
Contributor

Replaces the StepABC.step return type tuple[TensorDict, StepperState | None] with a StepOutput dataclass that additionally carries an opaque, step-implementation-specific StepMetrics payload, and threads a StepperMetrics container (parallel to StepperState) through BatchData and the prediction path. This is the foundation for surfacing per-step quantities that only exist inside a particular Step — e.g. the pre-correction values at a corrector boundary — to that step's loss and aggregator, without leaking step-specific fields (corrections, uncorrected) into the generic step return, into BatchData, or into generic loss/aggregator code. That leakage is the shared cost of #1218, #1283, and #1284; this PR is the seam those features will be re-homed onto.

It is pure plumbing with no behavior change: metrics defaults to NullStepMetrics and stepper_metrics to a Null payload everywhere, so generic consumers need no None/isinstance branches. The StepMetricsLoss / StepMetricsAggregator classes — and the builder-config that constructs them per step type — deliberately land later, with the features that consume them, to avoid merging unused code.

Changes:

  • fme.core.step.metrics (new): StepMetrics ABC + NullStepMetrics (the step-focused payload; its lifecycle interface mirrors StepperState), and StepperMetrics + null_stepper_metrics (the concrete container that rides on BatchData beside StepperState).

  • fme.core.step.step.StepOutput: new dataclass (output, stepper_state, metrics); StepABC.step and all implementations (SingleModuleStep, SeparateRadiationStep, SecondaryModuleStep, FCN3Step, MultiCallStep, _multi_call.MultiCall) return it.

  • fme.ace.stepper.single_module: Stepper.step and predict_generator build/forward StepOutput; process_prediction_generator_list concatenates the per-step metrics into a windowed StepperMetrics on the returned BatchData.

  • fme.ace.data_loading.batch_data.BatchData: new stepper_metrics field, threaded through every lifecycle method (to_device/to_cpu/pin_memory/broadcast_ensemble/new_on_*/derived/slicing) alongside stepper_state.

  • fme.coupled.stepper: consume StepOutput from the component step generators.

  • Tests added — existing step / stepper / batch_data / coupled suites updated to the StepOutput return; no new behavior to cover (pure refactor). Verified green: fme/core/step (91), test_batch_data (48), fme/ace/stepper + fme/coupled non-parallel (387), inference smoke (35), parallel test_step under torchrun (15); ruff, ruff-format, and mypy clean.

  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated (n/a — no dependency changes)

Draft: foundation only. The corrector-regularization (#1218), pre-correction uncorrected (#1283), and correction-metrics (#1284) features will be ported on top of this seam as follow-ups.

🤖 Generated with Claude Code

Replace the StepABC.step return type tuple[TensorDict, StepperState | None]
with a StepOutput dataclass (output, stepper_state, metrics), where metrics
is an opaque, step-implementation-specific StepMetrics payload. Thread a
StepperMetrics container (parallel to StepperState) through BatchData and the
prediction path, stacked over the rollout window.

Pure structural refactor with zero behavior change: metrics defaults to
NullStepMetrics everywhere and StepperMetrics to a Null payload, so generic
consumers need no None/isinstance branches. Establishes the seam for
compartmentalizing per-step loss and metrics (corrector regularization,
pre-correction values, correction metrics) behind step-owned classes; the
StepMetricsLoss / StepMetricsAggregator classes and their wiring land with
the features that consume them.

- fme/core/step/metrics.py: StepMetrics ABC + NullStepMetrics, StepperMetrics
  container + null_stepper_metrics, mirroring StepperState's lifecycle.
- fme/core/step/step.py: StepOutput dataclass; StepABC.step returns it.
- All Step implementations (single_module, radiation, secondary_module, fcn3,
  multi_call, _multi_call) return StepOutput.
- fme/ace/stepper/single_module.py: Stepper.step, predict_generator and
  process_prediction_generator_list build/forward StepOutput / StepperMetrics.
- fme/ace/data_loading/batch_data.py: BatchData carries stepper_metrics,
  threaded through every lifecycle method alongside stepper_state.
- fme/coupled/stepper.py: consume StepOutput from component generators.
- Tests updated to the StepOutput return.
@mcgibbon mcgibbon closed this Jun 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant