Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 45 additions & 19 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ async def _execute_and_compare(
num_calls: int,
produce_torch_out: Any,
compare: Any,
coreai_program: Any = None,
dump_path: Path | None = None,
) -> None:
"""The single place we invoke the Core AI runtime.
Expand All @@ -394,28 +395,17 @@ async def _execute_and_compare(
two callbacks let path A (stateful, recompute per call) and path B
(fixed expected output) share this loop without ad-hoc branching.
"""
# Stateful tests can't round-trip for now,
# which has no `--input` JSON path for state — so skip the dump entirely.
should_dump = dump_optests_enabled() and dump_path is not None and not state
try:
for call_idx in range(num_calls):
io_numpy = {}
if call_idx == 0 and dump_optests_enabled():
assert dump_path is not None
for name, arr in state.items():
io_numpy[f"initial_state_{name}"] = arr.numpy()

torch_out = produce_torch_out(call_idx)
rt_outputs = await rt_func(inputs=inputs, state=state)
compare(rt_outputs, torch_out, call_idx)

if call_idx == 0 and dump_optests_enabled():
assert dump_path is not None
for name, arr in inputs.items():
io_numpy[name] = arr.numpy()
for name, arr in state.items():
io_numpy[f"final_state_{name}"] = arr.numpy()
for name, arr in rt_outputs.items():
io_numpy[name] = arr.numpy()

np.savez(dump_path / "test_data.npz", **io_numpy)
if call_idx == 0 and should_dump:
_dump_optest_artifacts(coreai_program, inputs, rt_outputs, dump_path)

except Exception:
# Wipe bytecode and reference IO if the test failed as
Expand All @@ -425,6 +415,39 @@ async def _execute_and_compare(
raise


def _add_npz_entry(io_numpy: dict[str, np.ndarray], key: str, arr: np.ndarray) -> None:
"""Add an array to the npz dict, emitting a bf16 dtype override if needed.

NumPy has no native bf16, so surfaces it as void16 (``|V2``).
"""
io_numpy[key] = arr
if arr.dtype.str == "|V2":
io_numpy[f"_dtype_{key}"] = np.array("bf16")


def _dump_optest_artifacts(
coreai_program: Any,
inputs: dict[str, NDArray],
rt_outputs: dict[str, NDArray],
dump_path: Path,
) -> None:
"""Write a `<testname>.aimodel` + `<testname>_test_data.npz` pair.

Format: aimodel prefix == npz prefix == dump_path
leaf name; npz holds an ``op_name`` scalar plus ``input_<n>`` /
``output_<n>`` keys.
"""
testname = dump_path.name
coreai_program.save_asset(dump_path / f"{testname}.aimodel")

io_numpy: dict[str, np.ndarray] = {"op_name": np.array("main")}
for name, arr in inputs.items():
_add_npz_entry(io_numpy, f"input_{name}", arr.numpy())
for name, arr in rt_outputs.items():
_add_npz_entry(io_numpy, f"output_{name}", arr.numpy())
np.savez(dump_path / f"{testname}_test_data.npz", **io_numpy)


async def _run_with_model(
model: torch.nn.Module,
rt_func: Any,
Expand All @@ -437,6 +460,7 @@ async def _run_with_model(
rtol: float,
atol: float,
metal_inputs: bool = False,
coreai_program: Any = None,
dump_path: Path | None = None,
) -> None:
"""Path A: stateful, multi-call, name-based matching."""
Expand Down Expand Up @@ -478,6 +502,7 @@ def compare(
num_calls=num_calls,
produce_torch_out=produce_torch_out,
compare=compare,
coreai_program=coreai_program,
dump_path=dump_path,
)

Expand All @@ -490,6 +515,7 @@ async def _run_with_program(
rtol: float,
atol: float,
metal_inputs: bool = False,
coreai_program: Any = None,
dump_path: Path | None = None,
) -> None:
"""Path B: pre-converted program, single call, sorted-key matching."""
Expand Down Expand Up @@ -519,6 +545,7 @@ def compare(
num_calls=1,
produce_torch_out=produce_torch_out,
compare=compare,
coreai_program=coreai_program,
dump_path=dump_path,
)

Expand Down Expand Up @@ -592,9 +619,6 @@ async def validate_numerical_output(**kwargs: Any) -> None:
if dump_optests_enabled():
dump_path = _optest_dump_path(get_current_test_id())
dump_path.mkdir(parents=True, exist_ok=True)
model_path = dump_path / "main.AICode.bc"
model_path.unlink(missing_ok=True)
coreai_program._save_bytecode(model_path)

with TemporaryDirectory() as temp_directory:
aimodel_path = Path(temp_directory) / "model.aimodel"
Expand All @@ -616,6 +640,7 @@ async def validate_numerical_output(**kwargs: Any) -> None:
rtol=rtol,
atol=atol,
metal_inputs=metal_inputs,
coreai_program=coreai_program,
dump_path=dump_path,
)
else:
Expand All @@ -626,6 +651,7 @@ async def validate_numerical_output(**kwargs: Any) -> None:
rtol=rtol,
atol=atol,
metal_inputs=metal_inputs,
coreai_program=coreai_program,
dump_path=dump_path,
)

Expand Down