Refactor step return to StepOutput carrying opaque StepMetrics#1286
Closed
mcgibbon wants to merge 1 commit into
Closed
Refactor step return to StepOutput carrying opaque StepMetrics#1286mcgibbon wants to merge 1 commit into
mcgibbon wants to merge 1 commit into
Conversation
Replace the StepABC.step return type tuple[TensorDict, StepperState | None] with a StepOutput dataclass (output, stepper_state, metrics), where metrics is an opaque, step-implementation-specific StepMetrics payload. Thread a StepperMetrics container (parallel to StepperState) through BatchData and the prediction path, stacked over the rollout window. Pure structural refactor with zero behavior change: metrics defaults to NullStepMetrics everywhere and StepperMetrics to a Null payload, so generic consumers need no None/isinstance branches. Establishes the seam for compartmentalizing per-step loss and metrics (corrector regularization, pre-correction values, correction metrics) behind step-owned classes; the StepMetricsLoss / StepMetricsAggregator classes and their wiring land with the features that consume them. - fme/core/step/metrics.py: StepMetrics ABC + NullStepMetrics, StepperMetrics container + null_stepper_metrics, mirroring StepperState's lifecycle. - fme/core/step/step.py: StepOutput dataclass; StepABC.step returns it. - All Step implementations (single_module, radiation, secondary_module, fcn3, multi_call, _multi_call) return StepOutput. - fme/ace/stepper/single_module.py: Stepper.step, predict_generator and process_prediction_generator_list build/forward StepOutput / StepperMetrics. - fme/ace/data_loading/batch_data.py: BatchData carries stepper_metrics, threaded through every lifecycle method alongside stepper_state. - fme/coupled/stepper.py: consume StepOutput from component generators. - Tests updated to the StepOutput return.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Replaces the
StepABC.stepreturn typetuple[TensorDict, StepperState | None]with aStepOutputdataclass that additionally carries an opaque, step-implementation-specificStepMetricspayload, and threads aStepperMetricscontainer (parallel toStepperState) throughBatchDataand the prediction path. This is the foundation for surfacing per-step quantities that only exist inside a particularStep— e.g. the pre-correction values at a corrector boundary — to that step's loss and aggregator, without leaking step-specific fields (corrections,uncorrected) into the generic step return, intoBatchData, or into generic loss/aggregator code. That leakage is the shared cost of #1218, #1283, and #1284; this PR is the seam those features will be re-homed onto.It is pure plumbing with no behavior change:
metricsdefaults toNullStepMetricsandstepper_metricsto a Null payload everywhere, so generic consumers need noNone/isinstancebranches. TheStepMetricsLoss/StepMetricsAggregatorclasses — and the builder-config that constructs them per step type — deliberately land later, with the features that consume them, to avoid merging unused code.Changes:
fme.core.step.metrics(new):StepMetricsABC +NullStepMetrics(the step-focused payload; its lifecycle interface mirrorsStepperState), andStepperMetrics+null_stepper_metrics(the concrete container that rides onBatchDatabesideStepperState).fme.core.step.step.StepOutput: new dataclass (output,stepper_state,metrics);StepABC.stepand all implementations (SingleModuleStep,SeparateRadiationStep,SecondaryModuleStep,FCN3Step,MultiCallStep,_multi_call.MultiCall) return it.fme.ace.stepper.single_module:Stepper.stepandpredict_generatorbuild/forwardStepOutput;process_prediction_generator_listconcatenates the per-step metrics into a windowedStepperMetricson the returnedBatchData.fme.ace.data_loading.batch_data.BatchData: newstepper_metricsfield, threaded through every lifecycle method (to_device/to_cpu/pin_memory/broadcast_ensemble/new_on_*/derived/slicing) alongsidestepper_state.fme.coupled.stepper: consumeStepOutputfrom the component step generators.Tests added — existing step / stepper / batch_data / coupled suites updated to the
StepOutputreturn; no new behavior to cover (pure refactor). Verified green:fme/core/step(91),test_batch_data(48),fme/ace/stepper+fme/couplednon-parallel (387), inference smoke (35), paralleltest_stepunder torchrun (15); ruff, ruff-format, and mypy clean.If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated (n/a — no dependency changes)
Draft: foundation only. The corrector-regularization (#1218), pre-correction
uncorrected(#1283), and correction-metrics (#1284) features will be ported on top of this seam as follow-ups.🤖 Generated with Claude Code