Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
507 changes: 500 additions & 7 deletions iron/common/fusion.py

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions iron/operators/elementwise_add/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,8 @@ class ElementwiseAdd(BinaryElementwiseOperator):
kernel_fn_name: ClassVar[str] = "eltwise_add_bf16_vector"
kernel_subdir: ClassVar[str] = "generic"
callback_fn: ClassVar[str] = "my_eltwise_add"

def reference(self, a, b):
import torch

return (a.to(torch.float32) + b.to(torch.float32)).to(torch.bfloat16)
5 changes: 5 additions & 0 deletions iron/operators/elementwise_mul/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,8 @@ class ElementwiseMul(BinaryElementwiseOperator):
kernel_fn_name: ClassVar[str] = "eltwise_mul_bf16_vector"
kernel_subdir: ClassVar[str] = "generic"
callback_fn: ClassVar[str] = "my_eltwise_mul"

def reference(self, a, b):
import torch

return (a.to(torch.float32) * b.to(torch.float32)).to(torch.bfloat16)
13 changes: 13 additions & 0 deletions iron/operators/gemm/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,19 @@ def get_arg_spec(self):
), # output C
]

def reference(self, A, B):
"""CPU reference: ``C = A @ B`` honoring ``b_col_maj`` / ``c_col_maj``."""
import torch

A32 = A.to(torch.float32)
B32 = B.to(torch.float32)
if self.b_col_maj:
B32 = B32.transpose(-1, -2)
C = A32 @ B32
if self.c_col_maj:
C = C.transpose(-1, -2)
return C.contiguous().to(torch.bfloat16)

def pad_A(self, A_np):
"""Pad A matrix to match operator dimensions (M, K)"""
M, K = A_np.shape
Expand Down
8 changes: 8 additions & 0 deletions iron/operators/gemv/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,11 @@ def get_arg_spec(self):
AIERuntimeArgSpec("in", batch_dim + (self.K,)), # vector
AIERuntimeArgSpec("out", batch_dim + (self.M,)), # output
]

def reference(self, A, B):
"""CPU reference: (optionally batched) matrix-vector product."""
import torch

A32 = A.to(torch.float32)
B32 = B.to(torch.float32)
return (A32 @ B32.unsqueeze(-1)).squeeze(-1).to(torch.bfloat16)
4 changes: 4 additions & 0 deletions iron/operators/repeat/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,7 @@ def get_arg_spec(self):
AIERuntimeArgSpec("in", (self.rows, self.cols)),
AIERuntimeArgSpec("out", (self.rows * self.repeat, self.cols)),
]

def reference(self, x):
"""CPU reference: repeat-interleave along the leading dimension."""
return x.repeat_interleave(self.repeat, dim=0)
13 changes: 13 additions & 0 deletions iron/operators/rms_norm/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,16 @@ def get_arg_spec(self):
AIERuntimeArgSpec("out", (self.size // self.tile_size, self.tile_size))
)
return specs

def reference(self, x, w=None):
"""CPU reference: row-wise RMS normalization, optionally weighted."""
import torch

x32 = x.to(torch.float32)
rms = torch.sqrt((x32 * x32).mean(dim=-1, keepdim=True))
out = x32 / (rms + 1e-5)
if self.weighted:
if w is None:
raise ValueError("weighted RMSNorm requires weight input")
out = out * w.to(torch.float32)
return out.to(torch.bfloat16)
34 changes: 34 additions & 0 deletions iron/operators/rope/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,37 @@ def get_arg_spec(self):
AIERuntimeArgSpec("in", (self.angle_rows, self.cols)), # angles
AIERuntimeArgSpec("out", (self.rows, self.cols)), # output
]

def reference(self, x, angles):
"""CPU reference for RoPE.

Assumes ``angles`` holds interleaved [cos, sin, cos, sin, ...] pairs
along the last dim (length ``cols``). Only ``method_type == 0``
(TWO_HALVES) is currently supported.

``angles`` may have fewer rows than ``x``; in that case the angles
are tiled along the row dimension to match ``x``."""
import torch

if self.method_type != 0:
raise NotImplementedError(
f"RoPE reference only supports method_type=0 (TWO_HALVES), "
f"got {self.method_type}"
)
rows, cols = self.rows, self.cols
half = cols // 2
cos = angles[..., 0::2].to(torch.float32)
sin = angles[..., 1::2].to(torch.float32)
if cos.shape[0] != rows:
if rows % cos.shape[0] == 0:
rep = rows // cos.shape[0]
cos = cos.repeat(rep, 1)
sin = sin.repeat(rep, 1)
else:
cos = cos[:rows]
sin = sin[:rows]
x32 = x.to(torch.float32)
x1, x2 = x32[..., :half], x32[..., half:]
y1 = x1 * cos - x2 * sin
y2 = x2 * cos + x1 * sin
return torch.cat([y1, y2], dim=-1).to(torch.bfloat16)
6 changes: 6 additions & 0 deletions iron/operators/silu/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,9 @@ class SiLU(ChanneledUnaryOperator):
kernel_fn_name: ClassVar[str] = "silu_bf16"
callback_fn: ClassVar[str] = "my_silu"
needs_lut_ops: ClassVar[bool] = True

def reference(self, x):
import torch

x32 = x.to(torch.float32)
return (x32 * torch.sigmoid(x32)).to(torch.bfloat16)
12 changes: 12 additions & 0 deletions iron/operators/softmax/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,15 @@ def get_arg_spec(self):
AIERuntimeArgSpec("in", (self.size,)),
AIERuntimeArgSpec("out", (self.size,)),
]

def reference(self, x):
"""CPU reference: row-wise softmax over ``cols``.

Note: ignores the runtime ``vector_size_parameter`` (if any); the
reference always softmaxes over the full ``cols``. For decode-style
usage with a masked tail, the trailing positions will not match the
NPU output."""
import torch

x2 = x.reshape(self.rows, self.cols).to(torch.float32)
return torch.softmax(x2, dim=-1).reshape(-1).to(torch.bfloat16)
4 changes: 4 additions & 0 deletions iron/operators/transpose/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,7 @@ def get_arg_spec(self):
AIERuntimeArgSpec("in", (self.M * self.N,)),
AIERuntimeArgSpec("out", (self.M * self.N,)),
]

def reference(self, x):
"""CPU reference: 2D transpose of an (M, N) matrix stored row-major."""
return x.reshape(self.M, self.N).transpose(0, 1).contiguous().reshape(-1)
Loading