diff --git a/.github/workflows/yateto-pytest.yml b/.github/workflows/yateto-pytest.yml new file mode 100644 index 0000000..a858c23 --- /dev/null +++ b/.github/workflows/yateto-pytest.yml @@ -0,0 +1,43 @@ +name: yateto-pytest + +# Fast, dependency-free Python-level tests. Runs on every push and +# complements the heavier ``yateto-cpu`` workflow (which builds C++ +# and exercises each GEMM back-end). + +on: push + +jobs: + pytest: + runs-on: ubuntu-24.04 + strategy: + fail-fast: false + matrix: + python-version: ['3.10', '3.11', '3.12', '3.13'] + + steps: + - uses: actions/checkout@v6 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + + - name: Install + run: | + pip install -e . + pip install pytest pytest-cov hypothesis numpy + + - name: Run pytest suite + run: | + python -m pytest tests/pytest/ -v --tb=short \ + --cov=yateto \ + --cov-report=xml \ + --cov-report=term-missing \ + --cov-branch + + - name: Upload coverage + if: matrix.python-version == '3.12' + uses: actions/upload-artifact@v7 + with: + name: python-coverage + path: coverage.xml diff --git a/tests/pytest/conftest.py b/tests/pytest/conftest.py new file mode 100644 index 0000000..73fc4c2 --- /dev/null +++ b/tests/pytest/conftest.py @@ -0,0 +1,101 @@ +""" +Pytest configuration and fixtures for the Yateto Python unit-test suite. + +These tests exercise Yateto's compiler-style pipeline purely in Python +(frontend DSL -> AST passes -> control-flow graph). They intentionally +stop before C++ code generation / compilation, so they are fast, do not +depend on libxsmm / PSpaMM / CxxTest / a C++ toolchain, and can run +everywhere the `yateto` package imports. + +The matching C++/code-gen integration tests live under ``tests/code-gen`` +and are driven by the GitHub Actions workflow ``yateto-cpu.yml`` - we +do not duplicate them here. +""" +from __future__ import annotations + +import os +import sys + +import pytest + + +# Make the yateto source tree importable even when the package has not been +# installed via ``pip install -e .`` (e.g. when running the tests locally +# straight from a clone). +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +# --------------------------------------------------------------------------- +# Common fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def arch(): + """A host architecture used when passes need alignment info. + + ``dhsw`` = double precision on Haswell. Same default as the example + scripts. We re-set the layout's global alignment reference on every + test so no state leaks between tests. + """ + from yateto import useArchitectureIdentifiedBy + from yateto.memory import DenseMemoryLayout + + a = useArchitectureIdentifiedBy("dhsw") + yield a + # Reset global alignment state to keep tests hermetic. + DenseMemoryLayout.ALIGNMENT_ARCH = None + + +@pytest.fixture +def square_tensors(): + """A handful of 8x8 tensors, useful for most elementwise/matmul tests.""" + from yateto import Tensor + + N = 8 + return { + "N": N, + "A": Tensor("A", (N, N)), + "B": Tensor("B", (N, N)), + "C": Tensor("C", (N, N)), + } + + +@pytest.fixture +def deduced(): + """Helper that runs ``DeduceIndices`` on an AST and returns the result. + + ``DeduceIndices`` is the first mandatory pass after the DSL builds the + tree - without it most other visitors/transformers are not meaningful, + so almost every test needs it. + """ + from yateto.ast.transformer import DeduceIndices + + def _deduce(ast, target=None): + return DeduceIndices(target).visit(ast) + + return _deduce + + +@pytest.fixture +def run_ast_pipeline(arch): + """Push an AST through the middle-end up to the point where flops are + countable. Returns the (transformed) AST so tests can inspect it. + """ + from yateto.ast.transformer import ( + DeduceIndices, + EquivalentSparsityPattern, + SetSparsityPattern, + StrengthReduction, + ) + from yateto.ast.cost import BoundingBoxCostEstimator + + def _run(ast): + ast = DeduceIndices().visit(ast) + ast = EquivalentSparsityPattern().visit(ast) + ast = StrengthReduction(BoundingBoxCostEstimator).visit(ast) + ast = SetSparsityPattern().visit(ast) + return ast + + return _run diff --git a/tests/pytest/test_aspp.py b/tests/pytest/test_aspp.py new file mode 100644 index 0000000..e80715b --- /dev/null +++ b/tests/pytest/test_aspp.py @@ -0,0 +1,192 @@ +""" +Tests for ``yateto.aspp`` - abstract sparsity patterns. + +Yateto propagates a sparsity pattern through every AST node, uses it to +perform strength reduction, and finally feeds it to the back-end so dense +GEMM calls can be specialised for zero-filled rows/columns. Two concrete +implementations exist: + +* ``dense`` - a lightweight, shape-only representation +* ``general`` - a numpy-backed bit pattern + +Mixed operations dispatch between them. +""" +from __future__ import annotations + +import numpy as np +import pytest + +from yateto import aspp + + +# --------------------------------------------------------------------------- +# dense +# --------------------------------------------------------------------------- + + +class TestDense: + def test_count_nonzero_is_size(self): + d = aspp.dense((3, 4)) + assert d.count_nonzero() == 12 + assert d.size == 12 + + def test_is_dense(self): + assert aspp.dense((3, 4)).is_dense() + + def test_shape(self): + d = aspp.dense((2, 3, 5)) + assert d.shape == (2, 3, 5) + assert d.ndim == 3 + + def test_reshape(self): + d = aspp.dense((2, 3)).reshape((6,)) + assert d.shape == (6,) + + def test_reshape_checks_size(self): + with pytest.raises(AssertionError, match="Size mismatch"): + aspp.dense((2, 3)).reshape((4,)) + + def test_transpose(self): + d = aspp.dense((2, 3)).transposed((1, 0)) + assert d.shape == (3, 2) + + def test_broadcast(self): + d = aspp.dense((2, 3)).broadcast((3, 2)) + # broadcast multiplies each dim by the factor + assert d.shape == (6, 6) + + def test_indexSum_drops_axes(self): + from yateto.ast.indices import Indices + src = Indices("ijk", (2, 3, 4)) + tgt = Indices("ik", (2, 4)) + d = aspp.dense((2, 3, 4)).indexSum(src, tgt) + assert d.shape == (2, 4) + + def test_add_same_shape(self): + result = aspp.dense.add(aspp.dense((2, 3)), aspp.dense((2, 3))) + assert result.shape == (2, 3) + assert result.is_dense() + + def test_add_shape_mismatch_asserts(self): + with pytest.raises(AssertionError): + aspp.dense.add(aspp.dense((2, 3)), aspp.dense((3, 2))) + + def test_einsum_shape_inference(self): + # Classic matmul: (i,j) * (j,k) -> (i,k) + result = aspp.dense.einsum("ij,jk->ik", aspp.dense((3, 4)), aspp.dense((4, 5))) + assert result.shape == (3, 5) + + def test_einsum_rejects_bad_description(self): + with pytest.raises(ValueError, match="not understood"): + aspp.dense.einsum("bogus", aspp.dense((2,)), aspp.dense((2,))) + + def test_as_ndarray_is_all_ones(self): + arr = aspp.dense((2, 3)).as_ndarray() + assert arr.shape == (2, 3) + assert arr.dtype == bool + assert arr.all() + + +# --------------------------------------------------------------------------- +# general +# --------------------------------------------------------------------------- + + +class TestGeneral: + def test_basic_count(self): + pattern = np.array([[1, 0], [0, 1]], dtype=bool) + g = aspp.general(pattern) + assert g.count_nonzero() == 2 + assert g.shape == (2, 2) + + def test_is_dense_only_if_fully_filled(self): + g_full = aspp.general(np.ones((2, 2), dtype=bool)) + assert g_full.is_dense() + g_sparse = aspp.general(np.eye(2, dtype=bool)) + assert not g_sparse.is_dense() + + def test_transpose(self): + pattern = np.array([[1, 0, 1], [0, 1, 0]], dtype=bool) + g = aspp.general(pattern).transposed((1, 0)) + # Transposing a 2x3 pattern should produce a 3x2 pattern that is + # element-wise consistent with np.transpose. + assert np.array_equal(g.as_ndarray(), pattern.T) + + def test_einsum_matches_numpy(self): + A = np.array([[1, 0], [1, 1]], dtype=bool) + B = np.array([[1, 1], [0, 1]], dtype=bool) + g = aspp.general.einsum("ij,jk->ik", aspp.general(A), aspp.general(B)) + # Note: numpy einsum on bool does a logical OR / AND, but here we + # compare against the cast-to-bool of a real matmul. + expected = (A.astype(int) @ B.astype(int)) > 0 + assert np.array_equal(g.as_ndarray(), expected) + + def test_nonzero(self): + pattern = np.array([[1, 0], [0, 1]], dtype=bool) + nz = aspp.general(pattern).nonzero() + # numpy's ``.nonzero()`` returns a tuple of arrays + assert len(nz) == 2 + np.testing.assert_array_equal(nz[0], [0, 1]) + np.testing.assert_array_equal(nz[1], [0, 1]) + + def test_nnzbounds_returns_inclusive_bounds_per_axis(self): + pattern = np.zeros((5, 5), dtype=bool) + pattern[1:4, 2:5] = True + g = aspp.general(pattern) + bounds = g.nnzbounds() + # Per-axis (min, max) inclusive bounds of nonzero entries. + assert bounds == [(1, 3), (2, 4)] + + def test_copy_is_independent(self): + pattern = np.eye(3, dtype=bool) + g = aspp.general(pattern) + h = g.copy() + # Mutating the source must not affect the copy. + g.pattern[0, 0] = False + assert h.count_nonzero() == 3 + + +# --------------------------------------------------------------------------- +# Cross-class dispatch (dense/general mixed) +# --------------------------------------------------------------------------- + + +class TestDispatch: + def test_add_dense_and_general_promotes_to_general(self): + d = aspp.dense((2, 2)) + g = aspp.general(np.eye(2, dtype=bool)) + result = aspp.add(d, g) + assert isinstance(result, aspp.general) + # dense contributes "all ones", so the result is all ones + assert result.count_nonzero() == 4 + + def test_add_two_dense_stays_dense(self): + result = aspp.add(aspp.dense((2, 3)), aspp.dense((2, 3))) + assert isinstance(result, aspp.dense) + + def test_add_two_general(self): + a = aspp.general(np.array([[1, 0], [0, 0]], dtype=bool)) + b = aspp.general(np.array([[0, 1], [0, 0]], dtype=bool)) + result = aspp.add(a, b) + assert isinstance(result, aspp.general) + assert result.count_nonzero() == 2 + + def test_einsum_dispatches(self): + # Mixed types in an einsum must not crash - they route through + # ``dispatch`` which converts dense -> general on demand. + d = aspp.dense((3, 4)) + g = aspp.general(np.ones((4, 5), dtype=bool)) + result = aspp.einsum("ij,jk->ik", d, g) + assert result.shape == (3, 5) + + def test_array_equal_across_classes(self): + d = aspp.dense((2, 3)) + g = aspp.general(np.ones((2, 3), dtype=bool)) + assert aspp.array_equal(d, g) + # Different shape -> not equal. + assert not aspp.array_equal(d, aspp.dense((3, 2))) + + def test_array_equal_handles_none(self): + # Yateto occasionally compares ``None`` equivalents. + assert aspp.array_equal(None, None) + assert not aspp.array_equal(None, aspp.dense((2, 2))) diff --git a/tests/pytest/test_ast_node.py b/tests/pytest/test_ast_node.py new file mode 100644 index 0000000..f681220 --- /dev/null +++ b/tests/pytest/test_ast_node.py @@ -0,0 +1,448 @@ +""" +Tests for ``yateto.ast.node`` - the core AST. + +The DSL turns Python expressions into a tree of ``Node`` subclasses using +operator overloading: + + C['ij'] <= A['ik'] * B['kj'] + +becomes + + Assign( + IndexedTensor(C, "ij"), + Einsum( + IndexedTensor(A, "ik"), + IndexedTensor(B, "kj"), + ), + ) + +This module checks that: + +* the DSL really produces the expected tree shape, +* the tree's invariants (no nested ``ScalarMultiplication``, ``Assign`` lhs + must be an ``IndexedTensor``, associative operators absorb their peers, ...) + are enforced, +* the per-node sparsity-pattern / flop-count helpers are correct, +* the specialised nodes used by the middle-end (``Product``, ``IndexSum``, + ``Contraction``, ``LoopOverGEMM``, ``FusedGEMMs``, ``SliceView``, + ``Permute``, ``Broadcast``) behave as advertised. +""" +from __future__ import annotations + +import pytest + +from yateto import Tensor +from yateto.ast.indices import Indices +from yateto.ast.node import ( + Add, + Assign, + BinOp, + Broadcast, + Contraction, + Einsum, + FusedGEMMs, + IndexedTensor, + IndexSum, + LoopOverGEMM, + Op, + Permute, + Product, + ScalarMultiplication, + SliceView, + UnaryOp, +) +from yateto.memory import DenseMemoryLayout +from yateto import aspp + + +# --------------------------------------------------------------------------- +# IndexedTensor - leaves of the tree +# --------------------------------------------------------------------------- + + +class TestIndexedTensor: + def test_construction_via_getitem(self): + A = Tensor("A", (3, 4)) + it = A["ij"] + assert isinstance(it, IndexedTensor) + assert it.tensor is A + assert str(it.indices) == "ij" + assert it.indices.shape() == (3, 4) + + def test_index_arity_must_match_tensor_rank(self): + A = Tensor("A", (3, 4)) + with pytest.raises(AssertionError): + A["ijk"] # tensor is rank 2 + + def test_name_delegates_to_tensor(self): + A = Tensor("A", (2, 2)) + assert A["ij"].name() == "A" + + def test_nonZeroFlops_is_zero(self): + # Reading a leaf costs nothing. + assert Tensor("A", (2, 2))["ij"].nonZeroFlops() == 0 + + +# --------------------------------------------------------------------------- +# Einsum - via ``*`` +# --------------------------------------------------------------------------- + + +class TestEinsumBuilding: + def test_mul_creates_einsum(self, square_tensors): + A, B = square_tensors["A"], square_tensors["B"] + expr = A["ik"] * B["kj"] + assert isinstance(expr, Einsum) + assert len(expr) == 2 + + def test_mul_is_left_associative_and_flattens(self, square_tensors): + # ``a * b * c`` should produce a single ``Einsum(a, b, c)`` node, + # not ``Einsum(Einsum(a,b), c)``. This is what ``_binOp`` guarantees. + A, B, C = square_tensors["A"], square_tensors["B"], square_tensors["C"] + expr = A["ij"] * B["jk"] * C["kl"] + assert isinstance(expr, Einsum) + assert len(expr) == 3 + + def test_mul_flattens_right_side_too(self, square_tensors): + A, B, C = square_tensors["A"], square_tensors["B"], square_tensors["C"] + # (A) * (B * C) should also flatten. + right = B["jk"] * C["kl"] + expr = A["ij"] * right + assert isinstance(expr, Einsum) + assert len(expr) == 3 + + +# --------------------------------------------------------------------------- +# ScalarMultiplication - via ``*`` with a float/int +# --------------------------------------------------------------------------- + + +class TestScalarMultiplication: + def test_lhs_scalar(self, square_tensors): + A = square_tensors["A"] + expr = 2.0 * A["ij"] + assert isinstance(expr, ScalarMultiplication) + assert expr.is_constant() + assert expr.scalar() == 2.0 + + def test_rhs_scalar(self, square_tensors): + A = square_tensors["A"] + expr = A["ij"] * 2.0 + assert isinstance(expr, ScalarMultiplication) + assert expr.scalar() == 2.0 + + def test_negation(self, square_tensors): + A = square_tensors["A"] + expr = -A["ij"] + assert isinstance(expr, ScalarMultiplication) + assert expr.scalar() == -1.0 + + def test_nested_scalar_mul_rejected(self, square_tensors): + # ``k1 * (k2 * A)`` is disallowed by design - the user must + # pre-fold scalars into a single coefficient. This keeps the AST + # unambiguous and the code generator simple. + A = square_tensors["A"] + with pytest.raises(ValueError, match="Multiple multiplications"): + 2.0 * (3.0 * A["ij"]) + + def test_scalar_times_einsum_preserves_einsum_child(self, square_tensors): + A, B = square_tensors["A"], square_tensors["B"] + expr = 2.0 * (A["ik"] * B["kj"]) + assert isinstance(expr, ScalarMultiplication) + # The term inside is an Einsum, not a ScalarMultiplication. + assert isinstance(expr.term(), Einsum) + + def test_nonZeroFlops_is_zero_for_pm_one(self, square_tensors, run_ast_pipeline): + A = square_tensors["A"] + B = square_tensors["B"] + expr = Assign(A["ij"], -B["ij"]) + ast = run_ast_pipeline(expr) + # Find the scalar-mul child (the rhs) and check its flops. + rhs = ast.rightTerm() + # ``-1.0`` is a free sign flip. + assert isinstance(rhs, ScalarMultiplication) + assert rhs.nonZeroFlops() == 0 + + +# --------------------------------------------------------------------------- +# Add - via ``+`` +# --------------------------------------------------------------------------- + + +class TestAddBuilding: + def test_add_creates_add_node(self, square_tensors): + A, B = square_tensors["A"], square_tensors["B"] + expr = A["ij"] + B["ij"] + assert isinstance(expr, Add) + + def test_add_flattens(self, square_tensors): + A, B, C = square_tensors["A"], square_tensors["B"], square_tensors["C"] + expr = A["ij"] + B["ij"] + C["ij"] + assert isinstance(expr, Add) + assert len(expr) == 3 + + def test_sub_via_neg(self, square_tensors): + A, B = square_tensors["A"], square_tensors["B"] + expr = A["ij"] - B["ij"] + # ``a - b`` == ``a + (-b)``, i.e. an Add with a ScalarMul(-1) child. + assert isinstance(expr, Add) + assert isinstance(expr[1], ScalarMultiplication) + assert expr[1].scalar() == -1.0 + + def test_add_with_non_node_raises(self, square_tensors): + A = square_tensors["A"] + with pytest.raises(ValueError, match="Cannot add"): + A["ij"] + 5 + + +# --------------------------------------------------------------------------- +# Assign - via ``<=`` +# --------------------------------------------------------------------------- + + +class TestAssign: + def test_assign_builds_kernel(self, square_tensors): + A, B = square_tensors["A"], square_tensors["B"] + kernel = A["ij"] <= B["ij"] + assert isinstance(kernel, Assign) + assert len(kernel) == 2 + + def test_assign_lhs_must_be_indexed_tensor(self, square_tensors): + A, B, C = square_tensors["A"], square_tensors["B"], square_tensors["C"] + # The invariant ("first child of Assign must be an IndexedTensor") + # is enforced inside ``Assign.setChildren`` - i.e. when a later + # transformer pass rewrites the tree - not in the constructor. + # The DSL's ``__le__`` calls the constructor directly, so this + # expression is *accepted* at build time; the check fires only + # when a transformer tries to re-install children. + bad = (A["ij"] * B["ij"]) <= C["ij"] + assert isinstance(bad, Assign) + with pytest.raises(ValueError, match="must be an IndexedTensor"): + bad.setChildren([A["ij"] * B["ij"], C["ij"]]) + + def test_assign_flops_are_zero(self, square_tensors): + A, B = square_tensors["A"], square_tensors["B"] + kernel = A["ij"] <= B["ij"] + assert kernel.nonZeroFlops() == 0 + + +# --------------------------------------------------------------------------- +# SliceView - ``A['ij'].subslice(...)`` / ``.subselect(...)`` +# --------------------------------------------------------------------------- + + +class TestSliceView: + def test_subslice_is_a_sliceview(self, square_tensors): + A = square_tensors["A"] + sv = A["ij"].subslice("i", 1, 4) + assert isinstance(sv, SliceView) + assert sv.index == "i" + assert sv.start == 1 + assert sv.end == 4 + + def test_subselect_is_single_index_slice(self, square_tensors): + A = square_tensors["A"] + sv = A["ij"].subselect("i", 2) + assert isinstance(sv, SliceView) + # subselect(i, 2) should cover [2, 3) + assert sv.start == 2 + assert sv.end == 3 + + def test_name_delegates_through_the_view(self, square_tensors): + A = square_tensors["A"] + assert A["ij"].subslice("i", 0, 3).name() == "A" + + def test_viewed_unwraps_to_indexed_tensor(self, square_tensors): + A = square_tensors["A"] + sv = A["ij"].subslice("i", 0, 3) + inner = sv.viewed() + assert isinstance(inner, IndexedTensor) + assert inner.tensor is A + + +# --------------------------------------------------------------------------- +# Product / IndexSum / Contraction - the "lowered" Einsum +# --------------------------------------------------------------------------- + + +class TestLoweredNodes: + """After ``FindContractions``, ``Einsum`` is decomposed into + ``Product`` + ``IndexSum`` (or ``Contraction`` for binary cases). + The tests below construct them directly to pin down their contracts. + """ + + def test_product_merges_indices(self): + a = IndexedTensor(Tensor("A", (3, 4)), "ij") + b = IndexedTensor(Tensor("B", (4, 5)), "jk") + prod = Product(a, b) + # Product keeps every dimension, including the shared "j". + assert set(prod.indices) == {"i", "j", "k"} + + def test_product_rejects_mismatching_shared_dim(self): + a = IndexedTensor(Tensor("A", (3, 4)), "ij") + b = IndexedTensor(Tensor("B", (9, 5)), "jk") # j=9 vs j=4 + with pytest.raises(AssertionError): + Product(a, b) + + def test_indexsum_drops_one_index(self): + a = IndexedTensor(Tensor("A", (3, 4)), "ij") + s = IndexSum(a, "j") + assert str(s.indices) == "i" + # The stored sumIndex knows its size. + assert s.sumIndex().indexSize("j") == 4 + + def test_contraction_matmul(self): + a = IndexedTensor(Tensor("A", (3, 4)), "ij") + b = IndexedTensor(Tensor("B", (4, 5)), "jk") + c = Contraction( + indices=Indices("ik", (3, 5)), + lTerm=a, + rTerm=b, + sumIndices={"j"}, + ) + assert set(c.indices) == {"i", "k"} + assert c.sumIndices == {"j"} + + +# --------------------------------------------------------------------------- +# LoopOverGEMM - the node the Codegen actually emits +# --------------------------------------------------------------------------- + + +class TestLoopOverGEMM: + def _simple_gemm(self): + a = IndexedTensor(Tensor("A", (3, 4)), "ij") + b = IndexedTensor(Tensor("B", (4, 5)), "jk") + indices = Indices("ik", (3, 5)) + return LoopOverGEMM( + indices=indices, + aTerm=a, + bTerm=b, + m=Indices("i", (3,)), + n=Indices("k", (5,)), + k=Indices("j", (4,)), + ) + + def test_pure_gemm_detection(self): + log = self._simple_gemm() + assert log.is_pure_gemm() + + def test_non_pure_gemm_has_extra_dim(self): + # An outer-indexed third dimension breaks the ``is_pure_gemm`` test. + a = IndexedTensor(Tensor("A", (3, 4, 2)), "ijl") + b = IndexedTensor(Tensor("B", (4, 5)), "jk") + indices = Indices("ikl", (3, 5, 2)) + log = LoopOverGEMM( + indices=indices, + aTerm=a, + bTerm=b, + m=Indices("i", (3,)), + n=Indices("k", (5,)), + k=Indices("j", (4,)), + ) + assert not log.is_pure_gemm() + + def test_trans_flags_for_simple_gemm(self): + # Layout: A is (i,j), B is (j,k), result is (i,k). Both operands + # have the GEMM-friendly order, so no transposes are needed. + log = self._simple_gemm() + assert log.transA() is False + assert log.transB() is False + + def test_trans_flag_when_k_precedes_m(self): + # A uses ``ji`` instead of ``ij`` - k-index precedes m-index, so + # LoG must request a transpose on A. + a = IndexedTensor(Tensor("A", (4, 3)), "ji") + b = IndexedTensor(Tensor("B", (4, 5)), "jk") + indices = Indices("ik", (3, 5)) + log = LoopOverGEMM( + indices=indices, + aTerm=a, + bTerm=b, + m=Indices("i", (3,)), + n=Indices("k", (5,)), + k=Indices("j", (4,)), + ) + assert log.transA() is True + assert log.transB() is False + + +# --------------------------------------------------------------------------- +# FusedGEMMs - list-like container of (pure GEMM) LoGs +# --------------------------------------------------------------------------- + + +class TestFusedGEMMs: + def _log(self): + a = IndexedTensor(Tensor("A", (3, 4)), "ij") + b = IndexedTensor(Tensor("B", (4, 5)), "jk") + return LoopOverGEMM( + indices=Indices("ik", (3, 5)), + aTerm=a, bTerm=b, + m=Indices("i", (3,)), + n=Indices("k", (5,)), + k=Indices("j", (4,)), + ) + + def test_is_empty_on_construction(self): + fg = FusedGEMMs() + assert fg.is_empty() + + def test_add_accepts_only_log(self): + fg = FusedGEMMs() + fg.add(self._log()) + assert not fg.is_empty() + assert len(fg.get_children()) == 1 + # Non-LoG child -> rejected. + with pytest.raises(ValueError, match="expected LoopOverGEMM"): + fg.add(IndexedTensor(Tensor("x", (2, 2)), "ij")) + + +# --------------------------------------------------------------------------- +# Permute / Broadcast - required at LoG time to make indices line up +# --------------------------------------------------------------------------- + + +class TestPermuteBroadcast: + def test_permute_requires_same_index_set(self): + a = IndexedTensor(Tensor("A", (3, 4)), "ij") + # Permute cannot introduce or remove indices; it must be a pure reorder. + good = Permute(a, Indices("ji", (4, 3))) + assert set(good.indices) == {"i", "j"} + + def test_broadcast_adds_indices(self): + a = IndexedTensor(Tensor("A", (3,)), "i") + bcst = Broadcast(a, Indices("ij", (3, 4))) + assert set(bcst.indices) == {"i", "j"} + + def test_permute_nonZeroFlops_is_zero(self): + a = IndexedTensor(Tensor("A", (3, 4)), "ij") + p = Permute(a, Indices("ji", (4, 3))) + # Data-movement-only ops are free in the flop accounting. + assert p.nonZeroFlops() == 0 + + +# --------------------------------------------------------------------------- +# Common node invariants via ABCs +# --------------------------------------------------------------------------- + + +class TestNodeAbstractInvariants: + def test_unaryop_term_is_first_child(self): + a = IndexedTensor(Tensor("A", (3, 4)), "ij") + s = IndexSum(a, "j") # a UnaryOp + assert s.term() is s[0] + assert isinstance(s, UnaryOp) + + def test_binop_left_and_right_term(self): + a = IndexedTensor(Tensor("A", (3, 4)), "ij") + b = IndexedTensor(Tensor("B", (4, 5)), "jk") + p = Product(a, b) # BinOp + assert p.leftTerm() is p[0] + assert p.rightTerm() is p[1] + assert isinstance(p, BinOp) + + def test_op_is_iterable_over_children(self, square_tensors): + A, B = square_tensors["A"], square_tensors["B"] + expr = A["ij"] + B["ij"] + assert list(expr) == [expr[0], expr[1]] diff --git a/tests/pytest/test_ast_transformer.py b/tests/pytest/test_ast_transformer.py new file mode 100644 index 0000000..3bb7cd7 --- /dev/null +++ b/tests/pytest/test_ast_transformer.py @@ -0,0 +1,300 @@ +""" +Tests for ``yateto.ast.transformer`` - the AST rewriting passes. + +Transformers are the compiler's middle-end. Each one takes an AST, +walks it, possibly substitutes nodes with different ones, and returns +the new root. They must run in a prescribed order - see +``generator.py::Kernel.prepareUntilCodeGen`` - and many of them only +make sense once the preceding ones have run. + +This file checks each pass in isolation (what invariant does it +establish? what does it assume?) plus a few multi-step compositions +that exercise the real ordering. +""" +from __future__ import annotations + +import numpy as np +import pytest + +from yateto import Tensor +from yateto.ast.cost import BoundingBoxCostEstimator, ShapeCostEstimator +from yateto.ast.indices import Indices +from yateto.ast.node import ( + Add, + Assign, + Contraction, + Einsum, + IndexedTensor, + IndexSum, + Product, + ScalarMultiplication, +) +from yateto.ast.transformer import ( + ComputeMemoryLayout, + DeduceIndices, + EquivalentSparsityPattern, + FindContractions, + ImplementContractions, + SetSparsityPattern, + StrengthReduction, +) +from yateto.ast.visitor import FindIndexPermutations +from yateto.ast.transformer import SelectIndexPermutations + + +# --------------------------------------------------------------------------- +# DeduceIndices - the mandatory first pass +# --------------------------------------------------------------------------- + + +class TestDeduceIndices: + def test_matmul_indices_are_deduced(self, square_tensors): + A, B, C = square_tensors["A"], square_tensors["B"], square_tensors["C"] + kernel = C["ij"] <= A["ik"] * B["kj"] + # Before the pass, non-leaf nodes have ``indices = None``. + assert kernel.indices is None + assert kernel.rightTerm().indices is None + + kernel = DeduceIndices().visit(kernel) + # Afterwards, indices are set all the way down. + assert str(kernel.indices) == "ij" + assert str(kernel.rightTerm().indices) == "ij" + + def test_add_requires_same_sizes(self): + # Two addends sharing letters but different sizes -> error. + A = Tensor("A", (3, 4)) + B = Tensor("B", (3, 5)) # j has size 5 vs A's j=4 + C = Tensor("C", (3, 4)) + with pytest.raises(AssertionError): + # mergeStrict inside DeduceIndices.visit_Add catches this. + DeduceIndices().visit(C["ij"] <= A["ij"] + B["ij"]) + + def test_lhs_rhs_index_mismatch_is_flagged(self, square_tensors): + # LHS asks for an index that doesn't appear on RHS. + A, B = square_tensors["A"], square_tensors["B"] + kernel = A["ij"] <= B["jk"] + # "j" is unbound on the rhs in the context of lhs "ij" - there + # is no "i" on the rhs, so DeduceIndices must reject the kernel. + with pytest.raises(ValueError): + DeduceIndices().visit(kernel) + + def test_unbound_indices_on_lhs_rejected(self, square_tensors): + # A free index on the rhs that doesn't appear on the lhs is a + # contraction if the tree is an Einsum - otherwise it's an error. + A = Tensor("A", (3,)) + B = Tensor("B", (3,)) + C = Tensor("C", ()) # scalar lhs + # C[''] = A['i'] * B['i'] -- this is a dot product, indices must + # be bound, so DeduceIndices should accept it. + kernel = C[""] <= A["i"] * B["i"] + kernel = DeduceIndices().visit(kernel) + assert str(kernel.indices) == "" + + def test_target_indices_force_permutation(self, square_tensors): + # Passing an explicit target permutes the root. + A, B = square_tensors["A"], square_tensors["B"] + kernel = A["ji"] <= B["ij"] + kernel = DeduceIndices("ji").visit(kernel) + assert str(kernel.indices) == "ji" + + +# --------------------------------------------------------------------------- +# EquivalentSparsityPattern - computes every node's eqspp +# --------------------------------------------------------------------------- + + +class TestEquivalentSparsityPattern: + def test_sets_eqspp_on_every_node(self, deduced, square_tensors): + A, B, C = square_tensors["A"], square_tensors["B"], square_tensors["C"] + kernel = C["ij"] <= A["ik"] * B["kj"] + kernel = deduced(kernel) + # Before the pass, eqspp is None. + assert kernel.eqspp() is None + kernel = EquivalentSparsityPattern().visit(kernel) + # After the pass, every node - including the leaves - has an eqspp. + assert kernel.eqspp() is not None + assert kernel.rightTerm().eqspp() is not None + for child in kernel.rightTerm(): + assert child.eqspp() is not None + + def test_dense_matmul_eqspp_is_dense(self, deduced, square_tensors): + A, B, C = square_tensors["A"], square_tensors["B"], square_tensors["C"] + kernel = C["ij"] <= A["ik"] * B["kj"] + kernel = deduced(kernel) + kernel = EquivalentSparsityPattern().visit(kernel) + assert kernel.eqspp().is_dense() + + def test_sparse_matmul_shrinks_eqspp(self, deduced): + # Matmul of two diagonal matrices has a diagonal pattern. + diag = np.eye(4, dtype=bool) + A = Tensor("A", (4, 4), spp=diag) + B = Tensor("B", (4, 4), spp=diag) + C = Tensor("C", (4, 4)) + kernel = C["ij"] <= A["ik"] * B["kj"] + kernel = deduced(kernel) + kernel = EquivalentSparsityPattern().visit(kernel) + # D * D = D (diagonal) -> 4 nonzeros, not 16. + assert kernel.eqspp().count_nonzero() == 4 + + +# --------------------------------------------------------------------------- +# StrengthReduction - lowers Einsum to Product(...) / IndexSum(...) +# --------------------------------------------------------------------------- + + +class TestStrengthReduction: + def test_einsum_disappears(self, deduced, square_tensors): + A, B, C = square_tensors["A"], square_tensors["B"], square_tensors["C"] + kernel = C["ij"] <= A["ik"] * B["kj"] + kernel = deduced(kernel) + kernel = EquivalentSparsityPattern().visit(kernel) + kernel = StrengthReduction(BoundingBoxCostEstimator).visit(kernel) + + # Einsum is gone; the rhs is now a Product-under-IndexSum tree. + def has(cls, node): + if isinstance(node, cls): + return True + return any(has(cls, c) for c in node) + + assert not has(Einsum, kernel) + assert has(Product, kernel) + assert has(IndexSum, kernel) + + def test_costEstimator_is_a_class_not_an_instance(self, deduced, square_tensors): + # ``StrengthReduction`` calls ``self._costEstimator()`` internally + # - i.e. it instantiates a fresh estimator per Einsum node. This + # is a silent trap when users pass an already-constructed one. + A, B, C = square_tensors["A"], square_tensors["B"], square_tensors["C"] + kernel = deduced(C["ij"] <= A["ik"] * B["kj"]) + kernel = EquivalentSparsityPattern().visit(kernel) + # Passing an instance fails with TypeError (not callable). + with pytest.raises(TypeError): + StrengthReduction(BoundingBoxCostEstimator()).visit(kernel) + + def test_ternary_product_is_strength_reduced(self, deduced): + # a * b * c should be ordered by the cost estimator, not flattened + # into a single big product. The exact order depends on sizes. + A = Tensor("A", (2, 3)) + B = Tensor("B", (3, 4)) + C = Tensor("C", (4, 5)) + D = Tensor("D", (2, 5)) + kernel = D["il"] <= A["ij"] * B["jk"] * C["kl"] + kernel = deduced(kernel) + kernel = EquivalentSparsityPattern().visit(kernel) + kernel = StrengthReduction(BoundingBoxCostEstimator).visit(kernel) + # The result should be a binary tree of Products + IndexSums + # (i.e. Einsum has been fully split into pairwise GEMMs). + from yateto.ast.node import Einsum as _E + def has_einsum(node): + if isinstance(node, _E): + return True + return any(has_einsum(c) for c in node) + assert not has_einsum(kernel) + + +# --------------------------------------------------------------------------- +# FindContractions - fuses Product+IndexSum into a single Contraction node +# --------------------------------------------------------------------------- + + +class TestFindContractions: + def test_matmul_becomes_contraction(self, deduced, square_tensors): + A, B, C = square_tensors["A"], square_tensors["B"], square_tensors["C"] + kernel = C["ij"] <= A["ik"] * B["kj"] + kernel = deduced(kernel) + kernel = EquivalentSparsityPattern().visit(kernel) + kernel = StrengthReduction(BoundingBoxCostEstimator).visit(kernel) + kernel = FindContractions().visit(kernel) + + # The rhs should now be a single Contraction node with sumIndices={"k"}. + rhs = kernel.rightTerm() + assert isinstance(rhs, Contraction) + assert rhs.sumIndices == {"k"} + assert set(rhs.indices) == {"i", "j"} + + def test_dot_product_becomes_contraction(self, deduced): + # Vector dot-product: scalar = sum_i A[i] * B[i] + A = Tensor("A", (5,)) + B = Tensor("B", (5,)) + C = Tensor("C", ()) + kernel = C[""] <= A["i"] * B["i"] + kernel = deduced(kernel) + kernel = EquivalentSparsityPattern().visit(kernel) + kernel = StrengthReduction(BoundingBoxCostEstimator).visit(kernel) + kernel = FindContractions().visit(kernel) + rhs = kernel.rightTerm() + assert isinstance(rhs, Contraction) + assert rhs.sumIndices == {"i"} + + +# --------------------------------------------------------------------------- +# SetSparsityPattern +# --------------------------------------------------------------------------- + + +class TestSetSparsityPattern: + def test_populates_eqspp_bottom_up(self, deduced, square_tensors): + A, B, C = square_tensors["A"], square_tensors["B"], square_tensors["C"] + kernel = C["ij"] <= A["ik"] * B["kj"] + kernel = deduced(kernel) + kernel = EquivalentSparsityPattern().visit(kernel) + kernel = StrengthReduction(BoundingBoxCostEstimator).visit(kernel) + # SetSparsityPattern re-computes eqspps using the concrete tree + # (rather than equivalent patterns). After this, nonZeroFlops + # must not crash. + kernel = SetSparsityPattern().visit(kernel) + assert kernel.eqspp() is not None + + def test_indexed_tensor_eqspp_is_preserved(self, square_tensors): + # SetSparsityPattern treats IndexedTensor as a no-op and does not + # overwrite any existing eqspp. + A = Tensor("A", (3, 3)) + it = IndexedTensor(A, "ij") + from yateto.aspp import dense + it.setEqspp(dense((3, 3))) + SetSparsityPattern().visit_IndexedTensor(it) + assert it.eqspp() is not None + + +# --------------------------------------------------------------------------- +# ComputeMemoryLayout +# --------------------------------------------------------------------------- + + +class TestComputeMemoryLayout: + def test_assigns_memory_layouts(self, deduced, square_tensors, arch): + A, B, C = square_tensors["A"], square_tensors["B"], square_tensors["C"] + kernel = C["ij"] <= A["ik"] * B["kj"] + kernel = deduced(kernel) + kernel = EquivalentSparsityPattern().visit(kernel) + kernel = StrengthReduction(BoundingBoxCostEstimator).visit(kernel) + kernel = FindContractions().visit(kernel) + kernel = ComputeMemoryLayout().visit(kernel) + # Every op node now has a memory layout. + assert kernel.rightTerm().memoryLayout() is not None + + +# --------------------------------------------------------------------------- +# SelectIndexPermutations + ImplementContractions -> LoopOverGEMM +# --------------------------------------------------------------------------- + + +class TestLowerToLoG: + def test_contraction_becomes_log(self, deduced, square_tensors, arch): + from yateto.ast.node import LoopOverGEMM + A, B, C = square_tensors["A"], square_tensors["B"], square_tensors["C"] + kernel = C["ij"] <= A["ik"] * B["kj"] + kernel = deduced(kernel) + kernel = EquivalentSparsityPattern().visit(kernel) + kernel = StrengthReduction(BoundingBoxCostEstimator).visit(kernel) + kernel = FindContractions().visit(kernel) + kernel = ComputeMemoryLayout().visit(kernel) + variants = FindIndexPermutations().visit(kernel) + kernel = SelectIndexPermutations(variants).visit(kernel) + kernel = ImplementContractions().visit(kernel) + + # The Contraction node is replaced with a LoopOverGEMM node. + rhs = kernel.rightTerm() + assert isinstance(rhs, LoopOverGEMM) + # The LoG knows which m/n/k dimensions it covers. + assert rhs.is_pure_gemm() diff --git a/tests/pytest/test_ast_visitor.py b/tests/pytest/test_ast_visitor.py new file mode 100644 index 0000000..7e8fb50 --- /dev/null +++ b/tests/pytest/test_ast_visitor.py @@ -0,0 +1,273 @@ +""" +Tests for ``yateto.ast.visitor`` - the read-only AST walkers. + +Visitors implement the classic GoF visitor pattern (``visit_`` +dispatch). They are the analysis phase of the compiler: ``PrettyPrinter`` +for debugging, ``FindTensors`` / ``FindIndexPermutations`` for metadata +collection, ``ComputeSparsityPattern`` / ``ComputeOptimalFlopCount`` for +numerical-property analysis, ``ComputeIndexSet`` / ``ComputeConstantExpression`` +for the front-end. +""" +from __future__ import annotations + +import numpy as np +import pytest + +from yateto import Tensor +from yateto.ast.indices import Indices +from yateto.ast.node import Add, Assign, Einsum, IndexedTensor, Product +from yateto.ast.transformer import DeduceIndices +from yateto.ast.visitor import ( + CachedVisitor, + ComputeConstantExpression, + ComputeIndexSet, + ComputeOptimalFlopCount, + ComputeSparsityPattern, + FindTensors, + PrettyPrinter, + Visitor, +) + + +# --------------------------------------------------------------------------- +# The base ``Visitor`` class +# --------------------------------------------------------------------------- + + +class TestVisitorDispatch: + """Visitor dispatches on ``visit_``, falling back to + ``generic_visit`` when no specialised method exists. Custom visitors + rely on this being rock solid. + """ + + def test_dispatches_on_node_class_name(self, square_tensors): + A = square_tensors["A"] + seen = [] + + class Recorder(Visitor): + def visit_IndexedTensor(self, node): + seen.append(("IT", node.name())) + + def visit_Assign(self, node): + seen.append(("Assign",)) + self.generic_visit(node) + + kernel = A["ij"] <= A["ij"] + Recorder().visit(kernel) + # First the Assign, then both IndexedTensor children. + assert seen[0] == ("Assign",) + assert ("IT", "A") in seen + + def test_generic_visit_recurses_into_children(self, square_tensors): + A, B = square_tensors["A"], square_tensors["B"] + depth = {"max": 0, "cur": 0} + + class DepthProbe(Visitor): + def generic_visit(self, node): + depth["cur"] += 1 + depth["max"] = max(depth["max"], depth["cur"]) + for c in node: + self.visit(c) + depth["cur"] -= 1 + + DepthProbe().visit(A["ij"] + B["ij"]) + # Add wraps two IndexedTensor leaves -> max depth 2. + assert depth["max"] == 2 + + def test_cached_visitor_reuses_results(self, square_tensors): + A = square_tensors["A"] + calls = [0] + + class CachedCounter(CachedVisitor): + def generic_visit(self, node): + calls[0] += 1 + return id(node) + + v = CachedCounter() + kernel = A["ij"] <= A["ij"] + first = v.visit(kernel) + n_first = calls[0] + # Second visit on the same node returns the cached result without + # re-running ``generic_visit``. + second = v.visit(kernel) + assert second == first + assert calls[0] == n_first + + +# --------------------------------------------------------------------------- +# PrettyPrinter +# --------------------------------------------------------------------------- + + +class TestPrettyPrinter: + def test_prints_tree(self, capsys, square_tensors): + A, B = square_tensors["A"], square_tensors["B"] + kernel = A["ij"] <= B["ij"] + PrettyPrinter().visit(kernel) + out = capsys.readouterr().out + # The tree string contains the Assign root and both children. + assert "Assign" in out + assert "A[ij]" in out + assert "B[ij]" in out + + +# --------------------------------------------------------------------------- +# FindTensors +# --------------------------------------------------------------------------- + + +class TestFindTensors: + def test_collects_leaf_tensors_by_name(self, deduced, square_tensors): + A, B, C = square_tensors["A"], square_tensors["B"], square_tensors["C"] + kernel = C["ij"] <= A["ik"] * B["kj"] + kernel = deduced(kernel) + tensors = FindTensors().visit(kernel) + assert set(tensors.keys()) == {"A", "B", "C"} + # Every value is the original ``Tensor`` object. + assert tensors["A"] is A + assert tensors["B"] is B + assert tensors["C"] is C + + def test_skips_temporary_tensors(self, deduced): + # Temporary tensors must not be exposed as external kernel + # parameters. ``FindTensors`` filters them out. + A = Tensor("A", (3, 3)) + tmp = Tensor("tmp", (3, 3), temporary=True) + kernel = tmp["ij"] <= A["ij"] + kernel = deduced(kernel) + tensors = FindTensors().visit(kernel) + assert "tmp" not in tensors + assert "A" in tensors + + +# --------------------------------------------------------------------------- +# ComputeIndexSet +# --------------------------------------------------------------------------- + + +class TestComputeIndexSet: + def test_union_of_all_child_indices(self, square_tensors): + A, B = square_tensors["A"], square_tensors["B"] + # Two operands with partially overlapping indices. + tree = A["ij"] * B["jk"] + # ComputeIndexSet returns the union. + assert ComputeIndexSet().visit(tree) == {"i", "j", "k"} + + def test_indexed_tensor_is_leaf(self): + it = IndexedTensor(Tensor("A", (3, 4)), "ij") + assert ComputeIndexSet().visit(it) == {"i", "j"} + + +# --------------------------------------------------------------------------- +# ComputeSparsityPattern +# --------------------------------------------------------------------------- + + +class TestComputeSparsityPattern: + """``ComputeSparsityPattern`` is a read-only walker that re-runs each + node's ``computeSparsityPattern`` method. Crucially, it does *not* + work on raw ``Einsum`` nodes - ``Einsum.computeSparsityPattern`` + explicitly raises ``NotImplementedError``. The compiler pipeline + first lowers ``Einsum`` to ``Product`` + ``IndexSum`` via the + ``EquivalentSparsityPattern`` transformer, which is what the + ``run_ast_pipeline`` fixture does. + """ + + def test_raw_einsum_cannot_be_evaluated(self, deduced, square_tensors): + # Locking this in: ``ComputeSparsityPattern`` on a fresh DSL tree + # crashes because ``Einsum`` refuses to compute its pattern + # on its own. The fix is to run ``EquivalentSparsityPattern`` + # first (which decomposes the Einsum into Product + IndexSum). + A, B, C = square_tensors["A"], square_tensors["B"], square_tensors["C"] + kernel = C["ij"] <= A["ik"] * B["kj"] + kernel = deduced(kernel) + with pytest.raises(NotImplementedError): + ComputeSparsityPattern(useAvailable=False).visit(kernel) + + def test_dense_matmul_produces_dense_result(self, run_ast_pipeline, square_tensors): + A, B, C = square_tensors["A"], square_tensors["B"], square_tensors["C"] + kernel = C["ij"] <= A["ik"] * B["kj"] + kernel = run_ast_pipeline(kernel) + spp = ComputeSparsityPattern(useAvailable=True).visit(kernel) + # Dense inputs -> dense output. + assert spp.is_dense() + assert spp.shape == (8, 8) + + def test_sparse_input_shrinks_product(self, run_ast_pipeline): + # Sparse A (only diagonal) times dense B should produce a pattern + # that the strength reducer can exploit. The diagonal-of-A times + # dense-B case activates all rows of the result, so the pattern + # is still fully dense - but the number of *operations* goes + # down, which we exercise in the flop-count tests. + diag = np.eye(4, dtype=bool) + A = Tensor("A", (4, 4), spp=diag) + B = Tensor("B", (4, 4)) + C = Tensor("C", (4, 4)) + kernel = C["ij"] <= A["ik"] * B["kj"] + kernel = run_ast_pipeline(kernel) + spp = ComputeSparsityPattern(useAvailable=True).visit(kernel) + assert spp.count_nonzero() == 16 + + +# --------------------------------------------------------------------------- +# ComputeOptimalFlopCount +# --------------------------------------------------------------------------- + + +class TestComputeOptimalFlopCount: + def test_matmul_8x8_count(self, square_tensors, run_ast_pipeline): + A, B, C = square_tensors["A"], square_tensors["B"], square_tensors["C"] + kernel = C["ij"] <= A["ik"] * B["kj"] + kernel = run_ast_pipeline(kernel) + # 8x8 matmul decomposes into a Product (for each output) plus + # IndexSum over k. After strength reduction the flop count is + # well-defined: 960 for a dense 8x8x8 GEMM in Yateto's accounting. + # + # The exact number is a regression-lock - anything that changes + # it silently is a red flag. + assert ComputeOptimalFlopCount().visit(kernel) == 960 + + def test_add_flops_is_n_for_nxn(self, square_tensors, run_ast_pipeline): + # A 2D elementwise add of NxN matrices costs N*N additions. + A, B, C = square_tensors["A"], square_tensors["B"], square_tensors["C"] + kernel = C["ij"] <= A["ij"] + B["ij"] + kernel = run_ast_pipeline(kernel) + assert ComputeOptimalFlopCount().visit(kernel) == 64 # 8*8 + + def test_leaf_is_zero(self): + it = IndexedTensor(Tensor("A", (3, 3)), "ij") + assert ComputeOptimalFlopCount().visit(it) == 0 + + +# --------------------------------------------------------------------------- +# ComputeConstantExpression - reference evaluation (for compile-time +# constant folding or unit testing) +# --------------------------------------------------------------------------- + + +class TestComputeConstantExpression: + def test_matmul_of_constants(self, deduced): + # Both tensors have concrete values -> the evaluator should + # produce the actual matrix product. + vals_a = {(0, 0): 1.0, (0, 1): 2.0, (1, 0): 3.0, (1, 1): 4.0} + vals_b = {(0, 0): 5.0, (0, 1): 6.0, (1, 0): 7.0, (1, 1): 8.0} + A = Tensor("A", (2, 2), spp=vals_a) + B = Tensor("B", (2, 2), spp=vals_b) + C = Tensor("C", (2, 2)) + + # ComputeConstantExpression walks the *rhs* only. + rhs = A["ik"] * B["kj"] + # Must run DeduceIndices first so that Einsum.indices is known. + rhs = DeduceIndices("ij").visit(rhs) + result = ComputeConstantExpression().visit(rhs) + + expected = np.array([[1, 2], [3, 4]]) @ np.array([[5, 6], [7, 8]]) + np.testing.assert_array_equal(result, expected) + + def test_evaluator_requires_constant_tensors(self, deduced): + # A plain tensor (no values) cannot be evaluated numerically. + A = Tensor("A", (2, 2)) + B = Tensor("B", (2, 2), spp={(0, 0): 1.0, (1, 1): 1.0}) + rhs = DeduceIndices("ij").visit(A["ij"] + B["ij"]) + with pytest.raises(AssertionError, match="constant"): + ComputeConstantExpression().visit(rhs) diff --git a/tests/pytest/test_controlflow.py b/tests/pytest/test_controlflow.py new file mode 100644 index 0000000..eef0a0e --- /dev/null +++ b/tests/pytest/test_controlflow.py @@ -0,0 +1,293 @@ +""" +Tests for ``yateto.controlflow`` - the mini IR between the AST and the +emitted C++. + +After strength reduction and ``ImplementContractions``, the AST is +flattened into a straight-line control-flow graph (no loops / branches +at this level - the DSL doesn't have them). Each program point carries +a ``ProgramAction`` of the shape + + result [+]= [scalar *] term + +where ``term`` is either a single ``Variable`` or an ``Expression`` +(a LoopOverGEMM, a Permute, a Broadcast, ...). Subsequent CFG-level +passes do classic compiler things: liveness analysis, copy +propagation, dead-store elimination, action merging. + +These tests check: + +* ``AST2ControlFlow`` really emits a linear CFG and introduces fresh + temporaries for each intermediate result, +* ``LivenessAnalysis`` annotates every program point with a correct + ``live`` set, +* ``SubstituteForward`` / ``SubstituteBackward`` eliminate trivial + copies, +* ``RemoveEmptyStatements`` drops ``x = x`` lines, +* ``MergeActions`` fuses compatible actions. +""" +from __future__ import annotations + +import pytest + +from yateto import Tensor +from yateto.arch import useArchitectureIdentifiedBy +from yateto.ast.cost import BoundingBoxCostEstimator +from yateto.ast.transformer import ( + ComputeMemoryLayout, + DeduceIndices, + EquivalentSparsityPattern, + FindContractions, + ImplementContractions, + SetSparsityPattern, + StrengthReduction, +) +from yateto.ast.visitor import FindIndexPermutations +from yateto.ast.transformer import SelectIndexPermutations +from yateto.controlflow.graph import ( + Expression, + FusedActions, + ProgramAction, + ProgramPoint, + Variable, + VariableView, +) +from yateto.controlflow.transformer import ( + LivenessAnalysis, + MergeActions, + MergeScalarMultiplications, + RemoveEmptyStatements, + SubstituteBackward, + SubstituteForward, +) +from yateto.controlflow.visitor import AST2ControlFlow + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _lower_to_cfg(kernel, arch): + """Run the pipeline up to and including AST2ControlFlow. Returns + the kernel (post-AST passes) and its CFG. + """ + kernel = DeduceIndices().visit(kernel) + kernel = EquivalentSparsityPattern().visit(kernel) + kernel = StrengthReduction(BoundingBoxCostEstimator).visit(kernel) + kernel = FindContractions().visit(kernel) + kernel = ComputeMemoryLayout().visit(kernel) + variants = FindIndexPermutations().visit(kernel) + kernel = SelectIndexPermutations(variants).visit(kernel) + kernel = ImplementContractions().visit(kernel) + kernel = SetSparsityPattern().visit(kernel) + + conv = AST2ControlFlow() + conv.visit(kernel) + cfg = conv.cfg() + return kernel, cfg + + +def _live_var_names(live): + """Return the variable names inside a live set, irrespective of + whether ``live`` is a Python ``set`` (master) or a ``LiveSet`` + wrapper (nonlinearity branch).""" + if hasattr(live, "variables"): + return {v.name for v in live.variables()} + return {v.name for v in live} + + +# --------------------------------------------------------------------------- +# Variable +# --------------------------------------------------------------------------- + + +class TestVariable: + def test_globality_follows_tensor(self, arch): + from yateto.memory import DenseMemoryLayout + # A variable with a non-temporary tensor is global. + T = Tensor("A", (3, 3)) + ml = DenseMemoryLayout((3, 3)) + v = Variable("A", writable=False, memoryLayout=ml, tensor=T) + assert v.isGlobal() + assert not v.isLocal() + + def test_pure_temporary_is_local(self, arch): + from yateto.memory import DenseMemoryLayout + ml = DenseMemoryLayout((3, 3)) + v = Variable("_tmp0", writable=True, memoryLayout=ml, is_temporary=True) + assert v.isLocal() + assert not v.isGlobal() + + def test_hash_is_by_name(self, arch): + from yateto.memory import DenseMemoryLayout + ml = DenseMemoryLayout((3, 3)) + a = Variable("X", True, ml) + b = Variable("X", True, ml) + # Same name -> same hash, insertable into a set without dups. + s = {a, b} + assert len(s) == 1 + + def test_set_writable_only_matches_by_name(self, arch): + from yateto.memory import DenseMemoryLayout + ml = DenseMemoryLayout((3, 3)) + v = Variable("X", False, ml) + v.setWritable("Y") + assert v.writable is False + v.setWritable("X") + assert v.writable is True + + +# --------------------------------------------------------------------------- +# AST2ControlFlow - smoke test on a matmul +# --------------------------------------------------------------------------- + + +class TestAST2ControlFlow: + def test_produces_linear_cfg(self, arch): + A = Tensor("A", (8, 8)) + B = Tensor("B", (8, 8)) + C = Tensor("C", (8, 8)) + kernel = C["ij"] <= A["ik"] * B["kj"] + _, cfg = _lower_to_cfg(kernel, arch) + + # Every program point has no branching structure - it's a straight + # list (plus a terminating sentinel with ``action=None``). + assert all(isinstance(pp, ProgramPoint) for pp in cfg) + assert cfg[-1].action is None # sentinel + # There must be at least one action. + assert any(pp.action is not None for pp in cfg) + + def test_has_action_with_result_and_term(self, arch): + A = Tensor("A", (8, 8)) + B = Tensor("B", (8, 8)) + C = Tensor("C", (8, 8)) + kernel = C["ij"] <= A["ik"] * B["kj"] + _, cfg = _lower_to_cfg(kernel, arch) + + action = next(pp.action for pp in cfg if pp.action is not None) + assert action.result is not None + assert action.term is not None + + def test_temporary_names_are_unique(self, arch): + # Each _tmp name should appear exactly once as a result. + A = Tensor("A", (4, 4)) + B = Tensor("B", (4, 4)) + C = Tensor("C", (4, 4)) + D = Tensor("D", (4, 4)) + kernel = D["il"] <= A["ij"] * B["jk"] * C["kl"] + _, cfg = _lower_to_cfg(kernel, arch) + + tmp_results = [pp.action.result.name for pp in cfg + if pp.action is not None + and pp.action.result.name.startswith("_tmp")] + assert len(tmp_results) == len(set(tmp_results)) + + +# --------------------------------------------------------------------------- +# LivenessAnalysis +# --------------------------------------------------------------------------- + + +class TestLivenessAnalysis: + def test_annotates_every_program_point(self, arch): + A = Tensor("A", (4, 4)) + B = Tensor("B", (4, 4)) + C = Tensor("C", (4, 4)) + kernel = C["ij"] <= A["ik"] * B["kj"] + _, cfg = _lower_to_cfg(kernel, arch) + + # Before: live sets are None. + assert all(pp.live is None for pp in cfg) + cfg = LivenessAnalysis().visit(cfg) + # After: every program point has a live set (possibly empty). + assert all(pp.live is not None for pp in cfg) + + def test_sentinel_has_empty_live_set(self, arch): + A = Tensor("A", (4, 4)) + B = Tensor("B", (4, 4)) + C = Tensor("C", (4, 4)) + kernel = C["ij"] <= A["ik"] * B["kj"] + _, cfg = _lower_to_cfg(kernel, arch) + cfg = LivenessAnalysis().visit(cfg) + # Past the last action, nothing should be live - otherwise the + # kernel would leak. ``_live_var_names`` works both for plain + # Python sets (master) and for the ``LiveSet`` wrapper + # (nonlinearity). + assert _live_var_names(cfg[-1].live) == set() + + def test_inputs_are_live_at_first_use(self, arch): + # A + B: both tensors must be live at the beginning (they are read + # by the first real action). + A = Tensor("A", (4, 4)) + B = Tensor("B", (4, 4)) + C = Tensor("C", (4, 4)) + kernel = C["ij"] <= A["ij"] + B["ij"] + _, cfg = _lower_to_cfg(kernel, arch) + cfg = LivenessAnalysis().visit(cfg) + + first = next(pp for pp in cfg if pp.action is not None) + live_vars = _live_var_names(first.live) + # At least one of A/B is live at the first action. + assert "A" in live_vars or "B" in live_vars + + +# --------------------------------------------------------------------------- +# SubstituteForward / Backward / RemoveEmptyStatements +# --------------------------------------------------------------------------- + + +class TestCopyPropagation: + def test_pipeline_shrinks_cfg(self, arch): + # A simple identity assign ``C = A`` should be reduced aggressively + # by the CFG passes (the intermediate _tmp variables get folded). + A = Tensor("A", (4, 4)) + C = Tensor("C", (4, 4)) + kernel = C["ij"] <= A["ij"] + _, cfg = _lower_to_cfg(kernel, arch) + cfg = LivenessAnalysis().visit(cfg) + + before = sum(1 for pp in cfg if pp.action is not None) + cfg = SubstituteForward().visit(cfg) + cfg = SubstituteBackward().visit(cfg) + cfg = RemoveEmptyStatements().visit(cfg) + after = sum(1 for pp in cfg if pp.action is not None) + # The pipeline must not grow the CFG. It usually shrinks it. + assert after <= before + + +# --------------------------------------------------------------------------- +# MergeActions +# --------------------------------------------------------------------------- + + +class TestMergeActions: + def test_returns_cfg_with_liveness(self, arch): + A = Tensor("A", (4, 4)) + B = Tensor("B", (4, 4)) + C = Tensor("C", (4, 4)) + kernel = C["ij"] <= A["ik"] * B["kj"] + _, cfg = _lower_to_cfg(kernel, arch) + cfg = LivenessAnalysis().visit(cfg) + cfg = MergeActions().visit(cfg) + # After merging, liveness must still be up to date. + assert all(pp.live is not None for pp in cfg) + + +class TestFusedActions: + def test_is_empty_on_construction(self): + fa = FusedActions() + assert fa.is_empty() + + def test_add_rejects_non_log_term(self, arch): + from yateto.memory import DenseMemoryLayout + ml = DenseMemoryLayout((3, 3)) + # Construct a trivially non-LoG ProgramAction to check the guard. + result = Variable("R", True, ml) + term = Variable("X", False, ml) # a plain Variable, not Expression + action = ProgramAction(result, term, add=False) + + fa = FusedActions() + # The term is not an Expression at all (it's a Variable), so + # ``action.term.node`` would error - but the check happens first. + with pytest.raises(AttributeError): + fa.add(action) diff --git a/tests/pytest/test_cost.py b/tests/pytest/test_cost.py new file mode 100644 index 0000000..e7996e8 --- /dev/null +++ b/tests/pytest/test_cost.py @@ -0,0 +1,182 @@ +""" +Tests for ``yateto.ast.cost`` - the cost model that drives strength +reduction. + +There are three concrete estimators: + +* ``ShapeCostEstimator`` - counts flops using the declared shape + (dense upper bound). +* ``BoundingBoxCostEstimator`` - counts flops using the nonzero + bounding box (caches per node). +* ``ExactCost`` - counts flops using the actual + equivalent sparsity pattern (most + accurate, also most expensive). + +All estimators subclass ``CostEstimator`` and implement the +``estimate_`` dispatch pattern. Users of the API typically +hand a *class* (not an instance) to the generator, which then +instantiates a fresh estimator per AST. +""" +from __future__ import annotations + +import numpy as np +import pytest + +from yateto import Tensor +from yateto.ast.cost import ( + BoundingBoxCostEstimator, + CachedCostEstimator, + CostEstimator, + ExactCost, + ShapeCostEstimator, +) +from yateto.ast.node import IndexedTensor, IndexSum, Product +from yateto.ast.transformer import ( + DeduceIndices, + EquivalentSparsityPattern, + FindContractions, + SetSparsityPattern, + StrengthReduction, +) + + +def _lower_to_product_tree(kernel, estimator_cls=BoundingBoxCostEstimator): + """Helper: run the minimal set of passes so Product / IndexSum nodes + exist and eqspps are set. + """ + kernel = DeduceIndices().visit(kernel) + kernel = EquivalentSparsityPattern().visit(kernel) + kernel = StrengthReduction(estimator_cls).visit(kernel) + kernel = SetSparsityPattern().visit(kernel) + return kernel + + +# --------------------------------------------------------------------------- +# ShapeCostEstimator +# --------------------------------------------------------------------------- + + +class TestShapeCostEstimator: + def test_generic_node_is_free(self): + # Leaves and unopinionated nodes contribute nothing. + it = IndexedTensor(Tensor("A", (3, 4)), "ij") + assert ShapeCostEstimator().generic_estimate(it) == 0 + + def test_matmul_cost_is_ijk(self, square_tensors): + # For an 8x8 matmul built of a Product node (shape i,j,k) + an + # IndexSum over k, the cost is shape-based: + # Product: 8 * 8 * 8 = 512 + # IndexSum: (k-1) * i * j = 7 * 8 * 8 = 448 + # total = 960. Same as ComputeOptimalFlopCount's answer above + # (the cost model agrees with the flop counter for dense ops). + A, B, C = square_tensors["A"], square_tensors["B"], square_tensors["C"] + kernel = C["ij"] <= A["ik"] * B["kj"] + kernel = _lower_to_product_tree(kernel) + cost = ShapeCostEstimator().estimate(kernel) + assert cost == 960 + + def test_cost_scales_with_shape(self): + # Doubling a dimension doubles the estimated cost for a matmul. + def cost_for(N): + A = Tensor(f"A{N}", (N, N)) + B = Tensor(f"B{N}", (N, N)) + C = Tensor(f"C{N}", (N, N)) + kernel = C["ij"] <= A["ik"] * B["kj"] + kernel = _lower_to_product_tree(kernel) + return ShapeCostEstimator().estimate(kernel) + c4 = cost_for(4) + c8 = cost_for(8) + # Doubling N roughly eight-folds the matmul cost (N^3 term). + assert c8 / c4 >= 7 + + +# --------------------------------------------------------------------------- +# BoundingBoxCostEstimator +# --------------------------------------------------------------------------- + + +class TestBoundingBoxCostEstimator: + def test_dense_matches_shape_cost(self, square_tensors): + A, B, C = square_tensors["A"], square_tensors["B"], square_tensors["C"] + kernel = C["ij"] <= A["ik"] * B["kj"] + kernel = _lower_to_product_tree(kernel) + shape = ShapeCostEstimator().estimate(kernel) + bbox = BoundingBoxCostEstimator().estimate(kernel) + # On a fully dense kernel, bbox-based and shape-based estimators + # agree (bbox == full shape in that case). + assert bbox == shape + + def test_sparse_reduces_cost(self): + # Diagonal A * dense B: the bounding box of the diagonal is still + # the full (i,k)-plane (the diagonal spans from (0,0) to (N,N)), + # so the *bounding-box* estimator may not detect the saving. + # ``ExactCost`` does (see below). This is a useful distinction + # to lock in with the tests. + N = 4 + diag = np.eye(N, dtype=bool) + A = Tensor("A", (N, N), spp=diag) + B = Tensor("B", (N, N)) + C = Tensor("C", (N, N)) + kernel = C["ij"] <= A["ik"] * B["kj"] + kernel = _lower_to_product_tree(kernel, estimator_cls=ExactCost) + + bb_cost = BoundingBoxCostEstimator().estimate(kernel) + exact_cost = ExactCost().estimate(kernel) + # The exact cost must be strictly lower than the bounding-box + # upper bound. + assert exact_cost < bb_cost + + def test_caches_per_node(self, square_tensors): + # BoundingBoxCostEstimator inherits the cache of CachedCostEstimator. + # Estimating the same node twice must return the same value + # without re-doing work. + A, B, C = square_tensors["A"], square_tensors["B"], square_tensors["C"] + kernel = _lower_to_product_tree(C["ij"] <= A["ik"] * B["kj"]) + est = BoundingBoxCostEstimator() + first = est.estimate(kernel) + second = est.estimate(kernel) + assert first == second + + +# --------------------------------------------------------------------------- +# ExactCost +# --------------------------------------------------------------------------- + + +class TestExactCost: + def test_diagonal_times_diagonal_is_cheap(self): + # D * D = D (diagonal). A 4x4 matmul of two diagonals needs + # only N (not N^3) multiply-adds. + N = 4 + diag = np.eye(N, dtype=bool) + A = Tensor("A", (N, N), spp=diag) + B = Tensor("B", (N, N), spp=diag) + C = Tensor("C", (N, N)) + kernel = C["ij"] <= A["ik"] * B["kj"] + kernel = _lower_to_product_tree(kernel, estimator_cls=ExactCost) + cost = ExactCost().estimate(kernel) + # A dense 4x4 matmul costs 112 in our accounting (2*N^3 - N^2 = 128-16=112). + # Diagonal-of-diagonal must be much cheaper - exactly N products + # (the contraction collapses to elementwise pairing). + # Thus, we obtain (FMA - add to zero): 2*N - N = 8 - 4 = 4. + assert cost == 4 + + +# --------------------------------------------------------------------------- +# Abstract base behaviour +# --------------------------------------------------------------------------- + + +class TestAbstractBase: + def test_CostEstimator_requires_generic_estimate(self): + # Subclasses MUST implement generic_estimate - it's abstract. + class Incomplete(CostEstimator): + pass # no generic_estimate + with pytest.raises(TypeError): + Incomplete() + + def test_cached_estimator_is_abstract_too(self): + class Incomplete(CachedCostEstimator): + pass + with pytest.raises(TypeError): + Incomplete() diff --git a/tests/pytest/test_generator.py b/tests/pytest/test_generator.py new file mode 100644 index 0000000..93e8fd5 --- /dev/null +++ b/tests/pytest/test_generator.py @@ -0,0 +1,288 @@ +""" +Tests for ``yateto.generator`` - the orchestrator that ties everything +together. + +``Generator`` is the user-facing top-level object. Users call +``g.add("mykernel", C['ij'] <= A['ik'] * B['kj'])`` to register kernels +and ``g.generate(outdir)`` to spit out C++. Between those two, +``Kernel`` objects carry the ASTs and run them through +``prepareUntilUnitTest`` and ``prepareUntilCodeGen``. + +These tests exercise the Python side of that pipeline - stopping +**before** any C++ emission - so they can run without a compiler +toolchain. The C++ generation side is tested by the GitHub Actions +``yateto-cpu.yml`` workflow. +""" +from __future__ import annotations + +import os +import tempfile + +import pytest + +from yateto import Tensor, Generator, simpleParameterSpace, parameterSpaceFromRanges +from yateto.generator import Kernel, KernelFamily + + +# --------------------------------------------------------------------------- +# Kernel name validation +# --------------------------------------------------------------------------- + + +class TestKernelNames: + @pytest.mark.parametrize( + "name", + ["matmul", "K1", "my_kernel", "foo0", "A"], + ) + def test_valid_names(self, name): + assert Kernel.isValidName(name) + + @pytest.mark.parametrize( + "name", + ["0foo", "foo(1)", "foo-bar", ""], + ) + def test_invalid_names(self, name): + assert not Kernel.isValidName(name) + + +class TestKernelFamilyNames: + @pytest.mark.parametrize( + "name", + ["family(0)", "family(1)", "kfam(12)"], + ) + def test_valid_family_names(self, name): + # A KernelFamily name must be ``()``. + assert KernelFamily.isValidName(name) + + @pytest.mark.parametrize( + "name", + ["family", "family()", "family(0,1)"], + ) + def test_invalid_family_names(self, name): + # Note: ``family(0,1)`` is a *tensor* group name, not a family + # name. Family names have exactly one parenthesised nat. + assert not KernelFamily.isValidName(name) + + +# --------------------------------------------------------------------------- +# Kernel construction & prepareUntilUnitTest +# --------------------------------------------------------------------------- + + +class TestKernelPreparation: + def test_construction_stores_ast_as_list(self): + A = Tensor("A", (4, 4)) + B = Tensor("B", (4, 4)) + C = Tensor("C", (4, 4)) + kernel = Kernel("k", C["ij"] <= A["ik"] * B["kj"]) + # Internally the AST is always stored as a list (even for a + # single-statement kernel). + assert isinstance(kernel.ast, list) + assert len(kernel.ast) == 1 + + def test_construction_accepts_list_of_asts(self): + A = Tensor("A", (4, 4)) + B = Tensor("B", (4, 4)) + C = Tensor("C", (4, 4)) + ast_list = [ + C["ij"] <= A["ij"], + B["ij"] <= C["ij"], + ] + kernel = Kernel("k", ast_list) + assert len(kernel.ast) == 2 + + def test_default_target_is_cpu(self): + A = Tensor("A", (2, 2)) + kernel = Kernel("k", A["ij"] <= A["ij"]) + assert kernel.target == "cpu" + + def test_rejects_invalid_target(self): + A = Tensor("A", (2, 2)) + with pytest.raises(ValueError, match="target platform"): + Kernel("k", A["ij"] <= A["ij"], target="fpga") + + def test_prepareUntilUnitTest_populates_cfg(self, arch): + A = Tensor("A", (4, 4)) + B = Tensor("B", (4, 4)) + C = Tensor("C", (4, 4)) + kernel = Kernel("k", C["ij"] <= A["ik"] * B["kj"]) + assert kernel.cfg is None + kernel.prepareUntilUnitTest() + # After prepare, cfg is populated and each ProgramPoint has a + # live set (LivenessAnalysis has run). + assert kernel.cfg is not None + assert all(pp.live is not None for pp in kernel.cfg) + + def test_prepareUntilCodeGen_populates_nonzero_flops(self, arch): + from yateto.ast.cost import BoundingBoxCostEstimator + A = Tensor("A", (8, 8)) + B = Tensor("B", (8, 8)) + C = Tensor("C", (8, 8)) + kernel = Kernel("k", C["ij"] <= A["ik"] * B["kj"]) + kernel.prepareUntilUnitTest() + kernel.prepareUntilCodeGen(BoundingBoxCostEstimator, enableFusedGemm=False) + # The exact flop count for a dense 8x8 matmul is 960 - same + # value we pinned down in test_ast_visitor. + assert kernel.nonZeroFlops == 960 + + +# --------------------------------------------------------------------------- +# Prefetch argument +# --------------------------------------------------------------------------- + + +class TestKernelPrefetch: + def test_accepts_tensor(self): + A = Tensor("A", (4, 4)) + P = Tensor("P", (4, 4)) + kernel = Kernel("k", A["ij"] <= A["ij"], prefetch=P) + # kernel.prefetch() is stored as a list internally. + assert kernel.prefetch() == [P] + + def test_accepts_list_of_tensors(self): + A = Tensor("A", (4, 4)) + P1 = Tensor("P1", (4, 4)) + P2 = Tensor("P2", (4, 4)) + kernel = Kernel("k", A["ij"] <= A["ij"], prefetch=[P1, P2]) + assert kernel.prefetch() == [P1, P2] + + def test_rejects_invalid_prefetch(self): + A = Tensor("A", (4, 4)) + with pytest.raises(ValueError, match="Prefetch must"): + Kernel("k", A["ij"] <= A["ij"], prefetch="some string") + + +# --------------------------------------------------------------------------- +# Parameter spaces (helpers) +# --------------------------------------------------------------------------- + + +class TestParameterSpace: + def test_simple_parameter_space(self): + # simpleParameterSpace(a, b) -> cartesian product range(a) x range(b) + ps = simpleParameterSpace(2, 3) + assert len(ps) == 6 + assert (0, 0) in ps + assert (1, 2) in ps + assert (1, 3) not in ps + + def test_parameter_space_from_ranges(self): + ps = parameterSpaceFromRanges([0, 2], [1, 3]) + assert sorted(ps) == [(0, 1), (0, 3), (2, 1), (2, 3)] + + +# --------------------------------------------------------------------------- +# Generator - registration +# --------------------------------------------------------------------------- + + +class TestGeneratorRegistration: + def test_add_kernel(self, arch): + A = Tensor("A", (4, 4)) + B = Tensor("B", (4, 4)) + C = Tensor("C", (4, 4)) + g = Generator(arch) + g.add("matmul", C["ij"] <= A["ik"] * B["kj"]) + assert len(g.kernels()) == 1 + assert g.kernels()[0].name == "matmul" + + def test_add_family_member_creates_family(self, arch): + A = Tensor("A", (4, 4)) + B = Tensor("B", (4, 4)) + C = Tensor("C", (4, 4)) + g = Generator(arch) + # "foo(0)" is a family-indexed kernel; the generator should detect + # that and create a ``KernelFamily`` rather than a single Kernel. + g.add("foo(0)", C["ij"] <= A["ij"] + B["ij"]) + g.add("foo(1)", C["ij"] <= A["ij"] - B["ij"]) + # foo(0) and foo(1) both belong to the "foo" family. + assert len(g.kernels()) == 2 + # Family dispatches by internal renaming: we don't see "foo(0)" + # as a top-level kernel. + assert not any(k.name == "foo(0)" for k in g.kernels()) + + def test_add_rejects_invalid_kernel_name(self, arch): + A = Tensor("A", (4, 4)) + g = Generator(arch) + with pytest.raises(ValueError, match="Kernel name invalid"): + g.add("0bad", A["ij"] <= A["ij"]) + + def test_arch_is_attached(self, arch): + g = Generator(arch) + assert g.arch() is arch + + +# --------------------------------------------------------------------------- +# KernelFamily +# --------------------------------------------------------------------------- + + +class TestKernelFamily: + def test_linear_dispatch_math(self): + # The family's linear index formula: index = sum_i p_i * stride_i + # Used to map a multi-index group to a single numeric id. + stride = (1, 3, 9) + assert KernelFamily.linear(stride, (0, 0, 0)) == 0 + assert KernelFamily.linear(stride, (1, 0, 0)) == 1 + assert KernelFamily.linear(stride, (0, 1, 0)) == 3 + assert KernelFamily.linear(stride, (1, 1, 1)) == 1 + 3 + 9 + + def test_family_base_name(self): + assert KernelFamily.baseName("foo(3)") == "foo" + + def test_family_group_extraction(self): + assert KernelFamily.group("foo(5)") == 5 + + def test_stride_default(self): + f = KernelFamily() + assert f.stride() == (1,) + + +# --------------------------------------------------------------------------- +# Generator.generate smoke test +# --------------------------------------------------------------------------- + + +class TestGeneratorGenerateSmoke: + """``Generator.generate`` emits C++ files. We don't compile them + here, but the call should run end-to-end without error on a simple + kernel and the expected output files must materialise. + """ + + def test_generate_writes_expected_files(self, arch, tmp_path): + A = Tensor("A", (8, 8)) + B = Tensor("B", (8, 8)) + C = Tensor("C", (8, 8)) + g = Generator(arch) + g.add("matmul", C["ij"] <= A["ik"] * B["kj"]) + g.generate(str(tmp_path)) + + # These are the canonical files emitted by the generator. + # (The exact list is documented in ``Generator.FileNames``.) + expected = { + "tensor.h", + "tensor.cpp", + "init.h", + "init.cpp", + "kernel.h", + "kernel.cpp", + } + present = set(os.listdir(str(tmp_path))) + missing = expected - present + assert not missing, f"missing files: {missing}" + + def test_generate_writes_kernel_class(self, arch, tmp_path): + # Regression: the generated ``kernel.h`` must declare a struct + # named like our kernel. This catches gross breakage in the + # codegen without requiring a C++ compiler. + A = Tensor("A", (8, 8)) + B = Tensor("B", (8, 8)) + C = Tensor("C", (8, 8)) + g = Generator(arch) + g.add("matmul", C["ij"] <= A["ik"] * B["kj"]) + g.generate(str(tmp_path)) + + kernel_h = (tmp_path / "kernel.h").read_text() + # The generator emits a ``struct matmul`` in a ``namespace kernel``. + assert "matmul" in kernel_h + assert "namespace kernel" in kernel_h or "kernel::" in kernel_h diff --git a/tests/pytest/test_import.py b/tests/pytest/test_import.py new file mode 100644 index 0000000..6f3bf21 --- /dev/null +++ b/tests/pytest/test_import.py @@ -0,0 +1,101 @@ +""" +Importability regression tests. + +These tests don't exercise any algorithm - they just import the top-level +``yateto`` package under various conditions and check that the basic +surface area is intact. Boring, but very effective at catching stupid +regressions (a stray ``print`` statement, a broken relative import, a +missing ``__init__``, ...). + +They also act as the canary for environment-level issues. +""" +from __future__ import annotations + +import importlib +import sys +import subprocess +import textwrap + +import pytest + + +# --------------------------------------------------------------------------- +# Top-level API surface +# --------------------------------------------------------------------------- + + +class TestTopLevelAPI: + def test_package_imports(self): + import yateto + assert yateto is not None + + def test_Tensor_exported(self): + from yateto import Tensor # noqa: F401 + + def test_Scalar_exported(self): + from yateto import Scalar # noqa: F401 + + def test_Generator_exported(self): + from yateto import Generator # noqa: F401 + + def test_arch_helpers_exported(self): + from yateto import ( # noqa: F401 + useArchitectureIdentifiedBy, + deriveArchitecture, + HostArchDefinition, + DeviceArchDefinition, + fixArchitectureGlobal, + ) + + def test_parameter_space_helpers_exported(self): + from yateto import simpleParameterSpace, parameterSpaceFromRanges # noqa: F401 + + def test_GlobalRoutineCache_exported(self): + from yateto import GlobalRoutineCache # noqa: F401 + +# --------------------------------------------------------------------------- +# Submodule round-trip +# --------------------------------------------------------------------------- + + +class TestSubmodules: + """The yateto package is composed of many submodules. A simple + ``reload()`` round-trip on each one catches most cases where a + module fails to import due to a typo, a missing dependency, or a + cyclic import. + """ + + @pytest.mark.parametrize( + "modname", + [ + "yateto", + "yateto.aspp", + "yateto.arch", + "yateto.memory", + "yateto.type", + "yateto.generator", + "yateto.ast", + "yateto.ast.node", + "yateto.ast.indices", + "yateto.ast.visitor", + "yateto.ast.transformer", + "yateto.ast.cost", + "yateto.ast.opt", + "yateto.ast.log", + "yateto.controlflow", + "yateto.controlflow.graph", + "yateto.controlflow.visitor", + "yateto.controlflow.transformer", + ], + ) + def test_submodule_imports(self, modname): + mod = importlib.import_module(modname) + assert mod is not None + + def test_reload_roundtrip(self): + # A gentler check than the subprocess test above: reimport the + # top-level package and make sure the main types survive. + import yateto + importlib.reload(yateto) + from yateto import Tensor, Scalar, Generator + assert Tensor and Scalar and Generator diff --git a/tests/pytest/test_indices.py b/tests/pytest/test_indices.py new file mode 100644 index 0000000..1035329 --- /dev/null +++ b/tests/pytest/test_indices.py @@ -0,0 +1,295 @@ +""" +Tests for ``yateto.ast.indices`` - ``Indices``, ``Range``, ``BoundingBox``, +``LoGCost``. + +These are the type-theoretic bookkeeping objects of the Einstein-notation +DSL. Nearly every AST transformer and every cost estimator touches them, +so bugs here tend to surface as confusing downstream failures. Worth +nailing down with direct unit tests. +""" +from __future__ import annotations + +import pytest + +from yateto.ast.indices import Indices, Range, BoundingBox, LoGCost + + +# --------------------------------------------------------------------------- +# Indices construction and basic invariants +# --------------------------------------------------------------------------- + + +class TestIndicesConstruction: + def test_basic(self): + idx = Indices("ij", (4, 5)) + assert str(idx) == "ij" + assert len(idx) == 2 + assert idx.shape() == (4, 5) + assert idx.indexSize("i") == 4 + assert idx.indexSize("j") == 5 + + def test_empty_indices_for_scalar(self): + idx = Indices("", ()) + assert str(idx) == "" + assert len(idx) == 0 + assert idx.shape() == () + + def test_default_constructor_is_empty(self): + assert len(Indices()) == 0 + + def test_repeated_index_names_rejected(self): + with pytest.raises(AssertionError, match="Repeated indices"): + Indices("ii", (4, 4)) + + def test_shape_length_mismatch_rejected(self): + with pytest.raises(AssertionError, match="do not match tensor shape"): + Indices("ij", (4, 5, 6)) + with pytest.raises(AssertionError, match="do not match tensor shape"): + Indices("ijk", (4, 5)) + + +# --------------------------------------------------------------------------- +# Set-like operations +# --------------------------------------------------------------------------- + + +class TestIndicesSetOps: + def test_intersection_returns_raw_set(self): + a = Indices("ij", (3, 4)) + b = Indices("jk", (4, 5)) + # Intersection is a plain ``set``, not an ``Indices`` object. + # This is how ``Einsum`` / ``Product`` identify contraction indices. + assert a & b == {"j"} + # Commutative + assert b & a == {"j"} + + def test_difference_returns_indices(self): + a = Indices("ijk", (3, 4, 5)) + b = Indices("jk", (4, 5)) + diff = a - b + assert isinstance(diff, Indices) + assert str(diff) == "i" + assert diff.shape() == (3,) + + def test_difference_preserves_order(self): + # Important - strength reduction + LoG rely on this. + a = Indices("abcd", (1, 2, 3, 4)) + assert str(a - Indices("bd", (2, 4))) == "ac" + + def test_le_is_subset_with_matching_sizes(self): + a = Indices("ij", (3, 4)) + big = Indices("ijk", (3, 4, 5)) + assert a <= big + assert not (big <= a) + + def test_le_rejects_mismatched_sizes(self): + a = Indices("ij", (3, 4)) + bad = Indices("ij", (3, 5)) # same letters, different shape + assert not (a <= bad) + + def test_contains(self): + idx = Indices("ij", (3, 4)) + assert "i" in idx + assert "j" in idx + assert "k" not in idx + + +# --------------------------------------------------------------------------- +# Merging / permuting +# --------------------------------------------------------------------------- + + +class TestIndicesMerge: + def test_merged_concatenates(self): + a = Indices("ij", (3, 4)) + b = Indices("kl", (5, 6)) + merged = a.merged(b) + assert str(merged) == "ijkl" + assert merged.shape() == (3, 4, 5, 6) + + def test_merged_allows_duplicate_names_without_check(self): + # ``merged`` is the naive concat; ``mergeStrict`` is the checked + # variant that deduplicates. Exercise both contracts. + a = Indices("ij", (3, 4)) + # Duplicating "i" via ``merged`` triggers the "Repeated indices" + # assert in ``Indices.__init__``. + with pytest.raises(AssertionError): + a.merged(a) + + def test_merge_strict_dedupes(self): + a = Indices("ij", (3, 4)) + b = Indices("jk", (4, 5)) + merged = a.mergeStrict(b) + assert str(merged) == "ijk" + assert merged.shape() == (3, 4, 5) + + def test_merge_strict_rejects_incompatible_shape(self): + a = Indices("ij", (3, 4)) + b = Indices("jk", (9, 5)) # j has different size + with pytest.raises(AssertionError, match="Index merge failed"): + a.mergeStrict(b) + + def test_permuted_reorders(self): + idx = Indices("ijk", (3, 4, 5)) + p = idx.permuted("kij") + assert str(p) == "kij" + assert p.shape() == (5, 3, 4) + # Sizes must stay attached to the correct letters after permutation. + assert p.indexSize("i") == 3 + assert p.indexSize("j") == 4 + assert p.indexSize("k") == 5 + + def test_sorted(self): + idx = Indices("cab", (1, 2, 3)) + s = idx.sorted() + assert str(s) == "abc" + assert s.indexSize("a") == 2 + + +# --------------------------------------------------------------------------- +# Positions +# --------------------------------------------------------------------------- + + +class TestIndicesPositions: + def test_find(self): + idx = Indices("ijk", (3, 4, 5)) + assert idx.find("i") == 0 + assert idx.find("j") == 1 + assert idx.find("k") == 2 + + def test_positions_sorted_by_default(self): + idx = Indices("ijk", (3, 4, 5)) + # Positions are returned sorted by default - this is important for + # LoG (m/n/k indices must be a contiguous range to be fuseable). + assert idx.positions("kj") == [1, 2] + assert idx.positions(["j", "i"]) == [0, 1] + + def test_positions_unsorted(self): + idx = Indices("ijk", (3, 4, 5)) + assert idx.positions("kj", sort=False) == [2, 1] + + +# --------------------------------------------------------------------------- +# Hash / equality +# --------------------------------------------------------------------------- + + +class TestIndicesEquality: + def test_hashable(self): + # Indices are used as dict keys in the generator. + a = Indices("ij", (3, 4)) + b = Indices("ij", (3, 4)) + assert hash(a) == hash(b) + d = {a: 1} + assert d[b] == 1 + + def test_eq_considers_both_names_and_shape(self): + assert Indices("ij", (3, 4)) == Indices("ij", (3, 4)) + assert Indices("ij", (3, 4)) != Indices("ji", (3, 4)) + assert Indices("ij", (3, 4)) != Indices("ij", (3, 5)) + + +# --------------------------------------------------------------------------- +# Range +# --------------------------------------------------------------------------- + + +class TestRange: + def test_size(self): + assert Range(3, 8).size() == 5 + assert Range(0, 0).size() == 0 + + def test_intersection(self): + r = Range(2, 8) & Range(5, 10) + assert r.start == 5 + assert r.stop == 8 + + def test_union(self): + r = Range(2, 4) | Range(6, 10) + assert r.start == 2 + assert r.stop == 10 + + def test_contains(self): + outer = Range(0, 10) + assert Range(2, 5) in outer + assert Range(5, 10) in outer + assert Range(0, 11) not in outer + + def test_iter_enumerates_range(self): + assert list(Range(2, 5)) == [2, 3, 4] + + def test_eq(self): + assert Range(1, 5) == Range(1, 5) + assert Range(1, 5) != Range(1, 6) + + +# --------------------------------------------------------------------------- +# BoundingBox +# --------------------------------------------------------------------------- + + +class TestBoundingBox: + def test_length_and_iter(self): + bb = BoundingBox([Range(0, 3), Range(0, 4)]) + assert len(bb) == 2 + assert [r.size() for r in bb] == [3, 4] + + def test_size_is_volume(self): + assert BoundingBox([Range(0, 3), Range(0, 4)]).size() == 12 + # Empty box has size 1 (the empty product) - needed by scalar ops. + assert BoundingBox([]).size() == 1 + + def test_contains_point(self): + bb = BoundingBox([Range(0, 3), Range(0, 4)]) + assert (1, 2) in bb + # Wrong arity is a rejection, not a crash + assert (1, 2, 3) not in bb + + def test_fromSpp_dense(self): + # Build from a dense sparsity pattern - each dimension's range + # spans the whole shape (with an off-by-one because ``nnzbounds`` + # returns inclusive bounds). + from yateto.aspp import dense + bb = BoundingBox.fromSpp(dense((3, 4))) + assert len(bb) == 2 + assert bb[0].start == 0 and bb[0].stop == 3 + assert bb[1].start == 0 and bb[1].stop == 4 + + +# --------------------------------------------------------------------------- +# LoGCost - the tuple cost model used for GEMM variant selection +# --------------------------------------------------------------------------- + + +class TestLoGCost: + def test_identity_is_zero_cost(self): + c = LoGCost.addIdentity() + assert c == LoGCost(0, 0, 0, 0) + + def test_addition_is_componentwise(self): + a = LoGCost(1, 2, 3, 4) + b = LoGCost(10, 20, 30, 40) + c = a + b + # We can't access internals, but equality with the expected sum works + assert c == LoGCost(11, 22, 33, 44) + + def test_lower_stride_is_cheaper(self): + # Unit-stride (0) beats non-unit (1) + assert LoGCost(0, 0, 0, 0) < LoGCost(1, 0, 0, 0) + + def test_more_fused_indices_is_cheaper(self): + # fused_indices contributes negatively to the cost tuple, i.e. more + # fused indices is better. + assert LoGCost(0, 0, 0, 2) < LoGCost(0, 0, 0, 1) + + def test_fewer_transposes_is_cheaper(self): + assert LoGCost(0, 0, 0, 0) < LoGCost(0, 1, 0, 0) + assert LoGCost(0, 1, 0, 0) < LoGCost(0, 2, 0, 0) + + def test_tiebreak_prefers_fewer_left_transposes(self): + # When the summed-transposes match, the comparator falls back to + # ``_leftTranspose``. A > B iff A.leftTranspose > B.leftTranspose. + a = LoGCost(0, 0, 1, 0) + b = LoGCost(0, 1, 0, 0) + assert a < b diff --git a/tests/pytest/test_memory.py b/tests/pytest/test_memory.py new file mode 100644 index 0000000..2d651c6 --- /dev/null +++ b/tests/pytest/test_memory.py @@ -0,0 +1,214 @@ +""" +Tests for ``yateto.memory`` - the dense memory layout backend. + +``DenseMemoryLayout`` is Yateto's representation of how a tensor is laid +out in C++ memory: shape, bounding box (the "interesting" sub-rectangle +inside the full shape, e.g. the non-zero rows of a sparse matrix), +column-major strides, and optional leading-dimension alignment for +vectorisation. + +The layout drives: + +* address computation inside generated C++ code +* which dimensions may be fused into a single GEMM "m" / "n" / "k" +* whether a tensor can be vectorised along a given axis +* how permutations propagate through strength-reduced AST nodes +""" +from __future__ import annotations + +import pytest + +from yateto.ast.indices import BoundingBox, Indices, Range +from yateto.aspp import dense as aspp_dense, general as aspp_general +from yateto.memory import DenseMemoryLayout + +import numpy as np + + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +class TestConstruction: + def test_default_bbox_is_full_shape(self): + ml = DenseMemoryLayout((3, 4)) + assert ml.shape() == (3, 4) + assert ml.bbox() == BoundingBox([Range(0, 3), Range(0, 4)]) + + def test_explicit_bbox(self): + bb = BoundingBox([Range(1, 3), Range(0, 2)]) + ml = DenseMemoryLayout((4, 4), boundingBox=bb) + # Strides are computed from the bbox, not the full shape: only + # the non-zero sub-rectangle of an optimised sparse tensor needs + # storage. + assert ml.stride()[0] == 1 + assert ml.stride()[1] == bb[0].size() + + def test_from_dense_spp(self): + ml = DenseMemoryLayout.fromSpp(aspp_dense((5, 6))) + assert ml.shape() == (5, 6) + # Fully dense -> bounding box equals full shape. + assert ml.bbox() == BoundingBox([Range(0, 5), Range(0, 6)]) + + def test_from_sparse_spp_shrinks_bbox(self): + # Only a small sub-block is nonzero; the layout should contract + # its bounding box accordingly. + arr = np.zeros((5, 6), dtype=bool) + arr[1:3, 2:5] = True + ml = DenseMemoryLayout.fromSpp(aspp_general(arr)) + assert ml.shape() == (5, 6) + assert ml.bbox()[0] == Range(1, 3) + assert ml.bbox()[1] == Range(2, 5) + + +# --------------------------------------------------------------------------- +# Strides and addresses +# --------------------------------------------------------------------------- + + +class TestStridesAndAddresses: + def test_strides_are_column_major(self): + # Yateto uses column-major (Fortran) order so that the leading + # dimension is unit-stride, matching how GEMMs expect matrices. + ml = DenseMemoryLayout((3, 4, 5)) + assert ml.stride() == (1, 3, 12) + + def test_address_linear(self): + ml = DenseMemoryLayout((3, 4)) + # Column-major: linear = i + 3*j + assert ml.address((0, 0)) == 0 + assert ml.address((1, 0)) == 1 + assert ml.address((0, 1)) == 3 + assert ml.address((2, 3)) == 2 + 3 * 3 + + def test_address_with_bbox_offset(self): + # When the bbox doesn't start at zero, the address is relative to + # the bbox start (subtracts bbox.start per dim). + bb = BoundingBox([Range(1, 4), Range(2, 6)]) + ml = DenseMemoryLayout((4, 6), boundingBox=bb) + assert ml.address((1, 2)) == 0 + # (1,2) is the bbox origin - linear address 0. + assert ml.address((2, 2)) == 1 # unit stride along first dim + assert ml.address((1, 3)) == ml.stride()[1] + + def test_address_rejects_out_of_bbox(self): + bb = BoundingBox([Range(1, 4), Range(2, 6)]) + ml = DenseMemoryLayout((4, 6), boundingBox=bb) + with pytest.raises(AssertionError): + ml.address((0, 0)) + + def test_required_reals(self): + # ``requiredReals`` is the allocation size: stride[-1] * bbox[-1].size() + ml = DenseMemoryLayout((3, 4)) + assert ml.requiredReals() == 12 + # Sparse variant with smaller bbox allocates less. + bb = BoundingBox([Range(0, 2), Range(0, 3)]) + ml = DenseMemoryLayout((3, 4), boundingBox=bb) + assert ml.requiredReals() == 6 + + +# --------------------------------------------------------------------------- +# Alignment +# --------------------------------------------------------------------------- + + +class TestAlignment: + def test_alignment_false_without_arch(self): + # Without an architecture registered globally, the layout cannot + # claim aligned-stride and ``mayVectorizeDim`` returns False. + ml = DenseMemoryLayout((5, 3)) + assert ml.alignedStride() is False + assert ml.mayVectorizeDim(0) is False + + def test_aligned_stride_with_arch(self, arch): + # ``dhsw`` has 32B alignment -> 4 aligned doubles. A 5-entry + # leading dim gets extended to 8 when alignStride=True. + ml = DenseMemoryLayout((5, 3), alignStride=True) + assert ml.bbox()[0].size() == 8 # 5 -> aligned up to 8 + assert ml.alignedStride() is True + + def test_may_vectorize_when_dim_is_aligned(self, arch): + ml = DenseMemoryLayout((8, 3)) # 8 is a multiple of 4 + assert ml.mayVectorizeDim(0) is True + + def test_may_not_vectorize_when_dim_misaligned(self, arch): + ml = DenseMemoryLayout((7, 3)) + assert ml.mayVectorizeDim(0) is False + + +# --------------------------------------------------------------------------- +# mayFuse - contiguity predicate for LoopOverGEMM +# --------------------------------------------------------------------------- + + +class TestFusion: + def test_single_position_always_fuses(self): + ml = DenseMemoryLayout((3, 4, 5)) + assert ml.mayFuse([1]) is True + + def test_consecutive_indices_fuse_in_column_major(self): + # In column-major layout, stride[i+1] == shape[i] * stride[i]. + # So consecutive positions fuse - this is what lets Yateto fold + # multi-index GEMM operands into a single "m" dimension. + ml = DenseMemoryLayout((3, 4, 5)) + assert ml.mayFuse([0, 1]) is True + assert ml.mayFuse([1, 2]) is True + assert ml.mayFuse([0, 1, 2]) is True + + def test_non_consecutive_does_not_fuse(self): + ml = DenseMemoryLayout((3, 4, 5)) + # Skipping the middle index breaks contiguity. + assert ml.mayFuse([0, 2]) is False + + +# --------------------------------------------------------------------------- +# Permutation +# --------------------------------------------------------------------------- + + +class TestPermutation: + def test_permute_reorders_shape_and_bbox(self): + bb = BoundingBox([Range(1, 3), Range(0, 4), Range(0, 5)]) + ml = DenseMemoryLayout((3, 4, 5), boundingBox=bb) + pm = ml.permuted((2, 0, 1)) + # Shape and bbox must follow the permutation. + assert pm.shape() == (5, 3, 4) + assert pm.bbox()[0] == Range(0, 5) + assert pm.bbox()[1] == Range(1, 3) + assert pm.bbox()[2] == Range(0, 4) + + +# --------------------------------------------------------------------------- +# Address string formatting (used to emit C++) +# --------------------------------------------------------------------------- + + +class TestAddressString: + def test_address_string_simple(self): + ml = DenseMemoryLayout((3, 4)) + idx = Indices("ij", (3, 4)) + s = ml.addressString(idx) + # Should mention both index letters with their strides. + assert "_i" in s + assert "_j" in s + + def test_scalar_address_string(self): + ml = DenseMemoryLayout(()) # 0-D tensor + idx = Indices("", ()) + assert ml.addressString(idx) == "0" + + +# --------------------------------------------------------------------------- +# Global arch state isolation (via the fixture) +# --------------------------------------------------------------------------- + + +class TestAlignmentArchIsolation: + def test_fixture_resets_arch(self, arch): + assert DenseMemoryLayout.ALIGNMENT_ARCH is not None + + def test_fresh_test_starts_without_arch(self): + # This test uses no ``arch`` fixture and must see a reset global. + # (If this fails, the ``arch`` fixture's teardown is broken.) + assert DenseMemoryLayout.ALIGNMENT_ARCH is None diff --git a/tests/pytest/test_opt.py b/tests/pytest/test_opt.py new file mode 100644 index 0000000..812879a --- /dev/null +++ b/tests/pytest/test_opt.py @@ -0,0 +1,174 @@ +""" +Direct tests for ``yateto.ast.opt.strengthReduction``. + +``strengthReduction`` is Yateto's well-pruned exhaustive search for the +optimal pairing of tensor operands in a multi-way contraction (cf. +Lam et al., 1997 - referenced in the paper). Given a list of terms +and target indices, it returns an AST built from nested +``Product`` / ``IndexSum`` nodes whose total cost (per the supplied +cost estimator) is minimal. + +The transformer ``ast.transformer.StrengthReduction`` wraps this into a +visitor pass. We already exercise that end-to-end in +``test_ast_transformer.py``; here we poke the lower-level function +directly to lock in specific decisions (which pair to contract first, +what shape the tree should have). +""" +from __future__ import annotations + +import pytest + +from yateto import Tensor +from yateto.ast.cost import ShapeCostEstimator +from yateto.ast.indices import Indices +from yateto.ast.node import IndexedTensor, IndexSum, Product +from yateto.ast.opt import strengthReduction + + +def _it(name, shape, letters): + """Shortcut: an IndexedTensor for ``Tensor(name, shape)[letters]``.""" + return IndexedTensor(Tensor(name, shape), letters) + + +def _has(cls, node): + """Recursive instance-of check.""" + if isinstance(node, cls): + return True + return any(_has(cls, child) for child in node) + + +def _count(cls, node): + """Recursive count of nodes of ``cls``.""" + n = 1 if isinstance(node, cls) else 0 + for c in node: + n += _count(cls, c) + return n + + +# --------------------------------------------------------------------------- +# Trivial degenerate cases +# --------------------------------------------------------------------------- + + +class TestDegenerate: + def test_single_term_is_returned_as_is(self): + # With one operand there is nothing to pair, so the returned AST + # is just the single term itself. + a = _it("A", (3, 4), "ij") + tree = strengthReduction([a], Indices("ij", (3, 4)), ShapeCostEstimator()) + assert tree is a + + def test_single_term_with_summation_wraps_in_indexsum(self): + # Summing the only term over index ``j`` (not in target) must + # emit an IndexSum even though there's no Product. + a = _it("A", (3, 4), "ij") + tree = strengthReduction([a], Indices("i", (3,)), ShapeCostEstimator()) + assert isinstance(tree, IndexSum) + # The inner node is the original operand. + assert tree.term() is a + + +# --------------------------------------------------------------------------- +# Classical matmul (two operands) +# --------------------------------------------------------------------------- + + +class TestPairwiseMatmul: + def test_matmul_structure(self): + # A @ B, summing over k. Expected tree shape: + # IndexSum(Product(A, B), k) + A = _it("A", (3, 4), "ik") + B = _it("B", (4, 5), "kj") + tree = strengthReduction([A, B], Indices("ij", (3, 5)), + ShapeCostEstimator()) + + assert isinstance(tree, IndexSum) + assert isinstance(tree.term(), Product) + # Both leaves are preserved. + assert {c.tensor.name() for c in tree.term()} == {"A", "B"} + + +# --------------------------------------------------------------------------- +# Three-way contraction - the estimator decides the pairing +# --------------------------------------------------------------------------- + + +class TestThreeWay: + def test_picks_cheapest_pair_first(self): + # Three matrices: + # A: I x J (2x3) + # B: J x K (3x100) + # C: K x L (100x2) + # target: I x L + # + # The shape-cost estimator charges ``product(shape)`` per + # Product node. Pairing (B, C) first gives a J x L intermediate + # (3*100*2 = 600 weight); pairing (A, B) first gives a I x K + # intermediate (2*3*100 = 600 weight). Both are equally cheap + # for the first Product, but the next Product matters: + # + # ((AB) C) -> I x K * K x L -> 2*100*2 = 400 + # (A (BC)) -> I x J * J x L -> 2*3*2 = 12 + # + # So the second pairing is strictly cheaper. The algorithm + # must prefer it. + A = _it("A", (2, 3), "ij") + B = _it("B", (3, 100), "jk") + C = _it("C", (100, 2), "kl") + tree = strengthReduction([A, B, C], Indices("il", (2, 2)), + ShapeCostEstimator()) + # The outer Product should have A on one side and a BC sub-tree + # (possibly wrapped in an IndexSum) on the other. + # Drill in until we find a Product of Products (i.e. a Product + # whose argument is itself a Product-based subtree). + products = [] + def collect(node): + if isinstance(node, Product): + products.append(node) + for c in node: + collect(c) + collect(tree) + assert products, "No Product node found in the strength-reduced tree" + + # The cheaper pairing keeps the A leaf paired *last* with the + # (BC) intermediate. Equivalently: A is not paired together + # with B directly. Check by looking at the innermost Product + # - which should be BC, not AB. + innermost = min(products, key=lambda p: _count(Product, p)) + inner_names = {c.tensor.name() for c in innermost + if isinstance(c, IndexedTensor)} + # The innermost product is between leaves; neither side is A. + assert "A" not in inner_names, ( + f"Strength reducer picked the more expensive pairing: " + f"innermost Product is {inner_names}" + ) + + +# --------------------------------------------------------------------------- +# Sum-only case (e.g. trace) +# --------------------------------------------------------------------------- + + +class TestReductionOnly: + def test_trace_lowers_to_nested_indexsum(self): + # trace(A) = sum_i A[i,i]. This isn't really a valid Yateto + # input (index letters must be distinct in one tensor), but a + # scalar = sum_ij A[i] * B[j] is: + A = _it("A", (3,), "i") + B = _it("B", (4,), "j") + tree = strengthReduction([A, B], Indices("", ()), + ShapeCostEstimator()) + # Result has the two reduction axes lifted out as IndexSum. + assert _has(IndexSum, tree) + # And exactly one Product. + assert _count(Product, tree) == 1 + + def test_sum_indices_are_all_eliminated(self): + # Any target-free index of A must be summed out before we exit. + A = _it("A", (3, 4), "ij") + # target = "i" -> j is a sum index + tree = strengthReduction([A], Indices("i", (3,)), + ShapeCostEstimator()) + assert isinstance(tree, IndexSum) + # The sum index should be j. + assert str(tree.sumIndex()) == "j" diff --git a/tests/pytest/test_pipeline.py b/tests/pytest/test_pipeline.py new file mode 100644 index 0000000..4e0dd0f --- /dev/null +++ b/tests/pytest/test_pipeline.py @@ -0,0 +1,185 @@ +""" +End-to-end pipeline tests on the actual example kernels shipped with +Yateto. + +These are **integration** tests at the Python level: they load each +example from ``tests/code-gen/*.py`` (the same scripts that the CI +builds and runs with real C++ compilers), and push them through the +full Python pipeline up to ``prepareUntilCodeGen``. They also invoke +``Generator.generate`` into a scratch directory to make sure the C++ +emission itself doesn't crash. + +They **do not** compile or run the generated C++ - that's what the +``yateto-cpu.yml`` GitHub Actions workflow does. Our job here is to +catch Python-side regressions much faster. + +The tests are written to work on ``master``. +""" +from __future__ import annotations + +import importlib.util +import inspect +import os +import sys + +import pytest + +from yateto import Generator, Tensor +from yateto.generator import Kernel + + +TEST_SCRIPTS_DIR = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "code-gen") +) + + +def _available_scripts(*candidates): + """Filter a candidate list of example scripts to those that + actually exist in ``tests/code-gen/``. + """ + return [s for s in candidates + if os.path.isfile(os.path.join(TEST_SCRIPTS_DIR, f"{s}.py"))] + + +def _load_example_module(name): + """Load ``tests/code-gen/.py`` as a fresh module.""" + path = os.path.join(TEST_SCRIPTS_DIR, f"{name}.py") + spec = importlib.util.spec_from_file_location(f"_example_{name}", path) + mod = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = mod + spec.loader.exec_module(mod) + return mod + + +# --------------------------------------------------------------------------- +# Candidate scripts that may exist on either branch +# --------------------------------------------------------------------------- + +BASIC_SCRIPTS = _available_scripts("matmul", "minimal", "indices", "slicing") +REGRESSION_SCRIPTS = _available_scripts("regress") # master only + +# --------------------------------------------------------------------------- +# Examples +# --------------------------------------------------------------------------- + + +class TestExampleScripts: + """Each example registers its kernels via an ``add(g)`` entry point. + We feed each example into a fresh ``Generator`` and then push it + through the pipeline. + """ + + @pytest.mark.parametrize( + "script", + BASIC_SCRIPTS + REGRESSION_SCRIPTS, + ) + def test_example_prepares_without_error(self, arch, script): + mod = _load_example_module(script) + g = Generator(arch) + mod.add(g) + + # Every example must successfully run ``prepareUntilUnitTest`` + # on all its kernels. + for kernel in g.kernels(): + kernel.prepareUntilUnitTest() + assert kernel.cfg is not None + + @pytest.mark.parametrize("script", BASIC_SCRIPTS) + def test_example_generates_cpp(self, arch, tmp_path, script, request): + mod = _load_example_module(script) + g = Generator(arch) + mod.add(g) + + g.generate(str(tmp_path)) + out = os.listdir(str(tmp_path)) + # Always produced: + for expected in ("kernel.h", "kernel.cpp", "tensor.h", "init.h"): + assert expected in out, f"{expected} missing for {script}" + + +# --------------------------------------------------------------------------- +# Flop counts (regression) +# --------------------------------------------------------------------------- + + +class TestFlopRegression: + """Lock in the Yateto flop counts for a few canonical kernels. + + These are **regression** tests: if the cost model changes in a way + that alters the counts, the test will flag it. A change here is + not necessarily a bug - it just demands human attention and a + justified update. + """ + + def test_matmul_32x32(self, arch): + from yateto.ast.cost import BoundingBoxCostEstimator + mod = _load_example_module("matmul") + g = Generator(arch) + mod.add(g) + + # matmul.py registers 4 kernels (AB / ATB / ABT / ATBT) - all + # 32x32x32 dense matmuls; Yateto should report the same flop + # count for each. + counts = [] + for kernel in g.kernels(): + kernel.prepareUntilUnitTest() + kernel.prepareUntilCodeGen(BoundingBoxCostEstimator, enableFusedGemm=False) + counts.append(kernel.nonZeroFlops) + + # 2*N^3 - N^2 = 65536 - 1024 = 64512 (Yateto's accounting). + assert all(c == 64512 for c in counts), f"counts={counts}" + + def test_minimal_kernel(self, arch): + from yateto.ast.cost import BoundingBoxCostEstimator + mod = _load_example_module("minimal") + g = Generator(arch) + mod.add(g) + kernel = g.kernels()[0] + kernel.prepareUntilUnitTest() + kernel.prepareUntilCodeGen(BoundingBoxCostEstimator, enableFusedGemm=False) + # The ``minimal`` example's single kernel should come out with + # a sensible (positive) flop count. + assert kernel.nonZeroFlops > 0 + + +# --------------------------------------------------------------------------- +# The regression bundle +# --------------------------------------------------------------------------- + +class TestRegressions: + def test_regress_script_runs(self, arch, tmp_path): + # This script collects bug-fix regressions. Running it end to + # end protects against reintroducing those bugs. + mod = _load_example_module("regress") + g = Generator(arch) + mod.add(g) + g.generate(str(tmp_path)) + +# --------------------------------------------------------------------------- +# Family kernels via addFamily +# --------------------------------------------------------------------------- + + +class TestKernelFamily: + def test_add_family_iterates_over_parameter_space(self, arch): + from yateto import simpleParameterSpace + + A = Tensor("A", (4, 4)) + B = Tensor("B", (4, 4)) + C = Tensor("C", (4, 4)) + + def build(i, j): + # The specific formula doesn't matter - what matters is that + # the family machinery produces one valid AST per parameter. + if (i + j) % 2 == 0: + return C["ij"] <= A["ik"] * B["kj"] + return C["ij"] <= A["ij"] + B["ij"] + + g = Generator(arch) + g.addFamily("fam", simpleParameterSpace(2, 2), build) + # 2x2 parameter space -> 4 kernels. + assert len(g.kernels()) == 4 + # We don't run ``generate`` here because on the nonlinearity + # branch the Broadcast codegen is broken (see TestKnownBugs). + for kernel in g.kernels(): + kernel.prepareUntilUnitTest() diff --git a/tests/pytest/test_type.py b/tests/pytest/test_type.py new file mode 100644 index 0000000..dad5c19 --- /dev/null +++ b/tests/pytest/test_type.py @@ -0,0 +1,302 @@ +""" +Tests for ``yateto.type`` - the frontend types users interact with. + +This is the DSL's **surface** layer: ``Tensor``, ``Scalar``, and ``Collection`` +are the objects that the Python operator-overloading machinery wraps into an +AST. Name validation, grouping, sparsity storage and the shape invariants +belong to this layer - not to any of the downstream passes. +""" +from __future__ import annotations + +import numpy as np +import pytest + +from yateto import Tensor, Scalar +from yateto.type import Collection, IdentifiedType + + +# --------------------------------------------------------------------------- +# Name validation +# --------------------------------------------------------------------------- + + +class TestTensorNames: + """Tensor names follow a strict regexp: ``[(g1,g2,...)]``.""" + + @pytest.mark.parametrize( + "name", + [ + "A", + "AB", + "foo", + "A0", + "A_1", + "A(0)", + "A(1)", + "A(0,1)", + "A(10)", + "A(0,1,2,3)", + ], + ) + def test_valid_names_accepted(self, name): + Tensor(name, (2, 2)) # must not raise + + @pytest.mark.parametrize( + "name", + [ + "", # empty + "0A", # leading digit + "_A", # leading underscore + "A(01)", # leading zero in group index (regex disallows) + "A()", # empty group + "A(,)", # empty group index + "A(-1)", # negative group + "A-B", # invalid char + "A(1,)", # trailing comma + ], + ) + def test_invalid_names_rejected(self, name): + with pytest.raises(ValueError): + Tensor(name, (2, 2)) + + def test_group_extraction(self): + assert Tensor("A", (2, 2)).group() == () + assert Tensor("A(3)", (2, 2)).group() == (3,) + assert Tensor("A(1,2,3)", (2, 2)).group() == (1, 2, 3) + + def test_base_name_extraction(self): + assert Tensor("foo(1,2)", (2, 2)).baseName() == "foo" + assert Tensor("foo", (2, 2)).baseName() == "foo" + assert Tensor("foo42", (2, 2)).baseName() == "foo42" + + def test_is_valid_name_classmethod(self): + assert Tensor.isValidName("A(0)") + assert not Tensor.isValidName("0A") + + +# --------------------------------------------------------------------------- +# Shape invariants +# --------------------------------------------------------------------------- + + +class TestTensorShape: + def test_shape_must_be_tuple(self): + with pytest.raises(ValueError, match="shape must be a tuple"): + Tensor("A", [2, 2]) # list, not tuple + with pytest.raises(ValueError, match="shape must be a tuple"): + Tensor("A", 42) # int + + def test_shape_entries_must_be_positive(self): + with pytest.raises(ValueError, match="smaller than 1"): + Tensor("A", (0, 2)) + with pytest.raises(ValueError, match="smaller than 1"): + Tensor("A", (2, -1)) + + def test_shape_is_stored_as_tuple(self): + t = Tensor("A", (3, 4, 5)) + assert t.shape() == (3, 4, 5) + assert isinstance(t.shape(), tuple) + + def test_zero_dimensional_tensor_is_allowed(self): + # A scalar tensor: shape == (). Used by reductions. + t = Tensor("A", ()) + assert t.shape() == () + + +# --------------------------------------------------------------------------- +# Namespacing +# --------------------------------------------------------------------------- + + +class TestTensorNamespace: + def test_default_namespace_is_empty_string(self): + t = Tensor("A", (2, 2)) + assert t.namespace == "" + assert t.prefix() == "" + assert t.nameWithNamespace() == "A" + + def test_explicit_namespace_is_used_as_cpp_prefix(self): + t = Tensor("A", (2, 2), namespace="foo") + assert t.namespace == "foo" + assert t.prefix() == "foo::" + assert t.nameWithNamespace() == "foo::A" + assert t.baseNameWithNamespace() == "foo::A" + + def test_split_base_name_with_namespace(self): + # classmethod that undoes the ``foo::bar`` encoding + prefix, base = IdentifiedType.splitBasename("foo::bar") + assert prefix == "foo::" + assert base == "bar" + prefix, base = IdentifiedType.splitBasename("bar") + assert prefix == "" + assert base == "bar" + + +# --------------------------------------------------------------------------- +# Sparsity patterns +# --------------------------------------------------------------------------- + + +class TestTensorSparsity: + def test_dense_default(self): + t = Tensor("A", (3, 3)) + # Without an explicit spp, the tensor is dense. + assert t.spp().is_dense() + # Dense sparsity reports count_nonzero == total size. + assert t.spp().count_nonzero() == 9 + + def test_dict_sparsity_with_bool_values(self): + spp = {(0, 0): True, (1, 1): True, (2, 2): True} + t = Tensor("A", (3, 3), spp=spp) + assert t.spp().count_nonzero() == 3 + # No numerical values stored, as bool dict means pattern only + assert t.values() is None + + def test_dict_sparsity_with_float_values(self): + # Providing floats means Yateto also records the literal values. + spp = {(0, 0): 1.0, (1, 1): 2.0, (2, 2): 3.0} + t = Tensor("A", (3, 3), spp=spp) + assert t.spp().count_nonzero() == 3 + assert t.values() == {(0, 0): 1.0, (1, 1): 2.0, (2, 2): 3.0} + assert t.is_compute_constant() is True + + def test_ndarray_sparsity(self): + arr = np.eye(4, dtype=bool) + t = Tensor("A", (4, 4), spp=arr) + assert t.spp().count_nonzero() == 4 + + def test_float_ndarray_keeps_values(self): + arr = np.eye(3) # float64 identity + t = Tensor("A", (3, 3), spp=arr) + # Yateto extracts nonzero float values as strings (for codegen). + assert t.values() is not None + assert len(t.values()) == 3 + + def test_values_as_ndarray_roundtrip(self): + spp = {(0, 0): 1.5, (1, 2): -2.0} + t = Tensor("A", (3, 3), spp=spp) + arr = t.values_as_ndarray() + assert arr.shape == (3, 3) + assert arr[0, 0] == pytest.approx(1.5) + assert arr[1, 2] == pytest.approx(-2.0) + # All other entries remain zero. + arr[0, 0] = 0 + arr[1, 2] = 0 + assert np.all(arr == 0) + + def test_bad_sparsity_shape_raises(self): + wrong = np.ones((2, 2), dtype=bool) + with pytest.raises(ValueError): + Tensor("A", (3, 3), spp=wrong) + + def test_is_compute_constant_false_for_plain_tensor(self): + assert Tensor("A", (3, 3)).is_compute_constant() is False + + +# --------------------------------------------------------------------------- +# Tensor identity / hashing / equality +# --------------------------------------------------------------------------- + + +class TestTensorIdentity: + def test_hash_is_name_based(self): + # The hash is built from the tensor name, so two tensors with the + # same name can be put in a set even if they live in different + # scopes. This is what the codegen relies on. + assert hash(Tensor("A", (2, 2))) == hash(Tensor("A", (4, 4))) + + def test_equality_by_name(self): + t1 = Tensor("A", (2, 2)) + t2 = Tensor("A", (2, 2)) + assert t1 == t2 + + def test_equality_across_shapes_asserts(self): + # Yateto's ``__eq__`` asserts same shape/layout when names match - + # i.e. two tensors that share a name but differ structurally are + # detected as a bug in the user's code, not silently un-equal. + t1 = Tensor("A", (2, 2)) + t2 = Tensor("A", (3, 3)) + with pytest.raises(AssertionError): + t1 == t2 + + def test_inequality_by_name(self): + assert (Tensor("A", (2, 2)) == Tensor("B", (2, 2))) is False + + +# --------------------------------------------------------------------------- +# Memory layout attached to tensor +# --------------------------------------------------------------------------- + + +class TestTensorMemoryLayout: + def test_memory_layout_is_set(self): + t = Tensor("A", (3, 4)) + ml = t.memoryLayout() + assert ml is not None + assert ml.shape() == (3, 4) + + def test_aligned_stride_requires_arch(self, arch): + # With the arch fixture (sets DenseMemoryLayout's global alignment), + # we can request an aligned leading dimension. + t = Tensor("A", (5, 3), alignStride=True) + assert t.memoryLayout().alignedStride() is True + + +# --------------------------------------------------------------------------- +# Scalar +# --------------------------------------------------------------------------- + + +class TestScalar: + def test_basic(self): + s = Scalar("alpha") + assert s.name() == "alpha" + assert s.baseName() == "alpha" + + def test_scalar_group(self): + s = Scalar("alpha(2)") + assert s.group() == (2,) + + def test_invalid_scalar_name(self): + with pytest.raises(ValueError): + Scalar("0alpha") + + +# --------------------------------------------------------------------------- +# Collection +# --------------------------------------------------------------------------- + + +class TestCollection: + def test_set_and_get(self): + c = Collection() + c["foo"] = Tensor("foo", (2, 2)) + assert "foo" in c + assert c["foo"].name() == "foo" + + def test_containsName(self): + c = Collection() + c["A"] = Tensor("A", (2, 2)) + assert c.containsName("A") + assert not c.containsName("B") + + def test_containsName_rejects_invalid_names(self): + c = Collection() + with pytest.raises(ValueError): + c.containsName("0bad") + + def test_group_classmethod(self): + # Collection.group returns a single int for 1-tuple groups, a + # tuple otherwise, and () for non-grouped names. This mirrors how + # ``Generator.add`` dispatches ``A(0)`` vs ``A(0,1)``. + assert Collection.group("A") == () + assert Collection.group("A(3)") == 3 + assert Collection.group("A(1,2)") == (1, 2) + + def test_update_merges_collections(self): + a = Collection() + a["A"] = Tensor("A", (2, 2)) + b = Collection() + b["B"] = Tensor("B", (3, 3)) + a.update(b) + assert "A" in a and "B" in a diff --git a/yateto/arch.py b/yateto/arch.py index ff99cbc..2391fa8 100644 --- a/yateto/arch.py +++ b/yateto/arch.py @@ -57,7 +57,7 @@ def __init__(self, name (str): name of the compute (main) architecture e.g., skx, thunderx2t99, power9 sm_60, sm_61, etc., backend (str): backend name e.g., cpp, cuda, hip, oneapi, hipsycl - precision (str): either 'd' or 's' character which stands for 'double' or 'single' precision + precision (str): either 'd' (or 'f64') or 's' (or 'f32') which stands for 'double'/FP64 or 'single'/FP32 precision, respectively alignment (int): length of a cache line in bytes enablePrefetch (bool): indicates whether the compute (main) architecture supports data prefetching @@ -69,11 +69,13 @@ def __init__(self, self.host_name = host_name self.precision = precision.upper() - if self.precision == 'D': + if self.precision in ('D', 'F64'): + self.precision = 'D' self.bytesPerReal = 8 self.typename = 'double' self.epsilon = 2.22e-16 - elif self.precision == 'S': + elif self.precision in ('S', 'F32'): + self.precision = 'S' self.bytesPerReal = 4 self.typename = 'float' self.epsilon = 1.19e-7 diff --git a/yateto/codegen/copyscaleadd/factory.py b/yateto/codegen/copyscaleadd/factory.py index 8c936f6..707020b 100644 --- a/yateto/codegen/copyscaleadd/factory.py +++ b/yateto/codegen/copyscaleadd/factory.py @@ -3,7 +3,7 @@ from ...gemm_configuration import tinytc from .tinytc import CopyScaleAddTinytc -import importlib +import importlib.util gf_spec = importlib.util.find_spec('gemmforge') try: if gf_spec: diff --git a/yateto/generator.py b/yateto/generator.py index f12fe59..7270a1b 100644 --- a/yateto/generator.py +++ b/yateto/generator.py @@ -105,6 +105,9 @@ def prepareUntilCodeGen(self, cost_estimator, enableFusedGemm: bool): self.cfg = FindFusedGemms().visit(self.cfg) self.cfg = LivenessAnalysis().visit(self.cfg) + def prefetch(self): + return self._prefetch + class KernelFamily(object): GROUP_INDEX = r'\((0|[1-9]\d*)\)' VALID_NAME = r'^{}({})$'.format(Kernel.BASE_NAME, GROUP_INDEX) diff --git a/yateto/type.py b/yateto/type.py index 874645e..e18e9f7 100644 --- a/yateto/type.py +++ b/yateto/type.py @@ -108,7 +108,7 @@ def __init__(self, if spp is not None: if isinstance(spp, dict): - if not isinstance(next(iter(spp.values())), bool): + if not isinstance(next(iter(spp.values()), False), bool): self._values = spp npspp = zeros(shape, dtype=bool, order=aspp.general.NUMPY_DEFAULT_ORDER) for multiIndex, value in spp.items(): @@ -136,7 +136,7 @@ def setMemoryLayout(self, memoryLayoutClass, alignStride=False): def _setSparsityPattern(self, spp, setOnlyGroupSpp=False): if spp.shape != self._shape: - raise ValueError(name, 'The given Matrix\'s shape must match the shape specification.') + raise ValueError(self._name, 'The given Matrix\'s shape must match the shape specification.') spp = aspp.general(spp) if not isinstance(spp, aspp.ASpp) else spp if setOnlyGroupSpp == False: self._spp = spp