Refactor step to take [batch, ensemble] dims (on var_masking_simple)#1304
Draft
mcgibbon wants to merge 5 commits into
Draft
Refactor step to take [batch, ensemble] dims (on var_masking_simple)#1304mcgibbon wants to merge 5 commits into
mcgibbon wants to merge 5 commits into
Conversation
The step (`step_with_adjustments` and all four `StepABC` implementations) now receives tensors with an explicit `[batch, ensemble, *spatial]` leading dimension pair instead of a folded `[batch*ensemble, *spatial]` dimension. The nn.Module calls still operate on a folded `[batch*ensemble, channel, *spatial]` batch: each `network_call` folds the ensemble dimension into the batch immediately before the module and unfolds the module output. This makes the ensemble structure visible inside the step, a prerequisite for input channel dropout that is shared across ensemble members but independent across the batch (a follow-up). - Global mean removal reduces every dimension after the leading sample dim, so it is folded around (each ensemble member is an independent sample for mean removal) — keeping it byte-identical and agnostic to the spatial rank. - The corrector, ocean and normalizer are leading-dim agnostic (the corrector reduces only HORIZONTAL_DIMS) and run on the [batch, ensemble] layout directly. CorrectorState/StepperState gain fold/unfold_ensemble helpers; the corrector state is [batch, ensemble, 1, 1] inside the step and folded to [batch*ensemble, 1, 1] in the externally-threaded state. - predict_generator keeps its external contract folded: it unfolds the per-step inputs, masks and stepper_state, threads [batch, ensemble] through the step, and folds the yielded outputs — so training and inference callers are unchanged. - The input-mask helpers handle the leading [batch, ensemble] pair, and new fold/unfold_ensemble_tensor helpers fold a single tensor. Behavior-preserving: an added ensemble-independence test asserts each member's output equals an independent single-member step, and the spatial-parallel regression baselines reproduce the pre-refactor outputs bit-for-bit (modulo the size-1 ensemble dim).
…emble] The regression input/output fixtures store tensors, so they must reflect the new [batch, ensemble, *spatial] step contract. The values are unchanged (verified bit-for-bit modulo the size-1 ensemble dim); only a size-1 ensemble dimension is added. Verified under both H=2 and W=2 spatial decompositions.
…ble] The input dropout mask now flows as [batch, ensemble] (unfolded in predict_generator alongside data_mask), so the direct-step dropout tests pass [batch, ensemble] masks and StepArgs documents the new shape. Behavior is unchanged: the mask is still ensemble-shared (sampled per base sample, repeat- interleaved), now applied in the explicit [batch, ensemble] layout.
…nnel Per review: the stepper interfaces must not be handed a folded [batch*ensemble, ...] TensorMapping together with an n_ensemble: int that re-interprets its combined leading dimension. Instead the explicit [batch, ensemble, *spatial] layout flows through, folding only at the encapsulated storage/IO boundaries. - BatchData keeps its folded `data` + `n_ensemble` as encapsulated storage and now exposes `ensemble_data` (already present) plus `ensemble_data_mask` and `ensemble_stepper_state`, all deriving the explicit [batch, ensemble, ...] view (ensemble length 1 is valid even before broadcast_ensemble). - Stepper.predict_generator no longer takes n_ensemble: it consumes the explicit ensemble views (time at dim 2), threads [batch, ensemble, ...] through the step, and yields that layout directly. Callers pass `.ensemble_data` / `.ensemble_data_mask` / `.ensemble_stepper_state`; _accumulate_loss drops its post-yield unfold. - process_prediction_generator_list consumes the explicit yields and folds back into BatchData's folded storage at that one boundary (deriving n_ensemble from the shape, no argument). - The coupled stepper drives its components over a flat batch; it folds the components' explicit yields back at the boundary and treats the component initial conditions as a flat (n_ensemble=1) view, matching the freshly built component forcings. A deeper coupled migration is a separate follow-up. Behavior-preserving: ace step/stepper, coupled, inference, aggregator and the spatial-parallel regression suites all pass unchanged.
Per review comments on #1302: - Remove BatchData.ensemble_data_mask and ensemble_stepper_state. The data_mask is constant across ensemble members, so it stays per-sample and broadcasts at apply; the stepper_state stays in its folded layout. - step_with_adjustments now folds the explicit [batch, ensemble] data into the batch at entry (each ensemble member is an independent sample) and unfolds the output, running the existing per-sample body on the folded data. This reverts the per-op network_call / mask-helper / global-mean-removal / corrector changes back to main, and drops the CorrectorState/StepperState fold-helpers and the single-tensor fold/unfold tensor helpers. - predict_generator consumes BatchData.ensemble_data (self-describing, no n_ensemble argument) but takes data_mask/stepper_state in their folded form. - Revert the predict_generator parameter rename (ic_dict/forcing_dict). - Keep Stepper.TIME_DIM = 1 (the coupled stepper indexes its folded component data via it); predict_generator uses a local time_dim = TIME_DIM + 1. Behavior-preserving: the step body is byte-identical to main on the folded batch. ace step/stepper, coupled, inference, aggregator and the spatial-parallel regression suites pass.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Same refactor as #1302, based on
feature/var_masking_simpleso the explicit[batch, ensemble]step contract is integrated with the input-channel-dropout work on that branch. Base:feature/var_masking_simple(notmain). Mirrors the review changes made on #1302.The
step/stepper receive data with an explicit[batch, ensemble, *spatial]leading dimension pair (viaBatchData.ensemble_data) — self-describing, non_ensembleinteger side-channel.step_with_adjustmentsfolds the ensemble back into the batch at entry (each member is an independent sample), so the per-sample body — incl. the input-dropout masking innetwork_call— runs on the folded batch exactly as before, and the output is unfolded. Behavior-preserving (the step body is byte-identical to the base on the folded batch).On top of #1302's changes, integrated with var_masking's input dropout:
[batch*ensemble]layout (shared across ensemble members of a base sample, independent across the batch — matchingbroadcast_ensemble's block ordering);network_callapplies it on the folded batch. No[batch, ensemble]mask is materialized.StepArgs.input_dropout_maskflows throughpredict_generatorin its folded layout.(The ensemble-shared/batch-independent dropout behavior is unchanged; this refactor just gives the step the explicit
[batch, ensemble]data contract.)Verification: ace step / stepper / tensors / var_masking unit suites pass (incl. the input-dropout and ensemble tests); data_loading + inference + aggregator + coupled pass; spatial-parallel
test_step_regressionpasses underH=2andW=2.