diff --git a/iron/common/fusion.py b/iron/common/fusion.py index 99219848..c198a44f 100644 --- a/iron/common/fusion.py +++ b/iron/common/fusion.py @@ -1,26 +1,61 @@ # SPDX-FileCopyrightText: Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +import hashlib +import logging +import time import numpy as np import ml_dtypes import pyxrt import ctypes +import torch from . import compilation as comp from .base import AIEOperatorBase, MLIROperator from .utils import XRTSubBuffer import aie.utils as aie_utils +from aie.iron.device import NPU2 from aie.utils.hostruntime.xrtruntime.tensor import XRTTensor +from aie.utils.npukernel import NPUKernel + +logger = logging.getLogger(__name__) # Fused Operator # ########################################################################## class FusedMLIROperator(AIEOperatorBase): - """Operator that fuses multiple MLIROperators into one.""" + """Operator that fuses multiple MLIROperators into one. + + Args: + dispatch: Dispatch strategy for the fused operator. + ``"auto"`` (default) selects ``"fused"`` on NPU2 and + ``"separate"`` on NPU1. ``"fused"`` uses a single-ELF + dispatch (requires NPU2). ``"separate"`` compiles each + sub-operator to its own xclbin and invokes them sequentially. + ``"reference"`` runs only the per-operator CPU reference + implementations (no NPU compilation/dispatch). ``"compare"`` + runs the ``"separate"`` xclbin path and, after each NPU step, + also runs the operator's CPU reference on the NPU-produced + inputs and logs the deviation. + """ + + DISPATCH_MODES = ("auto", "fused", "separate", "reference", "compare") def __init__( - self, name, runlist, input_args, output_args, buffer_sizes=None, *args, **kwargs + self, + name, + runlist, + input_args, + output_args, + buffer_sizes=None, + dispatch="auto", + *args, + **kwargs, ): + if dispatch not in self.DISPATCH_MODES: + raise ValueError( + f"dispatch must be one of {self.DISPATCH_MODES!r}, got {dispatch!r}" + ) if not all( isinstance(op, MLIROperator) and all(isinstance(buf, str) for buf in bufs) for op, *bufs in runlist @@ -37,6 +72,7 @@ def __init__( self.explicit_buffer_sizes = ( buffer_sizes or {} ) # Optional dict: buffer_name -> size_in_bytes + self._dispatch = dispatch def get_kernel_artifacts(self): """Collect all kernel artifacts from child operators. @@ -205,13 +241,42 @@ def add_buffers(buffer_type, args_list): def set_up_artifacts(self): """Set up the artifact dependency graph for this fused operator. - Computes the buffer layout first, then builds the fused MLIR artifact - and full-ELF artifact and registers them via ``add_artifacts()``. + Computes the buffer layout first, then builds the artifacts. + The dispatch mode (``"fused"`` vs ``"separate"``) is resolved here + when set to ``"auto"``. """ - # Calculate buffer layout before building mlir artifact (used by get_mlir_artifact) + # Calculate buffer layout (used by both paths for get_buffer()) self.subbuffer_layout, self.buffer_sizes, self.slice_info = ( self._calculate_buffer_layout() ) + + is_npu2 = isinstance(aie_utils.get_current_device(), NPU2) + + if self._dispatch == "auto": + self._mode = "fused" if is_npu2 else "separate" + elif self._dispatch == "fused": + if not is_npu2: + raise RuntimeError( + "dispatch='fused' requires NPU2 (Strix); " + "Phoenix/NPU1 does not support full-ELF dispatch" + ) + self._mode = "fused" + else: + self._mode = self._dispatch # "separate", "reference", or "compare" + + # Backwards-compat flag (used by get_callable/params_path). + self._use_full_elf = self._mode == "fused" + + if self._mode == "fused": + self._set_up_full_elf_artifacts() + elif self._mode in ("separate", "compare"): + self._set_up_xclbin_artifacts() + else: + # "reference": no NPU artifacts to compile. + pass + + def _set_up_full_elf_artifacts(self): + """Full-ELF path (NPU2): fuse MLIR into a single ELF.""" operator_name = self.name mlir_artifact = self.get_mlir_artifact() kernel_objects = self.get_kernel_artifacts() @@ -222,6 +287,58 @@ def set_up_artifacts(self): ) self.add_artifacts([full_elf_artifact]) + def _set_up_xclbin_artifacts(self): + """Chained xclbin path (NPU1/Phoenix): separate xclbin per unique operator. + + Mirrors the pattern from ``chain_swiglu_artifacts`` in + ``iron/operators/swiglu_base.py``: each unique operator gets its own + xclbin + insts compiled separately, linked via ``--xclbin-input``. + """ + seen: dict[int, object] = {} + unique_operators = [ + seen.setdefault(id(op), op) + for op, *_ in self.runlist + if id(op) not in seen + ] + + # Short hash to keep xclbin kernel names under 31 chars + # (xclbinutil limits m_name to 64 chars as "name:name") + name_hash = hashlib.sha1(self.name.encode()).hexdigest()[:6] + + artifacts = [] + prev_xclbin = None + self._op_xclbin_map = {} # id(op) -> xclbin artifact + self._op_insts_map = {} # id(op) -> insts artifact + self._op_kernel_name_map = {} # id(op) -> kernel_name + + for idx, op in enumerate(unique_operators): + op_label = f"f{name_hash}_op{idx}" + kernel_id = f"0x{0x901 + idx:x}" + + xclbin, insts = op.get_artifacts(prefix=f"{op_label}_") + # Use list() to avoid mutating the shared extra_flags list + # (get_artifacts may alias the same list between xclbin and insts) + xclbin.extra_flags = list(xclbin.extra_flags) + [ + f"--xclbin-instance-name={op_label}", + f"--xclbin-kernel-id={kernel_id}", + ] + xclbin.kernel_name = op_label + + if prev_xclbin is not None: + xclbin.xclbin_input = prev_xclbin + xclbin.dependencies.add(prev_xclbin) + + artifacts.append(insts) + self._op_xclbin_map[id(op)] = xclbin + self._op_insts_map[id(op)] = insts + self._op_kernel_name_map[id(op)] = op_label + prev_xclbin = xclbin + + # The last xclbin in the chain is the combined xclbin. + artifacts.append(prev_xclbin) + self.combined_xclbin = prev_xclbin + self.add_artifacts(artifacts) + def get_arg_spec(self): raise NotImplementedError( "FusedMLIROperator does not expose a unified arg spec; " @@ -232,9 +349,16 @@ def get_callable(self): """Return a callable that executes the fused operator on the NPU. Returns: - A ``FusedFullELFCallable`` wrapping this operator. + A ``FusedFullELFCallable`` when using fused dispatch, or a + ``FusedXclbinCallable`` when using separate dispatch. """ - return FusedFullELFCallable(self) + if self._mode == "fused": + return FusedFullELFCallable(self) + if self._mode == "reference": + return FusedReferenceCallable(self) + if self._mode == "compare": + return FusedCompareCallable(self) + return FusedXclbinCallable(self) def get_layout_for_buffer(self, buffer_name): """Return the (buffer_type, offset, length) layout for a named buffer. @@ -378,3 +502,372 @@ def __call__(self): self.scratch_buffer.buffer_object(), ) self.output_buffer.to("cpu") + + +class FusedXclbinCallable: + """Callable for FusedMLIROperator on NPU1 (Phoenix) using chained xclbins. + + Instead of a single ELF dispatch, each step in the runlist is executed as a + separate ``NPUKernel`` invocation. Buffers are shared (same ``XRTTensor``) + across steps that reference the same buffer name, giving zero-copy handoff + between sequential operators. + """ + + def __init__(self, op): + self.op = op + self.last_elapsed = 0.0 + + combined_xclbin_path = op.combined_xclbin.filename + + # Build an NPUKernel per unique operator + self._op_callable_map = {} # id(op) -> NPUKernel + for op_id, xclbin in op._op_xclbin_map.items(): + insts = op._op_insts_map[op_id] + kernel_name = op._op_kernel_name_map[op_id] + self._op_callable_map[op_id] = NPUKernel( + xclbin_path=combined_xclbin_path, + kernel_name=kernel_name, + insts_path=insts.filename, + ) + + # Allocate one XRTTensor per unique base buffer name. + # Buffers that appear in multiple runlist entries share the same tensor + # (zero-copy between operators). + itemsize = np.dtype(ml_dtypes.bfloat16).itemsize + self._buffers = {} # base buffer name -> XRTTensor + for buf_name in list(op.subbuffer_layout.keys()): + _, _, length = op.subbuffer_layout[buf_name] + self._buffers[buf_name] = XRTTensor( + (max(length, itemsize) // itemsize,), + dtype=ml_dtypes.bfloat16, + ) + + # Pre-build the execution plan: list of (NPUKernel, [XRTTensor args]) + self._execution_plan = [] + for step_op, *buf_names in op.runlist: + kernel = self._op_callable_map[id(step_op)] + args = [] + for buf_name in buf_names: + args.append(self._resolve_buffer(buf_name)) + self._execution_plan.append((kernel, args)) + + # Cache for get_buffer() sub-buffer views (compatible with FusedFullELFCallable API) + self._buffer_cache = {} + + # Expose input/output/scratch buffers for API compatibility with + # FusedFullELFCallable (used by tests for .to("cpu") etc.) + input_buffer_size, output_buffer_size, scratch_buffer_size = op.buffer_sizes + self.input_buffer = XRTTensor( + (max(input_buffer_size, itemsize) // itemsize,), + dtype=ml_dtypes.bfloat16, + ) + self.output_buffer = XRTTensor( + (max(output_buffer_size, itemsize) // itemsize,), + dtype=ml_dtypes.bfloat16, + ) + self.scratch_buffer = XRTTensor( + (max(scratch_buffer_size, itemsize) // itemsize,), + dtype=ml_dtypes.bfloat16, + ) + + def _resolve_buffer(self, buf_name): + """Resolve a buffer name (possibly with slice notation) to an XRTTensor. + + Regular buffer names map directly to an allocated XRTTensor. + Sliced buffer names (e.g. ``queries[0:128]``) create an XRTSubBuffer + view into the parent buffer. + """ + if buf_name in self._buffers: + return self._buffers[buf_name] + + # Sliced buffer: "base_name[start:end]" + if buf_name in self.op.slice_info: + base_name, start_bytes, end_bytes = self.op.slice_info[buf_name] + parent = self._buffers[base_name] + itemsize = np.dtype(ml_dtypes.bfloat16).itemsize + size_bytes = end_bytes - start_bytes + sub = XRTSubBuffer( + parent_bo=parent.buffer_object(), + offset_bytes=start_bytes, + size_bytes=size_bytes, + shape=(size_bytes // itemsize,), + dtype=ml_dtypes.bfloat16, + ) + # Cache so the same slice always returns the same object + self._buffers[buf_name] = sub + return sub + + raise ValueError(f"Unknown buffer '{buf_name}' in fused runlist") + + def get_buffer(self, buffer_name): + """Return an XRTTensor(-like) view for a named buffer. + + Compatible with the ``FusedFullELFCallable.get_buffer()`` API so that + test helpers (``_load_input``, ``_get_output_tensor``, etc.) work + unchanged. + + For the xclbin path, each buffer is its own standalone XRTTensor (or + XRTSubBuffer for sliced buffers), so this just returns the resolved + buffer directly. + """ + if buffer_name in self._buffer_cache: + return self._buffer_cache[buffer_name] + buf = self._resolve_buffer(buffer_name) + self._buffer_cache[buffer_name] = buf + return buf + + def __call__(self): + # Sync all input buffers to device + for buf_name in self.op.input_args: + self._buffers[buf_name].to("npu") + + t0 = time.perf_counter() + for kernel, args in self._execution_plan: + kernel(*args) + self.last_elapsed = time.perf_counter() - t0 + + # Sync all base buffers from device so callers can read results + # (covers both output and scratch buffers) + for buf_name in self.op.subbuffer_layout: + if buf_name not in self.op.input_args: + self._buffers[buf_name].to("cpu") + + +# --------------------------------------------------------------------------- +# Reference and compare dispatch +# --------------------------------------------------------------------------- + + +class _CPUBuffer: + """Minimal buffer adapter compatible with the ``XRTTensor`` API used by + callers (``torch_view``, ``to("npu")``, ``to("cpu")``, ``fill_``). + + Backed by a flat 1D ``torch.bfloat16`` tensor in host memory. All device + sync calls are no-ops. + """ + + def __init__(self, n_elements): + self._t = torch.zeros(n_elements, dtype=torch.bfloat16) + + def torch_view(self): + return self._t + + def to(self, *_args, **_kwargs): + return self + + def fill_(self, value): + self._t.fill_(value) + return self + + def buffer_object(self): + return None + + +def _reshape_for_spec(flat_tensor, spec): + """Slice a flat host buffer to the element count implied by ``spec`` and + reshape it to the operator-declared shape. + + Returns a view (no copy).""" + n = int(np.prod(spec.shape)) if spec.shape else 1 + return flat_tensor[:n].reshape(spec.shape) + + +def _call_reference(step_op, inputs): + """Invoke ``step_op.reference(*inputs)`` if available. + + Returns the reference output tensor, or ``None`` if the operator has no + reference implementation. Propagates other exceptions. + """ + ref_fn = getattr(step_op, "reference", None) + if ref_fn is None: + return None + try: + return ref_fn(*inputs) + except NotImplementedError: + return None + + +class FusedReferenceCallable: + """Pure-CPU evaluation of a fused operator runlist. + + No NPU compilation or dispatch occurs. Each runlist step calls + ``op.reference(*inputs)`` on host-side ``torch.bfloat16`` buffers. + + Useful for validating the reference implementations themselves and for + comparing layer-by-layer expected outputs against NPU output. + """ + + def __init__(self, op): + self.op = op + self.last_elapsed = 0.0 + itemsize = np.dtype(ml_dtypes.bfloat16).itemsize + + self._buffers = {} # base buffer name -> _CPUBuffer + for buf_name, (_, _, length) in op.subbuffer_layout.items(): + n = max(length, itemsize) // itemsize + self._buffers[buf_name] = _CPUBuffer(n) + + # API parity with FusedFullELFCallable / FusedXclbinCallable + input_buffer_size, output_buffer_size, scratch_buffer_size = op.buffer_sizes + self.input_buffer = _CPUBuffer(max(input_buffer_size, itemsize) // itemsize) + self.output_buffer = _CPUBuffer(max(output_buffer_size, itemsize) // itemsize) + self.scratch_buffer = _CPUBuffer(max(scratch_buffer_size, itemsize) // itemsize) + + self._buffer_cache = {} + + def _resolve_buffer(self, buf_name): + if buf_name in self._buffers: + return self._buffers[buf_name] + if buf_name in self.op.slice_info: + base_name, start_bytes, end_bytes = self.op.slice_info[buf_name] + parent = self._buffers[base_name] + itemsize = np.dtype(ml_dtypes.bfloat16).itemsize + start = start_bytes // itemsize + end = end_bytes // itemsize + sliced = _CPUBuffer.__new__(_CPUBuffer) + sliced._t = parent.torch_view()[start:end] + self._buffers[buf_name] = sliced + return sliced + raise ValueError(f"Unknown buffer '{buf_name}' in fused runlist") + + def get_buffer(self, buffer_name): + if buffer_name in self._buffer_cache: + return self._buffer_cache[buffer_name] + buf = self._resolve_buffer(buffer_name) + self._buffer_cache[buffer_name] = buf + return buf + + def __call__(self): + t0 = time.perf_counter() + for step_op, *buf_names in self.op.runlist: + arg_specs = step_op.get_arg_spec() + if len(arg_specs) != len(buf_names): + raise ValueError( + f"Operator {step_op!r} arg-spec count {len(arg_specs)} " + f"does not match runlist buffer count {len(buf_names)}" + ) + *in_names, out_name = buf_names + *in_specs, out_spec = arg_specs + + inputs = [] + for name, spec in zip(in_names, in_specs): + flat = self._resolve_buffer(name).torch_view() + inputs.append(_reshape_for_spec(flat, spec).clone()) + + out = _call_reference(step_op, inputs) + if out is None: + raise NotImplementedError( + f"Operator {type(step_op).__name__} has no reference " + f"implementation; cannot use dispatch='reference'" + ) + + out_flat = self._resolve_buffer(out_name).torch_view() + n_out = int(np.prod(out_spec.shape)) if out_spec.shape else 1 + out_flat[:n_out].copy_(out.reshape(-1).to(torch.bfloat16)) + self.last_elapsed = time.perf_counter() - t0 + + +class FusedCompareCallable(FusedXclbinCallable): + """Run the separate-xclbin NPU pipeline and, after each step, run the + operator's CPU reference on the same (NPU-produced) inputs. + + Logs per-step max-abs and max-rel error. The NPU output is what + propagates to the next step on both sides, so each comparison reflects + only the deviation of the current operator (no error accumulation). + """ + + def __init__(self, op, rel_tol=0.05, abs_tol=1e-2): + super().__init__(op) + self.rel_tol = rel_tol + self.abs_tol = abs_tol + # Per-step diagnostic records populated on each __call__. + self.last_step_stats = [] + + def _read_buffer_to_cpu(self, name, spec): + """Sync a device buffer to host and return a reshaped float32 view.""" + buf = self._resolve_buffer(name) + buf.to("cpu") + flat = buf.torch_view() + n = int(np.prod(spec.shape)) if spec.shape else 1 + return flat[:n].clone().reshape(spec.shape) + + def __call__(self): + # Sync inputs to device. + for buf_name in self.op.input_args: + self._buffers[buf_name].to("npu") + + self.last_step_stats = [] + t0 = time.perf_counter() + + for step_idx, (kernel, args) in enumerate(self._execution_plan): + step_op, *buf_names = self.op.runlist[step_idx] + arg_specs = step_op.get_arg_spec() + *in_names, out_name = buf_names + *in_specs, out_spec = arg_specs + + # Snapshot NPU-side inputs before running the kernel. + cpu_inputs = [ + self._read_buffer_to_cpu(name, spec) + for name, spec in zip(in_names, in_specs) + ] + + # Run NPU step. + kernel(*args) + + # Read NPU output. + npu_out = self._read_buffer_to_cpu(out_name, out_spec).to(torch.float32) + + # Run reference on the same inputs. + ref_out = _call_reference(step_op, cpu_inputs) + stats = { + "step": step_idx, + "op": type(step_op).__name__, + "op_name": getattr(step_op, "name", type(step_op).__name__), + "inputs": list(in_names), + "output": out_name, + } + if ref_out is None: + stats["skipped"] = True + logger.info( + "[compare step %d] %s -> %s: no reference (skipped)", + step_idx, + stats["op"], + out_name, + ) + else: + ref_flat = ref_out.reshape(out_spec.shape).to(torch.float32) + diff = (npu_out - ref_flat).abs() + ref_mag = ref_flat.abs() + max_abs = float(diff.max()) + ref_max = float(ref_mag.max()) + rel = float((diff / (ref_mag + 1e-6)).max()) + mean_abs = float(diff.mean()) + stats.update( + skipped=False, + max_abs=max_abs, + mean_abs=mean_abs, + max_rel=rel, + ref_max=ref_max, + ) + fail = (max_abs > self.abs_tol) and (rel > self.rel_tol) + level = logging.WARNING if fail else logging.INFO + logger.log( + level, + "[compare step %d] %s -> %s: max_abs=%.4g mean_abs=%.4g max_rel=%.4g ref_max=%.4g%s", + step_idx, + stats["op"], + out_name, + max_abs, + mean_abs, + rel, + ref_max, + " MISMATCH" if fail else "", + ) + self.last_step_stats.append(stats) + + self.last_elapsed = time.perf_counter() - t0 + + # Sync all base buffers back so callers can read results. + for buf_name in self.op.subbuffer_layout: + if buf_name not in self.op.input_args: + self._buffers[buf_name].to("cpu") diff --git a/iron/operators/elementwise_add/op.py b/iron/operators/elementwise_add/op.py index b69ba68d..0995e7fe 100644 --- a/iron/operators/elementwise_add/op.py +++ b/iron/operators/elementwise_add/op.py @@ -15,3 +15,8 @@ class ElementwiseAdd(BinaryElementwiseOperator): kernel_fn_name: ClassVar[str] = "eltwise_add_bf16_vector" kernel_subdir: ClassVar[str] = "generic" callback_fn: ClassVar[str] = "my_eltwise_add" + + def reference(self, a, b): + import torch + + return (a.to(torch.float32) + b.to(torch.float32)).to(torch.bfloat16) diff --git a/iron/operators/elementwise_mul/op.py b/iron/operators/elementwise_mul/op.py index 7e83689b..0c3b244c 100644 --- a/iron/operators/elementwise_mul/op.py +++ b/iron/operators/elementwise_mul/op.py @@ -15,3 +15,8 @@ class ElementwiseMul(BinaryElementwiseOperator): kernel_fn_name: ClassVar[str] = "eltwise_mul_bf16_vector" kernel_subdir: ClassVar[str] = "generic" callback_fn: ClassVar[str] = "my_eltwise_mul" + + def reference(self, a, b): + import torch + + return (a.to(torch.float32) * b.to(torch.float32)).to(torch.bfloat16) diff --git a/iron/operators/gemm/op.py b/iron/operators/gemm/op.py index ac391cc3..f417e7af 100644 --- a/iron/operators/gemm/op.py +++ b/iron/operators/gemm/op.py @@ -160,6 +160,19 @@ def get_arg_spec(self): ), # output C ] + def reference(self, A, B): + """CPU reference: ``C = A @ B`` honoring ``b_col_maj`` / ``c_col_maj``.""" + import torch + + A32 = A.to(torch.float32) + B32 = B.to(torch.float32) + if self.b_col_maj: + B32 = B32.transpose(-1, -2) + C = A32 @ B32 + if self.c_col_maj: + C = C.transpose(-1, -2) + return C.contiguous().to(torch.bfloat16) + def pad_A(self, A_np): """Pad A matrix to match operator dimensions (M, K)""" M, K = A_np.shape diff --git a/iron/operators/gemv/op.py b/iron/operators/gemv/op.py index a21872ad..453fe9b5 100644 --- a/iron/operators/gemv/op.py +++ b/iron/operators/gemv/op.py @@ -99,3 +99,11 @@ def get_arg_spec(self): AIERuntimeArgSpec("in", batch_dim + (self.K,)), # vector AIERuntimeArgSpec("out", batch_dim + (self.M,)), # output ] + + def reference(self, A, B): + """CPU reference: (optionally batched) matrix-vector product.""" + import torch + + A32 = A.to(torch.float32) + B32 = B.to(torch.float32) + return (A32 @ B32.unsqueeze(-1)).squeeze(-1).to(torch.bfloat16) diff --git a/iron/operators/repeat/op.py b/iron/operators/repeat/op.py index 2c5b9f09..96a85655 100644 --- a/iron/operators/repeat/op.py +++ b/iron/operators/repeat/op.py @@ -59,3 +59,7 @@ def get_arg_spec(self): AIERuntimeArgSpec("in", (self.rows, self.cols)), AIERuntimeArgSpec("out", (self.rows * self.repeat, self.cols)), ] + + def reference(self, x): + """CPU reference: repeat-interleave along the leading dimension.""" + return x.repeat_interleave(self.repeat, dim=0) diff --git a/iron/operators/rms_norm/op.py b/iron/operators/rms_norm/op.py index 4f7591ea..94791209 100644 --- a/iron/operators/rms_norm/op.py +++ b/iron/operators/rms_norm/op.py @@ -124,3 +124,16 @@ def get_arg_spec(self): AIERuntimeArgSpec("out", (self.size // self.tile_size, self.tile_size)) ) return specs + + def reference(self, x, w=None): + """CPU reference: row-wise RMS normalization, optionally weighted.""" + import torch + + x32 = x.to(torch.float32) + rms = torch.sqrt((x32 * x32).mean(dim=-1, keepdim=True)) + out = x32 / (rms + 1e-5) + if self.weighted: + if w is None: + raise ValueError("weighted RMSNorm requires weight input") + out = out * w.to(torch.float32) + return out.to(torch.bfloat16) diff --git a/iron/operators/rope/op.py b/iron/operators/rope/op.py index 6a9a3d7e..a9029faf 100644 --- a/iron/operators/rope/op.py +++ b/iron/operators/rope/op.py @@ -92,3 +92,37 @@ def get_arg_spec(self): AIERuntimeArgSpec("in", (self.angle_rows, self.cols)), # angles AIERuntimeArgSpec("out", (self.rows, self.cols)), # output ] + + def reference(self, x, angles): + """CPU reference for RoPE. + + Assumes ``angles`` holds interleaved [cos, sin, cos, sin, ...] pairs + along the last dim (length ``cols``). Only ``method_type == 0`` + (TWO_HALVES) is currently supported. + + ``angles`` may have fewer rows than ``x``; in that case the angles + are tiled along the row dimension to match ``x``.""" + import torch + + if self.method_type != 0: + raise NotImplementedError( + f"RoPE reference only supports method_type=0 (TWO_HALVES), " + f"got {self.method_type}" + ) + rows, cols = self.rows, self.cols + half = cols // 2 + cos = angles[..., 0::2].to(torch.float32) + sin = angles[..., 1::2].to(torch.float32) + if cos.shape[0] != rows: + if rows % cos.shape[0] == 0: + rep = rows // cos.shape[0] + cos = cos.repeat(rep, 1) + sin = sin.repeat(rep, 1) + else: + cos = cos[:rows] + sin = sin[:rows] + x32 = x.to(torch.float32) + x1, x2 = x32[..., :half], x32[..., half:] + y1 = x1 * cos - x2 * sin + y2 = x2 * cos + x1 * sin + return torch.cat([y1, y2], dim=-1).to(torch.bfloat16) diff --git a/iron/operators/silu/op.py b/iron/operators/silu/op.py index 8b7c9853..63866242 100644 --- a/iron/operators/silu/op.py +++ b/iron/operators/silu/op.py @@ -17,3 +17,9 @@ class SiLU(ChanneledUnaryOperator): kernel_fn_name: ClassVar[str] = "silu_bf16" callback_fn: ClassVar[str] = "my_silu" needs_lut_ops: ClassVar[bool] = True + + def reference(self, x): + import torch + + x32 = x.to(torch.float32) + return (x32 * torch.sigmoid(x32)).to(torch.bfloat16) diff --git a/iron/operators/softmax/op.py b/iron/operators/softmax/op.py index 71aec051..462d5e6c 100644 --- a/iron/operators/softmax/op.py +++ b/iron/operators/softmax/op.py @@ -98,3 +98,15 @@ def get_arg_spec(self): AIERuntimeArgSpec("in", (self.size,)), AIERuntimeArgSpec("out", (self.size,)), ] + + def reference(self, x): + """CPU reference: row-wise softmax over ``cols``. + + Note: ignores the runtime ``vector_size_parameter`` (if any); the + reference always softmaxes over the full ``cols``. For decode-style + usage with a masked tail, the trailing positions will not match the + NPU output.""" + import torch + + x2 = x.reshape(self.rows, self.cols).to(torch.float32) + return torch.softmax(x2, dim=-1).reshape(-1).to(torch.bfloat16) diff --git a/iron/operators/transpose/op.py b/iron/operators/transpose/op.py index d37e0101..9eb0f018 100644 --- a/iron/operators/transpose/op.py +++ b/iron/operators/transpose/op.py @@ -94,3 +94,7 @@ def get_arg_spec(self): AIERuntimeArgSpec("in", (self.M * self.N,)), AIERuntimeArgSpec("out", (self.M * self.N,)), ] + + def reference(self, x): + """CPU reference: 2D transpose of an (M, N) matrix stored row-major.""" + return x.reshape(self.M, self.N).transpose(0, 1).contiguous().reshape(-1)