Skip to content

Refactor step to take [batch, ensemble] dims (on var_masking_simple)#1304

Draft
mcgibbon wants to merge 5 commits into
feature/var_masking_simplefrom
refactor/step-batch-ensemble-on-var-masking
Draft

Refactor step to take [batch, ensemble] dims (on var_masking_simple)#1304
mcgibbon wants to merge 5 commits into
feature/var_masking_simplefrom
refactor/step-batch-ensemble-on-var-masking

Conversation

@mcgibbon

@mcgibbon mcgibbon commented Jun 22, 2026

Copy link
Copy Markdown
Contributor

Same refactor as #1302, based on feature/var_masking_simple so the explicit [batch, ensemble] step contract is integrated with the input-channel-dropout work on that branch. Base: feature/var_masking_simple (not main). Mirrors the review changes made on #1302.

The step/stepper receive data with an explicit [batch, ensemble, *spatial] leading dimension pair (via BatchData.ensemble_data) — self-describing, no n_ensemble integer side-channel. step_with_adjustments folds the ensemble back into the batch at entry (each member is an independent sample), so the per-sample body — incl. the input-dropout masking in network_call — runs on the folded batch exactly as before, and the output is unfolded. Behavior-preserving (the step body is byte-identical to the base on the folded batch).

On top of #1302's changes, integrated with var_masking's input dropout:

  • The input-dropout mask is built per base sample and repeat-interleaved across the ensemble into the folded [batch*ensemble] layout (shared across ensemble members of a base sample, independent across the batch — matching broadcast_ensemble's block ordering); network_call applies it on the folded batch. No [batch, ensemble] mask is materialized.
  • StepArgs.input_dropout_mask flows through predict_generator in its folded layout.

(The ensemble-shared/batch-independent dropout behavior is unchanged; this refactor just gives the step the explicit [batch, ensemble] data contract.)

Verification: ace step / stepper / tensors / var_masking unit suites pass (incl. the input-dropout and ensemble tests); data_loading + inference + aggregator + coupled pass; spatial-parallel test_step_regression passes under H=2 and W=2.

  • Tests added
  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

mcgibbon added 5 commits June 22, 2026 16:07
The step (`step_with_adjustments` and all four `StepABC` implementations)
now receives tensors with an explicit `[batch, ensemble, *spatial]` leading
dimension pair instead of a folded `[batch*ensemble, *spatial]` dimension.
The nn.Module calls still operate on a folded `[batch*ensemble, channel,
*spatial]` batch: each `network_call` folds the ensemble dimension into the
batch immediately before the module and unfolds the module output.

This makes the ensemble structure visible inside the step, a prerequisite
for input channel dropout that is shared across ensemble members but
independent across the batch (a follow-up).

- Global mean removal reduces every dimension after the leading sample
  dim, so it is folded around (each ensemble member is an independent
  sample for mean removal) — keeping it byte-identical and agnostic to the
  spatial rank.
- The corrector, ocean and normalizer are leading-dim agnostic (the
  corrector reduces only HORIZONTAL_DIMS) and run on the [batch, ensemble]
  layout directly. CorrectorState/StepperState gain fold/unfold_ensemble
  helpers; the corrector state is [batch, ensemble, 1, 1] inside the step
  and folded to [batch*ensemble, 1, 1] in the externally-threaded state.
- predict_generator keeps its external contract folded: it unfolds the
  per-step inputs, masks and stepper_state, threads [batch, ensemble]
  through the step, and folds the yielded outputs — so training and
  inference callers are unchanged.
- The input-mask helpers handle the leading [batch, ensemble] pair, and
  new fold/unfold_ensemble_tensor helpers fold a single tensor.

Behavior-preserving: an added ensemble-independence test asserts each
member's output equals an independent single-member step, and the
spatial-parallel regression baselines reproduce the pre-refactor outputs
bit-for-bit (modulo the size-1 ensemble dim).
…emble]

The regression input/output fixtures store tensors, so they must reflect the
new [batch, ensemble, *spatial] step contract. The values are unchanged
(verified bit-for-bit modulo the size-1 ensemble dim); only a size-1 ensemble
dimension is added. Verified under both H=2 and W=2 spatial decompositions.
…ble]

The input dropout mask now flows as [batch, ensemble] (unfolded in
predict_generator alongside data_mask), so the direct-step dropout tests pass
[batch, ensemble] masks and StepArgs documents the new shape. Behavior is
unchanged: the mask is still ensemble-shared (sampled per base sample, repeat-
interleaved), now applied in the explicit [batch, ensemble] layout.
…nnel

Per review: the stepper interfaces must not be handed a folded
[batch*ensemble, ...] TensorMapping together with an n_ensemble: int that
re-interprets its combined leading dimension. Instead the explicit
[batch, ensemble, *spatial] layout flows through, folding only at the
encapsulated storage/IO boundaries.

- BatchData keeps its folded `data` + `n_ensemble` as encapsulated storage
  and now exposes `ensemble_data` (already present) plus `ensemble_data_mask`
  and `ensemble_stepper_state`, all deriving the explicit [batch, ensemble,
  ...] view (ensemble length 1 is valid even before broadcast_ensemble).
- Stepper.predict_generator no longer takes n_ensemble: it consumes the
  explicit ensemble views (time at dim 2), threads [batch, ensemble, ...]
  through the step, and yields that layout directly. Callers pass
  `.ensemble_data` / `.ensemble_data_mask` / `.ensemble_stepper_state`;
  _accumulate_loss drops its post-yield unfold.
- process_prediction_generator_list consumes the explicit yields and folds
  back into BatchData's folded storage at that one boundary (deriving
  n_ensemble from the shape, no argument).
- The coupled stepper drives its components over a flat batch; it folds the
  components' explicit yields back at the boundary and treats the component
  initial conditions as a flat (n_ensemble=1) view, matching the freshly
  built component forcings. A deeper coupled migration is a separate follow-up.

Behavior-preserving: ace step/stepper, coupled, inference, aggregator and the
spatial-parallel regression suites all pass unchanged.
Per review comments on #1302:
- Remove BatchData.ensemble_data_mask and ensemble_stepper_state. The
  data_mask is constant across ensemble members, so it stays per-sample and
  broadcasts at apply; the stepper_state stays in its folded layout.
- step_with_adjustments now folds the explicit [batch, ensemble] data into the
  batch at entry (each ensemble member is an independent sample) and unfolds
  the output, running the existing per-sample body on the folded data. This
  reverts the per-op network_call / mask-helper / global-mean-removal /
  corrector changes back to main, and drops the CorrectorState/StepperState
  fold-helpers and the single-tensor fold/unfold tensor helpers.
- predict_generator consumes BatchData.ensemble_data (self-describing, no
  n_ensemble argument) but takes data_mask/stepper_state in their folded form.
- Revert the predict_generator parameter rename (ic_dict/forcing_dict).
- Keep Stepper.TIME_DIM = 1 (the coupled stepper indexes its folded component
  data via it); predict_generator uses a local time_dim = TIME_DIM + 1.

Behavior-preserving: the step body is byte-identical to main on the folded
batch. ace step/stepper, coupled, inference, aggregator and the spatial-parallel
regression suites pass.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant