From fb32f5820911a724cbce40016a08c0519c41517f Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 15 Jun 2026 13:48:03 +0200 Subject: [PATCH 1/9] feat[next]: dtype-generic type system foundations (Stage 0) Add the type-system machinery for dtype-generic operators, inert until the frontend produces it: - ts.TypeVarType (value-constrained scalar type variable); widen FieldType.dtype to ScalarType | ListType | TypeVarType. - type_info.is_generic (recursive), bind_type_vars / substitute_type_vars (dtype-scoped: name -> ScalarType), and TypeVar-aware dtype predicates / promote / extract_dtype / extract_dims. - Route CompiledProgramsPool._is_generic through type_info.is_generic, resolving its TODO (behavior-preserving for scan operators). Widening FieldType.dtype requires narrowing at the concrete-IR lowering sites (dace, global_tmps), where the D4 invariant guarantees no TypeVarType, plus a return-type annotation in FOAST type deduction. Records the decision in ADR 0023. See the implementation roadmap for the PR series. --- .../ADRs/next/0023-Dtype-Generic-Operators.md | 146 ++++++++++++ .../ffront/foast_passes/type_deduction.py | 2 +- .../next/iterator/transforms/global_tmps.py | 12 +- src/gt4py/next/otf/compiled_program.py | 19 +- .../runners/dace/lowering/gtir_dataflow.py | 4 +- .../runners/dace/lowering/gtir_to_sdfg.py | 1 + .../lowering/gtir_to_sdfg_concat_where.py | 1 + .../dace/lowering/gtir_to_sdfg_scan.py | 4 +- .../dace/lowering/gtir_to_sdfg_types.py | 6 +- src/gt4py/next/type_system/type_info.py | 223 +++++++++++++++++- .../next/type_system/type_specifications.py | 28 ++- .../type_system_tests/test_type_info.py | 200 ++++++++++++++++ 12 files changed, 610 insertions(+), 36 deletions(-) create mode 100644 docs/development/ADRs/next/0023-Dtype-Generic-Operators.md diff --git a/docs/development/ADRs/next/0023-Dtype-Generic-Operators.md b/docs/development/ADRs/next/0023-Dtype-Generic-Operators.md new file mode 100644 index 0000000000..c0dfba4733 --- /dev/null +++ b/docs/development/ADRs/next/0023-Dtype-Generic-Operators.md @@ -0,0 +1,146 @@ +--- +tags: [] +--- + +# [Dtype-Generic Operators] + +- **Status**: valid +- **Authors**: Hannes Vogt (@havogt) +- **Created**: 2026-06-15 +- **Updated**: 2026-06-15 + +Field operators (and programs calling them) may be **generic in the field dtype**, +spelled with a native value-constrained `typing.TypeVar` so the same annotation is +meaningful to mypy and to the DSL frontend. Each concrete call is specialized +(monomorphized) at call time. + +```python +FloatT = typing.TypeVar("FloatT", gtx.float32, gtx.float64) + + +@gtx.field_operator +def diffusion( + a: gtx.Field[gtx.Dims[I, J], FloatT], b: gtx.Field[gtx.Dims[I, J], FloatT] +) -> gtx.Field[gtx.Dims[I, J], FloatT]: + return a - b +``` + +## Context + +`common.Field` is already a runtime-introspectable generic protocol, so +`Field[Dims[I, J], T]` with a value-constrained `TypeVar` is a valid, mypy-visible +annotation today. What was missing is the DSL side: translating such an annotation +into the internal type system and type-checking/lowering operators that use it. The +internal type system had `DeferredType` ("some type, maybe constrained") but no +notion of *identity* — it could not express "the *same* unknown dtype in two +parameters and the return type", which is the essence of generics. The runtime +monomorphization machinery, on the other hand, already existed (grown for scan +operators): `CompiledProgramsPool` keys a per-call specialization cache on the +concrete argument types. + +Prior art (numpy.typing, jaxtyping, Numba, Taichi, Triton, DaCe, two-level type +theory) converges on the same two choices adopted here: a real generic annotation +that static checkers can see, and monomorphization at call time. See the design +investigation for the full survey. + +## Decision + +### User-facing spelling + +A **value-constrained** type parameter inside the real generic `Field` class, +spelled with PEP 695 `def op[T: (float32, float64)](...)` (preferred at the 3.12+ +floor) or the equivalent module-level `TypeVar("T", float32, float64)` (accepted, +produces the same runtime objects). Value-constrained — not `bound=` — because each +use must resolve to exactly one listed type, which makes the dtype predicates +decidable and the variant set finite (eager precompilation possible). `bound=`-only +and unconstrained type variables are rejected with a clear message. + +### `ts.TypeVarType` + +A new `DataType` subclass carrying `name` and `constraints: tuple[ScalarType, ...]`. + +- Subclassing `DataType` lets it fit unchanged into `FieldType.dtype` (widened to + `ScalarType | ListType | TypeVarType`), `TupleType`/`NamedCollectionType` members, + and `foast.Symbol`. +- **Identity is the name**, scoped to one operator signature. Two *distinct* + same-named `TypeVar` objects in one signature are rejected at parse time (with PEP + 695 this is impossible by construction). As a frozen eve `DataModel` it gets + deterministic `eq`/`hash`/`content_hash` for cache keys. +- `ts.DeferredType` is **not** replaced. The two mechanisms coexist: `DeferredType` + means "not yet inferred" (and currently also encodes the scan operators' *dims* + genericity); `TypeVarType` means "universally quantified over the constraint set". + A single `type_info.is_generic` predicate recognizes both. + +### Decisions D1–D5 + +- **D1 — Decoration-time body checking with opaque `TypeVarType`.** The body is + type-checked once, at decoration time, with `T` treated as an opaque scalar. + Errors are reported in the user's vocabulary (`T`), not in instantiated terms. + (Rejected: skip-until-instantiation — breaks the decoration-time-errors UX; + finite monomorph-check — duplicates compile work and reports in the wrong + vocabulary. The finite check survives only as a test-suite cross-check.) +- **D2 — Value-constrained TypeVars only.** Finite variant set ⇒ decidable dtype + predicates and eager `.compile()` of all members. `bound=` is a future extension. +- **D3 — Strict no-promotion.** `promote(T, T) = T`; mixing `T` with a concrete + scalar/dtype (including literals: `a * 2.0`) is a decoration-time error naming the + type variable. `astype(x, T)` is the designated remediation (a fast-follow). We + pre-commit to strict-first rather than inheriting Numba-style silent promotion; + this is expected to be the main ergonomics complaint and is revisited via a named + "generic literals" follow-up, not by relaxing the default. +- **D4 — Monomorphize at FOAST level; never lower generic GTIR.** Specialization is + direct type substitution over the typed FOAST, with a full re-run of type + deduction as a soundness backstop under `__debug__`. (Rejected: lowering generic + FOAST and concretizing at GTIR level — `foast_to_gtir` bakes dtypes into literals + and casts; GTIR has no syntax for "dtype of param x".) +- **D5 — Binding is a first-class `type_info` utility.** + `bind_type_vars(params, args)` (structural match, consistency + exact-match + checks) and `substitute_type_vars(type_, binding)` (recursion over every TypeSpec). + `accepts_args` keeps its boolean interface; callers needing the binding use the + new API. + +### Monomorphization strategy + +- **Direct operator call with a backend:** `FieldOperator.__call__` → + `CompiledProgramsPool`. The pool already detects generic signatures + (`is_generic`), keys the cache on the full concrete substitution + (`arg_specialization_key`), and forwards concrete types as `CompileTimeArgs`. A + new `foast_specialize` toolchain step (after `func_to_foast`) computes the binding + and substitutes throughout the FOAST tree; everything downstream runs on a + concrete artifact. +- **Generic operator called from a concrete program:** the binding is fully static + at program decoration. The fieldop signature checks bind-and-substitute, and a + PAST monomorphization pass (run in `past_to_itir`) recomputes the binding from the + typed call-site args, name-mangles the callee per binding (e.g. + `diffusion__float32`), and swaps in a specialized callable via a new + `GTCallable.__gt_specialize__(binding)`. Two bindings of one operator naturally + become two GTIR `FunctionDefinition`s. +- **Embedded mode:** nearly free — the original Python definition runs on real + fields once decoration tolerates generic signatures. + +### Cache-key story + +The pool's `arg_specialization_key` hashes all argument types, so the full +substitution is in the key — distinct dtypes hit distinct variants. Value-constrained +TypeVars make eager precompilation of all variants possible via the existing +`.compile()` API. + +## Out of scope / deferred (with forward-compatibility notes) + +- **Generic scan operators** — rejected with a clear message (needs `init: T` + coercion semantics). Nothing in the utilities hardcodes `FieldOperatorType`. +- **`bound=` TypeVars** — infinite constraint sets; predicates by bound; no eager + precompile. +- **`astype(x, T)` / generic scalar constructors** — the D3 remediation; requires a + `ConstructorType` over `TypeVarType`. +- **Builtin coverage** — `where`, `broadcast`, reductions, `concat_where`, neighbor + fields are audited and widened incrementally; until then a generic argument to an + un-audited builtin is a clear decoration-time error (math builtins already work). +- **Dimension genericity** — a separate effort (the true fix for the scan + `DeferredType`/fabricated-`Dimension` hack). The binding utilities are kept + **dtype-scoped** here: `bind_type_vars`/`substitute_type_vars` map names to + `ScalarType` only, and same-name rejection is specified over dtype type variables + only. Widening the binding environment to dimensions and generalizing same-name + rejection across type-parameter kinds is explicitly deferred to that work. +- **PEP 696 dtype defaults** (unparameterized `Field` means `float64`) and + **mypy-plugin un-blurring** of `float32`/`float64` — later, coordinated with the + `Field` annotation cleanup (gt4py #1415/#1416). diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 1b090489ff..2b9693cc1f 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -831,7 +831,7 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> foast.Call: print(f"Warning: return type of '{func_name}' might be inconsistent (not implemented).") # deduce return type - return_type: Optional[ts.FieldType | ts.ScalarType] = None + return_type: Optional[ts.FieldType | ts.ScalarType | ts.TypeVarType] = None if ( func_name in fbuiltins.UNARY_MATH_NUMBER_BUILTIN_NAMES + fbuiltins.UNARY_MATH_FP_BUILTIN_NAMES diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 49d2f3d5b4..7aa8aa6cd5 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -181,12 +181,16 @@ def _transform_by_pattern( lambda x: next(uids["__tmp"]), result_collection_constructor=as_tuple, )(tmp_expr.type) + # the lowered IR is concrete, so `extract_dtype` never yields a `TypeVarType` here tmp_dtypes: ( ts.ScalarType | ts.ListType | tuple[ts.ScalarType | ts.ListType | tuple, ...] - ) = type_info.tree_map_type( - type_info.extract_dtype, - result_collection_constructor=as_tuple, - )(tmp_expr.type) + ) = cast( + "ts.ScalarType | ts.ListType | tuple[ts.ScalarType | ts.ListType | tuple, ...]", + type_info.tree_map_type( + type_info.extract_dtype, + result_collection_constructor=as_tuple, + )(tmp_expr.type), + ) tmp_domains: SymbolicDomain | tuple[SymbolicDomain | tuple, ...] = tmp_expr.annex.domain diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index f09ae16bd9..64b73ab9f2 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -396,7 +396,8 @@ def __call__( # expensive type deduction for all arguments and not include it in the key. if enable_jit: warnings.warn( - "Calling generic programs / direct calls to scan operators are not optimized. " + "Calling generic programs / operators (e.g. operators with generic dtype " + "or direct calls to scan operators) is not optimized. " "Consider calling a specialized version instead.", stacklevel=3, ) @@ -459,19 +460,11 @@ def _is_generic(self) -> bool: Is the operator or program generic in the sense that it can be called for different argument types. - Right now this is only the case for scan operators. + Right now this is the case for scan operators (genericity communicated via + `DeferredType` parameters created in `type_info.type_in_program_context`) and for + operators with a generic dtype (type variable). """ - # TODO(tehrengruber): This concept does not exist elsewhere and is not properly reflected - # in the type system. For now we just use `DeferredType` to communicate between - # here and `type_info.type_in_program_context`. - return any( - isinstance(t, ts.DeferredType) - for t in itertools.chain( - self.program_type.definition.pos_only_args, - self.program_type.definition.pos_or_kw_args.values(), - self.program_type.definition.kw_only_args.values(), - ) - ) + return type_info.is_generic(self.program_type.definition) @functools.cached_property def _args_canonicalizer(self) -> Callable[..., tuple[tuple, dict[str, Any]]]: diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index 87a505313f..dfc7e3ab9b 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -100,7 +100,9 @@ class MemletExpr: @property def gt_dtype(self) -> ts.ScalarType | ts.ListType: - return self.gt_field.dtype + dtype = self.gt_field.dtype + assert isinstance(dtype, (ts.ScalarType, ts.ListType)) + return dtype def __post_init__(self) -> None: if isinstance(self.gt_dtype, ts.ListType): diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index 223eff2d79..aacc0ed7bc 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -896,6 +896,7 @@ def _add_storage( dc_dtype = gtx_dace_args.as_dace_type(gt_type.dtype) all_dims = gt_type.dims else: # for 'ts.ListType' use 'offset_type' as local dimension + assert isinstance(gt_type.dtype, ts.ListType) assert gt_type.dtype.offset_type is not None assert gt_type.dtype.offset_type.kind == gtx_common.DimensionKind.LOCAL assert isinstance(gt_type.dtype.element_type, ts.ScalarType) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_concat_where.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_concat_where.py index ded61af77b..fde144dedc 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_concat_where.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_concat_where.py @@ -143,6 +143,7 @@ def testee(interior: cases.IJKField, boundary: cases.IJField) -> cases.IJKField: if isinstance(output_type.dtype, ts.ScalarType): all_dims = gtx_common.order_dimensions(output_type.dims) else: + assert isinstance(output_type.dtype, ts.ListType) assert output_type.dtype.offset_type all_dims = gtx_common.order_dimensions([*output_type.dims, output_type.dtype.offset_type]) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py index fddcb080b8..da1a813f38 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py @@ -622,7 +622,9 @@ def _handle_dataflow_result_of_nested_sdfg( None, dace.Memlet.from_array(outer_dataname, outer_desc), ) - output_expr = gtir_dataflow.ValueExpr(outer_node, inner_data.gt_type.dtype) + output_dtype = inner_data.gt_type.dtype + assert isinstance(output_dtype, (ts.ScalarType, ts.ListType)) + output_expr = gtir_dataflow.ValueExpr(outer_node, output_dtype) return gtir_dataflow.DataflowOutputEdge(outer_ctx.state, output_expr) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py index 83cd7660d8..665f35c148 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py @@ -86,9 +86,9 @@ def get_local_view( (dim, dace.symbolic.SymExpr(0) if self.origin is None else self.origin[i]) for i, dim in enumerate(self.gt_type.dims) ] - return gtir_dataflow.IteratorExpr( - self.dc_node, self.gt_type.dtype, field_origin, it_indices - ) + dtype = self.gt_type.dtype + assert isinstance(dtype, (ts.ScalarType, ts.ListType)) + return gtir_dataflow.IteratorExpr(self.dc_node, dtype, field_origin, it_indices) raise NotImplementedError(f"Node type {type(self.gt_type)} not supported.") diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 25dce1c3f5..6e873514ab 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -54,6 +54,54 @@ def is_concrete(symbol_type: ts.TypeSpec) -> TypeGuard[ts.TypeSpec]: return False +def is_generic(symbol_type: ts.TypeSpec) -> bool: + """ + Figure out if a type contains parts that are only known when concrete arguments are given. + + A generic (callable) type can be called with arguments of varying types, e.g. the program + context signature of a scan operator. Contrary to :func:`is_concrete` this predicate + recurses into composite types. + + Note: this returns ``True`` for a bare ``astype`` constructor type, whose ``definition`` + carries a ``DeferredType`` by design; callers that only care about *data* arguments must + filter for ``ts.DataType`` themselves. + + >>> is_generic(ts.DeferredType(constraint=None)) + True + + >>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) + >>> is_generic(bool_type) + False + + >>> is_generic(ts.TupleType(types=[bool_type, ts.DeferredType(constraint=None)])) + True + """ + match symbol_type: + case ts.DeferredType() | ts.TypeVarType(): + return True + case ts.FieldType(dtype=dtype): + return is_generic(dtype) + case ts.ListType(element_type=element_type): + return is_generic(element_type) + case ts.TupleType(types=types) | ts.NamedCollectionType(types=types): + return any(is_generic(t) for t in types) + case ts.FunctionType(): + return any( + is_generic(t) + for t in ( + *symbol_type.pos_only_args, + *symbol_type.pos_or_kw_args.values(), + *symbol_type.kw_only_args.values(), + symbol_type.returns, + ) + ) + # callable type wrappers (e.g. the field operator types in `ffront`) carry their + # signature in a `definition` attribute + if isinstance(definition := getattr(symbol_type, "definition", None), ts.TypeSpec): + return is_generic(definition) + return False + + def type_class(symbol_type: ts.TypeSpec) -> Type[ts.TypeSpec]: """ Determine which class should be used to create a compatible concrete type. @@ -181,9 +229,9 @@ def tree_map_type( ) -def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType | ts.ListType: +def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType | ts.ListType | ts.TypeVarType: """ - Extract the data type from ``symbol_type`` if it is either `FieldType` or `ScalarType`. + Extract the data type from ``symbol_type`` if it is `FieldType`, `ScalarType` or `TypeVarType`. Raise an error if no dtype can be found or the result would be ambiguous. @@ -201,6 +249,8 @@ def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType | ts.ListType: return dtype case ts.ScalarType() as dtype: return dtype + case ts.TypeVarType() as dtype: + return dtype raise ValueError(f"Can not unambiguosly extract data type from '{symbol_type}'.") @@ -215,10 +265,17 @@ def _scalar_kinds(scalar_types: tuple[type, ...]) -> frozenset[ts.ScalarKind]: def _is_field_or_scalar_of_kind(symbol_type: ts.TypeSpec, kinds: Collection[ts.ScalarKind]) -> bool: - """Check if ``symbol_type`` is a scalar or a field whose dtype kind is in ``kinds``.""" + """Check if ``symbol_type`` is a scalar or a field whose dtype kind is in ``kinds``. + + A type variable has the property iff all of its constraints have it. + """ + if isinstance(symbol_type, ts.TypeVarType): + return all(_is_field_or_scalar_of_kind(c, kinds) for c in symbol_type.constraints) if not isinstance(symbol_type, (ts.ScalarType, ts.FieldType)): return False dtype = extract_dtype(symbol_type) + if isinstance(dtype, ts.TypeVarType): + return all(_is_field_or_scalar_of_kind(c, kinds) for c in dtype.constraints) return isinstance(dtype, ts.ScalarType) and dtype.kind in kinds @@ -289,7 +346,7 @@ def is_arithmetic_scalar(symbol_type: ts.TypeSpec) -> bool: ... ) False """ - if not isinstance(symbol_type, ts.ScalarType): + if not isinstance(symbol_type, (ts.ScalarType, ts.TypeVarType)): return False return is_arithmetic(symbol_type) @@ -323,6 +380,15 @@ def is_arithmetic(symbol_type: ts.TypeSpec) -> bool: >>> is_arithmetic(ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32))) True """ + # `is_arithmetic` cannot reuse `_is_field_or_scalar_of_kind`'s "all constraints + # share the kind" rule: a type variable is arithmetic if every constraint is + # arithmetic, even when the constraints mix floating point and integral kinds. + if isinstance(symbol_type, ts.TypeVarType): + return all(is_arithmetic(c) for c in symbol_type.constraints) + if isinstance(symbol_type, (ts.ScalarType, ts.FieldType)) and isinstance( + dtype := extract_dtype(symbol_type), ts.TypeVarType + ): + return is_arithmetic(dtype) return is_floating_point(symbol_type) or is_integral(symbol_type) @@ -398,7 +464,7 @@ def extract_dims(symbol_type: ts.TypeSpec) -> list[common.Dimension]: >>> extract_dims(ts.FieldType(dims=[I, J], dtype=ts.ScalarType(kind=ts.ScalarKind.INT64))) [Dimension(value='I', kind=), Dimension(value='J', kind=)] """ - if isinstance(symbol_type, ts.ScalarType): + if isinstance(symbol_type, (ts.ScalarType, ts.TypeVarType)): return [] if isinstance(symbol_type, ts.FieldType): return symbol_type.dims @@ -558,9 +624,130 @@ def is_concretizable(symbol_type: ts.TypeSpec, to_type: ts.TypeSpec) -> bool: return False +def bind_type_vars( + params: Sequence[ts.TypeSpec], args: Sequence[ts.TypeSpec] +) -> dict[str, ts.ScalarType]: + """ + Compute a binding of all type variables in ``params`` by structurally matching ``args``. + + Concrete (non-generic) parts of the parameters are ignored here; mismatches in those are + reported by the regular signature checks. A type variable position only binds if the + corresponding argument provides a concrete scalar dtype; the caller is responsible for + checking that no type variable remained unbound (if required). + + Note: the binding is intentionally ``dtype``-scoped (scalar type variables only); a future + extension for dimension variables (see the dimension-generics design) will widen it. + + Raises: + ValueError: If a type variable would be bound inconsistently or to a dtype that is + not one of its constraints. + + >>> var = ts.TypeVarType(name="T", constraints=(ts.ScalarType(kind=ts.ScalarKind.FLOAT64),)) + >>> I = common.Dimension(value="I") + >>> binding = bind_type_vars( + ... [ts.FieldType(dims=[I], dtype=var)], + ... [ts.FieldType(dims=[I], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64))], + ... ) + >>> print(binding["T"]) + float64 + """ + binding: dict[str, ts.ScalarType] = {} + + def bind_var(var: ts.TypeVarType, dtype: ts.TypeSpec) -> None: + if not isinstance(dtype, ts.ScalarType): + return # no concrete dtype available, leave unbound + if dtype not in var.constraints: + raise ValueError( + f"'{dtype}' does not satisfy the constraints of type variable '{var}'." + ) + if (previous := binding.get(var.name)) is not None and previous != dtype: + raise ValueError( + f"Type variable '{var.name}' is bound inconsistently:" + f" '{previous}' and '{dtype}' (all arguments using '{var.name}'" + " must have the same dtype)." + ) + binding[var.name] = dtype + + def bind(param: ts.TypeSpec, arg: ts.TypeSpec) -> None: + match param: + case ts.TypeVarType(): + bind_var(param, arg) + case ts.FieldType(dtype=ts.TypeVarType() as var): + if isinstance(arg, ts.FieldType): + bind_var(var, arg.dtype) + elif isinstance(arg, ts.ScalarType): + # scalar arguments are promoted to zero-dimensional fields + bind_var(var, arg) + case ts.ListType(element_type=element_type) if isinstance(arg, ts.ListType): + bind(element_type, arg.element_type) + case ts.TupleType() | ts.NamedCollectionType() if isinstance( + arg, (ts.TupleType, ts.NamedCollectionType) + ): + for param_el, arg_el in zip(param.types, arg.types): + bind(param_el, arg_el) + + for param, arg in zip(params, args): + bind(param, arg) + return binding + + +def substitute_type_vars( + type_: ts.TypeSpec, binding: xtyping.Mapping[str, ts.ScalarType] +) -> ts.TypeSpec: + """ + Replace all type variables in ``type_`` that are bound in ``binding``. + + Unbound type variables and all other generic parts (e.g. `DeferredType`) are kept as-is. + + >>> var = ts.TypeVarType(name="T", constraints=(ts.ScalarType(kind=ts.ScalarKind.FLOAT64),)) + >>> I = common.Dimension(value="I") + >>> print( + ... substitute_type_vars( + ... ts.FieldType(dims=[I], dtype=var), + ... {"T": ts.ScalarType(kind=ts.ScalarKind.FLOAT64)}, + ... ) + ... ) + Field[[I], float64] + """ + if not binding or not is_generic(type_): + return type_ + match type_: + case ts.TypeVarType(name=name): + return binding.get(name, type_) + case ts.FieldType(dims=dims, dtype=dtype): + new_dtype = substitute_type_vars(dtype, binding) + assert isinstance(new_dtype, (ts.ScalarType, ts.ListType, ts.TypeVarType)) + return ts.FieldType(dims=dims, dtype=new_dtype) + case ts.ListType(element_type=element_type, offset_type=offset_type): + new_element_type = substitute_type_vars(element_type, binding) + assert isinstance(new_element_type, ts.DataType) + return ts.ListType(element_type=new_element_type, offset_type=offset_type) + case ts.TupleType(types=types): + return ts.TupleType(types=[substitute_type_vars(t, binding) for t in types]) + case ts.NamedCollectionType(types=types): + return ts.NamedCollectionType( + types=[substitute_type_vars(t, binding) for t in types], + keys=type_.keys, + original_python_type=type_.original_python_type, + ) + case ts.FunctionType(): + return ts.FunctionType( + pos_only_args=[substitute_type_vars(t, binding) for t in type_.pos_only_args], + pos_or_kw_args={ + name: substitute_type_vars(t, binding) + for name, t in type_.pos_or_kw_args.items() + }, + kw_only_args={ + name: substitute_type_vars(t, binding) for name, t in type_.kw_only_args.items() + }, + returns=substitute_type_vars(type_.returns, binding), + ) + return type_ + + def promote( - *types: ts.FieldType | ts.ScalarType, always_field: bool = False -) -> ts.FieldType | ts.ScalarType: + *types: ts.FieldType | ts.ScalarType | ts.TypeVarType, always_field: bool = False +) -> ts.FieldType | ts.ScalarType | ts.TypeVarType: """ Promote a set of field or scalar types to a common type. @@ -582,17 +769,29 @@ def promote( >>> promoted.dims == [I, J, K] and promoted.dtype == dtype True """ - if not always_field and all(isinstance(type_, ts.ScalarType) for type_ in types): + if not always_field and all( + isinstance(type_, (ts.ScalarType, ts.TypeVarType)) for type_ in types + ): if not all(type_ == types[0] for type_ in types): + if any(isinstance(type_, ts.TypeVarType) for type_ in types): + distinct_types = "', '".join(str(t) for t in dict.fromkeys(types)) + raise ValueError( + f"Could not promote '{distinct_types}': a generic dtype (type variable)" + " can only be combined with values of the same type variable," + " not with other dtypes." + ) raise ValueError("Could not promote scalars of different dtype (not implemented).") - if not all(type_.shape is None for type_ in types): # type: ignore[union-attr] + if not all(type_.shape is None for type_ in types if isinstance(type_, ts.ScalarType)): raise NotImplementedError("Shape promotion not implemented.") return types[0] - elif all(isinstance(type_, (ts.ScalarType, ts.FieldType)) for type_ in types): + elif all(isinstance(type_, (ts.ScalarType, ts.FieldType, ts.TypeVarType)) for type_ in types): dims = common.promote_dims(*(extract_dims(type_) for type_ in types)) extracted_dtypes = [extract_dtype(type_) for type_ in types] - assert all(isinstance(dtype, ts.ScalarType) for dtype in extracted_dtypes) - dtype = cast(ts.ScalarType, promote(*extracted_dtypes)) # type: ignore[arg-type] # checked is `ScalarType` + assert all(isinstance(dtype, (ts.ScalarType, ts.TypeVarType)) for dtype in extracted_dtypes) + dtype = cast( # type variables promote like scalars (only with themselves) + ts.ScalarType | ts.TypeVarType, + promote(*extracted_dtypes), # type: ignore[arg-type] # checked above + ) return ts.FieldType(dims=dims, dtype=dtype) raise TypeError("Expected a 'FieldType' or 'ScalarType'.") diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 59ac40f0f3..c7e2a4821b 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -105,6 +105,32 @@ def __str__(self) -> str: return f"{kind_str}{self.shape}" +class TypeVarType(DataType): + """ + A scalar type variable, universally quantified over its constraints. + + Represents the type of a value-constrained Python ``typing.TypeVar`` (e.g. + ``TypeVar("T", float32, float64)``) used in the signature of a generic operator. + Two occurrences with the same ``name`` within one signature denote the same type. + """ + + name: str + constraints: tuple[ScalarType, ...] + + def __str__(self) -> str: + return f"{self.name}: ({' | '.join(map(str, self.constraints))})" + + @eve_datamodels.validator("constraints") + def _constraints_validator( + self, attribute: eve_datamodels.Attribute, constraints: tuple[ScalarType, ...] + ) -> None: + if not constraints: + raise ValueError( + f"Type variable '{self.name}' must be value-constrained, i.e. have at" + " least one constraint." + ) + + class ListType(DataType): """Represents a neighbor list in the ITIR representation. @@ -119,7 +145,7 @@ class ListType(DataType): class FieldType(DataType, CallableType): dims: list[common.Dimension] - dtype: ScalarType | ListType + dtype: ScalarType | ListType | TypeVarType def __str__(self) -> str: dims = "..." if self.dims is Ellipsis else f"[{', '.join(dim.value for dim in self.dims)}]" diff --git a/tests/next_tests/unit_tests/type_system_tests/test_type_info.py b/tests/next_tests/unit_tests/type_system_tests/test_type_info.py index 35c3d2eba1..981423fbc9 100644 --- a/tests/next_tests/unit_tests/type_system_tests/test_type_info.py +++ b/tests/next_tests/unit_tests/type_system_tests/test_type_info.py @@ -373,6 +373,206 @@ def test_type_info_basic(symbol_type, expected): assert getattr(type_info, key)(symbol_type) == expected[key] +def is_generic_cases() -> list[tuple[ts.TypeSpec, bool]]: + deferred_type = ts.DeferredType(constraint=None) + float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + concrete_field_type = ts.FieldType(dims=[TDim], dtype=float_type) + + def function_type(params: list[ts.TypeSpec]) -> ts.FunctionType: + return ts.FunctionType( + pos_only_args=[], + pos_or_kw_args={f"arg{i}": param for i, param in enumerate(params)}, + kw_only_args={}, + returns=ts.VoidType(), + ) + + return [ + (deferred_type, True), + (float_type, False), + (concrete_field_type, False), + (ts.TupleType(types=[float_type, concrete_field_type]), False), + # `DeferredType` nested inside a composite type, e.g. the program context signature + # of a scan operator with tuple arguments + (ts.TupleType(types=[float_type, deferred_type]), True), + (function_type([concrete_field_type]), False), + (function_type([deferred_type]), True), + (function_type([ts.TupleType(types=[deferred_type])]), True), + ( + ts_ffront.ProgramType(definition=function_type([deferred_type])), + True, + ), + ( + ts_ffront.FieldOperatorType(definition=function_type([concrete_field_type])), + False, + ), + ] + + +@pytest.mark.parametrize("symbol_type,expected", is_generic_cases()) +def test_is_generic(symbol_type: ts.TypeSpec, expected: bool): + assert type_info.is_generic(symbol_type) == expected + + +float32_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT32) +float64_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) +int32_type = ts.ScalarType(kind=ts.ScalarKind.INT32) +bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) +float_var = ts.TypeVarType(name="T", constraints=(float32_type, float64_type)) +mixed_var = ts.TypeVarType(name="U", constraints=(float64_type, int32_type)) + + +class TestTypeVarType: + def test_validation(self): + with pytest.raises(ValueError, match="value-constrained"): + ts.TypeVarType(name="T", constraints=()) + + def test_identity_and_hashing(self): + from gt4py.eve import utils as eve_utils + + same_var = ts.TypeVarType(name="T", constraints=(float32_type, float64_type)) + assert float_var == same_var + assert hash(float_var) == hash(same_var) + assert eve_utils.content_hash(float_var) == eve_utils.content_hash(same_var) + assert float_var != ts.TypeVarType(name="S", constraints=(float32_type, float64_type)) + # constraint order is part of the identity (preserved as written) + assert float_var != ts.TypeVarType(name="T", constraints=(float64_type, float32_type)) + + def test_is_generic(self): + assert type_info.is_generic(float_var) + assert type_info.is_generic(ts.FieldType(dims=[TDim], dtype=float_var)) + assert type_info.is_generic( + ts.TupleType(types=[float64_type, ts.FieldType(dims=[TDim], dtype=float_var)]) + ) + + @pytest.mark.parametrize( + "predicate,var,expected", + [ + (type_info.is_floating_point, float_var, True), + (type_info.is_floating_point, mixed_var, False), + (type_info.is_integral, float_var, False), + (type_info.is_integral, ts.TypeVarType(name="I", constraints=(int32_type,)), True), + (type_info.is_arithmetic, float_var, True), + (type_info.is_arithmetic, mixed_var, True), + ( + type_info.is_arithmetic, + ts.TypeVarType(name="B", constraints=(bool_type, float64_type)), + False, + ), + (type_info.is_logical, ts.TypeVarType(name="B", constraints=(bool_type,)), True), + (type_info.is_logical, float_var, False), + (type_info.is_arithmetic_scalar, float_var, True), + ], + ) + def test_predicates_evaluate_over_constraints(self, predicate, var, expected): + assert predicate(var) == expected + if predicate is not type_info.is_arithmetic_scalar: # rejects fields by design + assert predicate(ts.FieldType(dims=[TDim], dtype=var)) == expected + + def test_promote_same_var(self): + assert type_info.promote(float_var, float_var) == float_var + promoted = type_info.promote( + ts.FieldType(dims=[TDim], dtype=float_var), ts.FieldType(dims=[TDim], dtype=float_var) + ) + assert promoted == ts.FieldType(dims=[TDim], dtype=float_var) + + def test_promote_var_with_scalar_arg(self): + promoted = type_info.promote(ts.FieldType(dims=[TDim], dtype=float_var), float_var) + assert promoted == ts.FieldType(dims=[TDim], dtype=float_var) + + @pytest.mark.parametrize( + "types", + [ + (float_var, float64_type), + (float_var, mixed_var), + (ts.FieldType(dims=[TDim], dtype=float_var), float64_type), + ( + ts.FieldType(dims=[TDim], dtype=float_var), + ts.FieldType(dims=[TDim], dtype=float64_type), + ), + ], + ) + def test_promote_mixing_error(self, types): + with pytest.raises(ValueError, match="type variable"): + type_info.promote(*types) + + +class TestBindTypeVars: + def test_bind_from_field(self): + binding = type_info.bind_type_vars( + [ts.FieldType(dims=[TDim], dtype=float_var)], + [ts.FieldType(dims=[TDim], dtype=float32_type)], + ) + assert binding == {"T": float32_type} + + def test_bind_from_scalar_and_nested(self): + binding = type_info.bind_type_vars( + [ts.TupleType(types=[float_var, ts.FieldType(dims=[TDim], dtype=float_var)])], + [ts.TupleType(types=[float64_type, ts.FieldType(dims=[TDim], dtype=float64_type)])], + ) + assert binding == {"T": float64_type} + + def test_concrete_params_dont_bind(self): + assert ( + type_info.bind_type_vars( + [ts.FieldType(dims=[TDim], dtype=float64_type)], + [ts.FieldType(dims=[TDim], dtype=float32_type)], + ) + == {} + ) + + def test_inconsistent_binding(self): + with pytest.raises(ValueError, match="bound inconsistently"): + type_info.bind_type_vars( + [ + ts.FieldType(dims=[TDim], dtype=float_var), + ts.FieldType(dims=[TDim], dtype=float_var), + ], + [ + ts.FieldType(dims=[TDim], dtype=float32_type), + ts.FieldType(dims=[TDim], dtype=float64_type), + ], + ) + + def test_constraint_violation(self): + with pytest.raises(ValueError, match="constraints"): + type_info.bind_type_vars( + [ts.FieldType(dims=[TDim], dtype=float_var)], + [ts.FieldType(dims=[TDim], dtype=int32_type)], + ) + + +class TestSubstituteTypeVars: + def test_substitute(self): + generic = ts.TupleType( + types=[float_var, ts.FieldType(dims=[TDim], dtype=float_var), int32_type] + ) + substituted = type_info.substitute_type_vars(generic, {"T": float32_type}) + assert substituted == ts.TupleType( + types=[float32_type, ts.FieldType(dims=[TDim], dtype=float32_type), int32_type] + ) + assert not type_info.is_generic(substituted) + + def test_unbound_vars_are_kept(self): + generic = ts.FieldType(dims=[TDim], dtype=float_var) + assert type_info.substitute_type_vars(generic, {"S": float32_type}) == generic + + def test_concrete_is_returned_unchanged(self): + concrete = ts.FieldType(dims=[TDim], dtype=float64_type) + assert type_info.substitute_type_vars(concrete, {"T": float32_type}) is concrete + + def test_substitute_function_type(self): + func_type = ts.FunctionType( + pos_only_args=[ts.FieldType(dims=[TDim], dtype=float_var)], + pos_or_kw_args={"a": float_var}, + kw_only_args={}, + returns=ts.FieldType(dims=[TDim], dtype=float_var), + ) + substituted = type_info.substitute_type_vars(func_type, {"T": float64_type}) + assert substituted.pos_only_args[0] == ts.FieldType(dims=[TDim], dtype=float64_type) + assert substituted.pos_or_kw_args["a"] == float64_type + assert substituted.returns == ts.FieldType(dims=[TDim], dtype=float64_type) + + @pytest.mark.parametrize("func_type,args,kwargs,expected,return_type", callable_type_info_cases()) def test_accept_args( func_type: ts.TypeSpec, From 378e6a0cf5b330c1c9a0257af39c7d400e88a36f Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 17 Jun 2026 11:42:12 +0200 Subject: [PATCH 2/9] refactor[next]: address review on dtype-generic type system - canonicalize TypeVarType.constraints (order carries no meaning for a value-constrained TypeVar) so the type's identity is order-insensitive - unify the type-variable traversals: add _type_params / tree_map_type_params and reexpress is_generic and substitute_type_vars on them; reword the is_generic docstring (deep check, not the negation of is_concrete) - document the non-scalar case in bind_var and the dtype-union narrowing assert in the dace lowering - ADR 0023: tighten, rename the example, drop the external design-investigation reference, and add it to the ADR index under a new Type System section --- .../ADRs/next/0023-Dtype-Generic-Operators.md | 27 ++-- docs/development/ADRs/next/README.md | 4 + .../dace/lowering/gtir_to_sdfg_types.py | 3 + src/gt4py/next/type_system/type_info.py | 139 ++++++++++-------- .../next/type_system/type_specifications.py | 8 +- .../type_system_tests/test_type_info.py | 4 +- 6 files changed, 107 insertions(+), 78 deletions(-) diff --git a/docs/development/ADRs/next/0023-Dtype-Generic-Operators.md b/docs/development/ADRs/next/0023-Dtype-Generic-Operators.md index c0dfba4733..33454e033e 100644 --- a/docs/development/ADRs/next/0023-Dtype-Generic-Operators.md +++ b/docs/development/ADRs/next/0023-Dtype-Generic-Operators.md @@ -19,7 +19,7 @@ FloatT = typing.TypeVar("FloatT", gtx.float32, gtx.float64) @gtx.field_operator -def diffusion( +def diff( a: gtx.Field[gtx.Dims[I, J], FloatT], b: gtx.Field[gtx.Dims[I, J], FloatT] ) -> gtx.Field[gtx.Dims[I, J], FloatT]: return a - b @@ -29,19 +29,16 @@ def diffusion( `common.Field` is already a runtime-introspectable generic protocol, so `Field[Dims[I, J], T]` with a value-constrained `TypeVar` is a valid, mypy-visible -annotation today. What was missing is the DSL side: translating such an annotation -into the internal type system and type-checking/lowering operators that use it. The -internal type system had `DeferredType` ("some type, maybe constrained") but no -notion of *identity* — it could not express "the *same* unknown dtype in two -parameters and the return type", which is the essence of generics. The runtime -monomorphization machinery, on the other hand, already existed (grown for scan -operators): `CompiledProgramsPool` keys a per-call specialization cache on the -concrete argument types. - -Prior art (numpy.typing, jaxtyping, Numba, Taichi, Triton, DaCe, two-level type -theory) converges on the same two choices adopted here: a real generic annotation -that static checkers can see, and monomorphization at call time. See the design -investigation for the full survey. +annotation today; what was missing is the DSL side. The internal type system had +`DeferredType` ("some type, maybe constrained") but no notion of *identity* — it +could not express "the *same* unknown dtype in two parameters and the return type", +the essence of generics. The runtime monomorphization machinery already existed +(grown for scan operators): `CompiledProgramsPool` keys a per-call specialization +cache on the concrete argument types. + +Prior art (numpy.typing, jaxtyping, Numba, Taichi, Triton, DaCe) converges on the +two choices adopted here: a real generic annotation that static checkers can see, +and monomorphization at call time. ## Decision @@ -111,7 +108,7 @@ A new `DataType` subclass carrying `name` and `constraints: tuple[ScalarType, .. at program decoration. The fieldop signature checks bind-and-substitute, and a PAST monomorphization pass (run in `past_to_itir`) recomputes the binding from the typed call-site args, name-mangles the callee per binding (e.g. - `diffusion__float32`), and swaps in a specialized callable via a new + `diff__float32`), and swaps in a specialized callable via a new `GTCallable.__gt_specialize__(binding)`. Two bindings of one operator naturally become two GTIR `FunctionDefinition`s. - **Embedded mode:** nearly free — the original Python definition runs on real diff --git a/docs/development/ADRs/next/README.md b/docs/development/ADRs/next/README.md index 6a133aba75..8a153eacd4 100644 --- a/docs/development/ADRs/next/README.md +++ b/docs/development/ADRs/next/README.md @@ -25,6 +25,10 @@ Writing a new ADR is simple: - [0010 - Domain in Field View](0010-Domain_in_Field_View.md) - [0013 - Scalar vs 0d-Fields](0013-Scalar_vs_0d_Fields.md) +### Type System + +- [0023 - Dtype-Generic Operators](0023-Dtype-Generic-Operators.md) + ### Iterator IR #iterator - [0003 - Iterator View Tuple Support for Fields](0003-Iterator_View_Tuple_Support_for_Fields.md) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py index 665f35c148..b220732756 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py @@ -87,6 +87,9 @@ def get_local_view( for i, dim in enumerate(self.gt_type.dims) ] dtype = self.gt_type.dtype + # `FieldType.dtype` is widened to include `TypeVarType`, but generic operators are + # monomorphized before lowering, so only concrete dtypes reach here. Removing this + # (and the sibling asserts) would need a non-generic field type; deferred for now. assert isinstance(dtype, (ts.ScalarType, ts.ListType)) return gtir_dataflow.IteratorExpr(self.dc_node, dtype, field_origin, it_indices) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 6e873514ab..b95b9a95bf 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -54,13 +54,43 @@ def is_concrete(symbol_type: ts.TypeSpec) -> TypeGuard[ts.TypeSpec]: return False +def _type_params(symbol_type: ts.TypeSpec) -> tuple[ts.TypeSpec, ...]: + """The immediate type-parameter sub-types of ``symbol_type``. + + These are its dtype, element type, tuple elements, or function argument / return types -- + one definition of where type parameters live, shared by the recursive type-variable + traversals (`is_generic` here, `substitute_type_vars` via `tree_map_type_params`). + """ + match symbol_type: + case ts.FieldType(dtype=dtype): + return (dtype,) + case ts.ListType(element_type=element_type): + return (element_type,) + case ts.TupleType(types=types) | ts.NamedCollectionType(types=types): + return tuple(types) + case ts.FunctionType(): + return ( + *symbol_type.pos_only_args, + *symbol_type.pos_or_kw_args.values(), + *symbol_type.kw_only_args.values(), + symbol_type.returns, + ) + # callable type wrappers (e.g. the field operator types in `ffront`) carry their + # signature in a `definition` attribute + if isinstance(definition := getattr(symbol_type, "definition", None), ts.TypeSpec): + return (definition,) + return () + + def is_generic(symbol_type: ts.TypeSpec) -> bool: """ Figure out if a type contains parts that are only known when concrete arguments are given. A generic (callable) type can be called with arguments of varying types, e.g. the program - context signature of a scan operator. Contrary to :func:`is_concrete` this predicate - recurses into composite types. + context signature of a scan operator. This recurses into composite types, reporting ``True`` + if any nested part is a `DeferredType` or `TypeVarType`. It is *not* the negation of + :func:`is_concrete`: the latter is a shallow deduction-progress flag (top-level + `DeferredType` only), so a tuple with a nested `DeferredType` is both concrete and generic. Note: this returns ``True`` for a bare ``astype`` constructor type, whose ``definition`` carries a ``DeferredType`` by design; callers that only care about *data* arguments must @@ -76,30 +106,9 @@ def is_generic(symbol_type: ts.TypeSpec) -> bool: >>> is_generic(ts.TupleType(types=[bool_type, ts.DeferredType(constraint=None)])) True """ - match symbol_type: - case ts.DeferredType() | ts.TypeVarType(): - return True - case ts.FieldType(dtype=dtype): - return is_generic(dtype) - case ts.ListType(element_type=element_type): - return is_generic(element_type) - case ts.TupleType(types=types) | ts.NamedCollectionType(types=types): - return any(is_generic(t) for t in types) - case ts.FunctionType(): - return any( - is_generic(t) - for t in ( - *symbol_type.pos_only_args, - *symbol_type.pos_or_kw_args.values(), - *symbol_type.kw_only_args.values(), - symbol_type.returns, - ) - ) - # callable type wrappers (e.g. the field operator types in `ffront`) carry their - # signature in a `definition` attribute - if isinstance(definition := getattr(symbol_type, "definition", None), ts.TypeSpec): - return is_generic(definition) - return False + if isinstance(symbol_type, (ts.DeferredType, ts.TypeVarType)): + return True + return any(is_generic(p) for p in _type_params(symbol_type)) def type_class(symbol_type: ts.TypeSpec) -> Type[ts.TypeSpec]: @@ -636,7 +645,7 @@ def bind_type_vars( checking that no type variable remained unbound (if required). Note: the binding is intentionally ``dtype``-scoped (scalar type variables only); a future - extension for dimension variables (see the dimension-generics design) will widen it. + extension for dimension variables will widen it. Raises: ValueError: If a type variable would be bound inconsistently or to a dtype that is @@ -655,7 +664,10 @@ def bind_type_vars( def bind_var(var: ts.TypeVarType, dtype: ts.TypeSpec) -> None: if not isinstance(dtype, ts.ScalarType): - return # no concrete dtype available, leave unbound + # not a concrete scalar to bind to -- e.g. a `TypeVarType` (operator-from-operator + # call), a `DeferredType` (scan), or a `ListType` (local field). Leave it unbound; + # the caller is responsible for checking that no type variable remained unbound. + return if dtype not in var.constraints: raise ValueError( f"'{dtype}' does not satisfy the constraints of type variable '{var}'." @@ -691,6 +703,42 @@ def bind(param: ts.TypeSpec, arg: ts.TypeSpec) -> None: return binding +def tree_map_type_params( + fun: Callable[[ts.TypeSpec], ts.TypeSpec], symbol_type: ts.TypeSpec +) -> ts.TypeSpec: + """Rebuild ``symbol_type`` applying ``fun`` to each immediate type-parameter sub-type. + + The counterpart of `tree_map_type` for the type-parameter positions enumerated by + `_type_params` (dtype, element type, function argument / return types), which `tree_map_type` + -- a *collection* map -- does not descend into. Leaf types are returned unchanged. + """ + match symbol_type: + case ts.FieldType(dims=dims, dtype=dtype): + new_dtype = fun(dtype) + assert isinstance(new_dtype, (ts.ScalarType, ts.ListType, ts.TypeVarType)) + return ts.FieldType(dims=dims, dtype=new_dtype) + case ts.ListType(element_type=element_type, offset_type=offset_type): + new_element_type = fun(element_type) + assert isinstance(new_element_type, ts.DataType) + return ts.ListType(element_type=new_element_type, offset_type=offset_type) + case ts.TupleType(types=types): + return ts.TupleType(types=[fun(t) for t in types]) + case ts.NamedCollectionType(types=types): + return ts.NamedCollectionType( + types=[fun(t) for t in types], + keys=symbol_type.keys, + original_python_type=symbol_type.original_python_type, + ) + case ts.FunctionType(): + return ts.FunctionType( + pos_only_args=[fun(t) for t in symbol_type.pos_only_args], + pos_or_kw_args={name: fun(t) for name, t in symbol_type.pos_or_kw_args.items()}, + kw_only_args={name: fun(t) for name, t in symbol_type.kw_only_args.items()}, + returns=fun(symbol_type.returns), + ) + return symbol_type + + def substitute_type_vars( type_: ts.TypeSpec, binding: xtyping.Mapping[str, ts.ScalarType] ) -> ts.TypeSpec: @@ -711,38 +759,9 @@ def substitute_type_vars( """ if not binding or not is_generic(type_): return type_ - match type_: - case ts.TypeVarType(name=name): - return binding.get(name, type_) - case ts.FieldType(dims=dims, dtype=dtype): - new_dtype = substitute_type_vars(dtype, binding) - assert isinstance(new_dtype, (ts.ScalarType, ts.ListType, ts.TypeVarType)) - return ts.FieldType(dims=dims, dtype=new_dtype) - case ts.ListType(element_type=element_type, offset_type=offset_type): - new_element_type = substitute_type_vars(element_type, binding) - assert isinstance(new_element_type, ts.DataType) - return ts.ListType(element_type=new_element_type, offset_type=offset_type) - case ts.TupleType(types=types): - return ts.TupleType(types=[substitute_type_vars(t, binding) for t in types]) - case ts.NamedCollectionType(types=types): - return ts.NamedCollectionType( - types=[substitute_type_vars(t, binding) for t in types], - keys=type_.keys, - original_python_type=type_.original_python_type, - ) - case ts.FunctionType(): - return ts.FunctionType( - pos_only_args=[substitute_type_vars(t, binding) for t in type_.pos_only_args], - pos_or_kw_args={ - name: substitute_type_vars(t, binding) - for name, t in type_.pos_or_kw_args.items() - }, - kw_only_args={ - name: substitute_type_vars(t, binding) for name, t in type_.kw_only_args.items() - }, - returns=substitute_type_vars(type_.returns, binding), - ) - return type_ + if isinstance(type_, ts.TypeVarType): + return binding.get(type_.name, type_) + return tree_map_type_params(lambda t: substitute_type_vars(t, binding), type_) def promote( diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index c7e2a4821b..8450bbb7d6 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -105,6 +105,12 @@ def __str__(self) -> str: return f"{kind_str}{self.shape}" +def _canonicalize_constraints(constraints: Sequence[ScalarType]) -> tuple[ScalarType, ...]: + # A value-constrained type variable resolves to exactly one of its constraints, so their + # order carries no meaning; canonicalize it to make `TypeVarType` identity order-insensitive. + return tuple(sorted(constraints, key=lambda c: c.kind)) + + class TypeVarType(DataType): """ A scalar type variable, universally quantified over its constraints. @@ -115,7 +121,7 @@ class TypeVarType(DataType): """ name: str - constraints: tuple[ScalarType, ...] + constraints: tuple[ScalarType, ...] = eve_datamodels.field(converter=_canonicalize_constraints) def __str__(self) -> str: return f"{self.name}: ({' | '.join(map(str, self.constraints))})" diff --git a/tests/next_tests/unit_tests/type_system_tests/test_type_info.py b/tests/next_tests/unit_tests/type_system_tests/test_type_info.py index 981423fbc9..85ee3d69ae 100644 --- a/tests/next_tests/unit_tests/type_system_tests/test_type_info.py +++ b/tests/next_tests/unit_tests/type_system_tests/test_type_info.py @@ -434,8 +434,8 @@ def test_identity_and_hashing(self): assert hash(float_var) == hash(same_var) assert eve_utils.content_hash(float_var) == eve_utils.content_hash(same_var) assert float_var != ts.TypeVarType(name="S", constraints=(float32_type, float64_type)) - # constraint order is part of the identity (preserved as written) - assert float_var != ts.TypeVarType(name="T", constraints=(float64_type, float32_type)) + # constraint order is canonicalized, so it is not part of the identity + assert float_var == ts.TypeVarType(name="T", constraints=(float64_type, float32_type)) def test_is_generic(self): assert type_info.is_generic(float_var) From c250e263c2bd0212323196360a7d664af51882cc Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 17 Jun 2026 12:12:54 +0200 Subject: [PATCH 3/9] docs[next]: tidy dtype-generic docstrings Put the type-variable utility doctests under an 'Examples:' section (matching the rest of type_info.py), make the _type_params summary a sentence, and drop docstring and comment text describing rationale, callers, or future plans rather than behavior. --- .../dace/lowering/gtir_to_sdfg_types.py | 3 +- src/gt4py/next/type_system/type_info.py | 79 +++++++++---------- 2 files changed, 37 insertions(+), 45 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py index b220732756..7f6fc45df8 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py @@ -88,8 +88,7 @@ def get_local_view( ] dtype = self.gt_type.dtype # `FieldType.dtype` is widened to include `TypeVarType`, but generic operators are - # monomorphized before lowering, so only concrete dtypes reach here. Removing this - # (and the sibling asserts) would need a non-generic field type; deferred for now. + # monomorphized before lowering, so only concrete dtypes reach here. assert isinstance(dtype, (ts.ScalarType, ts.ListType)) return gtir_dataflow.IteratorExpr(self.dc_node, dtype, field_origin, it_indices) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index b95b9a95bf..f5502b89a7 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -55,11 +55,9 @@ def is_concrete(symbol_type: ts.TypeSpec) -> TypeGuard[ts.TypeSpec]: def _type_params(symbol_type: ts.TypeSpec) -> tuple[ts.TypeSpec, ...]: - """The immediate type-parameter sub-types of ``symbol_type``. + """Return the immediate type-parameter sub-types of ``symbol_type``. - These are its dtype, element type, tuple elements, or function argument / return types -- - one definition of where type parameters live, shared by the recursive type-variable - traversals (`is_generic` here, `substitute_type_vars` via `tree_map_type_params`). + These are its dtype, element type, tuple elements, or function argument / return types. """ match symbol_type: case ts.FieldType(dtype=dtype): @@ -86,25 +84,24 @@ def is_generic(symbol_type: ts.TypeSpec) -> bool: """ Figure out if a type contains parts that are only known when concrete arguments are given. - A generic (callable) type can be called with arguments of varying types, e.g. the program - context signature of a scan operator. This recurses into composite types, reporting ``True`` - if any nested part is a `DeferredType` or `TypeVarType`. It is *not* the negation of - :func:`is_concrete`: the latter is a shallow deduction-progress flag (top-level - `DeferredType` only), so a tuple with a nested `DeferredType` is both concrete and generic. + Recurses into composite types, reporting ``True`` if any nested part is a `DeferredType` or + `TypeVarType`. Unlike :func:`is_concrete` (a shallow top-level check), this is deep, so a + tuple with a nested `DeferredType` is both concrete and generic. Note: this returns ``True`` for a bare ``astype`` constructor type, whose ``definition`` carries a ``DeferredType`` by design; callers that only care about *data* arguments must filter for ``ts.DataType`` themselves. - >>> is_generic(ts.DeferredType(constraint=None)) - True + Examples: + >>> is_generic(ts.DeferredType(constraint=None)) + True - >>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) - >>> is_generic(bool_type) - False + >>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) + >>> is_generic(bool_type) + False - >>> is_generic(ts.TupleType(types=[bool_type, ts.DeferredType(constraint=None)])) - True + >>> is_generic(ts.TupleType(types=[bool_type, ts.DeferredType(constraint=None)])) + True """ if isinstance(symbol_type, (ts.DeferredType, ts.TypeVarType)): return True @@ -639,26 +636,23 @@ def bind_type_vars( """ Compute a binding of all type variables in ``params`` by structurally matching ``args``. - Concrete (non-generic) parts of the parameters are ignored here; mismatches in those are - reported by the regular signature checks. A type variable position only binds if the - corresponding argument provides a concrete scalar dtype; the caller is responsible for - checking that no type variable remained unbound (if required). - - Note: the binding is intentionally ``dtype``-scoped (scalar type variables only); a future - extension for dimension variables will widen it. + Concrete (non-generic) parts of the parameters are ignored; a type variable position binds + only if the corresponding argument provides a concrete scalar dtype. The caller is + responsible for checking that no type variable remained unbound. Raises: ValueError: If a type variable would be bound inconsistently or to a dtype that is not one of its constraints. - >>> var = ts.TypeVarType(name="T", constraints=(ts.ScalarType(kind=ts.ScalarKind.FLOAT64),)) - >>> I = common.Dimension(value="I") - >>> binding = bind_type_vars( - ... [ts.FieldType(dims=[I], dtype=var)], - ... [ts.FieldType(dims=[I], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64))], - ... ) - >>> print(binding["T"]) - float64 + Examples: + >>> var = ts.TypeVarType(name="T", constraints=(ts.ScalarType(kind=ts.ScalarKind.FLOAT64),)) + >>> I = common.Dimension(value="I") + >>> binding = bind_type_vars( + ... [ts.FieldType(dims=[I], dtype=var)], + ... [ts.FieldType(dims=[I], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64))], + ... ) + >>> print(binding["T"]) + float64 """ binding: dict[str, ts.ScalarType] = {} @@ -708,9 +702,7 @@ def tree_map_type_params( ) -> ts.TypeSpec: """Rebuild ``symbol_type`` applying ``fun`` to each immediate type-parameter sub-type. - The counterpart of `tree_map_type` for the type-parameter positions enumerated by - `_type_params` (dtype, element type, function argument / return types), which `tree_map_type` - -- a *collection* map -- does not descend into. Leaf types are returned unchanged. + Leaf types are returned unchanged. """ match symbol_type: case ts.FieldType(dims=dims, dtype=dtype): @@ -747,15 +739,16 @@ def substitute_type_vars( Unbound type variables and all other generic parts (e.g. `DeferredType`) are kept as-is. - >>> var = ts.TypeVarType(name="T", constraints=(ts.ScalarType(kind=ts.ScalarKind.FLOAT64),)) - >>> I = common.Dimension(value="I") - >>> print( - ... substitute_type_vars( - ... ts.FieldType(dims=[I], dtype=var), - ... {"T": ts.ScalarType(kind=ts.ScalarKind.FLOAT64)}, - ... ) - ... ) - Field[[I], float64] + Examples: + >>> var = ts.TypeVarType(name="T", constraints=(ts.ScalarType(kind=ts.ScalarKind.FLOAT64),)) + >>> I = common.Dimension(value="I") + >>> print( + ... substitute_type_vars( + ... ts.FieldType(dims=[I], dtype=var), + ... {"T": ts.ScalarType(kind=ts.ScalarKind.FLOAT64)}, + ... ) + ... ) + Field[[I], float64] """ if not binding or not is_generic(type_): return type_ From be62638f6eca7f12fb02f1e3f1833d4e1fe2a500 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 17 Jun 2026 18:17:10 +0200 Subject: [PATCH 4/9] refactor[next]: make bind_type_vars functional Replace the mutating closure-based walk with pure module-level helpers (_bind_var / _merge_bindings / _bind) that return and merge per-leaf bindings; keeps the tolerant structural recursion (mismatches deferred to the signature checks). --- src/gt4py/next/type_system/type_info.py | 85 +++++++++++++------------ 1 file changed, 44 insertions(+), 41 deletions(-) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index f5502b89a7..ab6c61fe2c 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -630,6 +630,49 @@ def is_concretizable(symbol_type: ts.TypeSpec, to_type: ts.TypeSpec) -> bool: return False +def _bind_var(var: ts.TypeVarType, dtype: ts.TypeSpec) -> dict[str, ts.ScalarType]: + if not isinstance(dtype, ts.ScalarType): + # not a concrete scalar to bind to -- e.g. a `TypeVarType` (operator-from-operator + # call), a `DeferredType` (scan), or a `ListType` (local field). Leave it unbound; + # the caller is responsible for checking that no type variable remained unbound. + return {} + if dtype not in var.constraints: + raise ValueError(f"'{dtype}' does not satisfy the constraints of type variable '{var}'.") + return {var.name: dtype} + + +def _merge_bindings(parts: Iterable[dict[str, ts.ScalarType]]) -> dict[str, ts.ScalarType]: + binding: dict[str, ts.ScalarType] = {} + for part in parts: + for name, dtype in part.items(): + if (previous := binding.get(name)) is not None and previous != dtype: + raise ValueError( + f"Type variable '{name}' is bound inconsistently:" + f" '{previous}' and '{dtype}' (all arguments using '{name}'" + " must have the same dtype)." + ) + binding[name] = dtype + return binding + + +def _bind(param: ts.TypeSpec, arg: ts.TypeSpec) -> dict[str, ts.ScalarType]: + match param: + case ts.TypeVarType() as var: + return _bind_var(var, arg) + case ts.FieldType(dtype=ts.TypeVarType() as var): + # scalar arguments are promoted to zero-dimensional fields + return _bind_var(var, arg.dtype if isinstance(arg, ts.FieldType) else arg) + case ts.ListType(element_type=element_type) if isinstance(arg, ts.ListType): + return _bind(element_type, arg.element_type) + case ts.TupleType() | ts.NamedCollectionType() if isinstance( + arg, (ts.TupleType, ts.NamedCollectionType) + ): + # tolerant by design: a structural mismatch (e.g. tuple vs scalar) binds nothing + # here and is reported by the regular signature checks instead. + return _merge_bindings(_bind(p, a) for p, a in zip(param.types, arg.types)) + return {} + + def bind_type_vars( params: Sequence[ts.TypeSpec], args: Sequence[ts.TypeSpec] ) -> dict[str, ts.ScalarType]: @@ -654,47 +697,7 @@ def bind_type_vars( >>> print(binding["T"]) float64 """ - binding: dict[str, ts.ScalarType] = {} - - def bind_var(var: ts.TypeVarType, dtype: ts.TypeSpec) -> None: - if not isinstance(dtype, ts.ScalarType): - # not a concrete scalar to bind to -- e.g. a `TypeVarType` (operator-from-operator - # call), a `DeferredType` (scan), or a `ListType` (local field). Leave it unbound; - # the caller is responsible for checking that no type variable remained unbound. - return - if dtype not in var.constraints: - raise ValueError( - f"'{dtype}' does not satisfy the constraints of type variable '{var}'." - ) - if (previous := binding.get(var.name)) is not None and previous != dtype: - raise ValueError( - f"Type variable '{var.name}' is bound inconsistently:" - f" '{previous}' and '{dtype}' (all arguments using '{var.name}'" - " must have the same dtype)." - ) - binding[var.name] = dtype - - def bind(param: ts.TypeSpec, arg: ts.TypeSpec) -> None: - match param: - case ts.TypeVarType(): - bind_var(param, arg) - case ts.FieldType(dtype=ts.TypeVarType() as var): - if isinstance(arg, ts.FieldType): - bind_var(var, arg.dtype) - elif isinstance(arg, ts.ScalarType): - # scalar arguments are promoted to zero-dimensional fields - bind_var(var, arg) - case ts.ListType(element_type=element_type) if isinstance(arg, ts.ListType): - bind(element_type, arg.element_type) - case ts.TupleType() | ts.NamedCollectionType() if isinstance( - arg, (ts.TupleType, ts.NamedCollectionType) - ): - for param_el, arg_el in zip(param.types, arg.types): - bind(param_el, arg_el) - - for param, arg in zip(params, args): - bind(param, arg) - return binding + return _merge_bindings(_bind(param, arg) for param, arg in zip(params, args)) def tree_map_type_params( From 339f1dd5e34d73bd0e47faabffd43a8d6bb37b57 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Fri, 19 Jun 2026 11:52:12 +0200 Subject: [PATCH 5/9] docs[next]: renumber dtype-generic ADR to 0024 Upstream main now occupies the 0023 slot with the Fingerprinting ADR, so move this ADR to the next free number and update the index accordingly. --- ...ype-Generic-Operators.md => 0024-Dtype-Generic-Operators.md} | 0 docs/development/ADRs/next/README.md | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename docs/development/ADRs/next/{0023-Dtype-Generic-Operators.md => 0024-Dtype-Generic-Operators.md} (100%) diff --git a/docs/development/ADRs/next/0023-Dtype-Generic-Operators.md b/docs/development/ADRs/next/0024-Dtype-Generic-Operators.md similarity index 100% rename from docs/development/ADRs/next/0023-Dtype-Generic-Operators.md rename to docs/development/ADRs/next/0024-Dtype-Generic-Operators.md diff --git a/docs/development/ADRs/next/README.md b/docs/development/ADRs/next/README.md index 287eea2076..c6b1907a2c 100644 --- a/docs/development/ADRs/next/README.md +++ b/docs/development/ADRs/next/README.md @@ -28,7 +28,7 @@ Writing a new ADR is simple: ### Type System -- [0023 - Dtype-Generic Operators](0023-Dtype-Generic-Operators.md) +- [0024 - Dtype-Generic Operators](0024-Dtype-Generic-Operators.md) ### Iterator IR #iterator From 7368b79e2b5b1c64a7bf8c31d7e8071421eaeb79 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 23 Jun 2026 20:23:40 +0200 Subject: [PATCH 6/9] refactor[next]: simplify substitute_type_vars via tree_map_type Route substitution through the existing `tree_map_type` rather than a bespoke type-parameter map, handling the dtype / element-type / signature rewrite in the leaf function. Drop the upfront `is_generic` check: the traversal runs unconditionally and reconstructs the tree. Addresses the review comment on the `tree_map_type_params` helper. --- src/gt4py/next/type_system/type_info.py | 70 +++++++++---------- .../type_system_tests/test_type_info.py | 2 +- 2 files changed, 33 insertions(+), 39 deletions(-) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index ab6c61fe2c..558a37af49 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -700,40 +700,6 @@ def bind_type_vars( return _merge_bindings(_bind(param, arg) for param, arg in zip(params, args)) -def tree_map_type_params( - fun: Callable[[ts.TypeSpec], ts.TypeSpec], symbol_type: ts.TypeSpec -) -> ts.TypeSpec: - """Rebuild ``symbol_type`` applying ``fun`` to each immediate type-parameter sub-type. - - Leaf types are returned unchanged. - """ - match symbol_type: - case ts.FieldType(dims=dims, dtype=dtype): - new_dtype = fun(dtype) - assert isinstance(new_dtype, (ts.ScalarType, ts.ListType, ts.TypeVarType)) - return ts.FieldType(dims=dims, dtype=new_dtype) - case ts.ListType(element_type=element_type, offset_type=offset_type): - new_element_type = fun(element_type) - assert isinstance(new_element_type, ts.DataType) - return ts.ListType(element_type=new_element_type, offset_type=offset_type) - case ts.TupleType(types=types): - return ts.TupleType(types=[fun(t) for t in types]) - case ts.NamedCollectionType(types=types): - return ts.NamedCollectionType( - types=[fun(t) for t in types], - keys=symbol_type.keys, - original_python_type=symbol_type.original_python_type, - ) - case ts.FunctionType(): - return ts.FunctionType( - pos_only_args=[fun(t) for t in symbol_type.pos_only_args], - pos_or_kw_args={name: fun(t) for name, t in symbol_type.pos_or_kw_args.items()}, - kw_only_args={name: fun(t) for name, t in symbol_type.kw_only_args.items()}, - returns=fun(symbol_type.returns), - ) - return symbol_type - - def substitute_type_vars( type_: ts.TypeSpec, binding: xtyping.Mapping[str, ts.ScalarType] ) -> ts.TypeSpec: @@ -753,11 +719,39 @@ def substitute_type_vars( ... ) Field[[I], float64] """ - if not binding or not is_generic(type_): + if not binding: return type_ - if isinstance(type_, ts.TypeVarType): - return binding.get(type_.name, type_) - return tree_map_type_params(lambda t: substitute_type_vars(t, binding), type_) + + def substitute_leaf(leaf: ts.TypeSpec) -> ts.TypeSpec: + # `tree_map_type` has already mapped the tuple structure; what is left is to substitute + # inside the primitive constituents, i.e. in their dtype / element type / signature. + match leaf: + case ts.TypeVarType(): + return binding.get(leaf.name, leaf) + case ts.FieldType(dims=dims, dtype=dtype): + new_dtype = substitute_type_vars(dtype, binding) + assert isinstance(new_dtype, (ts.ScalarType, ts.ListType, ts.TypeVarType)) + return ts.FieldType(dims=dims, dtype=new_dtype) + case ts.ListType(element_type=element_type, offset_type=offset_type): + new_element_type = substitute_type_vars(element_type, binding) + assert isinstance(new_element_type, ts.DataType) + return ts.ListType(element_type=new_element_type, offset_type=offset_type) + case ts.FunctionType(): + return ts.FunctionType( + pos_only_args=[substitute_type_vars(a, binding) for a in leaf.pos_only_args], + pos_or_kw_args={ + name: substitute_type_vars(a, binding) + for name, a in leaf.pos_or_kw_args.items() + }, + kw_only_args={ + name: substitute_type_vars(a, binding) + for name, a in leaf.kw_only_args.items() + }, + returns=substitute_type_vars(leaf.returns, binding), + ) + return leaf + + return tree_map_type(substitute_leaf)(type_) def promote( diff --git a/tests/next_tests/unit_tests/type_system_tests/test_type_info.py b/tests/next_tests/unit_tests/type_system_tests/test_type_info.py index 85ee3d69ae..33b8da2348 100644 --- a/tests/next_tests/unit_tests/type_system_tests/test_type_info.py +++ b/tests/next_tests/unit_tests/type_system_tests/test_type_info.py @@ -558,7 +558,7 @@ def test_unbound_vars_are_kept(self): def test_concrete_is_returned_unchanged(self): concrete = ts.FieldType(dims=[TDim], dtype=float64_type) - assert type_info.substitute_type_vars(concrete, {"T": float32_type}) is concrete + assert type_info.substitute_type_vars(concrete, {"T": float32_type}) == concrete def test_substitute_function_type(self): func_type = ts.FunctionType( From 8e81c2361985663a87f1706f90b3504094d7e1d7 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 24 Jun 2026 10:57:33 +0200 Subject: [PATCH 7/9] refactor[next]: collapse is_arithmetic into _is_field_or_scalar_of_kind is_arithmetic hand-rolled its own TypeVarType handling under a comment claiming it could not reuse _is_field_or_scalar_of_kind. That helper takes a kind *set*, so arithmetic is just membership in the union of the floating point and integral kinds; all(constraint in union) is correct even for a type variable mixing float and integral constraints. Route it through the shared helper via _ARITHMETIC_KINDS, removing the duplicated constraint fold. --- src/gt4py/next/type_system/type_info.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 558a37af49..e6040c3d39 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -268,6 +268,7 @@ def _scalar_kinds(scalar_types: tuple[type, ...]) -> frozenset[ts.ScalarKind]: _FLOATING_POINT_KINDS: Final[frozenset[ts.ScalarKind]] = _scalar_kinds(core_defs.FLOAT_TYPES) _INTEGRAL_KINDS: Final[frozenset[ts.ScalarKind]] = _scalar_kinds(core_defs.INTEGRAL_TYPES) +_ARITHMETIC_KINDS: Final[frozenset[ts.ScalarKind]] = _FLOATING_POINT_KINDS | _INTEGRAL_KINDS def _is_field_or_scalar_of_kind(symbol_type: ts.TypeSpec, kinds: Collection[ts.ScalarKind]) -> bool: @@ -386,16 +387,7 @@ def is_arithmetic(symbol_type: ts.TypeSpec) -> bool: >>> is_arithmetic(ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32))) True """ - # `is_arithmetic` cannot reuse `_is_field_or_scalar_of_kind`'s "all constraints - # share the kind" rule: a type variable is arithmetic if every constraint is - # arithmetic, even when the constraints mix floating point and integral kinds. - if isinstance(symbol_type, ts.TypeVarType): - return all(is_arithmetic(c) for c in symbol_type.constraints) - if isinstance(symbol_type, (ts.ScalarType, ts.FieldType)) and isinstance( - dtype := extract_dtype(symbol_type), ts.TypeVarType - ): - return is_arithmetic(dtype) - return is_floating_point(symbol_type) or is_integral(symbol_type) + return _is_field_or_scalar_of_kind(symbol_type, _ARITHMETIC_KINDS) def arithmetic_bounds(arithmetic_type: ts.ScalarType) -> tuple[np.number, np.number]: From 2a84c7f4e468091215066681a1a5f54bf27dced0 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 24 Jun 2026 11:14:27 +0200 Subject: [PATCH 8/9] refactor[next]: centralize function-type sub-type traversal The positional/keyword/return sub-type layout of FunctionType was spelled out independently in _type_params, substitute_type_vars and is_compatible_type. Extract it into _function_type_arg_groups (the grouped layout, kept grouped so is_compatible_type still enforces per-group arity), _function_type_children (the flat enumeration) and _map_function_type (structure-preserving rebuild), and route the three call sites through them. --- src/gt4py/next/type_system/type_info.py | 64 ++++++++++++++----------- 1 file changed, 36 insertions(+), 28 deletions(-) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index e6040c3d39..5f0e073756 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -54,6 +54,35 @@ def is_concrete(symbol_type: ts.TypeSpec) -> TypeGuard[ts.TypeSpec]: return False +def _function_type_arg_groups( + function_type: ts.FunctionType, +) -> tuple[Sequence[ts.TypeSpec], ...]: + """The positional, keyword and return sub-type groups of a function type, in canonical order.""" + return ( + function_type.pos_only_args, + tuple(function_type.pos_or_kw_args.values()), + tuple(function_type.kw_only_args.values()), + (function_type.returns,), + ) + + +def _function_type_children(function_type: ts.FunctionType) -> tuple[ts.TypeSpec, ...]: + """Return the argument and return sub-types of a function type, in canonical order.""" + return tuple(child for group in _function_type_arg_groups(function_type) for child in group) + + +def _map_function_type( + function_type: ts.FunctionType, transform: Callable[[ts.TypeSpec], ts.TypeSpec] +) -> ts.FunctionType: + """Apply ``transform`` to each argument and return sub-type of a function type.""" + return ts.FunctionType( + pos_only_args=[transform(a) for a in function_type.pos_only_args], + pos_or_kw_args={name: transform(a) for name, a in function_type.pos_or_kw_args.items()}, + kw_only_args={name: transform(a) for name, a in function_type.kw_only_args.items()}, + returns=transform(function_type.returns), + ) + + def _type_params(symbol_type: ts.TypeSpec) -> tuple[ts.TypeSpec, ...]: """Return the immediate type-parameter sub-types of ``symbol_type``. @@ -67,12 +96,7 @@ def _type_params(symbol_type: ts.TypeSpec) -> tuple[ts.TypeSpec, ...]: case ts.TupleType(types=types) | ts.NamedCollectionType(types=types): return tuple(types) case ts.FunctionType(): - return ( - *symbol_type.pos_only_args, - *symbol_type.pos_or_kw_args.values(), - *symbol_type.kw_only_args.values(), - symbol_type.returns, - ) + return _function_type_children(symbol_type) # callable type wrappers (e.g. the field operator types in `ffront`) carry their # signature in a `definition` attribute if isinstance(definition := getattr(symbol_type, "definition", None), ts.TypeSpec): @@ -552,17 +576,12 @@ def is_compatible_type(type_a: ts.TypeSpec, type_b: ts.TypeSpec) -> bool: for el_type_a, el_type_b in zip(type_a.types, type_b.types, strict=True): is_compatible &= is_compatible_type(el_type_a, el_type_b) elif isinstance(type_a, ts.FunctionType) and isinstance(type_b, ts.FunctionType): - for arg_a, arg_b in zip(type_a.pos_only_args, type_b.pos_only_args, strict=True): - is_compatible &= is_compatible_type(arg_a, arg_b) - for arg_a, arg_b in zip( - type_a.pos_or_kw_args.values(), type_b.pos_or_kw_args.values(), strict=True - ): - is_compatible &= is_compatible_type(arg_a, arg_b) - for arg_a, arg_b in zip( - type_a.kw_only_args.values(), type_b.kw_only_args.values(), strict=True + # zip per group (not flattened) so a positional/keyword arity mismatch is still caught + for group_a, group_b in zip( + _function_type_arg_groups(type_a), _function_type_arg_groups(type_b), strict=True ): - is_compatible &= is_compatible_type(arg_a, arg_b) - is_compatible &= is_compatible_type(type_a.returns, type_b.returns) + for arg_a, arg_b in zip(group_a, group_b, strict=True): + is_compatible &= is_compatible_type(arg_a, arg_b) else: is_compatible &= is_concretizable(type_a, type_b) @@ -729,18 +748,7 @@ def substitute_leaf(leaf: ts.TypeSpec) -> ts.TypeSpec: assert isinstance(new_element_type, ts.DataType) return ts.ListType(element_type=new_element_type, offset_type=offset_type) case ts.FunctionType(): - return ts.FunctionType( - pos_only_args=[substitute_type_vars(a, binding) for a in leaf.pos_only_args], - pos_or_kw_args={ - name: substitute_type_vars(a, binding) - for name, a in leaf.pos_or_kw_args.items() - }, - kw_only_args={ - name: substitute_type_vars(a, binding) - for name, a in leaf.kw_only_args.items() - }, - returns=substitute_type_vars(leaf.returns, binding), - ) + return _map_function_type(leaf, lambda a: substitute_type_vars(a, binding)) return leaf return tree_map_type(substitute_leaf)(type_) From 8e6ca0aefee80afcd13f5260cb5d239408b81ba8 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 24 Jun 2026 11:33:12 +0200 Subject: [PATCH 9/9] refactor[next]: add is_concrete_dtype TypeIs guard The dace lowering narrows FieldType.dtype past the widened TypeVarType with plain 'assert isinstance(dtype, (ts.ScalarType, ts.ListType))'. Introduce a type_info.is_concrete_dtype TypeIs guard and route those asserts through it, so the narrowing is self-documenting (the dtype is concrete, i.e. not a type variable) and the intent is stated once. --- .../runners/dace/lowering/gtir_dataflow.py | 2 +- .../runners/dace/lowering/gtir_to_sdfg_scan.py | 2 +- .../runners/dace/lowering/gtir_to_sdfg_types.py | 6 ++---- src/gt4py/next/type_system/type_info.py | 5 +++++ 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index d8498b45c3..ed60aec004 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -105,7 +105,7 @@ class MemletExpr: @property def gt_dtype(self) -> ts.ScalarType | ts.ListType: dtype = self.gt_field.dtype - assert isinstance(dtype, (ts.ScalarType, ts.ListType)) + assert ti.is_concrete_dtype(dtype) return dtype def __post_init__(self) -> None: diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py index da1a813f38..0169e8b149 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_scan.py @@ -623,7 +623,7 @@ def _handle_dataflow_result_of_nested_sdfg( dace.Memlet.from_array(outer_dataname, outer_desc), ) output_dtype = inner_data.gt_type.dtype - assert isinstance(output_dtype, (ts.ScalarType, ts.ListType)) + assert ti.is_concrete_dtype(output_dtype) output_expr = gtir_dataflow.ValueExpr(outer_node, output_dtype) return gtir_dataflow.DataflowOutputEdge(outer_ctx.state, output_expr) diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py index 7f6fc45df8..30b9e4c7a3 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg_types.py @@ -21,7 +21,7 @@ from gt4py.next.iterator import builtins as gtir_builtins from gt4py.next.program_processors.runners.dace import sdfg_args as gtx_dace_args from gt4py.next.program_processors.runners.dace.lowering import gtir_dataflow, gtir_domain -from gt4py.next.type_system import type_specifications as ts +from gt4py.next.type_system import type_info as ti, type_specifications as ts @dataclasses.dataclass(frozen=True) @@ -87,9 +87,7 @@ def get_local_view( for i, dim in enumerate(self.gt_type.dims) ] dtype = self.gt_type.dtype - # `FieldType.dtype` is widened to include `TypeVarType`, but generic operators are - # monomorphized before lowering, so only concrete dtypes reach here. - assert isinstance(dtype, (ts.ScalarType, ts.ListType)) + assert ti.is_concrete_dtype(dtype) return gtir_dataflow.IteratorExpr(self.dc_node, dtype, field_origin, it_indices) raise NotImplementedError(f"Node type {type(self.gt_type)} not supported.") diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 5f0e073756..2456489de4 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -284,6 +284,11 @@ def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType | ts.ListType | ts. raise ValueError(f"Can not unambiguosly extract data type from '{symbol_type}'.") +def is_concrete_dtype(dtype: ts.TypeSpec) -> xtyping.TypeIs[ts.ScalarType | ts.ListType]: + """Whether ``dtype`` is a concrete field dtype, i.e. not a (generic) `TypeVarType`.""" + return isinstance(dtype, (ts.ScalarType, ts.ListType)) + + def _scalar_kinds(scalar_types: tuple[type, ...]) -> frozenset[ts.ScalarKind]: # Derived from the canonical scalar-type tuples in `gt4py._core.definitions` so the two # stay in sync; the `int`/`float` builtins collapse onto their fixed-width kind.