From 1b4707cfd1d6b5eb9623f3e51aa31c8d4696655c Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 20 Feb 2026 11:41:56 +0100 Subject: [PATCH 01/24] Tracer prototype --- src/gt4py/next/ffront/field_operator_ast.py | 12 +++++ .../ffront/foast_passes/type_deduction.py | 26 +++++++++- src/gt4py/next/ffront/foast_pretty_printer.py | 2 + src/gt4py/next/ffront/foast_to_gtir.py | 9 ++++ src/gt4py/next/ffront/foast_to_past.py | 8 ++-- src/gt4py/next/ffront/func_to_foast.py | 31 ++++++++++-- .../next/ffront/past_passes/type_deduction.py | 2 +- src/gt4py/next/iterator/builtins.py | 3 +- .../next/iterator/transforms/pass_manager.py | 3 ++ .../iterator/transforms/unroll_map_tuple.py | 47 +++++++++++++++++++ .../iterator/type_system/type_synthesizer.py | 13 +++++ src/gt4py/next/type_system/type_info.py | 8 ++++ .../next/type_system/type_specifications.py | 9 ++++ .../next/type_system/type_translation.py | 22 +++++++-- tests/next_tests/integration_tests/cases.py | 11 +++++ .../ffront_tests/test_execution.py | 30 ++++++++++++ 16 files changed, 222 insertions(+), 14 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/unroll_map_tuple.py diff --git a/src/gt4py/next/ffront/field_operator_ast.py b/src/gt4py/next/ffront/field_operator_ast.py index fa5bc4889f..f4950ea402 100644 --- a/src/gt4py/next/ffront/field_operator_ast.py +++ b/src/gt4py/next/ffront/field_operator_ast.py @@ -113,6 +113,18 @@ class TupleExpr(Expr): elts: list[Expr] +# TODO: give a good error for tuple(... for el in iter if ...) so that users understand that and why we don't support conditionals +# TODO: should this have SymbolTableTrait since target declares a new symbol. Write test that has two comprehensions using the same target name. +class TupleComprehension(Expr): + """ + tuple(element_expr for target in iterable) + """ + + element_expr: Expr + target: DataSymbol # TODO: how about `tuple(el1+el2 for el1, el2 in var_arg)`? + iterable: Expr + + class UnaryOp(Expr): op: dialect_ast_enums.UnaryOperator operand: Expr diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 68bf108a0a..e545f9e002 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -5,7 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause - +import collections from typing import Any, Optional, TypeAlias, TypeVar, cast import gt4py.next.ffront.field_operator_ast as foast @@ -501,6 +501,10 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> foast.Subscri f"Tuples need to be indexed with literal integers, got '{node.index}'.", ) from ex new_type = types[index] + case ts.VarArgType(element_type=element_type): + new_type = ( + element_type # TODO: we only temporarily allow any index for vararg types + ) case ts.OffsetType(source=source, target=(target1, target2)): if not target2.kind == DimensionKind.LOCAL: raise errors.DSLError( @@ -747,6 +751,26 @@ def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> foast.TupleEx new_type = ts.TupleType(types=[element.type for element in new_elts]) return foast.TupleExpr(elts=new_elts, type=new_type, location=node.location) + def visit_TupleComprehension( + self, node: foast.TupleComprehension, **kwargs: Any + ) -> foast.TupleComprehension: + symtable: collections.ChainMap = kwargs["symtable"] # todo annotation + iterable = self.visit(node.iterable, **kwargs) + target = self.visit(node.target, **kwargs) + assert isinstance(iterable.type, ts.VarArgType) + target.type = iterable.type.element_type + element_expr = self.visit( + node.element_expr, + **{**kwargs, "symtable": symtable.new_child({node.target.id: target})}, + ) + return foast.TupleComprehension( + element_expr=element_expr, + target=target, + iterable=iterable, + location=node.location, + type=ts.VarArgType(element_type=element_expr.type), + ) + def visit_Call(self, node: foast.Call, **kwargs: Any) -> foast.Call: new_func = self.visit(node.func, **kwargs) new_args = self.visit(node.args, **kwargs) diff --git a/src/gt4py/next/ffront/foast_pretty_printer.py b/src/gt4py/next/ffront/foast_pretty_printer.py index 8b2e369501..77495d78f7 100644 --- a/src/gt4py/next/ffront/foast_pretty_printer.py +++ b/src/gt4py/next/ffront/foast_pretty_printer.py @@ -118,6 +118,8 @@ def apply(cls, node: foast.LocatedNode, **kwargs: Any) -> str: # type: ignore[o TupleExpr = as_fmt("({', '.join(elts)}{',' if len(elts)==1 else ''})") + TupleComprehension = as_fmt("tuple(({element_expr} for {target} in {iterable}))") + UnaryOp = as_fmt("{op}{operand}") def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> str: diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 3825072cb7..2e587c346e 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -257,6 +257,15 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr: def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr: return im.make_tuple(*[self.visit(el, **kwargs) for el in node.elts]) + def visit_TupleComprehension(self, node: foast.TupleComprehension, **kwargs: Any) -> itir.Expr: + return im.call( + im.call("map_tuple")( + im.lambda_(self.visit(node.target, **kwargs))( + self.visit(node.element_expr, **kwargs) + ) + ) + )(self.visit(node.iterable, **kwargs)) + def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: # TODO(tehrengruber): extend iterator ir to support unary operators dtype = type_info.extract_dtype(node.type) diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index 05b080b70b..c37cba5a78 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -21,7 +21,7 @@ from gt4py.next.ffront.stages import ConcreteFOASTOperatorDef, ConcretePASTProgramDef from gt4py.next.iterator import ir as itir from gt4py.next.otf import toolchain, workflow -from gt4py.next.type_system import type_info, type_specifications as ts +from gt4py.next.type_system import type_specifications as ts @dataclasses.dataclass(frozen=True) @@ -113,9 +113,9 @@ def __call__(self, inp: ConcreteFOASTOperatorDef) -> ConcretePASTProgramDef: *partial_program_type.definition.kw_only_args.keys(), ] assert isinstance(type_, ts.CallableType) - assert arg_types[-1] == type_info.return_type( - type_, with_args=list(arg_types), with_kwargs=kwarg_types - ) + # assert arg_types[-1] == type_info.return_type( + # type_, with_args=list(arg_types), with_kwargs=kwarg_types + # ) assert args_names[-1] == "out" params_decl: list[past.Symbol] = [ diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index ced0ff3905..adefa7ba9e 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -337,7 +337,12 @@ def visit_Expr(self, node: ast.Expr) -> foast.Expr: return self.visit(node.value) def visit_Name(self, node: ast.Name, **kwargs: Any) -> foast.Name: - return foast.Name(id=node.id, location=self.get_location(node)) + loc = self.get_location(node) + if isinstance(node.ctx, ast.Store): + return foast.DataSymbol(id=node.id, location=loc, type=ts.DeferredType(constraint=None)) + else: + assert isinstance(node.ctx, ast.Load) + return foast.Name(id=node.id, location=loc) def visit_UnaryOp(self, node: ast.UnaryOp, **kwargs: Any) -> foast.UnaryOp: return foast.UnaryOp( @@ -469,8 +474,10 @@ def visit_NotEq(self, node: ast.NotEq, **kwargs: Any) -> foast.CompareOperator: return foast.CompareOperator.NOTEQ def _verify_builtin_type_constructor(self, node: ast.Call) -> None: - if len(node.args) > 0: - arg = node.args[0] + (arg,) = ( + node.args + ) # note for review: the change here is unrelated to the actual pr and just a small cleanup + if node.func.id == "tuple": if not ( isinstance(arg, ast.Constant) or (isinstance(arg, ast.UnaryOp) and isinstance(arg.operand, ast.Constant)) @@ -484,9 +491,25 @@ def _func_name(self, node: ast.Call) -> str: return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly. def visit_Call(self, node: ast.Call, **kwargs: Any) -> foast.Call: - # TODO(tehrengruber): is this still needed or redundant with the checks in type deduction? if isinstance(node.func, ast.Name): func_name = self._func_name(node) + + if func_name == "tuple": + (gen_expr,) = node.args + assert ( + len(gen_expr.generators) == 1 + ) # we don't support (... for ... in ... for ... in ...) + assert ( + gen_expr.generators[0].ifs == [] + ) # we don't support if conditions in comprehensions + return foast.TupleComprehension( + element_expr=self.visit(gen_expr.elt, **kwargs), + target=self.visit(gen_expr.generators[0].target, **kwargs), + iterable=self.visit(gen_expr.generators[0].iter, **kwargs), + location=self.get_location(node), + ) + + # TODO(tehrengruber): is this still needed or redundant with the checks in type deduction? if func_name in fbuiltins.TYPE_BUILTIN_NAMES: self._verify_builtin_type_constructor(node) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 9d021ceb51..530d407459 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -248,7 +248,7 @@ def visit_Call(self, node: past.Call, **kwargs: Any) -> past.Call: operator_return_type = type_info.return_type( new_func.type, with_args=arg_types, with_kwargs=kwarg_types ) - if operator_return_type != new_kwargs["out"].type: + if not type_info.is_compatible_type(operator_return_type, new_kwargs["out"].type): raise ValueError( "Expected keyword argument 'out' to be of " f"type '{operator_return_type}', got " diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index e54c6ea3d7..7b24c91884 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -498,7 +498,8 @@ def get_domain_range(*args): "lift", "make_const_list", "make_tuple", - "map_", + "map_tuple", + "map_", # TODO: rename to map_list "named_range", "neighbors", "reduce", diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 08ca9d94e0..4102790129 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -24,6 +24,7 @@ prune_empty_concat_where, remove_broadcast, symbol_ref_utils, + unroll_map_tuple, ) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple @@ -179,6 +180,7 @@ def apply_common_transforms( ) # domain inference does not support dynamic offsets yet ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) + ir = unroll_map_tuple.UnrollMapTuple.apply(ir, uids=uids) ir = infer_domain.infer_program( ir, @@ -293,6 +295,7 @@ def apply_fieldview_transforms( ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) + ir = unroll_map_tuple.UnrollMapTuple.apply(ir, uids=uids) ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( diff --git a/src/gt4py/next/iterator/transforms/unroll_map_tuple.py b/src/gt4py/next/iterator/transforms/unroll_map_tuple.py new file mode 100644 index 0000000000..66f96d66fa --- /dev/null +++ b/src/gt4py/next/iterator/transforms/unroll_map_tuple.py @@ -0,0 +1,47 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +import dataclasses + +from gt4py import eve +from gt4py.next import utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.type_system import inference as itir_inference +from gt4py.next.type_system import type_specifications as ts + + +@dataclasses.dataclass +class UnrollMapTuple(eve.NodeTranslator): + PRESERVED_ANNEX_ATTRS = ("domain",) + + uids: utils.IDGeneratorPool + + @classmethod + def apply(cls, program: itir.Program, *, uids: utils.IDGeneratorPool): + return cls(uids=uids).visit(program) + + def visit_FunCall(self, node: itir.Expr): + node = self.generic_visit(node) + + if cpm.is_call_to(node.fun, "map_tuple"): + # TODO: we have to duplicate the function here since the domain inference can not handle them yet + f = node.fun.args[0] + tup = node.args[0] + itir_inference.reinfer(tup) + assert isinstance(tup.type, ts.TupleType) + tup_ref = next(self.uids["_ump"]) + + result = im.let(tup_ref, tup)( + im.make_tuple( + *(im.call(f)(im.tuple_get(i, tup_ref)) for i in range(len(tup.type.types))) + ) + ) + itir_inference.reinfer(result) + + return result + return node diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 6d77c70375..4406dd9aa8 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -633,6 +633,19 @@ def applied_map( return applied_map +@_register_builtin_type_synthesizer +def map_tuple(op: TypeSynthesizer) -> TypeSynthesizer: + @type_synthesizer + def applied_map( + arg: ts.TupleType, offset_provider_type: common.OffsetProviderType + ) -> ts.TupleType: + return ts.TupleType( + types=[op(arg_, offset_provider_type=offset_provider_type) for arg_ in arg.types] + ) + + return applied_map + + @_register_builtin_type_synthesizer def reduce(op: TypeSynthesizer, init: ts.TypeSpec) -> TypeSynthesizer: @type_synthesizer diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index eb70d15947..69fccd33da 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -566,6 +566,14 @@ def is_concretizable(symbol_type: ts.TypeSpec, to_type: ts.TypeSpec) -> bool: or issubclass(type_class(to_type), symbol_type.constraint) ): return True + if isinstance(symbol_type, ts.VarArgType) and isinstance(to_type, ts.VarArgType): + return is_concretizable(symbol_type.element_type, to_type.element_type) + if isinstance(symbol_type, ts.VarArgType) and isinstance(to_type, ts.TupleType): + if len(to_type.types) == 0 or ( + all(type_ == to_type.types[0] for type_ in to_type.types) + and is_concretizable(symbol_type.element_type, to_type.types[0]) + ): + return True elif is_concrete(symbol_type): return symbol_type == to_type return False diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 59ac40f0f3..409138d593 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -148,6 +148,15 @@ def __len__(self) -> int: return len(self.types) +class VarArgType(DataType): + """Represents a variable number of arguments of the same type.""" + + element_type: DataType # TODO: maybe also support different DataTypes + + def __str__(self) -> str: + return f"VarArg[{self.element_type}]" + + class AnyPythonType: """Marker type representing any Python type which cannot be used for instantiation. diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 0f145e04aa..0ca020625a 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -180,8 +180,12 @@ def from_type_hint( case builtins.tuple: if not args: raise ValueError(f"Tuple annotation '{type_hint}' requires at least one argument.") - if Ellipsis in args: - raise ValueError(f"Unbound tuples '{type_hint}' are not allowed.") + if len(args) == 2 and args[1] is Ellipsis: + return ts.VarArgType(element_type=from_type_hint_same_ns(args[0])) + elif Ellipsis in args: + raise ValueError( + f"Vararg tuple annotation '{type_hint}' cannot have more than one argument." + ) tuple_types = [from_type_hint_same_ns(arg) for arg in args] assert all(isinstance(elem, ts.DataType) for elem in tuple_types) return ts.TupleType(types=tuple_types) @@ -321,7 +325,19 @@ def from_value(value: Any) -> ts.TypeSpec: return UnknownPythonObject(value) else: type_ = xtyping.infer_type(value, annotate_callable_kwargs=True) - symbol_type = from_type_hint(type_) + if type_ == type[tuple]: + # TODO: this special casing here is not nice, but infer_type is also called on the annotations where + # we don't want to allow unparameterized tuples (or do we?). + symbol_type = ts.ConstructorType( + definition=ts.FunctionType( + pos_only_args=[ts.DeferredType(constraint=None)], + pos_or_kw_args={}, + kw_only_args={}, + returns=ts.DeferredType(constraint=ts.VarArgType), + ) + ) + else: + symbol_type = from_type_hint(type_) if isinstance(symbol_type, (ts.DataType, ts.CallableType, ts.OffsetType, ts.DimensionType)): return symbol_type diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 78e6c62781..e723c963de 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -603,6 +603,15 @@ def _allocate_from_type( for t in types ) ) + case ts.VarArgType(element_type=element_type): + return tuple( + ( + _allocate_from_type( + case=case, arg_type=t, domain=domain, dtype=dtype, strategy=strategy + ) + for t in [element_type] * 3 # TODO: revisit + ) + ) case ts.NamedCollectionType(types=types) as named_collection_type_spec: container_constructor = ( named_collections.make_named_collection_constructor_from_type_spec( @@ -648,6 +657,8 @@ def get_param_size(param_type: ts.TypeSpec, sizes: dict[gtx.Dimension, int]) -> return sum([get_param_size(t, sizes=sizes) for t in types]) case ts.NamedCollectionType(types=types): return sum([get_param_size(t, sizes=sizes) for t in types]) + case ts.VarArgType(element_type=element_type): + return get_param_size(ts.TupleType(types=[element_type] * 3), sizes) # TODO: revisit case _: raise TypeError(f"Can not get size for parameter of type '{param_type}'.") diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 8060d5bb36..14f14b3ffb 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -336,6 +336,36 @@ def testee(a: tuple[cases.IField, cases.IJField]) -> cases.IJField: ) +@pytest.mark.uses_tuple_args +def test_tuple_comprehension(cartesian_case): + @gtx.field_operator + def testee( + tracers: tuple[cases.IFloatField, ...], factor: float + ) -> tuple[cases.IFloatField, ...]: + return tuple(tracer * factor for tracer in tracers) + + cases.verify_with_default_data( + cartesian_case, + testee, + ref=lambda t, f: tuple(el * f for el in t), + ) + + +@pytest.mark.uses_tuple_args +def test_tuple_vararg(cartesian_case): + @gtx.field_operator + def testee( + tracers: tuple[cases.IFloatField, ...], factor: float + ) -> tuple[cases.IFloatField, cases.IFloatField]: + return tracers[0] * factor, tracers[1] * factor + + cases.verify_with_default_data( + cartesian_case, + testee, + ref=lambda t, f: tuple(el * f for el in t[:2]), + ) + + @pytest.mark.uses_tuple_args @pytest.mark.xfail(reason="Iterator of tuple approach in lowering does not allow this.") def test_tuple_arg_with_unpromotable_dims(unstructured_case): From 02f881fef3276bd782eea4c98d414053b37987b1 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Mon, 27 Apr 2026 15:48:50 +0200 Subject: [PATCH 02/24] Introduce GTIR tree_map builtin and transform to make_tuple, also supporting nesting (extracted from #2487) --- src/gt4py/next/ffront/field_operator_ast.py | 12 ---- .../ffront/foast_passes/type_deduction.py | 27 +------- src/gt4py/next/ffront/foast_pretty_printer.py | 2 - src/gt4py/next/ffront/foast_to_gtir.py | 30 ++------ src/gt4py/next/ffront/foast_to_past.py | 8 +-- src/gt4py/next/ffront/func_to_foast.py | 31 ++------- .../next/ffront/past_passes/type_deduction.py | 2 +- src/gt4py/next/iterator/builtins.py | 2 +- src/gt4py/next/iterator/ir_utils/ir_makers.py | 5 ++ .../next/iterator/transforms/pass_manager.py | 6 +- .../iterator/transforms/unroll_map_tuple.py | 47 ------------- .../iterator/transforms/unroll_tree_map.py | 69 +++++++++++++++++++ .../iterator/type_system/type_synthesizer.py | 18 +++-- src/gt4py/next/type_system/type_info.py | 8 --- .../next/type_system/type_specifications.py | 9 --- .../next/type_system/type_translation.py | 22 +----- tests/next_tests/integration_tests/cases.py | 11 --- .../ffront_tests/test_execution.py | 30 -------- .../ffront_tests/test_foast_to_gtir.py | 9 +-- .../transforms_tests/test_unroll_tree_map.py | 43 ++++++++++++ 20 files changed, 159 insertions(+), 232 deletions(-) delete mode 100644 src/gt4py/next/iterator/transforms/unroll_map_tuple.py create mode 100644 src/gt4py/next/iterator/transforms/unroll_tree_map.py create mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py diff --git a/src/gt4py/next/ffront/field_operator_ast.py b/src/gt4py/next/ffront/field_operator_ast.py index 8ee216b96b..95a2588077 100644 --- a/src/gt4py/next/ffront/field_operator_ast.py +++ b/src/gt4py/next/ffront/field_operator_ast.py @@ -123,18 +123,6 @@ class TupleExpr(Expr): elts: list[Expr] -# TODO: give a good error for tuple(... for el in iter if ...) so that users understand that and why we don't support conditionals -# TODO: should this have SymbolTableTrait since target declares a new symbol. Write test that has two comprehensions using the same target name. -class TupleComprehension(Expr): - """ - tuple(element_expr for target in iterable) - """ - - element_expr: Expr - target: DataSymbol # TODO: how about `tuple(el1+el2 for el1, el2 in var_arg)`? - iterable: Expr - - class UnaryOp(Expr): op: dialect_ast_enums.UnaryOperator operand: Expr diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 2b33d54cca..11c0bfd88b 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -5,11 +5,10 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import collections + import textwrap from typing import Any, Optional, Sequence, TypeAlias, TypeVar, cast - import gt4py.next.ffront.field_operator_ast as foast from gt4py import eve from gt4py.eve import NodeTranslator, NodeVisitor, traits @@ -429,10 +428,6 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> foast.Subscri f"Tuples need to be indexed with literal integers, got '{node.index}'.", ) from ex new_type = types[index] - case ts.VarArgType(element_type=element_type): - new_type = ( - element_type # TODO: we only temporarily allow any index for vararg types - ) case ts.OffsetType(source=source, target=(target1, target2)): if not target2.kind == DimensionKind.LOCAL: raise errors.DSLError( @@ -679,26 +674,6 @@ def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> foast.TupleEx new_type = ts.TupleType(types=[element.type for element in new_elts]) return foast.TupleExpr(elts=new_elts, type=new_type, location=node.location) - def visit_TupleComprehension( - self, node: foast.TupleComprehension, **kwargs: Any - ) -> foast.TupleComprehension: - symtable: collections.ChainMap = kwargs["symtable"] # todo annotation - iterable = self.visit(node.iterable, **kwargs) - target = self.visit(node.target, **kwargs) - assert isinstance(iterable.type, ts.VarArgType) - target.type = iterable.type.element_type - element_expr = self.visit( - node.element_expr, - **{**kwargs, "symtable": symtable.new_child({node.target.id: target})}, - ) - return foast.TupleComprehension( - element_expr=element_expr, - target=target, - iterable=iterable, - location=node.location, - type=ts.VarArgType(element_type=element_expr.type), - ) - def visit_Call(self, node: foast.Call, **kwargs: Any) -> foast.Call: new_func = self.visit(node.func, **kwargs) new_args = self.visit(node.args, **kwargs) diff --git a/src/gt4py/next/ffront/foast_pretty_printer.py b/src/gt4py/next/ffront/foast_pretty_printer.py index 77495d78f7..8b2e369501 100644 --- a/src/gt4py/next/ffront/foast_pretty_printer.py +++ b/src/gt4py/next/ffront/foast_pretty_printer.py @@ -118,8 +118,6 @@ def apply(cls, node: foast.LocatedNode, **kwargs: Any) -> str: # type: ignore[o TupleExpr = as_fmt("({', '.join(elts)}{',' if len(elts)==1 else ''})") - TupleComprehension = as_fmt("tuple(({element_expr} for {target} in {iterable}))") - UnaryOp = as_fmt("{op}{operand}") def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> str: diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 2e587c346e..78b0671db1 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -257,15 +257,6 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr: def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr: return im.make_tuple(*[self.visit(el, **kwargs) for el in node.elts]) - def visit_TupleComprehension(self, node: foast.TupleComprehension, **kwargs: Any) -> itir.Expr: - return im.call( - im.call("map_tuple")( - im.lambda_(self.visit(node.target, **kwargs))( - self.visit(node.element_expr, **kwargs) - ) - ) - )(self.visit(node.iterable, **kwargs)) - def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: # TODO(tehrengruber): extend iterator ir to support unary operators dtype = type_info.extract_dtype(node.type) @@ -421,23 +412,16 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return self._lower_and_map("if_", *node.args) cond_ = self.visit(node.args[0]) + true_ = self.visit(node.args[1]) + false_ = self.visit(node.args[2]) cond_symref_name = f"__cond_{cond_.fingerprint()}" - def create_if( - true_: itir.Expr, false_: itir.Expr, arg_types: tuple[ts.TypeSpec, ts.TypeSpec] - ) -> itir.FunCall: - return _map( - "if_", - (im.ref(cond_symref_name), true_, false_), - (node.args[0].type, *arg_types), + # tree_map(lambda a, b: as_fieldop(if_)(cond_ref, a, b))(true_tup, false_tup) + result = im.tree_map( + im.lambda_("__a", "__b")( + im.op_as_fieldop("if_")(im.ref(cond_symref_name), im.ref("__a"), im.ref("__b")) ) - - result = lowering_utils.process_elements( - create_if, - (self.visit(node.args[1]), self.visit(node.args[2])), - node.type, - arg_types=(node.args[1].type, node.args[2].type), - ) + )(true_, false_) return im.let(cond_symref_name, cond_)(result) diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index c37cba5a78..05b080b70b 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -21,7 +21,7 @@ from gt4py.next.ffront.stages import ConcreteFOASTOperatorDef, ConcretePASTProgramDef from gt4py.next.iterator import ir as itir from gt4py.next.otf import toolchain, workflow -from gt4py.next.type_system import type_specifications as ts +from gt4py.next.type_system import type_info, type_specifications as ts @dataclasses.dataclass(frozen=True) @@ -113,9 +113,9 @@ def __call__(self, inp: ConcreteFOASTOperatorDef) -> ConcretePASTProgramDef: *partial_program_type.definition.kw_only_args.keys(), ] assert isinstance(type_, ts.CallableType) - # assert arg_types[-1] == type_info.return_type( - # type_, with_args=list(arg_types), with_kwargs=kwarg_types - # ) + assert arg_types[-1] == type_info.return_type( + type_, with_args=list(arg_types), with_kwargs=kwarg_types + ) assert args_names[-1] == "out" params_decl: list[past.Symbol] = [ diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index adefa7ba9e..ced0ff3905 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -337,12 +337,7 @@ def visit_Expr(self, node: ast.Expr) -> foast.Expr: return self.visit(node.value) def visit_Name(self, node: ast.Name, **kwargs: Any) -> foast.Name: - loc = self.get_location(node) - if isinstance(node.ctx, ast.Store): - return foast.DataSymbol(id=node.id, location=loc, type=ts.DeferredType(constraint=None)) - else: - assert isinstance(node.ctx, ast.Load) - return foast.Name(id=node.id, location=loc) + return foast.Name(id=node.id, location=self.get_location(node)) def visit_UnaryOp(self, node: ast.UnaryOp, **kwargs: Any) -> foast.UnaryOp: return foast.UnaryOp( @@ -474,10 +469,8 @@ def visit_NotEq(self, node: ast.NotEq, **kwargs: Any) -> foast.CompareOperator: return foast.CompareOperator.NOTEQ def _verify_builtin_type_constructor(self, node: ast.Call) -> None: - (arg,) = ( - node.args - ) # note for review: the change here is unrelated to the actual pr and just a small cleanup - if node.func.id == "tuple": + if len(node.args) > 0: + arg = node.args[0] if not ( isinstance(arg, ast.Constant) or (isinstance(arg, ast.UnaryOp) and isinstance(arg.operand, ast.Constant)) @@ -491,25 +484,9 @@ def _func_name(self, node: ast.Call) -> str: return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly. def visit_Call(self, node: ast.Call, **kwargs: Any) -> foast.Call: + # TODO(tehrengruber): is this still needed or redundant with the checks in type deduction? if isinstance(node.func, ast.Name): func_name = self._func_name(node) - - if func_name == "tuple": - (gen_expr,) = node.args - assert ( - len(gen_expr.generators) == 1 - ) # we don't support (... for ... in ... for ... in ...) - assert ( - gen_expr.generators[0].ifs == [] - ) # we don't support if conditions in comprehensions - return foast.TupleComprehension( - element_expr=self.visit(gen_expr.elt, **kwargs), - target=self.visit(gen_expr.generators[0].target, **kwargs), - iterable=self.visit(gen_expr.generators[0].iter, **kwargs), - location=self.get_location(node), - ) - - # TODO(tehrengruber): is this still needed or redundant with the checks in type deduction? if func_name in fbuiltins.TYPE_BUILTIN_NAMES: self._verify_builtin_type_constructor(node) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 530d407459..9d021ceb51 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -248,7 +248,7 @@ def visit_Call(self, node: past.Call, **kwargs: Any) -> past.Call: operator_return_type = type_info.return_type( new_func.type, with_args=arg_types, with_kwargs=kwarg_types ) - if not type_info.is_compatible_type(operator_return_type, new_kwargs["out"].type): + if operator_return_type != new_kwargs["out"].type: raise ValueError( "Expected keyword argument 'out' to be of " f"type '{operator_return_type}', got " diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 7b24c91884..273dca847f 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -498,7 +498,7 @@ def get_domain_range(*args): "lift", "make_const_list", "make_tuple", - "map_tuple", + "tree_map", "map_", # TODO: rename to map_list "named_range", "neighbors", diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 4b30e878fe..9545525a94 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -629,6 +629,11 @@ def map_(op): return call(call("map_")(op)) +def tree_map(op): + """Create a `tree_map` call: tree_map(op)(tup1, tup2, ...).""" + return call(call("tree_map")(op)) + + def reduce(op, expr): """Create a `reduce` call.""" return call(call("reduce")(op, expr)) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 8825ad00ed..5feba5be9e 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -23,7 +23,7 @@ prune_empty_concat_where, remove_broadcast, symbol_ref_utils, - unroll_map_tuple, + unroll_tree_map, ) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple @@ -177,7 +177,7 @@ def apply_common_transforms( ) # domain inference does not support dynamic offsets yet ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) - ir = unroll_map_tuple.UnrollMapTuple.apply(ir, uids=uids) + ir = unroll_tree_map.UnrollTreeMap.apply(ir, uids=uids) ir = infer_domain.infer_program( ir, @@ -292,7 +292,7 @@ def apply_fieldview_transforms( ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) - ir = unroll_map_tuple.UnrollMapTuple.apply(ir, uids=uids) + ir = unroll_tree_map.UnrollTreeMap.apply(ir, uids=uids) ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( diff --git a/src/gt4py/next/iterator/transforms/unroll_map_tuple.py b/src/gt4py/next/iterator/transforms/unroll_map_tuple.py deleted file mode 100644 index 66f96d66fa..0000000000 --- a/src/gt4py/next/iterator/transforms/unroll_map_tuple.py +++ /dev/null @@ -1,47 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause -import dataclasses - -from gt4py import eve -from gt4py.next import utils -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next.iterator.type_system import inference as itir_inference -from gt4py.next.type_system import type_specifications as ts - - -@dataclasses.dataclass -class UnrollMapTuple(eve.NodeTranslator): - PRESERVED_ANNEX_ATTRS = ("domain",) - - uids: utils.IDGeneratorPool - - @classmethod - def apply(cls, program: itir.Program, *, uids: utils.IDGeneratorPool): - return cls(uids=uids).visit(program) - - def visit_FunCall(self, node: itir.Expr): - node = self.generic_visit(node) - - if cpm.is_call_to(node.fun, "map_tuple"): - # TODO: we have to duplicate the function here since the domain inference can not handle them yet - f = node.fun.args[0] - tup = node.args[0] - itir_inference.reinfer(tup) - assert isinstance(tup.type, ts.TupleType) - tup_ref = next(self.uids["_ump"]) - - result = im.let(tup_ref, tup)( - im.make_tuple( - *(im.call(f)(im.tuple_get(i, tup_ref)) for i in range(len(tup.type.types))) - ) - ) - itir_inference.reinfer(result) - - return result - return node diff --git a/src/gt4py/next/iterator/transforms/unroll_tree_map.py b/src/gt4py/next/iterator/transforms/unroll_tree_map.py new file mode 100644 index 0000000000..c341797216 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/unroll_tree_map.py @@ -0,0 +1,69 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +import dataclasses + +from gt4py import eve +from gt4py.next import utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.type_system import inference as itir_inference +from gt4py.next.type_system import type_specifications as ts + + +def _unroll( + f: itir.Expr, + tup_types: list[ts.TupleType], + tup_exprs: list[itir.Expr], +) -> itir.Expr: + """Recursively expand ``tree_map(f)(tup0, tup1, ...)`` into ``make_tuple`` / ``tuple_get``.""" + n = len(tup_types[0].types) + + elements: list[itir.Expr] = [] + for i in range(n): + child_types = [t.types[i] for t in tup_types] + child_exprs = [im.tuple_get(i, e) for e in tup_exprs] + + if all(isinstance(ct, ts.TupleType) for ct in child_types): + elements.append(_unroll(f, child_types, child_exprs)) # type: ignore[arg-type] + else: + elements.append(im.call(f)(*child_exprs)) + + return im.make_tuple(*elements) + + +@dataclasses.dataclass +class UnrollTreeMap(eve.NodeTranslator): + PRESERVED_ANNEX_ATTRS = ("domain",) + + uids: utils.IDGeneratorPool + + @classmethod + def apply(cls, program: itir.Program, *, uids: utils.IDGeneratorPool): + return cls(uids=uids).visit(program) + + def visit_FunCall(self, node: itir.Expr): + node = self.generic_visit(node) + + if not cpm.is_call_to(node.fun, "tree_map"): + return node + + f = node.fun.args[0] + tup_args = node.args + for tup in tup_args: + itir_inference.reinfer(tup) + assert isinstance(tup.type, ts.TupleType) + + tup_refs = [next(self.uids["_utm"]) for _ in tup_args] + body = _unroll(f, [tup.type for tup in tup_args], [im.ref(r) for r in tup_refs]) + + result = body + for ref_name, tup in reversed(list(zip(tup_refs, tup_args))): + result = im.let(ref_name, tup)(result) + + itir_inference.reinfer(result) + return result diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 98f3540d91..02b7a52c3b 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -634,14 +634,22 @@ def applied_map( @_register_builtin_type_synthesizer -def map_tuple(op: TypeSynthesizer) -> TypeSynthesizer: +def tree_map(op: TypeSynthesizer) -> TypeSynthesizer: @type_synthesizer def applied_map( - arg: ts.TupleType, offset_provider_type: common.OffsetProviderType + *args: ts.TupleType, offset_provider_type: common.OffsetProviderType ) -> ts.TupleType: - return ts.TupleType( - types=[op(arg_, offset_provider_type=offset_provider_type) for arg_ in arg.types] - ) + def _recurse(*arg_types: ts.TypeSpec) -> ts.TypeSpec: + if isinstance(arg_types[0], ts.TupleType): + return ts.TupleType( + types=[ + _recurse(*(a.types[i] for a in arg_types)) # type: ignore[union-attr] + for i in range(len(arg_types[0].types)) # type: ignore[union-attr] + ] + ) + return op(*arg_types, offset_provider_type=offset_provider_type) + + return _recurse(*args) return applied_map diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 69fccd33da..eb70d15947 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -566,14 +566,6 @@ def is_concretizable(symbol_type: ts.TypeSpec, to_type: ts.TypeSpec) -> bool: or issubclass(type_class(to_type), symbol_type.constraint) ): return True - if isinstance(symbol_type, ts.VarArgType) and isinstance(to_type, ts.VarArgType): - return is_concretizable(symbol_type.element_type, to_type.element_type) - if isinstance(symbol_type, ts.VarArgType) and isinstance(to_type, ts.TupleType): - if len(to_type.types) == 0 or ( - all(type_ == to_type.types[0] for type_ in to_type.types) - and is_concretizable(symbol_type.element_type, to_type.types[0]) - ): - return True elif is_concrete(symbol_type): return symbol_type == to_type return False diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 409138d593..59ac40f0f3 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -148,15 +148,6 @@ def __len__(self) -> int: return len(self.types) -class VarArgType(DataType): - """Represents a variable number of arguments of the same type.""" - - element_type: DataType # TODO: maybe also support different DataTypes - - def __str__(self) -> str: - return f"VarArg[{self.element_type}]" - - class AnyPythonType: """Marker type representing any Python type which cannot be used for instantiation. diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 1d7a9aa2f7..3671c5b344 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -181,12 +181,8 @@ def from_type_hint( case builtins.tuple: if not args: raise ValueError(f"Tuple annotation '{type_hint}' requires at least one argument.") - if len(args) == 2 and args[1] is Ellipsis: - return ts.VarArgType(element_type=from_type_hint_same_ns(args[0])) - elif Ellipsis in args: - raise ValueError( - f"Vararg tuple annotation '{type_hint}' cannot have more than one argument." - ) + if Ellipsis in args: + raise ValueError(f"Unbound tuples '{type_hint}' are not allowed.") tuple_types = [from_type_hint_same_ns(arg) for arg in args] assert all(isinstance(elem, ts.DataType) for elem in tuple_types) return ts.TupleType(types=tuple_types) @@ -330,19 +326,7 @@ def from_value(value: Any) -> ts.TypeSpec: return NamespaceProxy(value) else: type_ = xtyping.infer_type(value, annotate_callable_kwargs=True) - if type_ == type[tuple]: - # TODO: this special casing here is not nice, but infer_type is also called on the annotations where - # we don't want to allow unparameterized tuples (or do we?). - symbol_type = ts.ConstructorType( - definition=ts.FunctionType( - pos_only_args=[ts.DeferredType(constraint=None)], - pos_or_kw_args={}, - kw_only_args={}, - returns=ts.DeferredType(constraint=ts.VarArgType), - ) - ) - else: - symbol_type = from_type_hint(type_) + symbol_type = from_type_hint(type_) if isinstance(symbol_type, (ts.DataType, ts.CallableType, ts.OffsetType, ts.DimensionType)): return symbol_type diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index d65ecefb10..d552a09a2a 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -603,15 +603,6 @@ def _allocate_from_type( for t in types ) ) - case ts.VarArgType(element_type=element_type): - return tuple( - ( - _allocate_from_type( - case=case, arg_type=t, domain=domain, dtype=dtype, strategy=strategy - ) - for t in [element_type] * 3 # TODO: revisit - ) - ) case ts.NamedCollectionType(types=types) as named_collection_type_spec: container_constructor = ( named_collections.make_named_collection_constructor_from_type_spec( @@ -657,8 +648,6 @@ def get_param_size(param_type: ts.TypeSpec, sizes: dict[gtx.Dimension, int]) -> return sum([get_param_size(t, sizes=sizes) for t in types]) case ts.NamedCollectionType(types=types): return sum([get_param_size(t, sizes=sizes) for t in types]) - case ts.VarArgType(element_type=element_type): - return get_param_size(ts.TupleType(types=[element_type] * 3), sizes) # TODO: revisit case _: raise TypeError(f"Can not get size for parameter of type '{param_type}'.") diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 986fa5f5cb..c58ac5f497 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -338,36 +338,6 @@ def testee(a: tuple[cases.IField, cases.IJField]) -> cases.IJField: ) -@pytest.mark.uses_tuple_args -def test_tuple_comprehension(cartesian_case): - @gtx.field_operator - def testee( - tracers: tuple[cases.IFloatField, ...], factor: float - ) -> tuple[cases.IFloatField, ...]: - return tuple(tracer * factor for tracer in tracers) - - cases.verify_with_default_data( - cartesian_case, - testee, - ref=lambda t, f: tuple(el * f for el in t), - ) - - -@pytest.mark.uses_tuple_args -def test_tuple_vararg(cartesian_case): - @gtx.field_operator - def testee( - tracers: tuple[cases.IFloatField, ...], factor: float - ) -> tuple[cases.IFloatField, cases.IFloatField]: - return tracers[0] * factor, tracers[1] * factor - - cases.verify_with_default_data( - cartesian_case, - testee, - ref=lambda t, f: tuple(el * f for el in t[:2]), - ) - - @pytest.mark.uses_tuple_args @pytest.mark.xfail(reason="Iterator of tuple approach in lowering does not allow this.") def test_tuple_arg_with_unpromotable_dims(unstructured_case): diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 8e3bba90b9..e516d7ddbd 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -207,10 +207,11 @@ def foo( lowered ) # we generate a let for the condition which is removed by inlining for easier testing - reference = im.make_tuple( - im.op_as_fieldop("if_")("a", im.tuple_get(0, "b"), im.tuple_get(0, "c")), - im.op_as_fieldop("if_")("a", im.tuple_get(1, "b"), im.tuple_get(1, "c")), - ) + reference = im.tree_map( # TODO: check if this is what we want + im.lambda_("__a", "__b")( + im.op_as_fieldop("if_")("a", im.ref("__a"), im.ref("__b")) + ) + )("b", "c") assert lowered_inlined.expr == reference diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py new file mode 100644 index 0000000000..9f1a379236 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py @@ -0,0 +1,43 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.unroll_tree_map import _unroll +from gt4py.next.type_system import type_specifications as ts + + +T = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) +TT = ts.TupleType(types=[T, T]) + + +def test_single_arg(): + result = _unroll(im.ref("f"), [TT], [im.ref("t")]) + expected = im.make_tuple(im.call("f")(im.tuple_get(0, "t")), im.call("f")(im.tuple_get(1, "t"))) + assert result == expected + + +def test_multi_arg(): + result = _unroll(im.ref("f"), [TT, TT], [im.ref("a"), im.ref("b")]) + expected = im.make_tuple( + im.call("f")(im.tuple_get(0, "a"), im.tuple_get(0, "b")), + im.call("f")(im.tuple_get(1, "a"), im.tuple_get(1, "b")), + ) + assert result == expected + + +def test_nested(): + outer = ts.TupleType(types=[TT, T]) + result = _unroll(im.ref("f"), [outer], [im.ref("t")]) + expected = im.make_tuple( + im.make_tuple( + im.call("f")(im.tuple_get(0, im.tuple_get(0, "t"))), + im.call("f")(im.tuple_get(1, im.tuple_get(0, "t"))), + ), + im.call("f")(im.tuple_get(1, "t")), + ) + assert result == expected From 0ec4692ac28ecd536e4756d5304a54b4b649a7fb Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Mon, 27 Apr 2026 16:57:07 +0200 Subject: [PATCH 03/24] Run pre-commit and fix some tests --- src/gt4py/next/ffront/foast_to_gtir.py | 1 - src/gt4py/next/iterator/builtins.py | 5 +++++ .../next/iterator/transforms/unroll_tree_map.py | 9 ++++++--- .../next/iterator/type_system/type_synthesizer.py | 13 +++++++------ .../unit_tests/ffront_tests/test_foast_to_gtir.py | 6 ++---- 5 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 78b0671db1..c341a311b1 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -416,7 +416,6 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: false_ = self.visit(node.args[2]) cond_symref_name = f"__cond_{cond_.fingerprint()}" - # tree_map(lambda a, b: as_fieldop(if_)(cond_ref, a, b))(true_tup, false_tup) result = im.tree_map( im.lambda_("__a", "__b")( im.op_as_fieldop("if_")(im.ref(cond_symref_name), im.ref("__a"), im.ref("__b")) diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 273dca847f..b60932eed8 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -57,6 +57,11 @@ def map_(*args): raise BackendNotSelectedError() +@builtin_dispatch +def tree_map(*args): + raise BackendNotSelectedError() + + @builtin_dispatch def make_const_list(*args): raise BackendNotSelectedError() diff --git a/src/gt4py/next/iterator/transforms/unroll_tree_map.py b/src/gt4py/next/iterator/transforms/unroll_tree_map.py index c341797216..03d12355ec 100644 --- a/src/gt4py/next/iterator/transforms/unroll_tree_map.py +++ b/src/gt4py/next/iterator/transforms/unroll_tree_map.py @@ -29,7 +29,8 @@ def _unroll( child_exprs = [im.tuple_get(i, e) for e in tup_exprs] if all(isinstance(ct, ts.TupleType) for ct in child_types): - elements.append(_unroll(f, child_types, child_exprs)) # type: ignore[arg-type] + nested_types = [ct for ct in child_types if isinstance(ct, ts.TupleType)] + elements.append(_unroll(f, nested_types, child_exprs)) else: elements.append(im.call(f)(*child_exprs)) @@ -46,7 +47,7 @@ class UnrollTreeMap(eve.NodeTranslator): def apply(cls, program: itir.Program, *, uids: utils.IDGeneratorPool): return cls(uids=uids).visit(program) - def visit_FunCall(self, node: itir.Expr): + def visit_FunCall(self, node: itir.FunCall): node = self.generic_visit(node) if not cpm.is_call_to(node.fun, "tree_map"): @@ -54,12 +55,14 @@ def visit_FunCall(self, node: itir.Expr): f = node.fun.args[0] tup_args = node.args + tup_types: list[ts.TupleType] = [] for tup in tup_args: itir_inference.reinfer(tup) assert isinstance(tup.type, ts.TupleType) + tup_types.append(tup.type) tup_refs = [next(self.uids["_utm"]) for _ in tup_args] - body = _unroll(f, [tup.type for tup in tup_args], [im.ref(r) for r in tup_refs]) + body = _unroll(f, tup_types, [im.ref(r) for r in tup_refs]) result = body for ref_name, tup in reversed(list(zip(tup_refs, tup_args))): diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 02b7a52c3b..4d5fe5e6d0 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -633,23 +633,24 @@ def applied_map( return applied_map -@_register_builtin_type_synthesizer -def tree_map(op: TypeSynthesizer) -> TypeSynthesizer: +@_register_builtin_type_synthesizer(fun_names=["tree_map"]) +def _tree_map(op: TypeSynthesizer) -> TypeSynthesizer: @type_synthesizer def applied_map( *args: ts.TupleType, offset_provider_type: common.OffsetProviderType ) -> ts.TupleType: def _recurse(*arg_types: ts.TypeSpec) -> ts.TypeSpec: if isinstance(arg_types[0], ts.TupleType): + tup_types = [a for a in arg_types if isinstance(a, ts.TupleType)] return ts.TupleType( types=[ - _recurse(*(a.types[i] for a in arg_types)) # type: ignore[union-attr] - for i in range(len(arg_types[0].types)) # type: ignore[union-attr] + _recurse(*(a.types[i] for a in tup_types)) + for i in range(len(arg_types[0].types)) ] ) - return op(*arg_types, offset_provider_type=offset_provider_type) + return op(*arg_types, offset_provider_type=offset_provider_type) # type: ignore[return-value] - return _recurse(*args) + return _recurse(*args) # type: ignore[return-value] return applied_map diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index e516d7ddbd..726be8051b 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -207,10 +207,8 @@ def foo( lowered ) # we generate a let for the condition which is removed by inlining for easier testing - reference = im.tree_map( # TODO: check if this is what we want - im.lambda_("__a", "__b")( - im.op_as_fieldop("if_")("a", im.ref("__a"), im.ref("__b")) - ) + reference = im.tree_map( # TODO: check if this is what we want + im.lambda_("__a", "__b")(im.op_as_fieldop("if_")("a", im.ref("__a"), im.ref("__b"))) )("b", "c") assert lowered_inlined.expr == reference From ab84ecc0597d35a0c06724f67d1990c957858768 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 28 Apr 2026 12:42:37 +0200 Subject: [PATCH 04/24] Run CollapseTuple after UnrollTreeMap --- .../next/iterator/transforms/pass_manager.py | 30 +++++++++++++++++++ .../ffront_tests/test_foast_to_gtir.py | 2 +- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 5feba5be9e..6b98ccbae6 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -179,6 +179,23 @@ def apply_common_transforms( ir = concat_where.canonicalize_domain_argument(ir) ir = unroll_tree_map.UnrollTreeMap.apply(ir, uids=uids) + # After UnrollTreeMap, collapse `tuple_get(i, let(...)(make_tuple(...)))` patterns so that + # domain inference does not encounter `as_fieldop` nodes inside dead tuple elements + # (which would receive NEVER domain). Do multiple iterations for nested `let`s. + for _ in range(10): + collapsed = ir + ir = CollapseTuple.apply( + ir, + enabled_transformations=( + CollapseTuple.Transformation.PROPAGATE_TUPLE_GET + | CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE + ), + uids=uids, + offset_provider_type=offset_provider_type, + ) # type: ignore[assignment] # always an itir.Program + if ir == collapsed: + break + ir = infer_domain.infer_program( ir, offset_provider=offset_provider, @@ -293,6 +310,19 @@ def apply_fieldview_transforms( ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) ir = unroll_tree_map.UnrollTreeMap.apply(ir, uids=uids) + for _ in range(10): + prev = ir + ir = CollapseTuple.apply( + ir, + enabled_transformations=( + CollapseTuple.Transformation.PROPAGATE_TUPLE_GET + | CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE + ), + uids=uids, + offset_provider_type=offset_provider_type, + ) # type: ignore[assignment] # always an itir.Program + if ir == prev: + break ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 726be8051b..bf2978a8f2 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -207,7 +207,7 @@ def foo( lowered ) # we generate a let for the condition which is removed by inlining for easier testing - reference = im.tree_map( # TODO: check if this is what we want + reference = im.tree_map( im.lambda_("__a", "__b")(im.op_as_fieldop("if_")("a", im.ref("__a"), im.ref("__b"))) )("b", "c") From 152300e75e46def284a98bf8a19bca5dc35435de Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 28 Apr 2026 15:27:33 +0200 Subject: [PATCH 05/24] Address review comments --- .../next/iterator/transforms/pass_manager.py | 5 ++ .../iterator/transforms/unroll_tree_map.py | 17 +++- .../iterator/type_system/type_synthesizer.py | 26 ++++-- .../iterator_tests/test_type_inference.py | 17 ++++ .../transforms_tests/test_unroll_tree_map.py | 85 +++++++++++++++---- 5 files changed, 124 insertions(+), 26 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 6b98ccbae6..32d847841b 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -195,6 +195,8 @@ def apply_common_transforms( ) # type: ignore[assignment] # always an itir.Program if ir == collapsed: break + else: + raise RuntimeError("'CollapseTuple' did not converge after `UnrollTreeMap`.") ir = infer_domain.infer_program( ir, @@ -323,6 +325,9 @@ def apply_fieldview_transforms( ) # type: ignore[assignment] # always an itir.Program if ir == prev: break + else: + raise RuntimeError("'CollapseTuple' did not converge after `UnrollTreeMap`.") + ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( diff --git a/src/gt4py/next/iterator/transforms/unroll_tree_map.py b/src/gt4py/next/iterator/transforms/unroll_tree_map.py index 03d12355ec..8ef9dad7ec 100644 --- a/src/gt4py/next/iterator/transforms/unroll_tree_map.py +++ b/src/gt4py/next/iterator/transforms/unroll_tree_map.py @@ -21,18 +21,31 @@ def _unroll( tup_exprs: list[itir.Expr], ) -> itir.Expr: """Recursively expand ``tree_map(f)(tup0, tup1, ...)`` into ``make_tuple`` / ``tuple_get``.""" + assert tup_types, "tree_map requires at least one tuple argument." n = len(tup_types[0].types) + if any(len(t.types) != n for t in tup_types[1:]): + raise ValueError( + f"All tree_map arguments must have the same tuple structure at each level, " + f"got {[len(t.types) for t in tup_types]}." + ) elements: list[itir.Expr] = [] for i in range(n): child_types = [t.types[i] for t in tup_types] child_exprs = [im.tuple_get(i, e) for e in tup_exprs] - if all(isinstance(ct, ts.TupleType) for ct in child_types): + all_tuples = all(isinstance(ct, ts.TupleType) for ct in child_types) + all_leaves = all(not isinstance(ct, ts.TupleType) for ct in child_types) + if all_tuples: nested_types = [ct for ct in child_types if isinstance(ct, ts.TupleType)] elements.append(_unroll(f, nested_types, child_exprs)) - else: + elif all_leaves: elements.append(im.call(f)(*child_exprs)) + else: + raise ValueError( + "All tree_map arguments must have the same tree structure " + "(all leaves must be reached simultaneously)." + ) return im.make_tuple(*elements) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 4d5fe5e6d0..38b9ce1bbf 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -639,16 +639,30 @@ def _tree_map(op: TypeSynthesizer) -> TypeSynthesizer: def applied_map( *args: ts.TupleType, offset_provider_type: common.OffsetProviderType ) -> ts.TupleType: + if not args: + raise TypeError("tree_map requires at least one argument.") + def _recurse(*arg_types: ts.TypeSpec) -> ts.TypeSpec: - if isinstance(arg_types[0], ts.TupleType): + all_tuples = all(isinstance(a, ts.TupleType) for a in arg_types) + all_leaves = all(not isinstance(a, ts.TupleType) for a in arg_types) + if all_tuples: tup_types = [a for a in arg_types if isinstance(a, ts.TupleType)] + n = len(tup_types[0].types) + if any(len(t.types) != n for t in tup_types[1:]): + raise TypeError( + f"All tree_map arguments must have the same tuple structure at each level, " + f"got {[len(t.types) for t in tup_types]}." + ) return ts.TupleType( - types=[ - _recurse(*(a.types[i] for a in tup_types)) - for i in range(len(arg_types[0].types)) - ] + types=[_recurse(*(a.types[i] for a in tup_types)) for i in range(n)] + ) + elif all_leaves: + return op(*arg_types, offset_provider_type=offset_provider_type) # type: ignore[return-value] + else: + raise TypeError( + "All tree_map arguments must have the same tree structure " + "(all leaves must be reached simultaneously)." ) - return op(*arg_types, offset_provider_type=offset_provider_type) # type: ignore[return-value] return _recurse(*args) # type: ignore[return-value] diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index d73fc1945f..f901a7ce39 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -155,6 +155,23 @@ def expression_test_cases(): ), ts.ListType(element_type=int_type, offset_type=V2EDim), ), + # tree_map + ( + im.tree_map(im.ref("plus"))( + im.ref("t1", ts.TupleType(types=[int_type, int_type])), + im.ref("t2", ts.TupleType(types=[int_type, int_type])), + ), + ts.TupleType(types=[int_type, int_type]), + ), + ( + im.tree_map(im.ref("not_"))( + im.ref( + "t", + ts.TupleType(types=[bool_type, ts.TupleType(types=[bool_type, bool_type])]), + ), + ), + ts.TupleType(types=[bool_type, ts.TupleType(types=[bool_type, bool_type])]), + ), # reduce (im.reduce("plus", 0)(im.ref("l", int_list_type)), int_type), ( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py index 9f1a379236..0c0f348d88 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py @@ -6,38 +6,87 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from gt4py.next import common, utils +from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms.unroll_tree_map import _unroll +from gt4py.next.iterator.transforms.unroll_tree_map import UnrollTreeMap from gt4py.next.type_system import type_specifications as ts - +IDim = common.Dimension("IDim") T = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) -TT = ts.TupleType(types=[T, T]) +i_field = ts.FieldType(dims=[IDim], dtype=T) +i_tuple_field = ts.TupleType(types=[i_field, i_field]) +i_nested_tuple_field = ts.TupleType(types=[i_tuple_field, i_field]) +i_domain = im.call("cartesian_domain")(im.named_range(itir.AxisLiteral(value="IDim"), 0, 1)) + + +def _make_program(params: list[itir.Sym], expr: itir.Expr) -> itir.Program: + return itir.Program( + id="testee", + function_definitions=[], + params=[*params, im.sym("out", i_field)], + declarations=[], + body=[ + itir.SetAt( + expr=expr, + domain=i_domain, + target=im.ref("out"), + ) + ], + ) + + +def _neg(): + return im.lambda_("__a")(im.op_as_fieldop("neg")("__a")) -def test_single_arg(): - result = _unroll(im.ref("f"), [TT], [im.ref("t")]) - expected = im.make_tuple(im.call("f")(im.tuple_get(0, "t")), im.call("f")(im.tuple_get(1, "t"))) - assert result == expected + +def _plus(): + return im.lambda_("__a", "__b")(im.op_as_fieldop("plus")("__a", "__b")) def test_multi_arg(): - result = _unroll(im.ref("f"), [TT, TT], [im.ref("a"), im.ref("b")]) - expected = im.make_tuple( - im.call("f")(im.tuple_get(0, "a"), im.tuple_get(0, "b")), - im.call("f")(im.tuple_get(1, "a"), im.tuple_get(1, "b")), + uids = utils.IDGeneratorPool() + program = _make_program( + [im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)], + im.call(im.call("tree_map")(_plus()))( + im.ref("a", i_tuple_field), im.ref("b", i_tuple_field) + ), + ) + result = UnrollTreeMap.apply(program, uids=uids) + + expected = _make_program( + [im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)], + im.let("_utm_0", "a")( + im.let("_utm_1", "b")( + im.make_tuple( + im.call(_plus())(im.tuple_get(0, "_utm_0"), im.tuple_get(0, "_utm_1")), + im.call(_plus())(im.tuple_get(1, "_utm_0"), im.tuple_get(1, "_utm_1")), + ) + ) + ), ) assert result == expected def test_nested(): - outer = ts.TupleType(types=[TT, T]) - result = _unroll(im.ref("f"), [outer], [im.ref("t")]) - expected = im.make_tuple( - im.make_tuple( - im.call("f")(im.tuple_get(0, im.tuple_get(0, "t"))), - im.call("f")(im.tuple_get(1, im.tuple_get(0, "t"))), + uids = utils.IDGeneratorPool() + program = _make_program( + [im.sym("t", i_nested_tuple_field)], + im.call(im.call("tree_map")(_neg()))(im.ref("t", i_nested_tuple_field)), + ) + result = UnrollTreeMap.apply(program, uids=uids) + + expected = _make_program( + [im.sym("t", i_nested_tuple_field)], + im.let("_utm_0", "t")( + im.make_tuple( + im.make_tuple( + im.call(_neg())(im.tuple_get(0, im.tuple_get(0, "_utm_0"))), + im.call(_neg())(im.tuple_get(1, im.tuple_get(0, "_utm_0"))), + ), + im.call(_neg())(im.tuple_get(1, "_utm_0")), + ) ), - im.call("f")(im.tuple_get(1, "t")), ) assert result == expected From d459b0e7a05aa42c18878f3fe10c2ffecd3a6239 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 28 Apr 2026 16:00:38 +0200 Subject: [PATCH 06/24] Address further review comments --- .../next/iterator/type_system/type_synthesizer.py | 5 +++++ .../transforms_tests/test_unroll_tree_map.py | 12 +++++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 38b9ce1bbf..209d6ecf80 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -641,6 +641,11 @@ def applied_map( ) -> ts.TupleType: if not args: raise TypeError("tree_map requires at least one argument.") + if not all(isinstance(a, ts.TupleType) for a in args): + raise TypeError( + "tree_map requires all top-level arguments to be TupleType, " + f"got {[type(a).__name__ for a in args]}." + ) def _recurse(*arg_types: ts.TypeSpec) -> ts.TypeSpec: all_tuples = all(isinstance(a, ts.TupleType) for a in arg_types) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py index 0c0f348d88..f0b4165f1a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py @@ -21,17 +21,19 @@ i_domain = im.call("cartesian_domain")(im.named_range(itir.AxisLiteral(value="IDim"), 0, 1)) -def _make_program(params: list[itir.Sym], expr: itir.Expr) -> itir.Program: +def _make_program( + params: list[itir.Sym], expr: itir.Expr, out_type: ts.TypeSpec = i_field +) -> itir.Program: return itir.Program( id="testee", function_definitions=[], - params=[*params, im.sym("out", i_field)], + params=[*params, im.sym("out", out_type)], declarations=[], body=[ itir.SetAt( expr=expr, domain=i_domain, - target=im.ref("out"), + target=im.ref("out", out_type), ) ], ) @@ -52,6 +54,7 @@ def test_multi_arg(): im.call(im.call("tree_map")(_plus()))( im.ref("a", i_tuple_field), im.ref("b", i_tuple_field) ), + out_type=i_tuple_field, ) result = UnrollTreeMap.apply(program, uids=uids) @@ -65,6 +68,7 @@ def test_multi_arg(): ) ) ), + out_type=i_tuple_field, ) assert result == expected @@ -74,6 +78,7 @@ def test_nested(): program = _make_program( [im.sym("t", i_nested_tuple_field)], im.call(im.call("tree_map")(_neg()))(im.ref("t", i_nested_tuple_field)), + out_type=i_nested_tuple_field, ) result = UnrollTreeMap.apply(program, uids=uids) @@ -88,5 +93,6 @@ def test_nested(): im.call(_neg())(im.tuple_get(1, "_utm_0")), ) ), + out_type=i_nested_tuple_field, ) assert result == expected From 97af81e1f3a91dc6a048371ba23d664735ea05a5 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 29 Apr 2026 15:24:01 +0200 Subject: [PATCH 07/24] Apply review comments --- .../iterator/transforms/unroll_tree_map.py | 55 ++++++------------- .../iterator/type_system/type_synthesizer.py | 36 ++++-------- .../transforms_tests/test_unroll_tree_map.py | 10 ++-- 3 files changed, 30 insertions(+), 71 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/unroll_tree_map.py b/src/gt4py/next/iterator/transforms/unroll_tree_map.py index 8ef9dad7ec..d2b0b7b642 100644 --- a/src/gt4py/next/iterator/transforms/unroll_tree_map.py +++ b/src/gt4py/next/iterator/transforms/unroll_tree_map.py @@ -6,6 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause import dataclasses +import functools from gt4py import eve from gt4py.next import utils @@ -15,41 +16,6 @@ from gt4py.next.type_system import type_specifications as ts -def _unroll( - f: itir.Expr, - tup_types: list[ts.TupleType], - tup_exprs: list[itir.Expr], -) -> itir.Expr: - """Recursively expand ``tree_map(f)(tup0, tup1, ...)`` into ``make_tuple`` / ``tuple_get``.""" - assert tup_types, "tree_map requires at least one tuple argument." - n = len(tup_types[0].types) - if any(len(t.types) != n for t in tup_types[1:]): - raise ValueError( - f"All tree_map arguments must have the same tuple structure at each level, " - f"got {[len(t.types) for t in tup_types]}." - ) - - elements: list[itir.Expr] = [] - for i in range(n): - child_types = [t.types[i] for t in tup_types] - child_exprs = [im.tuple_get(i, e) for e in tup_exprs] - - all_tuples = all(isinstance(ct, ts.TupleType) for ct in child_types) - all_leaves = all(not isinstance(ct, ts.TupleType) for ct in child_types) - if all_tuples: - nested_types = [ct for ct in child_types if isinstance(ct, ts.TupleType)] - elements.append(_unroll(f, nested_types, child_exprs)) - elif all_leaves: - elements.append(im.call(f)(*child_exprs)) - else: - raise ValueError( - "All tree_map arguments must have the same tree structure " - "(all leaves must be reached simultaneously)." - ) - - return im.make_tuple(*elements) - - @dataclasses.dataclass class UnrollTreeMap(eve.NodeTranslator): PRESERVED_ANNEX_ATTRS = ("domain",) @@ -75,11 +41,22 @@ def visit_FunCall(self, node: itir.FunCall): tup_types.append(tup.type) tup_refs = [next(self.uids["_utm"]) for _ in tup_args] - body = _unroll(f, tup_types, [im.ref(r) for r in tup_refs]) - result = body - for ref_name, tup in reversed(list(zip(tup_refs, tup_args))): - result = im.let(ref_name, tup)(result) + @utils.tree_map( + collection_type=ts.TupleType, + result_collection_constructor=lambda _, elts: im.make_tuple(*elts), + with_path_arg=True, + ) + def mapper(*args): + *_el_types, path = args + return im.call(f)( + *( + functools.reduce(lambda expr, i: im.tuple_get(i, expr), path, im.ref(ref_name)) + for ref_name in tup_refs + ) + ) + + result = im.let(*zip(tup_refs, tup_args))(mapper(*tup_types)) itir_inference.reinfer(result) return result diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 209d6ecf80..6437e23973 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -20,7 +20,6 @@ from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.type_system import type_specifications as it_ts from gt4py.next.type_system import type_info, type_specifications as ts -from gt4py.next.utils import tree_map def _type_synth_arg_cache_key(type_or_synth: TypeOrTypeSynthesizer) -> int: @@ -203,7 +202,7 @@ def if_( pred: ts.ScalarType | ts.DeferredType, true_branch: ts.DataType, false_branch: ts.DataType ) -> ts.DataType: if isinstance(true_branch, ts.TupleType) and isinstance(false_branch, ts.TupleType): - return tree_map( + return utils.tree_map( collection_type=ts.TupleType, result_collection_constructor=lambda _, elts: ts.TupleType(types=[*elts]), )(functools.partial(if_, pred))(true_branch, false_branch) @@ -633,8 +632,8 @@ def applied_map( return applied_map -@_register_builtin_type_synthesizer(fun_names=["tree_map"]) -def _tree_map(op: TypeSynthesizer) -> TypeSynthesizer: +@_register_builtin_type_synthesizer +def tree_map(op: TypeSynthesizer) -> TypeSynthesizer: @type_synthesizer def applied_map( *args: ts.TupleType, offset_provider_type: common.OffsetProviderType @@ -647,29 +646,14 @@ def applied_map( f"got {[type(a).__name__ for a in args]}." ) - def _recurse(*arg_types: ts.TypeSpec) -> ts.TypeSpec: - all_tuples = all(isinstance(a, ts.TupleType) for a in arg_types) - all_leaves = all(not isinstance(a, ts.TupleType) for a in arg_types) - if all_tuples: - tup_types = [a for a in arg_types if isinstance(a, ts.TupleType)] - n = len(tup_types[0].types) - if any(len(t.types) != n for t in tup_types[1:]): - raise TypeError( - f"All tree_map arguments must have the same tuple structure at each level, " - f"got {[len(t.types) for t in tup_types]}." - ) - return ts.TupleType( - types=[_recurse(*(a.types[i] for a in tup_types)) for i in range(n)] - ) - elif all_leaves: - return op(*arg_types, offset_provider_type=offset_provider_type) # type: ignore[return-value] - else: - raise TypeError( - "All tree_map arguments must have the same tree structure " - "(all leaves must be reached simultaneously)." - ) + def leaf_op(*leaf_types: ts.TypeSpec) -> ts.TypeSpec: + return op(*leaf_types, offset_provider_type=offset_provider_type) # type: ignore[return-value] - return _recurse(*args) # type: ignore[return-value] + return utils.tree_map( # type: ignore[return-value] + leaf_op, + collection_type=ts.TupleType, + result_collection_constructor=lambda _, elts: ts.TupleType(types=[*elts]), + )(*args) return applied_map diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py index f0b4165f1a..3462ef4084 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py @@ -60,12 +60,10 @@ def test_multi_arg(): expected = _make_program( [im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)], - im.let("_utm_0", "a")( - im.let("_utm_1", "b")( - im.make_tuple( - im.call(_plus())(im.tuple_get(0, "_utm_0"), im.tuple_get(0, "_utm_1")), - im.call(_plus())(im.tuple_get(1, "_utm_0"), im.tuple_get(1, "_utm_1")), - ) + im.let(("_utm_0", "a"), ("_utm_1", "b"))( + im.make_tuple( + im.call(_plus())(im.tuple_get(0, "_utm_0"), im.tuple_get(0, "_utm_1")), + im.call(_plus())(im.tuple_get(1, "_utm_0"), im.tuple_get(1, "_utm_1")), ) ), out_type=i_tuple_field, From 32e5b2df848354048cdb6911f5a0d8c2166d3246 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 28 May 2026 11:59:58 +0200 Subject: [PATCH 08/24] Rename map_ -> map_list --- src/gt4py/next/ffront/foast_to_gtir.py | 4 ++-- src/gt4py/next/iterator/builtins.py | 4 ++-- src/gt4py/next/iterator/embedded.py | 4 ++-- .../next/iterator/ir_utils/common_pattern_matcher.py | 2 +- src/gt4py/next/iterator/ir_utils/ir_makers.py | 6 +++--- .../next/iterator/transforms/collapse_list_get.py | 2 +- src/gt4py/next/iterator/transforms/cse.py | 4 ++-- src/gt4py/next/iterator/transforms/fuse_maps.py | 6 +++--- src/gt4py/next/iterator/transforms/trace_shifts.py | 2 +- .../next/iterator/type_system/type_synthesizer.py | 2 +- .../runners/dace/lowering/gtir_dataflow.py | 10 +++++----- .../feature_tests/ffront_tests/test_execution.py | 2 +- .../iterator_tests/test_with_toy_connectivity.py | 6 +++--- .../unit_tests/ffront_tests/test_foast_to_gtir.py | 6 +++--- .../iterator_tests/test_embedded_field_with_list.py | 10 +++++----- .../unit_tests/iterator_tests/test_type_inference.py | 8 ++++---- .../iterator_tests/transforms_tests/test_fuse_maps.py | 4 ++-- .../runners_tests/dace_tests/test_gtir_to_sdfg.py | 6 +++--- 18 files changed, 44 insertions(+), 44 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index c341a311b1..a8603d78ee 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -526,7 +526,7 @@ def _map( original_arg_types: tuple[ts.TypeSpec, ...], ) -> itir.FunCall: """ - Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_`ing lists. + Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_list`ing lists. """ if all( isinstance(t, (ts.ScalarType, ts.DimensionType, ts.DomainType)) @@ -539,7 +539,7 @@ def _map( promote_to_list(arg_type)(larg) for arg_type, larg in zip(original_arg_types, lowered_args) ) - op = im.map_(op) + op = im.map_list(op) return im.op_as_fieldop(op)(*lowered_args) diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index b60932eed8..67ab347621 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -53,7 +53,7 @@ def neighbors(*args): @builtin_dispatch -def map_(*args): +def map_list(*args): raise BackendNotSelectedError() @@ -504,7 +504,7 @@ def get_domain_range(*args): "make_const_list", "make_tuple", "tree_map", - "map_", # TODO: rename to map_list + "map_list", "named_range", "neighbors", "reduce", diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index a9b36a4624..1052a34c58 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1455,8 +1455,8 @@ def _get_offset(*lists: _List | _ConstList) -> Optional[runtime.Offset]: raise AssertionError("All lists must have the same offset.") -@builtins.map_.register(EMBEDDED) -def map_(op): +@builtins.map_list.register(EMBEDDED) +def map_list(op): def impl_(*lists): offset = _get_offset(*lists) if offset is None: diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index da13d20bb6..c0090ed3a2 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -71,7 +71,7 @@ def is_applied_map(arg: itir.Node) -> TypeGuard[_FunCallToFunCallToRef]: isinstance(arg, itir.FunCall) and isinstance(arg.fun, itir.FunCall) and isinstance(arg.fun.fun, itir.SymRef) - and arg.fun.fun.id == "map_" + and arg.fun.fun.id == "map_list" ) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 9545525a94..cf9fc44e3f 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -624,9 +624,9 @@ def index(dim: common.Dimension) -> itir.FunCall: return call("index")(itir.AxisLiteral(value=dim.value, kind=dim.kind)) -def map_(op): - """Create a `map_` call.""" - return call(call("map_")(op)) +def map_list(op): + """Create a `map_list` call.""" + return call(call("map_list")(op)) def tree_map(op): diff --git a/src/gt4py/next/iterator/transforms/collapse_list_get.py b/src/gt4py/next/iterator/transforms/collapse_list_get.py index 4c4219bda4..c951dcfdec 100644 --- a/src/gt4py/next/iterator/transforms/collapse_list_get.py +++ b/src/gt4py/next/iterator/transforms/collapse_list_get.py @@ -44,7 +44,7 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.Node: if cpm.is_call_to(node.args[1], "make_const_list"): return node.args[1].args[0] if cpm.is_applied_map(node.args[1]): - # list_get(0, map_(λ(val_) → foo(val_, int64))(·__sym_1)) + # list_get(0, map_list(λ(val_) → foo(val_, int64))(·__sym_1)) # -> (λ(val_) → foo(val_, int64))(list_get(0, ·__sym_1)) lsts = node.args[1].args assert len(node.args[1].fun.args) == 1 # a single lambda in the map diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index a365cb25e3..ef58e527f2 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -95,7 +95,7 @@ def _is_collectable_expr(node: itir.Node) -> bool: if isinstance(node, itir.FunCall): # do not collect (and thus deduplicate in CSE) shift(offsets…) calls. Node must still be # visited, to ensure symbol dependencies are recognized correctly. - # do also not collect reduce, map_ and neighbors nodes if they are left in the IR at this point, this may lead to + # do also not collect reduce, map_list and neighbors nodes if they are left in the IR at this point, this may lead to # conceptual problems (other parts of the tool chain rely on the arguments being present directly # on the reduce FunCall node (connectivity deduction)), as well as problems with the imperative backend # backend (single pass eager depth first visit approach), see also https://github.com/GridTools/gt4py/issues/1795 @@ -104,7 +104,7 @@ def _is_collectable_expr(node: itir.Node) -> bool: # do also not collect index nodes because otherwise the right hand side of SetAts becomes a let statement # instead of an as_fieldop if cpm.is_call_to( - node, ("lift", "shift", "neighbors", "reduce", "map_", "index") + node, ("lift", "shift", "neighbors", "reduce", "map_list", "index") ) or cpm.is_applied_lift(node): return False return True diff --git a/src/gt4py/next/iterator/transforms/fuse_maps.py b/src/gt4py/next/iterator/transforms/fuse_maps.py index 4efbbe718b..69638861ff 100644 --- a/src/gt4py/next/iterator/transforms/fuse_maps.py +++ b/src/gt4py/next/iterator/transforms/fuse_maps.py @@ -18,7 +18,7 @@ @dataclasses.dataclass(frozen=True) class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator): """ - Fuses nested `map_`s. + Fuses nested `map_list`s. Preconditions: - `FunctionDefinitions` are inlined @@ -29,7 +29,7 @@ class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrai to map(λ(a, b, c) → f(a, g(b, c)))(a, b, c) - reduce(λ(x, y) → f(x, y), init)(map_(g(z, w))(a, b)) + reduce(λ(x, y) → f(x, y), init)(map_list(g(z, w))(a, b)) to reduce(λ(x, y, z) → f(x, g(y, z)), init)(a, b) """ @@ -93,7 +93,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): new_op = ir.Lambda(params=new_params, expr=new_body) if cpm.is_applied_map(node): return ir.FunCall( - fun=ir.FunCall(fun=ir.SymRef(id="map_"), args=[new_op]), args=new_args + fun=ir.FunCall(fun=ir.SymRef(id="map_list"), args=[new_op]), args=new_args ) else: # is_applied_reduce(node) return ir.FunCall( diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 8173ceebbb..b25c6400d4 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -259,7 +259,7 @@ def applied_as_fieldop(*args): "scan": _scan, "reduce": _reduce, "neighbors": _neighbors, - "map_": _map, + "map_list": _map, "if_": _if, "make_tuple": _make_tuple, } diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 6437e23973..53c72de837 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -614,7 +614,7 @@ def apply_scan( @_register_builtin_type_synthesizer -def map_(op: TypeSynthesizer) -> TypeSynthesizer: +def map_list(op: TypeSynthesizer) -> TypeSynthesizer: @type_synthesizer def applied_map( *args: ts.ListType, offset_provider_type: common.OffsetProviderType diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index da590d84e0..dbdf712a2f 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -1176,7 +1176,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: The map operation is applied on the local dimension of input fields. In the example below, the local dimension consists of a list of neighbor values as the first argument, and a list of constant values `1.0`: - `map_(plus)(neighbors(V2E, it), make_const_list(1.0))` + `map_list(plus)(neighbors(V2E, it), make_const_list(1.0))` The `plus` operation is lowered to a tasklet inside a map that computes the domain of the local dimension (in this example, max neighbors in V2E). @@ -1234,7 +1234,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: # The dataflow we build in this class has some loose connections on input edges. # These edges are described as set of nodes, that will have to be connected to # external data source nodes passing through the map entry node of the field map. - # Similarly to `neighbors` expressions, the `map_` input edges terminate on view + # Similarly to `neighbors` expressions, the `map_list` input edges terminate on view # nodes (see `_construct_local_view` in the for-loop below), because it is simpler # than representing map-to-map edges (which require memlets with 2 pass-nodes). input_memlets = {} @@ -1261,7 +1261,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: result_node = self.state.add_access(result) if conn_type.has_skip_values: - # In case the `map_` input expressions contain skip values, we use + # In case the `map_list` input expressions contain skip values, we use # the connectivity-based offset provider as mask for map computation. conn_data = gtx_dace_args.connectivity_identifier(offset_type.value) conn_desc = self.sdfg.arrays[conn_data] @@ -1764,12 +1764,12 @@ def _visit_generic_builtin(self, node: gtir.FunCall) -> ValueExpr: if isinstance(node.type, ts.ListType): # The only builtin function (so far) handled here that returns a list - # is 'make_const_list'. There are other builtin functions (map_, neighbors) + # is 'make_const_list'. There are other builtin functions (map_list, neighbors) # that return a list but they are handled in specialized visit methods. # This method (the generic visitor for builtin functions) always returns # a single value. This is also the case of 'make_const_list' expression: # it simply broadcasts a scalar on the local domain of another expression, - # for example 'map_(plus)(neighbors(V2Eₒ, it), make_const_list(1.0))'. + # for example 'map_list(plus)(neighbors(V2Eₒ, it), make_const_list(1.0))'. # Therefore we handle `ListType` as a single-element array with shape (1,) # that will be accessed in a map expression on a local domain. assert isinstance(node.type.element_type, ts.ScalarType) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index c58ac5f497..a70fea3b53 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -747,7 +747,7 @@ def testee(a: cases.VField) -> cases.VField: @pytest.mark.uses_unstructured_shift -@pytest.mark.xfail(reason="Not yet supported in lowering, requires `map_`ing of inner reduce op.") +@pytest.mark.xfail(reason="Not yet supported in lowering, requires `map_list`ing of inner reduce op.") def test_nested_reduction_shift_first(unstructured_case): @gtx.field_operator def testee(inp: cases.EField) -> cases.EField: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index ff87de7348..b937ae7fe1 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -16,7 +16,7 @@ lift, list_get, make_const_list, - map_, + map_list, multiplies, neighbors, plus, @@ -101,7 +101,7 @@ def test_sum_edges_to_vertices(program_processor, stencil): @fundef def map_neighbors(in_edges): - return reduce(plus, 0)(map_(plus)(neighbors(V2E, in_edges), neighbors(V2E, in_edges))) + return reduce(plus, 0)(map_list(plus)(neighbors(V2E, in_edges), neighbors(V2E, in_edges))) def test_map_neighbors(program_processor): @@ -123,7 +123,7 @@ def test_map_neighbors(program_processor): @fundef def map_make_const_list(in_edges): - return reduce(plus, 0)(map_(multiplies)(neighbors(V2E, in_edges), make_const_list(2))) + return reduce(plus, 0)(map_list(multiplies)(neighbors(V2E, in_edges), make_const_list(2))) @pytest.mark.uses_constant_fields diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index bf2978a8f2..99033c29de 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -299,7 +299,7 @@ def foo(a: gtx.Field[gtx.Dims[Vertex, V2EDim], float64]): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - reference = im.op_as_fieldop(im.map_(im.lambda_("val")(im.cast_("val", "int32"))))("a") + reference = im.op_as_fieldop(im.map_list(im.lambda_("val")(im.cast_("val", "int32"))))("a") assert lowered.expr == reference @@ -832,9 +832,9 @@ def foo(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], float64] parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - mapped = im.op_as_fieldop(im.map_("multiplies"))( + mapped = im.op_as_fieldop(im.map_list("multiplies"))( im.op_as_fieldop("make_const_list")(im.literal("1.1", "float64")), - im.op_as_fieldop(im.map_("plus"))(ssa.unique_name("e1_nbh", 0), "e2"), + im.op_as_fieldop(im.map_list("plus"))(ssa.unique_name("e1_nbh", 0), "e2"), ) reference = im.let( diff --git a/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py b/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py index 01a259fcec..3a6562e6be 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py @@ -17,7 +17,7 @@ deref, if_, make_const_list, - map_, + map_list, neighbors, plus, ) @@ -70,7 +70,7 @@ def testee(): def test_write_map_neighbors_and_const_list(): def testee(inp): domain = runtime.UnstructuredDomain({E: range(2)}) - return as_fieldop(lambda x, y: map_(plus)(deref(x), deref(y)), domain)( + return as_fieldop(lambda x, y: map_list(plus)(deref(x), deref(y)), domain)( as_fieldop(lambda it: neighbors(E2V, it), domain)(inp), as_fieldop(lambda: make_const_list(42.0), domain)(), ) @@ -86,7 +86,7 @@ def testee(inp): def test_write_map_conditional_neighbors_and_const_list(): def testee(inp, mask): domain = runtime.UnstructuredDomain({E: range(2)}) - return as_fieldop(lambda m, x, y: map_(if_)(deref(m), deref(x), deref(y)), domain)( + return as_fieldop(lambda m, x, y: map_list(if_)(deref(m), deref(x), deref(y)), domain)( as_fieldop(lambda it: make_const_list(deref(it)), domain)(mask), as_fieldop(lambda it: neighbors(E2V, it), domain)(inp), as_fieldop(lambda it: make_const_list(deref(it)), domain)(42.0), @@ -106,7 +106,7 @@ def testee(inp, mask): def test_write_non_mapped_conditional_neighbors_and_const_list(): """ This test-case demonstrates a non-supported pattern: - Current ITIR requires the `if_` to be `map_`ed, see `test_write_map_conditional_neighbors_and_const_list`. + Current ITIR requires the `if_` to be `map_list`ed, see `test_write_map_conditional_neighbors_and_const_list`. We keep it here for documenting corner cases of the `itir.List` implementation for future discussions. """ @@ -134,7 +134,7 @@ def testee(inp, mask): def test_write_map_const_list_and_const_list(): def testee(): domain = runtime.UnstructuredDomain({E: range(2)}) - return as_fieldop(lambda x, y: map_(plus)(deref(x), deref(y)), domain)( + return as_fieldop(lambda x, y: map_list(plus)(deref(x), deref(y)), domain)( as_fieldop(lambda: make_const_list(1.0), domain)(), as_fieldop(lambda: make_const_list(42.0), domain)(), ) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index f901a7ce39..1c3d6da88e 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -137,19 +137,19 @@ def expression_test_cases(): # TODO: scan # map ( - im.map_(im.ref("plus"))(im.ref("a", int_list_type), im.ref("b", int_list_type)), + im.map_list(im.ref("plus"))(im.ref("a", int_list_type), im.ref("b", int_list_type)), int_list_type, ), ( - im.map_(im.ref("plus"))(im.call("make_const_list")(1), im.ref("b", int_list_type)), + im.map_list(im.ref("plus"))(im.call("make_const_list")(1), im.ref("b", int_list_type)), int_list_type, ), ( - im.map_(im.ref("plus"))(im.ref("a", int_list_type), im.call("make_const_list")(1)), + im.map_list(im.ref("plus"))(im.ref("a", int_list_type), im.call("make_const_list")(1)), int_list_type, ), ( - im.map_(im.ref("plus"))( + im.map_list(im.ref("plus"))( im.ref("a", int_list_type), im.ref("b", ts.ListType(element_type=int_type, offset_type=V2EDim)), ), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_maps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_maps.py index c64ab93b6a..c54de17959 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_maps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_maps.py @@ -13,11 +13,11 @@ def _map(op: ir.Expr, *args: ir.Expr) -> ir.FunCall: - return ir.FunCall(fun=ir.FunCall(fun=ir.SymRef(id="map_"), args=[op]), args=[*args]) + return ir.FunCall(fun=ir.FunCall(fun=ir.SymRef(id="map_list"), args=[op]), args=[*args]) def _map_p(op: ir.Expr | P, *args: ir.Expr | P) -> P: - return P(ir.FunCall, fun=P(ir.FunCall, fun=ir.SymRef(id="map_"), args=[op]), args=[*args]) + return P(ir.FunCall, fun=P(ir.FunCall, fun=ir.SymRef(id="map_list"), args=[op]), args=[*args]) def _reduce(op: ir.Expr, init: ir.Expr, *args: ir.Expr) -> ir.FunCall: diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index c91739d4f6..27c93a59fd 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -1230,7 +1230,7 @@ def test_gtir_neighbors_as_input(): im.reduce("plus", im.literal_from_value(init_value))(im.deref("it")) ), inner_domain, - )(im.op_as_fieldop(im.map_("divides"), inner_domain)("v2e_field", "x")) + )(im.op_as_fieldop(im.map_list("divides"), inner_domain)("v2e_field", "x")) ), domain=outer_domain, target=gtir.SymRef(id="vertices"), @@ -1439,8 +1439,8 @@ def test_gtir_reduce_dot_product(): im.reduce("plus", im.literal_from_value(init_value))(im.deref("it")) ) )( - im.op_as_fieldop(im.map_("plus"))( - im.op_as_fieldop(im.map_("multiplies"))( + im.op_as_fieldop(im.map_list("plus"))( + im.op_as_fieldop(im.map_list("multiplies"))( im.as_fieldop_neighbors("V2E", "edges"), "v2e_field", ), From a7175d77a827d686b035d2cdf1a12a7dd485f754 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 28 May 2026 12:04:20 +0200 Subject: [PATCH 09/24] Run pre-commit --- .../feature_tests/ffront_tests/test_execution.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index a70fea3b53..e58c86fb4c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -747,7 +747,9 @@ def testee(a: cases.VField) -> cases.VField: @pytest.mark.uses_unstructured_shift -@pytest.mark.xfail(reason="Not yet supported in lowering, requires `map_list`ing of inner reduce op.") +@pytest.mark.xfail( + reason="Not yet supported in lowering, requires `map_list`ing of inner reduce op." +) def test_nested_reduction_shift_first(unstructured_case): @gtx.field_operator def testee(inp: cases.EField) -> cases.EField: From 2779fd08b7648459135095250832b5fa183ef43d Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 28 May 2026 14:23:13 +0200 Subject: [PATCH 10/24] Refactor tree_map_tuple and add map_tuple with unrolling support --- src/gt4py/next/ffront/foast_to_gtir.py | 2 +- src/gt4py/next/iterator/builtins.py | 10 ++- src/gt4py/next/iterator/ir_utils/ir_makers.py | 11 ++- .../iterator/transforms/unroll_tree_map.py | 62 ++++++++++++----- .../iterator/type_system/type_synthesizer.py | 60 +++++++++++------ .../ffront_tests/test_foast_to_gtir.py | 2 +- .../iterator_tests/test_type_inference.py | 6 +- .../transforms_tests/test_unroll_tree_map.py | 67 ++++++++++++++++--- 8 files changed, 164 insertions(+), 56 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index a8603d78ee..73674a69da 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -416,7 +416,7 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: false_ = self.visit(node.args[2]) cond_symref_name = f"__cond_{cond_.fingerprint()}" - result = im.tree_map( + result = im.tree_map_tuple( im.lambda_("__a", "__b")( im.op_as_fieldop("if_")(im.ref(cond_symref_name), im.ref("__a"), im.ref("__b")) ) diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 67ab347621..d222f100df 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -58,7 +58,12 @@ def map_list(*args): @builtin_dispatch -def tree_map(*args): +def tree_map_tuple(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def map_tuple(*args): raise BackendNotSelectedError() @@ -503,7 +508,8 @@ def get_domain_range(*args): "lift", "make_const_list", "make_tuple", - "tree_map", + "tree_map_tuple", + "map_tuple", "map_list", "named_range", "neighbors", diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index cf9fc44e3f..a46c1f3ffd 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -629,9 +629,14 @@ def map_list(op): return call(call("map_list")(op)) -def tree_map(op): - """Create a `tree_map` call: tree_map(op)(tup1, tup2, ...).""" - return call(call("tree_map")(op)) +def tree_map_tuple(op): + """Create a `tree_map_tuple` call: tree_map_tuple(op)(tup1, tup2, ...).""" + return call(call("tree_map_tuple")(op)) + + +def map_tuple(op): + """Create a `map_tuple` call: map_tuple(op)(tup).""" + return call(call("map_tuple")(op)) def reduce(op, expr): diff --git a/src/gt4py/next/iterator/transforms/unroll_tree_map.py b/src/gt4py/next/iterator/transforms/unroll_tree_map.py index d2b0b7b642..66df6e96c8 100644 --- a/src/gt4py/next/iterator/transforms/unroll_tree_map.py +++ b/src/gt4py/next/iterator/transforms/unroll_tree_map.py @@ -16,8 +16,47 @@ from gt4py.next.type_system import type_specifications as ts +def _tree_map_tuple_body( + f: itir.Expr, tup_refs: list[str], tup_types: list[ts.TupleType] +) -> itir.Expr: + """Recursively unroll `tree_map_tuple(f)(t1, ..., tN)` into `make_tuple` calls.""" + + @utils.tree_map( + collection_type=ts.TupleType, + result_collection_constructor=lambda _, elts: im.make_tuple(*elts), + with_path_arg=True, + ) + def mapper(*args): + *_el_types, path = args + return im.call(f)( + *( + functools.reduce(lambda expr, i: im.tuple_get(i, expr), path, im.ref(ref_name)) + for ref_name in tup_refs + ) + ) + + return mapper(*tup_types) + + +def _map_tuple_body(f: itir.Expr, tup_refs: list[str], tup_types: list[ts.TupleType]) -> itir.Expr: + """Unroll `map_tuple(f)(t)` over top-level elements only (no recursion).""" + (ref_name,) = tup_refs + (tup_type,) = tup_types + return im.make_tuple( + *(im.call(f)(im.tuple_get(i, im.ref(ref_name))) for i in range(len(tup_type.types))) + ) + + +_UNROLLERS = { + "tree_map_tuple": _tree_map_tuple_body, + "map_tuple": _map_tuple_body, +} + + @dataclasses.dataclass class UnrollTreeMap(eve.NodeTranslator): + """Unroll tuple-map ITIR builtins (`tree_map_tuple`, `map_tuple`) into `make_tuple`.""" + PRESERVED_ANNEX_ATTRS = ("domain",) uids: utils.IDGeneratorPool @@ -29,11 +68,14 @@ def apply(cls, program: itir.Program, *, uids: utils.IDGeneratorPool): def visit_FunCall(self, node: itir.FunCall): node = self.generic_visit(node) - if not cpm.is_call_to(node.fun, "tree_map"): + builtin_name = next((name for name in _UNROLLERS if cpm.is_call_to(node.fun, name)), None) + if builtin_name is None: return node + assert isinstance(node.fun, itir.FunCall) f = node.fun.args[0] tup_args = node.args + tup_types: list[ts.TupleType] = [] for tup in tup_args: itir_inference.reinfer(tup) @@ -41,22 +83,8 @@ def visit_FunCall(self, node: itir.FunCall): tup_types.append(tup.type) tup_refs = [next(self.uids["_utm"]) for _ in tup_args] + body = _UNROLLERS[builtin_name](f, tup_refs, tup_types) - @utils.tree_map( - collection_type=ts.TupleType, - result_collection_constructor=lambda _, elts: im.make_tuple(*elts), - with_path_arg=True, - ) - def mapper(*args): - *_el_types, path = args - return im.call(f)( - *( - functools.reduce(lambda expr, i: im.tuple_get(i, expr), path, im.ref(ref_name)) - for ref_name in tup_refs - ) - ) - - result = im.let(*zip(tup_refs, tup_args))(mapper(*tup_types)) - + result = im.let(*zip(tup_refs, tup_args))(body) itir_inference.reinfer(result) return result diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 53c72de837..cbd17d3120 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -632,30 +632,50 @@ def applied_map( return applied_map -@_register_builtin_type_synthesizer -def tree_map(op: TypeSynthesizer) -> TypeSynthesizer: +def _tuple_map_synthesizer(builtin_name: str, *, recursive: bool) -> TypeSynthesizer: + """Shared implementation for `tree_map_tuple` (recursive) and `map_tuple` (top-level).""" + @type_synthesizer - def applied_map( - *args: ts.TupleType, offset_provider_type: common.OffsetProviderType - ) -> ts.TupleType: - if not args: - raise TypeError("tree_map requires at least one argument.") - if not all(isinstance(a, ts.TupleType) for a in args): - raise TypeError( - "tree_map requires all top-level arguments to be TupleType, " - f"got {[type(a).__name__ for a in args]}." - ) + def factory(op: TypeSynthesizer) -> TypeSynthesizer: + @type_synthesizer + def applied_map( + *args: ts.TupleType, offset_provider_type: common.OffsetProviderType + ) -> ts.TupleType: + if not args: + raise TypeError(f"'{builtin_name}' requires at least one argument.") + if not recursive and len(args) != 1: + raise TypeError(f"'{builtin_name}' requires exactly one argument, got {len(args)}.") + if not all(isinstance(a, ts.TupleType) for a in args): + raise TypeError( + f"'{builtin_name}' requires all top-level arguments to be TupleType, " + f"got {[type(a).__name__ for a in args]}." + ) - def leaf_op(*leaf_types: ts.TypeSpec) -> ts.TypeSpec: - return op(*leaf_types, offset_provider_type=offset_provider_type) # type: ignore[return-value] + def leaf_op(*leaf_types: ts.TypeSpec) -> ts.TypeSpec: + return op(*leaf_types, offset_provider_type=offset_provider_type) # type: ignore[return-value] - return utils.tree_map( # type: ignore[return-value] - leaf_op, - collection_type=ts.TupleType, - result_collection_constructor=lambda _, elts: ts.TupleType(types=[*elts]), - )(*args) + if recursive: + return utils.tree_map( # type: ignore[return-value] + leaf_op, + collection_type=ts.TupleType, + result_collection_constructor=lambda _, elts: ts.TupleType(types=[*elts]), + )(*args) - return applied_map + # Non-recursive: apply `op` once per top-level element. + (arg,) = args + return ts.TupleType(types=[leaf_op(el) for el in arg.types]) + + return applied_map + + return factory + + +tree_map_tuple = _register_builtin_type_synthesizer( + _tuple_map_synthesizer("tree_map_tuple", recursive=True), fun_names=["tree_map_tuple"] +) +map_tuple = _register_builtin_type_synthesizer( + _tuple_map_synthesizer("map_tuple", recursive=False), fun_names=["map_tuple"] +) @_register_builtin_type_synthesizer diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 99033c29de..4c6bf4d34a 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -207,7 +207,7 @@ def foo( lowered ) # we generate a let for the condition which is removed by inlining for easier testing - reference = im.tree_map( + reference = im.tree_map_tuple( im.lambda_("__a", "__b")(im.op_as_fieldop("if_")("a", im.ref("__a"), im.ref("__b"))) )("b", "c") diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 1c3d6da88e..025e75036a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -155,16 +155,16 @@ def expression_test_cases(): ), ts.ListType(element_type=int_type, offset_type=V2EDim), ), - # tree_map + # tree_map_tuple ( - im.tree_map(im.ref("plus"))( + im.tree_map_tuple(im.ref("plus"))( im.ref("t1", ts.TupleType(types=[int_type, int_type])), im.ref("t2", ts.TupleType(types=[int_type, int_type])), ), ts.TupleType(types=[int_type, int_type]), ), ( - im.tree_map(im.ref("not_"))( + im.tree_map_tuple(im.ref("not_"))( im.ref( "t", ts.TupleType(types=[bool_type, ts.TupleType(types=[bool_type, bool_type])]), diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py index 3462ef4084..1b97b35aa7 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py @@ -16,7 +16,6 @@ T = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) i_field = ts.FieldType(dims=[IDim], dtype=T) i_tuple_field = ts.TupleType(types=[i_field, i_field]) -i_nested_tuple_field = ts.TupleType(types=[i_tuple_field, i_field]) i_domain = im.call("cartesian_domain")(im.named_range(itir.AxisLiteral(value="IDim"), 0, 1)) @@ -47,11 +46,11 @@ def _plus(): return im.lambda_("__a", "__b")(im.op_as_fieldop("plus")("__a", "__b")) -def test_multi_arg(): +def test_tree_map_tuple_multi_arg(): uids = utils.IDGeneratorPool() program = _make_program( [im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)], - im.call(im.call("tree_map")(_plus()))( + im.call(im.call("tree_map_tuple")(_plus()))( im.ref("a", i_tuple_field), im.ref("b", i_tuple_field) ), out_type=i_tuple_field, @@ -71,26 +70,76 @@ def test_multi_arg(): assert result == expected -def test_nested(): +def test_tree_map_tuple_nested(): uids = utils.IDGeneratorPool() + nested = ts.TupleType(types=[i_tuple_field, i_tuple_field]) program = _make_program( - [im.sym("t", i_nested_tuple_field)], - im.call(im.call("tree_map")(_neg()))(im.ref("t", i_nested_tuple_field)), - out_type=i_nested_tuple_field, + [im.sym("t", nested)], + im.call(im.call("tree_map_tuple")(_neg()))(im.ref("t", nested)), + out_type=nested, ) result = UnrollTreeMap.apply(program, uids=uids) expected = _make_program( - [im.sym("t", i_nested_tuple_field)], + [im.sym("t", nested)], im.let("_utm_0", "t")( im.make_tuple( im.make_tuple( im.call(_neg())(im.tuple_get(0, im.tuple_get(0, "_utm_0"))), im.call(_neg())(im.tuple_get(1, im.tuple_get(0, "_utm_0"))), ), + im.make_tuple( + im.call(_neg())(im.tuple_get(0, im.tuple_get(1, "_utm_0"))), + im.call(_neg())(im.tuple_get(1, im.tuple_get(1, "_utm_0"))), + ), + ) + ), + out_type=nested, + ) + assert result == expected + + +def test_map_tuple_single_arg(): + uids = utils.IDGeneratorPool() + program = _make_program( + [im.sym("t", i_tuple_field)], + im.call(im.call("map_tuple")(_neg()))(im.ref("t", i_tuple_field)), + out_type=i_tuple_field, + ) + result = UnrollTreeMap.apply(program, uids=uids) + + expected = _make_program( + [im.sym("t", i_tuple_field)], + im.let("_utm_0", "t")( + im.make_tuple( + im.call(_neg())(im.tuple_get(0, "_utm_0")), im.call(_neg())(im.tuple_get(1, "_utm_0")), ) ), - out_type=i_nested_tuple_field, + out_type=i_tuple_field, + ) + assert result == expected + + +def test_map_tuple_does_not_recurse(): + uids = utils.IDGeneratorPool() + nested = ts.TupleType(types=[i_tuple_field, i_tuple_field]) + g = im.lambda_("__p")(im.op_as_fieldop("plus")(im.tuple_get(0, "__p"), im.tuple_get(1, "__p"))) + program = _make_program( + [im.sym("t", nested)], + im.call(im.call("map_tuple")(g))(im.ref("t", nested)), + out_type=i_tuple_field, + ) + result = UnrollTreeMap.apply(program, uids=uids) + + expected = _make_program( + [im.sym("t", nested)], + im.let("_utm_0", "t")( + im.make_tuple( + im.call(g)(im.tuple_get(0, "_utm_0")), + im.call(g)(im.tuple_get(1, "_utm_0")), + ) + ), + out_type=i_tuple_field, ) assert result == expected From 80f32734331484724b5e5caddc2569d3765ef581 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 28 May 2026 14:27:23 +0200 Subject: [PATCH 11/24] Rename --- src/gt4py/next/iterator/transforms/pass_manager.py | 12 ++++++------ .../{unroll_tree_map.py => unroll_tuple_maps.py} | 2 +- ..._unroll_tree_map.py => test_unroll_tuple_maps.py} | 10 +++++----- 3 files changed, 12 insertions(+), 12 deletions(-) rename src/gt4py/next/iterator/transforms/{unroll_tree_map.py => unroll_tuple_maps.py} (98%) rename tests/next_tests/unit_tests/iterator_tests/transforms_tests/{test_unroll_tree_map.py => test_unroll_tuple_maps.py} (93%) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 2ff883c9c8..1c98c448a8 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -24,7 +24,7 @@ prune_empty_concat_where, remove_broadcast, symbol_ref_utils, - unroll_tree_map, + unroll_tuple_maps, ) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple @@ -178,9 +178,9 @@ def apply_common_transforms( ) # domain inference does not support dynamic offsets yet ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) - ir = unroll_tree_map.UnrollTreeMap.apply(ir, uids=uids) + ir = unroll_tuple_maps.UnrollTupleMaps.apply(ir, uids=uids) - # After UnrollTreeMap, collapse `tuple_get(i, let(...)(make_tuple(...)))` patterns so that + # After UnrollTupleMaps, collapse `tuple_get(i, let(...)(make_tuple(...)))` patterns so that # domain inference does not encounter `as_fieldop` nodes inside dead tuple elements # (which would receive NEVER domain). Do multiple iterations for nested `let`s. for _ in range(10): @@ -197,7 +197,7 @@ def apply_common_transforms( if ir == collapsed: break else: - raise RuntimeError("'CollapseTuple' did not converge after `UnrollTreeMap`.") + raise RuntimeError("'CollapseTuple' did not converge after `UnrollTupleMaps`.") ir = infer_domain.infer_program( ir, @@ -312,7 +312,7 @@ def apply_fieldview_transforms( ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) - ir = unroll_tree_map.UnrollTreeMap.apply(ir, uids=uids) + ir = unroll_tuple_maps.UnrollTupleMaps.apply(ir, uids=uids) for _ in range(10): prev = ir ir = CollapseTuple.apply( @@ -327,7 +327,7 @@ def apply_fieldview_transforms( if ir == prev: break else: - raise RuntimeError("'CollapseTuple' did not converge after `UnrollTreeMap`.") + raise RuntimeError("'CollapseTuple' did not converge after `UnrollTupleMaps`.") ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program diff --git a/src/gt4py/next/iterator/transforms/unroll_tree_map.py b/src/gt4py/next/iterator/transforms/unroll_tuple_maps.py similarity index 98% rename from src/gt4py/next/iterator/transforms/unroll_tree_map.py rename to src/gt4py/next/iterator/transforms/unroll_tuple_maps.py index 66df6e96c8..e32b09b99f 100644 --- a/src/gt4py/next/iterator/transforms/unroll_tree_map.py +++ b/src/gt4py/next/iterator/transforms/unroll_tuple_maps.py @@ -54,7 +54,7 @@ def _map_tuple_body(f: itir.Expr, tup_refs: list[str], tup_types: list[ts.TupleT @dataclasses.dataclass -class UnrollTreeMap(eve.NodeTranslator): +class UnrollTupleMaps(eve.NodeTranslator): """Unroll tuple-map ITIR builtins (`tree_map_tuple`, `map_tuple`) into `make_tuple`.""" PRESERVED_ANNEX_ATTRS = ("domain",) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py similarity index 93% rename from tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py rename to tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py index 1b97b35aa7..d9357fa785 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py @@ -9,7 +9,7 @@ from gt4py.next import common, utils from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms.unroll_tree_map import UnrollTreeMap +from gt4py.next.iterator.transforms.unroll_tuple_maps import UnrollTupleMaps from gt4py.next.type_system import type_specifications as ts IDim = common.Dimension("IDim") @@ -55,7 +55,7 @@ def test_tree_map_tuple_multi_arg(): ), out_type=i_tuple_field, ) - result = UnrollTreeMap.apply(program, uids=uids) + result = UnrollTupleMaps.apply(program, uids=uids) expected = _make_program( [im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)], @@ -78,7 +78,7 @@ def test_tree_map_tuple_nested(): im.call(im.call("tree_map_tuple")(_neg()))(im.ref("t", nested)), out_type=nested, ) - result = UnrollTreeMap.apply(program, uids=uids) + result = UnrollTupleMaps.apply(program, uids=uids) expected = _make_program( [im.sym("t", nested)], @@ -106,7 +106,7 @@ def test_map_tuple_single_arg(): im.call(im.call("map_tuple")(_neg()))(im.ref("t", i_tuple_field)), out_type=i_tuple_field, ) - result = UnrollTreeMap.apply(program, uids=uids) + result = UnrollTupleMaps.apply(program, uids=uids) expected = _make_program( [im.sym("t", i_tuple_field)], @@ -130,7 +130,7 @@ def test_map_tuple_does_not_recurse(): im.call(im.call("map_tuple")(g))(im.ref("t", nested)), out_type=i_tuple_field, ) - result = UnrollTreeMap.apply(program, uids=uids) + result = UnrollTupleMaps.apply(program, uids=uids) expected = _make_program( [im.sym("t", nested)], From 454e15fb7e13274a366f85e005b7aa3562c3b892 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 28 May 2026 16:29:54 +0200 Subject: [PATCH 12/24] Minor fix --- src/gt4py/next/iterator/type_system/type_synthesizer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index cbd17d3120..07838ee751 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -632,10 +632,11 @@ def applied_map( return applied_map -def _tuple_map_synthesizer(builtin_name: str, *, recursive: bool) -> TypeSynthesizer: +def _tuple_map_synthesizer( + builtin_name: str, *, recursive: bool +) -> Callable[..., TypeOrTypeSynthesizer]: """Shared implementation for `tree_map_tuple` (recursive) and `map_tuple` (top-level).""" - @type_synthesizer def factory(op: TypeSynthesizer) -> TypeSynthesizer: @type_synthesizer def applied_map( From 31b969a097479167b3c6f7aa89cffaf93a583973 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 2 Jun 2026 12:04:13 +0200 Subject: [PATCH 13/24] Remove unnecessary CollapseTuple loop --- .../next/iterator/transforms/pass_manager.py | 38 +------------------ 1 file changed, 2 insertions(+), 36 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 3a26061d88..47ba25e9b8 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -163,6 +163,7 @@ def apply_common_transforms( ir = inline_fundefs.InlineFundefs().visit(ir) ir = inline_fundefs.prune_unreferenced_fundefs(ir) + ir = unroll_tuple_maps.UnrollTupleMaps.apply(ir, uids=uids) ir = NormalizeShifts().visit(ir) # TODO(tehrengruber): Many iterator test contain lifts that need to be inlined, e.g. @@ -178,26 +179,6 @@ def apply_common_transforms( ) # domain inference does not support dynamic offsets yet ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) - ir = unroll_tuple_maps.UnrollTupleMaps.apply(ir, uids=uids) - - # After UnrollTupleMaps, collapse `tuple_get(i, let(...)(make_tuple(...)))` patterns so that - # domain inference does not encounter `as_fieldop` nodes inside dead tuple elements - # (which would receive NEVER domain). Do multiple iterations for nested `let`s. - for _ in range(10): - collapsed = ir - ir = CollapseTuple.apply( - ir, - enabled_transformations=( - CollapseTuple.Transformation.PROPAGATE_TUPLE_GET - | CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE - ), - uids=uids, - offset_provider_type=offset_provider_type, - ) # type: ignore[assignment] # always an itir.Program - if ir == collapsed: - break - else: - raise RuntimeError("'CollapseTuple' did not converge after `UnrollTupleMaps`.") ir = infer_domain.infer_program( ir, @@ -301,6 +282,7 @@ def apply_fieldview_transforms( ir = inline_fundefs.InlineFundefs().visit(ir) ir = inline_fundefs.prune_unreferenced_fundefs(ir) + ir = unroll_tuple_maps.UnrollTupleMaps.apply(ir, uids=uids) # required for dead-code-elimination and `prune_empty_concat_where` pass ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program ir = dead_code_elimination.dead_code_elimination( @@ -312,22 +294,6 @@ def apply_fieldview_transforms( ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) - ir = unroll_tuple_maps.UnrollTupleMaps.apply(ir, uids=uids) - for _ in range(10): - prev = ir - ir = CollapseTuple.apply( - ir, - enabled_transformations=( - CollapseTuple.Transformation.PROPAGATE_TUPLE_GET - | CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE - ), - uids=uids, - offset_provider_type=offset_provider_type, - ) # type: ignore[assignment] # always an itir.Program - if ir == prev: - break - else: - raise RuntimeError("'CollapseTuple' did not converge after `UnrollTupleMaps`.") ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program From c7fc102b709a76e49c4e2a1922c3016607896cbf Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 2 Jun 2026 20:09:12 +0200 Subject: [PATCH 14/24] Reposition UnrollTupleMaps and simplify CollapseTuple usage --- .../next/iterator/transforms/pass_manager.py | 32 ++++- .../iterator/transforms/unroll_tuple_maps.py | 45 ++++-- .../test_unroll_tuple_maps.py | 136 +++++++++++++++--- 3 files changed, 177 insertions(+), 36 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 47ba25e9b8..f56b9e6296 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -163,7 +163,6 @@ def apply_common_transforms( ir = inline_fundefs.InlineFundefs().visit(ir) ir = inline_fundefs.prune_unreferenced_fundefs(ir) - ir = unroll_tuple_maps.UnrollTupleMaps.apply(ir, uids=uids) ir = NormalizeShifts().visit(ir) # TODO(tehrengruber): Many iterator test contain lifts that need to be inlined, e.g. @@ -171,6 +170,23 @@ def apply_common_transforms( ir = inline_lifts.InlineLifts().visit(ir) ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program + # `UnrollTupleMaps` requires fully-inferred tuple types (relies on `reinfer` to see + # nested `TupleType` chains). `expand_tuple_args` runs full type inference, so this is + # the earliest safe position. + ir = unroll_tuple_maps.UnrollTupleMaps.apply(ir, uids=uids) + # `UnrollTupleMaps` collapses `tuple_get(i, make_tuple(...))` patterns on the fly + # for trivial arguments, so no additional `CollapseTuple` cleanup loop is needed. + # A single `CollapseTuple` pass still handles any residual patterns produced when + # arguments had to be let-bound (non-trivial sub-expressions). + ir = CollapseTuple.apply( + ir, + enabled_transformations=( + CollapseTuple.Transformation.PROPAGATE_TUPLE_GET + | CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE + ), + uids=uids, + offset_provider_type=offset_provider_type, + ) # type: ignore[assignment] # always an itir.Program ir = dead_code_elimination.dead_code_elimination( ir, uids=uids, offset_provider_type=offset_provider_type ) # domain inference does not support dead-code @@ -282,9 +298,21 @@ def apply_fieldview_transforms( ir = inline_fundefs.InlineFundefs().visit(ir) ir = inline_fundefs.prune_unreferenced_fundefs(ir) - ir = unroll_tuple_maps.UnrollTupleMaps.apply(ir, uids=uids) # required for dead-code-elimination and `prune_empty_concat_where` pass ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program + # `UnrollTupleMaps` requires fully-inferred tuple types; `expand_tuple_args` runs full + # type inference, so this is the earliest safe position. + ir = unroll_tuple_maps.UnrollTupleMaps.apply(ir, uids=uids) + # See note in `apply_common_transforms` about why a single `CollapseTuple` pass suffices. + ir = CollapseTuple.apply( + ir, + enabled_transformations=( + CollapseTuple.Transformation.PROPAGATE_TUPLE_GET + | CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE + ), + uids=uids, + offset_provider_type=offset_provider_type, + ) # type: ignore[assignment] # always an itir.Program ir = dead_code_elimination.dead_code_elimination( ir, offset_provider_type=offset_provider_type, uids=uids ) diff --git a/src/gt4py/next/iterator/transforms/unroll_tuple_maps.py b/src/gt4py/next/iterator/transforms/unroll_tuple_maps.py index e32b09b99f..9c937219b3 100644 --- a/src/gt4py/next/iterator/transforms/unroll_tuple_maps.py +++ b/src/gt4py/next/iterator/transforms/unroll_tuple_maps.py @@ -16,8 +16,18 @@ from gt4py.next.type_system import type_specifications as ts +def _collapsing_tuple_get(expr: itir.Expr, i: int) -> itir.Expr: + """Like `im.tuple_get`, but collapses immediately when `expr` is a `make_tuple` call. + + Note: argument order is `(expr, i)` to allow use as a `functools.reduce` reducer. + """ + if cpm.is_call_to(expr, "make_tuple"): + return expr.args[i] + return im.tuple_get(i, expr) + + def _tree_map_tuple_body( - f: itir.Expr, tup_refs: list[str], tup_types: list[ts.TupleType] + f: itir.Expr, tup_exprs: list[itir.Expr], tup_types: list[ts.TupleType] ) -> itir.Expr: """Recursively unroll `tree_map_tuple(f)(t1, ..., tN)` into `make_tuple` calls.""" @@ -29,21 +39,20 @@ def _tree_map_tuple_body( def mapper(*args): *_el_types, path = args return im.call(f)( - *( - functools.reduce(lambda expr, i: im.tuple_get(i, expr), path, im.ref(ref_name)) - for ref_name in tup_refs - ) + *(functools.reduce(_collapsing_tuple_get, path, tup_expr) for tup_expr in tup_exprs) ) return mapper(*tup_types) -def _map_tuple_body(f: itir.Expr, tup_refs: list[str], tup_types: list[ts.TupleType]) -> itir.Expr: +def _map_tuple_body( + f: itir.Expr, tup_exprs: list[itir.Expr], tup_types: list[ts.TupleType] +) -> itir.Expr: """Unroll `map_tuple(f)(t)` over top-level elements only (no recursion).""" - (ref_name,) = tup_refs + (tup_expr,) = tup_exprs (tup_type,) = tup_types return im.make_tuple( - *(im.call(f)(im.tuple_get(i, im.ref(ref_name))) for i in range(len(tup_type.types))) + *(im.call(f)(_collapsing_tuple_get(tup_expr, i)) for i in range(len(tup_type.types))) ) @@ -82,9 +91,23 @@ def visit_FunCall(self, node: itir.FunCall): assert isinstance(tup.type, ts.TupleType) tup_types.append(tup.type) - tup_refs = [next(self.uids["_utm"]) for _ in tup_args] - body = _UNROLLERS[builtin_name](f, tup_refs, tup_types) + # For trivial args (those that can be duplicated without cost or side effects), + # we substitute them directly into the body. This avoids leaving behind + # `tuple_get(i, make_tuple(...))` patterns that would otherwise require a + # separate cleanup pass (CollapseTuple). For non-trivial args we still + # introduce a `let` binding to avoid duplicating expensive sub-expressions. + substituted_exprs: list[itir.Expr] = [] + let_bindings: list[tuple[str, itir.Expr]] = [] + for tup in tup_args: + if isinstance(tup, (itir.SymRef, itir.Literal)) or cpm.is_call_to(tup, "make_tuple"): + substituted_exprs.append(tup) + else: + ref_name = next(self.uids["_utm"]) + let_bindings.append((ref_name, tup)) + substituted_exprs.append(im.ref(ref_name)) + + body = _UNROLLERS[builtin_name](f, substituted_exprs, tup_types) - result = im.let(*zip(tup_refs, tup_args))(body) + result = im.let(*let_bindings)(body) if let_bindings else body itir_inference.reinfer(result) return result diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py index d9357fa785..593f519cc2 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py @@ -59,11 +59,9 @@ def test_tree_map_tuple_multi_arg(): expected = _make_program( [im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)], - im.let(("_utm_0", "a"), ("_utm_1", "b"))( - im.make_tuple( - im.call(_plus())(im.tuple_get(0, "_utm_0"), im.tuple_get(0, "_utm_1")), - im.call(_plus())(im.tuple_get(1, "_utm_0"), im.tuple_get(1, "_utm_1")), - ) + im.make_tuple( + im.call(_plus())(im.tuple_get(0, "a"), im.tuple_get(0, "b")), + im.call(_plus())(im.tuple_get(1, "a"), im.tuple_get(1, "b")), ), out_type=i_tuple_field, ) @@ -82,17 +80,15 @@ def test_tree_map_tuple_nested(): expected = _make_program( [im.sym("t", nested)], - im.let("_utm_0", "t")( + im.make_tuple( im.make_tuple( - im.make_tuple( - im.call(_neg())(im.tuple_get(0, im.tuple_get(0, "_utm_0"))), - im.call(_neg())(im.tuple_get(1, im.tuple_get(0, "_utm_0"))), - ), - im.make_tuple( - im.call(_neg())(im.tuple_get(0, im.tuple_get(1, "_utm_0"))), - im.call(_neg())(im.tuple_get(1, im.tuple_get(1, "_utm_0"))), - ), - ) + im.call(_neg())(im.tuple_get(0, im.tuple_get(0, "t"))), + im.call(_neg())(im.tuple_get(1, im.tuple_get(0, "t"))), + ), + im.make_tuple( + im.call(_neg())(im.tuple_get(0, im.tuple_get(1, "t"))), + im.call(_neg())(im.tuple_get(1, im.tuple_get(1, "t"))), + ), ), out_type=nested, ) @@ -110,11 +106,9 @@ def test_map_tuple_single_arg(): expected = _make_program( [im.sym("t", i_tuple_field)], - im.let("_utm_0", "t")( - im.make_tuple( - im.call(_neg())(im.tuple_get(0, "_utm_0")), - im.call(_neg())(im.tuple_get(1, "_utm_0")), - ) + im.make_tuple( + im.call(_neg())(im.tuple_get(0, "t")), + im.call(_neg())(im.tuple_get(1, "t")), ), out_type=i_tuple_field, ) @@ -134,10 +128,106 @@ def test_map_tuple_does_not_recurse(): expected = _make_program( [im.sym("t", nested)], - im.let("_utm_0", "t")( + im.make_tuple( + im.call(g)(im.tuple_get(0, "t")), + im.call(g)(im.tuple_get(1, "t")), + ), + out_type=i_tuple_field, + ) + assert result == expected + + +def test_make_tuple_arg_is_collapsed(): + """When the input tuple is a `make_tuple` literal, projection should collapse + directly to the element (no residual `tuple_get(make_tuple(...))`).""" + uids = utils.IDGeneratorPool() + program = _make_program( + [im.sym("a", i_field), im.sym("b", i_field)], + im.call(im.call("tree_map_tuple")(_neg()))(im.make_tuple("a", "b")), + out_type=i_tuple_field, + ) + result = UnrollTupleMaps.apply(program, uids=uids) + + expected = _make_program( + [im.sym("a", i_field), im.sym("b", i_field)], + im.make_tuple(im.call(_neg())("a"), im.call(_neg())("b")), + out_type=i_tuple_field, + ) + assert result == expected + + +def test_nested_make_tuple_arg_is_collapsed(): + """A nested `make_tuple` arg should be fully collapsed at every depth: each + `tuple_get(i, make_tuple(...))` along the recursion is folded directly.""" + uids = utils.IDGeneratorPool() + nested = ts.TupleType(types=[i_tuple_field, i_tuple_field]) + program = _make_program( + [ + im.sym("a", i_field), + im.sym("b", i_field), + im.sym("c", i_field), + im.sym("d", i_field), + ], + im.call(im.call("tree_map_tuple")(_neg()))( + im.make_tuple(im.make_tuple("a", "b"), im.make_tuple("c", "d")) + ), + out_type=nested, + ) + result = UnrollTupleMaps.apply(program, uids=uids) + + expected = _make_program( + [ + im.sym("a", i_field), + im.sym("b", i_field), + im.sym("c", i_field), + im.sym("d", i_field), + ], + im.make_tuple( + im.make_tuple(im.call(_neg())("a"), im.call(_neg())("b")), + im.make_tuple(im.call(_neg())("c"), im.call(_neg())("d")), + ), + out_type=nested, + ) + assert result == expected + + +def test_map_tuple_with_make_tuple_arg_is_collapsed(): + """The `make_tuple` short-circuit must also apply for the `map_tuple` builtin.""" + uids = utils.IDGeneratorPool() + program = _make_program( + [im.sym("a", i_field), im.sym("b", i_field)], + im.call(im.call("map_tuple")(_neg()))(im.make_tuple("a", "b")), + out_type=i_tuple_field, + ) + result = UnrollTupleMaps.apply(program, uids=uids) + + expected = _make_program( + [im.sym("a", i_field), im.sym("b", i_field)], + im.make_tuple(im.call(_neg())("a"), im.call(_neg())("b")), + out_type=i_tuple_field, + ) + assert result == expected + + +def test_non_trivial_arg_is_let_bound(): + """Non-trivial (potentially expensive) tuple expressions must still be + let-bound to avoid duplicating work across leaf projections.""" + uids = utils.IDGeneratorPool() + # `f(t)` is a non-trivial expression returning a tuple + f = im.lambda_("__t")(im.ref("__t", i_tuple_field)) + program = _make_program( + [im.sym("t", i_tuple_field)], + im.call(im.call("tree_map_tuple")(_neg()))(im.call(f)(im.ref("t", i_tuple_field))), + out_type=i_tuple_field, + ) + result = UnrollTupleMaps.apply(program, uids=uids) + + expected = _make_program( + [im.sym("t", i_tuple_field)], + im.let("_utm_0", im.call(f)("t"))( im.make_tuple( - im.call(g)(im.tuple_get(0, "_utm_0")), - im.call(g)(im.tuple_get(1, "_utm_0")), + im.call(_neg())(im.tuple_get(0, "_utm_0")), + im.call(_neg())(im.tuple_get(1, "_utm_0")), ) ), out_type=i_tuple_field, From b7f8ba9b85f8f52990f13f5b0d1ce6daec97196f Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 16 Jun 2026 14:49:07 +0200 Subject: [PATCH 15/24] Refactor tree_map unrolling --- src/gt4py/next/ffront/foast_to_gtir.py | 24 +- src/gt4py/next/ffront/lowering_utils.py | 133 +++++++++- .../next/iterator/transforms/pass_manager.py | 31 --- .../iterator/transforms/unroll_tuple_maps.py | 113 --------- src/gt4py/next/type_system/type_info.py | 34 +++ .../ffront_tests/test_gt4py_builtins.py | 39 +++ .../ffront_tests/test_foast_to_gtir.py | 7 +- .../test_unroll_tuple_maps.py | 235 ------------------ 8 files changed, 226 insertions(+), 390 deletions(-) delete mode 100644 src/gt4py/next/iterator/transforms/unroll_tuple_maps.py delete mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 73674a69da..35653a36f2 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -412,15 +412,27 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return self._lower_and_map("if_", *node.args) cond_ = self.visit(node.args[0]) - true_ = self.visit(node.args[1]) - false_ = self.visit(node.args[2]) cond_symref_name = f"__cond_{cond_.fingerprint()}" - result = im.tree_map_tuple( - im.lambda_("__a", "__b")( - im.op_as_fieldop("if_")(im.ref(cond_symref_name), im.ref("__a"), im.ref("__b")) + def create_if( + true_: itir.Expr, false_: itir.Expr, arg_types: tuple[ts.TypeSpec, ts.TypeSpec] + ) -> itir.FunCall: + # Lower each leaf via `_map` so that the per-leaf, type-dependent decision whether + # to wrap `if_` in `map_list` (and promote the condition with `make_const_list`) + # is taken based on the actual leaf types. A single, uniform leaf (e.g. via + # `tree_map_tuple`) can not do this, see `_map` for details. + return _map( + "if_", + (im.ref(cond_symref_name), true_, false_), + (node.args[0].type, *arg_types), ) - )(true_, false_) + + result = lowering_utils.process_elements( + create_if, + (self.visit(node.args[1]), self.visit(node.args[2])), + node.type, + arg_types=(node.args[1].type, node.args[2].type), + ) return im.let(cond_symref_name, cond_)(result) diff --git a/src/gt4py/next/ffront/lowering_utils.py b/src/gt4py/next/ffront/lowering_utils.py index 8c9434fa26..4c6ce18da4 100644 --- a/src/gt4py/next/ffront/lowering_utils.py +++ b/src/gt4py/next/ffront/lowering_utils.py @@ -6,12 +6,14 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import functools from collections.abc import Iterable from typing import Callable, Optional, TypeVar +from gt4py.next import utils from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.type_system import type_specifications as ts +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.type_system import type_info, type_specifications as ts # TODO(tehrengruber): The code quality of this function is poor. We should rewrite it. @@ -84,3 +86,130 @@ def _process_elements_impl( result = process_func(*_current_el_exprs) return result + + +def _collapsing_tuple_get(expr: itir.Expr, i: int) -> itir.Expr: + """Like `im.tuple_get`, but collapses immediately when `expr` is a `make_tuple` call. + + Note: argument order is `(expr, i)` to allow use as a `functools.reduce` reducer. + """ + if cpm.is_call_to(expr, "make_tuple"): + return expr.args[i] + return im.tuple_get(i, expr) + + +def _tree_map_tuple_body( + f: itir.Expr, tup_exprs: list[itir.Expr], tup_types: list[ts.TupleType] +) -> itir.Expr: + """Recursively unroll `tree_map_tuple(f)(t1, ..., tN)` into `make_tuple` calls.""" + + @utils.tree_map( + collection_type=ts.TupleType, + result_collection_constructor=lambda _, elts: im.make_tuple(*elts), + with_path_arg=True, + ) + def mapper(*args): + *_el_types, path = args + return im.call(f)( + *(functools.reduce(_collapsing_tuple_get, path, tup_expr) for tup_expr in tup_exprs) + ) + + return mapper(*tup_types) + + +def _map_tuple_body( + f: itir.Expr, tup_exprs: list[itir.Expr], tup_types: list[ts.TupleType] +) -> itir.Expr: + """Unroll `map_tuple(f)(t)` over top-level elements only (no recursion).""" + (tup_expr,) = tup_exprs + (tup_type,) = tup_types + return im.make_tuple( + *(im.call(f)(_collapsing_tuple_get(tup_expr, i)) for i in range(len(tup_type.types))) + ) + + +_UNROLLERS = { + "tree_map_tuple": _tree_map_tuple_body, + "map_tuple": _map_tuple_body, +} + + +def _unroll_tuple_map( + builtin_name: str, + f: itir.Expr, + tup_exprs: Iterable[itir.Expr], + tup_types: Iterable[ts.TypeSpec], + *, + uids: utils.IDGeneratorPool, +) -> itir.Expr: + tup_exprs = list(tup_exprs) + tup_types = list(tup_types) + for tup_type in tup_types: + if not isinstance(tup_type, ts.TupleType): + raise TypeError( + f"'{builtin_name}' requires all arguments to be tuples, got '{tup_type}'." + ) + if not type_info.tuple_structures_match(*tup_types): + raise TypeError( + f"'{builtin_name}' requires all arguments to share the same (nested) tuple " + f"structure, got {[str(t) for t in tup_types]}." + ) + + # For trivial args (those that can be duplicated without cost or side effects), + # we substitute them directly into the body. This avoids leaving behind + # `tuple_get(i, make_tuple(...))` patterns that would otherwise require a + # separate cleanup pass (`CollapseTuple`). For non-trivial args we still + # introduce a `let` binding to avoid duplicating expensive sub-expressions. + substituted_exprs: list[itir.Expr] = [] + let_bindings: list[tuple[str, itir.Expr]] = [] + for tup in tup_exprs: + if isinstance(tup, (itir.SymRef, itir.Literal)) or cpm.is_call_to(tup, "make_tuple"): + substituted_exprs.append(tup) + else: + ref_name = next(uids["__utm"]) + let_bindings.append((ref_name, tup)) + substituted_exprs.append(im.ref(ref_name)) + + body = _UNROLLERS[builtin_name](f, substituted_exprs, tup_types) + return im.let(*let_bindings)(body) if let_bindings else body + + +def unroll_tree_map_tuple( + f: itir.Expr, + tup_exprs: Iterable[itir.Expr], + tup_types: Iterable[ts.TypeSpec], + *, + uids: utils.IDGeneratorPool, +) -> itir.Expr: + """ + Lower ``tree_map_tuple(f)(t1, ..., tN)`` to explicit ``make_tuple`` calls, recursing into + nested tuples and applying ``f`` to each leaf. + + Args: + f: The function to apply at each leaf. + tup_exprs: The (already lowered) tuple argument expressions. + tup_types: The type of each argument in ``tup_exprs``; all must be ``TupleType`` and + share the same (nested) structure. + uids: Used to generate fresh names for `let`-bindings of non-trivial arguments. + """ + return _unroll_tuple_map("tree_map_tuple", f, tup_exprs, tup_types, uids=uids) + + +def unroll_map_tuple( + f: itir.Expr, + tup_expr: itir.Expr, + tup_type: ts.TypeSpec, + *, + uids: utils.IDGeneratorPool, +) -> itir.Expr: + """ + Lower ``map_tuple(f)(t)`` to an explicit ``make_tuple`` call, applying ``f`` to each + top-level element only (no recursion). + + Args: + f: The function to apply to each top-level element. + tup_expr: The (already lowered) tuple argument expression. + tup_type: The type of ``tup_expr``; must be a ``TupleType``. + uids: Used to generate a fresh name for a `let`-binding of a non-trivial argument. + """ + return _unroll_tuple_map("map_tuple", f, (tup_expr,), (tup_type,), uids=uids) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index f56b9e6296..2baf3d0cc2 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -24,7 +24,6 @@ prune_empty_concat_where, remove_broadcast, symbol_ref_utils, - unroll_tuple_maps, ) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple @@ -170,23 +169,6 @@ def apply_common_transforms( ir = inline_lifts.InlineLifts().visit(ir) ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program - # `UnrollTupleMaps` requires fully-inferred tuple types (relies on `reinfer` to see - # nested `TupleType` chains). `expand_tuple_args` runs full type inference, so this is - # the earliest safe position. - ir = unroll_tuple_maps.UnrollTupleMaps.apply(ir, uids=uids) - # `UnrollTupleMaps` collapses `tuple_get(i, make_tuple(...))` patterns on the fly - # for trivial arguments, so no additional `CollapseTuple` cleanup loop is needed. - # A single `CollapseTuple` pass still handles any residual patterns produced when - # arguments had to be let-bound (non-trivial sub-expressions). - ir = CollapseTuple.apply( - ir, - enabled_transformations=( - CollapseTuple.Transformation.PROPAGATE_TUPLE_GET - | CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE - ), - uids=uids, - offset_provider_type=offset_provider_type, - ) # type: ignore[assignment] # always an itir.Program ir = dead_code_elimination.dead_code_elimination( ir, uids=uids, offset_provider_type=offset_provider_type ) # domain inference does not support dead-code @@ -300,19 +282,6 @@ def apply_fieldview_transforms( ir = inline_fundefs.prune_unreferenced_fundefs(ir) # required for dead-code-elimination and `prune_empty_concat_where` pass ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program - # `UnrollTupleMaps` requires fully-inferred tuple types; `expand_tuple_args` runs full - # type inference, so this is the earliest safe position. - ir = unroll_tuple_maps.UnrollTupleMaps.apply(ir, uids=uids) - # See note in `apply_common_transforms` about why a single `CollapseTuple` pass suffices. - ir = CollapseTuple.apply( - ir, - enabled_transformations=( - CollapseTuple.Transformation.PROPAGATE_TUPLE_GET - | CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE - ), - uids=uids, - offset_provider_type=offset_provider_type, - ) # type: ignore[assignment] # always an itir.Program ir = dead_code_elimination.dead_code_elimination( ir, offset_provider_type=offset_provider_type, uids=uids ) diff --git a/src/gt4py/next/iterator/transforms/unroll_tuple_maps.py b/src/gt4py/next/iterator/transforms/unroll_tuple_maps.py deleted file mode 100644 index 9c937219b3..0000000000 --- a/src/gt4py/next/iterator/transforms/unroll_tuple_maps.py +++ /dev/null @@ -1,113 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause -import dataclasses -import functools - -from gt4py import eve -from gt4py.next import utils -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next.iterator.type_system import inference as itir_inference -from gt4py.next.type_system import type_specifications as ts - - -def _collapsing_tuple_get(expr: itir.Expr, i: int) -> itir.Expr: - """Like `im.tuple_get`, but collapses immediately when `expr` is a `make_tuple` call. - - Note: argument order is `(expr, i)` to allow use as a `functools.reduce` reducer. - """ - if cpm.is_call_to(expr, "make_tuple"): - return expr.args[i] - return im.tuple_get(i, expr) - - -def _tree_map_tuple_body( - f: itir.Expr, tup_exprs: list[itir.Expr], tup_types: list[ts.TupleType] -) -> itir.Expr: - """Recursively unroll `tree_map_tuple(f)(t1, ..., tN)` into `make_tuple` calls.""" - - @utils.tree_map( - collection_type=ts.TupleType, - result_collection_constructor=lambda _, elts: im.make_tuple(*elts), - with_path_arg=True, - ) - def mapper(*args): - *_el_types, path = args - return im.call(f)( - *(functools.reduce(_collapsing_tuple_get, path, tup_expr) for tup_expr in tup_exprs) - ) - - return mapper(*tup_types) - - -def _map_tuple_body( - f: itir.Expr, tup_exprs: list[itir.Expr], tup_types: list[ts.TupleType] -) -> itir.Expr: - """Unroll `map_tuple(f)(t)` over top-level elements only (no recursion).""" - (tup_expr,) = tup_exprs - (tup_type,) = tup_types - return im.make_tuple( - *(im.call(f)(_collapsing_tuple_get(tup_expr, i)) for i in range(len(tup_type.types))) - ) - - -_UNROLLERS = { - "tree_map_tuple": _tree_map_tuple_body, - "map_tuple": _map_tuple_body, -} - - -@dataclasses.dataclass -class UnrollTupleMaps(eve.NodeTranslator): - """Unroll tuple-map ITIR builtins (`tree_map_tuple`, `map_tuple`) into `make_tuple`.""" - - PRESERVED_ANNEX_ATTRS = ("domain",) - - uids: utils.IDGeneratorPool - - @classmethod - def apply(cls, program: itir.Program, *, uids: utils.IDGeneratorPool): - return cls(uids=uids).visit(program) - - def visit_FunCall(self, node: itir.FunCall): - node = self.generic_visit(node) - - builtin_name = next((name for name in _UNROLLERS if cpm.is_call_to(node.fun, name)), None) - if builtin_name is None: - return node - - assert isinstance(node.fun, itir.FunCall) - f = node.fun.args[0] - tup_args = node.args - - tup_types: list[ts.TupleType] = [] - for tup in tup_args: - itir_inference.reinfer(tup) - assert isinstance(tup.type, ts.TupleType) - tup_types.append(tup.type) - - # For trivial args (those that can be duplicated without cost or side effects), - # we substitute them directly into the body. This avoids leaving behind - # `tuple_get(i, make_tuple(...))` patterns that would otherwise require a - # separate cleanup pass (CollapseTuple). For non-trivial args we still - # introduce a `let` binding to avoid duplicating expensive sub-expressions. - substituted_exprs: list[itir.Expr] = [] - let_bindings: list[tuple[str, itir.Expr]] = [] - for tup in tup_args: - if isinstance(tup, (itir.SymRef, itir.Literal)) or cpm.is_call_to(tup, "make_tuple"): - substituted_exprs.append(tup) - else: - ref_name = next(self.uids["_utm"]) - let_bindings.append((ref_name, tup)) - substituted_exprs.append(im.ref(ref_name)) - - body = _UNROLLERS[builtin_name](f, substituted_exprs, tup_types) - - result = im.let(*let_bindings)(body) if let_bindings else body - itir_inference.reinfer(result) - return result diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index eb70d15947..edd40d4735 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -442,6 +442,40 @@ def contains_local_field(type_: ts.TypeSpec) -> bool: ) +def tuple_structures_match(*types: ts.TypeSpec) -> bool: + """ + Return if all `types` share the same (nested) tuple structure. + + Only the tuple skeleton is compared; the leaf types are not required to be equal. + + Examples: + --------- + >>> i = ts.ScalarType(kind=ts.ScalarKind.INT32) + >>> f = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + >>> tuple_structures_match(ts.TupleType(types=[i, i]), ts.TupleType(types=[f, f])) + True + >>> tuple_structures_match(ts.TupleType(types=[i, i]), ts.TupleType(types=[i])) + False + >>> tuple_structures_match( + ... ts.TupleType(types=[i, ts.TupleType(types=[i])]), + ... ts.TupleType(types=[i, i]), + ... ) + False + """ + if not types: + return True + first, *rest = types + for other in rest: + if isinstance(first, ts.TupleType) != isinstance(other, ts.TupleType): + return False + if isinstance(first, ts.TupleType) and isinstance(other, ts.TupleType): + if len(first.types) != len(other.types): + return False + if not all(tuple_structures_match(f, o) for f, o in zip(first.types, other.types)): + return False + return True + + # TODO(tehrengruber): This function has specializations on Iterator types, which are not part of # the general / shared type system. This functionality should be moved to the iterator-only # type system, but we need some sort of multiple dispatch for that. diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 22f69e8a6e..035d30ef55 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -275,6 +275,45 @@ def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField: ) +@pytest.mark.uses_unstructured_shift +def test_reduction_expression_with_where_tuple_local_condition(unstructured_case): + # Regression test: a tuple-`where` whose condition is itself a per-neighbour (local) field. + # The condition lowers to a list-typed predicate, so each tuple leaf must be lowered as + # `map_list(if_)` (with the condition promoted via `make_const_list`), exactly like the + # non-tuple path does. A uniform leaf (e.g. `op_as_fieldop("if_")`) would produce invalid IR. + @gtx.field_operator + def testee(a: cases.EField, b: cases.EField) -> cases.VField: + cond = a(V2E) > b(V2E) # per-neighbour (local) bool field + t = where(cond, (a(V2E), b(V2E)), (b(V2E), a(V2E))) + return neighbor_sum(t[0] + t[1], axis=V2EDim) + + v2e_table = unstructured_case.offset_provider["V2E"].asnumpy() + + a = cases.allocate(unstructured_case, testee, "a")() + b = cases.allocate(unstructured_case, testee, "b")() + out = cases.allocate(unstructured_case, testee, cases.RETURN)() + + a_nbh = a.asnumpy()[v2e_table] + b_nbh = b.asnumpy()[v2e_table] + cond_nbh = a_nbh > b_nbh + t0 = np.where(cond_nbh, a_nbh, b_nbh) + t1 = np.where(cond_nbh, b_nbh, a_nbh) + + cases.verify( + unstructured_case, + testee, + a, + b, + out=out, + ref=np.sum( + t0 + t1, + axis=1, + initial=0, + where=v2e_table != common._DEFAULT_SKIP_VALUE, + ), + ) + + @pytest.mark.uses_unstructured_shift def test_reduction_expression_with_where_and_scalar(unstructured_case): @gtx.field_operator diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 4c6bf4d34a..d27ffc6a43 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -207,9 +207,10 @@ def foo( lowered ) # we generate a let for the condition which is removed by inlining for easier testing - reference = im.tree_map_tuple( - im.lambda_("__a", "__b")(im.op_as_fieldop("if_")("a", im.ref("__a"), im.ref("__b"))) - )("b", "c") + reference = im.make_tuple( + im.op_as_fieldop("if_")("a", im.tuple_get(0, "b"), im.tuple_get(0, "c")), + im.op_as_fieldop("if_")("a", im.tuple_get(1, "b"), im.tuple_get(1, "c")), + ) assert lowered_inlined.expr == reference diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py deleted file mode 100644 index 593f519cc2..0000000000 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py +++ /dev/null @@ -1,235 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from gt4py.next import common, utils -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms.unroll_tuple_maps import UnrollTupleMaps -from gt4py.next.type_system import type_specifications as ts - -IDim = common.Dimension("IDim") -T = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) -i_field = ts.FieldType(dims=[IDim], dtype=T) -i_tuple_field = ts.TupleType(types=[i_field, i_field]) - -i_domain = im.call("cartesian_domain")(im.named_range(itir.AxisLiteral(value="IDim"), 0, 1)) - - -def _make_program( - params: list[itir.Sym], expr: itir.Expr, out_type: ts.TypeSpec = i_field -) -> itir.Program: - return itir.Program( - id="testee", - function_definitions=[], - params=[*params, im.sym("out", out_type)], - declarations=[], - body=[ - itir.SetAt( - expr=expr, - domain=i_domain, - target=im.ref("out", out_type), - ) - ], - ) - - -def _neg(): - return im.lambda_("__a")(im.op_as_fieldop("neg")("__a")) - - -def _plus(): - return im.lambda_("__a", "__b")(im.op_as_fieldop("plus")("__a", "__b")) - - -def test_tree_map_tuple_multi_arg(): - uids = utils.IDGeneratorPool() - program = _make_program( - [im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)], - im.call(im.call("tree_map_tuple")(_plus()))( - im.ref("a", i_tuple_field), im.ref("b", i_tuple_field) - ), - out_type=i_tuple_field, - ) - result = UnrollTupleMaps.apply(program, uids=uids) - - expected = _make_program( - [im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)], - im.make_tuple( - im.call(_plus())(im.tuple_get(0, "a"), im.tuple_get(0, "b")), - im.call(_plus())(im.tuple_get(1, "a"), im.tuple_get(1, "b")), - ), - out_type=i_tuple_field, - ) - assert result == expected - - -def test_tree_map_tuple_nested(): - uids = utils.IDGeneratorPool() - nested = ts.TupleType(types=[i_tuple_field, i_tuple_field]) - program = _make_program( - [im.sym("t", nested)], - im.call(im.call("tree_map_tuple")(_neg()))(im.ref("t", nested)), - out_type=nested, - ) - result = UnrollTupleMaps.apply(program, uids=uids) - - expected = _make_program( - [im.sym("t", nested)], - im.make_tuple( - im.make_tuple( - im.call(_neg())(im.tuple_get(0, im.tuple_get(0, "t"))), - im.call(_neg())(im.tuple_get(1, im.tuple_get(0, "t"))), - ), - im.make_tuple( - im.call(_neg())(im.tuple_get(0, im.tuple_get(1, "t"))), - im.call(_neg())(im.tuple_get(1, im.tuple_get(1, "t"))), - ), - ), - out_type=nested, - ) - assert result == expected - - -def test_map_tuple_single_arg(): - uids = utils.IDGeneratorPool() - program = _make_program( - [im.sym("t", i_tuple_field)], - im.call(im.call("map_tuple")(_neg()))(im.ref("t", i_tuple_field)), - out_type=i_tuple_field, - ) - result = UnrollTupleMaps.apply(program, uids=uids) - - expected = _make_program( - [im.sym("t", i_tuple_field)], - im.make_tuple( - im.call(_neg())(im.tuple_get(0, "t")), - im.call(_neg())(im.tuple_get(1, "t")), - ), - out_type=i_tuple_field, - ) - assert result == expected - - -def test_map_tuple_does_not_recurse(): - uids = utils.IDGeneratorPool() - nested = ts.TupleType(types=[i_tuple_field, i_tuple_field]) - g = im.lambda_("__p")(im.op_as_fieldop("plus")(im.tuple_get(0, "__p"), im.tuple_get(1, "__p"))) - program = _make_program( - [im.sym("t", nested)], - im.call(im.call("map_tuple")(g))(im.ref("t", nested)), - out_type=i_tuple_field, - ) - result = UnrollTupleMaps.apply(program, uids=uids) - - expected = _make_program( - [im.sym("t", nested)], - im.make_tuple( - im.call(g)(im.tuple_get(0, "t")), - im.call(g)(im.tuple_get(1, "t")), - ), - out_type=i_tuple_field, - ) - assert result == expected - - -def test_make_tuple_arg_is_collapsed(): - """When the input tuple is a `make_tuple` literal, projection should collapse - directly to the element (no residual `tuple_get(make_tuple(...))`).""" - uids = utils.IDGeneratorPool() - program = _make_program( - [im.sym("a", i_field), im.sym("b", i_field)], - im.call(im.call("tree_map_tuple")(_neg()))(im.make_tuple("a", "b")), - out_type=i_tuple_field, - ) - result = UnrollTupleMaps.apply(program, uids=uids) - - expected = _make_program( - [im.sym("a", i_field), im.sym("b", i_field)], - im.make_tuple(im.call(_neg())("a"), im.call(_neg())("b")), - out_type=i_tuple_field, - ) - assert result == expected - - -def test_nested_make_tuple_arg_is_collapsed(): - """A nested `make_tuple` arg should be fully collapsed at every depth: each - `tuple_get(i, make_tuple(...))` along the recursion is folded directly.""" - uids = utils.IDGeneratorPool() - nested = ts.TupleType(types=[i_tuple_field, i_tuple_field]) - program = _make_program( - [ - im.sym("a", i_field), - im.sym("b", i_field), - im.sym("c", i_field), - im.sym("d", i_field), - ], - im.call(im.call("tree_map_tuple")(_neg()))( - im.make_tuple(im.make_tuple("a", "b"), im.make_tuple("c", "d")) - ), - out_type=nested, - ) - result = UnrollTupleMaps.apply(program, uids=uids) - - expected = _make_program( - [ - im.sym("a", i_field), - im.sym("b", i_field), - im.sym("c", i_field), - im.sym("d", i_field), - ], - im.make_tuple( - im.make_tuple(im.call(_neg())("a"), im.call(_neg())("b")), - im.make_tuple(im.call(_neg())("c"), im.call(_neg())("d")), - ), - out_type=nested, - ) - assert result == expected - - -def test_map_tuple_with_make_tuple_arg_is_collapsed(): - """The `make_tuple` short-circuit must also apply for the `map_tuple` builtin.""" - uids = utils.IDGeneratorPool() - program = _make_program( - [im.sym("a", i_field), im.sym("b", i_field)], - im.call(im.call("map_tuple")(_neg()))(im.make_tuple("a", "b")), - out_type=i_tuple_field, - ) - result = UnrollTupleMaps.apply(program, uids=uids) - - expected = _make_program( - [im.sym("a", i_field), im.sym("b", i_field)], - im.make_tuple(im.call(_neg())("a"), im.call(_neg())("b")), - out_type=i_tuple_field, - ) - assert result == expected - - -def test_non_trivial_arg_is_let_bound(): - """Non-trivial (potentially expensive) tuple expressions must still be - let-bound to avoid duplicating work across leaf projections.""" - uids = utils.IDGeneratorPool() - # `f(t)` is a non-trivial expression returning a tuple - f = im.lambda_("__t")(im.ref("__t", i_tuple_field)) - program = _make_program( - [im.sym("t", i_tuple_field)], - im.call(im.call("tree_map_tuple")(_neg()))(im.call(f)(im.ref("t", i_tuple_field))), - out_type=i_tuple_field, - ) - result = UnrollTupleMaps.apply(program, uids=uids) - - expected = _make_program( - [im.sym("t", i_tuple_field)], - im.let("_utm_0", im.call(f)("t"))( - im.make_tuple( - im.call(_neg())(im.tuple_get(0, "_utm_0")), - im.call(_neg())(im.tuple_get(1, "_utm_0")), - ) - ), - out_type=i_tuple_field, - ) - assert result == expected From 3d38868367b4fa88b8aeba07dc403af8893960a9 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 16 Jun 2026 15:04:48 +0200 Subject: [PATCH 16/24] Cleanup --- src/gt4py/next/ffront/foast_to_gtir.py | 4 ---- src/gt4py/next/ffront/lowering_utils.py | 15 +++++++++------ .../next/iterator/transforms/pass_manager.py | 1 - 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 35653a36f2..202b28f22d 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -417,10 +417,6 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: def create_if( true_: itir.Expr, false_: itir.Expr, arg_types: tuple[ts.TypeSpec, ts.TypeSpec] ) -> itir.FunCall: - # Lower each leaf via `_map` so that the per-leaf, type-dependent decision whether - # to wrap `if_` in `map_list` (and promote the condition with `make_const_list`) - # is taken based on the actual leaf types. A single, uniform leaf (e.g. via - # `tree_map_tuple`) can not do this, see `_map` for details. return _map( "if_", (im.ref(cond_symref_name), true_, false_), diff --git a/src/gt4py/next/ffront/lowering_utils.py b/src/gt4py/next/ffront/lowering_utils.py index 4c6ce18da4..727df201d6 100644 --- a/src/gt4py/next/ffront/lowering_utils.py +++ b/src/gt4py/next/ffront/lowering_utils.py @@ -108,8 +108,9 @@ def _tree_map_tuple_body( result_collection_constructor=lambda _, elts: im.make_tuple(*elts), with_path_arg=True, ) - def mapper(*args): + def mapper(*args: ts.TypeSpec | tuple[int, ...]) -> itir.Expr: *_el_types, path = args + assert isinstance(path, tuple), "Expected path to be tuple[int, ...]" return im.call(f)( *(functools.reduce(_collapsing_tuple_get, path, tup_expr) for tup_expr in tup_exprs) ) @@ -143,16 +144,18 @@ def _unroll_tuple_map( uids: utils.IDGeneratorPool, ) -> itir.Expr: tup_exprs = list(tup_exprs) - tup_types = list(tup_types) - for tup_type in tup_types: + tup_types_list = list(tup_types) + for tup_type in tup_types_list: if not isinstance(tup_type, ts.TupleType): raise TypeError( f"'{builtin_name}' requires all arguments to be tuples, got '{tup_type}'." ) - if not type_info.tuple_structures_match(*tup_types): + tup_types_validated: list[ts.TupleType] = tup_types_list # type: ignore[assignment] + + if not type_info.tuple_structures_match(*tup_types_validated): raise TypeError( f"'{builtin_name}' requires all arguments to share the same (nested) tuple " - f"structure, got {[str(t) for t in tup_types]}." + f"structure, got {[str(t) for t in tup_types_validated]}." ) # For trivial args (those that can be duplicated without cost or side effects), @@ -170,7 +173,7 @@ def _unroll_tuple_map( let_bindings.append((ref_name, tup)) substituted_exprs.append(im.ref(ref_name)) - body = _UNROLLERS[builtin_name](f, substituted_exprs, tup_types) + body = _UNROLLERS[builtin_name](f, substituted_exprs, tup_types_validated) return im.let(*let_bindings)(body) if let_bindings else body diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 2baf3d0cc2..8bea65ed29 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -291,7 +291,6 @@ def apply_fieldview_transforms( ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) - ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( From 7d5c86c048b91efaf30f2480a7283851075f1eb6 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 17 Jun 2026 15:16:48 +0200 Subject: [PATCH 17/24] Revert "Cleanup" This reverts commit 3d38868367b4fa88b8aeba07dc403af8893960a9. --- src/gt4py/next/ffront/foast_to_gtir.py | 4 ++++ src/gt4py/next/ffront/lowering_utils.py | 15 ++++++--------- .../next/iterator/transforms/pass_manager.py | 1 + 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 202b28f22d..35653a36f2 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -417,6 +417,10 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: def create_if( true_: itir.Expr, false_: itir.Expr, arg_types: tuple[ts.TypeSpec, ts.TypeSpec] ) -> itir.FunCall: + # Lower each leaf via `_map` so that the per-leaf, type-dependent decision whether + # to wrap `if_` in `map_list` (and promote the condition with `make_const_list`) + # is taken based on the actual leaf types. A single, uniform leaf (e.g. via + # `tree_map_tuple`) can not do this, see `_map` for details. return _map( "if_", (im.ref(cond_symref_name), true_, false_), diff --git a/src/gt4py/next/ffront/lowering_utils.py b/src/gt4py/next/ffront/lowering_utils.py index 727df201d6..4c6ce18da4 100644 --- a/src/gt4py/next/ffront/lowering_utils.py +++ b/src/gt4py/next/ffront/lowering_utils.py @@ -108,9 +108,8 @@ def _tree_map_tuple_body( result_collection_constructor=lambda _, elts: im.make_tuple(*elts), with_path_arg=True, ) - def mapper(*args: ts.TypeSpec | tuple[int, ...]) -> itir.Expr: + def mapper(*args): *_el_types, path = args - assert isinstance(path, tuple), "Expected path to be tuple[int, ...]" return im.call(f)( *(functools.reduce(_collapsing_tuple_get, path, tup_expr) for tup_expr in tup_exprs) ) @@ -144,18 +143,16 @@ def _unroll_tuple_map( uids: utils.IDGeneratorPool, ) -> itir.Expr: tup_exprs = list(tup_exprs) - tup_types_list = list(tup_types) - for tup_type in tup_types_list: + tup_types = list(tup_types) + for tup_type in tup_types: if not isinstance(tup_type, ts.TupleType): raise TypeError( f"'{builtin_name}' requires all arguments to be tuples, got '{tup_type}'." ) - tup_types_validated: list[ts.TupleType] = tup_types_list # type: ignore[assignment] - - if not type_info.tuple_structures_match(*tup_types_validated): + if not type_info.tuple_structures_match(*tup_types): raise TypeError( f"'{builtin_name}' requires all arguments to share the same (nested) tuple " - f"structure, got {[str(t) for t in tup_types_validated]}." + f"structure, got {[str(t) for t in tup_types]}." ) # For trivial args (those that can be duplicated without cost or side effects), @@ -173,7 +170,7 @@ def _unroll_tuple_map( let_bindings.append((ref_name, tup)) substituted_exprs.append(im.ref(ref_name)) - body = _UNROLLERS[builtin_name](f, substituted_exprs, tup_types_validated) + body = _UNROLLERS[builtin_name](f, substituted_exprs, tup_types) return im.let(*let_bindings)(body) if let_bindings else body diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 8bea65ed29..2baf3d0cc2 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -291,6 +291,7 @@ def apply_fieldview_transforms( ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) + ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( From 747f36ed0997b7d46c77fc586e2540b6a88e3300 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 17 Jun 2026 15:16:48 +0200 Subject: [PATCH 18/24] Revert "Refactor tree_map unrolling" This reverts commit b7f8ba9b85f8f52990f13f5b0d1ce6daec97196f. --- src/gt4py/next/ffront/foast_to_gtir.py | 24 +- src/gt4py/next/ffront/lowering_utils.py | 133 +--------- .../next/iterator/transforms/pass_manager.py | 31 +++ .../iterator/transforms/unroll_tuple_maps.py | 113 +++++++++ src/gt4py/next/type_system/type_info.py | 34 --- .../ffront_tests/test_gt4py_builtins.py | 39 --- .../ffront_tests/test_foast_to_gtir.py | 7 +- .../test_unroll_tuple_maps.py | 235 ++++++++++++++++++ 8 files changed, 390 insertions(+), 226 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/unroll_tuple_maps.py create mode 100644 tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 35653a36f2..73674a69da 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -412,27 +412,15 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: return self._lower_and_map("if_", *node.args) cond_ = self.visit(node.args[0]) + true_ = self.visit(node.args[1]) + false_ = self.visit(node.args[2]) cond_symref_name = f"__cond_{cond_.fingerprint()}" - def create_if( - true_: itir.Expr, false_: itir.Expr, arg_types: tuple[ts.TypeSpec, ts.TypeSpec] - ) -> itir.FunCall: - # Lower each leaf via `_map` so that the per-leaf, type-dependent decision whether - # to wrap `if_` in `map_list` (and promote the condition with `make_const_list`) - # is taken based on the actual leaf types. A single, uniform leaf (e.g. via - # `tree_map_tuple`) can not do this, see `_map` for details. - return _map( - "if_", - (im.ref(cond_symref_name), true_, false_), - (node.args[0].type, *arg_types), + result = im.tree_map_tuple( + im.lambda_("__a", "__b")( + im.op_as_fieldop("if_")(im.ref(cond_symref_name), im.ref("__a"), im.ref("__b")) ) - - result = lowering_utils.process_elements( - create_if, - (self.visit(node.args[1]), self.visit(node.args[2])), - node.type, - arg_types=(node.args[1].type, node.args[2].type), - ) + )(true_, false_) return im.let(cond_symref_name, cond_)(result) diff --git a/src/gt4py/next/ffront/lowering_utils.py b/src/gt4py/next/ffront/lowering_utils.py index 4c6ce18da4..8c9434fa26 100644 --- a/src/gt4py/next/ffront/lowering_utils.py +++ b/src/gt4py/next/ffront/lowering_utils.py @@ -6,14 +6,12 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import functools from collections.abc import Iterable from typing import Callable, Optional, TypeVar -from gt4py.next import utils from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next.type_system import type_info, type_specifications as ts +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.type_system import type_specifications as ts # TODO(tehrengruber): The code quality of this function is poor. We should rewrite it. @@ -86,130 +84,3 @@ def _process_elements_impl( result = process_func(*_current_el_exprs) return result - - -def _collapsing_tuple_get(expr: itir.Expr, i: int) -> itir.Expr: - """Like `im.tuple_get`, but collapses immediately when `expr` is a `make_tuple` call. - - Note: argument order is `(expr, i)` to allow use as a `functools.reduce` reducer. - """ - if cpm.is_call_to(expr, "make_tuple"): - return expr.args[i] - return im.tuple_get(i, expr) - - -def _tree_map_tuple_body( - f: itir.Expr, tup_exprs: list[itir.Expr], tup_types: list[ts.TupleType] -) -> itir.Expr: - """Recursively unroll `tree_map_tuple(f)(t1, ..., tN)` into `make_tuple` calls.""" - - @utils.tree_map( - collection_type=ts.TupleType, - result_collection_constructor=lambda _, elts: im.make_tuple(*elts), - with_path_arg=True, - ) - def mapper(*args): - *_el_types, path = args - return im.call(f)( - *(functools.reduce(_collapsing_tuple_get, path, tup_expr) for tup_expr in tup_exprs) - ) - - return mapper(*tup_types) - - -def _map_tuple_body( - f: itir.Expr, tup_exprs: list[itir.Expr], tup_types: list[ts.TupleType] -) -> itir.Expr: - """Unroll `map_tuple(f)(t)` over top-level elements only (no recursion).""" - (tup_expr,) = tup_exprs - (tup_type,) = tup_types - return im.make_tuple( - *(im.call(f)(_collapsing_tuple_get(tup_expr, i)) for i in range(len(tup_type.types))) - ) - - -_UNROLLERS = { - "tree_map_tuple": _tree_map_tuple_body, - "map_tuple": _map_tuple_body, -} - - -def _unroll_tuple_map( - builtin_name: str, - f: itir.Expr, - tup_exprs: Iterable[itir.Expr], - tup_types: Iterable[ts.TypeSpec], - *, - uids: utils.IDGeneratorPool, -) -> itir.Expr: - tup_exprs = list(tup_exprs) - tup_types = list(tup_types) - for tup_type in tup_types: - if not isinstance(tup_type, ts.TupleType): - raise TypeError( - f"'{builtin_name}' requires all arguments to be tuples, got '{tup_type}'." - ) - if not type_info.tuple_structures_match(*tup_types): - raise TypeError( - f"'{builtin_name}' requires all arguments to share the same (nested) tuple " - f"structure, got {[str(t) for t in tup_types]}." - ) - - # For trivial args (those that can be duplicated without cost or side effects), - # we substitute them directly into the body. This avoids leaving behind - # `tuple_get(i, make_tuple(...))` patterns that would otherwise require a - # separate cleanup pass (`CollapseTuple`). For non-trivial args we still - # introduce a `let` binding to avoid duplicating expensive sub-expressions. - substituted_exprs: list[itir.Expr] = [] - let_bindings: list[tuple[str, itir.Expr]] = [] - for tup in tup_exprs: - if isinstance(tup, (itir.SymRef, itir.Literal)) or cpm.is_call_to(tup, "make_tuple"): - substituted_exprs.append(tup) - else: - ref_name = next(uids["__utm"]) - let_bindings.append((ref_name, tup)) - substituted_exprs.append(im.ref(ref_name)) - - body = _UNROLLERS[builtin_name](f, substituted_exprs, tup_types) - return im.let(*let_bindings)(body) if let_bindings else body - - -def unroll_tree_map_tuple( - f: itir.Expr, - tup_exprs: Iterable[itir.Expr], - tup_types: Iterable[ts.TypeSpec], - *, - uids: utils.IDGeneratorPool, -) -> itir.Expr: - """ - Lower ``tree_map_tuple(f)(t1, ..., tN)`` to explicit ``make_tuple`` calls, recursing into - nested tuples and applying ``f`` to each leaf. - - Args: - f: The function to apply at each leaf. - tup_exprs: The (already lowered) tuple argument expressions. - tup_types: The type of each argument in ``tup_exprs``; all must be ``TupleType`` and - share the same (nested) structure. - uids: Used to generate fresh names for `let`-bindings of non-trivial arguments. - """ - return _unroll_tuple_map("tree_map_tuple", f, tup_exprs, tup_types, uids=uids) - - -def unroll_map_tuple( - f: itir.Expr, - tup_expr: itir.Expr, - tup_type: ts.TypeSpec, - *, - uids: utils.IDGeneratorPool, -) -> itir.Expr: - """ - Lower ``map_tuple(f)(t)`` to an explicit ``make_tuple`` call, applying ``f`` to each - top-level element only (no recursion). - - Args: - f: The function to apply to each top-level element. - tup_expr: The (already lowered) tuple argument expression. - tup_type: The type of ``tup_expr``; must be a ``TupleType``. - uids: Used to generate a fresh name for a `let`-binding of a non-trivial argument. - """ - return _unroll_tuple_map("map_tuple", f, (tup_expr,), (tup_type,), uids=uids) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 2baf3d0cc2..f56b9e6296 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -24,6 +24,7 @@ prune_empty_concat_where, remove_broadcast, symbol_ref_utils, + unroll_tuple_maps, ) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple @@ -169,6 +170,23 @@ def apply_common_transforms( ir = inline_lifts.InlineLifts().visit(ir) ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program + # `UnrollTupleMaps` requires fully-inferred tuple types (relies on `reinfer` to see + # nested `TupleType` chains). `expand_tuple_args` runs full type inference, so this is + # the earliest safe position. + ir = unroll_tuple_maps.UnrollTupleMaps.apply(ir, uids=uids) + # `UnrollTupleMaps` collapses `tuple_get(i, make_tuple(...))` patterns on the fly + # for trivial arguments, so no additional `CollapseTuple` cleanup loop is needed. + # A single `CollapseTuple` pass still handles any residual patterns produced when + # arguments had to be let-bound (non-trivial sub-expressions). + ir = CollapseTuple.apply( + ir, + enabled_transformations=( + CollapseTuple.Transformation.PROPAGATE_TUPLE_GET + | CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE + ), + uids=uids, + offset_provider_type=offset_provider_type, + ) # type: ignore[assignment] # always an itir.Program ir = dead_code_elimination.dead_code_elimination( ir, uids=uids, offset_provider_type=offset_provider_type ) # domain inference does not support dead-code @@ -282,6 +300,19 @@ def apply_fieldview_transforms( ir = inline_fundefs.prune_unreferenced_fundefs(ir) # required for dead-code-elimination and `prune_empty_concat_where` pass ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program + # `UnrollTupleMaps` requires fully-inferred tuple types; `expand_tuple_args` runs full + # type inference, so this is the earliest safe position. + ir = unroll_tuple_maps.UnrollTupleMaps.apply(ir, uids=uids) + # See note in `apply_common_transforms` about why a single `CollapseTuple` pass suffices. + ir = CollapseTuple.apply( + ir, + enabled_transformations=( + CollapseTuple.Transformation.PROPAGATE_TUPLE_GET + | CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE + ), + uids=uids, + offset_provider_type=offset_provider_type, + ) # type: ignore[assignment] # always an itir.Program ir = dead_code_elimination.dead_code_elimination( ir, offset_provider_type=offset_provider_type, uids=uids ) diff --git a/src/gt4py/next/iterator/transforms/unroll_tuple_maps.py b/src/gt4py/next/iterator/transforms/unroll_tuple_maps.py new file mode 100644 index 0000000000..9c937219b3 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/unroll_tuple_maps.py @@ -0,0 +1,113 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +import dataclasses +import functools + +from gt4py import eve +from gt4py.next import utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.type_system import inference as itir_inference +from gt4py.next.type_system import type_specifications as ts + + +def _collapsing_tuple_get(expr: itir.Expr, i: int) -> itir.Expr: + """Like `im.tuple_get`, but collapses immediately when `expr` is a `make_tuple` call. + + Note: argument order is `(expr, i)` to allow use as a `functools.reduce` reducer. + """ + if cpm.is_call_to(expr, "make_tuple"): + return expr.args[i] + return im.tuple_get(i, expr) + + +def _tree_map_tuple_body( + f: itir.Expr, tup_exprs: list[itir.Expr], tup_types: list[ts.TupleType] +) -> itir.Expr: + """Recursively unroll `tree_map_tuple(f)(t1, ..., tN)` into `make_tuple` calls.""" + + @utils.tree_map( + collection_type=ts.TupleType, + result_collection_constructor=lambda _, elts: im.make_tuple(*elts), + with_path_arg=True, + ) + def mapper(*args): + *_el_types, path = args + return im.call(f)( + *(functools.reduce(_collapsing_tuple_get, path, tup_expr) for tup_expr in tup_exprs) + ) + + return mapper(*tup_types) + + +def _map_tuple_body( + f: itir.Expr, tup_exprs: list[itir.Expr], tup_types: list[ts.TupleType] +) -> itir.Expr: + """Unroll `map_tuple(f)(t)` over top-level elements only (no recursion).""" + (tup_expr,) = tup_exprs + (tup_type,) = tup_types + return im.make_tuple( + *(im.call(f)(_collapsing_tuple_get(tup_expr, i)) for i in range(len(tup_type.types))) + ) + + +_UNROLLERS = { + "tree_map_tuple": _tree_map_tuple_body, + "map_tuple": _map_tuple_body, +} + + +@dataclasses.dataclass +class UnrollTupleMaps(eve.NodeTranslator): + """Unroll tuple-map ITIR builtins (`tree_map_tuple`, `map_tuple`) into `make_tuple`.""" + + PRESERVED_ANNEX_ATTRS = ("domain",) + + uids: utils.IDGeneratorPool + + @classmethod + def apply(cls, program: itir.Program, *, uids: utils.IDGeneratorPool): + return cls(uids=uids).visit(program) + + def visit_FunCall(self, node: itir.FunCall): + node = self.generic_visit(node) + + builtin_name = next((name for name in _UNROLLERS if cpm.is_call_to(node.fun, name)), None) + if builtin_name is None: + return node + + assert isinstance(node.fun, itir.FunCall) + f = node.fun.args[0] + tup_args = node.args + + tup_types: list[ts.TupleType] = [] + for tup in tup_args: + itir_inference.reinfer(tup) + assert isinstance(tup.type, ts.TupleType) + tup_types.append(tup.type) + + # For trivial args (those that can be duplicated without cost or side effects), + # we substitute them directly into the body. This avoids leaving behind + # `tuple_get(i, make_tuple(...))` patterns that would otherwise require a + # separate cleanup pass (CollapseTuple). For non-trivial args we still + # introduce a `let` binding to avoid duplicating expensive sub-expressions. + substituted_exprs: list[itir.Expr] = [] + let_bindings: list[tuple[str, itir.Expr]] = [] + for tup in tup_args: + if isinstance(tup, (itir.SymRef, itir.Literal)) or cpm.is_call_to(tup, "make_tuple"): + substituted_exprs.append(tup) + else: + ref_name = next(self.uids["_utm"]) + let_bindings.append((ref_name, tup)) + substituted_exprs.append(im.ref(ref_name)) + + body = _UNROLLERS[builtin_name](f, substituted_exprs, tup_types) + + result = im.let(*let_bindings)(body) if let_bindings else body + itir_inference.reinfer(result) + return result diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index edd40d4735..eb70d15947 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -442,40 +442,6 @@ def contains_local_field(type_: ts.TypeSpec) -> bool: ) -def tuple_structures_match(*types: ts.TypeSpec) -> bool: - """ - Return if all `types` share the same (nested) tuple structure. - - Only the tuple skeleton is compared; the leaf types are not required to be equal. - - Examples: - --------- - >>> i = ts.ScalarType(kind=ts.ScalarKind.INT32) - >>> f = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - >>> tuple_structures_match(ts.TupleType(types=[i, i]), ts.TupleType(types=[f, f])) - True - >>> tuple_structures_match(ts.TupleType(types=[i, i]), ts.TupleType(types=[i])) - False - >>> tuple_structures_match( - ... ts.TupleType(types=[i, ts.TupleType(types=[i])]), - ... ts.TupleType(types=[i, i]), - ... ) - False - """ - if not types: - return True - first, *rest = types - for other in rest: - if isinstance(first, ts.TupleType) != isinstance(other, ts.TupleType): - return False - if isinstance(first, ts.TupleType) and isinstance(other, ts.TupleType): - if len(first.types) != len(other.types): - return False - if not all(tuple_structures_match(f, o) for f, o in zip(first.types, other.types)): - return False - return True - - # TODO(tehrengruber): This function has specializations on Iterator types, which are not part of # the general / shared type system. This functionality should be moved to the iterator-only # type system, but we need some sort of multiple dispatch for that. diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 035d30ef55..22f69e8a6e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -275,45 +275,6 @@ def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField: ) -@pytest.mark.uses_unstructured_shift -def test_reduction_expression_with_where_tuple_local_condition(unstructured_case): - # Regression test: a tuple-`where` whose condition is itself a per-neighbour (local) field. - # The condition lowers to a list-typed predicate, so each tuple leaf must be lowered as - # `map_list(if_)` (with the condition promoted via `make_const_list`), exactly like the - # non-tuple path does. A uniform leaf (e.g. `op_as_fieldop("if_")`) would produce invalid IR. - @gtx.field_operator - def testee(a: cases.EField, b: cases.EField) -> cases.VField: - cond = a(V2E) > b(V2E) # per-neighbour (local) bool field - t = where(cond, (a(V2E), b(V2E)), (b(V2E), a(V2E))) - return neighbor_sum(t[0] + t[1], axis=V2EDim) - - v2e_table = unstructured_case.offset_provider["V2E"].asnumpy() - - a = cases.allocate(unstructured_case, testee, "a")() - b = cases.allocate(unstructured_case, testee, "b")() - out = cases.allocate(unstructured_case, testee, cases.RETURN)() - - a_nbh = a.asnumpy()[v2e_table] - b_nbh = b.asnumpy()[v2e_table] - cond_nbh = a_nbh > b_nbh - t0 = np.where(cond_nbh, a_nbh, b_nbh) - t1 = np.where(cond_nbh, b_nbh, a_nbh) - - cases.verify( - unstructured_case, - testee, - a, - b, - out=out, - ref=np.sum( - t0 + t1, - axis=1, - initial=0, - where=v2e_table != common._DEFAULT_SKIP_VALUE, - ), - ) - - @pytest.mark.uses_unstructured_shift def test_reduction_expression_with_where_and_scalar(unstructured_case): @gtx.field_operator diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index d27ffc6a43..4c6bf4d34a 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -207,10 +207,9 @@ def foo( lowered ) # we generate a let for the condition which is removed by inlining for easier testing - reference = im.make_tuple( - im.op_as_fieldop("if_")("a", im.tuple_get(0, "b"), im.tuple_get(0, "c")), - im.op_as_fieldop("if_")("a", im.tuple_get(1, "b"), im.tuple_get(1, "c")), - ) + reference = im.tree_map_tuple( + im.lambda_("__a", "__b")(im.op_as_fieldop("if_")("a", im.ref("__a"), im.ref("__b"))) + )("b", "c") assert lowered_inlined.expr == reference diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py new file mode 100644 index 0000000000..593f519cc2 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py @@ -0,0 +1,235 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py.next import common, utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.unroll_tuple_maps import UnrollTupleMaps +from gt4py.next.type_system import type_specifications as ts + +IDim = common.Dimension("IDim") +T = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) +i_field = ts.FieldType(dims=[IDim], dtype=T) +i_tuple_field = ts.TupleType(types=[i_field, i_field]) + +i_domain = im.call("cartesian_domain")(im.named_range(itir.AxisLiteral(value="IDim"), 0, 1)) + + +def _make_program( + params: list[itir.Sym], expr: itir.Expr, out_type: ts.TypeSpec = i_field +) -> itir.Program: + return itir.Program( + id="testee", + function_definitions=[], + params=[*params, im.sym("out", out_type)], + declarations=[], + body=[ + itir.SetAt( + expr=expr, + domain=i_domain, + target=im.ref("out", out_type), + ) + ], + ) + + +def _neg(): + return im.lambda_("__a")(im.op_as_fieldop("neg")("__a")) + + +def _plus(): + return im.lambda_("__a", "__b")(im.op_as_fieldop("plus")("__a", "__b")) + + +def test_tree_map_tuple_multi_arg(): + uids = utils.IDGeneratorPool() + program = _make_program( + [im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)], + im.call(im.call("tree_map_tuple")(_plus()))( + im.ref("a", i_tuple_field), im.ref("b", i_tuple_field) + ), + out_type=i_tuple_field, + ) + result = UnrollTupleMaps.apply(program, uids=uids) + + expected = _make_program( + [im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)], + im.make_tuple( + im.call(_plus())(im.tuple_get(0, "a"), im.tuple_get(0, "b")), + im.call(_plus())(im.tuple_get(1, "a"), im.tuple_get(1, "b")), + ), + out_type=i_tuple_field, + ) + assert result == expected + + +def test_tree_map_tuple_nested(): + uids = utils.IDGeneratorPool() + nested = ts.TupleType(types=[i_tuple_field, i_tuple_field]) + program = _make_program( + [im.sym("t", nested)], + im.call(im.call("tree_map_tuple")(_neg()))(im.ref("t", nested)), + out_type=nested, + ) + result = UnrollTupleMaps.apply(program, uids=uids) + + expected = _make_program( + [im.sym("t", nested)], + im.make_tuple( + im.make_tuple( + im.call(_neg())(im.tuple_get(0, im.tuple_get(0, "t"))), + im.call(_neg())(im.tuple_get(1, im.tuple_get(0, "t"))), + ), + im.make_tuple( + im.call(_neg())(im.tuple_get(0, im.tuple_get(1, "t"))), + im.call(_neg())(im.tuple_get(1, im.tuple_get(1, "t"))), + ), + ), + out_type=nested, + ) + assert result == expected + + +def test_map_tuple_single_arg(): + uids = utils.IDGeneratorPool() + program = _make_program( + [im.sym("t", i_tuple_field)], + im.call(im.call("map_tuple")(_neg()))(im.ref("t", i_tuple_field)), + out_type=i_tuple_field, + ) + result = UnrollTupleMaps.apply(program, uids=uids) + + expected = _make_program( + [im.sym("t", i_tuple_field)], + im.make_tuple( + im.call(_neg())(im.tuple_get(0, "t")), + im.call(_neg())(im.tuple_get(1, "t")), + ), + out_type=i_tuple_field, + ) + assert result == expected + + +def test_map_tuple_does_not_recurse(): + uids = utils.IDGeneratorPool() + nested = ts.TupleType(types=[i_tuple_field, i_tuple_field]) + g = im.lambda_("__p")(im.op_as_fieldop("plus")(im.tuple_get(0, "__p"), im.tuple_get(1, "__p"))) + program = _make_program( + [im.sym("t", nested)], + im.call(im.call("map_tuple")(g))(im.ref("t", nested)), + out_type=i_tuple_field, + ) + result = UnrollTupleMaps.apply(program, uids=uids) + + expected = _make_program( + [im.sym("t", nested)], + im.make_tuple( + im.call(g)(im.tuple_get(0, "t")), + im.call(g)(im.tuple_get(1, "t")), + ), + out_type=i_tuple_field, + ) + assert result == expected + + +def test_make_tuple_arg_is_collapsed(): + """When the input tuple is a `make_tuple` literal, projection should collapse + directly to the element (no residual `tuple_get(make_tuple(...))`).""" + uids = utils.IDGeneratorPool() + program = _make_program( + [im.sym("a", i_field), im.sym("b", i_field)], + im.call(im.call("tree_map_tuple")(_neg()))(im.make_tuple("a", "b")), + out_type=i_tuple_field, + ) + result = UnrollTupleMaps.apply(program, uids=uids) + + expected = _make_program( + [im.sym("a", i_field), im.sym("b", i_field)], + im.make_tuple(im.call(_neg())("a"), im.call(_neg())("b")), + out_type=i_tuple_field, + ) + assert result == expected + + +def test_nested_make_tuple_arg_is_collapsed(): + """A nested `make_tuple` arg should be fully collapsed at every depth: each + `tuple_get(i, make_tuple(...))` along the recursion is folded directly.""" + uids = utils.IDGeneratorPool() + nested = ts.TupleType(types=[i_tuple_field, i_tuple_field]) + program = _make_program( + [ + im.sym("a", i_field), + im.sym("b", i_field), + im.sym("c", i_field), + im.sym("d", i_field), + ], + im.call(im.call("tree_map_tuple")(_neg()))( + im.make_tuple(im.make_tuple("a", "b"), im.make_tuple("c", "d")) + ), + out_type=nested, + ) + result = UnrollTupleMaps.apply(program, uids=uids) + + expected = _make_program( + [ + im.sym("a", i_field), + im.sym("b", i_field), + im.sym("c", i_field), + im.sym("d", i_field), + ], + im.make_tuple( + im.make_tuple(im.call(_neg())("a"), im.call(_neg())("b")), + im.make_tuple(im.call(_neg())("c"), im.call(_neg())("d")), + ), + out_type=nested, + ) + assert result == expected + + +def test_map_tuple_with_make_tuple_arg_is_collapsed(): + """The `make_tuple` short-circuit must also apply for the `map_tuple` builtin.""" + uids = utils.IDGeneratorPool() + program = _make_program( + [im.sym("a", i_field), im.sym("b", i_field)], + im.call(im.call("map_tuple")(_neg()))(im.make_tuple("a", "b")), + out_type=i_tuple_field, + ) + result = UnrollTupleMaps.apply(program, uids=uids) + + expected = _make_program( + [im.sym("a", i_field), im.sym("b", i_field)], + im.make_tuple(im.call(_neg())("a"), im.call(_neg())("b")), + out_type=i_tuple_field, + ) + assert result == expected + + +def test_non_trivial_arg_is_let_bound(): + """Non-trivial (potentially expensive) tuple expressions must still be + let-bound to avoid duplicating work across leaf projections.""" + uids = utils.IDGeneratorPool() + # `f(t)` is a non-trivial expression returning a tuple + f = im.lambda_("__t")(im.ref("__t", i_tuple_field)) + program = _make_program( + [im.sym("t", i_tuple_field)], + im.call(im.call("tree_map_tuple")(_neg()))(im.call(f)(im.ref("t", i_tuple_field))), + out_type=i_tuple_field, + ) + result = UnrollTupleMaps.apply(program, uids=uids) + + expected = _make_program( + [im.sym("t", i_tuple_field)], + im.let("_utm_0", im.call(f)("t"))( + im.make_tuple( + im.call(_neg())(im.tuple_get(0, "_utm_0")), + im.call(_neg())(im.tuple_get(1, "_utm_0")), + ) + ), + out_type=i_tuple_field, + ) + assert result == expected From b7677005f56a51c389e3c771e7a471bc382a4811 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 17 Jun 2026 15:35:19 +0200 Subject: [PATCH 19/24] Cleanup --- src/gt4py/next/iterator/transforms/pass_manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index f56b9e6296..1689f02977 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -322,7 +322,6 @@ def apply_fieldview_transforms( ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) - ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( From 7b270a3c729b424e929cce10734d78e7b69f6fbc Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 18 Jun 2026 13:51:00 +0200 Subject: [PATCH 20/24] Address review comment --- .../iterator/transforms/unroll_tuple_maps.py | 9 +++++++++ .../iterator/type_system/type_synthesizer.py | 14 +++++++++++++ .../ffront_tests/test_reductions.py | 4 +++- .../iterator_tests/test_type_inference.py | 18 +++++++++++++++++ .../test_unroll_tuple_maps.py | 20 +++++++++++++++++++ 5 files changed, 64 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/unroll_tuple_maps.py b/src/gt4py/next/iterator/transforms/unroll_tuple_maps.py index 9c937219b3..d2ecae3f96 100644 --- a/src/gt4py/next/iterator/transforms/unroll_tuple_maps.py +++ b/src/gt4py/next/iterator/transforms/unroll_tuple_maps.py @@ -31,6 +31,15 @@ def _tree_map_tuple_body( ) -> itir.Expr: """Recursively unroll `tree_map_tuple(f)(t1, ..., tN)` into `make_tuple` calls.""" + def tuple_structure(type_: ts.TypeSpec) -> tuple[object, ...] | None: + if isinstance(type_, ts.TupleType): + return tuple(tuple_structure(el_type) for el_type in type_.types) + return None + + expected_structure = tuple_structure(tup_types[0]) + if any(tuple_structure(tup_type) != expected_structure for tup_type in tup_types[1:]): + raise TypeError("'tree_map_tuple' requires all arguments to have the same tuple structure.") + @utils.tree_map( collection_type=ts.TupleType, result_collection_constructor=lambda _, elts: im.make_tuple(*elts), diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 9bc5b2d317..111d713ae3 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -631,6 +631,18 @@ def _tuple_map_synthesizer( ) -> Callable[..., TypeOrTypeSynthesizer]: """Shared implementation for `tree_map_tuple` (recursive) and `map_tuple` (top-level).""" + def tuple_structure(type_: ts.TypeSpec) -> tuple[object, ...] | None: + if isinstance(type_, ts.TupleType): + return tuple(tuple_structure(el_type) for el_type in type_.types) + return None + + def ensure_same_tuple_structure(args: tuple[ts.TupleType, ...]) -> None: + expected_structure = tuple_structure(args[0]) + if any(tuple_structure(arg) != expected_structure for arg in args[1:]): + raise TypeError( + f"'{builtin_name}' requires all arguments to have the same tuple structure." + ) + def factory(op: TypeSynthesizer) -> TypeSynthesizer: @type_synthesizer def applied_map( @@ -645,6 +657,8 @@ def applied_map( f"'{builtin_name}' requires all top-level arguments to be TupleType, " f"got {[type(a).__name__ for a in args]}." ) + if recursive: + ensure_same_tuple_structure(args) def leaf_op(*leaf_types: ts.TypeSpec) -> ts.TypeSpec: return op(*leaf_types, offset_provider_type=offset_provider_type) # type: ignore[return-value] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_reductions.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_reductions.py index a2d495babb..f2d413ff06 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_reductions.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_reductions.py @@ -465,7 +465,9 @@ def testee(a: cases.VField) -> cases.VField: @pytest.mark.uses_unstructured_shift -@pytest.mark.xfail(reason="Not yet supported in lowering, requires `map_list`ing of inner reduce op.") +@pytest.mark.xfail( + reason="Not yet supported in lowering, requires `map_list`ing of inner reduce op." +) def test_nested_reduction_shift_first(unstructured_case): @gtx.field_operator def testee(inp: cases.EField) -> cases.EField: diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index e39ece4916..9752600428 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -324,6 +324,24 @@ def test_expression_type(test_case): assert result.type == expected_type +@pytest.mark.parametrize( + "testee", + [ + im.tree_map_tuple(im.ref("plus"))( + im.ref("t1", ts.TupleType(types=[int_type, int_type, int_type])), + im.ref("t2", ts.TupleType(types=[int_type, int_type])), + ), + im.tree_map_tuple(im.ref("plus"))( + im.ref("t1", ts.TupleType(types=[int_type, ts.TupleType(types=[int_type, int_type])])), + im.ref("t2", ts.TupleType(types=[int_type, int_type])), + ), + ], +) +def test_tree_map_tuple_mismatched_structure_raises_type_error(testee): + with pytest.raises(TypeError, match=r"same tuple structure"): + itir_type_inference.infer(testee, offset_provider_type={}, allow_undeclared_symbols=True) + + @pytest.mark.parametrize( "test_case", [(expr, type_) for expr, type_ in expression_test_cases() if cpm.is_applied_as_fieldop(expr)], diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py index 593f519cc2..a0ea585a38 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import pytest + from gt4py.next import common, utils from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im @@ -233,3 +235,21 @@ def test_non_trivial_arg_is_let_bound(): out_type=i_tuple_field, ) assert result == expected + + +@pytest.mark.parametrize( + "lhs_type, rhs_type", + [ + (ts.TupleType(types=[i_field, i_field, i_field]), i_tuple_field), + (ts.TupleType(types=[i_field, i_tuple_field]), i_tuple_field), + ], +) +def test_tree_map_tuple_mismatched_structure_raises_type_error(lhs_type, rhs_type): + uids = utils.IDGeneratorPool() + program = _make_program( + [im.sym("a", lhs_type), im.sym("b", rhs_type)], + im.call(im.call("tree_map_tuple")(_plus()))(im.ref("a", lhs_type), im.ref("b", rhs_type)), + ) + + with pytest.raises(TypeError, match=r"same tuple structure"): + UnrollTupleMaps.apply(program, uids=uids) From d3d4e461835df663ba5ab041bb4df31b1d60dde7 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 19 Jun 2026 10:53:46 +0200 Subject: [PATCH 21/24] Remove CollapseTuple pass after UnrollTupleMaps --- .../next/iterator/transforms/pass_manager.py | 24 +------------------ 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 1689f02977..58f9d36958 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -174,19 +174,6 @@ def apply_common_transforms( # nested `TupleType` chains). `expand_tuple_args` runs full type inference, so this is # the earliest safe position. ir = unroll_tuple_maps.UnrollTupleMaps.apply(ir, uids=uids) - # `UnrollTupleMaps` collapses `tuple_get(i, make_tuple(...))` patterns on the fly - # for trivial arguments, so no additional `CollapseTuple` cleanup loop is needed. - # A single `CollapseTuple` pass still handles any residual patterns produced when - # arguments had to be let-bound (non-trivial sub-expressions). - ir = CollapseTuple.apply( - ir, - enabled_transformations=( - CollapseTuple.Transformation.PROPAGATE_TUPLE_GET - | CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE - ), - uids=uids, - offset_provider_type=offset_provider_type, - ) # type: ignore[assignment] # always an itir.Program ir = dead_code_elimination.dead_code_elimination( ir, uids=uids, offset_provider_type=offset_provider_type ) # domain inference does not support dead-code @@ -303,16 +290,7 @@ def apply_fieldview_transforms( # `UnrollTupleMaps` requires fully-inferred tuple types; `expand_tuple_args` runs full # type inference, so this is the earliest safe position. ir = unroll_tuple_maps.UnrollTupleMaps.apply(ir, uids=uids) - # See note in `apply_common_transforms` about why a single `CollapseTuple` pass suffices. - ir = CollapseTuple.apply( - ir, - enabled_transformations=( - CollapseTuple.Transformation.PROPAGATE_TUPLE_GET - | CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE - ), - uids=uids, - offset_provider_type=offset_provider_type, - ) # type: ignore[assignment] # always an itir.Program + ir = dead_code_elimination.dead_code_elimination( ir, offset_provider_type=offset_provider_type, uids=uids ) From d0272df3d1bc8251b10052bb5724a7fcdfac20a0 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 19 Jun 2026 11:34:35 +0200 Subject: [PATCH 22/24] Remove program wrapper in tests --- .../test_unroll_tuple_maps.py | 195 +++++------------- 1 file changed, 54 insertions(+), 141 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py index 593f519cc2..bd04b60e69 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py @@ -10,32 +10,19 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.unroll_tuple_maps import UnrollTupleMaps +from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.type_system import type_specifications as ts + IDim = common.Dimension("IDim") T = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) i_field = ts.FieldType(dims=[IDim], dtype=T) i_tuple_field = ts.TupleType(types=[i_field, i_field]) -i_domain = im.call("cartesian_domain")(im.named_range(itir.AxisLiteral(value="IDim"), 0, 1)) - - -def _make_program( - params: list[itir.Sym], expr: itir.Expr, out_type: ts.TypeSpec = i_field -) -> itir.Program: - return itir.Program( - id="testee", - function_definitions=[], - params=[*params, im.sym("out", out_type)], - declarations=[], - body=[ - itir.SetAt( - expr=expr, - domain=i_domain, - target=im.ref("out", out_type), - ) - ], - ) + +def _apply(expr: itir.Expr) -> itir.Expr: + expr = itir_type_inference.infer(expr, offset_provider_type={}, allow_undeclared_symbols=True) + return UnrollTupleMaps(uids=utils.IDGeneratorPool()).visit(expr) def _neg(): @@ -47,92 +34,54 @@ def _plus(): def test_tree_map_tuple_multi_arg(): - uids = utils.IDGeneratorPool() - program = _make_program( - [im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)], + result = _apply( im.call(im.call("tree_map_tuple")(_plus()))( im.ref("a", i_tuple_field), im.ref("b", i_tuple_field) - ), - out_type=i_tuple_field, + ) ) - result = UnrollTupleMaps.apply(program, uids=uids) - expected = _make_program( - [im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)], - im.make_tuple( - im.call(_plus())(im.tuple_get(0, "a"), im.tuple_get(0, "b")), - im.call(_plus())(im.tuple_get(1, "a"), im.tuple_get(1, "b")), - ), - out_type=i_tuple_field, + expected = im.make_tuple( + im.call(_plus())(im.tuple_get(0, "a"), im.tuple_get(0, "b")), + im.call(_plus())(im.tuple_get(1, "a"), im.tuple_get(1, "b")), ) assert result == expected def test_tree_map_tuple_nested(): - uids = utils.IDGeneratorPool() nested = ts.TupleType(types=[i_tuple_field, i_tuple_field]) - program = _make_program( - [im.sym("t", nested)], - im.call(im.call("tree_map_tuple")(_neg()))(im.ref("t", nested)), - out_type=nested, - ) - result = UnrollTupleMaps.apply(program, uids=uids) + result = _apply(im.call(im.call("tree_map_tuple")(_neg()))(im.ref("t", nested))) - expected = _make_program( - [im.sym("t", nested)], + expected = im.make_tuple( im.make_tuple( - im.make_tuple( - im.call(_neg())(im.tuple_get(0, im.tuple_get(0, "t"))), - im.call(_neg())(im.tuple_get(1, im.tuple_get(0, "t"))), - ), - im.make_tuple( - im.call(_neg())(im.tuple_get(0, im.tuple_get(1, "t"))), - im.call(_neg())(im.tuple_get(1, im.tuple_get(1, "t"))), - ), + im.call(_neg())(im.tuple_get(0, im.tuple_get(0, "t"))), + im.call(_neg())(im.tuple_get(1, im.tuple_get(0, "t"))), + ), + im.make_tuple( + im.call(_neg())(im.tuple_get(0, im.tuple_get(1, "t"))), + im.call(_neg())(im.tuple_get(1, im.tuple_get(1, "t"))), ), - out_type=nested, ) assert result == expected def test_map_tuple_single_arg(): - uids = utils.IDGeneratorPool() - program = _make_program( - [im.sym("t", i_tuple_field)], - im.call(im.call("map_tuple")(_neg()))(im.ref("t", i_tuple_field)), - out_type=i_tuple_field, - ) - result = UnrollTupleMaps.apply(program, uids=uids) + result = _apply(im.call(im.call("map_tuple")(_neg()))(im.ref("t", i_tuple_field))) - expected = _make_program( - [im.sym("t", i_tuple_field)], - im.make_tuple( - im.call(_neg())(im.tuple_get(0, "t")), - im.call(_neg())(im.tuple_get(1, "t")), - ), - out_type=i_tuple_field, + expected = im.make_tuple( + im.call(_neg())(im.tuple_get(0, "t")), + im.call(_neg())(im.tuple_get(1, "t")), ) assert result == expected def test_map_tuple_does_not_recurse(): - uids = utils.IDGeneratorPool() nested = ts.TupleType(types=[i_tuple_field, i_tuple_field]) g = im.lambda_("__p")(im.op_as_fieldop("plus")(im.tuple_get(0, "__p"), im.tuple_get(1, "__p"))) - program = _make_program( - [im.sym("t", nested)], - im.call(im.call("map_tuple")(g))(im.ref("t", nested)), - out_type=i_tuple_field, - ) - result = UnrollTupleMaps.apply(program, uids=uids) + result = _apply(im.call(im.call("map_tuple")(g))(im.ref("t", nested))) - expected = _make_program( - [im.sym("t", nested)], - im.make_tuple( - im.call(g)(im.tuple_get(0, "t")), - im.call(g)(im.tuple_get(1, "t")), - ), - out_type=i_tuple_field, + expected = im.make_tuple( + im.call(g)(im.tuple_get(0, "t")), + im.call(g)(im.tuple_get(1, "t")), ) assert result == expected @@ -140,96 +89,60 @@ def test_map_tuple_does_not_recurse(): def test_make_tuple_arg_is_collapsed(): """When the input tuple is a `make_tuple` literal, projection should collapse directly to the element (no residual `tuple_get(make_tuple(...))`).""" - uids = utils.IDGeneratorPool() - program = _make_program( - [im.sym("a", i_field), im.sym("b", i_field)], - im.call(im.call("tree_map_tuple")(_neg()))(im.make_tuple("a", "b")), - out_type=i_tuple_field, + result = _apply( + im.call(im.call("tree_map_tuple")(_neg()))( + im.make_tuple(im.ref("a", i_field), im.ref("b", i_field)) + ) ) - result = UnrollTupleMaps.apply(program, uids=uids) - expected = _make_program( - [im.sym("a", i_field), im.sym("b", i_field)], - im.make_tuple(im.call(_neg())("a"), im.call(_neg())("b")), - out_type=i_tuple_field, - ) + expected = im.make_tuple(im.call(_neg())("a"), im.call(_neg())("b")) assert result == expected def test_nested_make_tuple_arg_is_collapsed(): """A nested `make_tuple` arg should be fully collapsed at every depth: each `tuple_get(i, make_tuple(...))` along the recursion is folded directly.""" - uids = utils.IDGeneratorPool() - nested = ts.TupleType(types=[i_tuple_field, i_tuple_field]) - program = _make_program( - [ - im.sym("a", i_field), - im.sym("b", i_field), - im.sym("c", i_field), - im.sym("d", i_field), - ], + result = _apply( im.call(im.call("tree_map_tuple")(_neg()))( - im.make_tuple(im.make_tuple("a", "b"), im.make_tuple("c", "d")) - ), - out_type=nested, + im.make_tuple( + im.make_tuple(im.ref("a", i_field), im.ref("b", i_field)), + im.make_tuple(im.ref("c", i_field), im.ref("d", i_field)), + ) + ) ) - result = UnrollTupleMaps.apply(program, uids=uids) - - expected = _make_program( - [ - im.sym("a", i_field), - im.sym("b", i_field), - im.sym("c", i_field), - im.sym("d", i_field), - ], - im.make_tuple( - im.make_tuple(im.call(_neg())("a"), im.call(_neg())("b")), - im.make_tuple(im.call(_neg())("c"), im.call(_neg())("d")), - ), - out_type=nested, + + expected = im.make_tuple( + im.make_tuple(im.call(_neg())("a"), im.call(_neg())("b")), + im.make_tuple(im.call(_neg())("c"), im.call(_neg())("d")), ) assert result == expected def test_map_tuple_with_make_tuple_arg_is_collapsed(): """The `make_tuple` short-circuit must also apply for the `map_tuple` builtin.""" - uids = utils.IDGeneratorPool() - program = _make_program( - [im.sym("a", i_field), im.sym("b", i_field)], - im.call(im.call("map_tuple")(_neg()))(im.make_tuple("a", "b")), - out_type=i_tuple_field, + result = _apply( + im.call(im.call("map_tuple")(_neg()))( + im.make_tuple(im.ref("a", i_field), im.ref("b", i_field)) + ) ) - result = UnrollTupleMaps.apply(program, uids=uids) - expected = _make_program( - [im.sym("a", i_field), im.sym("b", i_field)], - im.make_tuple(im.call(_neg())("a"), im.call(_neg())("b")), - out_type=i_tuple_field, - ) + expected = im.make_tuple(im.call(_neg())("a"), im.call(_neg())("b")) assert result == expected def test_non_trivial_arg_is_let_bound(): """Non-trivial (potentially expensive) tuple expressions must still be let-bound to avoid duplicating work across leaf projections.""" - uids = utils.IDGeneratorPool() # `f(t)` is a non-trivial expression returning a tuple f = im.lambda_("__t")(im.ref("__t", i_tuple_field)) - program = _make_program( - [im.sym("t", i_tuple_field)], - im.call(im.call("tree_map_tuple")(_neg()))(im.call(f)(im.ref("t", i_tuple_field))), - out_type=i_tuple_field, + result = _apply( + im.call(im.call("tree_map_tuple")(_neg()))(im.call(f)(im.ref("t", i_tuple_field))) ) - result = UnrollTupleMaps.apply(program, uids=uids) - expected = _make_program( - [im.sym("t", i_tuple_field)], - im.let("_utm_0", im.call(f)("t"))( - im.make_tuple( - im.call(_neg())(im.tuple_get(0, "_utm_0")), - im.call(_neg())(im.tuple_get(1, "_utm_0")), - ) - ), - out_type=i_tuple_field, + expected = im.let("_utm_0", im.call(f)("t"))( + im.make_tuple( + im.call(_neg())(im.tuple_get(0, "_utm_0")), + im.call(_neg())(im.tuple_get(1, "_utm_0")), + ) ) assert result == expected From b7bb0b2fbc31c9eb93f393c7182f7415fe837820 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Fri, 19 Jun 2026 13:02:50 +0200 Subject: [PATCH 23/24] Fix test --- .../transforms_tests/test_unroll_tuple_maps.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py index 542a6f68d4..550cb8ac46 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py @@ -158,11 +158,7 @@ def test_non_trivial_arg_is_let_bound(): ], ) def test_tree_map_tuple_mismatched_structure_raises_type_error(lhs_type, rhs_type): - uids = utils.IDGeneratorPool() - program = _make_program( - [im.sym("a", lhs_type), im.sym("b", rhs_type)], - im.call(im.call("tree_map_tuple")(_plus()))(im.ref("a", lhs_type), im.ref("b", rhs_type)), - ) + expr = im.call(im.call("tree_map_tuple")(_plus()))(im.ref("a", lhs_type), im.ref("b", rhs_type)) with pytest.raises(TypeError, match=r"same tuple structure"): - UnrollTupleMaps.apply(program, uids=uids) + _apply(expr) From 7d8f56c1c7431d095f6b693641b69f352c90479e Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 24 Jun 2026 17:29:36 +0200 Subject: [PATCH 24/24] Also allow itir.Expr in UnrollTupleMaps and run tye_inference when necessary --- .../next/iterator/transforms/pass_manager.py | 15 +++++++----- .../iterator/transforms/unroll_tuple_maps.py | 24 ++++++++++++++++--- .../test_unroll_tuple_maps.py | 15 +++++++++--- 3 files changed, 42 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 58f9d36958..68beba6ec3 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -171,9 +171,10 @@ def apply_common_transforms( ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program # `UnrollTupleMaps` requires fully-inferred tuple types (relies on `reinfer` to see - # nested `TupleType` chains). `expand_tuple_args` runs full type inference, so this is - # the earliest safe position. - ir = unroll_tuple_maps.UnrollTupleMaps.apply(ir, uids=uids) + # nested `TupleType` chains), so the offset_provider is passed for on-demand inference. + ir = unroll_tuple_maps.UnrollTupleMaps.apply( + ir, uids=uids, offset_provider_type=offset_provider_type + ) ir = dead_code_elimination.dead_code_elimination( ir, uids=uids, offset_provider_type=offset_provider_type ) # domain inference does not support dead-code @@ -287,9 +288,11 @@ def apply_fieldview_transforms( ir = inline_fundefs.prune_unreferenced_fundefs(ir) # required for dead-code-elimination and `prune_empty_concat_where` pass ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program - # `UnrollTupleMaps` requires fully-inferred tuple types; `expand_tuple_args` runs full - # type inference, so this is the earliest safe position. - ir = unroll_tuple_maps.UnrollTupleMaps.apply(ir, uids=uids) + # `UnrollTupleMaps` requires fully-inferred tuple types, so the offset_provider is passed for + # on-demand inference. + ir = unroll_tuple_maps.UnrollTupleMaps.apply( + ir, uids=uids, offset_provider_type=offset_provider_type + ) ir = dead_code_elimination.dead_code_elimination( ir, offset_provider_type=offset_provider_type, uids=uids diff --git a/src/gt4py/next/iterator/transforms/unroll_tuple_maps.py b/src/gt4py/next/iterator/transforms/unroll_tuple_maps.py index d2ecae3f96..461855faab 100644 --- a/src/gt4py/next/iterator/transforms/unroll_tuple_maps.py +++ b/src/gt4py/next/iterator/transforms/unroll_tuple_maps.py @@ -7,9 +7,10 @@ # SPDX-License-Identifier: BSD-3-Clause import dataclasses import functools +from typing import TypeVar from gt4py import eve -from gt4py.next import utils +from gt4py.next import common, utils from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.type_system import inference as itir_inference @@ -71,6 +72,9 @@ def _map_tuple_body( } +ProgramOrExpr = TypeVar("ProgramOrExpr", bound=itir.Program | itir.Expr) + + @dataclasses.dataclass class UnrollTupleMaps(eve.NodeTranslator): """Unroll tuple-map ITIR builtins (`tree_map_tuple`, `map_tuple`) into `make_tuple`.""" @@ -80,8 +84,22 @@ class UnrollTupleMaps(eve.NodeTranslator): uids: utils.IDGeneratorPool @classmethod - def apply(cls, program: itir.Program, *, uids: utils.IDGeneratorPool): - return cls(uids=uids).visit(program) + def apply( + cls, + node: ProgramOrExpr, + *, + uids: utils.IDGeneratorPool | None = None, + offset_provider_type: common.OffsetProviderType | None = None, + ) -> ProgramOrExpr: + if node.type is None: + node = itir_inference.infer( + node, + offset_provider_type=offset_provider_type or {}, + allow_undeclared_symbols=not isinstance(node, itir.Program), + ) + if uids is None: + uids = utils.IDGeneratorPool() + return cls(uids=uids).visit(node) def visit_FunCall(self, node: itir.FunCall): node = self.generic_visit(node) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py index 550cb8ac46..d81acb4e79 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tuple_maps.py @@ -12,7 +12,6 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.unroll_tuple_maps import UnrollTupleMaps -from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.type_system import type_specifications as ts @@ -23,8 +22,7 @@ def _apply(expr: itir.Expr) -> itir.Expr: - expr = itir_type_inference.infer(expr, offset_provider_type={}, allow_undeclared_symbols=True) - return UnrollTupleMaps(uids=utils.IDGeneratorPool()).visit(expr) + return UnrollTupleMaps.apply(expr, uids=utils.IDGeneratorPool(), offset_provider_type={}) def _neg(): @@ -76,6 +74,17 @@ def test_map_tuple_single_arg(): assert result == expected +def test_apply_infers_uninferred_expr(): + expr = im.call(im.call("tree_map_tuple")(_neg()))( + im.make_tuple(im.ref("a", i_field), im.ref("b", i_field)) + ) + + result = UnrollTupleMaps.apply(expr, offset_provider_type={}) + + assert expr.type is None + assert result == im.make_tuple(im.call(_neg())("a"), im.call(_neg())("b")) + + def test_map_tuple_does_not_recurse(): nested = ts.TupleType(types=[i_tuple_field, i_tuple_field]) g = im.lambda_("__p")(im.op_as_fieldop("plus")(im.tuple_get(0, "__p"), im.tuple_get(1, "__p")))