From d5581624fe567c99194173fca61a2dbb02139b6c Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Wed, 6 May 2026 16:21:46 -0400 Subject: [PATCH 1/6] Add monoid module (#653) * add monoid module * clean up * fix doctest * fix * wip * remove incorrect rule * add disjoint set tests and fix bug * lint * drop jax monoid defs * drop incorrect comment * add assert * reduce nondeterminism and add assertions * fix inconsistent stream numbering and missing constant factors --- effectful/internals/disjoint_set.py | 99 +++++ effectful/ops/monoid.py | 556 +++++++++++++++++++++++++++ effectful/ops/syntax.py | 78 ++++ pyproject.toml | 1 + tests/_monoid_helpers.py | 85 ++++ tests/test_internals_disjoint_set.py | 124 ++++++ tests/test_ops_monoid.py | 518 +++++++++++++++++++++++++ 7 files changed, 1461 insertions(+) create mode 100644 effectful/internals/disjoint_set.py create mode 100644 effectful/ops/monoid.py create mode 100644 tests/_monoid_helpers.py create mode 100644 tests/test_internals_disjoint_set.py create mode 100644 tests/test_ops_monoid.py diff --git a/effectful/internals/disjoint_set.py b/effectful/internals/disjoint_set.py new file mode 100644 index 000000000..73b5c5c52 --- /dev/null +++ b/effectful/internals/disjoint_set.py @@ -0,0 +1,99 @@ +class DisjointSet: + """Disjoint Set Union (Union-Find) data structure. + + Maintains a collection of disjoint sets over the integers 0..n-1, + supporting near-constant-time union and find operations via + path compression and union by rank. + + The amortized time complexity per operation is O(α(n)), where α + is the inverse Ackermann function (effectively constant for any + practical n). + + Example: + >>> dsu = DisjointSet(5) + >>> dsu.union(0, 1) + True + >>> dsu.union(1, 2) + True + >>> dsu.find(0) == dsu.find(2) + True + >>> dsu.find(0) == dsu.find(3) + False + """ + + def __init__(self, n): + """Initialize n singleton sets: {0}, {1}, ..., {n-1}. + + Args: + n: The number of elements. Elements are labeled 0..n-1. + """ + self.parent = list(range(n)) + self.rank = [0] * n + + def _validate(self, x): + if x < 0 or x >= len(self.parent): + raise IndexError(f"Element {x} out of bounds") + + def find(self, x): + """Return the representative (root) of the set containing x. + + Two elements belong to the same set if and only if they have + the same representative. Applies path compression: every node + traversed is re-parented directly to its grandparent, flattening + the tree to speed up future queries. + + Args: + x: The element to look up. + + Returns: + The root element of x's set. + """ + self._validate(x) + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] # path compression + x = self.parent[x] + return x + + def union(self, *elements): + """Merge the sets containing all given elements into one. + + Accepts any number of elements and unions them all together. + Uses union by rank: shallower trees are attached under the root + of the deeper one, keeping the combined tree shallow. + + Args: + *elements: Two or more elements to merge into a single set. + Calling with 0 or 1 elements is a no-op and returns False. + + Returns: + True if any merging occurred (i.e., at least two of the + elements were in different sets); False if all elements + were already in the same set or fewer than 2 were given. + """ + if len(elements) < 2: + return False + + merged = False + first = elements[0] + + for y in elements[1:]: + if self._union_pair(first, y): + merged = True + + return merged + + def _union_pair(self, x, y): + rx = self.find(x) + ry = self.find(y) + + if rx == ry: + return False + + if self.rank[rx] < self.rank[ry]: + rx, ry = ry, rx + + self.parent[ry] = rx + if self.rank[rx] == self.rank[ry]: + self.rank[rx] += 1 + + return True diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py new file mode 100644 index 000000000..58a10ba3d --- /dev/null +++ b/effectful/ops/monoid.py @@ -0,0 +1,556 @@ +import collections.abc +import functools +import itertools +import numbers +import typing +from collections import Counter, defaultdict +from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Sequence +from dataclasses import dataclass +from graphlib import TopologicalSorter +from typing import Annotated, Any + +from effectful.internals.disjoint_set import DisjointSet +from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler +from effectful.ops.syntax import ( + ObjectInterpretation, + Scoped, + _NumberTerm, + defdata, + implements, + iter_, + syntactic_eq, + syntactic_hash, +) +from effectful.ops.types import Interpretation, NotHandled, Operation, Term + +# Note: The streams value type should be something like Iterable[T], but some of +# our target stream types (e.g. jax.Array) are not subtypes of Iterable +type Streams[T] = Mapping[Operation[[], T], Any] + +type Body[T] = ( + Iterable[T] + | Callable[..., Body[T]] + | T + | Mapping[Any, Body[T]] + | Interpretation[T, Body[T]] +) + + +def order_streams[T](streams: Streams[T]) -> Iterable[tuple[Operation[[], T], Any]]: + """Determine an order to evaluate the streams based on their dependencies""" + stream_vars = set(streams.keys()) + dependencies = {k: fvsof(v) & stream_vars for k, v in streams.items()} + topo = TopologicalSorter(dependencies) + topo.prepare() + while topo.is_active(): + node_group = topo.get_ready() + for op in sorted(node_group): + yield (op, streams[op]) + topo.done(*node_group) + + +class Monoid[T]: + kernel: Operation[[T, T], T] + identity: T + + def __init__(self, kernel: Callable[[T, T], T], identity: T): + self.identity = identity + self.kernel = ( + kernel if isinstance(kernel, Operation) else Operation.define(kernel) + ) + + def __repr__(self): + return f"{type(self)}({self.kernel}, {self.identity})" + + @Operation.define + def plus[S: Body[T]](self, *args: S) -> S: + """Monoid addition with broadcasting over common collection types, + callables, and interpretations. + + """ + if not args: + return typing.cast(S, self.identity) + + if any(isinstance(x, Term) for x in args): + return typing.cast(S, defdata(self.plus, *args)) + + return self._plus(*args) + + @functools.singledispatchmethod + def _plus[S](self, *args: S) -> S: + return typing.cast(S, functools.reduce(self.kernel, args, self.identity)) + + @_plus.register(Sequence) + def _(self, *args): + return type(args[0])(self.plus(*vs) for vs in zip(*args, strict=True)) + + @_plus.register(Mapping) + def _(self, *args): + if isinstance(args[0], Interpretation): + keys = args[0].keys() + + for b in args[1:]: + if not isinstance(b, Interpretation): + raise TypeError(f"Expected interpretation but got {b}") + + b_keys = b.keys() + if not keys == b_keys: + raise ValueError( + f"Expected interpretation of {keys} but got {b_keys}" + ) + + result = {k: self.plus(*(handler(b)(b[k]) for b in args)) for k in keys} + return result + + for b in args[1:]: + if not isinstance(b, Mapping): + raise TypeError(f"Expected mapping but got {b}") + + all_values = collections.defaultdict(list) + for d in args: + for k, v in d.items(): + all_values[k].append(v) + result = {k: self.plus(*vs) for (k, vs) in all_values.items()} + return result + + @Operation.define + @functools.singledispatchmethod + def reduce[A, B, U: Body]( + self, + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], + ) -> Annotated[U, Scoped[B]]: + if callable(body): + return typing.cast(U, lambda *a, **k: self.reduce(body(*a, **k), streams)) + + def generator(loop_order) -> Iterator[Interpretation]: + if len(loop_order) == 0: + return + + stream_key = loop_order[0][0] + stream_values = evaluate(streams[stream_key]) + stream_values_iter = iter(stream_values) # type: ignore[arg-type] + + # If we try to iterate and get a term instead of a real + # iterator, give up + if isinstance(stream_values_iter, Term) and stream_values_iter.op is iter_: + raise NotHandled + + if len(loop_order) == 1: + for val in stream_values_iter: + yield {stream_key: functools.partial(lambda v: v, val)} + else: + for val in stream_values_iter: + intp: Interpretation = { + stream_key: functools.partial(lambda v: v, val) + } + with handler(intp): + for intp2 in generator(loop_order[1:]): + yield coproduct(intp, intp2) + + loop_order = list(order_streams(streams)) + try: + return self.plus( + *(handler(intp)(evaluate)(body) for intp in generator(loop_order)) + ) + except NotHandled: + return typing.cast(U, defdata(self.reduce, body, streams)) + + @reduce.register # type: ignore[attr-defined] + def _(self, body: Mapping, streams): + return {k: self.reduce(v, streams) for (k, v) in body.items()} + + @reduce.register # type: ignore[attr-defined] + def _(self, body: Sequence, streams): + return type(body)(self.reduce(x, streams) for x in body) # type:ignore[call-arg] + + @reduce.register # type: ignore[attr-defined] + def _(self, body: Generator, streams): + return (self.reduce(x, streams) for x in body) + + +class IdempotentMonoid[T](Monoid[T]): + @Operation.define + def plus[S: Body[T]](self, *args: S) -> S: + return super().plus(*args) + + @Operation.define + def reduce[A, B, U: Body]( + self, + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], + ) -> Annotated[U, Scoped[B]]: + return super().reduce(body, streams) + + +class CommutativeMonoid[T](Monoid[T]): + @Operation.define + def plus[S: Body[T]](self, *args: S) -> S: + return super().plus(*args) + + @Operation.define + def reduce[A, B, U: Body]( + self, + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], + ) -> Annotated[U, Scoped[B]]: + return super().reduce(body, streams) + + +class CommutativeMonoidWithZero[T](CommutativeMonoid[T]): + zero: T + + def __init__(self, kernel: Callable[[T, T], T], identity: T, zero: T): + super().__init__(kernel, identity) + self.zero = zero + + def __repr__(self): + return f"{type(self)}({self.kernel}, {self.identity}, {self.zero})" + + @Operation.define + def plus[S: Body[T]](self, *args: S) -> S: + return super().plus(*args) + + @Operation.define + def reduce[A, B, U: Body]( + self, + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], + ) -> Annotated[U, Scoped[B]]: + return super().reduce(body, streams) + + +class Semilattice[T](IdempotentMonoid[T], CommutativeMonoid[T]): + @Operation.define + def plus[S: Body[T]](self, *args: S) -> S: + return super().plus(*args) + + @Operation.define + def reduce[A, B, U: Body]( + self, + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], + ) -> Annotated[U, Scoped[B]]: + return super().reduce(body, streams) + + +@Operation.define +def _arg_min[T]( + a: tuple[numbers.Number, T | None], b: tuple[numbers.Number, T | None] +) -> tuple[numbers.Number, T | None]: + if isinstance(a[0], Term) or isinstance(b[0], Term): + raise NotHandled + return b if b[0] < a[0] else a # type: ignore + + +@Operation.define +def _arg_max[T]( + a: tuple[numbers.Number, T | None], b: tuple[numbers.Number, T | None] +) -> tuple[numbers.Number, T | None]: + if isinstance(a[0], Term) or isinstance(b[0], Term): + raise NotHandled + return b if b[0] > a[0] else a # type: ignore + + +Min = Semilattice(kernel=min, identity=float("inf")) +Max = Semilattice(kernel=max, identity=float("-inf")) +ArgMin = Monoid(kernel=_arg_min, identity=(float("inf"), None)) +ArgMax = Monoid(kernel=_arg_max, identity=(float("-inf"), None)) +Sum = CommutativeMonoid(kernel=_NumberTerm.__add__, identity=0) +Product = CommutativeMonoidWithZero(kernel=_NumberTerm.__mul__, identity=1, zero=0) + + +@dataclass +class _ExtensibleBinaryRelation[S, T]: + tuples: set[tuple[S, T]] + + def register(self, s: S, t: T) -> None: + self.tuples.add((s, t)) + + def __call__(self, s: S, t: T) -> bool: + return (s, t) in self.tuples + + +distributes_over = _ExtensibleBinaryRelation( + { + (Max.plus, Min.plus), + (Min.plus, Max.plus), + (Sum.plus, Min.plus), + (Sum.plus, Max.plus), + (Product.plus, Sum.plus), + } +) + + +class PlusEmpty(ObjectInterpretation): + """plus() = 0""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if not args: + return monoid.identity + return fwd() + + +class PlusSingle(ObjectInterpretation): + """plus(x) = x""" + + @implements(Monoid.plus) + def plus(self, _, *args): + if len(args) == 1: + return args[0] + return fwd() + + +class PlusIdentity(ObjectInterpretation): + """x₁ + ... + 0 + ... + xₙ = x₁ + ... + xₙ""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if any(x is monoid.identity for x in args): + return monoid.plus(*(x for x in args if x is not monoid.identity)) + return fwd() + + +class PlusAssoc(ObjectInterpretation): + """x + (y + z) = (x + y) + z = x + y + z""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if any(isinstance(x, Term) and x.op is monoid.plus for x in args): + flat_args = itertools.chain.from_iterable( + t.args if isinstance(t, Term) and t.op is monoid.plus else (t,) + for t in args + ) + assert len(args) > 0 + return monoid.plus(*flat_args) + return fwd() + + +class PlusDistr(ObjectInterpretation): + """x + (y * z) = x * y + x * z""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if any( + isinstance(x, Term) and distributes_over(monoid.plus, x.op) for x in args + ): + non_terms = [] + + # group terms by head operation + by_head_op = defaultdict(list) + for t in args: + if isinstance(t, Term): + by_head_op[t.op].append(t) + else: + non_terms.append(t) + + # distribute over each group + progress = False + final_sum = [] + for op, terms in by_head_op.items(): + if ( + len(terms) > 1 + and distributes_over(monoid.plus, op) + and not distributes_over(op, monoid.plus) + ): + progress = True + term_args = (t.args for t in terms) + dist_terms = ( + monoid.plus(*args) for args in itertools.product(*term_args) + ) + final_sum.append(op(*dist_terms)) + else: + final_sum += terms + if progress: + return monoid.plus(*non_terms, *final_sum) + return fwd() + + +class PlusZero(ObjectInterpretation): + """x₁ * ... * 0 * ... * xₙ = 0""" + + @implements(CommutativeMonoidWithZero.plus) + def plus(self, monoid, *args): + if any(x is monoid.zero for x in args): + return monoid.zero + return fwd() + + +class PlusConsecutiveDups(ObjectInterpretation): + """x ⊕ x ⊕ y = x ⊕ y""" + + @implements(IdempotentMonoid.plus) + def plus(self, monoid, *args): + dedup_args = ( + args[i] + for i in range(len(args)) + if i == 0 or not syntactic_eq(args[i - 1], args[i]) + ) + return fwd(monoid, *dedup_args) + + +class PlusDups(ObjectInterpretation): + """x ⊕ y ⊕ x = x ⊕ y""" + + @dataclass + class _HashableTerm: + term: Term + + def __eq__(self, other): + return syntactic_eq(self, other) + + def __hash__(self): + return syntactic_hash(self) + + @implements(Semilattice.plus) + def plus(self, monoid, *args): + # elim dups + args_count = Counter(self._HashableTerm(t) for t in args) + if len(args_count) < len(args): + dedup_args = [] + for t in args: + ht = self._HashableTerm(t) + if ht in args_count: + dedup_args.append(t) + del args_count[ht] + return fwd(monoid, *dedup_args) + return fwd() + + +NormalizePlusIntp = functools.reduce( + coproduct, + typing.cast( + list[Interpretation], + [ + PlusEmpty(), + PlusSingle(), + PlusIdentity(), + PlusAssoc(), + PlusDistr(), + PlusZero(), + PlusConsecutiveDups(), + PlusDups(), + ], + ), +) + + +class ReduceNoStreams(ObjectInterpretation): + """Implements the identity + reduce(R, ∅, body) = 0 + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, _, streams): + if len(streams) == 0: + return monoid.identity + return fwd() + + +class ReduceFusion(ObjectInterpretation): + """Implements the identity + reduce(R, S1, reduce(R, S2, body)) = reduce(R, S1 ∪ S2, body) + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) and body.op == monoid.reduce: + return monoid.reduce(body.args[0], streams | body.args[1]) + return fwd() + + +class ReduceSplit(ObjectInterpretation): + """Implements the identity + reduce(R, S, b1 + ... + bn) = reduce(R, S, b1) + ... + reduce(R, S, bn) + """ + + @implements(CommutativeMonoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) and body.op == monoid.plus: + return monoid.plus(*(monoid.reduce(x, streams) for x in body.args)) + return fwd() + + +class ReduceFactorization(ObjectInterpretation): + """ + Implements factorization of independent terms. + For example, when having two independent distributions, + we can rewrite their marginalization as: + ∫p(x)⋅q(y)dxdy => ∫p(x)dx ⋅ ∫q(y)dy + + More specifically, in terms of reduces we are performing: + reduce(R, (S₁ × ... × Sₖ) , A₁ * ... * Aₖ) + => reduce(R, S₁, A₁) * ... * reduce(R, Sₖ, Aₖ) + where free(Aᵢ) ∩ free(Aⱼ) ∩ S = ∅ + and free(Aᵢ) ∩ S ⊆ Sᵢ + """ + + @implements(CommutativeMonoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) and distributes_over(body.op, monoid.plus): + stream_vars = set(streams.keys()) + factors = [(arg, fvsof(arg)) for arg in body.args] + stream_ids = {v: i for (i, v) in enumerate(stream_vars)} + ds = DisjointSet(len(streams)) + + # streams are in the same partition as their dependencies + for stream_var, stream_id in stream_ids.items(): + stream_body = streams[stream_var] + deps = sorted([stream_ids[v] for v in fvsof(stream_body) & stream_vars]) + ds.union(stream_id, *deps) + + # factors are in the same partition as their dependencies + for factor, factor_fvs in factors: + factor_streams = sorted( + [stream_ids[v] for v in (factor_fvs & stream_vars)] + ) + ds.union(*factor_streams) + + placed_streams = set() + new_reduces = [] + for stream_key in streams: + if stream_key in placed_streams: + continue + + partition = ds.find(stream_ids[stream_key]) + partition_streams = { + k: v + for (k, v) in streams.items() + if ds.find(stream_ids[k]) == partition + } + partition_stream_keys = set(partition_streams.keys()) + + partition_factors = [ + t for t in factors if (t[1] & partition_stream_keys) + ] + + assert all( + (t[1] & stream_vars) <= partition_stream_keys + for t in partition_factors + ), "partition contains all streams required by factor" + + partition_term = body.op(*(t[0] for t in partition_factors)) + new_reduces.append((partition_term, partition_streams)) + placed_streams |= partition_stream_keys + + constant_factors = [t for (t, fvs) in factors if not (fvs & stream_vars)] + + if len(new_reduces) > 1: + result = body.op( + *constant_factors, *(monoid.reduce(*args) for args in new_reduces) + ) + return result + + return fwd() + + +NormalizeReduceIntp = functools.reduce( + coproduct, + typing.cast( + list[Interpretation], + [ReduceNoStreams(), ReduceFusion(), ReduceSplit(), ReduceFactorization()], + ), +) + +NormalizeIntp = coproduct(NormalizePlusIntp, NormalizeReduceIntp) diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 764016752..8fb12598f 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -852,6 +852,84 @@ def _(x: object, other) -> bool: return x == other +@_CustomSingleDispatchCallable +def syntactic_hash(__dispatch: Callable[[type], Callable[[Any], int]], x) -> int: + """Structural hash compatible with :func:`syntactic_eq`. + + Guarantees that ``syntactic_eq(x, y)`` implies + ``syntactic_hash(x) == syntactic_hash(y)``. + + :param x: A term. + :returns: An integer hash. + """ + if dataclasses.is_dataclass(x) and not isinstance(x, type): + return hash( + ( + "dataclass", + type(x), + syntactic_hash( + { + field.name: getattr(x, field.name) + for field in dataclasses.fields(x) + } + ), + ) + ) + else: + return __dispatch(type(x))(x) + + +@syntactic_hash.register +def _(x: Term) -> int: + return hash( + ( + "term", + x.op, + len(x.args), + tuple(syntactic_hash(a) for a in x.args), + # sort kwargs so order doesn't affect the hash + tuple((k, syntactic_hash(x.kwargs[k])) for k in sorted(x.kwargs)), + ) + ) + + +@syntactic_hash.register +def _(x: collections.abc.Mapping) -> int: + # XOR over (key_hash, value_hash) pairs — order-independent, + # matching the set-based comparison in syntactic_eq's Mapping branch. + acc = 0 + for k in x: + acc ^= hash((hash(k), syntactic_hash(x[k]))) + return hash(("mapping", acc)) + + +@syntactic_hash.register +def _(x: collections.abc.Sequence) -> int: + if ( + isinstance(x, tuple) + and hasattr(x, "_fields") + and all(hasattr(x, f) for f in x._fields) + ): + return hash( + ( + "namedtuple", + type(x), + tuple(syntactic_hash(getattr(x, f)) for f in x._fields), + ) + ) + else: + # Use the abstract Sequence tag (not type(x)) because syntactic_eq + # treats any two Sequences of equal length and elementwise-equal + # contents as equal — e.g. [1,2] and (1,2) compare equal. + return hash(("sequence", len(x), tuple(syntactic_hash(a) for a in x))) + + +@syntactic_hash.register(object) +@syntactic_hash.register(str | bytes) +def _(x: object) -> int: + return hash(x) + + class ObjectInterpretation[T, V](collections.abc.Mapping): """A helper superclass for defining an ``Interpretation`` of many :class:`~effectful.ops.types.Operation` instances with shared state or behavior. diff --git a/pyproject.toml b/pyproject.toml index d565403f2..685aaf55f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ test = [ "pytest-cov", "pytest-xdist", "pytest-benchmark", + "hypothesis", "mypy", "ruff", "nbval", diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py new file mode 100644 index 000000000..4532ae72d --- /dev/null +++ b/tests/_monoid_helpers.py @@ -0,0 +1,85 @@ +from collections.abc import Callable, Mapping, Sequence +from typing import Any, get_args, get_origin + +from hypothesis import strategies as st + +from effectful.ops.syntax import deffn +from effectful.ops.types import Operation + + +def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: + """Strategy for the value an *0-arg* Operation should return.""" + if annotation is int: + return st.integers() + if annotation is float: + return st.floats(allow_nan=False) + if get_origin(annotation) is list and get_args(annotation) == (int,): + return st.lists(st.integers()) + raise NotImplementedError( + f"No value strategy for return annotation {annotation!r}; " + "supported: int, list[int]" + ) + + +_UNARY_INT_FNS: list[Callable[[int], int]] = [ + lambda x: x, + lambda x: x + 1, + lambda x: x - 1, + lambda x: -x, + lambda x: 2 * x, + lambda x: 3 * x + 1, +] + +_BINARY_INT_FNS: list[Callable[[int, int], int]] = [ + lambda x, y: x + y, + lambda x, y: x - y, + lambda x, y: x * y, + lambda x, y: x + 2 * y, + lambda x, y: 2 * x - y, +] + +_UNARY_LIST_FNS: list[Callable[[int], list[int]]] = [ + lambda _x: [], + lambda x: [x], + lambda x: [x, x + 1], + lambda x: [x, -x], + lambda x: [0, x, x + 1], +] + + +def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: + """Pick a strategy producing a callable suitable for binding `op` in an + interpretation. Inspects the operation's signature. + """ + sig = op.__signature__ + params = list(sig.parameters.values()) + ret = sig.return_annotation + param_types = tuple(p.annotation for p in params) + + if not params: + return _value_strategy_for(ret).map(deffn) + if ret is int and param_types == (int,): + return st.sampled_from(_UNARY_INT_FNS) + if ret is int and param_types == (int, int): + return st.sampled_from(_BINARY_INT_FNS) + if get_origin(ret) is list and get_args(ret) == (int,) and param_types == (int,): + return st.sampled_from(_UNARY_LIST_FNS) + raise NotImplementedError( + f"Function-typed free var must return int or list[int]; got {ret!r} for {op}" + ) + + +@st.composite +def random_interpretation( + draw: st.DrawFn, free_vars: Sequence[Operation] +) -> Mapping[Operation, Callable[..., Any]]: + """Draw an Interpretation binding every Operation in `case.free_vars` to + a randomly chosen value/callable. Keys are Operation identities. + """ + intp: dict[Operation, Callable[..., Any]] = {} + for op in free_vars: + intp[op] = draw(_strategy_for_op(op)) + return intp + + +__all__ = ["random_interpretation"] diff --git a/tests/test_internals_disjoint_set.py b/tests/test_internals_disjoint_set.py new file mode 100644 index 000000000..808b8d25d --- /dev/null +++ b/tests/test_internals_disjoint_set.py @@ -0,0 +1,124 @@ +import random + +import pytest + +from effectful.internals.disjoint_set import DisjointSet + + +@pytest.fixture +def dsu(): + return DisjointSet(10) + + +def test_initial_state(dsu): + for i in range(10): + assert dsu.find(i) == i + + +def test_simple_union(dsu): + assert dsu.union(1, 2) is True + assert dsu.find(1) == dsu.find(2) + + +def test_union_idempotent(dsu): + dsu.union(1, 2) + assert dsu.union(1, 2) is False + + +def test_union_chain(dsu): + dsu.union(1, 2) + dsu.union(2, 3) + assert dsu.find(1) == dsu.find(3) + + +def test_union_multiple_elements_all_connected(dsu): + dsu.union(1, 2, 3, 4, 5) + roots = {dsu.find(i) for i in [1, 2, 3, 4, 5]} + assert len(roots) == 1 + + +def test_union_multiple_elements_partial_overlap(dsu): + dsu.union(1, 2) + dsu.union(3, 4) + dsu.union(2, 3, 5) + + roots = {dsu.find(i) for i in [1, 2, 3, 4, 5]} + assert len(roots) == 1 + + +def test_union_multiple_elements_with_existing_connections(dsu): + dsu.union(1, 2) + dsu.union(2, 3) + dsu.union(3, 4, 5, 6) + + roots = {dsu.find(i) for i in [1, 2, 3, 4, 5, 6]} + assert len(roots) == 1 + + +def test_union_single_element(dsu): + assert dsu.union(1) is False + + +def test_union_no_elements(dsu): + assert dsu.union() is False + + +def test_union_self(dsu): + assert dsu.union(3, 3) is False + assert dsu.find(3) == 3 + + +def test_transitivity(dsu): + dsu.union(1, 2) + dsu.union(2, 3) + dsu.union(3, 4) + assert dsu.find(1) == dsu.find(4) + + +def test_disjoint_sets_remain_separate(dsu): + dsu.union(1, 2) + dsu.union(3, 4) + assert dsu.find(1) != dsu.find(3) + + +def test_randomized_unions(): + n = 50 + dsu = DisjointSet(n) + + groups = [{i} for i in range(n)] + + def find_group(x): + for g in groups: + if x in g: + return g + + for _ in range(100): + elems = random.sample(range(n), random.randint(2, 5)) + dsu.union(*elems) + + # merge ground-truth groups + merged = set() + for e in elems: + merged |= find_group(e) + + groups = [g for g in groups if g.isdisjoint(merged)] + groups.append(merged) + + # verify structure matches ground truth + for g in groups: + roots = {dsu.find(x) for x in g} + assert len(roots) == 1 + + +def test_path_compression_effect(): + dsu = DisjointSet(6) + dsu.union(0, 1) + dsu.union(1, 2) + dsu.union(2, 3) + dsu.union(3, 4) + + # Trigger compression + root_before = dsu.find(4) + root_after = dsu.find(4) + + assert root_before == root_after diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py new file mode 100644 index 000000000..a22928cca --- /dev/null +++ b/tests/test_ops_monoid.py @@ -0,0 +1,518 @@ +import functools +import itertools + +import pytest +from hypothesis import given, settings +from hypothesis import strategies as st + +from effectful.internals.runtime import interpreter +from effectful.ops.monoid import Max, Min, NormalizeIntp, Product, Semilattice, Sum +from effectful.ops.semantics import apply, evaluate, fvsof, handler +from effectful.ops.syntax import _BaseTerm, defdata, syntactic_eq +from effectful.ops.types import NotHandled, Operation +from tests._monoid_helpers import random_interpretation + +_INT = st.integers(min_value=-100, max_value=100) + +ALL_MONOIDS = [ + pytest.param(Sum, id="Sum"), + pytest.param(Product, id="Product"), + pytest.param(Min, id="Min"), + pytest.param(Max, id="Max"), +] + +COMMUTATIVE = [ + pytest.param(Sum, id="Sum"), + pytest.param(Product, id="Product"), + pytest.param(Min, id="Min"), + pytest.param(Max, id="Max"), +] + +IDEMPOTENT = [ + pytest.param(Min, id="Min"), + pytest.param(Max, id="Max"), +] + +WITH_ZERO = [ + pytest.param(Product, id="Product"), +] + + +def define_vars(*names, typ=int): + if len(names) == 1: + return Operation.define(typ, name=names[0]) + return tuple(Operation.define(typ, name=n) for n in names) + + +@functools.cache +def _canonical_op(idx: int) -> Operation: + """Globally cached canonical Operation, keyed by encounter index. + + Cached so that two independent canonicalize runs return the same + Operation object for the same index — letting ``syntactic_eq`` + compare canonical forms by Operation identity. + """ + return Operation.define(int, name=f"__cv_{idx}") + + +def syntactic_eq_alpha(x, y) -> bool: + """Alpha-equivalence-respecting variant of ``syntactic_eq``. + + Walks each expression bottom-up with :func:`evaluate` and renames + every bound variable to a deterministic canonical Operation. The + canonical names are assigned by a counter that increments in + ``evaluate``'s natural traversal order, so two alpha-equivalent + expressions canonicalize to syntactically identical results. + """ + return syntactic_eq(_canonicalize(x), _canonicalize(y)) + + +def _canonicalize(expr): + counter = itertools.count() + + def _passthrough(op, *args, **kwargs): + return defdata(op, *args, **kwargs) + + def _substitute(arg, renaming): + """Apply a bound-variable renaming using ``evaluate`` for traversal.""" + if not renaming: + return arg + with interpreter({apply: _passthrough, **renaming}): + return evaluate(arg) + + def _bound_var_order(args, kwargs, bound_set): + """Return bound variables in deterministic encounter order.""" + seen: list[Operation] = [] + seen_set: set[Operation] = set() + + def _capture(op, *a, **kw): + if op in bound_set and op not in seen_set: + seen.append(op) + seen_set.add(op) + return defdata(op, *a, **kw) + + # ``evaluate`` walks Terms, lists, tuples, mappings, dataclasses, + # etc. for free; the apply handler captures bound vars used as + # ``x()`` anywhere in the body. + with interpreter({apply: _capture}): + evaluate((args, kwargs)) + + # Binders bypass the apply handler. Pick them up with a small structural + # walk that visits dict keys too. + def _walk_bare(obj): + if isinstance(obj, Operation): + if obj in bound_set and obj not in seen_set: + seen.append(obj) + seen_set.add(obj) + elif isinstance(obj, dict): + for k, v in obj.items(): + _walk_bare(k) + _walk_bare(v) + elif isinstance(obj, list | set | frozenset | tuple): + for v in obj: + _walk_bare(v) + + _walk_bare((args, kwargs)) + return seen + + def _apply_canonical(op, *args, **kwargs): + bindings = op.__fvs_rule__(*args, **kwargs) + all_bound: set[Operation] = set().union( + *bindings.args, *bindings.kwargs.values() + ) + if not all_bound: + return defdata(op, *args, **kwargs) + + order = _bound_var_order(args, kwargs, all_bound) + canonical = {var: _canonical_op(next(counter)) for var in order} + assert all_bound <= set(order) + + new_args = tuple( + _substitute( + arg, {v: canonical[v] for v in bindings.args[i] if v in canonical} + ) + for i, arg in enumerate(args) + ) + new_kwargs = { + k: _substitute( + v, + {var: canonical[var] for var in bindings.kwargs[k] if var in canonical}, + ) + for k, v in kwargs.items() + } + + # avoid the renaming from defdata + return _BaseTerm(op, *new_args, **new_kwargs) + + with interpreter({apply: _apply_canonical}): + return evaluate(expr) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +@given(a=_INT, b=_INT, c=_INT) +@settings(max_examples=50, deadline=None) +def test_associativity(monoid, a, b, c): + left = monoid.plus(monoid.plus(a, b), c) + right = monoid.plus(a, monoid.plus(b, c)) + assert left == right + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +@given(a=_INT) +@settings(max_examples=50, deadline=None) +def test_identity(monoid, a): + assert monoid.plus(monoid.identity, a) == a + assert monoid.plus(a, monoid.identity) == a + + +@pytest.mark.parametrize("monoid", COMMUTATIVE) +@given(a=_INT, b=_INT) +@settings(max_examples=50, deadline=None) +def test_commutativity(monoid, a, b): + assert monoid.plus(a, b) == monoid.plus(b, a) + + +@pytest.mark.parametrize("monoid", IDEMPOTENT) +@given(a=_INT) +@settings(max_examples=50, deadline=None) +def test_idempotence(monoid, a): + assert monoid.plus(a, a) == a + + +@pytest.mark.parametrize("monoid", WITH_ZERO) +@given(a=_INT) +@settings(max_examples=50, deadline=None) +def test_zero_absorbs(monoid, a): + assert monoid.plus(monoid.zero, a) == monoid.zero + assert monoid.plus(a, monoid.zero) == monoid.zero + + +def _check_pair(lhs, rhs, *, free_vars=[], max_examples: int = 25) -> None: + """Run structural + semantic checks on a TermPair.""" + with handler(NormalizeIntp): + norm = evaluate(lhs) + + assert syntactic_eq_alpha(norm, rhs) + + @given(intp=random_interpretation(free_vars)) + @settings(max_examples=max_examples, deadline=None) + def _check_semantics(intp): + with handler(intp): + lhs_val = evaluate(lhs) + rhs_val = evaluate(rhs) + assert lhs_val == rhs_val + + _check_semantics() + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_empty(monoid): + _check_pair(lhs=monoid.plus(), rhs=monoid.identity) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_single(monoid): + x = define_vars("x", typ=type(monoid.identity)) + _check_pair(lhs=monoid.plus(x()), rhs=x(), free_vars=[x]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_identity_right(monoid): + x = define_vars("x", typ=type(monoid.identity)) + _check_pair(lhs=monoid.plus(x(), monoid.identity), rhs=x(), free_vars=[x]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_identity_left(monoid): + x = define_vars("x", typ=type(monoid.identity)) + _check_pair(lhs=monoid.plus(monoid.identity, x()), rhs=x(), free_vars=[x]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_assoc_right(monoid): + x, y, z = define_vars("x", "y", "z", typ=type(monoid.identity)) + _check_pair( + lhs=monoid.plus(x(), monoid.plus(y(), z())), + rhs=monoid.plus(x(), y(), z()), + free_vars=[x, y, z], + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_assoc_left(monoid): + x, y, z = define_vars("x", "y", "z", typ=type(monoid.identity)) + _check_pair( + lhs=monoid.plus(monoid.plus(x(), y()), z()), + rhs=monoid.plus(x(), y(), z()), + free_vars=[x, y, z], + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_sequence(monoid): + a, b, c, d = define_vars("a", "b", "c", "d", typ=type(monoid.identity)) + _check_pair( + lhs=monoid.plus([a(), b()], [c(), d()]), + rhs=[monoid.plus(a(), c()), monoid.plus(b(), d())], + free_vars=[a, b, c, d], + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_mapping(monoid): + a, b, c, d = define_vars("a", "b", "c", "d", typ=type(monoid.identity)) + _check_pair( + lhs=monoid.plus({"x": a(), "y": b()}, {"x": c(), "z": d()}), + rhs={"x": monoid.plus(a(), c()), "y": b(), "z": d()}, + free_vars=[a, b, c, d], + ) + + +def test_plus_distributes(): + a, b, c, d = define_vars("a", "b", "c", "d") + lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d())) + rhs = Sum.plus( + Product.plus(a(), c()), + Product.plus(a(), d()), + Product.plus(b(), c()), + Product.plus(b(), d()), + ) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b, c, d]) + + +def test_plus_distributes_constant(): + a, b, c, d = define_vars("a", "b", "c", "d") + lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d()), 5) + rhs = Product.plus( + 5, + Sum.plus( + Product.plus(a(), c()), + Product.plus(a(), d()), + Product.plus(b(), c()), + Product.plus(b(), d()), + ), + ) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b, c, d]) + + +def test_plus_distributes_multiple(): + a, b, c, d = define_vars("a", "b", "c", "d") + lhs = Sum.plus( + Min.plus(a(), b()), + Min.plus(c(), d()), + Max.plus(a(), b()), + Max.plus(c(), d()), + ) + rhs = Sum.plus( + Min.plus( + Sum.plus(a(), c()), + Sum.plus(a(), d()), + Sum.plus(b(), c()), + Sum.plus(b(), d()), + ), + Max.plus( + Sum.plus(a(), c()), + Sum.plus(a(), d()), + Sum.plus(b(), c()), + Sum.plus(b(), d()), + ), + ) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b, c, d]) + + +@pytest.mark.parametrize("monoid", IDEMPOTENT) +def test_plus_idempotent_consecutive(monoid): + """``a, a, b → a, b`` — only consecutive duplicates collapse.""" + a, b = define_vars("a", "b") + lhs = monoid.plus(a(), a(), b()) + return _check_pair(lhs=lhs, rhs=monoid.plus(a(), b()), free_vars=[a, b]) + + +@pytest.mark.parametrize("monoid", IDEMPOTENT) +def test_plus_idempotent_non_consecutive(monoid): + """``a, b, a`` — Semilattice (Min/Max) collapses via commutative + PlusDups; plain IdempotentMonoid leaves it as-is (consecutive-only).""" + a, b = define_vars("a", "b") + lhs = monoid.plus(a(), b(), a()) + if isinstance(monoid, Semilattice): + rhs = monoid.plus(a(), b()) + else: + rhs = monoid.plus(a(), b(), a()) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b]) + + +def test_plus_commutative_idempotent_long(): + """Long alternation collapses via commutative dedup (Min/Max only).""" + a, b = define_vars("a", "b") + lhs = Min.plus(a(), b(), a(), b(), b(), a(), a()) + _check_pair(lhs=lhs, rhs=Min.plus(a(), b()), free_vars=[a, b]) + + +@pytest.mark.parametrize("monoid", WITH_ZERO) +def test_plus_zero(monoid): + a = define_vars("a") + lhs_right = monoid.plus(a(), monoid.zero) + lhs_left = monoid.plus(monoid.zero, a()) + _check_pair(lhs=lhs_right, rhs=monoid.zero, free_vars=[a]) + _check_pair(lhs=lhs_left, rhs=monoid.zero, free_vars=[a]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_body_sequence(monoid): + x = Operation.define(int, name="x") + X = Operation.define(list[int], name="X") + + @Operation.define + def f(_x: int) -> int: + raise NotHandled + + g = Operation.define(f, name="g") + + lhs = monoid.reduce([f(x()), g(x())], {x: X()}) + rhs = [monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})] + + _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, f, g]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_body_sequence_2(monoid): + x, y = define_vars("x", "y") + X, Y = define_vars("X", "Y", typ=list[int]) + + @Operation.define + def f(_x: int) -> int: + raise NotHandled + + g = Operation.define(f, name="g") + + lhs = monoid.reduce([f(x()), g(y())], {x: X(), y: Y()}) + rhs = [ + monoid.reduce(f(x()), {x: X(), y: Y()}), + monoid.reduce(g(y()), {x: X(), y: Y()}), + ] + + _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, Y, f, g]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_body_mapping(monoid): + x = Operation.define(int, name="x") + X = Operation.define(list[int], name="X") + + @Operation.define + def f(_x: int) -> int: + raise NotHandled + + g = Operation.define(f, name="g") + + lhs = monoid.reduce({"a": f(x()), "b": g(x())}, {x: X()}) + rhs = { + "a": monoid.reduce(f(x()), {x: X()}), + "b": monoid.reduce(g(x()), {x: X()}), + } + _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, f, g]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_no_streams(monoid): + a = define_vars("a") + lhs = monoid.reduce(a(), {}) + rhs = monoid.identity + + _check_pair(lhs=lhs, rhs=rhs, free_vars=[a]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_reduce(monoid): + a, b = define_vars("a", "b") + A, B = define_vars("A", "B", typ=list[int]) + + @Operation.define + def f(_x: int, _y: int) -> int: + raise NotHandled + + lhs = monoid.reduce(monoid.reduce(f(a(), b()), {a: A()}), {b: B()}) + rhs = monoid.reduce(f(a(), b()), {a: A(), b: B()}) + + _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, f]) + + +@pytest.mark.parametrize("monoid", COMMUTATIVE) +def test_reduce_plus(monoid): + a, b = define_vars("a", "b") + A, B = define_vars("A", "B", typ=list[int]) + lhs = monoid.reduce(monoid.plus(a(), b()), {a: A(), b: B()}) + rhs = monoid.plus( + monoid.reduce(a(), {a: A(), b: B()}), + monoid.reduce(b(), {a: A(), b: B()}), + ) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B]) + + +def test_reduce_independent_1(): + a, b = define_vars("a", "b") + A, B = define_vars("A", "B", typ=list[int]) + lhs = Sum.reduce(Product.plus(a(), b()), {a: A(), b: B()}) + rhs = Product.plus(Sum.reduce(a(), {a: A()}), Sum.reduce(b(), {b: B()})) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B]) + + +def test_reduce_independent_2(): + a, b, c = define_vars("a", "b", "c") + A, B, C = define_vars("A", "B", "C", typ=list[int]) + + @Operation.define + def f(_x: int, _y: int) -> int: + raise NotHandled + + lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c())), {a: A(), b: B(), c: C()}) + rhs = Product.plus( + Sum.reduce(a(), {a: A()}), + Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), + ) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, C, f]) + + +def test_reduce_independent_3_negative(): + """Stream `b` depends on `a` (b: g(a())), so the proposed factorization + is unsound — the normalizer must NOT apply it.""" + a, b, c = define_vars("a", "b", "c") + A, C = define_vars("A", "C", typ=list[int]) + + @Operation.define + def f(_x: int, _y: int) -> int: + raise NotHandled + + @Operation.define + def g(_x: int) -> list[int]: + raise NotHandled + + with handler(NormalizeIntp): + lhs = Sum.reduce( + Product.plus(a(), b(), f(b(), c())), {a: A(), b: g(a()), c: C()} + ) + bogus_rhs = Product.plus( + Sum.reduce(a(), {a: A()}), + Sum.reduce(Product.plus(b(), f(b(), c())), {b: g(a()), c: C()}), + ) + assert fvsof(bogus_rhs) != fvsof(lhs) + # Structural-only negative check: the normalizer correctly refused to apply + # the bogus factorization. + assert not syntactic_eq_alpha(lhs, bogus_rhs) + + +def test_reduce_independent_4(): + a, b, c = define_vars("a", "b", "c") + A, B, C = define_vars("A", "B", "C", typ=list[int]) + + @Operation.define + def f(_x: int, _y: int) -> int: + raise NotHandled + + lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c()), 7), {a: A(), b: B(), c: C()}) + rhs = Product.plus( + 7, + Sum.reduce(a(), {a: A()}), + Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), + ) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, C, f]) From 698716431d83cdc6dc10bed589a21503b0a97da5 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 12 May 2026 14:57:58 -0400 Subject: [PATCH 2/6] Add inversion from `weighted` (#655) * Add monoid module (#653) * add monoid module * clean up * fix doctest * fix * wip * remove incorrect rule * add disjoint set tests and fix bug * lint * drop jax monoid defs * drop incorrect comment * add assert * reduce nondeterminism and add assertions * fix inconsistent stream numbering and missing constant factors * wip * cleanup * fix rule * wip * fix bug * cleanup * lin --- effectful/internals/product_n.py | 2 +- effectful/ops/monoid.py | 162 ++++++++++++++++++++++++++-- effectful/ops/semantics.py | 1 + effectful/ops/types.py | 5 +- tests/_monoid_helpers.py | 14 +-- tests/test_handlers_llm_provider.py | 2 +- tests/test_ops_monoid.py | 146 ++++++++++++++++++++++--- tests/test_ops_syntax.py | 1 - 8 files changed, 300 insertions(+), 33 deletions(-) diff --git a/effectful/internals/product_n.py b/effectful/internals/product_n.py index 4b8bd2a81..87a9c6a42 100644 --- a/effectful/internals/product_n.py +++ b/effectful/internals/product_n.py @@ -69,7 +69,7 @@ def map_structure(func, expr): else: return type(expr)(map_structure(func, tuple(expr.items()))) elif isinstance(expr, collections.abc.Sequence): - if isinstance(expr, str | bytes): + if isinstance(expr, str | bytes | range): return expr elif ( isinstance(expr, tuple) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 58a10ba3d..ad83de47b 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -4,24 +4,26 @@ import numbers import typing from collections import Counter, defaultdict -from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Sequence +from collections.abc import Callable, Generator, Iterable, Iterator, Mapping from dataclasses import dataclass from graphlib import TopologicalSorter from typing import Annotated, Any from effectful.internals.disjoint_set import DisjointSet +from effectful.internals.runtime import interpreter from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler from effectful.ops.syntax import ( ObjectInterpretation, Scoped, _NumberTerm, defdata, + deffn, implements, iter_, syntactic_eq, syntactic_hash, ) -from effectful.ops.types import Interpretation, NotHandled, Operation, Term +from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term # Note: The streams value type should be something like Iterable[T], but some of # our target stream types (e.g. jax.Array) are not subtypes of Iterable @@ -80,9 +82,13 @@ def plus[S: Body[T]](self, *args: S) -> S: def _plus[S](self, *args: S) -> S: return typing.cast(S, functools.reduce(self.kernel, args, self.identity)) - @_plus.register(Sequence) + @_plus.register(tuple) def _(self, *args): - return type(args[0])(self.plus(*vs) for vs in zip(*args, strict=True)) + return tuple(self.plus(*vs) for vs in zip(*args, strict=True)) + + @_plus.register(Generator) + def _(self, *args): + return (self.plus(*vs) for vs in zip(*args, strict=True)) @_plus.register(Mapping) def _(self, *args): @@ -161,8 +167,8 @@ def _(self, body: Mapping, streams): return {k: self.reduce(v, streams) for (k, v) in body.items()} @reduce.register # type: ignore[attr-defined] - def _(self, body: Sequence, streams): - return type(body)(self.reduce(x, streams) for x in body) # type:ignore[call-arg] + def _(self, body: tuple, streams): + return tuple(self.reduce(x, streams) for x in body) @reduce.register # type: ignore[attr-defined] def _(self, body: Generator, streams): @@ -252,12 +258,26 @@ def _arg_max[T]( return b if b[0] > a[0] else a # type: ignore +@Operation.define +def product[T]( + a: Iterable[tuple[T, ...] | T], b: Iterable[tuple[T, ...] | T] +) -> Iterable[tuple[T, ...]]: + if isinstance(a, Term) or isinstance(b, Term): + raise NotHandled + + def to_tuple(x): + return x if isinstance(x, tuple) else (x,) + + return [to_tuple(x) + to_tuple(y) for (x, y) in itertools.product(a, b)] + + Min = Semilattice(kernel=min, identity=float("inf")) Max = Semilattice(kernel=max, identity=float("-inf")) ArgMin = Monoid(kernel=_arg_min, identity=(float("inf"), None)) ArgMax = Monoid(kernel=_arg_max, identity=(float("-inf"), None)) Sum = CommutativeMonoid(kernel=_NumberTerm.__add__, identity=0) Product = CommutativeMonoidWithZero(kernel=_NumberTerm.__mul__, identity=1, zero=0) +CartesianProduct = Monoid(kernel=product, identity=[()]) @dataclass @@ -545,11 +565,139 @@ def reduce(self, monoid, body, streams): return fwd() +def inner_stream( + streams: dict[Operation, Expr], +) -> Iterable[tuple[dict[Operation, Expr], Operation, Expr]]: + """Returns the streams that can be ordered innermost in the loop nest as + well as the remaining streams in the nest. + + """ + stream_vars = set(streams.keys()) + + no_dependents = set() + succ = defaultdict(set) + for k, v in streams.items(): + preds = fvsof(v) & stream_vars + if preds: + for pred in preds: + succ[pred].add(k) + else: + no_dependents.add(k) + + topo = TopologicalSorter(succ) + topo.prepare() + return ( + ({k: v for (k, v) in streams.items() if k != op}, op, streams[op]) + for op in set(topo.get_ready()) | no_dependents + ) + + +def match_reduce(term: Term) -> tuple | None: + reduce_args = None + + def set_reduce_args(*args, **kwargs): + nonlocal reduce_args + reduce_args = args + + with interpreter({Monoid.reduce: set_reduce_args}): + term.op(*term.args, **term.kwargs) + return reduce_args + + +class ReduceDistributeCartesianProduct(ObjectInterpretation): + """Eliminates a reduce over a cartesian product. + ∑_x₁ ∑_x₂ ... ∑_xₙ ∏_i f(xᵢ) = ∏_i ∑_xᵢ f(xᵢ) + This transform is also called inversion in the lifting + literature (e.g. [1]). + + More specifically, this transform implements the identity + reduce(⨁, reduce(⨂, body2, {vv: v()}), {v: reduce(×, body1, S1)} ∪ S2) + = reduce(⨁, reduce(⨂, reduce(⨁, body2, {vv: body1}), S1), S2) + where × is the cartesian product and ⨂ distributes over ⨁. + + Note: This could be generalized to grouped inversion [2]. + + [1] Braz, Rd, Eyal Amir, and Dan Roth. "Lifted first-order + probabilistic inference." IJCAI. 2005. + [2] Taghipour, Nima, et al. "Completeness results for lifted + variable elimination." AISTATS. 2013. + """ + + @implements(CommutativeMonoid.reduce) + def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): + if not (isinstance(sum_body, Term)): + return fwd() + + # body is a product or multiplication of products + if distributes_over(sum_body.op, sum_monoid.plus): + prod_reduces = sum_body.args + else: + prod_reduces = [sum_body] + + products: list[tuple[Monoid, Callable, Operation, Term]] = [] + for prod_reduce in prod_reduces: + prod_args = match_reduce(prod_reduce) + if prod_args is None: + return fwd() + (prod_monoid, prod_body, prod_streams) = prod_args + if not ( + distributes_over(prod_monoid.plus, sum_monoid.plus) + and (len(products) == 0 or products[-1][0] == prod_monoid) + ): + return fwd() + + if len(prod_streams) > 1 or len(prod_streams) == 0: + return fwd() + (prod_op, prod_stream) = next(iter(prod_streams.items())) + products.append( + (prod_monoid, deffn(prod_body, prod_op), prod_op, prod_stream) + ) + + assert len(products) > 0 + + for outer_sum_streams, cprod_op, cprod_term in inner_stream(sum_streams): + if not ( + isinstance(cprod_term, Term) + and cprod_term.op == CartesianProduct.reduce + ): + continue + (cprod_body, cprod_streams) = cprod_term.args + + if not all( + prod_stream.op == cprod_op for (_, _, _, prod_stream) in products + ): + continue + + prod_op = Operation.define(products[0][2]) + prod_monoid = products[0][0] + inner_sum = sum_monoid.reduce( + prod_monoid.plus( + *(prod_body(prod_op()) for (_, prod_body, _, _) in products) + ), + {prod_op: cprod_body}, + ) + prod = prod_monoid.reduce(inner_sum, cprod_streams) + outer_sum = ( + sum_monoid.reduce(prod, outer_sum_streams) + if outer_sum_streams + else prod + ) + return outer_sum + + return fwd() + + NormalizeReduceIntp = functools.reduce( coproduct, typing.cast( list[Interpretation], - [ReduceNoStreams(), ReduceFusion(), ReduceSplit(), ReduceFactorization()], + [ + ReduceNoStreams(), + ReduceFusion(), + ReduceSplit(), + ReduceFactorization(), + ReduceDistributeCartesianProduct(), + ], ), ) diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index f7678fd24..8fd62bcd5 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -209,6 +209,7 @@ def evaluate[T]( @evaluate.register(object) @evaluate.register(str) @evaluate.register(bytes) +@evaluate.register(range) def _evaluate_object[T](expr: T, **kwargs) -> T: if dataclasses.is_dataclass(expr) and not isinstance(expr, type): return typing.cast( diff --git a/effectful/ops/types.py b/effectful/ops/types.py index 40c1f4af5..d24be9745 100644 --- a/effectful/ops/types.py +++ b/effectful/ops/types.py @@ -488,7 +488,10 @@ def _instance_op(instance, *args, **kwargs): else: return default_result - instance_op = self.define(types.MethodType(_instance_op, instance)) + name = ("" if owner is None else f"{owner.__name__}_") + self.__name__ + instance_op = self.define( + types.MethodType(_instance_op, instance), name=name + ) instance.__dict__[self._name_on_instance] = instance_op return instance_op elif instance is not None: diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index 4532ae72d..9b311b257 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -14,14 +14,14 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: if annotation is float: return st.floats(allow_nan=False) if get_origin(annotation) is list and get_args(annotation) == (int,): - return st.lists(st.integers()) + return st.lists(st.integers(), max_size=2) raise NotImplementedError( f"No value strategy for return annotation {annotation!r}; " "supported: int, list[int]" ) -_UNARY_INT_FNS: list[Callable[[int], int]] = [ +_UNARY_NUM_FNS: list[Callable[[int], int]] = [ lambda x: x, lambda x: x + 1, lambda x: x - 1, @@ -30,7 +30,7 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: lambda x: 3 * x + 1, ] -_BINARY_INT_FNS: list[Callable[[int, int], int]] = [ +_BINARY_NUM_FNS: list[Callable[[int, int], int]] = [ lambda x, y: x + y, lambda x, y: x - y, lambda x, y: x * y, @@ -58,10 +58,10 @@ def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: if not params: return _value_strategy_for(ret).map(deffn) - if ret is int and param_types == (int,): - return st.sampled_from(_UNARY_INT_FNS) - if ret is int and param_types == (int, int): - return st.sampled_from(_BINARY_INT_FNS) + if ret in (int, float) and param_types == (int,): + return st.sampled_from(_UNARY_NUM_FNS) + if ret in (int, float) and param_types == (int, int): + return st.sampled_from(_BINARY_NUM_FNS) if get_origin(ret) is list and get_args(ret) == (int,) and param_types == (int,): return st.sampled_from(_UNARY_LIST_FNS) raise NotImplementedError( diff --git a/tests/test_handlers_llm_provider.py b/tests/test_handlers_llm_provider.py index b56fd7bbd..9a2983901 100644 --- a/tests/test_handlers_llm_provider.py +++ b/tests/test_handlers_llm_provider.py @@ -240,7 +240,7 @@ def test_agent_tool_names_are_valid_integration(): agent = _ToolNameAgent() template = agent.ask tools = template.tools - expected_helper_tool_name = f"self__{agent.helper.__name__}" + expected_helper_tool_name = "self__helper" assert tools assert expected_helper_tool_name in tools assert all(re.fullmatch(r"[a-zA-Z0-9_-]+", name) for name in tools) diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index a22928cca..e73a9a7b2 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -1,12 +1,23 @@ import functools import itertools +import typing import pytest from hypothesis import given, settings from hypothesis import strategies as st from effectful.internals.runtime import interpreter -from effectful.ops.monoid import Max, Min, NormalizeIntp, Product, Semilattice, Sum +from effectful.ops.monoid import ( + CartesianProduct, + Max, + Min, + Monoid, + NormalizeIntp, + Product, + Semilattice, + Sum, + distributes_over, +) from effectful.ops.semantics import apply, evaluate, fvsof, handler from effectful.ops.syntax import _BaseTerm, defdata, syntactic_eq from effectful.ops.types import NotHandled, Operation @@ -37,6 +48,18 @@ pytest.param(Product, id="Product"), ] +# Pairs (outer, inner) such that inner distributes over outer — i.e. the lifting +# identity ``outer(inner(body, A), CartesianProduct...) == inner(outer(body, D), ...)`` +# is valid for that semiring pair. +MONOID_PAIRS = [ + pytest.param(o.values[0], i.values[0], id=f"{o.id}-{i.id}") + for o in ALL_MONOIDS + for i in ALL_MONOIDS + if distributes_over( + typing.cast(Monoid, i.values[0]).plus, typing.cast(Monoid, o.values[0]).plus + ) +] + def define_vars(*names, typ=int): if len(names) == 1: @@ -70,14 +93,11 @@ def syntactic_eq_alpha(x, y) -> bool: def _canonicalize(expr): counter = itertools.count() - def _passthrough(op, *args, **kwargs): - return defdata(op, *args, **kwargs) - def _substitute(arg, renaming): """Apply a bound-variable renaming using ``evaluate`` for traversal.""" if not renaming: return arg - with interpreter({apply: _passthrough, **renaming}): + with interpreter({apply: _BaseTerm, **renaming}): return evaluate(arg) def _bound_var_order(args, kwargs, bound_set): @@ -121,7 +141,7 @@ def _apply_canonical(op, *args, **kwargs): *bindings.args, *bindings.kwargs.values() ) if not all_bound: - return defdata(op, *args, **kwargs) + return _BaseTerm(op, *args, **kwargs) order = _bound_var_order(args, kwargs, all_bound) canonical = {var: _canonical_op(next(counter)) for var in order} @@ -252,8 +272,8 @@ def test_plus_assoc_left(monoid): def test_plus_sequence(monoid): a, b, c, d = define_vars("a", "b", "c", "d", typ=type(monoid.identity)) _check_pair( - lhs=monoid.plus([a(), b()], [c(), d()]), - rhs=[monoid.plus(a(), c()), monoid.plus(b(), d())], + lhs=monoid.plus((a(), b()), (c(), d())), + rhs=(monoid.plus(a(), c()), monoid.plus(b(), d())), free_vars=[a, b, c, d], ) @@ -368,8 +388,8 @@ def f(_x: int) -> int: g = Operation.define(f, name="g") - lhs = monoid.reduce([f(x()), g(x())], {x: X()}) - rhs = [monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})] + lhs = monoid.reduce((f(x()), g(x())), {x: X()}) + rhs = (monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})) _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, f, g]) @@ -385,11 +405,11 @@ def f(_x: int) -> int: g = Operation.define(f, name="g") - lhs = monoid.reduce([f(x()), g(y())], {x: X(), y: Y()}) - rhs = [ + lhs = monoid.reduce((f(x()), g(y())), {x: X(), y: Y()}) + rhs = ( monoid.reduce(f(x()), {x: X(), y: Y()}), monoid.reduce(g(y()), {x: X(), y: Y()}), - ] + ) _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, Y, f, g]) @@ -496,8 +516,6 @@ def g(_x: int) -> list[int]: Sum.reduce(Product.plus(b(), f(b(), c())), {b: g(a()), c: C()}), ) assert fvsof(bogus_rhs) != fvsof(lhs) - # Structural-only negative check: the normalizer correctly refused to apply - # the bogus factorization. assert not syntactic_eq_alpha(lhs, bogus_rhs) @@ -516,3 +534,101 @@ def f(_x: int, _y: int) -> int: Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, C, f]) + + +@pytest.mark.parametrize("outer,inner", MONOID_PAIRS) +def test_reduce_lifted_1(outer, inner): + a, i = define_vars("a", "i") + A, N, A_domain = define_vars("A", "N", "A_domain", typ=list[int]) + + @Operation.define + def f(_: int) -> float: + raise NotHandled + + term1 = outer.reduce( + inner.reduce(f(a()), {a: A()}), + {A: CartesianProduct.reduce(A_domain(), {i: N()})}, + ) + term2 = inner.reduce(outer.reduce(f(a()), {a: A_domain()}), {i: N()}) + _check_pair(lhs=term1, rhs=term2, free_vars=[N, A_domain, f]) + + +def test_reduce_cartesian_1(): + a, i = define_vars("a", "i") + A = define_vars("A", typ=list[int]) + + term1 = Sum.reduce( + Product.reduce(a(), {a: []}), + {A: CartesianProduct.reduce([], {i: []})}, + ) + term2 = Product.reduce(Sum.reduce(a(), {a: []}), {i: []}) + assert term1 == term2 + + +def test_reduce_cartesian_2(): + a, i = define_vars("a", "i") + A = define_vars("A", typ=list[int]) + + term1 = Sum.reduce( + Product.reduce(a(), {a: A()}), + {A: CartesianProduct.reduce([(0,)], {i: [0]})}, + ) + term2 = Product.reduce(Sum.reduce(a(), {a: [0]}), {i: [0]}) + assert term1 == term2 + + +@pytest.mark.parametrize("outer,inner", MONOID_PAIRS) +def test_reduce_lifted_multi_index(outer, inner): + a, i, j = define_vars("a", "i", "j") + A, N, M, A_domain = define_vars("A", "N", "M", "A_domain", typ=list[int]) + + @Operation.define + def f(_: int) -> float: + raise NotHandled + + term1 = outer.reduce( + inner.reduce(f(a()), {a: A()}), + {A: CartesianProduct.reduce(A_domain(), {i: N(), j: M()})}, + ) + term2 = inner.reduce( + outer.reduce(f(a()), {a: A_domain()}), + {i: N(), j: M()}, + ) + _check_pair(lhs=term1, rhs=term2, free_vars=[N, M, A_domain, f]) + + +@pytest.mark.parametrize("outer,inner", MONOID_PAIRS) +def test_reduce_lifted_2(outer, inner): + """The worked example on page 396 of 'Lifted Variable Elimination: + Decoupling the Operators from the Constraint Language'. + + """ + a, i, s, t = define_vars("a", "i", "s", "t") + A, N, T = define_vars("A", "N", "T", typ=list[int]) + + @Operation.define + def A_domain(_i: int) -> list[int]: + raise NotHandled + + @Operation.define + def f1(_a: int, _s: int) -> float: + raise NotHandled + + @Operation.define + def f2(_t: int, _a: int) -> float: + raise NotHandled + + term1 = outer.reduce( + inner.reduce(inner.plus(f1(a(), s()), f2(t(), a())), {a: A()}), + {A: CartesianProduct.reduce(A_domain(i()), {i: N()}), t: T()}, + ) + + term2 = outer.reduce( + inner.reduce( + outer.reduce(inner.plus(f1(a(), s()), f2(t(), a())), {a: A_domain(i())}), + {i: N()}, + ), + {t: T()}, + ) + + _check_pair(lhs=term1, rhs=term2, free_vars=[a, i, s, t, A, N, T, A_domain, f1, f2]) diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index 185b6132e..1f5c47763 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -489,7 +489,6 @@ def _(self, x: bool) -> bool: ) assert isinstance(term_float, Term) - assert term_float.op.__name__ == "my_singledispatch" assert term_float.args == (1.5,) assert term_float.kwargs == {} From 557777d2caec1e655bb3a701c737a75ce967bf5d Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 12 May 2026 16:42:40 -0400 Subject: [PATCH 3/6] Refactor `monoid.py` to remove class structure (#661) * Add monoid module (#653) * add monoid module * clean up * fix doctest * fix * wip * remove incorrect rule * add disjoint set tests and fix bug * lint * drop jax monoid defs * drop incorrect comment * add assert * reduce nondeterminism and add assertions * fix inconsistent stream numbering and missing constant factors * wip * cleanup * fix rule * wip * fix bug * cleanup * lin * wip * fix tests * format * lint * wip --- effectful/ops/monoid.py | 255 ++++++++++++++++++--------------------- effectful/ops/types.py | 76 ++++++++++++ tests/test_ops_monoid.py | 6 +- 3 files changed, 197 insertions(+), 140 deletions(-) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index ad83de47b..0d6e230c0 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -10,20 +10,25 @@ from typing import Annotated, Any from effectful.internals.disjoint_set import DisjointSet -from effectful.internals.runtime import interpreter from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler from effectful.ops.syntax import ( ObjectInterpretation, Scoped, _NumberTerm, - defdata, deffn, implements, iter_, syntactic_eq, syntactic_hash, ) -from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term +from effectful.ops.types import ( + Expr, + Interpretation, + NotHandled, + Operation, + Term, + _CustomSingleDispatchMethod, +) # Note: The streams value type should be something like Iterable[T], but some of # our target stream types (e.g. jax.Array) are not subtypes of Iterable @@ -64,60 +69,57 @@ def __init__(self, kernel: Callable[[T, T], T], identity: T): def __repr__(self): return f"{type(self)}({self.kernel}, {self.identity})" + def __eq__(self, other): + return id(self) == id(other) + + def __hash__(self): + return hash(id(self)) + @Operation.define - def plus[S: Body[T]](self, *args: S) -> S: + @_CustomSingleDispatchMethod + def plus[S](self, dispatch, *args: S) -> S: """Monoid addition with broadcasting over common collection types, callables, and interpretations. - """ if not args: return typing.cast(S, self.identity) + return dispatch(type(args[0]))(self, *args) + @plus.register(object) # type: ignore[attr-defined] + def _(self, *args): if any(isinstance(x, Term) for x in args): - return typing.cast(S, defdata(self.plus, *args)) + raise NotHandled + return functools.reduce(self.kernel, args, self.identity) - return self._plus(*args) - - @functools.singledispatchmethod - def _plus[S](self, *args: S) -> S: - return typing.cast(S, functools.reduce(self.kernel, args, self.identity)) - - @_plus.register(tuple) + @plus.register(tuple) # type: ignore[attr-defined] def _(self, *args): return tuple(self.plus(*vs) for vs in zip(*args, strict=True)) - @_plus.register(Generator) + @plus.register(Generator) # type: ignore[attr-defined] def _(self, *args): return (self.plus(*vs) for vs in zip(*args, strict=True)) - @_plus.register(Mapping) + @plus.register(Mapping) # type: ignore[attr-defined] def _(self, *args): if isinstance(args[0], Interpretation): keys = args[0].keys() - for b in args[1:]: if not isinstance(b, Interpretation): raise TypeError(f"Expected interpretation but got {b}") - - b_keys = b.keys() - if not keys == b_keys: + if not keys == b.keys(): raise ValueError( - f"Expected interpretation of {keys} but got {b_keys}" + f"Expected interpretation of {keys} but got {b.keys()}" ) - - result = {k: self.plus(*(handler(b)(b[k]) for b in args)) for k in keys} - return result + return {k: self.plus(*(handler(b)(b[k]) for b in args)) for k in keys} for b in args[1:]: if not isinstance(b, Mapping): raise TypeError(f"Expected mapping but got {b}") - all_values = collections.defaultdict(list) for d in args: for k, v in d.items(): all_values[k].append(v) - result = {k: self.plus(*vs) for (k, vs) in all_values.items()} - return result + return {k: self.plus(*vs) for (k, vs) in all_values.items()} @Operation.define @functools.singledispatchmethod @@ -155,12 +157,9 @@ def generator(loop_order) -> Iterator[Interpretation]: yield coproduct(intp, intp2) loop_order = list(order_streams(streams)) - try: - return self.plus( - *(handler(intp)(evaluate)(body) for intp in generator(loop_order)) - ) - except NotHandled: - return typing.cast(U, defdata(self.reduce, body, streams)) + return self.plus( + *(handler(intp)(evaluate)(body) for intp in generator(loop_order)) + ) @reduce.register # type: ignore[attr-defined] def _(self, body: Mapping, streams): @@ -175,35 +174,19 @@ def _(self, body: Generator, streams): return (self.reduce(x, streams) for x in body) -class IdempotentMonoid[T](Monoid[T]): - @Operation.define - def plus[S: Body[T]](self, *args: S) -> S: - return super().plus(*args) - - @Operation.define - def reduce[A, B, U: Body]( - self, - body: Annotated[U, Scoped[A | B]], - streams: Annotated[Streams, Scoped[A]], - ) -> Annotated[U, Scoped[B]]: - return super().reduce(body, streams) +def _is_monoid_plus(op: Operation) -> bool: + """True if ``op`` is the ``plus`` operation of some :class:`Monoid`.""" + owner = getattr(op, "__self__", None) + return isinstance(owner, Monoid) and op is owner.plus -class CommutativeMonoid[T](Monoid[T]): - @Operation.define - def plus[S: Body[T]](self, *args: S) -> S: - return super().plus(*args) - - @Operation.define - def reduce[A, B, U: Body]( - self, - body: Annotated[U, Scoped[A | B]], - streams: Annotated[Streams, Scoped[A]], - ) -> Annotated[U, Scoped[B]]: - return super().reduce(body, streams) +def _is_monoid_reduce(op: Operation) -> bool: + """True if ``op`` is the ``reduce`` operation of some :class:`Monoid`.""" + owner = getattr(op, "__self__", None) + return isinstance(owner, Monoid) and op is owner.reduce -class CommutativeMonoidWithZero[T](CommutativeMonoid[T]): +class MonoidWithZero[T](Monoid[T]): zero: T def __init__(self, kernel: Callable[[T, T], T], identity: T, zero: T): @@ -213,32 +196,6 @@ def __init__(self, kernel: Callable[[T, T], T], identity: T, zero: T): def __repr__(self): return f"{type(self)}({self.kernel}, {self.identity}, {self.zero})" - @Operation.define - def plus[S: Body[T]](self, *args: S) -> S: - return super().plus(*args) - - @Operation.define - def reduce[A, B, U: Body]( - self, - body: Annotated[U, Scoped[A | B]], - streams: Annotated[Streams, Scoped[A]], - ) -> Annotated[U, Scoped[B]]: - return super().reduce(body, streams) - - -class Semilattice[T](IdempotentMonoid[T], CommutativeMonoid[T]): - @Operation.define - def plus[S: Body[T]](self, *args: S) -> S: - return super().plus(*args) - - @Operation.define - def reduce[A, B, U: Body]( - self, - body: Annotated[U, Scoped[A | B]], - streams: Annotated[Streams, Scoped[A]], - ) -> Annotated[U, Scoped[B]]: - return super().reduce(body, streams) - @Operation.define def _arg_min[T]( @@ -271,15 +228,30 @@ def to_tuple(x): return [to_tuple(x) + to_tuple(y) for (x, y) in itertools.product(a, b)] -Min = Semilattice(kernel=min, identity=float("inf")) -Max = Semilattice(kernel=max, identity=float("-inf")) +Min = Monoid(kernel=min, identity=float("inf")) +Max = Monoid(kernel=max, identity=float("-inf")) ArgMin = Monoid(kernel=_arg_min, identity=(float("inf"), None)) ArgMax = Monoid(kernel=_arg_max, identity=(float("-inf"), None)) -Sum = CommutativeMonoid(kernel=_NumberTerm.__add__, identity=0) -Product = CommutativeMonoidWithZero(kernel=_NumberTerm.__mul__, identity=1, zero=0) +Sum = Monoid(kernel=_NumberTerm.__add__, identity=0) +Product = MonoidWithZero(kernel=_NumberTerm.__mul__, identity=1, zero=0) CartesianProduct = Monoid(kernel=product, identity=[()]) +@dataclass +class _ExtensiblePredicate[T]: + elems: set[T] + + def register(self, t: T) -> None: + self.elems.add(t) + + def __call__(self, t: T) -> bool: + return t in self.elems + + +is_commutative = _ExtensiblePredicate({Max, Min, Sum, Product}) +is_idempotent = _ExtensiblePredicate({Max, Min}) + + @dataclass class _ExtensibleBinaryRelation[S, T]: tuples: set[tuple[S, T]] @@ -292,13 +264,7 @@ def __call__(self, s: S, t: T) -> bool: distributes_over = _ExtensibleBinaryRelation( - { - (Max.plus, Min.plus), - (Min.plus, Max.plus), - (Sum.plus, Min.plus), - (Sum.plus, Max.plus), - (Product.plus, Sum.plus), - } + {(Max, Min), (Min, Max), (Sum, Min), (Sum, Max), (Product, Sum)} ) @@ -337,10 +303,12 @@ class PlusAssoc(ObjectInterpretation): @implements(Monoid.plus) def plus(self, monoid, *args): - if any(isinstance(x, Term) and x.op is monoid.plus for x in args): + def is_nested_plus(x): + return isinstance(x, Term) and x.op is monoid.plus + + if any(is_nested_plus(x) for x in args): flat_args = itertools.chain.from_iterable( - t.args if isinstance(t, Term) and t.op is monoid.plus else (t,) - for t in args + t.args if is_nested_plus(t) else (t,) for t in args ) assert len(args) > 0 return monoid.plus(*flat_args) @@ -353,33 +321,36 @@ class PlusDistr(ObjectInterpretation): @implements(Monoid.plus) def plus(self, monoid, *args): if any( - isinstance(x, Term) and distributes_over(monoid.plus, x.op) for x in args + isinstance(x, Term) + and _is_monoid_plus(x.op) + and distributes_over(monoid, x.op.__self__) + for x in args ): non_terms = [] - # group terms by head operation - by_head_op = defaultdict(list) + # group terms by their monoid + by_monoid: dict[Monoid, list[Term]] = defaultdict(list) for t in args: - if isinstance(t, Term): - by_head_op[t.op].append(t) + if isinstance(t, Term) and _is_monoid_plus(t.op): + by_monoid[t.op.__self__].append(t) else: non_terms.append(t) # distribute over each group progress = False final_sum = [] - for op, terms in by_head_op.items(): + for m, terms in by_monoid.items(): if ( len(terms) > 1 - and distributes_over(monoid.plus, op) - and not distributes_over(op, monoid.plus) + and distributes_over(monoid, m) + and not distributes_over(m, monoid) ): progress = True term_args = (t.args for t in terms) dist_terms = ( monoid.plus(*args) for args in itertools.product(*term_args) ) - final_sum.append(op(*dist_terms)) + final_sum.append(m.plus(*dist_terms)) else: final_sum += terms if progress: @@ -390,8 +361,10 @@ def plus(self, monoid, *args): class PlusZero(ObjectInterpretation): """x₁ * ... * 0 * ... * xₙ = 0""" - @implements(CommutativeMonoidWithZero.plus) + @implements(Monoid.plus) def plus(self, monoid, *args): + if not (isinstance(monoid, MonoidWithZero)): + return fwd() if any(x is monoid.zero for x in args): return monoid.zero return fwd() @@ -400,8 +373,11 @@ def plus(self, monoid, *args): class PlusConsecutiveDups(ObjectInterpretation): """x ⊕ x ⊕ y = x ⊕ y""" - @implements(IdempotentMonoid.plus) + @implements(Monoid.plus) def plus(self, monoid, *args): + if not is_idempotent(monoid): + return fwd() + dedup_args = ( args[i] for i in range(len(args)) @@ -423,8 +399,11 @@ def __eq__(self, other): def __hash__(self): return syntactic_hash(self) - @implements(Semilattice.plus) + @implements(Monoid.plus) def plus(self, monoid, *args): + if not (is_idempotent(monoid) and is_commutative(monoid)): + return fwd() + # elim dups args_count = Counter(self._HashableTerm(t) for t in args) if len(args_count) < len(args): @@ -475,7 +454,7 @@ class ReduceFusion(ObjectInterpretation): @implements(Monoid.reduce) def reduce(self, monoid, body, streams): - if isinstance(body, Term) and body.op == monoid.reduce: + if isinstance(body, Term) and body.op is monoid.reduce: return monoid.reduce(body.args[0], streams | body.args[1]) return fwd() @@ -485,9 +464,11 @@ class ReduceSplit(ObjectInterpretation): reduce(R, S, b1 + ... + bn) = reduce(R, S, b1) + ... + reduce(R, S, bn) """ - @implements(CommutativeMonoid.reduce) + @implements(Monoid.reduce) def reduce(self, monoid, body, streams): - if isinstance(body, Term) and body.op == monoid.plus: + if not is_commutative(monoid): + return fwd() + if isinstance(body, Term) and body.op is monoid.plus: return monoid.plus(*(monoid.reduce(x, streams) for x in body.args)) return fwd() @@ -506,9 +487,16 @@ class ReduceFactorization(ObjectInterpretation): and free(Aᵢ) ∩ S ⊆ Sᵢ """ - @implements(CommutativeMonoid.reduce) + @implements(Monoid.reduce) def reduce(self, monoid, body, streams): - if isinstance(body, Term) and distributes_over(body.op, monoid.plus): + if not is_commutative(monoid): + return fwd() + if ( + isinstance(body, Term) + and _is_monoid_plus(body.op) + and distributes_over(body.op.__self__, monoid) + ): + inner_monoid: Monoid = body.op.__self__ stream_vars = set(streams.keys()) factors = [(arg, fvsof(arg)) for arg in body.args] stream_ids = {v: i for (i, v) in enumerate(stream_vars)} @@ -521,7 +509,7 @@ def reduce(self, monoid, body, streams): ds.union(stream_id, *deps) # factors are in the same partition as their dependencies - for factor, factor_fvs in factors: + for _, factor_fvs in factors: factor_streams = sorted( [stream_ids[v] for v in (factor_fvs & stream_vars)] ) @@ -550,14 +538,14 @@ def reduce(self, monoid, body, streams): for t in partition_factors ), "partition contains all streams required by factor" - partition_term = body.op(*(t[0] for t in partition_factors)) + partition_term = inner_monoid.plus(*(t[0] for t in partition_factors)) new_reduces.append((partition_term, partition_streams)) placed_streams |= partition_stream_keys constant_factors = [t for (t, fvs) in factors if not (fvs & stream_vars)] if len(new_reduces) > 1: - result = body.op( + result = inner_monoid.plus( *constant_factors, *(monoid.reduce(*args) for args in new_reduces) ) return result @@ -592,18 +580,6 @@ def inner_stream( ) -def match_reduce(term: Term) -> tuple | None: - reduce_args = None - - def set_reduce_args(*args, **kwargs): - nonlocal reduce_args - reduce_args = args - - with interpreter({Monoid.reduce: set_reduce_args}): - term.op(*term.args, **term.kwargs) - return reduce_args - - class ReduceDistributeCartesianProduct(ObjectInterpretation): """Eliminates a reduce over a cartesian product. ∑_x₁ ∑_x₂ ... ∑_xₙ ∏_i f(xᵢ) = ∏_i ∑_xᵢ f(xᵢ) @@ -623,25 +599,30 @@ class ReduceDistributeCartesianProduct(ObjectInterpretation): variable elimination." AISTATS. 2013. """ - @implements(CommutativeMonoid.reduce) + @implements(Monoid.reduce) def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): - if not (isinstance(sum_body, Term)): + if not (is_commutative(sum_monoid) and isinstance(sum_body, Term)): return fwd() # body is a product or multiplication of products - if distributes_over(sum_body.op, sum_monoid.plus): + if _is_monoid_plus(sum_body.op) and distributes_over( + sum_body.op.__self__, sum_monoid + ): prod_reduces = sum_body.args else: prod_reduces = [sum_body] products: list[tuple[Monoid, Callable, Operation, Term]] = [] for prod_reduce in prod_reduces: - prod_args = match_reduce(prod_reduce) - if prod_args is None: + if not ( + isinstance(prod_reduce, Term) and _is_monoid_reduce(prod_reduce.op) + ): return fwd() - (prod_monoid, prod_body, prod_streams) = prod_args + prod_monoid: Monoid = prod_reduce.op.__self__ + prod_body = prod_reduce.args[0] + prod_streams = typing.cast(Mapping, prod_reduce.args[1]) if not ( - distributes_over(prod_monoid.plus, sum_monoid.plus) + distributes_over(prod_monoid, sum_monoid) and (len(products) == 0 or products[-1][0] == prod_monoid) ): return fwd() @@ -658,7 +639,7 @@ def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): for outer_sum_streams, cprod_op, cprod_term in inner_stream(sum_streams): if not ( isinstance(cprod_term, Term) - and cprod_term.op == CartesianProduct.reduce + and cprod_term.op is CartesianProduct.reduce ): continue (cprod_body, cprod_streams) = cprod_term.args diff --git a/effectful/ops/types.py b/effectful/ops/types.py index d24be9745..c68e0d46c 100644 --- a/effectful/ops/types.py +++ b/effectful/ops/types.py @@ -42,6 +42,59 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: return self.func(self.dispatch, *args, **kwargs) +class _CustomSingleDispatchMethod[**P, **Q, S, T]: + """Method analog of :class:`_CustomSingleDispatchCallable`. + + The wrapped function has signature ``(self, dispatch, *args, **kwargs)``, + where ``dispatch`` is :meth:`functools.singledispatch.dispatch`. As a + descriptor, it binds ``self`` on attribute access, so callers invoke it + as ``instance.method(*args, **kwargs)``. + """ + + def __init__( + self, + func: Callable[Concatenate[Any, Callable[[type], Callable[Q, S]], P], T], + ): + self.func = func + self._registry = functools.singledispatch(func) + self.__signature__ = inspect.signature( + functools.partial(func, None, None) # type: ignore[arg-type] + ) + functools.update_wrapper(self, func) # type: ignore[arg-type] + + @property + def dispatch(self): + return self._registry.dispatch + + @property + def register(self): + return self._registry.register + + def __get__(self, instance, owner=None): + if instance is None: + return self + return _BoundCustomSingleDispatchMethod(self, instance) + + +class _BoundCustomSingleDispatchMethod: + __slots__ = ("_method", "_instance") + + def __init__(self, method: _CustomSingleDispatchMethod, instance: Any): + self._method = method + self._instance = instance + + @property + def dispatch(self): + return self._method.dispatch + + @property + def register(self): + return self._method.register + + def __call__(self, *args, **kwargs): + return self._method.func(self._instance, self._method.dispatch, *args, **kwargs) + + class _ClassMethodOpDescriptor(classmethod): def __init__(self, define, *args, **kwargs): super().__init__(*args, **kwargs) @@ -311,6 +364,15 @@ def func(*args, **kwargs): return typing.cast(Operation[P, T], cls.define(func, **kwargs)) + @define.register(types.MethodType) + @classmethod + def _define_methodtype[**P, T]( + cls, t: Callable[P, T], *, name: str | None = None + ) -> "Operation[P, T]": + op = cls._define_callable(t, name=name) + op.__self__ = t.__self__ # type: ignore[attr-defined] + return typing.cast("Operation[P, T]", op) + @define.register(staticmethod) @classmethod def _define_staticmethod[**P, T](cls, t: "staticmethod[P, T]", **kwargs): @@ -350,6 +412,20 @@ def func(*args, **kwargs): op.register = default._registry.register # type: ignore[attr-defined] return op + @define.register(_CustomSingleDispatchMethod) + @classmethod + def _define_customsingledispatchmethod( + cls, default: _CustomSingleDispatchMethod, **kwargs + ): + @functools.wraps(default.func) + def _wrapper(obj, *args, **kwargs): + return default.__get__(obj)(*args, **kwargs) + + op = cls.define(_wrapper, **kwargs) + op.register = default.register # type: ignore[attr-defined] + op.dispatch = default.dispatch # type: ignore[attr-defined] + return op + @typing.final def __default_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> "Expr[V]": """The default rule is used when the operation is not handled. diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index e73a9a7b2..d881869ac 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -14,9 +14,9 @@ Monoid, NormalizeIntp, Product, - Semilattice, Sum, distributes_over, + is_commutative, ) from effectful.ops.semantics import apply, evaluate, fvsof, handler from effectful.ops.syntax import _BaseTerm, defdata, syntactic_eq @@ -56,7 +56,7 @@ for o in ALL_MONOIDS for i in ALL_MONOIDS if distributes_over( - typing.cast(Monoid, i.values[0]).plus, typing.cast(Monoid, o.values[0]).plus + typing.cast(Monoid, i.values[0]), typing.cast(Monoid, o.values[0]) ) ] @@ -354,7 +354,7 @@ def test_plus_idempotent_non_consecutive(monoid): PlusDups; plain IdempotentMonoid leaves it as-is (consecutive-only).""" a, b = define_vars("a", "b") lhs = monoid.plus(a(), b(), a()) - if isinstance(monoid, Semilattice): + if is_commutative(monoid): rhs = monoid.plus(a(), b()) else: rhs = monoid.plus(a(), b(), a()) From 800b8bc4fc6968ccd6f507a91aee29c802a86d2a Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 21 May 2026 15:11:59 -0400 Subject: [PATCH 4/6] Add `jax` array monoids and reduction rule (#658) * Add monoid module (#653) * add monoid module * clean up * fix doctest * fix * wip * remove incorrect rule * add disjoint set tests and fix bug * lint * drop jax monoid defs * drop incorrect comment * add assert * reduce nondeterminism and add assertions * fix inconsistent stream numbering and missing constant factors * wip * cleanup * wip * fix rule * wip * fix bug * cleanup * lin * wip * fix tests * format * lint * wip * wip * wip * wip * wip * wip * wip * wip * drop runtime typed dict lifting * wip * format * reorganize * stop using string dicts to avoid unification issue * wip * wip * wip * wip * wip * use check_rewrite in jax tests * lint * fix bugs --- effectful/handlers/jax/_handlers.py | 7 + effectful/handlers/jax/monoid.py | 162 +++++++ effectful/ops/monoid.py | 486 +++++++++++-------- effectful/ops/syntax.py | 2 + tests/_monoid_helpers.py | 284 ++++++++++- tests/test_handlers_jax_monoid.py | 96 ++++ tests/test_ops_monoid.py | 718 +++++++++++++++------------- 7 files changed, 1206 insertions(+), 549 deletions(-) create mode 100644 effectful/handlers/jax/monoid.py create mode 100644 tests/test_handlers_jax_monoid.py diff --git a/effectful/handlers/jax/_handlers.py b/effectful/handlers/jax/_handlers.py index 308cdb76e..c5d104233 100644 --- a/effectful/handlers/jax/_handlers.py +++ b/effectful/handlers/jax/_handlers.py @@ -19,6 +19,7 @@ deffn, defop, syntactic_eq, + syntactic_hash, ) from effectful.ops.types import Expr, NotHandled, Operation, Term @@ -277,3 +278,9 @@ def _(x: jax.Array, other) -> bool: and x.shape == other.shape and bool((jnp.asarray(x) == jnp.asarray(other)).all()) ) + + +@syntactic_hash.register(jax.Array) +def _(x: jax.Array) -> int: + # Concrete arrays aren't hashable; hash by shape, dtype, and bytes. + return hash(("jax.Array", x.shape, str(x.dtype), bytes(jax.numpy.asarray(x)))) diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py new file mode 100644 index 000000000..a406cda5b --- /dev/null +++ b/effectful/handlers/jax/monoid.py @@ -0,0 +1,162 @@ +import functools + +import jax + +import effectful.handlers.jax.numpy as jnp +from effectful.handlers.jax import bind_dims, unbind_dims +from effectful.handlers.jax.scipy.special import logsumexp +from effectful.ops.monoid import ( + CartesianProduct, + Max, + Min, + Monoid, + NormalizeIntp, + Product, + Sum, + outer_stream, +) +from effectful.ops.semantics import evaluate, fvsof, fwd, handler, typeof +from effectful.ops.syntax import ObjectInterpretation, deffn, implements +from effectful.ops.types import Operation + + +def cartesian_prod(x, y): + if x.ndim == 1: + x = x[:, None] + if y.ndim == 1: + y = y[:, None] + nx, dx = x.shape + ny, dy = y.shape + # Broadcast into (nx, ny, dx+dy), then flatten the first two axes + x_b = jnp.broadcast_to(x[:, None, :], (nx, ny, dx)) + y_b = jnp.broadcast_to(y[None, :, :], (nx, ny, dy)) + return jnp.concatenate([x_b, y_b], axis=-1).reshape(nx * ny, dx + dy) + + +LogSumExp = Monoid(name="LogSumExp", identity=jnp.asarray(float("-inf"))) + + +def _jax_args(args): + """True iff ``args`` is non-empty and every arg is a concrete + :class:`jax.Array` (no Terms). + """ + typs = (typeof(a) for a in args) + return ( + bool(args) + and any(issubclass(t, jax.Array) for t in typs) + and all(issubclass(t, jax.typing.ArrayLike) for t in typs) + ) + + +class SumPlusJax(ObjectInterpretation): + @implements(Sum.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.add, args) + + +class ProductPlusJax(ObjectInterpretation): + @implements(Product.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.multiply, args) + + +class MinPlusJax(ObjectInterpretation): + @implements(Min.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.minimum, args) + + +class MaxPlusJax(ObjectInterpretation): + @implements(Max.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.maximum, args) + + +class LogSumExpPlusJax(ObjectInterpretation): + @implements(LogSumExp.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.logaddexp, args) + + +class CartesianProductPlusJax(ObjectInterpretation): + @implements(CartesianProduct.plus) + def plus(self, *args): + # Skip identity ``[()]`` args; short-circuit on zero ``[]``. Both + # sentinels arrive as Python lists alongside jax-array factors, so + # check for them explicitly before composing. + if not any(isinstance(a, jax.Array) for a in args): + return fwd() + result = None + for a in args: + if a is CartesianProduct.zero: + return CartesianProduct.zero + if a is CartesianProduct.identity: + continue + if not isinstance(a, jax.Array): + return fwd() + result = a if result is None else cartesian_prod(result, a) + return result if result is not None else CartesianProduct.identity + + +ARRAY_REDUCTORS = { + Sum: jnp.sum, + Product: jnp.prod, + Min: jnp.min, + Max: jnp.max, + LogSumExp: logsumexp, +} + + +class ArrayReduce(ObjectInterpretation): + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if monoid not in ARRAY_REDUCTORS or typeof(body) is not jax.Array: + return fwd() + if not streams: + return monoid.identity + + reductor = ARRAY_REDUCTORS[monoid] + index = Operation.define(jax.Array) + for stream_key, stream_body, streams_tail in outer_stream(streams): + if not issubclass(typeof(stream_body), jax.Array): + continue + + if stream_key in fvsof(body): + with handler({stream_key: deffn(unbind_dims(stream_body, index))}): + eval_body = evaluate(body) + eval_streams_tail = evaluate(streams_tail) + assert isinstance(eval_streams_tail, dict) + reduce_tail = ( + monoid.reduce(eval_body, eval_streams_tail) + if len(eval_streams_tail) > 0 + else eval_body + ) + return reductor(bind_dims(reduce_tail, index), axis=0) + else: + # TODO: In this case, the stream is unused in the body. The body + # should be multiplied by the length of the stream. The current + # behavior is not efficient. + return fwd() + + return fwd() + + +NormalizeIntp.extend( + ArrayReduce(), + SumPlusJax(), + ProductPlusJax(), + MinPlusJax(), + MaxPlusJax(), + LogSumExpPlusJax(), + CartesianProductPlusJax(), +) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 0d6e230c0..70bb50022 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -1,10 +1,10 @@ import collections.abc import functools import itertools -import numbers +import operator import typing -from collections import Counter, defaultdict -from collections.abc import Callable, Generator, Iterable, Iterator, Mapping +from collections import Counter, UserDict, defaultdict +from collections.abc import Callable, Generator, Iterable, Mapping from dataclasses import dataclass from graphlib import TopologicalSorter from typing import Annotated, Any @@ -14,21 +14,13 @@ from effectful.ops.syntax import ( ObjectInterpretation, Scoped, - _NumberTerm, deffn, implements, iter_, syntactic_eq, syntactic_hash, ) -from effectful.ops.types import ( - Expr, - Interpretation, - NotHandled, - Operation, - Term, - _CustomSingleDispatchMethod, -) +from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term # Note: The streams value type should be something like Iterable[T], but some of # our target stream types (e.g. jax.Array) are not subtypes of Iterable @@ -43,31 +35,35 @@ ) -def order_streams[T](streams: Streams[T]) -> Iterable[tuple[Operation[[], T], Any]]: - """Determine an order to evaluate the streams based on their dependencies""" +def outer_stream( + streams: Streams, +) -> Iterable[tuple[Operation, Expr, dict[Operation, Expr]]]: + """Returns the streams that can be ordered outermost in the loop nest as + well as the remaining streams in the nest. + + """ stream_vars = set(streams.keys()) - dependencies = {k: fvsof(v) & stream_vars for k, v in streams.items()} - topo = TopologicalSorter(dependencies) + pred = {k: fvsof(v) & stream_vars for k, v in streams.items()} + topo = TopologicalSorter(pred) topo.prepare() - while topo.is_active(): - node_group = topo.get_ready() - for op in sorted(node_group): - yield (op, streams[op]) - topo.done(*node_group) + return ( + (op, streams[op], {k: v for (k, v) in streams.items() if k != op}) + for op in topo.get_ready() + ) class Monoid[T]: - kernel: Operation[[T, T], T] + """A monoid with ``plus`` and ``reduce`` :class:`Operation` s.""" + + _name: str identity: T - def __init__(self, kernel: Callable[[T, T], T], identity: T): + def __init__(self, identity: T, name: str): + self._name = name self.identity = identity - self.kernel = ( - kernel if isinstance(kernel, Operation) else Operation.define(kernel) - ) def __repr__(self): - return f"{type(self)}({self.kernel}, {self.identity})" + return f"Monoid({self._name!r})" def __eq__(self, other): return id(self) == id(other) @@ -75,166 +71,63 @@ def __eq__(self, other): def __hash__(self): return hash(id(self)) + # the weak typing allows us to write monoid.plus(monoid.identity, ) + # and monoid.plus(monoid.identity, ) @Operation.define - @_CustomSingleDispatchMethod - def plus[S](self, dispatch, *args: S) -> S: - """Monoid addition with broadcasting over common collection types, - callables, and interpretations. + def plus(self, *args: Any) -> Any: + """Monoid addition. Handlers supply per-monoid and broadcasting + behavior; the default rule only handles empty / Term cases. """ if not args: - return typing.cast(S, self.identity) - return dispatch(type(args[0]))(self, *args) - - @plus.register(object) # type: ignore[attr-defined] - def _(self, *args): - if any(isinstance(x, Term) for x in args): - raise NotHandled - return functools.reduce(self.kernel, args, self.identity) - - @plus.register(tuple) # type: ignore[attr-defined] - def _(self, *args): - return tuple(self.plus(*vs) for vs in zip(*args, strict=True)) - - @plus.register(Generator) # type: ignore[attr-defined] - def _(self, *args): - return (self.plus(*vs) for vs in zip(*args, strict=True)) - - @plus.register(Mapping) # type: ignore[attr-defined] - def _(self, *args): - if isinstance(args[0], Interpretation): - keys = args[0].keys() - for b in args[1:]: - if not isinstance(b, Interpretation): - raise TypeError(f"Expected interpretation but got {b}") - if not keys == b.keys(): - raise ValueError( - f"Expected interpretation of {keys} but got {b.keys()}" - ) - return {k: self.plus(*(handler(b)(b[k]) for b in args)) for k in keys} - - for b in args[1:]: - if not isinstance(b, Mapping): - raise TypeError(f"Expected mapping but got {b}") - all_values = collections.defaultdict(list) - for d in args: - for k, v in d.items(): - all_values[k].append(v) - return {k: self.plus(*vs) for (k, vs) in all_values.items()} + return self.identity + raise NotHandled @Operation.define - @functools.singledispatchmethod def reduce[A, B, U: Body]( self, body: Annotated[U, Scoped[A | B]], streams: Annotated[Streams, Scoped[A]], ) -> Annotated[U, Scoped[B]]: - if callable(body): - return typing.cast(U, lambda *a, **k: self.reduce(body(*a, **k), streams)) - - def generator(loop_order) -> Iterator[Interpretation]: - if len(loop_order) == 0: - return - - stream_key = loop_order[0][0] - stream_values = evaluate(streams[stream_key]) - stream_values_iter = iter(stream_values) # type: ignore[arg-type] - - # If we try to iterate and get a term instead of a real - # iterator, give up + """Reduce ``body`` over ``streams``. Handlers supply per-monoid and + broadcasting behavior; the default rule only handles the empty-stream + case. + """ + for stream_key, stream_body, streams_tail in outer_stream(streams): + if isinstance(stream_body, Term): + continue + stream_values_iter = iter(stream_body) if isinstance(stream_values_iter, Term) and stream_values_iter.op is iter_: - raise NotHandled - - if len(loop_order) == 1: - for val in stream_values_iter: - yield {stream_key: functools.partial(lambda v: v, val)} - else: - for val in stream_values_iter: - intp: Interpretation = { - stream_key: functools.partial(lambda v: v, val) - } - with handler(intp): - for intp2 in generator(loop_order[1:]): - yield coproduct(intp, intp2) - - loop_order = list(order_streams(streams)) - return self.plus( - *(handler(intp)(evaluate)(body) for intp in generator(loop_order)) - ) - - @reduce.register # type: ignore[attr-defined] - def _(self, body: Mapping, streams): - return {k: self.reduce(v, streams) for (k, v) in body.items()} - - @reduce.register # type: ignore[attr-defined] - def _(self, body: tuple, streams): - return tuple(self.reduce(x, streams) for x in body) - - @reduce.register # type: ignore[attr-defined] - def _(self, body: Generator, streams): - return (self.reduce(x, streams) for x in body) - - -def _is_monoid_plus(op: Operation) -> bool: - """True if ``op`` is the ``plus`` operation of some :class:`Monoid`.""" - owner = getattr(op, "__self__", None) - return isinstance(owner, Monoid) and op is owner.plus - - -def _is_monoid_reduce(op: Operation) -> bool: - """True if ``op`` is the ``reduce`` operation of some :class:`Monoid`.""" - owner = getattr(op, "__self__", None) - return isinstance(owner, Monoid) and op is owner.reduce + continue + new_reduces = [] + for stream_val in stream_values_iter: + with handler({stream_key: deffn(stream_val)}): + eval_args = evaluate((body, streams_tail)) + assert isinstance(eval_args, tuple) + new_reduces.append( + self.reduce(*eval_args) if streams_tail else eval_args[0] + ) + return self.plus(*new_reduces) + raise NotHandled class MonoidWithZero[T](Monoid[T]): zero: T - def __init__(self, kernel: Callable[[T, T], T], identity: T, zero: T): - super().__init__(kernel, identity) + def __init__(self, name: str, identity: T, zero: T): + super().__init__(name=name, identity=identity) self.zero = zero - def __repr__(self): - return f"{type(self)}({self.kernel}, {self.identity}, {self.zero})" - - -@Operation.define -def _arg_min[T]( - a: tuple[numbers.Number, T | None], b: tuple[numbers.Number, T | None] -) -> tuple[numbers.Number, T | None]: - if isinstance(a[0], Term) or isinstance(b[0], Term): - raise NotHandled - return b if b[0] < a[0] else a # type: ignore - - -@Operation.define -def _arg_max[T]( - a: tuple[numbers.Number, T | None], b: tuple[numbers.Number, T | None] -) -> tuple[numbers.Number, T | None]: - if isinstance(a[0], Term) or isinstance(b[0], Term): - raise NotHandled - return b if b[0] > a[0] else a # type: ignore - - -@Operation.define -def product[T]( - a: Iterable[tuple[T, ...] | T], b: Iterable[tuple[T, ...] | T] -) -> Iterable[tuple[T, ...]]: - if isinstance(a, Term) or isinstance(b, Term): - raise NotHandled - - def to_tuple(x): - return x if isinstance(x, tuple) else (x,) - return [to_tuple(x) + to_tuple(y) for (x, y) in itertools.product(a, b)] - - -Min = Monoid(kernel=min, identity=float("inf")) -Max = Monoid(kernel=max, identity=float("-inf")) -ArgMin = Monoid(kernel=_arg_min, identity=(float("inf"), None)) -ArgMax = Monoid(kernel=_arg_max, identity=(float("-inf"), None)) -Sum = Monoid(kernel=_NumberTerm.__add__, identity=0) -Product = MonoidWithZero(kernel=_NumberTerm.__mul__, identity=1, zero=0) -CartesianProduct = Monoid(kernel=product, identity=[()]) +Min = Monoid(name="Min", identity=float("inf")) +Max = Monoid(name="Max", identity=-float("inf")) +ArgMin = Monoid(name="ArgMin", identity=(Min.identity, None)) +ArgMax = Monoid(name="ArgMax", identity=(Max.identity, None)) +Sum = Monoid(name="Sum", identity=0) +Product = MonoidWithZero(name="Product", identity=1, zero=0) +# CartesianProduct values are "two-level indexable" (rows × positions). The +# identity ``[()]`` is one row of zero positions (composing with it preserves +# shape); the zero ``[]`` is no rows (absorbs under product). +CartesianProduct = MonoidWithZero(name="CartesianProduct", identity=[()], zero=[]) @dataclass @@ -268,6 +161,18 @@ def __call__(self, s: S, t: T) -> bool: ) +def _is_monoid_plus(op: Operation) -> bool: + """True if ``op`` is the ``plus`` operation of some :class:`Monoid`.""" + owner = getattr(op, "__self__", None) + return isinstance(owner, Monoid) and op is owner.plus + + +def _is_monoid_reduce(op: Operation) -> bool: + """True if ``op`` is the ``reduce`` operation of some :class:`Monoid`.""" + owner = getattr(op, "__self__", None) + return isinstance(owner, Monoid) and op is owner.reduce + + class PlusEmpty(ObjectInterpretation): """plus() = 0""" @@ -319,7 +224,7 @@ class PlusDistr(ObjectInterpretation): """x + (y * z) = x * y + x * z""" @implements(Monoid.plus) - def plus(self, monoid, *args): + def plus(self, monoid: Monoid, *args): if any( isinstance(x, Term) and _is_monoid_plus(x.op) @@ -417,24 +322,6 @@ def plus(self, monoid, *args): return fwd() -NormalizePlusIntp = functools.reduce( - coproduct, - typing.cast( - list[Interpretation], - [ - PlusEmpty(), - PlusSingle(), - PlusIdentity(), - PlusAssoc(), - PlusDistr(), - PlusZero(), - PlusConsecutiveDups(), - PlusDups(), - ], - ), -) - - class ReduceNoStreams(ObjectInterpretation): """Implements the identity reduce(R, ∅, body) = 0 @@ -668,18 +555,217 @@ def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): return fwd() -NormalizeReduceIntp = functools.reduce( - coproduct, - typing.cast( - list[Interpretation], - [ - ReduceNoStreams(), - ReduceFusion(), - ReduceSplit(), - ReduceFactorization(), - ReduceDistributeCartesianProduct(), - ], - ), +class MonoidOverCallable(ObjectInterpretation): + """``monoid.reduce(f, streams) = lambda *a: monoid.reduce(f(*a), streams)``.""" + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) or not isinstance(body, Callable): + return fwd() + return lambda *a, **k: monoid.reduce(body(*a, **k), streams) + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if not args or any( + isinstance(arg, Term) or not isinstance(arg, Callable) for arg in args + ): + return fwd() + return lambda *a, **k: monoid.plus(*(arg(*a, **k) for arg in args)) + + +class MonoidOverMapping(ObjectInterpretation): + """``monoid.reduce({k: v_k}, streams) = {k: monoid.reduce(v_k, streams)}``.""" + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) or not isinstance(body, Mapping): + return fwd() + return {k: monoid.reduce(v, streams) for (k, v) in body.items()} + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if not args or not isinstance(args[0], Mapping): + return fwd() + + if isinstance(args[0], Interpretation): + keys = args[0].keys() + for b in args[1:]: + if not isinstance(b, Interpretation): + raise TypeError(f"Expected interpretation but got {b}") + if not keys == b.keys(): + raise ValueError( + f"Expected interpretation of {keys} but got {b.keys()}" + ) + return {k: monoid.plus(*(handler(b)(b[k]) for b in args)) for k in keys} + + for b in args[1:]: + if not isinstance(b, Mapping): + raise TypeError(f"Expected mapping but got {b}") + all_values = collections.defaultdict(list) + for d in args: + for k, v in d.items(): + all_values[k].append(v) + return {k: monoid.plus(*vs) for (k, vs) in all_values.items()} + + +def _scalar_args(args): + """True iff ``args`` is non-empty and every arg is a concrete int/float.""" + return ( + bool(args) + and not any(isinstance(x, Term) for x in args) + and all(isinstance(x, int | float) for x in args) + ) + + +class SumPlus(ObjectInterpretation): + """Scalar implementation of :data:`Sum`.""" + + @implements(Sum.plus) + def plus(self, *args): + if not _scalar_args(args): + return fwd() + return sum(args) + + +class MinPlus(ObjectInterpretation): + """Scalar implementation of :data:`Min`.""" + + @implements(Min.plus) + def plus(self, *args): + if not _scalar_args(args): + return fwd() + return min(args) + + +class MaxPlus(ObjectInterpretation): + """Scalar implementation of :data:`Max`.""" + + @implements(Max.plus) + def plus(self, *args): + if not _scalar_args(args): + return fwd() + return max(args) + + +class ProductPlus(ObjectInterpretation): + """Scalar implementation of :data:`Product`.""" + + @implements(Product.plus) + def plus(self, *args): + if not _scalar_args(args): + return fwd() + return functools.reduce(operator.mul, args) + + +class ArgMinPlus(ObjectInterpretation): + """Scalar score implementation of :data:`ArgMin`.""" + + @implements(ArgMin.plus) + def plus(self, *args): + if not args or not all(isinstance(a, tuple) for a in args): + return fwd() + if any(isinstance(a[0], Term) for a in args): + return fwd() + if not all(isinstance(a[0], int | float) for a in args): + return fwd() + return min(args, key=lambda a: a[0]) + + +class ArgMaxPlus(ObjectInterpretation): + """Scalar score implementation of :data:`ArgMax`.""" + + @implements(ArgMax.plus) + def plus(self, *args): + if not args or not all(isinstance(a, tuple) for a in args): + return fwd() + if any(isinstance(a[0], Term) for a in args): + return fwd() + if not all(isinstance(a[0], int | float) for a in args): + return fwd() + return max(args, key=lambda a: a[0]) + + +class CartesianProductPlus(ObjectInterpretation): + """Pure-Python implementation of :data:`CartesianProduct`.""" + + @implements(CartesianProduct.plus) + def plus(self, *args): + if not args: + return fwd() + if any(isinstance(x, Term) for x in args): + return fwd() + if not all(isinstance(x, Iterable) for x in args): + return fwd() + + def to_tuple(x): + return x if isinstance(x, tuple) else (x,) + + return [ + sum((to_tuple(v) for v in vals), ()) for vals in itertools.product(*args) + ] + + +is_scalar = _ExtensiblePredicate({Min, Max, Sum, Product}) + + +class MonoidOverSequence(ObjectInterpretation): + @implements(Monoid.plus) + def plus(self, monoid, *args): + if ( + not is_scalar(monoid) + or not args + or not isinstance(args[0], tuple | list | Generator) + ): + return fwd() + zipped = zip(*args, strict=True) + result = (monoid.plus(*vs) for vs in zipped) + if isinstance(args[0], tuple | list): + return type(args[0])(result) + return result + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if not is_scalar(monoid) or not isinstance(body, tuple | list | Generator): + return fwd() + result = (monoid.reduce(x, streams) for x in body) + if isinstance(body, tuple | list): + return type(body)(result) + return result + + +class _ExtensibleInterpretation(UserDict, Interpretation): + def extend(self, *intps: Interpretation) -> typing.Self: + for intp in intps: + self.data = coproduct(self.data, intp) # type: ignore[assignment] + return self + + +NormalizeIntp = _ExtensibleInterpretation().extend( + MonoidOverSequence(), + MonoidOverMapping(), + MonoidOverCallable(), + ReduceNoStreams(), + ReduceFusion(), + ReduceSplit(), + ReduceFactorization(), + ReduceDistributeCartesianProduct(), + PlusEmpty(), + PlusSingle(), + PlusIdentity(), + PlusAssoc(), + PlusDistr(), + PlusZero(), + PlusConsecutiveDups(), + PlusDups(), + SumPlus(), + MinPlus(), + MaxPlus(), + ProductPlus(), + ArgMinPlus(), + ArgMaxPlus(), + CartesianProductPlus(), ) +"""``NormalizeIntp``applies pure-Term rewrites (associativity, distributivity, +identity elimination, fusion, factorization, etc.). -NormalizeIntp = coproduct(NormalizePlusIntp, NormalizeReduceIntp) +""" diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 8fb12598f..5ea04fcb8 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -849,6 +849,8 @@ def _(x: collections.abc.Sequence, other) -> bool: @syntactic_eq.register(object) @syntactic_eq.register(str | bytes) def _(x: object, other) -> bool: + if isinstance(other, Term): # Terms often override __eq__ + return False return x == other diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index 9b311b257..f15103e30 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -1,23 +1,60 @@ +import itertools from collections.abc import Callable, Mapping, Sequence +from dataclasses import dataclass from typing import Any, get_args, get_origin +import jax +from hypothesis import given, settings from hypothesis import strategies as st -from effectful.ops.syntax import deffn -from effectful.ops.types import Operation +import effectful.handlers.jax.numpy as _jnp +from effectful.internals.runtime import interpreter +from effectful.ops.monoid import NormalizeIntp +from effectful.ops.semantics import apply, evaluate, handler +from effectful.ops.syntax import _BaseTerm, defdata, deffn, syntactic_eq +from effectful.ops.types import NotHandled, Operation, Term + +_JAX_ARRAY_SHAPE = (2,) + + +def _jax_array_value_strategy() -> st.SearchStrategy[jax.Array]: + return st.lists( + st.integers(min_value=-5, max_value=5), + min_size=_JAX_ARRAY_SHAPE[0], + max_size=_JAX_ARRAY_SHAPE[0], + ).map(lambda xs: jax.numpy.asarray(xs, dtype=jax.numpy.float32)) + + +# Unary jax fns map a scalar to a 1-D array (analogous to ``_UNARY_LIST_FNS`` +# for ints). Uses the effectful-wrapped jnp so named-dim broadcasting works. +_UNARY_JAX_FNS: list[Callable[[jax.Array], jax.Array]] = [ + lambda a: _jnp.stack([a, a + 1]), + lambda a: _jnp.stack([a, -a]), + lambda a: _jnp.stack([a, a + 1, 2 * a]), +] + +_BINARY_JAX_FNS: list[Callable[[jax.Array, jax.Array], jax.Array]] = [ + lambda a, b: a + b, + lambda a, b: a - b, + lambda a, b: a * b, +] def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: """Strategy for the value an *0-arg* Operation should return.""" if annotation is int: - return st.integers() + return st.integers(min_value=-100, max_value=100) if annotation is float: return st.floats(allow_nan=False) if get_origin(annotation) is list and get_args(annotation) == (int,): - return st.lists(st.integers(), max_size=2) + return st.lists(st.integers(min_value=-100, max_value=100), max_size=2) + if annotation is jax.Array: + return _jax_array_value_strategy() + if get_origin(annotation) is list and get_args(annotation) == (jax.Array,): + return st.lists(_jax_array_value_strategy(), max_size=2) raise NotImplementedError( f"No value strategy for return annotation {annotation!r}; " - "supported: int, list[int]" + "supported: int, list[int], jax.Array, list[jax.Array]" ) @@ -46,6 +83,13 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: lambda x: [0, x, x + 1], ] +_UNARY_JAX_LIST_FNS: list[Callable[[jax.Array], list[jax.Array]]] = [ + lambda _x: [], + lambda x: [x], + lambda x: [x, x + 1], + lambda x: [x, -x], +] + def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: """Pick a strategy producing a callable suitable for binding `op` in an @@ -64,8 +108,18 @@ def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: return st.sampled_from(_BINARY_NUM_FNS) if get_origin(ret) is list and get_args(ret) == (int,) and param_types == (int,): return st.sampled_from(_UNARY_LIST_FNS) + if ret is jax.Array and param_types == (jax.Array,): + return st.sampled_from(_UNARY_JAX_FNS) + if ret is jax.Array and param_types == (jax.Array, jax.Array): + return st.sampled_from(_BINARY_JAX_FNS) + if ( + get_origin(ret) is list + and get_args(ret) == (jax.Array,) + and param_types == (jax.Array,) + ): + return st.sampled_from(_UNARY_JAX_LIST_FNS) raise NotImplementedError( - f"Function-typed free var must return int or list[int]; got {ret!r} for {op}" + f"No callable strategy for free var with return {ret!r}, params {param_types!r}" ) @@ -82,4 +136,220 @@ def random_interpretation( return intp -__all__ = ["random_interpretation"] +def define_vars(*names, typ=int): + if len(names) == 1: + return Operation.define(typ, name=names[0]) + return tuple(Operation.define(typ, name=n) for n in names) + + +def syntactic_eq_alpha(x, y) -> bool: + """Alpha-equivalence-respecting variant of ``syntactic_eq``. + + Walks each expression bottom-up with :func:`evaluate` and renames + every bound variable to a deterministic canonical Operation. The + canonical names are assigned by a counter that increments in + ``evaluate``'s natural traversal order, so two alpha-equivalent + expressions canonicalize to syntactically identical results. + """ + + _op_cache: dict[int, Operation] = {} + + def _canonical_op(idx: int, op: Operation) -> Operation: + """Cached canonical Operation, keyed by encounter index. + + Cached so that two independent canonicalize runs return the same + Operation object for the same index — letting ``syntactic_eq`` + compare canonical forms by Operation identity. + """ + if idx in _op_cache: + return _op_cache[idx] + + op = Operation.define(op, name=f"__cv_{idx}") + _op_cache[idx] = op + return op + + cx = _canonicalize(x, _canonical_op) + cy = _canonicalize(y, _canonical_op) + return syntactic_eq(cx, cy) + + +def _canonicalize(expr, _canonical_op): + counter = itertools.count() + + def _substitute(arg, renaming): + """Apply a bound-variable renaming using ``evaluate`` for traversal.""" + if not renaming: + return arg + with interpreter({apply: _BaseTerm, **renaming}): + return evaluate(arg) + + def _bound_var_order(args, kwargs, bound_set: set[Operation]) -> list[Operation]: + """Return bound variables in deterministic encounter order.""" + seen: list[Operation] = [] + seen_set: set[Operation] = set() + + def _capture(op, *a, **kw): + if op in bound_set and op not in seen_set: + seen.append(op) + seen_set.add(op) + return defdata(op, *a, **kw) + + # ``evaluate`` walks Terms, lists, tuples, mappings, dataclasses, + # etc. for free; the apply handler captures bound vars used as + # ``x()`` anywhere in the body. + with interpreter({apply: _capture}): + evaluate((args, kwargs)) + + # Binders bypass the apply handler. Pick them up with a small structural + # walk that visits dict keys too. + def _walk_bare(obj): + if isinstance(obj, Operation): + if obj in bound_set and obj not in seen_set: + seen.append(obj) + seen_set.add(obj) + elif isinstance(obj, dict): + for k, v in obj.items(): + _walk_bare(k) + _walk_bare(v) + elif isinstance(obj, list | set | frozenset | tuple): + for v in obj: + _walk_bare(v) + + _walk_bare((args, kwargs)) + return seen + + def _apply_canonical(op, *args, **kwargs) -> Term: + bindings = op.__fvs_rule__(*args, **kwargs) + all_bound: set[Operation] = set().union( + *bindings.args, *bindings.kwargs.values() + ) + if not all_bound: + return _BaseTerm(op, *args, **kwargs) + + order = _bound_var_order(args, kwargs, all_bound) + canonical = {var: _canonical_op(next(counter), var) for var in order} + assert all_bound <= set(order) + + new_args = tuple( + _substitute( + arg, {v: canonical[v] for v in bindings.args[i] if v in canonical} + ) + for i, arg in enumerate(args) + ) + new_kwargs = { + k: _substitute( + v, + {var: canonical[var] for var in bindings.kwargs[k] if var in canonical}, + ) + for k, v in kwargs.items() + } + + # avoid the renaming from defdata + return _BaseTerm(op, *new_args, **new_kwargs) + + with interpreter({apply: _apply_canonical}): + return evaluate(expr) + + +@dataclass(frozen=True) +class Backend: + """A value-domain spec used to share monoid tests across int and jax.Array + backends. Provides the concrete value type, the hypothesis strategy for + drawing scalars in property tests, and an equality predicate that works + for that domain. + """ + + name: str + scalar_typ: Any + stream_typ: Any + scalar_strategy: st.SearchStrategy[Any] + eq: Callable[[Any, Any], bool] + + def fresh_op(self, name: str, n_args: int = 1, ret: str = "scalar") -> Operation: + """Build a fresh, unhandled Operation whose parameter and return + annotations are derived from this backend. + + ``ret`` is ``"scalar"`` for a scalar return or ``"stream"`` for a + stream-of-scalar return. The operation has ``n_args`` parameters, + each of type ``scalar_typ``. + """ + scalar = self.scalar_typ + out = self.stream_typ if ret == "stream" else scalar + params = ", ".join(f"_a{i}" for i in range(n_args)) + ns: dict[str, Any] = {"NotHandled": NotHandled} + exec(f"def _fn({params}):\n raise NotHandled\n", ns) + fn = ns["_fn"] + fn.__annotations__ = { + **{f"_a{i}": scalar for i in range(n_args)}, + "return": out, + } + return Operation.define(fn, name=name) + + +def _int_eq(a: Any, b: Any) -> bool: + return not isinstance(a, Term) and not isinstance(b, Term) and a == b + + +def _jax_eq(a: Any, b: Any) -> bool: + def _leaf_eq(x: Any, y: Any) -> bool: + return bool(jax.numpy.all(jax.numpy.isclose(x, y, equal_nan=True))) + + try: + leaves = jax.tree.leaves(jax.tree.map(_leaf_eq, a, b)) + except (ValueError, TypeError): + return False + return all(leaves) + + +def check_rewrite( + lhs, + rhs, + rule, + *, + backend: Backend, + free_vars=[], + max_examples: int = 25, + deadline=None, +) -> None: + with handler(rule): + norm = evaluate(lhs) + assert syntactic_eq_alpha(norm, rhs) + + @given(intp=random_interpretation(free_vars)) + @settings(max_examples=max_examples, deadline=deadline) + def _check_semantics(intp): + with handler(NormalizeIntp), handler(intp): + lhs_val = evaluate(lhs) + rhs_val = evaluate(rhs) + assert backend.eq(lhs_val, rhs_val) + + _check_semantics() + + +INT_BACKEND = Backend( + name="int", + scalar_typ=int, + stream_typ=list[int], + scalar_strategy=st.integers(min_value=-100, max_value=100), + eq=_int_eq, +) + + +JAX_BACKEND = Backend( + name="jax", + scalar_typ=jax.Array, + stream_typ=jax.Array, + scalar_strategy=_jax_array_value_strategy(), + eq=_jax_eq, +) + + +__all__ = [ + "Backend", + "INT_BACKEND", + "JAX_BACKEND", + "random_interpretation", + "define_vars", + "syntactic_eq_alpha", + "check_rewrite", +] diff --git a/tests/test_handlers_jax_monoid.py b/tests/test_handlers_jax_monoid.py new file mode 100644 index 000000000..35d041fe2 --- /dev/null +++ b/tests/test_handlers_jax_monoid.py @@ -0,0 +1,96 @@ +import jax +import pytest + +import effectful.handlers.jax.numpy as jnp +from effectful.handlers.jax import bind_dims, unbind_dims +from effectful.handlers.jax.monoid import ArrayReduce, LogSumExp +from effectful.handlers.jax.scipy.special import logsumexp +from effectful.ops.monoid import Max, Min, Product, Sum +from tests._monoid_helpers import JAX_BACKEND, Backend, check_rewrite, define_vars + +MONOIDS = [ + pytest.param(Sum, jnp.sum, id="Sum"), + pytest.param(Product, jnp.prod, id="Product"), + pytest.param(Min, jnp.min, id="Min"), + pytest.param(Max, jnp.max, id="Max"), + pytest.param(LogSumExp, logsumexp, id="LogSumExp"), +] + + +@pytest.fixture +def backend() -> Backend: + return JAX_BACKEND + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_1(monoid, reductor, backend: Backend): + (x, k) = define_vars("x", "k", typ=jax.Array) + X = define_vars("X", typ=backend.stream_typ) + + lhs = monoid.reduce(x(), {x: X()}) + rhs = reductor(bind_dims(unbind_dims(X(), k), k), axis=0) + + check_rewrite( + lhs=lhs, rhs=rhs, rule=ArrayReduce(), backend=backend, free_vars=[x, X, k] + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_2(monoid, reductor, backend: Backend): + (x, y, k1, k2) = define_vars("x", "y", "k1", "k2", typ=backend.scalar_typ) + (X, Y) = define_vars("X", "Y", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") + + lhs = monoid.reduce(f(x(), y()), {x: X(), y: Y()}) + rhs = reductor( + bind_dims( + reductor( + bind_dims(f(unbind_dims(X(), k1), unbind_dims(Y(), k2)), k2), + axis=0, + ), + k1, + ), + axis=0, + ) + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ArrayReduce(), + backend=backend, + free_vars=[x, y, k1, k2, X, Y, f], + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_3(monoid, reductor, backend: Backend): + """Stream `y` is `g(x())` — depends on the bound element of X. The reducer + must inline ``g`` along the same named dim used to unbind `x`.""" + (x, y, k1, k2) = define_vars("x", "y", "k1", "k2", typ=backend.scalar_typ) + X = define_vars("X", typ=backend.stream_typ) + + f = backend.fresh_op("f", n_args=2, ret="scalar") + g = backend.fresh_op("g", n_args=1, ret="stream") + + lhs = monoid.reduce(f(x(), y()), {x: X(), y: g(x())}) + rhs = reductor( + bind_dims( + reductor( + bind_dims( + f(unbind_dims(X(), k1), unbind_dims(g(unbind_dims(X(), k1)), k2)), + k2, + ), + axis=0, + ), + k1, + ), + axis=0, + ) + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ArrayReduce(), + backend=backend, + free_vars=[x, y, k1, k2, X, f, g], + ) diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index d881869ac..c7ee7567c 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -1,29 +1,51 @@ -import functools -import itertools import typing import pytest -from hypothesis import given, settings +from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st -from effectful.internals.runtime import interpreter +import effectful.handlers.jax.monoid # noqa: F401 from effectful.ops.monoid import ( CartesianProduct, Max, Min, Monoid, + MonoidOverMapping, + MonoidOverSequence, NormalizeIntp, + PlusAssoc, + PlusConsecutiveDups, + PlusDistr, + PlusDups, + PlusEmpty, + PlusIdentity, + PlusSingle, + PlusZero, Product, + ReduceDistributeCartesianProduct, + ReduceFactorization, + ReduceFusion, + ReduceNoStreams, + ReduceSplit, Sum, distributes_over, - is_commutative, ) -from effectful.ops.semantics import apply, evaluate, fvsof, handler -from effectful.ops.syntax import _BaseTerm, defdata, syntactic_eq -from effectful.ops.types import NotHandled, Operation -from tests._monoid_helpers import random_interpretation +from effectful.ops.semantics import fvsof, handler +from effectful.ops.types import Operation +from tests._monoid_helpers import ( + INT_BACKEND, + JAX_BACKEND, + Backend, + check_rewrite, + define_vars, + syntactic_eq_alpha, +) + + +@pytest.fixture(params=[INT_BACKEND, JAX_BACKEND], ids=["int", "jax"]) +def backend(request) -> Backend: + return request.param -_INT = st.integers(min_value=-100, max_value=100) ALL_MONOIDS = [ pytest.param(Sum, id="Sum"), @@ -61,247 +83,183 @@ ] -def define_vars(*names, typ=int): - if len(names) == 1: - return Operation.define(typ, name=names[0]) - return tuple(Operation.define(typ, name=n) for n in names) - - -@functools.cache -def _canonical_op(idx: int) -> Operation: - """Globally cached canonical Operation, keyed by encounter index. - - Cached so that two independent canonicalize runs return the same - Operation object for the same index — letting ``syntactic_eq`` - compare canonical forms by Operation identity. - """ - return Operation.define(int, name=f"__cv_{idx}") - - -def syntactic_eq_alpha(x, y) -> bool: - """Alpha-equivalence-respecting variant of ``syntactic_eq``. - - Walks each expression bottom-up with :func:`evaluate` and renames - every bound variable to a deterministic canonical Operation. The - canonical names are assigned by a counter that increments in - ``evaluate``'s natural traversal order, so two alpha-equivalent - expressions canonicalize to syntactically identical results. - """ - return syntactic_eq(_canonicalize(x), _canonicalize(y)) - - -def _canonicalize(expr): - counter = itertools.count() - - def _substitute(arg, renaming): - """Apply a bound-variable renaming using ``evaluate`` for traversal.""" - if not renaming: - return arg - with interpreter({apply: _BaseTerm, **renaming}): - return evaluate(arg) - - def _bound_var_order(args, kwargs, bound_set): - """Return bound variables in deterministic encounter order.""" - seen: list[Operation] = [] - seen_set: set[Operation] = set() - - def _capture(op, *a, **kw): - if op in bound_set and op not in seen_set: - seen.append(op) - seen_set.add(op) - return defdata(op, *a, **kw) - - # ``evaluate`` walks Terms, lists, tuples, mappings, dataclasses, - # etc. for free; the apply handler captures bound vars used as - # ``x()`` anywhere in the body. - with interpreter({apply: _capture}): - evaluate((args, kwargs)) - - # Binders bypass the apply handler. Pick them up with a small structural - # walk that visits dict keys too. - def _walk_bare(obj): - if isinstance(obj, Operation): - if obj in bound_set and obj not in seen_set: - seen.append(obj) - seen_set.add(obj) - elif isinstance(obj, dict): - for k, v in obj.items(): - _walk_bare(k) - _walk_bare(v) - elif isinstance(obj, list | set | frozenset | tuple): - for v in obj: - _walk_bare(v) - - _walk_bare((args, kwargs)) - return seen - - def _apply_canonical(op, *args, **kwargs): - bindings = op.__fvs_rule__(*args, **kwargs) - all_bound: set[Operation] = set().union( - *bindings.args, *bindings.kwargs.values() - ) - if not all_bound: - return _BaseTerm(op, *args, **kwargs) - - order = _bound_var_order(args, kwargs, all_bound) - canonical = {var: _canonical_op(next(counter)) for var in order} - assert all_bound <= set(order) - - new_args = tuple( - _substitute( - arg, {v: canonical[v] for v in bindings.args[i] if v in canonical} - ) - for i, arg in enumerate(args) - ) - new_kwargs = { - k: _substitute( - v, - {var: canonical[var] for var in bindings.kwargs[k] if var in canonical}, - ) - for k, v in kwargs.items() - } - - # avoid the renaming from defdata - return _BaseTerm(op, *new_args, **new_kwargs) - - with interpreter({apply: _apply_canonical}): - return evaluate(expr) - - @pytest.mark.parametrize("monoid", ALL_MONOIDS) -@given(a=_INT, b=_INT, c=_INT) -@settings(max_examples=50, deadline=None) -def test_associativity(monoid, a, b, c): - left = monoid.plus(monoid.plus(a, b), c) - right = monoid.plus(a, monoid.plus(b, c)) - assert left == right +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_associativity(monoid, backend, data): + a = data.draw(backend.scalar_strategy) + b = data.draw(backend.scalar_strategy) + c = data.draw(backend.scalar_strategy) + with handler(NormalizeIntp): + left = monoid.plus(monoid.plus(a, b), c) + right = monoid.plus(a, monoid.plus(b, c)) + assert backend.eq(left, right) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -@given(a=_INT) -@settings(max_examples=50, deadline=None) -def test_identity(monoid, a): - assert monoid.plus(monoid.identity, a) == a - assert monoid.plus(a, monoid.identity) == a +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_identity(monoid, backend, data): + a = data.draw(backend.scalar_strategy) + with handler(NormalizeIntp): + assert backend.eq(monoid.plus(monoid.identity, a), a) + assert backend.eq(monoid.plus(a, monoid.identity), a) @pytest.mark.parametrize("monoid", COMMUTATIVE) -@given(a=_INT, b=_INT) -@settings(max_examples=50, deadline=None) -def test_commutativity(monoid, a, b): - assert monoid.plus(a, b) == monoid.plus(b, a) +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_commutativity(monoid, backend, data): + a = data.draw(backend.scalar_strategy) + b = data.draw(backend.scalar_strategy) + with handler(NormalizeIntp): + assert backend.eq(monoid.plus(a, b), monoid.plus(b, a)) @pytest.mark.parametrize("monoid", IDEMPOTENT) -@given(a=_INT) -@settings(max_examples=50, deadline=None) -def test_idempotence(monoid, a): - assert monoid.plus(a, a) == a +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_idempotence(monoid, backend, data): + a = data.draw(backend.scalar_strategy) + with handler(NormalizeIntp): + assert backend.eq(monoid.plus(a, a), a) @pytest.mark.parametrize("monoid", WITH_ZERO) -@given(a=_INT) -@settings(max_examples=50, deadline=None) -def test_zero_absorbs(monoid, a): - assert monoid.plus(monoid.zero, a) == monoid.zero - assert monoid.plus(a, monoid.zero) == monoid.zero - - -def _check_pair(lhs, rhs, *, free_vars=[], max_examples: int = 25) -> None: - """Run structural + semantic checks on a TermPair.""" +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_zero_absorbs(monoid, backend, data): + a = data.draw(backend.scalar_strategy) with handler(NormalizeIntp): - norm = evaluate(lhs) - - assert syntactic_eq_alpha(norm, rhs) + assert backend.eq(monoid.plus(monoid.zero, a), monoid.zero) + assert backend.eq(monoid.plus(a, monoid.zero), monoid.zero) - @given(intp=random_interpretation(free_vars)) - @settings(max_examples=max_examples, deadline=None) - def _check_semantics(intp): - with handler(intp): - lhs_val = evaluate(lhs) - rhs_val = evaluate(rhs) - assert lhs_val == rhs_val - _check_semantics() +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_empty(monoid, backend): + check_rewrite( + lhs=monoid.plus(), rhs=monoid.identity, rule=PlusEmpty(), backend=backend + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_empty(monoid): - _check_pair(lhs=monoid.plus(), rhs=monoid.identity) +def test_plus_single(monoid, backend): + x = define_vars("x", typ=backend.scalar_typ) + check_rewrite( + lhs=monoid.plus(x()), rhs=x(), rule=PlusSingle(), backend=backend, free_vars=[x] + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_single(monoid): - x = define_vars("x", typ=type(monoid.identity)) - _check_pair(lhs=monoid.plus(x()), rhs=x(), free_vars=[x]) +def test_plus_identity_right(monoid, backend): + x = define_vars("x", typ=backend.scalar_typ) + lhs = monoid.plus(x(), monoid.identity) + rhs = monoid.plus(x()) -@pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_identity_right(monoid): - x = define_vars("x", typ=type(monoid.identity)) - _check_pair(lhs=monoid.plus(x(), monoid.identity), rhs=x(), free_vars=[x]) + check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity(), backend=backend, free_vars=[x]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_identity_left(monoid): - x = define_vars("x", typ=type(monoid.identity)) - _check_pair(lhs=monoid.plus(monoid.identity, x()), rhs=x(), free_vars=[x]) +def test_plus_identity_left(monoid, backend): + x = define_vars("x", typ=backend.scalar_typ) + + lhs = monoid.plus(monoid.identity, x()) + rhs = monoid.plus(x()) + + check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity(), backend=backend, free_vars=[x]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_assoc_right(monoid): - x, y, z = define_vars("x", "y", "z", typ=type(monoid.identity)) - _check_pair( +def test_plus_assoc_right(monoid, backend): + x, y, z = define_vars("x", "y", "z", typ=backend.scalar_typ) + check_rewrite( lhs=monoid.plus(x(), monoid.plus(y(), z())), rhs=monoid.plus(x(), y(), z()), + rule=PlusAssoc(), + backend=backend, free_vars=[x, y, z], ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_assoc_left(monoid): - x, y, z = define_vars("x", "y", "z", typ=type(monoid.identity)) - _check_pair( +def test_plus_assoc_left(monoid, backend): + x, y, z = define_vars("x", "y", "z", typ=backend.scalar_typ) + check_rewrite( lhs=monoid.plus(monoid.plus(x(), y()), z()), rhs=monoid.plus(x(), y(), z()), + rule=PlusAssoc(), + backend=backend, free_vars=[x, y, z], ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_sequence(monoid): - a, b, c, d = define_vars("a", "b", "c", "d", typ=type(monoid.identity)) - _check_pair( +def test_plus_sequence(monoid, backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) + check_rewrite( lhs=monoid.plus((a(), b()), (c(), d())), rhs=(monoid.plus(a(), c()), monoid.plus(b(), d())), + rule=MonoidOverSequence(), + backend=backend, free_vars=[a, b, c, d], ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_mapping(monoid): - a, b, c, d = define_vars("a", "b", "c", "d", typ=type(monoid.identity)) - _check_pair( - lhs=monoid.plus({"x": a(), "y": b()}, {"x": c(), "z": d()}), - rhs={"x": monoid.plus(a(), c()), "y": b(), "z": d()}, +def test_plus_mapping(monoid, backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) + + lhs = monoid.plus({0: a(), 1: b()}, {0: c(), 2: d()}) + rhs = {0: monoid.plus(a(), c()), 1: monoid.plus(b()), 2: monoid.plus(d())} + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=MonoidOverMapping(), + backend=backend, free_vars=[a, b, c, d], ) -def test_plus_distributes(): - a, b, c, d = define_vars("a", "b", "c", "d") +def test_plus_distributes(backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d())) - rhs = Sum.plus( - Product.plus(a(), c()), - Product.plus(a(), d()), - Product.plus(b(), c()), - Product.plus(b(), d()), + rhs = Product.plus( + Sum.plus( + Product.plus(a(), c()), + Product.plus(a(), d()), + Product.plus(b(), c()), + Product.plus(b(), d()), + ) + ) + check_rewrite( + lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b, c, d]) -def test_plus_distributes_constant(): - a, b, c, d = define_vars("a", "b", "c", "d") +def test_plus_distributes_constant(backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d()), 5) rhs = Product.plus( 5, @@ -312,11 +270,13 @@ def test_plus_distributes_constant(): Product.plus(b(), d()), ), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b, c, d]) + check_rewrite( + lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] + ) -def test_plus_distributes_multiple(): - a, b, c, d = define_vars("a", "b", "c", "d") +def test_plus_distributes_multiple(backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) lhs = Sum.plus( Min.plus(a(), b()), Min.plus(c(), d()), @@ -337,72 +297,123 @@ def test_plus_distributes_multiple(): Sum.plus(b(), d()), ), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b, c, d]) + check_rewrite( + lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] + ) @pytest.mark.parametrize("monoid", IDEMPOTENT) -def test_plus_idempotent_consecutive(monoid): +def test_plus_idempotent_consecutive(monoid, backend): """``a, a, b → a, b`` — only consecutive duplicates collapse.""" - a, b = define_vars("a", "b") + a, b = define_vars("a", "b", typ=backend.scalar_typ) lhs = monoid.plus(a(), a(), b()) - return _check_pair(lhs=lhs, rhs=monoid.plus(a(), b()), free_vars=[a, b]) + return check_rewrite( + lhs=lhs, + rhs=monoid.plus(a(), b()), + rule=PlusConsecutiveDups(), + backend=backend, + free_vars=[a, b], + ) @pytest.mark.parametrize("monoid", IDEMPOTENT) -def test_plus_idempotent_non_consecutive(monoid): +def test_plus_idempotent_non_consecutive(monoid, backend): """``a, b, a`` — Semilattice (Min/Max) collapses via commutative - PlusDups; plain IdempotentMonoid leaves it as-is (consecutive-only).""" - a, b = define_vars("a", "b") + PlusDups.""" + a, b = define_vars("a", "b", typ=backend.scalar_typ) lhs = monoid.plus(a(), b(), a()) - if is_commutative(monoid): - rhs = monoid.plus(a(), b()) - else: - rhs = monoid.plus(a(), b(), a()) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b]) + rhs = monoid.plus(a(), b()) + check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups(), backend=backend, free_vars=[a, b]) -def test_plus_commutative_idempotent_long(): +@pytest.mark.parametrize("monoid", [Min, Max]) +def test_plus_commutative_idempotent_long(monoid, backend): """Long alternation collapses via commutative dedup (Min/Max only).""" - a, b = define_vars("a", "b") - lhs = Min.plus(a(), b(), a(), b(), b(), a(), a()) - _check_pair(lhs=lhs, rhs=Min.plus(a(), b()), free_vars=[a, b]) + a, b = define_vars("a", "b", typ=backend.scalar_typ) + lhs = monoid.plus(a(), b(), a(), b(), b(), a(), a()) + rhs = monoid.plus(a(), b()) + check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups(), backend=backend, free_vars=[a, b]) @pytest.mark.parametrize("monoid", WITH_ZERO) -def test_plus_zero(monoid): - a = define_vars("a") +def test_plus_zero(monoid, backend): + a = define_vars("a", typ=backend.scalar_typ) lhs_right = monoid.plus(a(), monoid.zero) lhs_left = monoid.plus(monoid.zero, a()) - _check_pair(lhs=lhs_right, rhs=monoid.zero, free_vars=[a]) - _check_pair(lhs=lhs_left, rhs=monoid.zero, free_vars=[a]) + rhs = monoid.zero + check_rewrite( + lhs=lhs_right, rhs=rhs, rule=PlusZero(), backend=backend, free_vars=[a] + ) + check_rewrite( + lhs=lhs_left, rhs=rhs, rule=PlusZero(), backend=backend, free_vars=[a] + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_1(monoid, backend): + x, y = define_vars("x", "y", typ=backend.scalar_typ) + lhs = monoid.reduce(x(), {x: []}) + rhs = monoid.identity + check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_2(monoid, backend): + x, y = define_vars("x", "y", typ=backend.scalar_typ) + Y = define_vars("Y", typ=backend.stream_typ) + + lhs = monoid.reduce(x(), {y: Y(), x: []}) + rhs = monoid.identity + + check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, Y]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_body_sequence(monoid): - x = Operation.define(int, name="x") - X = Operation.define(list[int], name="X") +def test_partial_3(monoid, backend): + x, y, a, b = define_vars("x", "y", "a", "b", typ=backend.scalar_typ) + Y = define_vars("Y", typ=backend.stream_typ) + + lhs = monoid.reduce(x(), {y: Y(), x: [a(), b()]}) + rhs = monoid.plus(monoid.reduce(a(), {y: Y()}), monoid.reduce(b(), {y: Y()})) + + check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, a, b, Y]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_4(monoid, backend): + x, y, a, b = define_vars("x", "y", "a", "b", typ=backend.scalar_typ) + f = backend.fresh_op("f", n_args=1, ret="stream") + + lhs = monoid.reduce(x(), {y: f(x()), x: [a(), b()]}) + rhs = monoid.plus(monoid.reduce(a(), {y: f(a())}), monoid.reduce(b(), {y: f(b())})) - @Operation.define - def f(_x: int) -> int: - raise NotHandled + check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, a, b, f]) + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_body_sequence(monoid, backend): + x = Operation.define(backend.scalar_typ, name="x") + X = Operation.define(backend.stream_typ, name="X") + f = backend.fresh_op("f", n_args=1, ret="scalar") g = Operation.define(f, name="g") lhs = monoid.reduce((f(x()), g(x())), {x: X()}) rhs = (monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, f, g]) + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=MonoidOverSequence(), + backend=backend, + free_vars=[X, f, g], + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_body_sequence_2(monoid): - x, y = define_vars("x", "y") - X, Y = define_vars("X", "Y", typ=list[int]) - - @Operation.define - def f(_x: int) -> int: - raise NotHandled - +def test_reduce_body_sequence_2(monoid, backend): + x, y = define_vars("x", "y", typ=backend.scalar_typ) + X, Y = define_vars("X", "Y", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=1, ret="scalar") g = Operation.define(f, name="g") lhs = monoid.reduce((f(x()), g(y())), {x: X(), y: Y()}) @@ -411,103 +422,115 @@ def f(_x: int) -> int: monoid.reduce(g(y()), {x: X(), y: Y()}), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, Y, f, g]) + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=MonoidOverSequence(), + backend=backend, + free_vars=[X, Y, f, g], + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_body_mapping(monoid): - x = Operation.define(int, name="x") - X = Operation.define(list[int], name="X") - - @Operation.define - def f(_x: int) -> int: - raise NotHandled - +def test_reduce_body_mapping(monoid, backend): + x = Operation.define(backend.scalar_typ, name="x") + X = Operation.define(backend.stream_typ, name="X") + f = backend.fresh_op("f", n_args=1, ret="scalar") g = Operation.define(f, name="g") - lhs = monoid.reduce({"a": f(x()), "b": g(x())}, {x: X()}) + lhs = monoid.reduce({0: f(x()), 1: g(x())}, {x: X()}) rhs = { - "a": monoid.reduce(f(x()), {x: X()}), - "b": monoid.reduce(g(x()), {x: X()}), + 0: monoid.reduce(f(x()), {x: X()}), + 1: monoid.reduce(g(x()), {x: X()}), } - _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, f, g]) + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=MonoidOverMapping(), + backend=backend, + free_vars=[X, f, g], + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_no_streams(monoid): - a = define_vars("a") +def test_reduce_no_streams(monoid, backend): + a = define_vars("a", typ=backend.scalar_typ) lhs = monoid.reduce(a(), {}) rhs = monoid.identity - _check_pair(lhs=lhs, rhs=rhs, free_vars=[a]) + check_rewrite( + lhs=lhs, rhs=rhs, rule=ReduceNoStreams(), backend=backend, free_vars=[a] + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_reduce(monoid): - a, b = define_vars("a", "b") - A, B = define_vars("A", "B", typ=list[int]) - - @Operation.define - def f(_x: int, _y: int) -> int: - raise NotHandled +def test_reduce_reduce(monoid, backend): + a, b = define_vars("a", "b", typ=backend.scalar_typ) + A, B = define_vars("A", "B", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") lhs = monoid.reduce(monoid.reduce(f(a(), b()), {a: A()}), {b: B()}) rhs = monoid.reduce(f(a(), b()), {a: A(), b: B()}) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, f]) + check_rewrite( + lhs=lhs, rhs=rhs, rule=ReduceFusion(), backend=backend, free_vars=[A, B, f] + ) @pytest.mark.parametrize("monoid", COMMUTATIVE) -def test_reduce_plus(monoid): - a, b = define_vars("a", "b") - A, B = define_vars("A", "B", typ=list[int]) +def test_reduce_plus(monoid, backend): + a, b = define_vars("a", "b", typ=backend.scalar_typ) + A, B = define_vars("A", "B", typ=backend.stream_typ) lhs = monoid.reduce(monoid.plus(a(), b()), {a: A(), b: B()}) rhs = monoid.plus( monoid.reduce(a(), {a: A(), b: B()}), monoid.reduce(b(), {a: A(), b: B()}), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B]) + check_rewrite( + lhs=lhs, rhs=rhs, rule=ReduceSplit(), backend=backend, free_vars=[A, B] + ) -def test_reduce_independent_1(): - a, b = define_vars("a", "b") - A, B = define_vars("A", "B", typ=list[int]) +def test_reduce_independent_1(backend): + a, b = define_vars("a", "b", typ=backend.scalar_typ) + A, B = define_vars("A", "B", typ=backend.stream_typ) lhs = Sum.reduce(Product.plus(a(), b()), {a: A(), b: B()}) - rhs = Product.plus(Sum.reduce(a(), {a: A()}), Sum.reduce(b(), {b: B()})) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B]) - + rhs = Product.plus( + Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b()), {b: B()}) + ) + check_rewrite( + lhs=lhs, rhs=rhs, rule=ReduceFactorization(), backend=backend, free_vars=[A, B] + ) -def test_reduce_independent_2(): - a, b, c = define_vars("a", "b", "c") - A, B, C = define_vars("A", "B", "C", typ=list[int]) - @Operation.define - def f(_x: int, _y: int) -> int: - raise NotHandled +def test_reduce_independent_2(backend): + a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) + A, B, C = define_vars("A", "B", "C", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c())), {a: A(), b: B(), c: C()}) rhs = Product.plus( - Sum.reduce(a(), {a: A()}), + Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, C, f]) + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ReduceFactorization(), + backend=backend, + free_vars=[A, B, C, f], + ) -def test_reduce_independent_3_negative(): +def test_reduce_independent_3_negative(backend): """Stream `b` depends on `a` (b: g(a())), so the proposed factorization is unsound — the normalizer must NOT apply it.""" - a, b, c = define_vars("a", "b", "c") - A, C = define_vars("A", "C", typ=list[int]) - - @Operation.define - def f(_x: int, _y: int) -> int: - raise NotHandled + a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) + A, C = define_vars("A", "C", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") + g = backend.fresh_op("g", n_args=1, ret="stream") - @Operation.define - def g(_x: int) -> list[int]: - raise NotHandled - - with handler(NormalizeIntp): + with handler(ReduceFactorization()): # ty:ignore[invalid-argument-type] lhs = Sum.reduce( Product.plus(a(), b(), f(b(), c())), {a: A(), b: g(a()), c: C()} ) @@ -519,104 +542,107 @@ def g(_x: int) -> list[int]: assert not syntactic_eq_alpha(lhs, bogus_rhs) -def test_reduce_independent_4(): - a, b, c = define_vars("a", "b", "c") - A, B, C = define_vars("A", "B", "C", typ=list[int]) - - @Operation.define - def f(_x: int, _y: int) -> int: - raise NotHandled +def test_reduce_independent_4(backend): + a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) + A, B, C = define_vars("A", "B", "C", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c()), 7), {a: A(), b: B(), c: C()}) rhs = Product.plus( 7, - Sum.reduce(a(), {a: A()}), + Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, C, f]) + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ReduceFactorization(), + backend=backend, + free_vars=[A, B, C, f], + ) @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) -def test_reduce_lifted_1(outer, inner): - a, i = define_vars("a", "i") - A, N, A_domain = define_vars("A", "N", "A_domain", typ=list[int]) - - @Operation.define - def f(_: int) -> float: - raise NotHandled +def test_reduce_lifted_1(outer, inner, backend): + a, i = define_vars("a", "i", typ=backend.scalar_typ) + A, N, A_domain = define_vars("A", "N", "A_domain", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=1, ret="scalar") term1 = outer.reduce( inner.reduce(f(a()), {a: A()}), {A: CartesianProduct.reduce(A_domain(), {i: N()})}, ) - term2 = inner.reduce(outer.reduce(f(a()), {a: A_domain()}), {i: N()}) - _check_pair(lhs=term1, rhs=term2, free_vars=[N, A_domain, f]) + term2 = inner.reduce(outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N()}) + + check_rewrite( + lhs=term1, + rhs=term2, + rule=ReduceDistributeCartesianProduct(), + backend=backend, + free_vars=[N, A_domain, f], + ) def test_reduce_cartesian_1(): - a, i = define_vars("a", "i") - A = define_vars("A", typ=list[int]) + a, i = define_vars("a", "i", typ=int) + A = define_vars("A", typ=tuple[int]) - term1 = Sum.reduce( - Product.reduce(a(), {a: []}), - {A: CartesianProduct.reduce([], {i: []})}, - ) - term2 = Product.reduce(Sum.reduce(a(), {a: []}), {i: []}) + with handler(NormalizeIntp): + term1 = Sum.reduce( + Product.reduce(a(), {a: []}), + {A: CartesianProduct.reduce([], {i: []})}, + ) + term2 = Product.reduce(Sum.reduce(a(), {a: []}), {i: []}) assert term1 == term2 def test_reduce_cartesian_2(): - a, i = define_vars("a", "i") - A = define_vars("A", typ=list[int]) + a, i = define_vars("a", "i", typ=int) + A = define_vars("A", typ=tuple[int]) - term1 = Sum.reduce( - Product.reduce(a(), {a: A()}), - {A: CartesianProduct.reduce([(0,)], {i: [0]})}, - ) - term2 = Product.reduce(Sum.reduce(a(), {a: [0]}), {i: [0]}) + with handler(NormalizeIntp): + term1 = Sum.reduce( + Product.reduce(a(), {a: A()}), + {A: CartesianProduct.reduce([(0,)], {i: [0]})}, + ) + term2 = Product.reduce(Sum.reduce(a(), {a: [0]}), {i: [0]}) assert term1 == term2 @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) -def test_reduce_lifted_multi_index(outer, inner): - a, i, j = define_vars("a", "i", "j") - A, N, M, A_domain = define_vars("A", "N", "M", "A_domain", typ=list[int]) - - @Operation.define - def f(_: int) -> float: - raise NotHandled +def test_reduce_lifted_multi_index(outer, inner, backend): + a, i, j = define_vars("a", "i", "j", typ=backend.scalar_typ) + A, N, M, A_domain = define_vars("A", "N", "M", "A_domain", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=1, ret="scalar") term1 = outer.reduce( inner.reduce(f(a()), {a: A()}), {A: CartesianProduct.reduce(A_domain(), {i: N(), j: M()})}, ) term2 = inner.reduce( - outer.reduce(f(a()), {a: A_domain()}), + outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N(), j: M()}, ) - _check_pair(lhs=term1, rhs=term2, free_vars=[N, M, A_domain, f]) + check_rewrite( + lhs=term1, + rhs=term2, + rule=ReduceDistributeCartesianProduct(), + backend=backend, + free_vars=[N, M, A_domain, f], + ) @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) -def test_reduce_lifted_2(outer, inner): +def test_reduce_lifted_2(outer, inner, backend): """The worked example on page 396 of 'Lifted Variable Elimination: Decoupling the Operators from the Constraint Language'. """ - a, i, s, t = define_vars("a", "i", "s", "t") - A, N, T = define_vars("A", "N", "T", typ=list[int]) - - @Operation.define - def A_domain(_i: int) -> list[int]: - raise NotHandled - - @Operation.define - def f1(_a: int, _s: int) -> float: - raise NotHandled - - @Operation.define - def f2(_t: int, _a: int) -> float: - raise NotHandled + a, i, s, t = define_vars("a", "i", "s", "t", typ=backend.scalar_typ) + A, N, T = define_vars("A", "N", "T", typ=backend.stream_typ) + A_domain = backend.fresh_op("A_domain", n_args=1, ret="stream") + f1 = backend.fresh_op("f1", n_args=2, ret="scalar") + f2 = backend.fresh_op("f2", n_args=2, ret="scalar") term1 = outer.reduce( inner.reduce(inner.plus(f1(a(), s()), f2(t(), a())), {a: A()}), @@ -625,10 +651,18 @@ def f2(_t: int, _a: int) -> float: term2 = outer.reduce( inner.reduce( - outer.reduce(inner.plus(f1(a(), s()), f2(t(), a())), {a: A_domain(i())}), + outer.reduce( + inner.plus(inner.plus(f1(a(), s()), f2(t(), a()))), {a: A_domain(i())} + ), {i: N()}, ), {t: T()}, ) - _check_pair(lhs=term1, rhs=term2, free_vars=[a, i, s, t, A, N, T, A_domain, f1, f2]) + check_rewrite( + lhs=term1, + rhs=term2, + rule=ReduceDistributeCartesianProduct(), + backend=backend, + free_vars=[a, i, s, t, A, N, T, A_domain, f1, f2], + ) From 74555f208dd3536ed292400735157376b94ec8c7 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Fri, 22 May 2026 12:51:04 -0400 Subject: [PATCH 5/6] Add `delta` terms for array construction in `handlers.jax.monoid` (#663) * Add monoid module (#653) * add monoid module * clean up * fix doctest * fix * wip * remove incorrect rule * add disjoint set tests and fix bug * lint * drop jax monoid defs * drop incorrect comment * add assert * reduce nondeterminism and add assertions * fix inconsistent stream numbering and missing constant factors * wip * cleanup * wip * wip * fix rule * wip * fix bug * cleanup * lin * wip * fix tests * format * lint * wip * wip * wip * wip * wip * wip * wip * wip * drop runtime typed dict lifting * wip * format * reorganize * stop using string dicts to avoid unification issue * wip * wip * wip * wip * wip * use check_rewrite in jax tests * lint * wip * fix bugs * comment on not implemented cases * format * simplify * lint * add matmul test --- effectful/handlers/jax/_handlers.py | 6 + effectful/handlers/jax/_terms.py | 20 +-- effectful/handlers/jax/monoid.py | 255 +++++++++++++++++++++++++++- effectful/ops/monoid.py | 6 +- tests/test_handlers_jax_monoid.py | 177 ++++++++++++++++++- 5 files changed, 442 insertions(+), 22 deletions(-) diff --git a/effectful/handlers/jax/_handlers.py b/effectful/handlers/jax/_handlers.py index c5d104233..91fba369f 100644 --- a/effectful/handlers/jax/_handlers.py +++ b/effectful/handlers/jax/_handlers.py @@ -87,6 +87,12 @@ def _partial_eval(t: Expr[jax.Array]) -> Expr[jax.Array]: if not sized_fvs: return t + # if any dimension is zero sized, the result is empty + if any(size == 0 for size in sized_fvs.values()): + key = tuple(sized_fvs.keys()) + shape = tuple(sized_fvs[k] for k in key) + return jax_getitem(jnp.empty(shape), key) + def _is_eager(t): return not isinstance(t, Term) or t.op in sized_fvs or is_eager_array(t) diff --git a/effectful/handlers/jax/_terms.py b/effectful/handlers/jax/_terms.py index 812062931..c88fe9341 100644 --- a/effectful/handlers/jax/_terms.py +++ b/effectful/handlers/jax/_terms.py @@ -8,7 +8,6 @@ import effectful.handlers.jax.numpy as jnp from effectful.handlers.jax._handlers import ( IndexElement, - _partial_eval, _register_jax_op, bind_dims, jax_getitem, @@ -451,28 +450,15 @@ def _bind_dims_array(t: jax.Array, *args: Operation[[], jax.Array]) -> jax.Array >>> bind_dims(t, b, a).shape (3, 2) """ - - def _evaluate(expr): - if isinstance(expr, Term): - (args, kwargs) = jax.tree.map(_evaluate, (expr.args, expr.kwargs)) - return _partial_eval(expr) - if not jax.tree_util.treedef_is_leaf(jax.tree.structure(expr)): - return jax.tree.map(_evaluate, expr) - return expr - if not isinstance(t, Term): return t - result = _evaluate(t) - if not isinstance(result, Term) or not args: - return result - # ensure that the result is a jax_getitem with an array as the first argument - if not (result.op is jax_getitem and isinstance(result.args[0], jax.Array)): + if not (t.op is jax_getitem and isinstance(t.args[0], jax.Array)): raise NotHandled - array = result.args[0] - dims = result.args[1] + array = t.args[0] + dims = t.args[1] assert isinstance(dims, Sequence) # ensure that the order is a subset of the named dimensions diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py index a406cda5b..42d7866ec 100644 --- a/effectful/handlers/jax/monoid.py +++ b/effectful/handlers/jax/monoid.py @@ -1,4 +1,6 @@ import functools +import typing +from collections.abc import Iterable import jax @@ -12,12 +14,13 @@ Monoid, NormalizeIntp, Product, + Streams, Sum, outer_stream, ) from effectful.ops.semantics import evaluate, fvsof, fwd, handler, typeof from effectful.ops.syntax import ObjectInterpretation, deffn, implements -from effectful.ops.types import Operation +from effectful.ops.types import Interpretation, NotHandled, Operation, Term def cartesian_prod(x, y): @@ -151,8 +154,258 @@ def reduce(self, monoid, body, streams): return fwd() +@Operation.define +def delta(_index: tuple[int, ...], _weight: jax.Array) -> jax.Array: + raise NotHandled + + +py_range = range + + +@Operation.define +def range(*args: int) -> Iterable[jax.Array]: + raise NotHandled + + +def _range_start(term: Term): + assert term.op == range + if len(term.args) < 2: + return 0 + return term.args[0] + + +def _range_stop(term: Term): + assert term.op == range + if len(term.args) < 2: + return term.args[0] + return term.args[1] + + +def _range_step(term: Term): + assert term.op == range + if len(term.args) < 3: + return 1 + return term.args[2] + + +def _is_simple_range(term: Term) -> bool: + if term.op != range: + return False + + start = _range_start(term) + step = _range_step(term) + return ( + not isinstance(start, Term) + and start == 0 + and not isinstance(step, Term) + and step == 1 + ) + + +class ReduceDeltaIndependent(ObjectInterpretation): + """Eliminate a Delta that has independent, dense index arguments. + + reduce(M, streams, delta((), body)) ≡ reduce(M, streams, body) + + reduce(M, streams ∪ {v: range(N)}, delta(idx' ++ (v(),), body)) + ═══════════════════════════════════════════════════════════════════════════ + reduce(M, streams, delta(idx', bind_dims(body[v() := unbind_dims(streams[v], fv)], fv))) + + Not yet supported: + + - **Strided index streams** (``range(0, N, k)`` for ``k != 1``): the + premise ``_is_simple_range`` requires ``start == 0`` and ``step == 1``. + A strided extension would substitute ``v() := unbind_dims(jnp.arange( + start, stop, step), fv)`` and otherwise follow the same shape — the + change is purely in the recognised range form, the bind/unbind cycle + below is unchanged. + - **Non-zero start** (``range(a, b, 1)`` with ``a != 0``): same template + as the strided case; only the recognised range form changes. + - **Non-bare index expressions** (``delta((2*v(),), w)``, + ``delta((f(v()),), w)``, etc.): currently requires the final index + entry to be a bare call ``v()`` of a stream var op. Generalizing to + arbitrary index expressions is a scatter, not a bind: materialize the + index expression and the weight separately over ``v``, then + ``jnp.zeros(N).at[indices].set(values)`` (for Sum; analogous for + other monoids using ``.add``/``.min``/``.max``/...). This is a + different leaf operation from ``bind_dims`` and warrants a sibling + rule rather than an extension of this one. + """ + + @implements(Monoid.reduce) + def _(self, monoid: Monoid, body, streams: Streams): + if not (isinstance(body, Term) and body.op == delta): + return fwd() + + indices, weight = body.args + assert isinstance(indices, tuple) + + if not indices: + return monoid.reduce(weight, streams) + + head_indices, tail_index = indices[:-1], indices[-1] + if not (isinstance(tail_index, Term) and tail_index.op in streams): + return fwd() + + tail_op: Operation = tail_index.op + tail_stream = streams[tail_op] + if not (isinstance(tail_stream, Term) and _is_simple_range(tail_stream)): + return fwd() + + fresh_op = Operation.define(tail_op) + indices = jnp.arange(_range_stop(tail_stream)) + if isinstance(indices, jax.Array) and len(indices) == 0: + return monoid.identity + + fresh_stream = unbind_dims(indices, fresh_op) + subst_intp = typing.cast(Interpretation, {tail_op: deffn(fresh_stream)}) + fresh_body = bind_dims(handler(subst_intp)(evaluate)(weight), fresh_op) + fresh_streams = {k: v for (k, v) in streams.items() if k != tail_op} + return monoid.reduce(delta(head_indices, fresh_body), fresh_streams) + + +class ReduceDependentRangeMask(ObjectInterpretation): + """Eliminate a dependent range by masking. + + reduce(M, streams ∪ {u: range(N), v: range(u())}, body) + ═══════════════════════════════════════════════════════════════════════════ + reduce(M, streams ∪ {u: range(N), v: range(N)}, where(v() < u(), body, M.identity)) + + Currently recognises only the lower-triangular form ``v: range(u())``: + constant start of 0, dependent stop equal to a bare call of another + stream var. + + Not yet supported: + + - **Upper-triangular** (``v: range(u(), N)`` — constant stop, dependent + start): bbox becomes ``range(0, N)`` (or ``range(0, bbox_N)``), guard + becomes ``v() >= u()``. Same shape of rewrite as lower-tri; differs + only in which side of the range carries the stream-var reference and + in the predicate direction. + - **Banded** (``v: range(u() - k, u() + k + 1)`` — two-sided dependent + bounds with constant width): bbox is ``range(0, N + k)`` (or similar + bounded by both endpoints' extents), guard is + ``(v() >= u() - k) & (v() < u() + k + 1)``. Needs both-sides + affine-bound recognition. + - **Strided dependent** (``v: range(0, u(), k)`` for ``k != 1``): bbox + stays ``range(0, N)`` and guard becomes + ``(v() < u()) & (v() % k == 0)`` (or equivalent), or alternatively + embed in a smaller bbox ``range(0, ceil(N/k))`` and remap the index. + - **Affine bounds** (``v: range(a*u() + b, c*u() + d)`` for affine + coefficients): bbox computed from ``ub(c*u() + d)`` over ``u``'s + range; guard is the conjunction of the two affine constraints. This + subsumes the upper/banded/strided cases under one affine recogniser. + - **Multi-stream-var dependent** (``v: range(u() + w())`` referencing + more than one outer stream var): bbox is the affine combination over + both referents' ranges; guard threads through all dependencies. + - **Reverse-order dependent ranges**: e.g. ``v: range(u(), 0, -1)``; + needs to handle negative step and the corresponding reverse + enumeration. + """ + + @implements(Monoid.reduce) + def _(self, monoid: Monoid, body, streams: Streams): + stream_vars = set(streams.keys()) + + # streams of the form k: range(X) + simple_ranges = { + k: v + for (k, v) in streams.items() + if isinstance(v, Term) and _is_simple_range(v) + } + for u, u_stream in simple_ranges.items(): + if fvsof(u_stream) & stream_vars: + continue + + for v, v_stream in simple_ranges.items(): + if ( + isinstance(v_stream, Term) + and isinstance(_range_stop(v_stream), Term) + and _range_stop(v_stream).op == u + ): + fresh_streams = { + a: (u_stream if a == v else b) for (a, b) in streams.items() + } + + # there are other commuting rules for delta that we do not + # currently include + if isinstance(body, Term) and body.op == delta: + fresh_body = delta( + body.args[0], + jnp.where(v() < u(), body.args[1], monoid.identity), # type: ignore[arg-type] + ) + else: + fresh_body = jnp.where(v() < u(), body, monoid.identity) + + return monoid.reduce(fresh_body, fresh_streams) + + return fwd() + + +class ReduceRange(ObjectInterpretation): + """Replace concrete-range stream values with materialized ``jnp.arange``. + + reduce(M, streams ∪ {v: range(a, b, s)}, body) + ≡ reduce(M, streams ∪ {v: jnp.arange(a, b, s)}, body) + + when ``a``, ``b``, ``s`` are concrete and ``body`` is not a delta term. + Delegates the actual reduction to whichever handler picks up the + materialized ``jax.Array`` streams. + """ + + @implements(Monoid.reduce) + def _(self, monoid: Monoid, body, streams: Streams): + if isinstance(body, Term) and body.op == delta: + return fwd() + + new_streams: dict = {} + any_replaced = False + for k, v in streams.items(): + if isinstance(v, Term) and v.op == range: + new_streams[k] = jnp.arange( + _range_start(v), _range_stop(v), _range_step(v) + ) + any_replaced = True + else: + new_streams[k] = v + + if not any_replaced: + return fwd() + return monoid.reduce(body, new_streams) + + +# Cross-cutting delta rules not yet implemented: +# +# - **Delta-commuting** (DC-hoist): for any pure op ``f`` (no Scoped binders +# that intersect a delta's index ops), push delta outward: +# f(args..., delta(idx, body), args...) +# ≡ delta(idx, f(args..., body, args...)) +# This normalizes delta to the outermost position so the reduce rules can +# pattern-match ``isinstance(body, Term) and body.op == delta`` cleanly. +# The soundness condition is mechanical via ``op.__fvs_rule__``: refuse to +# commute when a non-delta arg's scope binds any op in the delta's idx. +# +# - **Delta-merging** (DC-merge): under a pure binary op ``f`` (or +# generalized n-ary), merge multiple deltas when their index tuples are +# subsequence-compatible: +# f(delta(idx_a, v), delta(idx_b, w)) ≡ delta(idx_max, f(v, w)) +# where ``idx_max`` is the longer of ``idx_a``, ``idx_b`` and ``idx_a`` is +# a subsequence of ``idx_b`` (or vice versa). Refuse to fire when neither +# is a subsequence of the other, since that would silently insert an +# outer-product broadcast. +# +# - **Empty-domain detection at the term level**: currently size-0 named +# dims must be resolved by leaf consumers (``bind_dims``, reductors with +# ``initial=monoid.identity``). The empty-domain check is intentionally +# NOT a rule on its own — rewrites stay size-polymorphic and leaf ops +# carry the burden. See the conversation in monoid.py's history for why. + + NormalizeIntp.extend( ArrayReduce(), + ReduceRange(), + ReduceDeltaIndependent(), + ReduceDependentRangeMask(), SumPlusJax(), ProductPlusJax(), MinPlusJax(), diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 70bb50022..c9231510c 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -16,7 +16,6 @@ Scoped, deffn, implements, - iter_, syntactic_eq, syntactic_hash, ) @@ -96,8 +95,11 @@ def reduce[A, B, U: Body]( if isinstance(stream_body, Term): continue stream_values_iter = iter(stream_body) - if isinstance(stream_values_iter, Term) and stream_values_iter.op is iter_: + + # if we iterate and get a term instead of a real iterator, skip + if isinstance(stream_values_iter, Term): continue + new_reduces = [] for stream_val in stream_values_iter: with handler({stream_key: deffn(stream_val)}): diff --git a/tests/test_handlers_jax_monoid.py b/tests/test_handlers_jax_monoid.py index 35d041fe2..fe888ad43 100644 --- a/tests/test_handlers_jax_monoid.py +++ b/tests/test_handlers_jax_monoid.py @@ -1,11 +1,20 @@ import jax import pytest +from jax import random as random import effectful.handlers.jax.numpy as jnp from effectful.handlers.jax import bind_dims, unbind_dims -from effectful.handlers.jax.monoid import ArrayReduce, LogSumExp +from effectful.handlers.jax.monoid import ( + ArrayReduce, + LogSumExp, + ReduceDeltaIndependent, + ReduceDependentRangeMask, + delta, +) +from effectful.handlers.jax.monoid import range as Range from effectful.handlers.jax.scipy.special import logsumexp -from effectful.ops.monoid import Max, Min, Product, Sum +from effectful.ops.monoid import Max, Min, NormalizeIntp, Product, Sum +from effectful.ops.semantics import handler from tests._monoid_helpers import JAX_BACKEND, Backend, check_rewrite, define_vars MONOIDS = [ @@ -94,3 +103,167 @@ def test_reduce_array_3(monoid, reductor, backend: Backend): backend=backend, free_vars=[x, y, k1, k2, X, f, g], ) + + +# --------------------------------------------------------------------------- +# Delta rules. All tests use the operation form ``delta(idx, body)`` rather +# than the ``Delta`` dataclass; the delta op is the user-facing surface. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_delta_empty(monoid, reductor, backend: Backend): + """An empty-index delta unwraps to its body. + + reduce(M, streams, delta((), body)) ≡ reduce(M, streams, body) + """ + x = define_vars("x", typ=backend.scalar_typ) + X = define_vars("X", typ=backend.stream_typ) + + lhs = monoid.reduce(delta((), x()), {x: X()}) + rhs = monoid.reduce(x(), {x: X()}) + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ReduceDeltaIndependent(), + backend=backend, + free_vars=[x, X], + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_delta_independent_one(monoid, reductor, backend: Backend): + """One R1 step: peel the final preserved index off a delta. + + reduce(M, {y: Y()}, delta((y(),), f(y()))) + ≡ reduce(M, {}, delta((), bind_dims(f(unbind_dims(Y(), k)), k))) + """ + (y, k) = define_vars("y", "k", typ=backend.scalar_typ) + Y = define_vars("Y", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=1, ret="scalar") + + # We use a concrete range here instead of an abstract one, because + # unbind_dims is undefined on empty arrays (and the rewrite produces a + # different rhs in this case) + lhs = monoid.reduce(delta((y(),), f(y())), {y: Range(3)}) + rhs = monoid.reduce(bind_dims(f(unbind_dims(jnp.arange(3), k)), k), {}) + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ReduceDeltaIndependent(), + backend=backend, + free_vars=[y, k, Y, f], + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_delta_independent_preserves_others(monoid, reductor, backend: Backend): + """R1 peels only the final index. Streams not matching the peeled index op + stay untouched, as do earlier entries in the index tuple. + + reduce(M, {x: X(), y: Y()}, delta((x(), y()), f(x(), y()))) + ≡ reduce(M, {x: X()}, delta((x(),), bind_dims(f(x(), unbind_dims(Y(), k)), k))) + """ + (x, y, k) = define_vars("x", "y", "k", typ=backend.scalar_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") + + lhs = monoid.reduce(delta((x(), y()), f(x(), y())), {x: Range(2), y: Range(3)}) + rhs = monoid.reduce( + bind_dims( + bind_dims( + f(unbind_dims(jnp.arange(2), x), unbind_dims(jnp.arange(3), k)), k + ), + x, + ), + {}, + ) + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ReduceDeltaIndependent(), + backend=backend, + free_vars=[f], + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_dependent_range_mask(monoid, reductor, backend: Backend): + """A dependent range stream gets rewritten to the referent's bbox stream, + with the original constraint folded into the body as a where-guard. + + reduce(M, {u: range(0, N, 1), v: range(0, u(), 1)}, body) + ≡ reduce(M, {u: range(0, N, 1), v: range(0, N, 1)}, where(v() < u(), body, M.identity)) + """ + (u, v) = define_vars("u", "v", typ=backend.scalar_typ) + N = 5 + f = backend.fresh_op("f", n_args=2, ret="scalar") + + body = f(u(), v()) + + lhs = monoid.reduce(body, {u: Range(0, N, 1), v: Range(0, u(), 1)}) + rhs = monoid.reduce( + jnp.where(v() < u(), body, monoid.identity), + {u: Range(0, N, 1), v: Range(0, N, 1)}, + ) + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ReduceDependentRangeMask(), + backend=backend, + free_vars=[u, v, f], + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_dependent_range_mask_delta_body(monoid, reductor, backend: Backend): + """When the body is a delta term, R4 folds the constraint into the delta's + weight while leaving its index tuple untouched. + + reduce(M, {u: range(N), v: range(u())}, delta((u(), v()), w)) + ≡ reduce(M, {u: range(N), v: range(N)}, + delta((u(), v()), where(v() < u(), w, M.identity))) + """ + (u, v) = define_vars("u", "v", typ=backend.scalar_typ) + N = 5 + f = backend.fresh_op("f", n_args=2, ret="scalar") + + weight = f(u(), v()) + idx = (u(), v()) + + lhs = monoid.reduce(delta(idx, weight), {u: Range(0, N, 1), v: Range(0, u(), 1)}) + rhs = monoid.reduce( + delta(idx, jnp.where(v() < u(), weight, monoid.identity)), + {u: Range(0, N, 1), v: Range(0, N, 1)}, + ) + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ReduceDependentRangeMask(), + backend=backend, + free_vars=[u, v, f], + ) + + +def test_reduce_matmul(): + key = jax.random.PRNGKey(0) + # Define dimensions + B, I, J, K = 2, 3, 4, 5 + + # Create sample matrices + X = random.normal(key, (B, I, J)) + Y = random.normal(key, (B, J, K)) + (b, i, j, k) = define_vars("b", "i", "j", "k", typ=jax.Array) + + with handler(NormalizeIntp): + actual = Sum.reduce( + delta((b(), i(), k()), unbind_dims(X, b, i, j) * unbind_dims(Y, b, j, k)), + {b: Range(B), i: Range(I), j: Range(J), k: Range(K)}, + ) + + expected = jnp.einsum("bij,bjk->bik", X, Y) + assert jnp.allclose(actual, expected) From f71eb542374ffb47d1a0cb0ec6e75642a8d51c44 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 2 Jun 2026 13:37:52 -0400 Subject: [PATCH 6/6] Add weighted streams (#665) * more precise stream type * add tests for weighted rules * add reduction rule for weighted streams and tests * add test to demo expectation * add numpyro monoid module * add quadrature * add tests * wip * refactor tests * wip * test composition of lifting and weighting * drop numpyro changes * drop unused ops * lint * make weighted a Monoid method * fix typing of jax arrays * change weighted typing to take callable * fix test * fix test * resolve type aliases before dispatching * wip * wip * remove typeof_full * wip * wip * wip * format * refactor test harness * drop unused test --- effectful/handlers/jax/_terms.py | 7 + effectful/handlers/jax/monoid.py | 15 +- effectful/internals/unification.py | 11 + effectful/ops/monoid.py | 125 +++++- effectful/ops/semantics.py | 9 +- tests/_monoid_helpers.py | 466 ++++++++++++---------- tests/test_handlers_jax_monoid.py | 184 ++++----- tests/test_internals_unification.py | 7 + tests/test_ops_monoid.py | 587 ++++++++++++++++------------ 9 files changed, 845 insertions(+), 566 deletions(-) diff --git a/effectful/handlers/jax/_terms.py b/effectful/handlers/jax/_terms.py index c88fe9341..05a5390e7 100644 --- a/effectful/handlers/jax/_terms.py +++ b/effectful/handlers/jax/_terms.py @@ -14,10 +14,17 @@ unbind_dims, ) from effectful.internals.tensor_utils import _desugar_tensor_index +from effectful.internals.unification import Box, nested_type from effectful.ops.syntax import defdata from effectful.ops.types import Expr, NotHandled, Operation, Term +@nested_type.register(jax.Array) +@nested_type.register(jax._src.core.Tracer) +def _(value): + return Box(jax.Array) + + class _IndexUpdateHelper: """Helper class to implement array-style .at[index].set() updates for effectful arrays.""" diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py index 42d7866ec..3f6273be3 100644 --- a/effectful/handlers/jax/monoid.py +++ b/effectful/handlers/jax/monoid.py @@ -16,6 +16,7 @@ Product, Streams, Sum, + distributes_over, outer_stream, ) from effectful.ops.semantics import evaluate, fvsof, fwd, handler, typeof @@ -38,6 +39,10 @@ def cartesian_prod(x, y): LogSumExp = Monoid(name="LogSumExp", identity=jnp.asarray(float("-inf"))) +# ``Sum`` in log space is multiplication, which distributes over ``LogSumExp``: +# a + logsumexp(b, c) = logsumexp(a + b, a + c) +distributes_over.register(Sum, LogSumExp) + def _jax_args(args): """True iff ``args`` is non-empty and every arg is a concrete @@ -108,7 +113,15 @@ def plus(self, *args): if not isinstance(a, jax.Array): return fwd() result = a if result is None else cartesian_prod(result, a) - return result if result is not None else CartesianProduct.identity + if result is None: + return CartesianProduct.identity + # CartesianProduct values are streams of rows. ``cartesian_prod`` + # already lifts 1D inputs to 2D, but a single-array call seeds + # ``result = a`` unchanged — promote so the rank invariant holds for + # every array-path return. + if result.ndim == 1: + result = result[:, None] + return result ARRAY_REDUCTORS = { diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index e425bba6c..2eadaeab6 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -556,6 +556,17 @@ def _unify_generic(typ, subtyp, subs: Substitutions) -> Substitutions: and issubclass(subtyp, typing.get_origin(typ)) ): return subs # implicit expansion to subtyp[Any] + elif isinstance(typ, GenericAlias): + # Special case for treating arrays as iterables of arrays + try: + import jax + + if typing.get_origin(typ) is collections.abc.Iterable and issubclass( + subtyp, jax.Array + ): + return unify(typing.get_args(typ)[0], jax.Array, subs) + except ImportError: + pass raise TypeError(f"Cannot unify generic type {typ} with {subtyp} given {subs}.") diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index c9231510c..76351fa62 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -10,7 +10,14 @@ from typing import Annotated, Any from effectful.internals.disjoint_set import DisjointSet -from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler +from effectful.ops.semantics import ( + coproduct, + evaluate, + fvsof, + fwd, + handler, + typeof, +) from effectful.ops.syntax import ( ObjectInterpretation, Scoped, @@ -19,11 +26,17 @@ syntactic_eq, syntactic_hash, ) -from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term +from effectful.ops.types import ( + Expr, + Interpretation, + NotHandled, + Operation, + Term, +) + +type Stream[T] = Iterable[T] -# Note: The streams value type should be something like Iterable[T], but some of -# our target stream types (e.g. jax.Array) are not subtypes of Iterable -type Streams[T] = Mapping[Operation[[], T], Any] +type Streams = Mapping[Operation[[], Any], Stream[Any]] type Body[T] = ( Iterable[T] @@ -34,9 +47,7 @@ ) -def outer_stream( - streams: Streams, -) -> Iterable[tuple[Operation, Expr, dict[Operation, Expr]]]: +def outer_stream(streams: Streams) -> Iterable[tuple[Operation, Stream, Streams]]: """Returns the streams that can be ordered outermost in the loop nest as well as the remaining streams in the nest. @@ -51,13 +62,13 @@ def outer_stream( ) -class Monoid[T]: +class Monoid[W]: """A monoid with ``plus`` and ``reduce`` :class:`Operation` s.""" _name: str - identity: T + identity: W - def __init__(self, identity: T, name: str): + def __init__(self, identity: W, name: str): self._name = name self.identity = identity @@ -111,6 +122,18 @@ def reduce[A, B, U: Body]( return self.plus(*new_reduces) raise NotHandled + @Operation.define + def weighted[T]( + self, stream: Stream[T], weight: Callable[[T], W] | Operation[[T], W] + ) -> Stream[T]: + """A stream paired with a per-element weight. ``var`` is an + :class:`Operation` standing for "an element of ``stream``"; ``weight`` + is an expression that uses ``var`` and evaluates to the weight of that + element. + + """ + raise NotHandled + class MonoidWithZero[T](Monoid[T]): zero: T @@ -175,6 +198,12 @@ def _is_monoid_reduce(op: Operation) -> bool: return isinstance(owner, Monoid) and op is owner.reduce +def _is_monoid_weighted(op: Operation) -> bool: + """True if ``op`` is the ``weighted`` operation of some :class:`Monoid`.""" + owner = getattr(op, "__self__", None) + return isinstance(owner, Monoid) and op is owner.weighted + + class PlusEmpty(ObjectInterpretation): """plus() = 0""" @@ -557,6 +586,78 @@ def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): return fwd() +class ReduceWeightedStream(ObjectInterpretation): + """reduce(M, body, {x: WM.weighted(s, v, w), ...}) = reduce(M, WM.plus(w[v:=x()], body), {x: s, ...}) + + requires distributes_over(WM, M). + + The substitution ``v -> x`` is done by beta-reducing ``deffn(w, v)`` on + ``x()`` — symbolic, no Python dispatch on the weight expression. + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + for k, v in streams.items(): + if isinstance(v, Term) and _is_monoid_weighted(v.op): + v_stream, v_weight = v.args + v_monoid = v.op.__self__ + if not distributes_over(v_monoid, monoid): + continue + w_at_k = v_weight(k()) + weighted_body = v_monoid.plus(w_at_k, body) + new_streams = {**streams, k: v_stream} + return monoid.reduce(weighted_body, new_streams) + return fwd() + + +class ReduceCartesianWeightedStream(ObjectInterpretation): + """``CartesianProduct.reduce`` over a :func:`weighted` body whose + ``weight`` is independent of the plate (product-index) streams:: + + CartesianProduct.reduce(M.weighted(s, w), plates) + = M.weighted( + CartesianProduct.reduce(s, plates), + deffn(M.reduce(w, {e: row()}), row), + ) + + Reuses ``body``'s element binder ``e`` (already typed by construction); + introduces a fresh ``row`` binder typed as ``Iterable[elem_type]``. + + Only fires when ``w`` is independent of the plate vars. + """ + + @Operation.define + @staticmethod + def _iterable_elem[T](iter: Iterable[T]) -> T: + raise NotHandled + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if monoid is not CartesianProduct: + return fwd() + if not (isinstance(body, Term) and _is_monoid_weighted(body.op)): + return fwd() + + s, w = body.args + if not isinstance(s, Term) and len(s) == 0: + return CartesianProduct.reduce([], streams) + + if set(streams.keys()) & fvsof(w): + return fwd() + + elem_typ = typeof(self._iterable_elem(s)) + elem_op = Operation.define(elem_typ, name="elem") + row_op = Operation.define(Iterable[elem_typ], name="row") + + weight_monoid = body.op.__self__ + joint_weight = deffn( + weight_monoid.reduce(w(elem_op()), {elem_op: row_op()}), row_op + ) + joint_stream = CartesianProduct.reduce(s, streams) + + return weight_monoid.weighted(joint_stream, joint_weight) + + class MonoidOverCallable(ObjectInterpretation): """``monoid.reduce(f, streams) = lambda *a: monoid.reduce(f(*a), streams)``.""" @@ -751,6 +852,8 @@ def extend(self, *intps: Interpretation) -> typing.Self: ReduceSplit(), ReduceFactorization(), ReduceDistributeCartesianProduct(), + ReduceWeightedStream(), + ReduceCartesianWeightedStream(), PlusEmpty(), PlusSingle(), PlusIdentity(), diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index 8fd62bcd5..acfdf9fdb 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -287,6 +287,13 @@ def _evaluate_list_view(expr, **kwargs): def _simple_type(tp: type) -> type: """Convert a type object into a type that can be dispatched on.""" + + def _resolve_aliases(tp: type) -> type: + tp = typing.get_origin(tp) or tp + if isinstance(tp, typing.TypeAliasType): + return _resolve_aliases(tp.__value__) + return tp + if isinstance(tp, typing.TypeVar): tp = ( tp.__bound__ @@ -304,7 +311,7 @@ def _simple_type(tp: type) -> type: tp = functools.reduce(operator.or_, (type(arg) for arg in args)) if isinstance(tp, types.UnionType): raise TypeError(f"Union types are not supported: {tp}") - return typing.get_origin(tp) or tp + return _resolve_aliases(tp) def typeof[T](term: Expr[T]) -> type[T]: diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index f15103e30..f8089bec7 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -1,146 +1,22 @@ +import builtins import itertools -from collections.abc import Callable, Mapping, Sequence -from dataclasses import dataclass -from typing import Any, get_args, get_origin +import typing +from abc import ABC, abstractmethod +from collections.abc import Callable, Mapping +from typing import Any, Literal, overload import jax from hypothesis import given, settings from hypothesis import strategies as st +from hypothesis.strategies import SearchStrategy import effectful.handlers.jax.numpy as _jnp from effectful.internals.runtime import interpreter -from effectful.ops.monoid import NormalizeIntp -from effectful.ops.semantics import apply, evaluate, handler +from effectful.ops.monoid import NormalizeIntp, Stream, _is_monoid_weighted +from effectful.ops.semantics import apply, evaluate, fvsof, handler from effectful.ops.syntax import _BaseTerm, defdata, deffn, syntactic_eq from effectful.ops.types import NotHandled, Operation, Term -_JAX_ARRAY_SHAPE = (2,) - - -def _jax_array_value_strategy() -> st.SearchStrategy[jax.Array]: - return st.lists( - st.integers(min_value=-5, max_value=5), - min_size=_JAX_ARRAY_SHAPE[0], - max_size=_JAX_ARRAY_SHAPE[0], - ).map(lambda xs: jax.numpy.asarray(xs, dtype=jax.numpy.float32)) - - -# Unary jax fns map a scalar to a 1-D array (analogous to ``_UNARY_LIST_FNS`` -# for ints). Uses the effectful-wrapped jnp so named-dim broadcasting works. -_UNARY_JAX_FNS: list[Callable[[jax.Array], jax.Array]] = [ - lambda a: _jnp.stack([a, a + 1]), - lambda a: _jnp.stack([a, -a]), - lambda a: _jnp.stack([a, a + 1, 2 * a]), -] - -_BINARY_JAX_FNS: list[Callable[[jax.Array, jax.Array], jax.Array]] = [ - lambda a, b: a + b, - lambda a, b: a - b, - lambda a, b: a * b, -] - - -def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: - """Strategy for the value an *0-arg* Operation should return.""" - if annotation is int: - return st.integers(min_value=-100, max_value=100) - if annotation is float: - return st.floats(allow_nan=False) - if get_origin(annotation) is list and get_args(annotation) == (int,): - return st.lists(st.integers(min_value=-100, max_value=100), max_size=2) - if annotation is jax.Array: - return _jax_array_value_strategy() - if get_origin(annotation) is list and get_args(annotation) == (jax.Array,): - return st.lists(_jax_array_value_strategy(), max_size=2) - raise NotImplementedError( - f"No value strategy for return annotation {annotation!r}; " - "supported: int, list[int], jax.Array, list[jax.Array]" - ) - - -_UNARY_NUM_FNS: list[Callable[[int], int]] = [ - lambda x: x, - lambda x: x + 1, - lambda x: x - 1, - lambda x: -x, - lambda x: 2 * x, - lambda x: 3 * x + 1, -] - -_BINARY_NUM_FNS: list[Callable[[int, int], int]] = [ - lambda x, y: x + y, - lambda x, y: x - y, - lambda x, y: x * y, - lambda x, y: x + 2 * y, - lambda x, y: 2 * x - y, -] - -_UNARY_LIST_FNS: list[Callable[[int], list[int]]] = [ - lambda _x: [], - lambda x: [x], - lambda x: [x, x + 1], - lambda x: [x, -x], - lambda x: [0, x, x + 1], -] - -_UNARY_JAX_LIST_FNS: list[Callable[[jax.Array], list[jax.Array]]] = [ - lambda _x: [], - lambda x: [x], - lambda x: [x, x + 1], - lambda x: [x, -x], -] - - -def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: - """Pick a strategy producing a callable suitable for binding `op` in an - interpretation. Inspects the operation's signature. - """ - sig = op.__signature__ - params = list(sig.parameters.values()) - ret = sig.return_annotation - param_types = tuple(p.annotation for p in params) - - if not params: - return _value_strategy_for(ret).map(deffn) - if ret in (int, float) and param_types == (int,): - return st.sampled_from(_UNARY_NUM_FNS) - if ret in (int, float) and param_types == (int, int): - return st.sampled_from(_BINARY_NUM_FNS) - if get_origin(ret) is list and get_args(ret) == (int,) and param_types == (int,): - return st.sampled_from(_UNARY_LIST_FNS) - if ret is jax.Array and param_types == (jax.Array,): - return st.sampled_from(_UNARY_JAX_FNS) - if ret is jax.Array and param_types == (jax.Array, jax.Array): - return st.sampled_from(_BINARY_JAX_FNS) - if ( - get_origin(ret) is list - and get_args(ret) == (jax.Array,) - and param_types == (jax.Array,) - ): - return st.sampled_from(_UNARY_JAX_LIST_FNS) - raise NotImplementedError( - f"No callable strategy for free var with return {ret!r}, params {param_types!r}" - ) - - -@st.composite -def random_interpretation( - draw: st.DrawFn, free_vars: Sequence[Operation] -) -> Mapping[Operation, Callable[..., Any]]: - """Draw an Interpretation binding every Operation in `case.free_vars` to - a randomly chosen value/callable. Keys are Operation identities. - """ - intp: dict[Operation, Callable[..., Any]] = {} - for op in free_vars: - intp[op] = draw(_strategy_for_op(op)) - return intp - - -def define_vars(*names, typ=int): - if len(names) == 1: - return Operation.define(typ, name=names[0]) - return tuple(Operation.define(typ, name=n) for n in names) - def syntactic_eq_alpha(x, y) -> bool: """Alpha-equivalence-respecting variant of ``syntactic_eq``. @@ -251,8 +127,7 @@ def _apply_canonical(op, *args, **kwargs) -> Term: return evaluate(expr) -@dataclass(frozen=True) -class Backend: +class Backend(ABC): """A value-domain spec used to share monoid tests across int and jax.Array backends. Provides the concrete value type, the hypothesis strategy for drawing scalars in property tests, and an equality predicate that works @@ -262,10 +137,29 @@ class Backend: name: str scalar_typ: Any stream_typ: Any - scalar_strategy: st.SearchStrategy[Any] - eq: Callable[[Any, Any], bool] - - def fresh_op(self, name: str, n_args: int = 1, ret: str = "scalar") -> Operation: + strategy_for_op: dict[Operation, st.SearchStrategy[Callable[..., Any]]] + + def __init__(self): + self.strategy_for_op = {} + + @abstractmethod + def eq(self, a: Any, b: Any) -> bool: + raise NotImplementedError + + @abstractmethod + def strategy( + self, + arg_types: tuple[type, ...] = (), + ret: Literal["scalar", "stream"] = "scalar", + ) -> SearchStrategy: + raise NotImplementedError + + def _fresh_op( + self, + name: str, + arg_types: tuple[type, ...] = (), + ret: Literal["scalar", "stream"] = "scalar", + ) -> Operation: """Build a fresh, unhandled Operation whose parameter and return annotations are derived from this backend. @@ -275,81 +169,245 @@ def fresh_op(self, name: str, n_args: int = 1, ret: str = "scalar") -> Operation """ scalar = self.scalar_typ out = self.stream_typ if ret == "stream" else scalar - params = ", ".join(f"_a{i}" for i in range(n_args)) + params = ", ".join(f"_a{i}" for i in range(len(arg_types))) ns: dict[str, Any] = {"NotHandled": NotHandled} exec(f"def _fn({params}):\n raise NotHandled\n", ns) fn = ns["_fn"] fn.__annotations__ = { - **{f"_a{i}": scalar for i in range(n_args)}, + **{f"_a{i}": t for i, t in enumerate(arg_types)}, "return": out, } - return Operation.define(fn, name=name) + op = Operation.define(fn, name=name) + self.strategy_for_op[op] = self.strategy(arg_types, ret) + return op + + @overload + def define_vars(self, name: str, /, **kwargs) -> Operation: ... + + @overload + def define_vars( + self, n1: str, n2: str, /, *names: str, **kwargs + ) -> tuple[Operation, ...]: ... + + def define_vars(self, *names: str, **kwargs) -> Operation | tuple[Operation, ...]: # type: ignore[misc] + if len(names) == 1: + return self._fresh_op(names[0], **kwargs) + return tuple(self._fresh_op(n, **kwargs) for n in names) + + def check_rewrite( + self, + lhs, + rhs, + rule, + *, + max_examples: int = 25, + deadline=None, + normalize=NormalizeIntp, + ) -> None: + with handler(rule): + norm = evaluate(lhs) + assert syntactic_eq_alpha(norm, rhs) + + fvs = fvsof(lhs) | fvsof(rhs) + + @st.composite + def random_interpretation( + draw: st.DrawFn, + ) -> Mapping[Operation, Callable[..., Any]]: + """Draw an Interpretation binding every Operation in `free_vars` to + a randomly chosen value/callable. Keys are Operation identities. + """ + intp: dict[Operation, Callable[..., Any]] = {} + for op, strategy in self.strategy_for_op.items(): + if op in fvs: + intp[op] = draw(strategy) + return intp + + @given(intp=random_interpretation()) + @settings( + max_examples=max_examples, deadline=deadline, report_multiple_bugs=False + ) + def _check_semantics(intp): + with handler(normalize), handler(intp): + lhs_val = evaluate(lhs) + rhs_val = evaluate(rhs) + assert self.eq(lhs_val, rhs_val) + _check_semantics() -def _int_eq(a: Any, b: Any) -> bool: - return not isinstance(a, Term) and not isinstance(b, Term) and a == b +def _is_weighted(x: Any) -> bool: + return isinstance(x, Term) and _is_monoid_weighted(x.op) -def _jax_eq(a: Any, b: Any) -> bool: - def _leaf_eq(x: Any, y: Any) -> bool: - return bool(jax.numpy.all(jax.numpy.isclose(x, y, equal_nan=True))) - try: - leaves = jax.tree.leaves(jax.tree.map(_leaf_eq, a, b)) - except (ValueError, TypeError): +def _weight_pairs(x: Any, monoid: Any) -> list[tuple[Any, Any]] | None: + """Return ``(element, weight)`` pairs for a stream. + + A weighted-monoid Term yields each element paired with its weight. A plain + (unweighted) stream yields each element paired with ``monoid.identity`` -- + the no-op weight -- so an unweighted stream compares equal to a weighted one + exactly when every weight reduces to the identity (e.g. ``[()]`` vs a + weighted ``[()]`` whose single empty row reduces to the identity, and, more + generally, whenever both streams are empty). Returns ``None`` for a + non-stream Term, which never compares equal to a weighted stream. + """ + if isinstance(x, Term): + if not _is_monoid_weighted(x.op): + return None + stream, weight = x.args + assert not isinstance(stream, Term) + return [(e, typing.cast(Callable, weight)(e)) for e in stream] + return [(e, monoid.identity) for e in x] + + +def _weighted_stream_eq(a, b, leaf_eq: Callable[[Any, Any], bool]) -> bool: + monoids = {x.op.__self__ for x in (a, b) if _is_weighted(x)} + # distinct weight monoids can never be equal + if len(monoids) != 1: return False - return all(leaves) - - -def check_rewrite( - lhs, - rhs, - rule, - *, - backend: Backend, - free_vars=[], - max_examples: int = 25, - deadline=None, -) -> None: - with handler(rule): - norm = evaluate(lhs) - assert syntactic_eq_alpha(norm, rhs) - - @given(intp=random_interpretation(free_vars)) - @settings(max_examples=max_examples, deadline=deadline) - def _check_semantics(intp): - with handler(NormalizeIntp), handler(intp): - lhs_val = evaluate(lhs) - rhs_val = evaluate(rhs) - assert backend.eq(lhs_val, rhs_val) - - _check_semantics() - - -INT_BACKEND = Backend( - name="int", - scalar_typ=int, - stream_typ=list[int], - scalar_strategy=st.integers(min_value=-100, max_value=100), - eq=_int_eq, -) - - -JAX_BACKEND = Backend( - name="jax", - scalar_typ=jax.Array, - stream_typ=jax.Array, - scalar_strategy=_jax_array_value_strategy(), - eq=_jax_eq, -) - - -__all__ = [ - "Backend", - "INT_BACKEND", - "JAX_BACKEND", - "random_interpretation", - "define_vars", - "syntactic_eq_alpha", - "check_rewrite", -] + monoid = next(iter(monoids)) + + a_pairs = _weight_pairs(a, monoid) + b_pairs = _weight_pairs(b, monoid) + if a_pairs is None or b_pairs is None or len(a_pairs) != len(b_pairs): + return False + for (ea, wa), (eb, wb) in zip(a_pairs, b_pairs): + if not leaf_eq(ea, eb) or not leaf_eq(wa, wb): + return False + return True + + +class IntBackend(Backend): + name = "int" + scalar_typ = int + stream_typ = Stream[int] + + _unary_num_fns: list[Callable[[int], int]] = [ + lambda x: x, + lambda x: x + 1, + lambda x: x - 1, + lambda x: -x, + lambda x: 2 * x, + lambda x: 3 * x + 1, + ] + + _binary_num_fns: list[Callable[[int, int], int]] = [ + lambda x, y: x + y, + lambda x, y: x - y, + lambda x, y: x * y, + lambda x, y: x + 2 * y, + lambda x, y: 2 * x - y, + ] + + _unary_list_fns: list[Callable[[int], list[int]]] = [ + lambda _x: [], + lambda x: [x], + lambda x: [x, x + 1], + lambda x: [x, -x], + lambda x: [0, x, x + 1], + ] + + def strategy( + self, + arg_types: tuple[type, ...] = (), + ret: Literal["scalar", "stream"] = "scalar", + ) -> SearchStrategy: + match arg_types, ret: + case (), "scalar": + return st.integers(min_value=-100, max_value=100).map(deffn) + case (), "stream": + scalars = st.integers(min_value=-100, max_value=100) + return st.lists(scalars, max_size=2).map(deffn) + case (builtins.int,), "scalar": + return st.sampled_from(self._unary_num_fns) + case (builtins.int, builtins.int), "scalar": + return st.sampled_from(self._binary_num_fns) + case (builtins.int,), "stream": + return st.sampled_from(self._unary_list_fns) + raise NotImplementedError( + f"No int strategy for op with return {ret!r} and {arg_types} args" + ) + + def eq(self, a: Any, b: Any) -> bool: + if _is_weighted(a) or _is_weighted(b): + return _weighted_stream_eq(a, b, self.eq) + return not isinstance(a, Term) and not isinstance(b, Term) and a == b + + +class JaxBackend(Backend): + name = "jax" + scalar_typ = jax.Array + stream_typ = jax.Array + + _unary_jax_scalar_fns: list[Callable[[jax.Array], jax.Array]] = [ + lambda a: a, + lambda a: a + 1, + lambda a: a - 1, + lambda a: -a, + lambda a: 2 * a, + ] + + _unary_jax_stream_fns: list[Callable[[jax.Array], Stream[jax.Array]]] = [ + lambda a: _jnp.stack([a, a + 1]), + lambda a: _jnp.stack([a, -a]), + lambda a: _jnp.stack([a, a + 1, 2 * a]), + ] + + _binary_jax_scalar_fns: list[Callable[[jax.Array, jax.Array], jax.Array]] = [ + lambda a, b: a + b, + lambda a, b: a - b, + lambda a, b: a * b, + ] + + def strategy( + self, + arg_types: tuple[type, ...] = (), + ret: Literal["scalar", "stream"] = "scalar", + ) -> st.SearchStrategy[Callable]: + match arg_types, ret: + case (), "scalar": + return ( + st.lists( + st.integers(min_value=-5, max_value=5), + min_size=2, + max_size=2, + ) + .map(lambda xs: jax.numpy.asarray(xs, dtype=jax.numpy.float32)) + .map(deffn) + ) + case (), "stream": + return ( + st.lists( + st.integers(min_value=-5, max_value=5), + min_size=1, + max_size=2, + ) + .map(lambda xs: jax.numpy.asarray(xs, dtype=jax.numpy.float32)) + .map(deffn) + ) + case (jax.Array,), "scalar": + return st.sampled_from(self._unary_jax_scalar_fns) + case (jax.Array, jax.Array), "scalar": + return st.sampled_from(self._binary_jax_scalar_fns) + case (jax.Array,), "stream": + return st.sampled_from(self._unary_jax_stream_fns) + + raise NotImplementedError( + f"No jax strategy for op with return {ret!r} and {arg_types} args" + ) + + def eq(self, a: Any, b: Any) -> bool: + if _is_weighted(a) or _is_weighted(b): + return _weighted_stream_eq(a, b, self.eq) + + def _leaf_eq(x: Any, y: Any) -> bool: + return bool(jax.numpy.all(jax.numpy.isclose(x, y, equal_nan=True))) + + try: + leaves = jax.tree.leaves(jax.tree.map(_leaf_eq, a, b)) + except (ValueError, TypeError): + return False + return all(leaves) + + +__all__ = ["Backend", "IntBackend", "JaxBackend", "syntactic_eq_alpha"] diff --git a/tests/test_handlers_jax_monoid.py b/tests/test_handlers_jax_monoid.py index fe888ad43..18df84018 100644 --- a/tests/test_handlers_jax_monoid.py +++ b/tests/test_handlers_jax_monoid.py @@ -1,3 +1,6 @@ +import functools +import typing + import jax import pytest from jax import random as random @@ -7,15 +10,24 @@ from effectful.handlers.jax.monoid import ( ArrayReduce, LogSumExp, + ProductPlusJax, ReduceDeltaIndependent, ReduceDependentRangeMask, delta, ) from effectful.handlers.jax.monoid import range as Range from effectful.handlers.jax.scipy.special import logsumexp -from effectful.ops.monoid import Max, Min, NormalizeIntp, Product, Sum -from effectful.ops.semantics import handler -from tests._monoid_helpers import JAX_BACKEND, Backend, check_rewrite, define_vars +from effectful.ops.monoid import ( + Max, + Min, + NormalizeIntp, + Product, + ReduceWeightedStream, + Sum, +) +from effectful.ops.semantics import coproduct, handler +from effectful.ops.types import Interpretation +from tests._monoid_helpers import JaxBackend MONOIDS = [ pytest.param(Sum, jnp.sum, id="Sum"), @@ -27,28 +39,27 @@ @pytest.fixture -def backend() -> Backend: - return JAX_BACKEND +def backend() -> JaxBackend: + return JaxBackend() @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_array_1(monoid, reductor, backend: Backend): - (x, k) = define_vars("x", "k", typ=jax.Array) - X = define_vars("X", typ=backend.stream_typ) +def test_reduce_array_1(monoid, reductor, backend: JaxBackend): + (x, k) = backend.define_vars("x", "k", ret="scalar") + X = backend.define_vars("X", ret="stream") lhs = monoid.reduce(x(), {x: X()}) rhs = reductor(bind_dims(unbind_dims(X(), k), k), axis=0) - - check_rewrite( - lhs=lhs, rhs=rhs, rule=ArrayReduce(), backend=backend, free_vars=[x, X, k] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ArrayReduce()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_array_2(monoid, reductor, backend: Backend): - (x, y, k1, k2) = define_vars("x", "y", "k1", "k2", typ=backend.scalar_typ) - (X, Y) = define_vars("X", "Y", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=2, ret="scalar") +def test_reduce_array_2(monoid, reductor, backend: JaxBackend): + (x, y, k1, k2) = backend.define_vars("x", "y", "k1", "k2", ret="scalar") + (X, Y) = backend.define_vars("X", "Y", ret="stream") + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) lhs = monoid.reduce(f(x(), y()), {x: X(), y: Y()}) rhs = reductor( @@ -61,25 +72,20 @@ def test_reduce_array_2(monoid, reductor, backend: Backend): ), axis=0, ) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ArrayReduce(), - backend=backend, - free_vars=[x, y, k1, k2, X, Y, f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ArrayReduce()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_array_3(monoid, reductor, backend: Backend): +def test_reduce_array_3(monoid, reductor, backend: JaxBackend): """Stream `y` is `g(x())` — depends on the bound element of X. The reducer must inline ``g`` along the same named dim used to unbind `x`.""" - (x, y, k1, k2) = define_vars("x", "y", "k1", "k2", typ=backend.scalar_typ) - X = define_vars("X", typ=backend.stream_typ) + (x, y, k1, k2) = backend.define_vars("x", "y", "k1", "k2", ret="scalar") + X = backend.define_vars("X", ret="stream") - f = backend.fresh_op("f", n_args=2, ret="scalar") - g = backend.fresh_op("g", n_args=1, ret="stream") + f = backend.define_vars( + "f", arg_types=[backend.scalar_typ, backend.scalar_typ], ret="scalar" + ) + g = backend.define_vars("g", arg_types=[backend.scalar_typ], ret="stream") lhs = monoid.reduce(f(x(), y()), {x: X(), y: g(x())}) rhs = reductor( @@ -95,13 +101,37 @@ def test_reduce_array_3(monoid, reductor, backend: Backend): ), axis=0, ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ArrayReduce()) - check_rewrite( + +def test_jax_weighted_reduce(backend: JaxBackend): + """Sum over a single stream with ``Product`` weights lowers to + ``jnp.sum(w(X) * body(X))`` under ``NormalizeIntp`` ∘ ``ArrayReduce``. + + Verifies that the desugaring rule composes cleanly with the JAX lowering + so existing handlers need no changes to support weighted streams. + + """ + (x, k) = backend.define_vars("x", "k", ret="scalar") + X = backend.define_vars("X", ret="stream") + body = backend.define_vars("body", arg_types=[backend.scalar_typ], ret="scalar") + w = backend.define_vars("w", arg_types=[backend.scalar_typ], ret="scalar") + + ws = Product.weighted(X(), w) + lhs = Sum.reduce(body(x()), {x: ws}) + rhs = jnp.sum( + bind_dims(w(unbind_dims(X(), k)) * body(unbind_dims(X(), k)), k), axis=0 + ) + backend.check_rewrite( lhs=lhs, rhs=rhs, - rule=ArrayReduce(), - backend=backend, - free_vars=[x, y, k1, k2, X, f, g], + rule=functools.reduce( + coproduct, + typing.cast( + list[Interpretation], + [ReduceWeightedStream(), ArrayReduce(), ProductPlusJax()], + ), + ), ) @@ -112,62 +142,51 @@ def test_reduce_array_3(monoid, reductor, backend: Backend): @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_delta_empty(monoid, reductor, backend: Backend): +def test_reduce_delta_empty(monoid, reductor, backend: JaxBackend): """An empty-index delta unwraps to its body. reduce(M, streams, delta((), body)) ≡ reduce(M, streams, body) """ - x = define_vars("x", typ=backend.scalar_typ) - X = define_vars("X", typ=backend.stream_typ) + x = backend.define_vars("x", ret="scalar") + X = backend.define_vars("X", ret="stream") lhs = monoid.reduce(delta((), x()), {x: X()}) rhs = monoid.reduce(x(), {x: X()}) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceDeltaIndependent(), - backend=backend, - free_vars=[x, X], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDeltaIndependent()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_delta_independent_one(monoid, reductor, backend: Backend): +def test_reduce_delta_independent_one(monoid, reductor, backend: JaxBackend): """One R1 step: peel the final preserved index off a delta. reduce(M, {y: Y()}, delta((y(),), f(y()))) ≡ reduce(M, {}, delta((), bind_dims(f(unbind_dims(Y(), k)), k))) """ - (y, k) = define_vars("y", "k", typ=backend.scalar_typ) - Y = define_vars("Y", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=1, ret="scalar") + (y, k) = backend.define_vars("y", "k", ret="scalar") + f = backend.define_vars("f", arg_types=[backend.scalar_typ], ret="scalar") # We use a concrete range here instead of an abstract one, because # unbind_dims is undefined on empty arrays (and the rewrite produces a # different rhs in this case) lhs = monoid.reduce(delta((y(),), f(y())), {y: Range(3)}) rhs = monoid.reduce(bind_dims(f(unbind_dims(jnp.arange(3), k)), k), {}) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceDeltaIndependent(), - backend=backend, - free_vars=[y, k, Y, f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDeltaIndependent()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_delta_independent_preserves_others(monoid, reductor, backend: Backend): +def test_reduce_delta_independent_preserves_others( + monoid, reductor, backend: JaxBackend +): """R1 peels only the final index. Streams not matching the peeled index op stay untouched, as do earlier entries in the index tuple. reduce(M, {x: X(), y: Y()}, delta((x(), y()), f(x(), y()))) ≡ reduce(M, {x: X()}, delta((x(),), bind_dims(f(x(), unbind_dims(Y(), k)), k))) """ - (x, y, k) = define_vars("x", "y", "k", typ=backend.scalar_typ) - f = backend.fresh_op("f", n_args=2, ret="scalar") + (x, y, k) = backend.define_vars("x", "y", "k", ret="scalar") + f = backend.define_vars( + "f", arg_types=[backend.scalar_typ, backend.scalar_typ], ret="scalar" + ) lhs = monoid.reduce(delta((x(), y()), f(x(), y())), {x: Range(2), y: Range(3)}) rhs = monoid.reduce( @@ -179,27 +198,22 @@ def test_reduce_delta_independent_preserves_others(monoid, reductor, backend: Ba ), {}, ) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceDeltaIndependent(), - backend=backend, - free_vars=[f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDeltaIndependent()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_dependent_range_mask(monoid, reductor, backend: Backend): +def test_reduce_dependent_range_mask(monoid, reductor, backend: JaxBackend): """A dependent range stream gets rewritten to the referent's bbox stream, with the original constraint folded into the body as a where-guard. reduce(M, {u: range(0, N, 1), v: range(0, u(), 1)}, body) ≡ reduce(M, {u: range(0, N, 1), v: range(0, N, 1)}, where(v() < u(), body, M.identity)) """ - (u, v) = define_vars("u", "v", typ=backend.scalar_typ) + (u, v) = backend.define_vars("u", "v", ret="scalar") N = 5 - f = backend.fresh_op("f", n_args=2, ret="scalar") + f = backend.define_vars( + "f", arg_types=[backend.scalar_typ, backend.scalar_typ], ret="scalar" + ) body = f(u(), v()) @@ -208,18 +222,11 @@ def test_reduce_dependent_range_mask(monoid, reductor, backend: Backend): jnp.where(v() < u(), body, monoid.identity), {u: Range(0, N, 1), v: Range(0, N, 1)}, ) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceDependentRangeMask(), - backend=backend, - free_vars=[u, v, f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDependentRangeMask()) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_dependent_range_mask_delta_body(monoid, reductor, backend: Backend): +def test_reduce_dependent_range_mask_delta_body(monoid, reductor, backend: JaxBackend): """When the body is a delta term, R4 folds the constraint into the delta's weight while leaving its index tuple untouched. @@ -227,9 +234,11 @@ def test_reduce_dependent_range_mask_delta_body(monoid, reductor, backend: Backe ≡ reduce(M, {u: range(N), v: range(N)}, delta((u(), v()), where(v() < u(), w, M.identity))) """ - (u, v) = define_vars("u", "v", typ=backend.scalar_typ) + (u, v) = backend.define_vars("u", "v", ret="scalar") N = 5 - f = backend.fresh_op("f", n_args=2, ret="scalar") + f = backend.define_vars( + "f", arg_types=[backend.scalar_typ, backend.scalar_typ], ret="scalar" + ) weight = f(u(), v()) idx = (u(), v()) @@ -239,17 +248,10 @@ def test_reduce_dependent_range_mask_delta_body(monoid, reductor, backend: Backe delta(idx, jnp.where(v() < u(), weight, monoid.identity)), {u: Range(0, N, 1), v: Range(0, N, 1)}, ) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceDependentRangeMask(), - backend=backend, - free_vars=[u, v, f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDependentRangeMask()) -def test_reduce_matmul(): +def test_reduce_matmul(backend: JaxBackend): key = jax.random.PRNGKey(0) # Define dimensions B, I, J, K = 2, 3, 4, 5 @@ -257,7 +259,7 @@ def test_reduce_matmul(): # Create sample matrices X = random.normal(key, (B, I, J)) Y = random.normal(key, (B, J, K)) - (b, i, j, k) = define_vars("b", "i", "j", "k", typ=jax.Array) + (b, i, j, k) = backend.define_vars("b", "i", "j", "k", ret="scalar") with handler(NormalizeIntp): actual = Sum.reduce( diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index 8b93976fc..8abfdc5ac 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -1900,3 +1900,10 @@ class Info(typing.TypedDict): subs = unify(collections.abc.Mapping, Info) assert subs == {} + + +def test_unify_jax_array_iterable(): + import jax + + subs = unify(collections.abc.Iterable[T], jax.Array) + assert subs == {T: jax.Array} diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index c7ee7567c..fcd72f064 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -1,10 +1,13 @@ +import math import typing +from collections.abc import Iterable import pytest from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st import effectful.handlers.jax.monoid # noqa: F401 +import effectful.handlers.jax.numpy as jnp from effectful.ops.monoid import ( CartesianProduct, Max, @@ -22,29 +25,25 @@ PlusSingle, PlusZero, Product, + ReduceCartesianWeightedStream, ReduceDistributeCartesianProduct, ReduceFactorization, ReduceFusion, ReduceNoStreams, ReduceSplit, + ReduceWeightedStream, Sum, distributes_over, ) -from effectful.ops.semantics import fvsof, handler -from effectful.ops.types import Operation -from tests._monoid_helpers import ( - INT_BACKEND, - JAX_BACKEND, - Backend, - check_rewrite, - define_vars, - syntactic_eq_alpha, -) +from effectful.ops.semantics import coproduct, evaluate, fvsof, handler +from effectful.ops.syntax import deffn +from effectful.ops.types import NotHandled, Operation, Term +from tests._monoid_helpers import Backend, IntBackend, JaxBackend, syntactic_eq_alpha -@pytest.fixture(params=[INT_BACKEND, JAX_BACKEND], ids=["int", "jax"]) +@pytest.fixture(params=[IntBackend, JaxBackend], ids=["int", "jax"]) def backend(request) -> Backend: - return request.param + return request.param() ALL_MONOIDS = [ @@ -90,10 +89,10 @@ def backend(request) -> Backend: deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture], ) -def test_associativity(monoid, backend, data): - a = data.draw(backend.scalar_strategy) - b = data.draw(backend.scalar_strategy) - c = data.draw(backend.scalar_strategy) +def test_associativity(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() + b = data.draw(backend.strategy(ret="scalar"))() + c = data.draw(backend.strategy(ret="scalar"))() with handler(NormalizeIntp): left = monoid.plus(monoid.plus(a, b), c) right = monoid.plus(a, monoid.plus(b, c)) @@ -107,8 +106,8 @@ def test_associativity(monoid, backend, data): deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture], ) -def test_identity(monoid, backend, data): - a = data.draw(backend.scalar_strategy) +def test_identity(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() with handler(NormalizeIntp): assert backend.eq(monoid.plus(monoid.identity, a), a) assert backend.eq(monoid.plus(a, monoid.identity), a) @@ -121,9 +120,9 @@ def test_identity(monoid, backend, data): deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture], ) -def test_commutativity(monoid, backend, data): - a = data.draw(backend.scalar_strategy) - b = data.draw(backend.scalar_strategy) +def test_commutativity(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() + b = data.draw(backend.strategy(ret="scalar"))() with handler(NormalizeIntp): assert backend.eq(monoid.plus(a, b), monoid.plus(b, a)) @@ -135,8 +134,8 @@ def test_commutativity(monoid, backend, data): deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture], ) -def test_idempotence(monoid, backend, data): - a = data.draw(backend.scalar_strategy) +def test_idempotence(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() with handler(NormalizeIntp): assert backend.eq(monoid.plus(a, a), a) @@ -148,102 +147,86 @@ def test_idempotence(monoid, backend, data): deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture], ) -def test_zero_absorbs(monoid, backend, data): - a = data.draw(backend.scalar_strategy) +def test_zero_absorbs(monoid, backend: Backend, data): + a = data.draw(backend.strategy(ret="scalar"))() with handler(NormalizeIntp): assert backend.eq(monoid.plus(monoid.zero, a), monoid.zero) assert backend.eq(monoid.plus(a, monoid.zero), monoid.zero) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_empty(monoid, backend): - check_rewrite( - lhs=monoid.plus(), rhs=monoid.identity, rule=PlusEmpty(), backend=backend - ) +def test_plus_empty(monoid, backend: Backend): + backend.check_rewrite(lhs=monoid.plus(), rhs=monoid.identity, rule=PlusEmpty()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_single(monoid, backend): - x = define_vars("x", typ=backend.scalar_typ) - check_rewrite( - lhs=monoid.plus(x()), rhs=x(), rule=PlusSingle(), backend=backend, free_vars=[x] - ) +def test_plus_single(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") + backend.check_rewrite(lhs=monoid.plus(x()), rhs=x(), rule=PlusSingle()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_identity_right(monoid, backend): - x = define_vars("x", typ=backend.scalar_typ) +def test_plus_identity_right(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") lhs = monoid.plus(x(), monoid.identity) rhs = monoid.plus(x()) - check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity(), backend=backend, free_vars=[x]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_identity_left(monoid, backend): - x = define_vars("x", typ=backend.scalar_typ) +def test_plus_identity_left(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") lhs = monoid.plus(monoid.identity, x()) rhs = monoid.plus(x()) - check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity(), backend=backend, free_vars=[x]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_assoc_right(monoid, backend): - x, y, z = define_vars("x", "y", "z", typ=backend.scalar_typ) - check_rewrite( +def test_plus_assoc_right(monoid, backend: Backend): + x, y, z = backend.define_vars("x", "y", "z", ret="scalar") + backend.check_rewrite( lhs=monoid.plus(x(), monoid.plus(y(), z())), rhs=monoid.plus(x(), y(), z()), rule=PlusAssoc(), - backend=backend, - free_vars=[x, y, z], ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_assoc_left(monoid, backend): - x, y, z = define_vars("x", "y", "z", typ=backend.scalar_typ) - check_rewrite( +def test_plus_assoc_left(monoid, backend: Backend): + x, y, z = backend.define_vars("x", "y", "z", ret="scalar") + backend.check_rewrite( lhs=monoid.plus(monoid.plus(x(), y()), z()), rhs=monoid.plus(x(), y(), z()), rule=PlusAssoc(), - backend=backend, - free_vars=[x, y, z], ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_sequence(monoid, backend): - a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) - check_rewrite( +def test_plus_sequence(monoid, backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") + backend.check_rewrite( lhs=monoid.plus((a(), b()), (c(), d())), rhs=(monoid.plus(a(), c()), monoid.plus(b(), d())), rule=MonoidOverSequence(), - backend=backend, - free_vars=[a, b, c, d], ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_mapping(monoid, backend): - a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) +def test_plus_mapping(monoid, backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") lhs = monoid.plus({0: a(), 1: b()}, {0: c(), 2: d()}) rhs = {0: monoid.plus(a(), c()), 1: monoid.plus(b()), 2: monoid.plus(d())} - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=MonoidOverMapping(), - backend=backend, - free_vars=[a, b, c, d], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=MonoidOverMapping()) -def test_plus_distributes(backend): - a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) +def test_plus_distributes(backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d())) rhs = Product.plus( Sum.plus( @@ -253,13 +236,11 @@ def test_plus_distributes(backend): Product.plus(b(), d()), ) ) - check_rewrite( - lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDistr()) -def test_plus_distributes_constant(backend): - a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) +def test_plus_distributes_constant(backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d()), 5) rhs = Product.plus( 5, @@ -270,13 +251,11 @@ def test_plus_distributes_constant(backend): Product.plus(b(), d()), ), ) - check_rewrite( - lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDistr()) -def test_plus_distributes_multiple(backend): - a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) +def test_plus_distributes_multiple(backend: Backend): + a, b, c, d = backend.define_vars("a", "b", "c", "d", ret="scalar") lhs = Sum.plus( Min.plus(a(), b()), Min.plus(c(), d()), @@ -297,238 +276,195 @@ def test_plus_distributes_multiple(backend): Sum.plus(b(), d()), ), ) - check_rewrite( - lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDistr()) @pytest.mark.parametrize("monoid", IDEMPOTENT) -def test_plus_idempotent_consecutive(monoid, backend): +def test_plus_idempotent_consecutive(monoid, backend: Backend): """``a, a, b → a, b`` — only consecutive duplicates collapse.""" - a, b = define_vars("a", "b", typ=backend.scalar_typ) + a, b = backend.define_vars("a", "b", ret="scalar") lhs = monoid.plus(a(), a(), b()) - return check_rewrite( - lhs=lhs, - rhs=monoid.plus(a(), b()), - rule=PlusConsecutiveDups(), - backend=backend, - free_vars=[a, b], + return backend.check_rewrite( + lhs=lhs, rhs=monoid.plus(a(), b()), rule=PlusConsecutiveDups() ) @pytest.mark.parametrize("monoid", IDEMPOTENT) -def test_plus_idempotent_non_consecutive(monoid, backend): +def test_plus_idempotent_non_consecutive(monoid, backend: Backend): """``a, b, a`` — Semilattice (Min/Max) collapses via commutative PlusDups.""" - a, b = define_vars("a", "b", typ=backend.scalar_typ) + a, b = backend.define_vars("a", "b", ret="scalar") lhs = monoid.plus(a(), b(), a()) rhs = monoid.plus(a(), b()) - check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups(), backend=backend, free_vars=[a, b]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups()) @pytest.mark.parametrize("monoid", [Min, Max]) -def test_plus_commutative_idempotent_long(monoid, backend): +def test_plus_commutative_idempotent_long(monoid, backend: Backend): """Long alternation collapses via commutative dedup (Min/Max only).""" - a, b = define_vars("a", "b", typ=backend.scalar_typ) + a, b = backend.define_vars("a", "b", ret="scalar") lhs = monoid.plus(a(), b(), a(), b(), b(), a(), a()) rhs = monoid.plus(a(), b()) - check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups(), backend=backend, free_vars=[a, b]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups()) @pytest.mark.parametrize("monoid", WITH_ZERO) -def test_plus_zero(monoid, backend): - a = define_vars("a", typ=backend.scalar_typ) +def test_plus_zero(monoid, backend: Backend): + a = backend.define_vars("a", ret="scalar") lhs_right = monoid.plus(a(), monoid.zero) lhs_left = monoid.plus(monoid.zero, a()) rhs = monoid.zero - check_rewrite( - lhs=lhs_right, rhs=rhs, rule=PlusZero(), backend=backend, free_vars=[a] - ) - check_rewrite( - lhs=lhs_left, rhs=rhs, rule=PlusZero(), backend=backend, free_vars=[a] - ) + backend.check_rewrite(lhs=lhs_right, rhs=rhs, rule=PlusZero()) + backend.check_rewrite(lhs=lhs_left, rhs=rhs, rule=PlusZero()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_partial_1(monoid, backend): - x, y = define_vars("x", "y", typ=backend.scalar_typ) +def test_partial_1(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") lhs = monoid.reduce(x(), {x: []}) rhs = monoid.identity - check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule={}) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_partial_2(monoid, backend): - x, y = define_vars("x", "y", typ=backend.scalar_typ) - Y = define_vars("Y", typ=backend.stream_typ) +def test_partial_2(monoid, backend: Backend): + x, y = backend.define_vars("x", "y", ret="scalar") + Y = backend.define_vars("Y", ret="stream") lhs = monoid.reduce(x(), {y: Y(), x: []}) rhs = monoid.identity - - check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, Y]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule={}) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_partial_3(monoid, backend): - x, y, a, b = define_vars("x", "y", "a", "b", typ=backend.scalar_typ) - Y = define_vars("Y", typ=backend.stream_typ) +def test_partial_3(monoid, backend: Backend): + x, y, a, b = backend.define_vars("x", "y", "a", "b", ret="scalar") + Y = backend.define_vars("Y", ret="stream") lhs = monoid.reduce(x(), {y: Y(), x: [a(), b()]}) rhs = monoid.plus(monoid.reduce(a(), {y: Y()}), monoid.reduce(b(), {y: Y()})) - - check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, a, b, Y]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule={}) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_partial_4(monoid, backend): - x, y, a, b = define_vars("x", "y", "a", "b", typ=backend.scalar_typ) - f = backend.fresh_op("f", n_args=1, ret="stream") +def test_partial_4(monoid, backend: Backend): + x, y, a, b = backend.define_vars("x", "y", "a", "b", ret="scalar") + f = backend.define_vars("f", arg_types=(backend.scalar_typ,), ret="stream") lhs = monoid.reduce(x(), {y: f(x()), x: [a(), b()]}) rhs = monoid.plus(monoid.reduce(a(), {y: f(a())}), monoid.reduce(b(), {y: f(b())})) - - check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, a, b, f]) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule={}) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_body_sequence(monoid, backend): - x = Operation.define(backend.scalar_typ, name="x") - X = Operation.define(backend.stream_typ, name="X") - f = backend.fresh_op("f", n_args=1, ret="scalar") - g = Operation.define(f, name="g") +def test_reduce_body_sequence(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") + X = backend.define_vars("X", ret="stream") + f, g = backend.define_vars("f", "g", arg_types=(backend.scalar_typ,), ret="scalar") lhs = monoid.reduce((f(x()), g(x())), {x: X()}) rhs = (monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=MonoidOverSequence(), - backend=backend, - free_vars=[X, f, g], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=MonoidOverSequence()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_body_sequence_2(monoid, backend): - x, y = define_vars("x", "y", typ=backend.scalar_typ) - X, Y = define_vars("X", "Y", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=1, ret="scalar") - g = Operation.define(f, name="g") +def test_reduce_body_sequence_2(monoid, backend: Backend): + x, y = backend.define_vars("x", "y", ret="scalar") + X, Y = backend.define_vars("X", "Y", ret="stream") + f, g = backend.define_vars("f", "g", arg_types=(backend.scalar_typ,), ret="scalar") lhs = monoid.reduce((f(x()), g(y())), {x: X(), y: Y()}) rhs = ( monoid.reduce(f(x()), {x: X(), y: Y()}), monoid.reduce(g(y()), {x: X(), y: Y()}), ) - - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=MonoidOverSequence(), - backend=backend, - free_vars=[X, Y, f, g], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=MonoidOverSequence()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_body_mapping(monoid, backend): - x = Operation.define(backend.scalar_typ, name="x") - X = Operation.define(backend.stream_typ, name="X") - f = backend.fresh_op("f", n_args=1, ret="scalar") - g = Operation.define(f, name="g") +def test_reduce_body_mapping(monoid, backend: Backend): + x = backend.define_vars("x", ret="scalar") + X = backend.define_vars("X", ret="stream") + f, g = backend.define_vars("f", "g", arg_types=(backend.scalar_typ,), ret="scalar") lhs = monoid.reduce({0: f(x()), 1: g(x())}, {x: X()}) rhs = { 0: monoid.reduce(f(x()), {x: X()}), 1: monoid.reduce(g(x()), {x: X()}), } - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=MonoidOverMapping(), - backend=backend, - free_vars=[X, f, g], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=MonoidOverMapping()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_no_streams(monoid, backend): - a = define_vars("a", typ=backend.scalar_typ) +def test_reduce_no_streams(monoid, backend: Backend): + a = backend.define_vars("a", ret="scalar") + lhs = monoid.reduce(a(), {}) rhs = monoid.identity - - check_rewrite( - lhs=lhs, rhs=rhs, rule=ReduceNoStreams(), backend=backend, free_vars=[a] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceNoStreams()) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_reduce(monoid, backend): - a, b = define_vars("a", "b", typ=backend.scalar_typ) - A, B = define_vars("A", "B", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=2, ret="scalar") +def test_reduce_reduce(monoid, backend: Backend): + a, b = backend.define_vars("a", "b", ret="scalar") + A, B = backend.define_vars("A", "B", ret="stream") + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) lhs = monoid.reduce(monoid.reduce(f(a(), b()), {a: A()}), {b: B()}) rhs = monoid.reduce(f(a(), b()), {a: A(), b: B()}) - - check_rewrite( - lhs=lhs, rhs=rhs, rule=ReduceFusion(), backend=backend, free_vars=[A, B, f] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFusion()) @pytest.mark.parametrize("monoid", COMMUTATIVE) -def test_reduce_plus(monoid, backend): - a, b = define_vars("a", "b", typ=backend.scalar_typ) - A, B = define_vars("A", "B", typ=backend.stream_typ) +def test_reduce_plus(monoid, backend: Backend): + a, b = backend.define_vars("a", "b", ret="scalar") + A, B = backend.define_vars("A", "B", ret="stream") + lhs = monoid.reduce(monoid.plus(a(), b()), {a: A(), b: B()}) rhs = monoid.plus( monoid.reduce(a(), {a: A(), b: B()}), monoid.reduce(b(), {a: A(), b: B()}), ) - check_rewrite( - lhs=lhs, rhs=rhs, rule=ReduceSplit(), backend=backend, free_vars=[A, B] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceSplit()) -def test_reduce_independent_1(backend): - a, b = define_vars("a", "b", typ=backend.scalar_typ) - A, B = define_vars("A", "B", typ=backend.stream_typ) +def test_reduce_independent_1(backend: Backend): + a, b = backend.define_vars("a", "b", ret="scalar") + A, B = backend.define_vars("A", "B", ret="stream") + lhs = Sum.reduce(Product.plus(a(), b()), {a: A(), b: B()}) rhs = Product.plus( Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b()), {b: B()}) ) - check_rewrite( - lhs=lhs, rhs=rhs, rule=ReduceFactorization(), backend=backend, free_vars=[A, B] - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFactorization()) -def test_reduce_independent_2(backend): - a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) - A, B, C = define_vars("A", "B", "C", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=2, ret="scalar") +def test_reduce_independent_2(backend: Backend): + a, b, c = backend.define_vars("a", "b", "c", ret="scalar") + A, B, C = backend.define_vars("A", "B", "C", ret="stream") + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c())), {a: A(), b: B(), c: C()}) rhs = Product.plus( Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceFactorization(), - backend=backend, - free_vars=[A, B, C, f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFactorization()) -def test_reduce_independent_3_negative(backend): +def test_reduce_independent_3_negative(backend: Backend): """Stream `b` depends on `a` (b: g(a())), so the proposed factorization is unsound — the normalizer must NOT apply it.""" - a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) - A, C = define_vars("A", "C", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=2, ret="scalar") - g = backend.fresh_op("g", n_args=1, ret="stream") + a, b, c = backend.define_vars("a", "b", "c", ret="scalar") + A, C = backend.define_vars("A", "C", ret="stream") + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) + g = backend.define_vars("g", arg_types=(backend.scalar_typ,), ret="stream") with handler(ReduceFactorization()): # ty:ignore[invalid-argument-type] lhs = Sum.reduce( @@ -542,10 +478,12 @@ def test_reduce_independent_3_negative(backend): assert not syntactic_eq_alpha(lhs, bogus_rhs) -def test_reduce_independent_4(backend): - a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) - A, B, C = define_vars("A", "B", "C", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=2, ret="scalar") +def test_reduce_independent_4(backend: Backend): + a, b, c = backend.define_vars("a", "b", "c", ret="scalar") + A, B, C = backend.define_vars("A", "B", "C", ret="stream") + f = backend.define_vars( + "f", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c()), 7), {a: A(), b: B(), c: C()}) rhs = Product.plus( @@ -553,39 +491,44 @@ def test_reduce_independent_4(backend): Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) - check_rewrite( - lhs=lhs, - rhs=rhs, - rule=ReduceFactorization(), - backend=backend, - free_vars=[A, B, C, f], - ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceFactorization()) + + +def test_reduce_cartesian_3(): + backend = JaxBackend() + i = backend.define_vars("i", ret="scalar") + + with handler(NormalizeIntp): + value = CartesianProduct.reduce(jnp.zeros(2), {i: jnp.arange(3)}) + assert value.shape == (2**3, 3) + + with handler(NormalizeIntp): + value = CartesianProduct.reduce(jnp.zeros(2), {i: jnp.arange(1)}) + assert value.shape == (2**1, 1) + + with handler(NormalizeIntp): + value = CartesianProduct.reduce(jnp.zeros(1), {i: jnp.arange(3)}) + assert value.shape == (1**3, 3) @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) -def test_reduce_lifted_1(outer, inner, backend): - a, i = define_vars("a", "i", typ=backend.scalar_typ) - A, N, A_domain = define_vars("A", "N", "A_domain", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=1, ret="scalar") +def test_reduce_lifted_1(outer, inner, backend: Backend): + a, i = backend.define_vars("a", "i", ret="scalar") + A, N, A_domain = backend.define_vars("A", "N", "A_domain", ret="stream") + f = backend.define_vars("f", arg_types=(backend.scalar_typ,), ret="scalar") - term1 = outer.reduce( + lhs = outer.reduce( inner.reduce(f(a()), {a: A()}), {A: CartesianProduct.reduce(A_domain(), {i: N()})}, ) - term2 = inner.reduce(outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N()}) - - check_rewrite( - lhs=term1, - rhs=term2, - rule=ReduceDistributeCartesianProduct(), - backend=backend, - free_vars=[N, A_domain, f], - ) + rhs = inner.reduce(outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N()}) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDistributeCartesianProduct()) def test_reduce_cartesian_1(): - a, i = define_vars("a", "i", typ=int) - A = define_vars("A", typ=tuple[int]) + backend = IntBackend() + a, i = backend.define_vars("a", "i", ret="scalar") + A = backend.define_vars("A", ret="stream") with handler(NormalizeIntp): term1 = Sum.reduce( @@ -597,8 +540,9 @@ def test_reduce_cartesian_1(): def test_reduce_cartesian_2(): - a, i = define_vars("a", "i", typ=int) - A = define_vars("A", typ=tuple[int]) + backend = IntBackend() + a, i = backend.define_vars("a", "i", ret="scalar") + A = backend.define_vars("A", ret="stream") with handler(NormalizeIntp): term1 = Sum.reduce( @@ -610,46 +554,41 @@ def test_reduce_cartesian_2(): @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) -def test_reduce_lifted_multi_index(outer, inner, backend): - a, i, j = define_vars("a", "i", "j", typ=backend.scalar_typ) - A, N, M, A_domain = define_vars("A", "N", "M", "A_domain", typ=backend.stream_typ) - f = backend.fresh_op("f", n_args=1, ret="scalar") +def test_reduce_lifted_multi_index(outer, inner, backend: Backend): + a, i, j = backend.define_vars("a", "i", "j", ret="scalar") + A, N, M, A_domain = backend.define_vars("A", "N", "M", "A_domain", ret="stream") + f = backend.define_vars("f", arg_types=(backend.scalar_typ,), ret="scalar") - term1 = outer.reduce( + lhs = outer.reduce( inner.reduce(f(a()), {a: A()}), {A: CartesianProduct.reduce(A_domain(), {i: N(), j: M()})}, ) - term2 = inner.reduce( - outer.reduce(inner.plus(f(a())), {a: A_domain()}), - {i: N(), j: M()}, - ) - check_rewrite( - lhs=term1, - rhs=term2, - rule=ReduceDistributeCartesianProduct(), - backend=backend, - free_vars=[N, M, A_domain, f], + rhs = inner.reduce( + outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N(), j: M()} ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDistributeCartesianProduct()) @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) -def test_reduce_lifted_2(outer, inner, backend): +def test_reduce_lifted_2(outer, inner, backend: Backend): """The worked example on page 396 of 'Lifted Variable Elimination: Decoupling the Operators from the Constraint Language'. """ - a, i, s, t = define_vars("a", "i", "s", "t", typ=backend.scalar_typ) - A, N, T = define_vars("A", "N", "T", typ=backend.stream_typ) - A_domain = backend.fresh_op("A_domain", n_args=1, ret="stream") - f1 = backend.fresh_op("f1", n_args=2, ret="scalar") - f2 = backend.fresh_op("f2", n_args=2, ret="scalar") + a, i, s, t = backend.define_vars("a", "i", "s", "t", ret="scalar") + A, N, T = backend.define_vars("A", "N", "T", ret="stream") + A_domain = backend.define_vars( + "A_domain", arg_types=(backend.scalar_typ,), ret="stream" + ) + f1, f2 = backend.define_vars( + "f1", "f2", arg_types=(backend.scalar_typ, backend.scalar_typ), ret="scalar" + ) - term1 = outer.reduce( + lhs = outer.reduce( inner.reduce(inner.plus(f1(a(), s()), f2(t(), a())), {a: A()}), {A: CartesianProduct.reduce(A_domain(i()), {i: N()}), t: T()}, ) - - term2 = outer.reduce( + rhs = outer.reduce( inner.reduce( outer.reduce( inner.plus(inner.plus(f1(a(), s()), f2(t(), a()))), {a: A_domain(i())} @@ -658,11 +597,143 @@ def test_reduce_lifted_2(outer, inner, backend): ), {t: T()}, ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceDistributeCartesianProduct()) + + +# --------------------------------------------------------------------------- +# Weighted streams +# --------------------------------------------------------------------------- + + +def test_reduce_single_weighted_stream(backend: Backend): + """Single weighted stream desugars: + Sum.reduce(body, {a: WS(A, w, Product)}) + = Sum.reduce(Product.plus(w(a), body), {a: A}) + """ + a = backend.define_vars("a", ret="scalar") + A = backend.define_vars("A", ret="stream") + body, w = backend.define_vars( + "body", "w", arg_types=(backend.scalar_typ,), ret="scalar" + ) + + lhs = Sum.reduce(body(a()), {a: Product.weighted(A(), w)}) + rhs = Sum.reduce(Product.plus(w(a()), body(a())), {a: A()}) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceWeightedStream()) + + +def test_reduce_weighted_factorization(backend: Backend): + """Two independent weighted streams under Sum with Product weights factor: + Sum.reduce(f(a)*g(b), {a: Product.weighted(A, a, w_a), b: Product.weighted(B, b, w_b)}) + = (Sum.reduce(w_a(a)*f(a), {a: A})) * (Sum.reduce(w_b(b)*g(b), {b: B})) + + Exercises chaining of ``ReduceWeightedStream`` with ``ReduceFactorization`` + inside ``NormalizeIntp``. + """ + a, b = backend.define_vars("a", "b", ret="scalar") + A, B = backend.define_vars("A", "B", ret="stream") + f, g, w_a, w_b = backend.define_vars( + "f", "g", "w_a", "w_b", arg_types=(backend.scalar_typ,), ret="scalar" + ) + + lhs = Sum.reduce( + Product.plus(f(a()), g(b())), + {a: Product.weighted(A(), w_a), b: Product.weighted(B(), w_b)}, + ) + rhs = Product.plus( + Sum.reduce(Product.plus(w_a(a()), Product.plus(f(a()))), {a: A()}), + Sum.reduce(Product.plus(w_b(b()), Product.plus(g(b()))), {b: B()}), + ) + backend.check_rewrite( + lhs=lhs, rhs=rhs, rule=coproduct(ReduceWeightedStream(), ReduceFactorization()) + ) + + +def test_reduce_cartesian_weighted_stream(backend: Backend): + """``CartesianProduct.reduce`` over a ``WeightedStream`` body whose weight + is independent of the plate var rewrites to a single joint + ``WeightedStream``: + + CartesianProduct.reduce(M.weighted(s, e, w(e)), {p: P}) + = M.weighted(CartesianProduct.reduce(s, {p: P}), row, M.reduce(w(e), {e: row()})) + """ + p, e_var = backend.define_vars("p", "e_var", ret="scalar") + S, P = backend.define_vars("S", "P", ret="stream") + w = backend.define_vars("w", arg_types=(backend.scalar_typ,), ret="scalar") + + lhs = CartesianProduct.reduce(Product.weighted(S(), w), {p: P()}) + row_var = Operation.define(Iterable[backend.scalar_typ], name="row") # type: ignore[name-defined] + rhs = Product.weighted( + CartesianProduct.reduce(S(), {p: P()}), + deffn(Product.reduce(w(e_var()), {e_var: row_var()}), row_var), + ) + backend.check_rewrite(lhs=lhs, rhs=rhs, rule=ReduceCartesianWeightedStream()) + + +def test_lift_weighted_cartesian(backend: Backend): + """Compose ``ReduceCartesianWeightedStream`` + ``ReduceWeightedStream`` + + ``ReduceDistributeCartesianProduct`` on a Sum-of-Product-of-weighted shape: + + Sum.reduce( + Product.reduce(body(a()), {a: A()}), + {A: CartesianProduct.reduce(Product.weighted(S, e, w(e)), {p: P})}, + ) + + The inner ``weighted`` becomes a joint ``weighted`` (rule 1), lifts its + per-element weight into the outer Sum body (rule 2), and the lifted form + matches the inversion pattern (rule 3), yielding:: + + Product.reduce( + Sum.reduce(Product.plus(w(a()), body(a())), {a: S}), + {p: P}, + ) + """ + a, p = backend.define_vars("a", "p", ret="scalar") + A, S, P = backend.define_vars("A", "S", "P", ret="stream") + body, w = backend.define_vars( + "body", "w", arg_types=(backend.scalar_typ,), ret="scalar" + ) - check_rewrite( - lhs=term1, - rhs=term2, - rule=ReduceDistributeCartesianProduct(), - backend=backend, - free_vars=[a, i, s, t, A, N, T, A_domain, f1, f2], + lhs = Sum.reduce( + Product.reduce(body(a()), {a: A()}), + {A: CartesianProduct.reduce(Product.weighted(S(), w), {p: P()})}, + ) + rhs = Product.reduce( + Sum.reduce(Product.plus(w(a()), body(a())), {a: S()}), {p: P()} ) + backend.check_rewrite( + lhs=lhs, + rhs=rhs, + rule=coproduct( + coproduct(ReduceWeightedStream(), ReduceCartesianWeightedStream()), + ReduceDistributeCartesianProduct(), + ), + ) + + +def test_weighted_expectation_demo(): + """Demo: compute E[f(X)] = Σ_x w(x)·f(x) via a weighted reduce. + + X ranges over [1, 2, 3, 4] with weights w(x) = x/10 (a valid distribution + since the weights sum to 1) and f(x) = x*x. Expected value: + 0.1·1 + 0.2·4 + 0.3·9 + 0.4·16 = 10.0 + """ + weights = {1: 0.1, 2: 0.2, 3: 0.3, 4: 0.4} + + def _w(v: int) -> float: + if isinstance(v, Term): + raise NotHandled + return weights[v] + + def _f(v: int) -> float: + if isinstance(v, Term): + raise NotHandled + return float(v * v) + + a = Operation.define(int, name="a") + w = Operation.define(_w, name="w") + f = Operation.define(_f, name="f") + + with handler(NormalizeIntp): + result = evaluate(Sum.reduce(f(a()), {a: Product.weighted([1, 2, 3, 4], w)})) + + assert math.isclose(result, 10.0)