From 65952fd8d97cac68d63516c04a78659a9d5de45d Mon Sep 17 00:00:00 2001 From: gokulkrishna98 Date: Wed, 24 Jun 2026 10:52:57 -0700 Subject: [PATCH 1/2] Sync v0.4.1: bump version, update coreai-core to b2, sync coreai_torch and tests - Bump coreai-torch to 0.4.1 - Update coreai-core dependency to 1.0.0b2 - Sync coreai_torch: converter, debugging (debug_info, inspector), _debug_locations, _torch_metal_kernel, _utils - Sync tests: debugging, dsl (new), ops, composite_ops, test_docs (new), utils --- coreai_torch/__version__.py | 2 +- coreai_torch/_debug_locations.py | 365 ++++++------------ coreai_torch/_torch_metal_kernel.py | 130 +++++++ coreai_torch/_utils.py | 27 +- coreai_torch/converter.py | 85 +++- coreai_torch/debugging/debug_info.py | 159 ++++++-- coreai_torch/debugging/inspector.py | 356 ++++++++++------- pyproject.toml | 2 +- tests/composite_ops/test_sdpa.py | 40 +- tests/debugging/test_benchmarker.py | 2 +- tests/debugging/test_comparator.py | 2 - tests/debugging/test_debug_info.py | 41 +- tests/debugging/test_graph_diff.py | 5 +- tests/debugging/test_inspector.py | 5 - tests/debugging/test_intermediates.py | 279 ++++++++++++++ tests/debugging/test_location_bindings.py | 2 - tests/debugging/test_model.py | 16 +- tests/debugging/test_torch_utils.py | 1 - tests/debugging/test_validator.py | 3 - tests/dsl/test_dtype_specialization.py | 348 +++++++++++++++++ tests/dsl/test_kernel_collisions.py | 269 +++++++++++++ tests/dsl/test_scalar_inputs.py | 447 ++++++++++++++++++++++ tests/dsl/test_thread_config.py | 344 +++++++++++++++++ tests/ops/test_ops.py | 26 +- tests/test_debug_locations.py | 2 - tests/test_docs.py | 57 +++ tests/utils.py | 13 +- 27 files changed, 2506 insertions(+), 522 deletions(-) create mode 100644 tests/debugging/test_intermediates.py create mode 100644 tests/dsl/test_dtype_specialization.py create mode 100644 tests/dsl/test_kernel_collisions.py create mode 100644 tests/dsl/test_scalar_inputs.py create mode 100644 tests/dsl/test_thread_config.py create mode 100644 tests/test_docs.py diff --git a/coreai_torch/__version__.py b/coreai_torch/__version__.py index d4518aa..0ccc745 100644 --- a/coreai_torch/__version__.py +++ b/coreai_torch/__version__.py @@ -5,4 +5,4 @@ """Version information for coreai-torch.""" -__version__ = "0.4.0" +__version__ = "0.4.1" diff --git a/coreai_torch/_debug_locations.py b/coreai_torch/_debug_locations.py index 17afbe6..2e282a1 100644 --- a/coreai_torch/_debug_locations.py +++ b/coreai_torch/_debug_locations.py @@ -13,7 +13,6 @@ from collections import OrderedDict from contextlib import contextmanager from dataclasses import dataclass -from enum import Enum from typing import Iterable, Iterator, Optional import torch.fx as fx @@ -148,7 +147,10 @@ class DebugInfo: def _create_debug_location_from_debug_info( - debug_info: DebugInfo, scope: Attribute, context: Context + debug_info: DebugInfo, + scope: Attribute, + context: Context, + unknown_src: Location, ) -> location_attr: """ Create a debug location from DebugInfo dataclass. @@ -162,6 +164,7 @@ def _create_debug_location_from_debug_info( scope: Debug scope attribute (SubprogramAttr, CompileUnitAttr, or UnitAttr) context: Core AI context + unknown_src: Unknown location. Returns: LocationAttr with file location as base and metadata from other @@ -200,136 +203,12 @@ def _create_debug_location_from_debug_info( scope=scope, context=context, metadata_attrs=metadata_attrs, + unknown_src=unknown_src, ) -def _create_locations_from_debug_info( - debug_info: DebugInfo, context: Context -) -> list[Location]: - """Create Core AI locations from DebugInfo file locations and operation IDs.""" - locations = [] - - # Convert file locations to Core AI Location objects - if debug_info.file_locations: - locations.extend( - [ - Location.file( - filename=file_loc.filename, - line=file_loc.line, - col=file_loc.col, - context=context, - ) - for file_loc in debug_info.file_locations - ] - ) - - # Add operation ID location using _create_operation_id_location - if debug_info.operation_id: - op_id_location = _create_operation_id_location(debug_info.operation_id, context) - locations.append(op_id_location) - - # Add source operation ID location using _create_operation_id_location - if debug_info.source: - source_operation_id = OperationID( - type=debug_info.source.name, value=debug_info.source.id - ) - source_op_id_location = _create_operation_id_location( - source_operation_id, context - ) - locations.append(source_op_id_location) - - return locations - - -def _create_metadata_dict_from_debug_info( - debug_info: DebugInfo, context: Context -) -> dict[str, Attribute]: - """Create metadata dictionary from DebugInfo components.""" - metadata_dict = {} - - # Add source identifiers metadata if source exists - if debug_info.source and debug_info.source.identifiers: - id_attrs = [ - StringAttr.get(ident, context=context) - for ident in debug_info.source.identifiers - ] - metadata_dict["identifiers"] = ArrayAttr.get(id_attrs, context=context) - - # Add output maps metadata if present - if debug_info.output_maps: - map_dicts = [ - _output_map_to_dictionary(output_map, context) - for output_map in debug_info.output_maps - ] - metadata_dict["output_maps"] = ArrayAttr.get(map_dicts, context=context) - - # Add module hierarchy metadata if present - if debug_info.call_stack: - string_attrs = [ - StringAttr.get(name, context=context) for name in debug_info.call_stack - ] - metadata_dict["call_stack"] = ArrayAttr.get(string_attrs, context=context) - - return metadata_dict - - -def _create_location_from_debug_info( - debug_info: DebugInfo, context: Context -) -> Location: - """ - Create a Core AI Location from DebugInfo dataclass. - - Uses file_locations as the base location, creating a fused location with - operation ID locations for both Operation ID and source operation ID. - - Args: - debug_info: DebugInfo dataclass instance containing all debug - information - context: Core AI context - - Returns: - Core AI Location object from file locations with operation ID locations, - or unknown location if no file locations and no operation IDs available - """ - # Create Core AI locations from debug info - locations = _create_locations_from_debug_info(debug_info, context) - - # Create metadata dictionary - metadata_dict = _create_metadata_dict_from_debug_info(debug_info, context) - - # Handle edge cases - if not locations: - return Location.unknown(context=context) - - if len(locations) == 1 and not metadata_dict: - return locations[0] - - # Create fused location with operation ID locations and metadata - fused_attr = DictAttr.get(metadata_dict, context=context) if metadata_dict else None - return Location.fused(locations, metadata=fused_attr, context=context) - - # ============================================================================= # Integer Attribute Helper Functions -def _create_operation_id_location( - operation_id: OperationID, context: Context -) -> Location: - """Create a NamedLocation with operation ID information. - - Args: - operation_id: Operation ID containing type and value - context: Core AI context - - Returns: - Core AI NamedLocation with operation ID - """ - # Create NamedLocation(, loc("op_id", , 0)) - file_loc = Location.file( - filename="op_id", line=operation_id.value, col=0, context=context - ) - return Location.name(operation_id.type, file_loc, context=context) - - # ============================================================================= @@ -459,6 +338,53 @@ def _create_output_maps_metadata( return metadata_attr(name="output_maps", data=output_maps_array, context=context) +def _create_unknown_location( + unknown_src: Location, + context: Context, + metadata_attrs: list[metadata_attr] | None = None, +) -> location_attr: + """Create an unknown debuginfo location with optional metadata. + + Args: + unknown_src: A pre-created unknown source Location (e.g. filename="", line=0). + context: Core AI context. + metadata_attrs: Optional metadata attributes to attach. + + Returns: + A location_attr with UnitAttr scope and the given metadata. + """ + scope = UnitAttr.get(context=context) + return location_attr( + src=unknown_src, + scope=scope, + metadata=metadata_attrs, + context=context, + ) + + +def _create_unknown_location_with_operation_id( + operation_id: OperationID | None, + unknown_src: Location, + context: Context, +) -> location_attr: + """Create an unknown debuginfo location preserving only an operation ID. + + Args: + operation_id: Optional operation ID to embed as metadata. + unknown_src: A pre-created unknown source Location. + context: Core AI context. + + Returns: + A location_attr with UnitAttr scope and operation_id metadata. + """ + metadata_attrs: list[metadata_attr] = [] + if operation_id is not None: + metadata_attrs.append(_create_operation_id_metadata(operation_id, context)) + return _create_unknown_location( + unknown_src, context, metadata_attrs if metadata_attrs else None + ) + + # ============================================================================= # Location Creation Functions # ============================================================================= @@ -489,60 +415,6 @@ def _create_stack_trace_file_locations( return file_locations -def _create_unknown_location( - context: Context, - metadata_attrs: list[metadata_attr] | None = None, -) -> location_attr: - """ - Create an unknown location when file location information is not - available. - - Args: - context: Core AI context - metadata_attrs: Optional list of metadata attributes to attach - - Returns: - LocationAttr with empty filename and UnitAttr scope - """ - # Create empty file location (filename="", line=0, col=0) - unknown_src = Location.file( - filename="-", - line=0, - col=0, - context=context, - ) - - # Create with UnitAttr scope - scope = UnitAttr.get(context=context) - - return location_attr( - src=unknown_src, - scope=scope, - metadata=metadata_attrs, - context=context, - ) - - -def _create_unknown_location_with_operation_id( - debug_info: DebugInfo, context: Context -) -> location_attr: - """ - Create an unknown location with operation_id metadata if available. - - Args: - debug_info: Debug information containing optional operation_id - context: Core AI context - - Returns: - LocationAttr with unknown location and operation_id metadata - """ - metadata_attrs = [] - if debug_info.operation_id: - op_id_metadata = _create_operation_id_metadata(debug_info.operation_id, context) - metadata_attrs.append(op_id_metadata) - return _create_unknown_location(context, metadata_attrs) - - def _create_debug_info_from_node( node: fx.Node, source_operation_id: int, @@ -582,6 +454,7 @@ def _create_stack_trace_debug_location( file_locations: list[FileLineColLoc], scope: Attribute, context: Context, + unknown_src: Location, metadata_attrs: list[metadata_attr] | None = None, ) -> location_attr: """ @@ -593,6 +466,7 @@ def _create_stack_trace_debug_location( scope: Debug scope attribute (SubprogramAttr, CompileUnitAttr, or UnitAttr) context: Core AI context + unknown_src: Unknown location. metadata_attrs: Optional list of metadata attributes to attach Returns: @@ -600,7 +474,13 @@ def _create_stack_trace_debug_location( locations """ if not file_locations: - return _create_unknown_location(context, metadata_attrs) + unit_scope = UnitAttr.get(context=context) + return location_attr( + src=unknown_src, + scope=unit_scope, + metadata=metadata_attrs, + context=context, + ) # Convert first location to Core AI Location for src first_location = Location.file( @@ -863,25 +743,17 @@ def clear(self) -> None: class _DebugInfoRecorder: """Recorder for recording debug information with node and results.""" - class Options(Enum): - """Enum to specify location creation mode.""" - - DEBUGINFO = "debuginfo" # Use location_attr with full debug info - STANDARD = "standard" # Use standard Core AI Location objects - @dataclass(frozen=True) class Config: """Configuration for _DebugInfoRecorder.""" include_stack_trace: bool - options: "_DebugInfoRecorder.Options" verify_debuginfo_locations: bool def __init__( self: Self, config: "Config" = Config( include_stack_trace=True, - options=Options.STANDARD, verify_debuginfo_locations=False, ), ): @@ -897,87 +769,89 @@ def __init__( ) self._current_node: fx.Node | None = None self._file_cache: OrderedDict[str, file_attr] = OrderedDict() + self._unknown_src: Location | None = None - def _create_unknown_location_with_operation_id( - self: Self, debug_info: DebugInfo, context: Context - ) -> Location: - """Create a NamedLocation with operation ID information. + def _get_unknown_src(self: Self, context: Context) -> Location: + """Get or create the shared unknown source location. + + Creates the Location once on first call and reuses it for all + subsequent calls within the same recorder instance. Args: - debug_info: Debug information containing operation ID context: Core AI context Returns: - Core AI NamedLocation with operation ID or unknown location + A shared Location with filename="", line=0, col=0 """ - if debug_info.operation_id: - return _create_operation_id_location(debug_info.operation_id, context) - else: - return Location.unknown(context=context) + if self._unknown_src is None: + self._unknown_src = Location.file( + filename="", + line=0, + col=0, + context=context, + ) + return self._unknown_src - def _create_debuginfo_location( + def _get_unknown_location( self: Self, - debug_info: DebugInfo, context: Context, - scope: Attribute | None = None, + metadata_attrs: list[metadata_attr] | None = None, ) -> location_attr: - """Create location_attr for DEBUGINFO mode. + """Create an unknown location reusing the cached unknown_src. Args: - debug_info: Debug information for the location context: Core AI context - scope: Optional scope for DEBUGINFO mode + metadata_attrs: Optional list of metadata attributes to attach Returns: - location_attr with debug info or unknown location if disabled + LocationAttr with cached unknown_src and UnitAttr scope """ - if not self.config.include_stack_trace: - return _create_unknown_location_with_operation_id(debug_info, context) - else: - if scope is None: - scope = UnitAttr.get(context=context) - return _create_debug_location_from_debug_info(debug_info, scope, context) + unknown_src = self._get_unknown_src(context) + return _create_unknown_location(unknown_src, context, metadata_attrs) - def _create_standard_location( - self: Self, - debug_info: DebugInfo, - context: Context, - ) -> Location: - """Create standard Core AI Location for STANDARD mode. + def _get_unknown_location_with_operation_id( + self: Self, debug_info: DebugInfo, context: Context + ) -> location_attr: + """Create an unknown location with operation_id metadata if available. + + Reuses the cached unknown_src and attaches operation ID as metadata. Args: - debug_info: Debug information for the location + debug_info: Debug information containing optional operation_id context: Core AI context Returns: - Location with debug info or unknown location if disabled + LocationAttr with cached unknown_src and operation_id metadata """ - if not self.config.include_stack_trace: - return self._create_unknown_location_with_operation_id(debug_info, context) - else: - return _create_location_from_debug_info(debug_info, context) + unknown_src = self._get_unknown_src(context) + return _create_unknown_location_with_operation_id( + debug_info.operation_id, unknown_src, context + ) def _create_operation_location( self: Self, debug_info: DebugInfo, context: Context, scope: Attribute | None = None, - ) -> location_attr | Location: - """Create location based on the current location mode and enable_locations setting. + ) -> location_attr: + """Create a debuginfo location_attr for an operation. Args: debug_info: Debug information for the location context: Core AI context - scope: Optional scope for DEBUGINFO mode + scope: Optional scope attribute Returns: - location_attr for DEBUGINFO mode or Location for STANDARD mode, - with operation ID metadata preserved even when locations are disabled + location_attr with debug info, or unknown location with operation + ID metadata when stack traces are disabled """ - if self.config.options == self.Options.DEBUGINFO: - return self._create_debuginfo_location(debug_info, context, scope) - else: # STANDARD mode - return self._create_standard_location(debug_info, context) + if not self.config.include_stack_trace: + return self._get_unknown_location_with_operation_id(debug_info, context) + if scope is None: + scope = UnitAttr.get(context=context) + return _create_debug_location_from_debug_info( + debug_info, scope, context, unknown_src=self._get_unknown_src(context) + ) def _populate_file_cache_from_debug_info( self: Self, debug_info: DebugInfo, context: Context @@ -1052,10 +926,6 @@ def _set_graph_location(self: Self, graph_operation: Operation) -> None: Args: graph_operation: The graph operation to set location for """ - # Skip location creation for STANDARD mode - if self.config.options == self.Options.STANDARD: - return - context = graph_operation.context # Create operation ID metadata for the graph operation @@ -1065,7 +935,7 @@ def _set_graph_location(self: Self, graph_operation: Operation) -> None: if not self.config.include_stack_trace: # Create unknown location with metadata if locations are disabled - debug_location = _create_unknown_location(context, [op_id_metadata]) + debug_location = self._get_unknown_location(context, [op_id_metadata]) else: # Get the graph name using the helper function graph_name = _get_symbol_name(graph_operation, "") @@ -1121,17 +991,13 @@ def _set_module_location(self: Self, module: Module) -> None: Args: module: The Core AI Module to set location for """ - # Skip location creation for STANDARD mode - if self.config.options == self.Options.STANDARD: - return - # Get the module operation module_op = module.operation context = module_op.context if not self.config.include_stack_trace: # Create unknown location if locations are disabled - debug_location = _create_unknown_location(context) + debug_location = self._get_unknown_location(context) else: # Create compile unit for module using files from the graph if available if self._file_cache: @@ -1223,11 +1089,8 @@ def record_module(self: Self, module: Module): # Set module location before restoring context self._set_module_location(module) - # Verify that each location is a debuginfo location (only in DEBUGINFO mode and when enabled) - if ( - self.config.options == self.Options.DEBUGINFO - and self.config.verify_debuginfo_locations - ): + # Verify that each location is a debuginfo location when enabled + if self.config.verify_debuginfo_locations: self._verify_debuginfo_locations(module) # Restore previous module and clear caches @@ -1410,7 +1273,7 @@ def record_operation(self: Self, node: fx.Node): if not self.config.include_stack_trace: # Create unknown location with metadata if locations are disabled - location = _create_unknown_location_with_operation_id( + location = self._get_unknown_location_with_operation_id( debug_info, context ) else: @@ -1460,7 +1323,7 @@ def _ensure_all_operations_have_debug_locations( # Create debug location if not self.config.include_stack_trace: # Create unknown location with metadata if locations are disabled - location = _create_unknown_location_with_operation_id( + location = self._get_unknown_location_with_operation_id( debug_info, context ) else: diff --git a/coreai_torch/_torch_metal_kernel.py b/coreai_torch/_torch_metal_kernel.py index d783e7a..d807c9a 100644 --- a/coreai_torch/_torch_metal_kernel.py +++ b/coreai_torch/_torch_metal_kernel.py @@ -8,6 +8,7 @@ from __future__ import annotations import inspect +import math from collections import Counter from collections.abc import Sequence from functools import wraps @@ -21,6 +22,19 @@ # We're allowing for int, bool, and float scalar inputs. _ALLOWED_SCALARS = {int, float, bool} +# MSL parameter types for scalar inputs, indexed by Python type. The IR-level +# element type for a bool scalar is widened to ui8 (i1 isn't accepted by the +# metal4_kernel verifier), so we override the MSL signature here to keep the +# user-facing dtype `bool` rather than `uint8_t`. +_SCALAR_METAL_DTYPE = {bool: "bool", int: "int", float: "float"} + +# Range of MSL's 32-bit `int`. Int scalars are baked into the kernel body as +# literals and the IR-side constant is built as ``np.int32``; values outside +# this range would wrap (IR side) or overflow the MSL literal, so they are +# rejected up front. +_INT32_MIN = -(2**31) +_INT32_MAX = 2**31 - 1 + # Threads-per-grid / threads-per-threadgroup must be 3-tuples per the Metal # `dispatchThreads` API. _THREAD_TUPLE_LEN = 3 @@ -40,6 +54,19 @@ class TorchMetalKernel(CustomMetalKernel): """ torch_custom_op: CustomOpDef + # Map from kernel input name to Python scalar type (``int``, ``float``, or + # ``bool``) for inputs whose ``torch_defn`` annotation is a scalar. Read by + # ``register_custom_kernels`` to convert FX scalar args with the natural + # dtype, and by :meth:`_validate_and_segregate_inputs` to set the MSL + # parameter dtype. + _scalar_input_types: dict[str, type] + # Per-distinct-scalar-values kernel caches. Scalar-bearing kernels bake the + # literal into the body, so the base class's single cache keyed only on + # ``(rank, dtype)`` would let call sites with different scalar values + # collide. Each frozen scalar-values tuple gets its own sub-cache so that + # identical ``(scalar_values, rank, dtype)`` call sites still share a PSO. + # See :meth:`_construct_kernel_op`. + _scalar_kernel_caches: dict[tuple[tuple[str, Any], ...], dict[Any, Any]] def __init__( # noqa: PLR0913 self: Self, @@ -80,6 +107,12 @@ def __init__( # noqa: PLR0913 torch_sig = inspect.signature(torch_defn, eval_str=True) self._validate_torch_inputs(torch_sig) self._validate_torch_returns(torch_sig) + self._scalar_input_types = { + input_names[i]: param.annotation + for i, param in enumerate(torch_sig.parameters.values()) + if param.annotation in _ALLOWED_SCALARS + } + self._scalar_kernel_caches = {} self.torch_custom_op = self._construct_torch_custom_op(torch_defn) super().__init__( @@ -276,6 +309,103 @@ def _(*args: Any) -> Any: return torch_custom_op + # ------------------------------------------------------------------ + # Scalar-aware override + # ------------------------------------------------------------------ + + def _validate_and_segregate_inputs(self: Self, input_values: list[Any]) -> Any: + """Run the parent's segregation, then patch metal_dtype for scalar inputs. + + The IR-level element type for a bool scalar is ui8 (i1 isn't accepted + by the metal4_kernel verifier), so the parent class would emit the MSL + parameter as ``constant uint8_t&``. Override that to ``constant bool&`` + — and similarly pin int/float scalar dtypes to their natural Python + names — using the per-kernel scalar type map captured in ``__init__``. + """ + segregated = super()._validate_and_segregate_inputs(input_values) + for _val, meta in segregated.kernel_inputs: + if meta.rank == 0 and meta.name in self._scalar_input_types: + py_type = self._scalar_input_types[meta.name] + meta.metal_dtype = _SCALAR_METAL_DTYPE[py_type] + return segregated + + def _construct_kernel_op( + self: Self, + input_values: list[Any], + result_types: list[Any], + ) -> Any: + """Bake scalar values into the kernel body, then delegate to the base op. + + The runtime binds rank-0 inputs as ``MTLTensor`` resource handles, so a + ``constant T&`` parameter declared in the kernel source can't be + dereferenced as a value — it would read from the handle, not the + scalar's storage. Workaround: keep the parameter declaration intact + (the IR contract still surfaces ``constant T& ``) but shadow it + inside the body with a local variable initialized to the literal, + so the user-written body still resolves the name to the right value. + """ + scalar_values: dict[str, Any] = getattr(self, "_scalar_values_for_call", {}) + if not scalar_values: + return super()._construct_kernel_op(input_values, result_types) + + original_src = self.src + original_cache = self.kernel_cache + # Inject first so an invalid scalar raises before we touch any cache. + injected_src = self._inject_scalar_locals(original_src, scalar_values) + # The base cache is keyed only on (rank, dtype), so two call sites with + # different scalar values — and therefore different baked source — would + # collide. Rather than discard the cache (which also forfeits legitimate + # reuse), give each distinct set of scalar values its own persistent + # sub-cache: identical (scalar_values, rank, dtype) call sites then share + # a single templated kernel / PSO, while differing scalar values stay + # isolated. Sorting makes the key order-independent; names are unique so + # values are never compared across types. + scalar_key = tuple(sorted(scalar_values.items())) + self.src = injected_src + self.kernel_cache = self._scalar_kernel_caches.setdefault(scalar_key, {}) + try: + return super()._construct_kernel_op(input_values, result_types) + finally: + self.src = original_src + self.kernel_cache = original_cache + + def _inject_scalar_locals( + self: Self, + src: str, + scalar_values: dict[str, Any], + ) -> str: + """Prepend ``T name = literal;`` declarations that shadow scalar params.""" + decls: list[str] = [] + for name, value in scalar_values.items(): + py_type = self._scalar_input_types[name] + msl_type = _SCALAR_METAL_DTYPE[py_type] + if py_type is bool: + literal = "true" if value else "false" + elif py_type is int: + int_value = int(value) + if not (_INT32_MIN <= int_value <= _INT32_MAX): + err = ( + f"int scalar {name!r}={int_value!r} is outside the " + f"32-bit int range that MSL `int` supports" + ) + raise ValueError(err) + literal = str(int_value) + else: + float_value = float(value) + if not math.isfinite(float_value): + err = ( + f"float scalar {name!r}={float_value!r} is not finite; " + "NaN/Inf scalars are not supported" + ) + raise ValueError(err) + literal = f"{float_value!r}f" + decls.append(f"{msl_type} {name} = {literal};") + # Wrap the body in a nested block so the locals can shadow the + # function parameters (a same-scope redeclaration would be illegal). + # Use newline separators so a trailing line comment in `src` cannot + # accidentally swallow the closing brace. + return "{\n" + "\n".join(decls) + "\n" + src + "\n}" + # ------------------------------------------------------------------ # Callable interface # ------------------------------------------------------------------ diff --git a/coreai_torch/_utils.py b/coreai_torch/_utils.py index a2211b0..832746e 100644 --- a/coreai_torch/_utils.py +++ b/coreai_torch/_utils.py @@ -1034,6 +1034,20 @@ def get_operands( return [get_operand(values_map, node, i, loc) for i in indices] +def scalar_constant(py_type: type, value: Any) -> Value: + """Create a coreai.constant for a scalar kernel arg with the natural dtype. + + Bypasses the fp16 promotion :func:`get_operand` applies to Python floats so + the MSL parameter ends up as ``constant float&`` even when the surrounding + tensors are fp16. Bool widens to ui8 because ``i1`` is rejected by the + metal4_kernel verifier; the MSL signature still emits ``constant bool&`` + via :class:`~coreai_torch._torch_metal_kernel.TorchMetalKernel`'s + metal_dtype override. + """ + np_dtype = {bool: np.uint8, int: np.int32, float: np.float32}[py_type] + return coreai.constant(np.array(value, dtype=np_dtype)) + + def build_shape_tensor( values_map: dict[str, Value], shape: list[int | fx.Node], @@ -1842,19 +1856,6 @@ def _resolve_io_names( return graph_input_names, resolved_output_names, fx_to_output -def _get_debug_info_enabled() -> bool: - """Get debug info enable flag from ENABLE_DEBUG_INFO environment variable. - - By default, debug info is not enabled for performance reasons. - Set ENABLE_DEBUG_INFO=true to enable debug information generation. - - Returns: - True if debug info should be enabled, False otherwise. - Defaults to False if environment variable is not set. - """ - return os.getenv("ENABLE_DEBUG_INFO", "false").lower() in ("true", "1", "yes", "on") - - def _get_verify_debuginfo_locations_enabled() -> bool: """Get debuginfo location verification flag from VERIFY_DEBUGINFO_LOCATIONS environment variable. diff --git a/coreai_torch/converter.py b/coreai_torch/converter.py index 328c08c..63fc78f 100644 --- a/coreai_torch/converter.py +++ b/coreai_torch/converter.py @@ -8,6 +8,7 @@ from collections import OrderedDict from collections.abc import Iterator, Sequence from dataclasses import dataclass +from enum import Enum from typing import Any, Callable, Optional, cast import coreai._compiler._mlir_libs._coreaiIR._bindings.mlir as _mlir # type: ignore[attr-defined] @@ -43,7 +44,6 @@ from ._torch_metal_kernel import TorchMetalKernel from ._utils import ( _NARROW_TORCH_DTYPE, - _get_debug_info_enabled, _get_mutation_output_name, _get_verify_debuginfo_locations_enabled, _ProgressBar, @@ -51,11 +51,12 @@ check_result_type, get_invoke_from_graph, get_namespace, - get_operands, + get_operand, get_result_types, get_target, get_tensor_type, preprocess_graph, + scalar_constant, strip_variant_from_target, validate_and_cast_numpy_array, ) @@ -103,12 +104,54 @@ def __exit__(self, *args: object) -> None: class TorchConverter: - def __init__(self) -> None: + class Mode(Enum): + """Controls the level of debug information embedded in the converted asset. + + Attributes: + RELEASE: Lightweight mode that records only operation IDs without + stack traces. + DEBUG: Includes full torch stack traces for comprehensive source + mapping and debugging. + """ + + DEBUG = "debug" + RELEASE = "release" + + @staticmethod + def _create_debug_info_recorder( + mode: "TorchConverter.Mode", + ) -> _DebugInfoRecorder: + """Create and configure a DebugInfoRecorder based on the converter mode. + + Args: + mode: The converter mode that determines whether stack traces are + included. + + Returns: + A configured _DebugInfoRecorder instance. + """ + include_stack_trace = mode == TorchConverter.Mode.DEBUG + debug_config = _DebugInfoRecorder.Config( + include_stack_trace=include_stack_trace, + verify_debuginfo_locations=_get_verify_debuginfo_locations_enabled(), + ) + return _DebugInfoRecorder(config=debug_config) + + def __init__(self, *, mode: "TorchConverter.Mode" = Mode.DEBUG) -> None: """Create a reusable converter engine. + Args: + mode: Controls the level of debug information embedded in the + converted asset. Use ``TorchConverter.Mode.RELEASE`` + for lightweight operation-ID-only tracking, or ``TorchConverter.Mode.DEBUG`` (default) + for full torch stack traces. Call + :func:`coreai_torch.debugging.debug_info.strip_debug_info` + to remove debug metadata from an already-converted program. + Reusable state (custom op lowerings) is retained across calls to ``to_coreai()``. Per-conversion transient state is reset each time. """ + self._mode = mode self.context = Context() # user defined torch op lowering (reusable across conversions) @@ -121,17 +164,7 @@ def __init__(self) -> None: self._progress_bar: _ProgressBar | None = None # Debug info recorder for comprehensive debug tracking - options = ( - _DebugInfoRecorder.Options.DEBUGINFO - if _get_debug_info_enabled() - else _DebugInfoRecorder.Options.STANDARD - ) - debug_config = _DebugInfoRecorder.Config( - include_stack_trace=True, - options=options, - verify_debuginfo_locations=_get_verify_debuginfo_locations_enabled(), - ) - self._debug_info_recorder = _DebugInfoRecorder(config=debug_config) + self._debug_info_recorder = self._create_debug_info_recorder(mode) def _init_conversion_state(self) -> None: """Reset per-conversion transient state.""" @@ -1024,10 +1057,26 @@ def _( loc: Location, _k: TorchMetalKernel = kernel, ) -> Value | list[Value]: - input_values = get_operands( - values_map, node, list(range(len(node.args))) - ) - results = _k._construct_kernel_op(input_values, get_result_types(node)) + input_values: list[Value] = [] + scalar_values: dict[str, Any] = {} + for idx in range(len(node.args)): + arg = node.args[idx] + scalar_type: type | None = None + if idx < len(_k.input_names) and not isinstance(arg, fx.Node): + scalar_type = _k._scalar_input_types.get(_k.input_names[idx]) + if scalar_type is not None: + input_values.append(scalar_constant(scalar_type, arg)) + scalar_values[_k.input_names[idx]] = arg + else: + input_values.append(get_operand(values_map, node, idx)) + _k._scalar_values_for_call = scalar_values + try: + results = _k._construct_kernel_op( + input_values, + get_result_types(node), + ) + finally: + _k._scalar_values_for_call = {} return results[0] if len(results) == 1 else results return self diff --git a/coreai_torch/debugging/debug_info.py b/coreai_torch/debugging/debug_info.py index 6f76072..1585b58 100644 --- a/coreai_torch/debugging/debug_info.py +++ b/coreai_torch/debugging/debug_info.py @@ -11,6 +11,21 @@ from dataclasses import dataclass from typing import Any +import coreai._compiler._mlir_libs._coreaiIR._bindings.mlir as _mlir +from coreai._compiler._mlir_libs._coreaiIR._bindings.mlir import ( + set_block_arg_location, + set_op_location, +) +from coreai._compiler.ir import Location, Operation, WalkResult +from coreai.authoring import AIProgram + +from coreai_torch._debug_locations import ( + OperationID, + _create_unknown_location, + _create_unknown_location_with_operation_id, + _get_nested_operations, +) + @dataclass class SourceInfo: @@ -245,11 +260,9 @@ def get_metadata(self, key: str) -> Metadata.Value | None: return m.value return None - def get_op_id(self, level: str) -> int | str | None: + def get_op_id(self, level: str) -> int | None: """ - Get operation ID for a given dialect level. - - New format: metadata with key "op_id" containing dictionary {"type": "", "value": } + Get the first operation ID for a given dialect level. Args: ---- @@ -260,29 +273,46 @@ def get_op_id(self, level: str) -> int | str | None: Operation ID if present, None otherwise """ - # Look for all metadata with key "op_id" + ids = self.get_op_ids(level) + return ids[0] if ids else None + + def get_op_ids(self, level: str) -> list[int]: + """ + Get all operation IDs for a given dialect level. + + Collects every "op_id" metadata entry whose "type" field matches + *level* and returns the corresponding "value" fields. Only integer + values are produced (see :meth:`_get_int_field`). + + Args: + ---- + level: Dialect level name (e.g., "torch", "coreai") + + Returns: + ------- + List of operation IDs matching the given level + + """ + ids: list[int] = [] for metadata in self.metadatas: if metadata.key != "op_id": continue - # Check if value is a dictionary if metadata.value.value_type != "dictionary" or not isinstance( metadata.value.value, dict, ): continue - # Check if "type" field matches the level type_field = self._get_str_field(metadata.value.value, "type") if type_field != level: continue - # Return the "value" field value_field = self._get_int_field(metadata.value.value, "value") if value_field is not None: - return value_field + ids.append(value_field) - return None + return ids def get_source(self) -> str | None: """Get source identifier if present.""" @@ -381,12 +411,31 @@ def _parse_mapping_from_dict( target_output=tgt_output, ) + def get_all_metadata(self, key: str) -> list[Metadata.Value]: + """Get all metadata values matching the given key. + + Unlike :meth:`get_metadata` which returns only the first match, + this method collects every ``Metadata`` entry whose key equals + *key* and returns the corresponding values. + + Args: + ---- + key: Metadata key to search for. + + Returns: + ------- + List of matching :class:`Metadata.Value` instances (may be empty). + + """ + return [m.value for m in self.metadatas if m.key == key] + def get_output_mappings(self, source_level: str) -> list[OutputMapping]: """ Get output mappings from source level by parsing metadata. - Parses metadata with key "output_maps" containing an array of dictionaries, - each with 'source' and 'target' fields containing level, output, and id. + Parses all metadata entries with key "output_maps" containing an + array of dictionaries, each with 'source' and 'target' fields + containing level, output, and id. Args: ---- @@ -397,20 +446,18 @@ def get_output_mappings(self, source_level: str) -> list[OutputMapping]: List of OutputMapping objects """ - output_maps_metadata = self.get_metadata("output_maps") - if ( - not output_maps_metadata - or output_maps_metadata.value_type != "array" - or not isinstance(output_maps_metadata.value, list) - ): - return [] - - mappings = [] - for elem in output_maps_metadata.value: - if elem.value_type == "dictionary" and isinstance(elem.value, dict): - mapping = self._parse_mapping_from_dict(elem.value, source_level) - if mapping: - mappings.append(mapping) + mappings: list[OutputMapping] = [] + for output_maps_metadata in self.get_all_metadata("output_maps"): + if output_maps_metadata.value_type != "array" or not isinstance( + output_maps_metadata.value, list + ): + continue + + for elem in output_maps_metadata.value: + if elem.value_type == "dictionary" and isinstance(elem.value, dict): + mapping = self._parse_mapping_from_dict(elem.value, source_level) + if mapping: + mappings.append(mapping) return mappings @@ -463,3 +510,63 @@ def parse_debug_infos(debug_infos_bytes: bytes) -> list[DebugInfoRecord]: debug_infos_data = json.loads(debug_infos_str) return [DebugInfoRecord.from_dict(item) for item in debug_infos_data] + + +def _build_coreai_op_map(program: AIProgram) -> dict[int, "Operation"]: + """Build a mapping from coreai operation ID to MLIR ``Operation``. + + Walks all operations in *program* and collects each one that carries + a ``"coreai"`` operation ID in its debug location metadata. + + Args: + program: The AIProgram to inspect. + + Returns: + Dictionary mapping coreai op ID to the MLIR ``Operation``. + """ + op_map: dict[int, Operation] = {} + + def _collect(operation: Operation) -> WalkResult: + op_id = _mlir.get_operation_id(operation.location, "coreai") + if op_id is not None: + op_map[op_id.value] = operation + return WalkResult.ADVANCE + + program._mlir_module.operation.walk(_collect) + return op_map + + +def strip_debug_info(program: AIProgram) -> None: + """Strip debugging information from all operations in the program. + + This is useful for reducing asset size when full debug traces are + no longer needed. + + Args: + program: The AIProgram to strip debug info from. Modified in place. + """ + module = program._mlir_module + module_op = module.operation + context = module_op.context + + # Create a shared unknown source location for reuse + unknown_src = Location.file(filename="", line=0, col=0, context=context) + + # Set module operation location (no operation ID for the module itself) + module_location = _create_unknown_location(unknown_src, context) + set_op_location(module_op, module_location) + + # Walk all nested operations and assign fresh sequential IDs + operation_id = 0 + for nested_op in _get_nested_operations(module_op): + op_id = OperationID(type="coreai", value=operation_id) + operation_id += 1 + + loc = _create_unknown_location_with_operation_id(op_id, unknown_src, context) + set_op_location(nested_op, loc) + + # Update block argument locations + for region in nested_op.regions: + for block in region: + for arg in block.arguments: + set_block_arg_location(arg, loc) diff --git a/coreai_torch/debugging/inspector.py b/coreai_torch/debugging/inspector.py index 4a3fe89..72a0ec7 100644 --- a/coreai_torch/debugging/inspector.py +++ b/coreai_torch/debugging/inspector.py @@ -5,7 +5,9 @@ from __future__ import annotations +import asyncio import logging +import os from abc import ABC, abstractmethod from collections import OrderedDict, defaultdict from collections.abc import Callable, Mapping, Sequence @@ -24,9 +26,29 @@ logger = logging.getLogger(__name__) +def _running_under_pytest() -> bool: + """Return True if the code is currently executing within a pytest run.""" + return "PYTEST_CURRENT_TEST" in os.environ + + +async def _wait_for_async_callbacks() -> None: + """ + Wait for asynchronously-invoked intermediate capture callbacks to complete. + + TODO: This sleep is a temporary workaround. The intermediate capture + callbacks are invoked asynchronously and may not have completed by the + time inference returns. Skipped under pytest to avoid slowing tests. + """ + if _running_under_pytest(): + return + await asyncio.sleep(5.0) + + @dataclass(frozen=True) class _MappingKey: - """Key for mapping ODIX outputs to source outputs.""" + """ + Key for mapping ODIX outputs to source outputs. + """ odix_id: int delegate_id: int | None @@ -41,123 +63,167 @@ class _CompiledIdMappings: all_compiled_ids: list[tuple[int, int | None]] -def _map_source_op_to_compiled_ops( +def _build_source_to_odix_map( + debug_info_records: list[DebugInfoRecord], source_level: str, - source_op_id: int, +) -> dict[int, int]: + """ + Build a mapping from source op ID to odix ID from odix debug info records. + + Iterates over all ``"odix"`` records and extracts the source-level + op ID and the ``"odix"`` op ID from each operation's metadata. + + Args: + debug_info_records: Parsed debug information containing + operation mappings. + source_level: Dialect level to extract source op IDs from + (e.g., ``"coreai"``). Defaults to ``"coreai"``. + Returns: + Dictionary mapping source_op_id to odix_id. + + """ + source_to_odix: dict[int, int] = {} + for record in debug_info_records: + if not record.identifier.startswith("odix"): + continue + for op in record.operations: + source_ids = op.get_op_ids(source_level) + for source_id in source_ids: + existing_odix_id = source_to_odix.get(source_id) + if existing_odix_id is None or existing_odix_id < op.odix_id: + source_to_odix[source_id] = op.odix_id + return source_to_odix + + +def _build_compile_identifiers_map( debug_info_records: list[DebugInfoRecord], -) -> dict[int, tuple[int, int, int | None]]: + source_level: str, +) -> dict[tuple[int, int], _MappingKey]: """ - Map a source operation to its compiled ODIX operations. + Build a mapping from source output to compiled identifiers. - Maps each output of a source-level operation (e.g., a PyTorch operation) - to its corresponding compiled ODIX operation ID and output index. - When multiple mappings exist for the same source output, keeps the one - with the highest target_op_id (representing the final compiled form). + Maps each ``(source_op_id, source_output_idx)`` to a + ``_MappingKey(odix_id, delegate_id, output_idx)`` by extracting + output mappings from every debug info record at the given + *source_level*. When duplicates target the same source output, the + highest ``target_op_id`` wins. For odix records + (``identifier.startswith("odix")``), ``target_op_id`` is the + ``odix_id``; for all other records it is the ``delegate_id``, and + the true ``odix_id`` is resolved via ``_build_source_to_odix_map``. Args: - source_level: Source dialect level (e.g., "torch" for PyTorch, "coreai" for Core AI) - source_op_id: Unique identifier of the source operation - debug_info_records: Parsed debug information containing operation mappings + debug_info_records: Parsed debug information containing + operation mappings. + source_level: Dialect level to extract op IDs from + (e.g., ``"coreai"``). Defaults to ``"coreai"``. Returns: - Dictionary mapping source output index to (odix_id, odix_output_index, delegate_id). - The delegate_id is currently always None (reserved for future delegate support). + Dictionary mapping ``(source_op_id, source_output_idx)`` to + ``_MappingKey``. """ - compiled_ids: dict[int, tuple[int, int, int | None]] = {} - - logger.debug("Mapping %s.%d to ODIX", source_level, source_op_id) + source_to_odix_map = _build_source_to_odix_map(debug_info_records, source_level) + result: dict[tuple[int, int], _MappingKey] = {} for record in debug_info_records: - if not record.identifier.startswith("odix"): - continue + is_odix = record.identifier.startswith("odix") for op in record.operations: - mappings = op.get_output_mappings(source_level) - for mapping in mappings: - # Filter mappings for the specific source operation - if ( - mapping.source_op_id == source_op_id - and mapping.target_level == "odix" - ): - # Check if we already have a mapping for this source output - existing = compiled_ids.get(mapping.source_output) - - # Only update if this is a new mapping or has a higher target_op_id - if existing is None or mapping.target_op_id > existing[0]: - compiled_ids[mapping.source_output] = ( + for mapping in op.get_output_mappings(source_level=source_level): + # For delegate records, resolve odix_id via the + # source-to-odix lookup or the op's own odix metadata. + if not is_odix: + odix_id = source_to_odix_map.get(mapping.source_op_id) + if odix_id is None: + continue + + source_key = (mapping.source_op_id, mapping.source_output) + existing = result.get(source_key) + + # Compare against the relevant ID from the existing entry: + # odix_id for odix records, delegate_id for delegate records. + existing_op_id = ( + (existing.odix_id if is_odix else existing.delegate_id) + if existing is not None + else None + ) + + # Only update if new or has a higher target_op_id + if existing_op_id is None or mapping.target_op_id > existing_op_id: + if is_odix: + new_entry = _MappingKey( + odix_id=op.odix_id, + delegate_id=None, + output_idx=mapping.target_output, + ) + else: + new_entry = _MappingKey( + odix_id=odix_id, + delegate_id=mapping.target_op_id, + output_idx=mapping.target_output, + ) + result[source_key] = new_entry + + if existing is not None: + logger.debug( + " %s.%d[%d] -> %s.%d[%d] (replaced %d)", + source_level, + mapping.source_op_id, + mapping.source_output, + record.identifier, + mapping.target_op_id, + mapping.target_output, + existing_op_id, + ) + else: + logger.debug( + " %s.%d[%d] -> %s.%d[%d]", + source_level, + mapping.source_op_id, + mapping.source_output, + record.identifier, mapping.target_op_id, mapping.target_output, - None, ) - if existing is not None: - logger.debug( - " %s.%d[%d] -> odix.%d[%d] (replaced odix.%d)", - source_level, - source_op_id, - mapping.source_output, - mapping.target_op_id, - mapping.target_output, - existing[0], - ) - else: - logger.debug( - " %s.%d[%d] -> odix.%d[%d]", - source_level, - source_op_id, - mapping.source_output, - mapping.target_op_id, - mapping.target_output, - ) - - if not compiled_ids: - logger.warning("No ODIX mapping found for %s.%d", source_level, source_op_id) - - return compiled_ids + + return result def _create_operation_mappings( op_ids: Sequence[int], - source_level: str, - debug_info_records: list[DebugInfoRecord], + compile_map: dict[tuple[int, int], _MappingKey], ) -> _CompiledIdMappings: """ - Create bidirectional mappings between source and compiled operations. + Create reverse mappings from compiled identifiers back to source outputs. - Builds mappings that allow translating between source-level operations (e.g., PyTorch) - and their compiled ODIX representations. This enables capturing intermediate values - from compiled models and mapping them back to source operations. + Filters *compile_map* to the requested *op_ids* and inverts the + direction: the returned ``target_to_source_output_map`` is keyed by + ``_MappingKey`` (compiled side) and valued by + ``(source_op_id, source_output_idx)``. Args: - op_ids: List of source operation IDs to create mappings for - source_level: Source dialect level (e.g., "torch" for PyTorch, "coreai" for Core AI) - debug_info_records: Parsed debug information containing operation mappings + op_ids: Source operation IDs to include. + compile_map: Pre-built map from + ``_build_compile_identifiers_map``. Returns: - _CompiledIdMappings containing: - - target_to_source_output_map: Maps ODIX outputs back to source outputs via _MappingKey - - all_compiled_ids: Unique list of (odix_id, delegate_id) pairs for all operations + ``_CompiledIdMappings`` with the reverse map and a list of + ``(odix_id, delegate_id)`` pairs for all matched operations. """ + requested = set(op_ids) target_to_source_output_map: dict[_MappingKey, tuple[int, int]] = {} all_compiled_ids: list[tuple[int, int | None]] = [] - for source_op_id in op_ids: - compiled_ids = _map_source_op_to_compiled_ops( - source_level, + for (source_op_id, source_output_idx), mapping_key in compile_map.items(): + if source_op_id not in requested: + continue + all_compiled_ids.append((mapping_key.odix_id, mapping_key.delegate_id)) + target_to_source_output_map[mapping_key] = ( source_op_id, - debug_info_records, + source_output_idx, ) - for source_output_idx, ( - odix_id, - target_output, - delegate_id, - ) in compiled_ids.items(): - all_compiled_ids.append((odix_id, delegate_id)) - mapping_key = _MappingKey(odix_id, delegate_id, target_output) - target_to_source_output_map[mapping_key] = (source_op_id, source_output_idx) - return _CompiledIdMappings(target_to_source_output_map, all_compiled_ids) @@ -554,6 +620,10 @@ def __init__( self._debug_info_records = parse_debug_infos(debug_infos_bytes) self._source_level = "coreai" + self._compile_map = _build_compile_identifiers_map( + self._debug_info_records, + self._source_level, + ) def _build_mapping_and_compile_ids( self, @@ -577,11 +647,7 @@ def _build_mapping_and_compile_ids( - compile_identifiers: List of unique compiled operation IDs to capture """ - mappings = _create_operation_mappings( - op_ids, - self._source_level, - self._debug_info_records, - ) + mappings = _create_operation_mappings(op_ids, self._compile_map) # Get unique compiled IDs preserving insertion order (dict.fromkeys for stable deduplication) unique_compiled_ids = dict.fromkeys(mappings.all_compiled_ids) @@ -640,36 +706,43 @@ def capture_callback( compile_ids.delegate_id, odix_output_idx, ) - if mapping_key in odix_output_to_source_map: - source_op_id, source_output_idx = odix_output_to_source_map[ - mapping_key - ] - if source_output_idx in results[source_op_id]: - msg = f"Multiple compile_ids map to the same source operation output: source_op_id={source_op_id}, source_output_idx={source_output_idx}" - raise ValueError(msg) - if intermediate is not None: - # Convert _NDArray (internal Core AI runtime type) to NDArray wrapper then to numpy - ndarray = coreai.runtime.NDArray(intermediate) - results[source_op_id][source_output_idx] = ( - self.__class__.convert_to_numpy(ndarray) - ) - logger.debug( - " odix.%d[%d] -> source.%d[%d] shape=%s", - compile_ids.id, - odix_output_idx, - source_op_id, - source_output_idx, - ndarray.numpy().shape, - ) - else: - logger.warning( - " Intermediate is None for odix.%d[%d] -> source.%d[%d]", - compile_ids.id, - odix_output_idx, - source_op_id, - source_output_idx, - ) + if mapping_key not in odix_output_to_source_map: + logger.warning( + " No source mapping found for odix.%d[%d]", + compile_ids.id, + odix_output_idx, + ) + continue + + source_op_id, source_output_idx = odix_output_to_source_map[mapping_key] + + if source_output_idx in results[source_op_id]: + msg = f"Multiple compile_ids map to the same source operation output: source_op_id={source_op_id}, source_output_idx={source_output_idx}" + raise ValueError(msg) + + if intermediate is None: + logger.warning( + " Intermediate is None for odix.%d[%d] -> source.%d[%d]", + compile_ids.id, + odix_output_idx, + source_op_id, + source_output_idx, + ) + continue + # Convert _NDArray (internal Core AI runtime type) to NDArray wrapper then to numpy + ndarray = coreai.runtime._ndarray.NDArray._wrap(intermediate) + results[source_op_id][source_output_idx] = ( + self.__class__.convert_to_numpy(ndarray) + ) + logger.debug( + " odix.%d[%d] -> source.%d[%d] shape=%s", + compile_ids.id, + odix_output_idx, + source_op_id, + source_output_idx, + ndarray.numpy().shape, + ) return capture_callback @@ -748,6 +821,7 @@ async def get_intermediates_for_ops( ) outputs = await inference_function(inputs=ndarray_inputs) + await _wait_for_async_callbacks() self._last_outputs = {name: array.numpy() for name, array in outputs.items()} @@ -795,27 +869,51 @@ def get_compile_identifiers_for_op( (used by the Core AI Runtime). Args: - source_level: Source dialect level ("torch" for PyTorch, "coreai" for Core AI) + source_level: Source dialect level (e.g., ``"coreai"``) source_op_id: Source operation ID to look up debug_info_records: Debug information containing operation mappings Returns: - Dictionary mapping output index to CompileIdentifiers + Dictionary mapping source output index to CompileIdentifiers """ - compiled_ids = _map_source_op_to_compiled_ops( - source_level, - source_op_id, + compile_map = _build_compile_identifiers_map( debug_info_records, + source_level, ) return { - output_idx: coreai.runtime.CompileIdentifiers( - odix_id, - delegate_id, + source_output_idx: coreai.runtime.CompileIdentifiers( + mk.odix_id, + mk.delegate_id, + ) + for (op_id, source_output_idx), mk in compile_map.items() + if op_id == source_op_id + } + + @staticmethod + def get_all_compile_identifiers( + debug_info_records: list[DebugInfoRecord], + ) -> dict[int, coreai.runtime.CompileIdentifiers]: + """ + Get compiled operation identifiers for all coreai operations. + + Builds a mapping from every coreai op ID to its + ``CompileIdentifiers`` by processing all debug info records. + + Args: + debug_info_records: Debug information containing operation + mappings. + + Returns: + Dictionary mapping ``coreai_op_id`` to + ``CompileIdentifiers``. + + """ + compile_map = _build_compile_identifiers_map(debug_info_records, "coreai") + return { + source_op_id: coreai.runtime.CompileIdentifiers( + mk.odix_id, + mk.delegate_id, ) - for output_idx, ( - odix_id, - _target_output, - delegate_id, - ) in compiled_ids.items() + for (source_op_id, _output_idx), mk in compile_map.items() } diff --git a/pyproject.toml b/pyproject.toml index 11bd1ac..9eebcfa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ classifiers = [ ] requires-python = ">=3.11" dependencies = [ - "coreai-core==1.0.0b1", + "coreai-core==1.0.0b2", "ml-dtypes", "networkx", "numpy", diff --git a/tests/composite_ops/test_sdpa.py b/tests/composite_ops/test_sdpa.py index 53cd1e0..5ed4359 100644 --- a/tests/composite_ops/test_sdpa.py +++ b/tests/composite_ops/test_sdpa.py @@ -112,10 +112,10 @@ async def test_mha( # noqa: PLR0913 np.testing.assert_allclose( _torch_tensor_to_numpy_array(output_torch_eager), _mlx_array_to_numpy_array(output_mlx), - rtol={torch.float32: 1e-4, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ + rtol={torch.float32: 2e-3, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ dtype ], - atol={torch.float32: 1e-4, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ + atol={torch.float32: 2e-3, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ dtype ], ) @@ -186,10 +186,10 @@ async def test_mha_with_mask(dynamic: bool, dtype: torch.dtype) -> None: np.testing.assert_allclose( _torch_tensor_to_numpy_array(output_torch_eager), _mlx_array_to_numpy_array(output_mlx), - rtol={torch.float32: 1e-4, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ + rtol={torch.float32: 2e-3, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ dtype ], - atol={torch.float32: 1e-4, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ + atol={torch.float32: 2e-3, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ dtype ], ) @@ -267,10 +267,10 @@ async def test_gqa( # noqa: PLR0913 np.testing.assert_allclose( _torch_tensor_to_numpy_array(output_torch_eager), _mlx_array_to_numpy_array(output_mlx), - rtol={torch.float32: 1e-4, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ + rtol={torch.float32: 2e-3, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ dtype ], - atol={torch.float32: 1e-4, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ + atol={torch.float32: 2e-3, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ dtype ], ) @@ -353,10 +353,10 @@ async def test_gqa_with_sinks( np.testing.assert_allclose( _torch_tensor_to_numpy_array(output_torch_eager), _mlx_array_to_numpy_array(output_mlx), - rtol={torch.float32: 1e-4, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ + rtol={torch.float32: 2e-3, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ dtype ], - atol={torch.float32: 1e-4, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ + atol={torch.float32: 2e-3, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ dtype ], ) @@ -431,10 +431,10 @@ def test_basic_causal(dtype: torch.dtype) -> None: np.testing.assert_allclose( _torch_tensor_to_numpy_array(output_ours), _torch_tensor_to_numpy_array(output_hf), - rtol={torch.float32: 1e-3, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ + rtol={torch.float32: 2e-3, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ dtype ], - atol={torch.float32: 1e-3, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ + atol={torch.float32: 2e-3, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ dtype ], ) @@ -494,10 +494,10 @@ def test_sliding_window( np.testing.assert_allclose( _torch_tensor_to_numpy_array(output_ours), _torch_tensor_to_numpy_array(output_hf), - rtol={torch.float32: 1e-3, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ + rtol={torch.float32: 2e-3, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ dtype ], - atol={torch.float32: 1e-3, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ + atol={torch.float32: 2e-3, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ dtype ], ) @@ -666,10 +666,10 @@ def forward( await validate_numerical_output( coreai_program=converted_program, torch_out=output_torch_eager, - rtol={torch.float32: 1e-4, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ + rtol={torch.float32: 2e-3, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ dtype ], - atol={torch.float32: 1e-4, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ + atol={torch.float32: 2e-3, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ dtype ], query=query, @@ -826,10 +826,10 @@ def forward( await validate_numerical_output( coreai_program=converted_program, torch_out=output_torch_eager, - rtol={torch.float32: 1e-4, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ + rtol={torch.float32: 2e-3, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ dtype ], - atol={torch.float32: 1e-4, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ + atol={torch.float32: 2e-3, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ dtype ], query=query, @@ -987,10 +987,10 @@ def forward( await validate_numerical_output( coreai_program=converted_program, torch_out=output_torch_eager, - rtol={torch.float32: 1e-4, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ + rtol={torch.float32: 2e-3, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ dtype ], - atol={torch.float32: 1e-4, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ + atol={torch.float32: 2e-3, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ dtype ], query=query, @@ -1164,10 +1164,10 @@ def forward( await validate_numerical_output( coreai_program=converted_program, torch_out=output_torch_eager, - rtol={torch.float32: 1e-4, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ + rtol={torch.float32: 2e-3, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ dtype ], - atol={torch.float32: 1e-4, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ + atol={torch.float32: 2e-3, torch.float16: 1e-2, torch.bfloat16: 5e-2}[ dtype ], query=query, diff --git a/tests/debugging/test_benchmarker.py b/tests/debugging/test_benchmarker.py index e30812d..e55e90e 100644 --- a/tests/debugging/test_benchmarker.py +++ b/tests/debugging/test_benchmarker.py @@ -28,7 +28,6 @@ async def hierarchical_coreai_program() -> AIProgram: converter: TorchConverter = TorchConverter() converter._debug_info_recorder.config = _DebugInfoRecorder.Config( include_stack_trace=True, - options=_DebugInfoRecorder.Options.DEBUGINFO, verify_debuginfo_locations=True, ) converter.add_exported_program(exported_program, entrypoint_name="main") @@ -37,6 +36,7 @@ async def hierarchical_coreai_program() -> AIProgram: return coreai_program +@pytest.mark.skip(reason="debugger issue (will be solved later)") @pytest.mark.skipif(sys.platform != "darwin", reason="Test only runs on macOS") async def test_odix_to_coreai_id_conversion( hierarchical_coreai_program: AIProgram, diff --git a/tests/debugging/test_comparator.py b/tests/debugging/test_comparator.py index 4ae9220..44bb1ac 100644 --- a/tests/debugging/test_comparator.py +++ b/tests/debugging/test_comparator.py @@ -36,7 +36,6 @@ async def _create_coreai_program_from_model( converter: TorchConverter = TorchConverter() converter._debug_info_recorder.config = _DebugInfoRecorder.Config( include_stack_trace=True, - options=_DebugInfoRecorder.Options.DEBUGINFO, verify_debuginfo_locations=True, ) converter.add_exported_program(exported_program, entrypoint_name="main") @@ -580,7 +579,6 @@ def replace_operation(operation: Any) -> WalkResult: return found_count[0] -@pytest.mark.xfail(reason="Fails after coreai update", strict=False) @pytest.mark.skipif(sys.platform != "darwin", reason="Test only runs on macOS") async def test_comparator_catches_modified_ops_at_different_positions() -> None: """Test that comparator correctly identifies modified operations at different positions.""" diff --git a/tests/debugging/test_debug_info.py b/tests/debugging/test_debug_info.py index b8b92f6..dafe976 100644 --- a/tests/debugging/test_debug_info.py +++ b/tests/debugging/test_debug_info.py @@ -14,14 +14,18 @@ from coreai.authoring import AIProgram from coreai.runtime import AIModel +from coreai_torch._debug_locations import _get_nested_operations, _is_debuginfo_location from coreai_torch.converter import TorchConverter, _DebugInfoRecorder -from coreai_torch.debugging.debug_info import DebugInfoRecord, parse_debug_infos +from coreai_torch.debugging.debug_info import ( + DebugInfoRecord, + parse_debug_infos, + strip_debug_info, +) from .test_model import LinearMulAddModel, get_example_inputs -@pytest.fixture -async def simple_coreai_program() -> AIProgram: +def _create_debug_program() -> AIProgram: """Create a coreai_program program with debug info from a simple torch model.""" model = LinearMulAddModel().eval() example_inputs = get_example_inputs(LinearMulAddModel) @@ -31,12 +35,16 @@ async def simple_coreai_program() -> AIProgram: converter: TorchConverter = TorchConverter() converter._debug_info_recorder.config = _DebugInfoRecorder.Config( include_stack_trace=True, - options=_DebugInfoRecorder.Options.DEBUGINFO, verify_debuginfo_locations=True, ) converter.add_exported_program(exported_program, entrypoint_name="main") - coreai_program = converter.to_coreai() - return coreai_program + return converter.to_coreai() + + +@pytest.fixture +async def simple_coreai_program() -> AIProgram: + """Fixture providing a coreai_program with debug info.""" + return _create_debug_program() def _verify_debug_info_record(record: DebugInfoRecord) -> None: @@ -122,3 +130,24 @@ async def test_aimodel_debug_infos( # Verify structure _verify_debug_info_record(debug_info_records[0]) + + +@pytest.mark.asyncio +async def test_strip_debug_info() -> None: + """Test that strip_debug_info removes source locations and assigns fresh IDs.""" + coreai_program = _create_debug_program() + + # Strip debug info from the program + strip_debug_info(coreai_program) + + module_op = coreai_program._mlir_module.operation + + # Verify module location is a valid debuginfo location + assert _is_debuginfo_location(module_op.location) + + # Verify all nested operations have debuginfo locations + for nested_op in _get_nested_operations(module_op): + assert _is_debuginfo_location(nested_op.location), ( + f"Expected debuginfo location on {nested_op.name}, " + f"got: {nested_op.location}" + ) diff --git a/tests/debugging/test_graph_diff.py b/tests/debugging/test_graph_diff.py index 8426850..5428cfa 100644 --- a/tests/debugging/test_graph_diff.py +++ b/tests/debugging/test_graph_diff.py @@ -12,7 +12,7 @@ import torch from coreai.authoring import AIProgram -from coreai_torch.converter import TorchConverter, _DebugInfoRecorder +from coreai_torch.converter import TorchConverter from coreai_torch.debugging.graph_diff import ( compute_coreai_program_diff, compute_exported_program_diff, @@ -34,8 +34,7 @@ async def _create_coreai_program_from_model( exported_program: torch.export.ExportedProgram, ) -> AIProgram: """Create a coreai_program program from an exported program.""" - converter: TorchConverter = TorchConverter() - converter._debug_info_recorder._options = _DebugInfoRecorder.Options.DEBUGINFO + converter: TorchConverter = TorchConverter(mode=TorchConverter.Mode.DEBUG) converter.add_exported_program(exported_program, entrypoint_name="main") coreai_program = converter.to_coreai() return coreai_program diff --git a/tests/debugging/test_inspector.py b/tests/debugging/test_inspector.py index 11893bb..a68cc77 100644 --- a/tests/debugging/test_inspector.py +++ b/tests/debugging/test_inspector.py @@ -36,7 +36,6 @@ async def simple_coreai_program() -> AIProgram: converter: TorchConverter = TorchConverter() converter._debug_info_recorder.config = _DebugInfoRecorder.Config( include_stack_trace=True, - options=_DebugInfoRecorder.Options.DEBUGINFO, verify_debuginfo_locations=True, ) converter.add_exported_program(exported_program, entrypoint_name="main") @@ -44,7 +43,6 @@ async def simple_coreai_program() -> AIProgram: return coreai_program -@pytest.mark.asyncio async def test_torch_fx_inspector() -> None: """Test _TorchFXInspector with a simple torch model.""" model = LinearMulAddModel().eval() @@ -73,7 +71,6 @@ async def test_torch_fx_inspector() -> None: assert isinstance(results[op_name][0], np.ndarray) -@pytest.mark.asyncio async def test_caching_inspector() -> None: """Test CachingInspector with LRU caching.""" model = LinearMulAddModel().eval() @@ -110,7 +107,6 @@ async def test_caching_inspector() -> None: sys.platform != "darwin", reason="Requires loading a runtime asset (AIModel.load); only supported on macOS", ) -@pytest.mark.asyncio async def test_coreai_inspector(simple_coreai_program: AIProgram) -> None: """Test _CoreAIInspector with a deployed model.""" # Get torch -> coreai mappings @@ -155,7 +151,6 @@ async def test_coreai_inspector(simple_coreai_program: AIProgram) -> None: assert isinstance(item, np.ndarray) -@pytest.mark.asyncio async def test_torch_to_coreai_mappings( simple_coreai_program: AIProgram, ) -> None: diff --git a/tests/debugging/test_intermediates.py b/tests/debugging/test_intermediates.py new file mode 100644 index 0000000..bd3902d --- /dev/null +++ b/tests/debugging/test_intermediates.py @@ -0,0 +1,279 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +"""Tests verifying intermediate values match between PyTorch FX and CoreAI for all models.""" + +import logging +import sys +import tempfile +from pathlib import Path +from typing import Any + +import numpy as np +import pytest +import torch +from coreai.authoring import AIProgram +from coreai.runtime import AIModel +from numpy.typing import NDArray + +from coreai_torch.converter import TorchConverter, _DebugInfoRecorder +from coreai_torch.debugging.debug_info import ( + OutputMapping, + _build_coreai_op_map, +) +from coreai_torch.debugging.inspector import ( + CoreAIInspector, + Inspector, + TorchFXInspector, +) +from coreai_torch.debugging.torch_utils import ( + get_torch_to_coreai_output_mapping, +) + +from .test_model import ( + EXAMPLE_INPUTS, + LayerNormBlock, + SDPAAttentionBlock, + TinyTransformerBlock, + get_example_inputs, +) + +logger = logging.getLogger(__name__) + +# Models excluded from intermediate comparison. +EXCLUDED_MODEL_CLASSES = {LayerNormBlock, TinyTransformerBlock, SDPAAttentionBlock} + +ALL_MODEL_CLASSES = [ + cls for cls in EXAMPLE_INPUTS.keys() if cls not in EXCLUDED_MODEL_CLASSES +] + + +def _export_and_convert( + model_cls: type[torch.nn.Module], +) -> tuple[torch.export.ExportedProgram, AIProgram]: + """Export a model, decompose, and convert to AIProgram with debug info. + + Returns: + Tuple of (decomposed ExportedProgram, AIProgram). + """ + model = model_cls().eval() + example_inputs = get_example_inputs(model_cls) + exported_program = torch.export.export(model, args=tuple(example_inputs.values())) + exported_program = exported_program.run_decompositions() + + converter = TorchConverter() + converter._debug_info_recorder.config = _DebugInfoRecorder.Config( + include_stack_trace=True, + verify_debuginfo_locations=True, + ) + converter.add_exported_program(exported_program, entrypoint_name="main") + coreai_program = converter.to_coreai() + + return exported_program, coreai_program + + +async def _capture_torch_intermediates( + exported_program: torch.export.ExportedProgram, + torch_args: tuple[torch.Tensor, ...], +) -> dict[Inspector.OpID, list[NDArray[Any] | None] | None]: + """Run model through TorchFXInspector and capture all call_function intermediates.""" + inspector = TorchFXInspector(exported_program) + node_names = [ + node.name for node in exported_program.graph.nodes if node.op == "call_function" + ] + return await inspector.get_intermediates_for_ops(node_names, torch_args) + + +async def _capture_coreai_intermediates( + coreai_program: Any, + coreai_op_ids: set[int], + numpy_inputs: dict[str, NDArray[Any]], +) -> dict[Inspector.OpID, list[NDArray[Any] | None] | None]: + """Deploy AIProgram and capture intermediates via CoreAIInspector.""" + with tempfile.TemporaryDirectory() as tmpdir: + asset_path = Path(tmpdir) / "model.aimodel" + asset = coreai_program.save_asset(asset_path) + ai_model = await AIModel.load( + asset.path, + ) + inspector = CoreAIInspector(model=ai_model, function_name="main") + return await inspector.get_intermediates_for_ops( + list(coreai_op_ids), numpy_inputs + ) + + +def _compare_mapped_intermediates( + torch_intermediates: dict[Inspector.OpID, list[NDArray[Any] | None] | None], + coreai_intermediates: dict[Inspector.OpID, list[NDArray[Any] | None] | None], + mappings: dict[str, OutputMapping], + model_name: str, + coreai_op_map: dict[int, Any] | None = None, +) -> int: + """Compare torch and coreai intermediates for each mapped operation. + + Logs a warning for every torch op whose corresponding coreai output + cannot be found or has a shape mismatch. + + Returns: + Number of successfully compared operation outputs. + """ + compared = 0 + for torch_node_name, mapping in mappings.items(): + coreai_op = coreai_op_map.get(mapping.target_op_id) if coreai_op_map else None + + torch_values = torch_intermediates.get(torch_node_name) + if torch_values is None: + logger.warning( + "%s: no torch intermediate for '%s'", model_name, torch_node_name + ) + continue + + if len(torch_values) > 1: + logger.warning( + "%s: skipping torch op '%s' with multiple outputs (n=%d)", + model_name, + torch_node_name, + len(torch_values), + ) + continue + + coreai_values = coreai_intermediates.get(mapping.target_op_id) + if coreai_values is None: + logger.warning( + "%s: no coreai intermediate for torch op '%s' (expected coreai op %d)\n %s", + model_name, + torch_node_name, + mapping.target_op_id, + coreai_op, + ) + continue + + if mapping.source_output >= len(torch_values): + logger.warning( + "%s: torch op '%s' source_output %d out of range (len=%d)", + model_name, + torch_node_name, + mapping.source_output, + len(torch_values), + ) + continue + + torch_output = torch_values[mapping.source_output] + + if mapping.target_output >= len(coreai_values): + logger.warning( + "%s: coreai op %d target_output %d out of range (len=%d) " + "for torch op '%s'\n %s", + model_name, + mapping.target_op_id, + mapping.target_output, + len(coreai_values), + torch_node_name, + coreai_op, + ) + continue + + coreai_output = coreai_values[mapping.target_output] + + if torch_output is None or coreai_output is None: + logger.warning( + "%s: None output for torch op '%s' → coreai op %d " + "(torch=%s, coreai=%s)\n %s", + model_name, + torch_node_name, + mapping.target_op_id, + torch_output is not None, + coreai_output is not None, + coreai_op, + ) + continue + + # Squeeze and compare if the squeezed shapes match. + if torch_output.shape != coreai_output.shape: + squeezed_torch = np.squeeze(torch_output) + squeezed_coreai = np.squeeze(coreai_output) + if squeezed_torch.shape == squeezed_coreai.shape: + torch_output = squeezed_torch + coreai_output = squeezed_coreai + else: + logger.warning( + "%s: shape mismatch for torch op '%s' → coreai op %d: " + "torch %s vs coreai %s — skipping comparison\n %s", + model_name, + torch_node_name, + mapping.target_op_id, + torch_output.shape, + coreai_output.shape, + coreai_op, + ) + continue + + abs_diff = np.abs(torch_output - coreai_output) + assert np.allclose(torch_output, coreai_output, rtol=1e-3, atol=1e-3), ( + "%s: intermediate mismatch for torch op '%s' → coreai op %d " + "(max abs diff=%g, mean abs diff=%g)\n coreai op: %s" + % ( + model_name, + torch_node_name, + mapping.target_op_id, + abs_diff.max(), + abs_diff.mean(), + coreai_op, + ) + ) + + compared += 1 + logger.info( + "%s: ✓ torch op '%s' [%d] matches coreai op %d [%d] (shape=%s)\n %s", + model_name, + torch_node_name, + mapping.source_output, + mapping.target_op_id, + mapping.target_output, + torch_output.shape, + coreai_op, + ) + + return compared + + +@pytest.mark.skipif(sys.platform != "darwin", reason="Test only runs on macOS") +@pytest.mark.parametrize("model_cls", ALL_MODEL_CLASSES, ids=lambda cls: cls.__name__) +async def test_intermediates_torch_vs_coreai( + model_cls: type[torch.nn.Module], +) -> None: + """ + Verify intermediate values from PyTorch FX match CoreAI for each mapped op. + """ + example_inputs = get_example_inputs(model_cls) + torch_args = tuple(example_inputs.values()) + numpy_inputs = {k: v.numpy() for k, v in example_inputs.items()} + + exported_program, coreai_program = _export_and_convert(model_cls) + + torch_intermediates = await _capture_torch_intermediates( + exported_program, torch_args + ) + + mappings = get_torch_to_coreai_output_mapping(coreai_program) + coreai_op_map = _build_coreai_op_map(coreai_program) + coreai_op_ids = {m.target_op_id for m in mappings.values()} + + coreai_intermediates = await _capture_coreai_intermediates( + coreai_program, coreai_op_ids, numpy_inputs + ) + + compared_count = _compare_mapped_intermediates( + torch_intermediates, + coreai_intermediates, + mappings, + model_cls.__name__, + coreai_op_map=coreai_op_map, + ) + + assert compared_count > 0, ( + f"No intermediates were compared for {model_cls.__name__}. " + f"Found {len(mappings)} mappings but none had matching intermediates." + ) diff --git a/tests/debugging/test_location_bindings.py b/tests/debugging/test_location_bindings.py index 1aa27d9..464e7c4 100644 --- a/tests/debugging/test_location_bindings.py +++ b/tests/debugging/test_location_bindings.py @@ -28,7 +28,6 @@ async def simple_coreai_program() -> AIProgram: converter: TorchConverter = TorchConverter() converter._debug_info_recorder.config = _DebugInfoRecorder.Config( include_stack_trace=True, - options=_DebugInfoRecorder.Options.DEBUGINFO, verify_debuginfo_locations=True, ) converter.add_exported_program(exported_program, entrypoint_name="main") @@ -46,7 +45,6 @@ async def complex_coreai_program() -> AIProgram: converter: TorchConverter = TorchConverter() converter._debug_info_recorder.config = _DebugInfoRecorder.Config( include_stack_trace=True, - options=_DebugInfoRecorder.Options.DEBUGINFO, verify_debuginfo_locations=True, ) converter.add_exported_program(exported_program, entrypoint_name="main") diff --git a/tests/debugging/test_model.py b/tests/debugging/test_model.py index 7661277..c47a09e 100644 --- a/tests/debugging/test_model.py +++ b/tests/debugging/test_model.py @@ -57,14 +57,6 @@ def forward(self, x, y): return torch.clamp(z, -3.0, 3.0) -class BroadcastReduceBlock(torch.nn.Module): - def forward(self, x, bias): - y = x + bias - m = y.mean(dim=-1, keepdim=True) - v = y.var(dim=-1, correction=0, keepdim=True) - return (y - m) / torch.sqrt(v + 1e-5) - - class LayerNormBlock(torch.nn.Module): def __init__(self, dim=8): super().__init__() @@ -463,10 +455,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x=torch.randn(1, 2, 8), y=torch.randn(1, 2, 8), ), - BroadcastReduceBlock: lambda: OrderedDict( - x=torch.randn(1, 2, 8), - bias=torch.randn(8), - ), LayerNormBlock: lambda: OrderedDict( x=torch.randn(1, 2, 8), ), @@ -489,10 +477,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x=torch.randn(1, 2, 8), ), EmbeddingMLPBlock: lambda: OrderedDict( - tokens=torch.randint(0, 32, (1, 3)), + tokens=torch.randint(0, 32, (1, 3), dtype=torch.int32), ), TinyTransformerBlock: lambda: OrderedDict( - tokens=torch.randint(0, 32, (1, 3)), + tokens=torch.randint(0, 32, (1, 3), dtype=torch.int32), ), LinearMulAddModel: lambda: OrderedDict( x=torch.randn(2, 4), diff --git a/tests/debugging/test_torch_utils.py b/tests/debugging/test_torch_utils.py index 650edca..a89cfb9 100644 --- a/tests/debugging/test_torch_utils.py +++ b/tests/debugging/test_torch_utils.py @@ -277,7 +277,6 @@ def test_save_intermediates_with_coreai_program_program( converter: TorchConverter = TorchConverter() converter._debug_info_recorder.config = _DebugInfoRecorder.Config( include_stack_trace=True, - options=_DebugInfoRecorder.Options.DEBUGINFO, verify_debuginfo_locations=True, ) converter.add_exported_program(exported_program, entrypoint_name="main") diff --git a/tests/debugging/test_validator.py b/tests/debugging/test_validator.py index c8e240e..60ab95c 100644 --- a/tests/debugging/test_validator.py +++ b/tests/debugging/test_validator.py @@ -399,7 +399,6 @@ async def _create_coreai_program_from_model( converter: TorchConverter = TorchConverter() converter._debug_info_recorder.config = _DebugInfoRecorder.Config( include_stack_trace=True, - options=_DebugInfoRecorder.Options.DEBUGINFO, verify_debuginfo_locations=True, ) converter.add_exported_program(exported_program, entrypoint_name="main") @@ -416,11 +415,9 @@ async def _create_coreai_program_from_model( [ pytest.param( "fc1", - marks=pytest.mark.xfail(reason="Fails after coreai update", strict=False), ), pytest.param( "fc2", - marks=pytest.mark.xfail(reason="Fails after coreai update", strict=False), ), None, ], diff --git a/tests/dsl/test_dtype_specialization.py b/tests/dsl/test_dtype_specialization.py new file mode 100644 index 0000000..8ea4984 --- /dev/null +++ b/tests/dsl/test_dtype_specialization.py @@ -0,0 +1,348 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +"""Tests for ``template_dtypes`` / type specialization in custom Metal kernels. + +A ``CustomMetalKernel`` accepts a ``template_dtypes`` dict that maps +input-name → placeholder-string. At ``_construct_kernel_op`` time the +placeholder is substituted with the *actual* metal dtype string of that +input (``half``, ``float``, ``bfloat``, …). This produces a per-shape kernel +*variant* — the same Python kernel emitted twice with different dtypes +generates two distinct PSOs because the templated MSL source differs. + +These tests pin: + +* Single-template substitution lowers to the right metal type. +* Multiple template params on different inputs each get substituted + independently. +* The same Python kernel called twice in one model with different input + dtypes emits two distinct kernel sources (and randomized names). +* A template parameter on an input that is *also* fed into another op as + a regular data input continues to lower without issue. +""" + +from __future__ import annotations + +import sys +from typing import Any + +import pytest +import torch + +from coreai_torch import ( + MetalParameter, + TorchConverter, + TorchMetalKernel, + get_decomp_table, +) + + +def _convert_model( + model: torch.nn.Module, + args: tuple, + kernels: list[TorchMetalKernel], + output_names: list[str] | None = None, +) -> Any: + exported = torch.export.export(model, args=args) + ep = exported.run_decompositions(get_decomp_table()) + converter = TorchConverter() + converter.register_custom_kernels(kernels) + converter.add_exported_program(ep, output_names=output_names or []) + return converter.to_coreai() + + +def _kernel_source_strings(ir: str) -> list[str]: + """Extract the ``kernel_source = "..."`` string-attr value(s) from IR. + + Each ``coreai.metal4_kernel`` op carries the templated MSL source as a + string attribute. The randomization suffix on ``kernel_name`` makes + diffing kernels by name brittle, so tests inspect the source body. + """ + needle = 'kernel_source = "' + out = [] + pos = 0 + while True: + i = ir.find(needle, pos) + if i < 0: + return out + i += len(needle) + # Find the closing unescaped quote. + end = i + while end < len(ir): + if ir[end] == "\\": + end += 2 + continue + if ir[end] == '"': + break + end += 1 + out.append(ir[i:end]) + pos = end + 1 + + +# --------------------------------------------------------------------------- + + +class TestSingleTemplateSubstitution: + """A single ``template_dtypes`` entry substitutes into the MSL source.""" + + @staticmethod + @pytest.mark.parametrize( + ("dtype", "expected_metal_type"), + [ + (torch.float16, "half"), + (torch.float32, "float"), + (torch.bfloat16, "bfloat"), + ], + ) + def test_template_substitution_picks_metal_type( + dtype: torch.dtype, + expected_metal_type: str, + ) -> None: + """``TYPE`` in the body is replaced with the metal type of input ``x``.""" + + def torch_defn(x: torch.Tensor) -> torch.Tensor: + return x + 1 + + kernel = TorchMetalKernel( + "single_template", + input_names=["x"], + result_names=["out"], + src="out[id] = x[id] + TYPE(1.0);", + torch_defn=torch_defn, + metal_params=[MetalParameter("id", "uint", "thread_position_in_grid")], + template_dtypes={"x": "TYPE"}, + ) + + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return kernel( + x, + threads_per_grid=(x.shape[0], 1, 1), + threads_per_thread_group=(1, 1, 1), + result_shapes=[list(x.shape)], + ) + + coreai_program = _convert_model( + Model().eval(), + args=(torch.zeros(4, dtype=dtype),), + kernels=[kernel], + output_names=["out"], + ) + + sources = _kernel_source_strings(str(coreai_program)) + assert len(sources) == 1, ( + f"Expected exactly one metal4_kernel source, got {len(sources)}" + ) + # Placeholder replaced. + assert "TYPE" not in sources[0] + assert f"{expected_metal_type}(1.0)" in sources[0] + + +class TestMultipleTemplateParams: + """Multiple ``template_dtypes`` entries substitute independently.""" + + @staticmethod + def test_two_templates_substitute_independently() -> None: + """Inputs ``x`` and ``y`` map to ``T_X`` / ``T_Y`` with distinct dtypes.""" + + def torch_defn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x.to(torch.float32) + y.to(torch.float32) + + kernel = TorchMetalKernel( + "two_templates", + input_names=["x", "y"], + result_names=["out"], + src=("T_X xv = x[id]; T_Y yv = y[id]; out[id] = float(xv) + float(yv);"), + torch_defn=torch_defn, + metal_params=[MetalParameter("id", "uint", "thread_position_in_grid")], + template_dtypes={"x": "T_X", "y": "T_Y"}, + ) + + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return kernel( + x, + y, + threads_per_grid=(x.shape[0], 1, 1), + threads_per_thread_group=(1, 1, 1), + result_shapes=[list(x.shape)], + ) + + coreai_program = _convert_model( + Model().eval(), + args=( + torch.zeros(4, dtype=torch.float16), # x → half + torch.zeros(4, dtype=torch.float32), # y → float + ), + kernels=[kernel], + output_names=["out"], + ) + + sources = _kernel_source_strings(str(coreai_program)) + assert len(sources) == 1 + src = sources[0] + assert "T_X" not in src and "T_Y" not in src + # x's template substituted to `half`, y's to `float`. + assert "half xv" in src + assert "float yv" in src + + +class TestSameKernelTwoDtypeCombinations: + """Same Python kernel used with two dtype combinations → two distinct PSOs.""" + + @staticmethod + def test_two_invocations_with_different_dtypes_emit_two_sources() -> None: + """Each ``(rank, metal_dtype)`` combo bypasses the kernel cache.""" + + def torch_defn(x: torch.Tensor) -> torch.Tensor: + return x + + kernel = TorchMetalKernel( + "two_dtype_combos", + input_names=["x"], + result_names=["out"], + src="out[id] = x[id];", + torch_defn=torch_defn, + metal_params=[MetalParameter("id", "uint", "thread_position_in_grid")], + template_dtypes={"x": "TYPE"}, + ) + + class Model(torch.nn.Module): + def forward( + self, a: torch.Tensor, b: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + # `a` is float16 → kernel templated with `half`. + ra = kernel( + a, + threads_per_grid=(a.shape[0], 1, 1), + threads_per_thread_group=(1, 1, 1), + result_shapes=[list(a.shape)], + ) + # `b` is float32 → kernel templated with `float`. + rb = kernel( + b, + threads_per_grid=(b.shape[0], 1, 1), + threads_per_thread_group=(1, 1, 1), + result_shapes=[list(b.shape)], + ) + return ra, rb + + coreai_program = _convert_model( + Model().eval(), + args=( + torch.zeros(4, dtype=torch.float16), + torch.zeros(4, dtype=torch.float32), + ), + kernels=[kernel], + output_names=["ra", "rb"], + ) + + ir = str(coreai_program) + sources = _kernel_source_strings(ir) + assert len(sources) == 2, ( + f"Expected two distinct metal4_kernel sources (one per dtype), " + f"got {len(sources)}" + ) + # The two sources differ — one mentions `half`, the other `float`. + type_a, type_b = sources + assert type_a != type_b + assert "device half" in (type_a + type_b) + assert "device float" in (type_a + type_b) + # Two distinct randomized names — the kernel cache key includes dtype. + assert ir.count("coreai.metal4_kernel") == 2 + + +class TestTemplateOnPassthroughInput: + """A template-bound input that is also forwarded to other ops.""" + + @staticmethod + def test_template_input_is_also_a_data_input() -> None: + """Same tensor flows into the kernel and into a stock op — both work.""" + + def torch_defn(x: torch.Tensor) -> torch.Tensor: + return x + 1.0 + + kernel = TorchMetalKernel( + "template_passthrough", + input_names=["x"], + result_names=["out"], + src="out[id] = x[id] + TYPE(1.0);", + torch_defn=torch_defn, + metal_params=[MetalParameter("id", "uint", "thread_position_in_grid")], + template_dtypes={"x": "TYPE"}, + ) + + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + kernel_out = kernel( + x, + threads_per_grid=(x.shape[0], 1, 1), + threads_per_thread_group=(1, 1, 1), + result_shapes=[list(x.shape)], + ) + # x is also consumed by a stock aten op in the same graph. + return kernel_out + torch.relu(x) + + coreai_program = _convert_model( + Model().eval(), + args=(torch.zeros(4, dtype=torch.float16),), + kernels=[kernel], + output_names=["out"], + ) + ir = str(coreai_program) + # Kernel is emitted exactly once (same shape, same dtype → cache hit). + assert ir.count("coreai.metal4_kernel") == 1 + # The MSL source has had `TYPE` replaced with `half` (input dtype). + sources = _kernel_source_strings(ir) + assert "TYPE" not in sources[0] + assert "half(1.0)" in sources[0] + + +# --------------------------------------------------------------------------- +# Numerical: end-to-end on macOS +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(sys.platform != "darwin", reason="Metal tests run only on Mac") +class TestTemplateNumerical: + """Same kernel emitted with two dtypes produces correct results in both.""" + + @staticmethod + @pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) + async def test_template_specialized_kernel_numerical( + dtype: torch.dtype, + ) -> None: + from ..utils import validate_numerical_output + + def torch_defn(x: torch.Tensor) -> torch.Tensor: + return x * 2.0 + + kernel = TorchMetalKernel( + f"template_num_{str(dtype).split('.')[-1]}", + input_names=["x"], + result_names=["out"], + src="out[id] = x[id] * TYPE(2.0);", + torch_defn=torch_defn, + metal_params=[MetalParameter("id", "uint", "thread_position_in_grid")], + template_dtypes={"x": "TYPE"}, + ) + + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return kernel( + x, + threads_per_grid=(x.shape[0], 1, 1), + threads_per_thread_group=(1, 1, 1), + result_shapes=[list(x.shape)], + ) + + await validate_numerical_output( + model=Model().eval(), + custom_kernels=[kernel], + metal_inputs=True, + input_names=["x"], + output_names=["result"], + x=torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=dtype), + ) diff --git a/tests/dsl/test_kernel_collisions.py b/tests/dsl/test_kernel_collisions.py new file mode 100644 index 0000000..7e092df --- /dev/null +++ b/tests/dsl/test_kernel_collisions.py @@ -0,0 +1,269 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +"""Tests for kernel-name collision behavior. + +Two ``CustomMetalKernel`` instances each pick an 8-character random suffix per +``(rank, dtype)`` combination at lowering time, so the *emitted* function +names cannot collide across distinct kernels — even when the user-visible +``name`` field is identical. The MPS runtime's ``MPSRuntime.mm`` dedupe +(by ``function_name``) is therefore not directly reachable from +coreai-torch unless the same kernel is reused. + +These tests pin: + +* Two kernels with the same ``name`` are rejected at + ``register_custom_kernels`` time, with a clear "already registered" error. +* A kernel that is registered, used, then re-registered (or registered twice + in the same call) fails the same way. +* A single kernel reused with the same ``(rank, dtype)`` reuses the cached + randomized name (the cache short-circuits MSL re-templating). +* A single kernel reused with *different* ``(rank, dtype)`` produces two + distinct randomized names — exercising the cache-miss path. +""" + +from __future__ import annotations + +from typing import Any + +import pytest +import torch + +from coreai_torch import ( + MetalParameter, + TorchConverter, + TorchMetalKernel, + get_decomp_table, +) + + +def _convert_model( + model: torch.nn.Module, + args: tuple, + kernels: list[TorchMetalKernel], + output_names: list[str] | None = None, +) -> Any: + exported = torch.export.export(model, args=args) + ep = exported.run_decompositions(get_decomp_table()) + converter = TorchConverter() + converter.register_custom_kernels(kernels) + converter.add_exported_program(ep, output_names=output_names or []) + return converter.to_coreai() + + +def _identity_kernel(name: str, *, src: str = "out[id] = x[id];") -> TorchMetalKernel: + def torch_defn(x: torch.Tensor) -> torch.Tensor: + return x + + return TorchMetalKernel( + name, + input_names=["x"], + result_names=["out"], + src=src, + torch_defn=torch_defn, + metal_params=[MetalParameter("id", "uint", "thread_position_in_grid")], + ) + + +# --------------------------------------------------------------------------- +# Collision at register time +# --------------------------------------------------------------------------- + + +class TestRegisterTimeCollision: + """Two distinct kernel objects with the same ``name`` field.""" + + @staticmethod + def test_same_name_identical_source_rejected_at_register() -> None: + """Even with the *same* MSL source, the second registration fails fast. + + A coreai-torch ``register_custom_kernels`` call cannot tell that two + kernels are equivalent; its only option is to fail and let the user + register a single instance. + """ + kernel_a = _identity_kernel("name_collision_same_src") + kernel_b = _identity_kernel("name_collision_same_src") + + converter = TorchConverter() + with pytest.raises(ValueError, match="already registered"): + converter.register_custom_kernels([kernel_a, kernel_b]) + + @staticmethod + def test_same_name_different_source_rejected_at_register() -> None: + """Distinct MSL bodies under the same ``name`` would silently shadow. + + The converter must not allow this — the second register call should + raise. + """ + kernel_a = _identity_kernel("name_collision_diff_src", src="out[id] = x[id];") + kernel_b = _identity_kernel( + "name_collision_diff_src", src="out[id] = x[id] * x[id];" + ) + + converter = TorchConverter() + with pytest.raises(ValueError, match="already registered"): + converter.register_custom_kernels([kernel_a, kernel_b]) + + @staticmethod + def test_same_name_split_across_two_register_calls_rejected() -> None: + """Splitting the two registrations across calls still collides.""" + kernel_a = _identity_kernel("name_collision_two_calls") + kernel_b = _identity_kernel("name_collision_two_calls") + + converter = TorchConverter() + converter.register_custom_kernels([kernel_a]) + with pytest.raises(ValueError, match="already registered"): + converter.register_custom_kernels([kernel_b]) + + @staticmethod + def test_distinct_names_register_cleanly() -> None: + """The collision check is keyed on ``name``, not on object identity.""" + kernel_a = _identity_kernel("name_distinct_a") + kernel_b = _identity_kernel("name_distinct_b") + + converter = TorchConverter() + converter.register_custom_kernels([kernel_a, kernel_b]) + # No error. + + +# --------------------------------------------------------------------------- +# Per-instance kernel cache: same-instance reuse +# --------------------------------------------------------------------------- + + +class TestPerInstanceCaching: + """The ``kernel_cache`` keys randomized names by ``(rank, dtype)``.""" + + @staticmethod + def test_same_kernel_same_shape_dtype_reuses_cached_name() -> None: + """Calling the same kernel twice with the same shape reuses the source.""" + kernel = _identity_kernel("cache_hit") + + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = kernel( + x, + threads_per_grid=(x.shape[0], 1, 1), + threads_per_thread_group=(1, 1, 1), + result_shapes=[list(x.shape)], + ) + b = kernel( + a, + threads_per_grid=(x.shape[0], 1, 1), + threads_per_thread_group=(1, 1, 1), + result_shapes=[list(x.shape)], + ) + return b + + coreai_program = _convert_model( + Model().eval(), + args=(torch.zeros(4, dtype=torch.float16),), + kernels=[kernel], + output_names=["out"], + ) + ir = str(coreai_program) + assert ir.count("coreai.metal4_kernel") == 2 + # The kernel cache is keyed on (rank, metal_dtype). Both calls have + # rank-1 + half tensors → same cached randomized name. + cached_entries = list(kernel.kernel_cache.values()) + assert len(cached_entries) == 1, ( + f"Expected exactly one cached randomized name, got {len(cached_entries)}" + ) + + @staticmethod + def test_same_kernel_different_dtypes_emits_two_cache_entries() -> None: + """Distinct ``(rank, dtype)`` combinations produce distinct PSO sources.""" + + def torch_defn(x: torch.Tensor) -> torch.Tensor: + return x + + kernel = TorchMetalKernel( + "cache_miss", + input_names=["x"], + result_names=["out"], + src="out[id] = x[id];", + torch_defn=torch_defn, + metal_params=[MetalParameter("id", "uint", "thread_position_in_grid")], + template_dtypes={"x": "TYPE"}, + ) + + class Model(torch.nn.Module): + def forward( + self, a: torch.Tensor, b: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + ra = kernel( + a, + threads_per_grid=(a.shape[0], 1, 1), + threads_per_thread_group=(1, 1, 1), + result_shapes=[list(a.shape)], + ) + rb = kernel( + b, + threads_per_grid=(b.shape[0], 1, 1), + threads_per_thread_group=(1, 1, 1), + result_shapes=[list(b.shape)], + ) + return ra, rb + + _convert_model( + Model().eval(), + args=( + torch.zeros(4, dtype=torch.float16), + torch.zeros(4, dtype=torch.float32), + ), + kernels=[kernel], + output_names=["ra", "rb"], + ) + # Two distinct (rank, dtype) keys → two cache entries. + cached_entries = list(kernel.kernel_cache.values()) + assert len(cached_entries) == 2 + # Each entry has a unique randomized name. + names = {entry[0] for entry in cached_entries} + assert len(names) == 2 + + @staticmethod + def test_two_instances_same_name_get_distinct_randomized_names() -> None: + """Two kernel objects with the same ``name`` randomize independently. + + Even though ``register_custom_kernels`` rejects this, the underlying + randomized-name machinery must produce distinct suffixes per instance + so that any future cross-converter use cannot silently collide. + """ + kernel_a = _identity_kernel("instance_a") + kernel_b = _identity_kernel("instance_a") + + # Drive each through its own converter to bypass the register-time + # collision check; the kernel_cache is per-instance. + for k in (kernel_a, kernel_b): + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self._k = k + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self._k( + x, + threads_per_grid=(x.shape[0], 1, 1), + threads_per_thread_group=(1, 1, 1), + result_shapes=[list(x.shape)], + ) + + converter = TorchConverter() + converter.register_custom_kernels([k]) + ep = torch.export.export( + Model().eval(), args=(torch.zeros(4, dtype=torch.float16),) + ).run_decompositions(get_decomp_table()) + converter.add_exported_program(ep, output_names=["out"]) + converter.to_coreai() + + # Each instance has its own cache; the randomized names cannot overlap. + names_a = {entry[0] for entry in kernel_a.kernel_cache.values()} + names_b = {entry[0] for entry in kernel_b.kernel_cache.values()} + assert names_a and names_b + assert names_a.isdisjoint(names_b), ( + f"Per-instance randomized names must be disjoint; got " + f"{names_a} and {names_b}" + ) diff --git a/tests/dsl/test_scalar_inputs.py b/tests/dsl/test_scalar_inputs.py new file mode 100644 index 0000000..20dc4b9 --- /dev/null +++ b/tests/dsl/test_scalar_inputs.py @@ -0,0 +1,447 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +"""Tests for scalar inputs to custom Metal kernels (``n_scalar_inputs`` path). + +A scalar passed to a ``TorchMetalKernel`` (``int``, ``float``, ``bool``) is +captured by ``get_operand`` as a ``coreai.constant`` rank-0 tensor and bound +to the kernel as a ``constant T& name [[buffer(N)]]`` parameter — a different +runtime path from the regular tensor (``MTLTensor``) bindings. + +Both data inputs and scalar inputs share the 31-buffer limit imposed by +``CustomMetalKernel.PARAMETER_LIMIT``. The validation tests here pin that +contract at the converter layer so the failure surfaces with a clear +``ValueError`` rather than as a runtime crash inside MPS. +""" + +from __future__ import annotations + +import re +import sys +from typing import Any + +import pytest +import torch + +from coreai_torch import ( + MetalParameter, + TorchConverter, + TorchMetalKernel, + get_decomp_table, +) + + +def _convert_model( + model: torch.nn.Module, + args: tuple, + kernels: list[TorchMetalKernel], + output_names: list[str] | None = None, +) -> Any: + """Export, register, and convert a model with custom kernels.""" + exported = torch.export.export(model, args=args) + ep = exported.run_decompositions(get_decomp_table()) + converter = TorchConverter() + converter.register_custom_kernels(kernels) + converter.add_exported_program(ep, output_names=output_names or []) + return converter.to_coreai() + + +# --------------------------------------------------------------------------- +# IR-level: scalar inputs lower through the converter +# --------------------------------------------------------------------------- + + +class TestScalarInputLowering: + """A scalar input is captured as a rank-0 constant and bound as buffer.""" + + @staticmethod + @pytest.mark.parametrize( + ("annotation", "scalar_value", "metal_dtype"), + [ + (float, 2.5, "float"), + (int, 7, "int"), + (bool, True, "bool"), + ], + ) + def test_single_scalar_input_lowers( + annotation: type, + scalar_value: Any, + metal_dtype: str, + ) -> None: + """One tensor input + one scalar input — verify IR signature has a rank-0 operand.""" + + if annotation is float: + + def torch_defn(x: torch.Tensor, c: float) -> torch.Tensor: # type: ignore[misc] + return x + c + elif annotation is int: + + def torch_defn(x: torch.Tensor, c: int) -> torch.Tensor: # type: ignore[misc, no-redef] + return x + c + else: + + def torch_defn(x: torch.Tensor, c: bool) -> torch.Tensor: # type: ignore[misc, no-redef] + return x + int(c) + + kernel = TorchMetalKernel( + f"scalar_{metal_dtype}_kernel", + input_names=["x", "c"], + result_names=["out"], + src="out[id] = x[id];", + torch_defn=torch_defn, + metal_params=[MetalParameter("id", "uint", "thread_position_in_grid")], + ) + + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return kernel( + x, + scalar_value, + threads_per_grid=(x.shape[0], 1, 1), + threads_per_thread_group=(1, 1, 1), + result_shapes=[list(x.shape)], + ) + + coreai_program = _convert_model( + Model().eval(), + args=(torch.zeros(4, dtype=torch.float16),), + kernels=[kernel], + output_names=["out"], + ) + ir = str(coreai_program) + + # Kernel name appears in IR. + assert f"scalar_{metal_dtype}_kernel_" in ir + # MSL signature emits scalar as `constant T&` rather than `tensor<...>`. + # The kernel_source attribute embeds the Metal source directly. + assert f"constant {metal_dtype}& c " in ir, ( + f"Expected `constant {metal_dtype}& c` in the emitted MSL " + f"source but got: {ir!s}" + ) + + @staticmethod + def test_max_scalar_inputs_at_buffer_limit() -> None: + """Scalars-only kernel close to the 31-buffer limit lowers cleanly. + + 29 scalar inputs + 1 result = 30 buffers, well under the 31 cap. + """ + n_scalars = 29 + + # `def *args` is rejected by the constructor (variadic). Build a real + # signature with N scalar parameters via exec. + scalar_args = ", ".join(f"s{i}: float" for i in range(n_scalars)) + ns: dict[str, Any] = {"torch": torch} + exec( # noqa: S102 + f"def torch_defn({scalar_args}) -> torch.Tensor:\n" + " return torch.zeros(4, dtype=torch.float16)\n", + ns, + ) + torch_defn = ns["torch_defn"] + + kernel = TorchMetalKernel( + "max_scalar_inputs", + input_names=[f"s{i}" for i in range(n_scalars)], + result_names=["out"], + src="out[id] = 0.0;", + torch_defn=torch_defn, + metal_params=[MetalParameter("id", "uint", "thread_position_in_grid")], + ) + + scalar_values = [float(i) for i in range(n_scalars)] + + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return ( + kernel( + *scalar_values, + threads_per_grid=(4, 1, 1), + threads_per_thread_group=(1, 1, 1), + result_shapes=[[4]], + ) + + x + ) + + coreai_program = _convert_model( + Model().eval(), + args=(torch.zeros(4, dtype=torch.float16),), + kernels=[kernel], + output_names=["out"], + ) + ir = str(coreai_program) + assert "max_scalar_inputs_" in ir + # All 29 scalars should appear as `constant float&` parameters. + assert ir.count("constant float&") == n_scalars + + @staticmethod + def test_scalars_plus_data_inputs_exceeding_limit_rejected() -> None: + """Total of (data inputs + scalar inputs + results) > 31 must error.""" + # 25 data inputs + 6 scalar inputs + 1 result = 32 > 31. + n_data = 25 + n_scalars = 6 + names_data = [f"t{i}" for i in range(n_data)] + names_scalars = [f"s{i}" for i in range(n_scalars)] + + params = ", ".join( + [f"{name}: torch.Tensor" for name in names_data] + + [f"{name}: float" for name in names_scalars] + ) + body = " + ".join(names_data) + ( + "" if not names_scalars else " + " + " + ".join(names_scalars) + ) + ns: dict[str, Any] = {"torch": torch} + exec( # noqa: S102 + f"def torch_defn({params}) -> torch.Tensor:\n return {body}\n", + ns, + ) + torch_defn = ns["torch_defn"] + + kernel = TorchMetalKernel( + "over_limit", + input_names=[*names_data, *names_scalars], + result_names=["out"], + src="out[id] = 0.0;", + torch_defn=torch_defn, + metal_params=[MetalParameter("id", "uint", "thread_position_in_grid")], + ) + + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + tensor_args = [x] * n_data + scalar_args = [float(i) for i in range(n_scalars)] + return kernel( + *tensor_args, + *scalar_args, + threads_per_grid=(4, 1, 1), + threads_per_thread_group=(1, 1, 1), + result_shapes=[list(x.shape)], + ) + + with pytest.raises(ValueError, match=r"metal kernels support 31 inputs"): + _convert_model( + Model().eval(), + args=(torch.zeros(4, dtype=torch.float16),), + kernels=[kernel], + ) + + +class TestScalarCaching: + """Scalar-value-aware kernel caching. + + Scalar-bearing kernels bake the literal into the MSL body, so the base + class's cache (keyed only on ``(rank, dtype)``) can't be shared blindly + across call sites. ``TorchMetalKernel`` keeps one sub-cache per distinct set + of scalar values: identical ``(scalar_values, rank, dtype)`` call sites reuse + a single templated kernel (one PSO, one randomized name), while differing + scalar values stay isolated. + """ + + @staticmethod + def _kernel_names(ir: str) -> list[str]: + """Every emitted ``metal4_kernel`` name, one per call site.""" + return re.findall(r'kernel_name = "([^"]+)"', ir) + + @staticmethod + def _scalar_add_kernel(name: str) -> TorchMetalKernel: + def torch_defn(x: torch.Tensor, c: float) -> torch.Tensor: + return x + c + + return TorchMetalKernel( + name, + input_names=["x", "c"], + result_names=["out"], + src="out[id] = x[id] + c;", + torch_defn=torch_defn, + metal_params=[MetalParameter("id", "uint", "thread_position_in_grid")], + ) + + @staticmethod + def _two_call_model( + kernel: TorchMetalKernel, + scalar_a: float, + scalar_b: float, + ) -> torch.nn.Module: + class Model(torch.nn.Module): + def forward( + self, a: torch.Tensor, b: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + ra = kernel( + a, + scalar_a, + threads_per_grid=(a.shape[0], 1, 1), + threads_per_thread_group=(1, 1, 1), + result_shapes=[list(a.shape)], + ) + rb = kernel( + b, + scalar_b, + threads_per_grid=(b.shape[0], 1, 1), + threads_per_thread_group=(1, 1, 1), + result_shapes=[list(b.shape)], + ) + return ra, rb + + return Model().eval() + + @staticmethod + def test_same_scalar_value_reuses_one_kernel() -> None: + """Two call sites with the same scalar value + shape share one PSO.""" + kernel = TestScalarCaching._scalar_add_kernel("cached_scalar") + coreai_program = _convert_model( + TestScalarCaching._two_call_model(kernel, 5.0, 5.0), + args=( + torch.zeros(4, dtype=torch.float16), + torch.ones(4, dtype=torch.float16), + ), + kernels=[kernel], + output_names=["ra", "rb"], + ) + names = TestScalarCaching._kernel_names(str(coreai_program)) + # One op per call site, but the shared scalar/shape collapses to one PSO. + assert len(names) == 2 + assert len(set(names)) == 1, ( + f"Expected both call sites to share one kernel name, got {names}" + ) + + @staticmethod + def test_different_scalar_values_emit_distinct_kernels() -> None: + """Two call sites with different scalar values stay isolated.""" + kernel = TestScalarCaching._scalar_add_kernel("scalar_per_value") + coreai_program = _convert_model( + TestScalarCaching._two_call_model(kernel, 5.0, 9.0), + args=( + torch.zeros(4, dtype=torch.float16), + torch.ones(4, dtype=torch.float16), + ), + kernels=[kernel], + output_names=["ra", "rb"], + ) + ir = str(coreai_program) + names = TestScalarCaching._kernel_names(ir) + assert len(names) == 2 + assert len(set(names)) == 2, ( + f"Expected two distinct kernels for differing scalar values, got {names}" + ) + # Each call site baked its own literal. + assert "c = 5.0" in ir + assert "c = 9.0" in ir + + +# --------------------------------------------------------------------------- +# Numerical: scalar input behavior end-to-end (macOS only) +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(sys.platform != "darwin", reason="Metal tests run only on Mac") +class TestScalarInputNumerical: + """End-to-end numerical correctness when scalars are baked into the graph.""" + + @staticmethod + async def test_float_scalar_added_elementwise() -> None: + """A float scalar passed alongside a tensor produces correct output.""" + from ..utils import validate_numerical_output + + def torch_defn(x: torch.Tensor, c: float) -> torch.Tensor: + return x + c + + kernel = TorchMetalKernel( + "scalar_add_float", + input_names=["x", "c"], + result_names=["out"], + src="out[id] = x[id] + (TYPE)c;", + torch_defn=torch_defn, + metal_params=[MetalParameter("id", "uint", "thread_position_in_grid")], + template_dtypes={"x": "TYPE"}, + ) + + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return kernel( + x, + 3.5, + threads_per_grid=(x.shape[0], 1, 1), + threads_per_thread_group=(1, 1, 1), + result_shapes=[list(x.shape)], + ) + + await validate_numerical_output( + model=Model().eval(), + custom_kernels=[kernel], + metal_inputs=True, + input_names=["x"], + output_names=["result"], + x=torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32), + ) + + @staticmethod + async def test_int_scalar_used_as_index_offset() -> None: + """An int scalar is bound and read in the kernel body.""" + from ..utils import validate_numerical_output + + def torch_defn(x: torch.Tensor, n: int) -> torch.Tensor: + return x + float(n) + + kernel = TorchMetalKernel( + "scalar_add_int", + input_names=["x", "n"], + result_names=["out"], + src="out[id] = x[id] + float(n);", + torch_defn=torch_defn, + metal_params=[MetalParameter("id", "uint", "thread_position_in_grid")], + ) + + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return kernel( + x, + 7, + threads_per_grid=(x.shape[0], 1, 1), + threads_per_thread_group=(1, 1, 1), + result_shapes=[list(x.shape)], + ) + + await validate_numerical_output( + model=Model().eval(), + custom_kernels=[kernel], + metal_inputs=True, + input_names=["x"], + output_names=["result"], + x=torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=torch.float32), + ) + + @staticmethod + async def test_bool_scalar_branches_kernel_path() -> None: + """A bool scalar selects between two kernel branches at runtime.""" + from ..utils import validate_numerical_output + + def torch_defn(x: torch.Tensor, flag: bool) -> torch.Tensor: + return x * 2.0 if flag else x + + kernel = TorchMetalKernel( + "scalar_bool_branch", + input_names=["x", "flag"], + result_names=["out"], + src="out[id] = flag ? (x[id] * 2.0f) : x[id];", + torch_defn=torch_defn, + metal_params=[MetalParameter("id", "uint", "thread_position_in_grid")], + ) + + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return kernel( + x, + True, + threads_per_grid=(x.shape[0], 1, 1), + threads_per_thread_group=(1, 1, 1), + result_shapes=[list(x.shape)], + ) + + await validate_numerical_output( + model=Model().eval(), + custom_kernels=[kernel], + metal_inputs=True, + input_names=["x"], + output_names=["result"], + x=torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32), + ) diff --git a/tests/dsl/test_thread_config.py b/tests/dsl/test_thread_config.py new file mode 100644 index 0000000..eb58751 --- /dev/null +++ b/tests/dsl/test_thread_config.py @@ -0,0 +1,344 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +"""Thread-configuration edge cases for ``TorchMetalKernel`` dispatch. + +The runtime forwards ``threads_per_grid`` and ``threads_per_threadgroup`` +verbatim to Metal's ``dispatchThreads``. Threadgroup sizes exceeding +``maxTotalThreadsPerThreadgroup`` (1024 on current Apple Silicon) are +rejected by Metal at PSO time; valid sizes near or above the visible +grid are simply rounded up — the kernel is responsible for guarding +out-of-bounds reads/writes when ``threads_per_grid`` exceeds the +visible tensor extent. + +Note: ``MTLTensor`` extents are stored in *reverse* of the torch shape +(see ``NDArray+Metal.swift``: ``shapeSpan.reversed()``). For a torch +tensor of shape ``(D0, D1, D2)`` the kernel sees extents +``(D2, D1, D0)``; ``get_extent(0)`` is the innermost (fastest-varying) +torch dim. Multi-dim dispatch tuples must match this convention. + +Only IR-level checks here can be fully cross-platform; the numerical tests +below need a Metal-backed runtime to actually execute. They are gated on +macOS via the dsl conftest's collection hook. +""" + +from __future__ import annotations + +import sys +from typing import Any + +import pytest +import torch + +from coreai_torch import ( + MetalParameter, + TorchConverter, + TorchMetalKernel, + get_decomp_table, +) + + +def _convert_model( + model: torch.nn.Module, + args: tuple, + kernels: list[TorchMetalKernel], + output_names: list[str] | None = None, +) -> Any: + exported = torch.export.export(model, args=args) + ep = exported.run_decompositions(get_decomp_table()) + converter = TorchConverter() + converter.register_custom_kernels(kernels) + converter.add_exported_program(ep, output_names=output_names or []) + return converter.to_coreai() + + +def _make_identity_kernel(name: str) -> TorchMetalKernel: + """Identity kernel with explicit bounds-check on ``id``.""" + + def torch_defn(x: torch.Tensor) -> torch.Tensor: + return x.clone() + + return TorchMetalKernel( + name, + input_names=["x"], + result_names=["out"], + src=("if (id >= x.get_extent(0)) return; out[id] = x[id];"), + torch_defn=torch_defn, + metal_params=[MetalParameter("id", "uint", "thread_position_in_grid")], + ) + + +# --------------------------------------------------------------------------- +# IR-level: dispatch values land in the IR unchanged +# --------------------------------------------------------------------------- + + +class TestDispatchValueLowering: + """The 3-tuple values land in the IR; clamping is a runtime concern.""" + + @staticmethod + def test_unit_grid_dispatch_lowers() -> None: + """``threads_per_grid=(1,1,1)`` is valid and lowers cleanly. + + The kernel's bounds-check guards against the over-large tensor. + """ + kernel = _make_identity_kernel("thread_unit_grid") + + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return kernel( + x, + threads_per_grid=(1, 1, 1), + threads_per_thread_group=(1, 1, 1), + result_shapes=[list(x.shape)], + ) + + coreai_program = _convert_model( + Model().eval(), + args=(torch.zeros(4, dtype=torch.float16),), + kernels=[kernel], + output_names=["out"], + ) + assert "thread_unit_grid_" in str(coreai_program) + + @staticmethod + def test_full_3d_dispatch_lowers() -> None: + """Non-trivial values for x, y and z dimensions all lower.""" + kernel = _make_identity_kernel("thread_full_3d") + + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return kernel( + x, + threads_per_grid=(8, 4, 2), + threads_per_thread_group=(4, 2, 2), + result_shapes=[list(x.shape)], + ) + + # 64-element flat tensor. Kernel's bounds-check ensures correctness + # regardless of the 3D dispatch decomposition. + coreai_program = _convert_model( + Model().eval(), + args=(torch.zeros(64, dtype=torch.float16),), + kernels=[kernel], + output_names=["out"], + ) + assert "thread_full_3d_" in str(coreai_program) + + @staticmethod + def test_threadgroup_larger_than_typical_pso_max_lowers() -> None: + """An over-large ``threads_per_thread_group`` is accepted at the IR layer. + + The runtime clamps this to ``pso.maxTotalThreadsPerThreadgroup`` + (1024 on most Apple Silicon GPUs); the converter should not pre-validate + that — it's a hardware property only known at PSO compilation time. + """ + kernel = _make_identity_kernel("thread_clamping_test") + + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return kernel( + x, + threads_per_grid=(2048, 1, 1), + # Way above any real GPU's max threadgroup size; runtime + # clamping must handle this gracefully. + threads_per_thread_group=(2048, 1, 1), + result_shapes=[list(x.shape)], + ) + + coreai_program = _convert_model( + Model().eval(), + args=(torch.zeros(2048, dtype=torch.float16),), + kernels=[kernel], + output_names=["out"], + ) + assert "thread_clamping_test_" in str(coreai_program) + + +# --------------------------------------------------------------------------- +# Numerical: behavior under unusual dispatch configurations +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(sys.platform != "darwin", reason="Metal tests run only on Mac") +class TestThreadDispatchNumerical: + """Behavior under boundary dispatch configurations.""" + + @staticmethod + async def test_threadgroup_larger_than_grid_does_not_corrupt_output() -> None: + """Threadgroup size larger than the grid still dispatches correctly. + + ``dispatchThreads`` rounds the grid up to the next multiple of the + threadgroup, so a 1024-wide threadgroup over a 64-element grid still + launches one full threadgroup. The kernel's bounds-check filters out + the over-dispatched threads. + """ + from ..utils import validate_numerical_output + + kernel = _make_identity_kernel("thread_clamp_numerical") + + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return kernel( + x, + threads_per_grid=(x.shape[0], 1, 1), + # 1024 = maxTotalThreadsPerThreadgroup on current Apple + # Silicon — the largest threadgroup Metal will accept. + threads_per_thread_group=(1024, 1, 1), + result_shapes=[list(x.shape)], + ) + + await validate_numerical_output( + model=Model().eval(), + custom_kernels=[kernel], + metal_inputs=True, + input_names=["x"], + output_names=["result"], + x=torch.arange(64, dtype=torch.float32), + ) + + @staticmethod + async def test_full_3d_dispatch_numerical() -> None: + """A full 3D dispatch produces the same identity result as a 1D one.""" + from ..utils import validate_numerical_output + + # Index a 3D tensor by (gid.x, gid.y, gid.z). + def torch_defn(x: torch.Tensor) -> torch.Tensor: + return x.clone() + + kernel = TorchMetalKernel( + "thread_3d_identity", + input_names=["x"], + result_names=["out"], + src=( + "if (gid.x >= x.get_extent(0) || " + " gid.y >= x.get_extent(1) || " + " gid.z >= x.get_extent(2)) return; " + "out[gid.x, gid.y, gid.z] = x[gid.x, gid.y, gid.z];" + ), + torch_defn=torch_defn, + metal_params=[ + MetalParameter("gid", "uint3", "thread_position_in_grid"), + ], + ) + + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return kernel( + x, + # MTLTensor extents are reversed from the torch shape, so + # ``get_extent(0)`` is the innermost torch dim. Dispatch + # in the same reversed order so each ``gid`` axis lines + # up with the matching ``get_extent``. + threads_per_grid=(x.shape[2], x.shape[1], x.shape[0]), + threads_per_thread_group=(2, 2, 2), + result_shapes=[list(x.shape)], + ) + + await validate_numerical_output( + model=Model().eval(), + custom_kernels=[kernel], + metal_inputs=True, + input_names=["x"], + output_names=["result"], + x=torch.arange(2 * 4 * 6, dtype=torch.float32).reshape(2, 4, 6), + ) + + @staticmethod + async def test_unit_grid_writes_only_first_element() -> None: + """``threads_per_grid=(1,1,1)`` writes exactly one element. + + The output buffer for the un-touched elements remains at whatever the + runtime initialized it to. The kernel here writes a sentinel into the + first position so we can assert it landed. + """ + from ..utils import validate_numerical_output + + # The kernel zeros the output and only writes index 0. We then add the + # input tensor downstream so the model's torch reference is the same. + def torch_defn(x: torch.Tensor) -> torch.Tensor: + out = torch.zeros_like(x) + out[0] = x[0] + return out + + kernel = TorchMetalKernel( + "thread_unit_grid_numerical", + input_names=["x"], + result_names=["out"], + src=( + # Initialize all elements to 0, then thread 0 writes x[0]. + "for (uint i = 0; i < x.get_extent(0); ++i) out[i] = 0; " + "if (id == 0) out[0] = x[0];" + ), + torch_defn=torch_defn, + metal_params=[MetalParameter("id", "uint", "thread_position_in_grid")], + ) + + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return kernel( + x, + threads_per_grid=(1, 1, 1), + threads_per_thread_group=(1, 1, 1), + result_shapes=[list(x.shape)], + ) + + await validate_numerical_output( + model=Model().eval(), + custom_kernels=[kernel], + metal_inputs=True, + input_names=["x"], + output_names=["result"], + x=torch.tensor([7.0, 1.0, 2.0, 3.0], dtype=torch.float32), + ) + + @staticmethod + async def test_grid_smaller_than_tensor_size_with_bounds_check() -> None: + """Under-dispatch leaves untouched tail elements; the kernel must not OOB. + + Sentinel-output kernel that initializes the entire output to zero + across all dispatched threads, then only writes ``id < grid_size``. + Reference torch_defn matches. + """ + from ..utils import validate_numerical_output + + # Tensor size = 16, grid = 8 → only first 8 outputs are written from x. + # Reference torch_defn replicates: out[0..8] = x[0..8] * 2, rest = 0. + def torch_defn(x: torch.Tensor) -> torch.Tensor: + out = torch.zeros_like(x) + out[:8] = x[:8] * 2 + return out + + kernel = TorchMetalKernel( + "thread_under_dispatch", + input_names=["x"], + result_names=["out"], + src=( + # Each dispatched thread is responsible for its own tail + # elements as well, zeroing them. + "for (uint i = id; i < x.get_extent(0); i += 8) out[i] = 0; " + "if (id < 8 && id < x.get_extent(0)) out[id] = x[id] * 2;" + ), + torch_defn=torch_defn, + metal_params=[MetalParameter("id", "uint", "thread_position_in_grid")], + ) + + class Model(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return kernel( + x, + threads_per_grid=(8, 1, 1), + threads_per_thread_group=(1, 1, 1), + result_shapes=[list(x.shape)], + ) + + await validate_numerical_output( + model=Model().eval(), + custom_kernels=[kernel], + metal_inputs=True, + input_names=["x"], + output_names=["result"], + x=torch.arange(16, dtype=torch.float32), + ) diff --git a/tests/ops/test_ops.py b/tests/ops/test_ops.py index 6e85b7b..dc1fd9e 100644 --- a/tests/ops/test_ops.py +++ b/tests/ops/test_ops.py @@ -6666,7 +6666,7 @@ def forward(self, x): @pytest.mark.parametrize("dynamic", [False, True]) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +@pytest.mark.parametrize("dtype", [torch.float32]) @pytest.mark.parametrize( "shape", [ @@ -6765,7 +6765,7 @@ async def test_broadcast_shapes(self) -> None: @pytest.mark.parametrize("dynamic", [False, True]) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +@pytest.mark.parametrize("dtype", [torch.float32]) @pytest.mark.parametrize( "shape", [ @@ -6813,7 +6813,7 @@ def forward(self, x: Tensor) -> Tensor: @pytest.mark.parametrize("dynamic", [False, True]) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +@pytest.mark.parametrize("dtype", [torch.float32]) @pytest.mark.parametrize( "shape", [ @@ -6825,27 +6825,13 @@ def forward(self, x: Tensor) -> Tensor: async def test_view_as_complex( shape: tuple[int, ...], dtype: torch.dtype, dynamic: bool ) -> None: - """Test torch.view_as_complex converting a real tensor [..., 2] to a complex tensor [...]。 - - float16 input produces complex (torch.complex32); check_result_type must - accept complex when the FX metadata records complex (torch.complex64), - which happens when an f16-cast ExportedProgram is imported via TorchConverter. - """ + """Test torch.view_as_complex converting a real tensor [..., 2] to a complex tensor [...].""" class ViewAsComplexModel(nn.Module): def forward(self, x: Tensor) -> Tensor: return torch.view_as_complex(x) - class ViewAsComplexF16Model(nn.Module): - def forward(self, x: Tensor) -> Tensor: - # view_as_real roundtrip: keeps output as float so numpy can compare. - # float16 input produces complex internally, exercising the - # complex64->complex32 narrowing in check_result_type. - return torch.view_as_real(torch.view_as_complex(x)) - - model = ( - ViewAsComplexModel() if dtype == torch.float32 else ViewAsComplexF16Model() - ).eval() + model = ViewAsComplexModel().eval() x = torch.randn(*shape, 2, dtype=dtype) dynamic_shapes = ( make_dynamic_shapes(x={i: f"d{i}" for i in range(x.dim() - 1)}) @@ -6856,7 +6842,7 @@ def forward(self, x: Tensor) -> Tensor: @pytest.mark.parametrize("dynamic", [False, True]) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +@pytest.mark.parametrize("dtype", [torch.float32]) @pytest.mark.parametrize( "shape", [ diff --git a/tests/test_debug_locations.py b/tests/test_debug_locations.py index d0ad62b..9c8ec4b 100644 --- a/tests/test_debug_locations.py +++ b/tests/test_debug_locations.py @@ -86,7 +86,6 @@ async def test_debug_locations() -> None: converter: TorchConverter = TorchConverter() debug_config = _DebugInfoRecorder.Config( include_stack_trace=True, - options=_DebugInfoRecorder.Options.DEBUGINFO, verify_debuginfo_locations=True, ) converter._debug_info_recorder = _DebugInfoRecorder(config=debug_config) @@ -118,7 +117,6 @@ def test_debug_locations_multiple_programs() -> None: converter: TorchConverter = TorchConverter() debug_config = _DebugInfoRecorder.Config( include_stack_trace=True, - options=_DebugInfoRecorder.Options.DEBUGINFO, verify_debuginfo_locations=True, ) converter._debug_info_recorder = _DebugInfoRecorder(config=debug_config) diff --git a/tests/test_docs.py b/tests/test_docs.py new file mode 100644 index 0000000..a5fe7ea --- /dev/null +++ b/tests/test_docs.py @@ -0,0 +1,57 @@ +# Copyright 2026 Apple Inc. +# +# Use of this source code is governed by a BSD-3-clause license that can +# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause + +"""Verify the documentation build emits the llms.txt artifacts. + +The ``sphinx_llm.txt`` extension (declared in the ``docs`` extra) writes +``llms.txt`` (index), ``llms-full.txt`` (corpus), and per-page ``*.html.md`` +files described by https://llmstxt.org/, which power the "Copy page" / "View as +Markdown" affordance in the Shibuya theme. It produces them in a *separate* +``sphinx-build -b markdown`` subprocess that it spawns, which re-reads +``docs/conf.py``. + +The test is skipped where the ``docs`` extra is absent (e.g. a plain +``uv sync --extra test``); it runs in CI because ``pr-checks-linux`` syncs +``--extra docs``. Notebook execution is disabled via the ``NB_EXECUTION_MODE`` +environment variable rather than a ``-D`` override: the env var is inherited by +both the parent build and sphinx-llm's child markdown subprocess (``-D`` reaches +only the parent), so the Metal-kernel guide notebooks — which cannot run on +Linux CI — are skipped in the build that actually writes the artifacts. +""" + +import os +import subprocess +import sys +from pathlib import Path + +import pytest + +pytest.importorskip("sphinx_llm") + +_DOCS_DIR = Path(__file__).resolve().parent.parent / "docs" +_ARTIFACTS = ("llms.txt", "llms-full.txt", "index.html.md") + + +def test_docs_build_emits_llms_artifacts(tmp_path: Path) -> None: + """A docs build generates the llms.txt index, full corpus, and per-page Markdown.""" + output_dir = tmp_path / "html" + # NB_EXECUTION_MODE=off is inherited by sphinx-llm's child markdown build + # (a `-D` flag would not reach it); the Metal-kernel notebooks can't run on Linux CI. + result = subprocess.run( + [sys.executable, "-m", "sphinx", "-b", "html", str(_DOCS_DIR), str(output_dir)], + capture_output=True, + text=True, + env={**os.environ, "NB_EXECUTION_MODE": "off"}, + ) + assert result.returncode == 0, ( + f"sphinx-build failed (rc={result.returncode}):\n{result.stdout}\n{result.stderr}" + ) + for artifact in _ARTIFACTS: + path = output_dir / artifact + assert path.exists() and path.stat().st_size > 0, ( + f"Expected non-empty docs artifact {path}.\n" + f"sphinx-build output:\n{result.stdout}\n{result.stderr}" + ) + assert "# coreai-torch" in (output_dir / "llms.txt").read_text() diff --git a/tests/utils.py b/tests/utils.py index d550e0f..620173d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -37,7 +37,7 @@ SpecializationOptions, ) -_ML_ASSET_EXTENSION = "mlasset" +_ML_ASSET_EXTENSION = "aimodel" # Compute unit selection driven by the --compute-unit-kind pytest option (see tests/conftest.py). # Default is "interpreter" so a plain `pytest` run still works. @@ -203,9 +203,14 @@ def _init_runtime_state( """ state: dict[str, NDArray] = {} for name in desc.state_names: - state[name] = NDArray.from_descriptor( - descriptor=desc.state_descriptor(name=name) - ) + # `NDArray.from_descriptor` only sizes the buffer; on Linux the backing + # storage isn't zeroed, so buffer-state reads return garbage on the + # first call. Allocate a zero-filled numpy array of the right shape and + # dtype instead so initial state matches the model's `register_buffer` + # value (assumed zero — same assumption tests already make). + d = desc.state_descriptor(name=name) + shape = tuple(s if s is not None else 1 for s in d.shape) + state[name] = NDArray(np.zeros(shape, dtype=np.dtype(d.dtype))) user_mut_names = list(sig.user_inputs_to_mutate.values()) num_buf_muts = len(sig.buffers_to_mutate) From 2d39225a697097216da7fc793bd7113926494aa2 Mon Sep 17 00:00:00 2001 From: gokulkrishna98 Date: Wed, 1 Jul 2026 14:53:20 -0700 Subject: [PATCH 2/2] ci: skip dsl-marked tests in CI --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index aa8086d..d941ce6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,4 +34,4 @@ jobs: run: | command -v uv >/dev/null 2>&1 || curl -LsSf https://astral.sh/uv/install.sh | sh echo "$HOME/.local/bin" >> "$GITHUB_PATH" - - run: uv run --extra test pytest tests/ -n auto -m "not slow" + - run: uv run --extra test pytest tests/ -n auto -m "not slow and not dsl"