From a3bbb86f76e543a971d54cb15ea1b71e194a8178 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 5 Jun 2026 17:22:21 +0200 Subject: [PATCH] fix[next]: canonical order for gather output dims `return_type_field` (frontend) and `_gather_output_domain` (embedded) built the gather output in insertion order. When a connectivity introduces a dim of a different kind than the codomain it replaced (e.g. a vertical iteration axis while a local dim survives), insertion order is non-canonical: the result is rejected by `check_dims` / mismatches the out field, and runtime and frontend can disagree. Return both outputs via `order_dimensions` so they agree on a canonical, valid field domain. The frontend deduces the dims as the set `(field.dims - source) | target`; the source dim survives iff it reappears in the target. Tests: vertical-axis canonicalization regressions (frontend + embedded), update the raw-premap tests to expect canonical order, and a guard that same-dim connectivities (V2V) keep the source dim. --- src/gt4py/next/embedded/nd_array_field.py | 5 +- src/gt4py/next/type_system/type_info.py | 3 +- .../embedded_tests/test_nd_array_field.py | 50 ++++++++++++++++--- .../ffront_tests/test_type_deduction.py | 35 +++++++++++++ 4 files changed, 82 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 69bb89da3a..336e0ac4f6 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -593,7 +593,8 @@ def _gather_output_domain( field_domain: common.Domain, connectivities: Sequence[common.GatherConnectivity] ) -> common.Domain: """Output domain of a simultaneous gather: each codomain is replaced by the dimensions of its - connectivity's domain; dimensions shared with the field domain are intersected in place.""" + connectivity's domain; dimensions shared with the field domain are intersected. Returned in + canonical order so it is a valid field domain matching the frontend-deduced type.""" domain = field_domain for conn in connectivities: cod = conn.codomain @@ -616,7 +617,7 @@ def _gather_output_domain( else: result.append(nr) domain = common.Domain(*result) - return domain + return common.Domain(*(domain[d] for d in common.order_dimensions(domain.dims))) def _gather_premap(data: NdArrayField, *connectivities: common.GatherConnectivity) -> NdArrayField: diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index eb70d15947..26b4df04bd 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -663,7 +663,8 @@ def return_type_field( new_dims.append(d) else: new_dims.extend(target_dims) - return ts.FieldType(dims=new_dims, dtype=field_type.dtype) + # Canonical order so the deduced type matches the embedded `premap` output domain. + return ts.FieldType(dims=common.order_dimensions(new_dims), dtype=field_type.dtype) @return_type.register diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 3ed346df2e..ecb6e4ff7f 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -553,10 +553,12 @@ def test_premap_same_dim_multineighbor_with_extra_dim(): result = c_field.premap(conn) + # canonical order: K (horizontal) before C2E2CO (local) assert result.domain == common.Domain( - dims=(C, C2E2CO, K), ranges=(UnitRange(0, NC), UnitRange(0, NN), UnitRange(0, NK)) + dims=(C, K, C2E2CO), ranges=(UnitRange(0, NC), UnitRange(0, NK), UnitRange(0, NN)) ) - assert np.all(result.ndarray == c_field.ndarray[table]) # out[c, n, k] == f[table[c, n], k] + # out[c, k, n] == f[table[c, n], k] + assert np.all(result.ndarray == np.transpose(c_field.ndarray[table], (0, 2, 1))) def test_gather_premap_multiple_connectivities(): @@ -655,9 +657,10 @@ def test_gather_premap_reads_non_codomain_field_dim(): result = f.premap(conn) - assert result.domain == common.Domain(dims=(L, B), ranges=(UnitRange(0, 2), UnitRange(0, NB))) - # out[l, b] = f[table[b, l], b] - assert np.all(result.ndarray == f.ndarray[table.T, np.arange(NB)[None, :]]) + # canonical order: B (horizontal) before L (local) + assert result.domain == common.Domain(dims=(B, L), ranges=(UnitRange(0, NB), UnitRange(0, 2))) + # out[b, l] = f[table[b, l], b] + assert np.all(result.ndarray == f.ndarray[table, np.arange(NB)[:, None]]) def test_gather_premap_shared_domain_dim(): @@ -713,9 +716,40 @@ def test_gather_premap_mix_introducing_and_preserving(): result = f.premap(conn_a, conn_b) - assert result.domain == common.Domain(dims=(X, B), ranges=(UnitRange(0, 3), UnitRange(0, NB))) - # out[x, b] = f[ca[x], cb[b]] - assert np.all(result.ndarray == f.ndarray[ca[:, None], cb[None, :]]) + # canonical order: B before X (both horizontal, ordered by name) + assert result.domain == common.Domain(dims=(B, X), ranges=(UnitRange(0, NB), UnitRange(0, 3))) + # out[b, x] = f[ca[x], cb[b]] + assert np.all(result.ndarray == f.ndarray[ca[None, :], cb[:, None]]) + + +def test_gather_premap_introduced_vertical_dim_canonicalized(): + # Introduce a vertical dim (K) while a local dim (Band) survives: insertion gives the + # non-canonical (Cell, K, Band), the output must be canonical (Cell, Band, K). + PT = Dimension("PT") + Band = Dimension("Band", kind=DimensionKind.LOCAL) + Cell = Dimension("Cell") + K = Dimension("K", kind=DimensionKind.VERTICAL) + + NPT, NBND, NC, NK = 4, 3, 5, 2 + f = common._field( + np.arange(NPT * NBND).reshape(NPT, NBND).astype(float), + domain=common.Domain(dims=(PT, Band), ranges=(UnitRange(0, NPT), UnitRange(0, NBND))), + ) + table = (np.arange(NC * NK).reshape(NC, NK)) % NPT # (Cell, K) -> PT + conn = common._connectivity( + table, + domain=common.Domain(dims=(Cell, K), ranges=(UnitRange(0, NC), UnitRange(0, NK))), + codomain=PT, + ) + + result = f.premap(conn) + + assert result.domain == common.Domain( + dims=(Cell, Band, K), ranges=(UnitRange(0, NC), UnitRange(0, NBND), UnitRange(0, NK)) + ) + # out[c, b, k] = f[table[c, k], b] + expected = f.ndarray[table[:, None, :], np.arange(NBND)[None, :, None]] + assert np.all(result.ndarray == expected) def test_premap_chained_connectivities_raises(): diff --git a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py index 22bd1a7a9e..49c07f9eb4 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py @@ -293,6 +293,41 @@ def premap_fo(bar: Field[[X, K], int64]) -> Field[[Y, Y2XDim, K], int64]: ) +def test_premap_introduced_vertical_dim_canonicalized(): + # Introduce a vertical dim (K) while a local dim (B) survives: insertion gives the non-canonical + # (K, B), the deduced type must be canonical (B, K). + Src = Dimension("Src") + B = Dimension("B", kind=DimensionKind.LOCAL) + K = Dimension("K", kind=DimensionKind.VERTICAL) + SrcToK = FieldOffset("SrcToK", source=Src, target=(K,)) + + def premap_fo(bar: Field[[Src, B], int64]) -> Field[[B, K], int64]: + return bar(SrcToK) + + parsed = FieldOperatorParser.apply_to_function(premap_fo) + + assert parsed.body.stmts[0].value.type == ts.FieldType( + dims=[B, K], dtype=ts.ScalarType(kind=ts.ScalarKind.INT64) + ) + + +def test_premap_same_dim_connectivity_keeps_source(): + # A same-dim connectivity (V2V: source reappears in the target) keeps the source dim: + # [V] -> [V, V2VDim]. + V = Dimension("V") + V2VDim = Dimension("V2V", kind=DimensionKind.LOCAL) + V2V = FieldOffset("V2V", source=V, target=(V, V2VDim)) + + def premap_fo(bar: Field[[V], int64]) -> Field[[V, V2VDim], int64]: + return bar(V2V) + + parsed = FieldOperatorParser.apply_to_function(premap_fo) + + assert parsed.body.stmts[0].value.type == ts.FieldType( + dims=[V, V2VDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT64) + ) + + def test_premap_reduce(premap_setup): X, Y, Y2XDim, Y2X = premap_setup