diff --git a/CHANGELOG.md b/CHANGELOG.md index ca9380e82..4657854b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added new attribute `OpenGraph.output_cliffords` - Added `clifford` abstract method to `AbstractMeasurement`. Implemented it for `Plane` and `Axis`. +- #515: + - `Pattern.remove_local_clifford_commands` transpiles MBQC+LC patterns into pure MBQC patterns. + - `Clifford.to_opengraph` returns an open graph without local Cliffords that implements the given single-qubit Clifford gate. + +- #515, #519: + - `Pattern.reindex` returns a pattern whose nodes have been re-indexed either by a supplied mapping or, by default, consecutively starting at 0. + - `Pattern.node_mapping` returns a mapping that can be passed to `Pattern.reindex`, with flexibility for specifying how nodes are mapped. + ### Fixed - #454, #481: Ensure `Pattern.minimize_space` only reduces max-space and does not increase it. diff --git a/docs/source/modifier.rst b/docs/source/modifier.rst index 4fe4f5727..2d3b96073 100644 --- a/docs/source/modifier.rst +++ b/docs/source/modifier.rst @@ -36,6 +36,8 @@ Pattern Manipulation .. automethod:: remove_pauli_measurements + .. automethod:: remove_local_clifford_commands + .. automethod:: to_ascii .. automethod:: to_unicode diff --git a/graphix/_db.py b/graphix/_db.py index 36f586ca4..5d7346430 100644 --- a/graphix/_db.py +++ b/graphix/_db.py @@ -8,6 +8,7 @@ from graphix import utils from graphix.fundamentals import Axis, Sign +from graphix.measurements import Measurement from graphix.ops import Ops # 24 unique 1-qubit Clifford gates @@ -226,3 +227,31 @@ class _CMTuple(NamedTuple): ("h", "x", "sdg"), ("h", "x", "s"), ) + +CLIFFORD_PAULI_DECOMPOSITION = ( + (), + (Measurement.X, -Measurement.X), + (-Measurement.X, -Measurement.X), + (-Measurement.X, Measurement.X), + (-Measurement.Y, Measurement.X), + (Measurement.Y, Measurement.X), + (Measurement.X,), + (Measurement.X, -Measurement.Y), + (-Measurement.X,), + (-Measurement.Y, -Measurement.X), + (Measurement.Y, -Measurement.X), + (Measurement.X, -Measurement.X, -Measurement.X), + (Measurement.X, -Measurement.X, Measurement.X), + (-Measurement.X, -Measurement.Y), + (-Measurement.X, Measurement.Y), + (Measurement.X, Measurement.Y), + (Measurement.Y,), + (Measurement.X, -Measurement.X, Measurement.Y), + (Measurement.X, -Measurement.X, -Measurement.Y), + (-Measurement.Y,), + (Measurement.Y, Measurement.Y), + (-Measurement.Y, -Measurement.Y), + (Measurement.Y, -Measurement.Y), + (-Measurement.Y, Measurement.Y), +) +"""Decomposition of every Clifford gate C into sequences of Pauli measurements.""" diff --git a/graphix/clifford.py b/graphix/clifford.py index fc09eff3f..ec5b93fc2 100644 --- a/graphix/clifford.py +++ b/graphix/clifford.py @@ -8,6 +8,7 @@ from enum import Enum from typing import TYPE_CHECKING, Any +import networkx as nx import numpy as np import typing_extensions @@ -18,6 +19,7 @@ CLIFFORD_LABEL, CLIFFORD_MEASURE, CLIFFORD_MUL, + CLIFFORD_PAULI_DECOMPOSITION, CLIFFORD_TO_QASM3, ) from graphix.fundamentals import Axis, ComplexUnit, I @@ -26,6 +28,8 @@ if TYPE_CHECKING: import numpy.typing as npt + from graphix import OpenGraph, PauliMeasurement + @dataclass class Domains: @@ -178,6 +182,19 @@ def commute_domains(self, domains: Domains) -> Domains: raise RuntimeError(f"{gate} should be either I, H, S or Z.") return Domains(s_domain, t_domain) + def to_opengraph(self) -> OpenGraph[PauliMeasurement]: + """Return a local-Clifford-free open graph equivalent to the Clifford gate.""" + from graphix import OpenGraph # noqa: PLC0415 + + decomposition = CLIFFORD_PAULI_DECOMPOSITION[self.value] + n = len(decomposition) + return OpenGraph( + graph=nx.path_graph(n + 1), + input_nodes=[0], + output_nodes=[n], + measurements=dict(enumerate(decomposition)), + ) + Clifford.I = Clifford(0) Clifford.X = Clifford(1) diff --git a/graphix/command.py b/graphix/command.py index 8dcad749e..3b84af840 100644 --- a/graphix/command.py +++ b/graphix/command.py @@ -5,9 +5,13 @@ import dataclasses import enum import logging +from abc import ABC, abstractmethod from enum import Enum from typing import TYPE_CHECKING, ClassVar, Literal, TypeAlias +# override introduced in Python 3.12 +from typing_extensions import override + from graphix import utils from graphix.clifford import Clifford, Domains from graphix.measurements import Measurement @@ -16,6 +20,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from typing import Self Node: TypeAlias = int @@ -44,9 +49,13 @@ def __init_subclass__(cls) -> None: utils.check_kind(cls, {"CommandKind": CommandKind, "Clifford": Clifford}) -class BaseCommand(DataclassReprMixin): +class BaseCommand(ABC, DataclassReprMixin): """Base class for pattern command.""" + @abstractmethod + def reindex(self, f: Callable[[Node], Node]) -> Self: + """Return a command whose nodes have been reindexed using ``f``.""" + @dataclasses.dataclass(repr=False) class BaseN(BaseCommand): @@ -70,6 +79,10 @@ class BaseN(BaseCommand): node: int kind: ClassVar[Literal[CommandKind.N]] = dataclasses.field(default=CommandKind.N, init=False) + @override + def reindex(self, f: Callable[[Node], Node]) -> Self: + return dataclasses.replace(self, node=f(self.node)) + @dataclasses.dataclass(repr=False) class N(BaseN, _KindChecker): @@ -101,6 +114,10 @@ class `M`, with given plane, angles, and domains. The base class `BaseM` allows node: Node kind: ClassVar[Literal[CommandKind.M]] = dataclasses.field(default=CommandKind.M, init=False) + @override + def reindex(self, f: Callable[[Node], Node]) -> BaseM: + return BaseM(f(self.node)) + @dataclasses.dataclass(repr=False) class M(BaseM, _KindChecker): @@ -156,10 +173,13 @@ def map(self, f: Callable[[Measurement], Measurement]) -> M: ------- M The resulting command. - """ return M(self.node, f(self.measurement), self.s_domain, self.t_domain) + @override + def reindex(self, f: Callable[[Node], Node]) -> M: + return M(f(self.node), self.measurement, set(map(f, self.s_domain)), set(map(f, self.t_domain))) + @dataclasses.dataclass(repr=False) class E(_KindChecker, BaseCommand): @@ -174,6 +194,11 @@ class E(_KindChecker, BaseCommand): nodes: tuple[Node, Node] kind: ClassVar[Literal[CommandKind.E]] = dataclasses.field(default=CommandKind.E, init=False) + @override + def reindex(self, f: Callable[[Node], Node]) -> E: + u, v = self.nodes + return E((f(u), f(v))) + @dataclasses.dataclass(repr=False) class C(_KindChecker, BaseCommand): @@ -191,6 +216,10 @@ class C(_KindChecker, BaseCommand): clifford: Clifford kind: ClassVar[Literal[CommandKind.C]] = dataclasses.field(default=CommandKind.C, init=False) + @override + def reindex(self, f: Callable[[Node], Node]) -> C: + return C(f(self.node), self.clifford) + @dataclasses.dataclass(repr=False) class X(_KindChecker, BaseCommand): @@ -208,6 +237,10 @@ class X(_KindChecker, BaseCommand): domain: set[Node] = dataclasses.field(default_factory=set) kind: ClassVar[Literal[CommandKind.X]] = dataclasses.field(default=CommandKind.X, init=False) + @override + def reindex(self, f: Callable[[Node], Node]) -> X: + return X(f(self.node), set(map(f, self.domain))) + @dataclasses.dataclass(repr=False) class Z(_KindChecker, BaseCommand): @@ -225,6 +258,10 @@ class Z(_KindChecker, BaseCommand): domain: set[Node] = dataclasses.field(default_factory=set) kind: ClassVar[Literal[CommandKind.Z]] = dataclasses.field(default=CommandKind.Z, init=False) + @override + def reindex(self, f: Callable[[Node], Node]) -> Z: + return Z(f(self.node), set(map(f, self.domain))) + @dataclasses.dataclass(repr=False) class S(_KindChecker, BaseCommand): @@ -242,6 +279,10 @@ class S(_KindChecker, BaseCommand): domain: set[Node] = dataclasses.field(default_factory=set) kind: ClassVar[Literal[CommandKind.S]] = dataclasses.field(default=CommandKind.S, init=False) + @override + def reindex(self, f: Callable[[Node], Node]) -> S: + return S(f(self.node), set(map(f, self.domain))) + @dataclasses.dataclass(repr=False) class T(_KindChecker, BaseCommand): @@ -255,6 +296,10 @@ class T(_KindChecker, BaseCommand): kind: ClassVar[Literal[CommandKind.T]] = dataclasses.field(default=CommandKind.T, init=False) + @override + def reindex(self, f: Callable[[Node], Node]) -> T: + return self + class Command: """Grouping of all commands for namespace exposure. diff --git a/graphix/optimization.py b/graphix/optimization.py index 99f33ea72..e3981b763 100644 --- a/graphix/optimization.py +++ b/graphix/optimization.py @@ -610,3 +610,43 @@ def decompose_domain( new_pattern.reorder_output_nodes(pattern.output_nodes) return new_pattern + + +def remove_local_clifford_commands(pattern: Pattern) -> Pattern: + """Return an equivalent pattern where local Clifford commands have been replaced by MBQC commands. + + This function transpiles MBQC+LC patterns into MBQC patterns. + """ + from graphix.pattern import Pattern # noqa: PLC0415 + + nodes = pattern.extract_nodes() + if not nodes: + return pattern + max_node = max(nodes) + new_pattern = Pattern(input_nodes=pattern.input_nodes) + mapping: dict[Node, Node] = {} + + def reindex(node: Node) -> Node: + return mapping.get(node, node) + + for cmd in pattern: + match cmd.kind: + case CommandKind.C: + cmd_node = reindex(cmd.node) + clifford_pattern = cmd.clifford.to_opengraph().to_pattern() + (output_node,) = clifford_pattern.output_nodes + # We avoid using `new_pattern.compose` here because + # pattern composition is linear in the size of each + # pattern, which would make transpilation run in + # quadratic time. + # clifford_pattern satisfies the following properties: + # - The set of input nodes is {0}. + # - The output node is the highest-indexed node. + new_pattern.extend(clifford_pattern.reindex(lambda node: cmd_node if node == 0 else node + max_node)) # noqa: B023 + max_node += output_node + mapping[cmd.node] = max_node + case _: + new_cmd = cmd.reindex(reindex) + new_pattern.add(new_cmd) + new_pattern.reorder_output_nodes(map(reindex, pattern.output_nodes)) + return new_pattern diff --git a/graphix/pattern.py b/graphix/pattern.py index bcc18035b..6e08df1b5 100644 --- a/graphix/pattern.py +++ b/graphix/pattern.py @@ -10,7 +10,8 @@ import enum import itertools import warnings -from collections.abc import Iterable, Iterator +from collections.abc import Iterable, Iterator, Mapping +from collections.abc import Set as AbstractSet from dataclasses import dataclass from enum import Enum from pathlib import Path @@ -34,9 +35,8 @@ from graphix.visualization import GraphVisualizer if TYPE_CHECKING: - from collections.abc import Callable, Container, Iterator, Mapping - from collections.abc import Set as AbstractSet - from typing import Any + from collections.abc import Callable, Collection, Container, Iterator + from typing import Any, TypeVar from numpy.random import Generator @@ -56,6 +56,9 @@ from graphix.states import State from graphix.visualization import DrawKwargs + K = TypeVar("K") + V = TypeVar("V") + _BuiltinBackendState = DensityMatrix | Statevec | MBQCTensorNet @@ -162,6 +165,124 @@ def replace(self, cmds: list[CommandType], input_nodes: list[int] | None = None) self.clear() self.extend(cmds) + def node_mapping( + self, + mapping: Mapping[Node, Node] | None = None, + *, + start: int = 0, + avoids: Collection[Node] | None = None, + preserves: Collection[Node] | None = None, + ) -> dict[Node, Node]: + """Compute an injective mapping of node indices. + + The resulting mapping can be passed to :meth:`reindex`. + Fresh indices are allocated in increasing order, starting at + ``start`` and skipping any values already used. + + Parameters + ---------- + mapping: Mapping[Node, Node] | None, optional + An initial (partial) mapping that will be extended. + A ``ValueError`` is raised if the supplied mapping is not + injective. + start: int, optional + The starting index from which fresh indices are allocated. + avoids: Collection[Node] | None, optional + Indices that must not be assigned to a node. + preserves: Collection[Node] | None, optional + Nodes that should not appear in the mapping. + Their index will not be changed by :meth:`reindex`. + + Returns + ------- + dict[Node, Node] + A dictionary that maps each node of ``self`` (except those + listed in ``preserves``) to a distinct integer index. + + Notes + ----- + The resulting dictionary is guaranteed to be an extension of + ``mapping``, i.e., all entries from ``mapping`` appear in the + resulting dictionary. In particular, the constraints specified + in ``avoids`` and ``preserves`` only affect nodes that are + freshly generated by the function; they do not apply to the + nodes already present in ``mapping``. + """ + nodes = self.extract_nodes() + if mapping is None: + result = {} + else: + result = dict(mapping) + unknown_nodes = [node for node in mapping if node not in nodes] + if unknown_nodes: + unknown_nodes_str = ", ".join(map(str, unknown_nodes)) + raise ValueError(f"Some keys in the initial mapping are not nodes: {unknown_nodes_str}") + duplicates = _duplicates_in_mapping(result) + if duplicates: + reason = ", ".join(f"{antecedents} are mapped to {value}" for value, antecedents in duplicates.items()) + raise ValueError(f"Initial mapping is not injective: {reason}") + used_indices = set(avoids or ()) + used_indices.update(result.values()) + if preserves: + used_indices.update(preserves) + preserves_set = preserves if isinstance(preserves, AbstractSet) else set(preserves) + else: + preserves_set = set() + candidate = start + for node in sorted(nodes): + if node in result or node in preserves_set: + continue + while candidate in used_indices: + candidate += 1 + result[node] = candidate + candidate += 1 + return result + + def reindex( + self, mapping: Callable[[Node], Node] | Mapping[Node, Node] | None = None, *, copy: bool = False + ) -> Pattern: + """Return a pattern whose nodes have been re-indexed using ``f``. + + This method does not verify that ``f`` is injective. The + semantic of the pattern is only preserved when ``f`` is + injective. A non-injective mapping can even break the + runnability of the resulting pattern. + + Parameters + ---------- + mapping : Callable[[Node], Node] | Mapping[Node, Node] | None, optional + A function or a mapping that translates the current node + indices to new ones. Indices that are not present in the + mapping are left unchanged. If ``f`` is omitted, a + default mapping is computed by using :meth:`node_mapping`. + + copy : bool, optional + If ``True``, the current pattern remains unchanged and a + new pattern is returned. The default is ``False``, meaning + that changes are performed in place. + + Returns + ------- + Pattern + The re-indexed pattern. Equal to ``self`` if ``copy`` is ``False``. + """ + if mapping is None: + # Suggested in issue #519 + mapping = self.node_mapping() + func: Callable[[Node], Node] = ( + (lambda node: mapping.get(node, node)) if isinstance(mapping, Mapping) else mapping + ) + new_pattern = Pattern(input_nodes=map(func, self.input_nodes)) + for cmd in self: + new_pattern.add(cmd.reindex(func)) + new_pattern.reorder_output_nodes(map(func, self.output_nodes)) + if copy: + return new_pattern + self.__input_nodes = new_pattern.__input_nodes + self.__seq = new_pattern.__seq + self.__output_nodes = new_pattern.__output_nodes + return self + def compose( self, other: Pattern, mapping: Mapping[int, int], preserve_mapping: bool = False ) -> tuple[Pattern, dict[int, int]]: @@ -235,18 +356,12 @@ def compose( ) shift = max((*nodes_p1, *mapping.values())) + 1 - mapping_sequential = { - node: i for i, node in enumerate(sorted(nodes_p2 - mapping.keys()), start=shift) - } # assigns new labels to nodes in other not specified in mapping - - mapping_complete = {**mapping, **mapping_sequential} - - mapped_inputs = [mapping_complete[n] for n in other.input_nodes] - mapped_outputs = [mapping_complete[n] for n in other.output_nodes] + mapping_complete = other.node_mapping(mapping, start=shift) + other = other.reindex(mapping_complete, copy=True) merged = mapping_values_set.intersection(self.__output_nodes) - inputs = self.__input_nodes + [n for n in mapped_inputs if n not in merged] + inputs = self.__input_nodes + [n for n in other.input_nodes if n not in merged] if preserve_mapping and not (len(merged) == len(other.input_nodes) == len(other.output_nodes)): warnings.warn( @@ -256,32 +371,12 @@ def compose( preserve_mapping = False if preserve_mapping: - io_mapping = { - mapping[i]: mapping_complete[o] for i, o in zip(other.input_nodes, other.output_nodes, strict=True) - } + io_mapping = dict(zip(other.input_nodes, other.output_nodes, strict=True)) outputs = [io_mapping[n] if n in merged else n for n in self.__output_nodes] else: - outputs = [n for n in self.__output_nodes if n not in merged] + mapped_outputs - - def update_command(cmd: CommandType) -> CommandType: - # Shallow copy is enough since the mutable attributes of cmd_new susceptible to change are reassigned - cmd_new = copy.copy(cmd) - - if cmd_new.kind is CommandKind.E: - i, j = cmd_new.nodes - cmd_new.nodes = (mapping_complete[i], mapping_complete[j]) - elif cmd_new.kind is not CommandKind.T: - cmd_new.node = mapping_complete[cmd_new.node] - match cmd_new.kind: - case CommandKind.M: - cmd_new.s_domain = {mapping_complete[i] for i in cmd_new.s_domain} - cmd_new.t_domain = {mapping_complete[i] for i in cmd_new.t_domain} - case CommandKind.X | CommandKind.Z | CommandKind.S: - cmd_new.domain = {mapping_complete[i] for i in cmd_new.domain} - - return cmd_new + outputs = [n for n in self.__output_nodes if n not in merged] + other.output_nodes - seq = self.__seq + [update_command(c) for c in other] + seq = [*self.__seq, *other] p = Pattern(input_nodes=inputs, output_nodes=outputs, cmds=seq) @@ -1737,6 +1832,36 @@ def remove_pauli_measurements( self.__output_nodes = pattern.__output_nodes return self + def remove_local_clifford_commands(self, *, copy: bool = False) -> Pattern: + """Return an equivalent pattern where local Clifford commands have been replaced by MBQC commands. + + See :func:`~optimization.remove_local_clifford_commands` + for more information. + + Parameters + ---------- + copy : bool, optional + If ``True``, the current pattern remains unchanged and a + new pattern is returned. The default is ``False``, meaning + that changes are performed in place. + stacklevel : int, optional + Stack level to use for warnings. Defaults to 1, meaning that warnings + are reported at this function's call site. + + Returns + ------- + Pattern + The pattern in which local Clifford commands have been replaced + by MBQC commands. If ``copy`` is ``False``, the result is + ``self``. + """ + new_pattern = optimization.remove_local_clifford_commands(self) + if copy: + return new_pattern + self.__seq = new_pattern.__seq + self.__output_nodes = new_pattern.__output_nodes + return self + class PatternError(Exception): """Exception subclass to handle pattern errors.""" @@ -1855,3 +1980,60 @@ def shift_outcomes(outcomes: Mapping[int, Outcome], signal_dict: Mapping[int, Ab node: toggle_outcome(outcome) if sum(outcomes[i] for i in signal_dict.get(node, [])) % 2 == 1 else outcome for node, outcome in outcomes.items() } + + +def _reverse_mapping(m: Mapping[K, V]) -> dict[V, list[K]]: + """ + Build a reverse mapping of ``m``. + + For each key-value pair ``k → v`` in the input mapping ``m``, the + returned dictionary contains an entry ``v → [k1, k2, …]`` where the + list holds all keys that were associated with ``v`` in ``m``. + The order of keys inside each list corresponds to their first + appearance while iterating over ``m``. + + Parameters + ---------- + m: Mapping[K, V] + The mapping to be inspected. + + Returns + ------- + dict[V, list[K]] + A new dictionary where each distinct value from ``m`` maps to a list + of all keys that originally pointed to that value. + """ + result: dict[V, list[K]] = {} + + for k, v in m.items(): + if (antecedents := result.get(v)) is not None: + antecedents.append(k) + else: + result[v] = [k] + + return result + + +def _duplicates_in_mapping(m: Mapping[K, V]) -> dict[V, list[K]]: + """ + Return a dictionary of values that appear more than once in ``m``. + + For each value that is mapped to by multiple keys, the resulting + dictionary contains an entry ``value → [key1, key2, …]`` where the list + holds all keys that map to that value (the order follows the + iteration order of ``m.items()``). + + Parameters + ---------- + m : Mapping[K, V] + The mapping to be inspected. + + Returns + ------- + dict[V, list[K]] + A dictionary whose keys are the duplicated values and whose values are + lists of the keys that map to them. If the input mapping is injective, + an empty dictionary is returned. + """ + reversed_mapping = _reverse_mapping(m) + return {v: antecedents for v, antecedents in reversed_mapping.items() if len(antecedents) > 1} diff --git a/tests/test_clifford.py b/tests/test_clifford.py index 6b2bd95b1..06bd981bd 100644 --- a/tests/test_clifford.py +++ b/tests/test_clifford.py @@ -11,9 +11,11 @@ import numpy as np import pytest +from graphix import Command, Pattern from graphix.clifford import Clifford from graphix.fundamentals import IXYZ_VALUES, ComplexUnit, Sign from graphix.pauli import Pauli +from graphix.random_objects import rand_state_vector if TYPE_CHECKING: from numpy.random import Generator @@ -94,3 +96,14 @@ def test_try_from_matrix(self, fx_rng: Generator, c: Clifford) -> None: def test_try_from_matrix_ng(self, fx_rng: Generator) -> None: assert Clifford.try_from_matrix(np.zeros((2, 3))) is None assert Clifford.try_from_matrix(fx_rng.normal(size=(2, 2))) is None + + @pytest.mark.parametrize("c", Clifford) + def test_to_pattern(self, fx_rng: Generator, c: Clifford) -> None: + og = c.to_opengraph() + og.to_bloch().extract_causal_flow() + pattern = og.to_pattern() + pattern_ref = Pattern(input_nodes=[0], cmds=[Command.C(0, c)]) + input_state = rand_state_vector(nqubits=1, rng=fx_rng) + state = pattern.simulate_pattern(input_state=input_state, rng=fx_rng) + state_ref = pattern_ref.simulate_pattern(input_state=input_state, rng=fx_rng) + assert state.isclose(state_ref) diff --git a/tests/test_command.py b/tests/test_command.py index 9d48db4f9..45a7647f8 100644 --- a/tests/test_command.py +++ b/tests/test_command.py @@ -1 +1,42 @@ from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from graphix import Clifford +from graphix.command import BaseM, BaseN, Command + +if TYPE_CHECKING: + from graphix.command import BaseCommand + + +@pytest.mark.parametrize( + ("cmd1", "cmd2"), + [ + (BaseN(0), BaseN(1)), + (BaseN(2), BaseN(2)), + (Command.N(0), Command.N(1)), + (Command.N(2), Command.N(2)), + (BaseM(0), BaseM(1)), + (BaseM(2), BaseM(2)), + (Command.M(0), Command.M(1)), + (Command.M(2), Command.M(2)), + (Command.E((0, 2)), Command.E((1, 2))), + (Command.E((2, 0)), Command.E((2, 1))), + (Command.C(0, Clifford.H), Command.C(1, Clifford.H)), + (Command.C(2, Clifford.S), Command.C(2, Clifford.S)), + (Command.X(0, {2}), Command.X(1, {2})), + (Command.X(2, {0}), Command.X(2, {1})), + (Command.Z(0, {2}), Command.Z(1, {2})), + (Command.Z(2, {0}), Command.Z(2, {1})), + (Command.S(0, {2}), Command.S(1, {2})), + (Command.S(2, {0}), Command.S(2, {1})), + (Command.T(), Command.T()), + ], +) +def test_reindex(cmd1: BaseCommand, cmd2: BaseCommand) -> None: + def reindex(node: int) -> int: + return 1 if node == 0 else node + + assert cmd1.reindex(reindex) == cmd2 diff --git a/tests/test_db.py b/tests/test_db.py index 93c0127f1..8c34862b1 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -1,19 +1,32 @@ from __future__ import annotations import itertools +from typing import TYPE_CHECKING +import networkx as nx import numpy as np import pytest +from graphix import Command, OpenGraph, Pattern, PauliMeasurement from graphix._db import ( CLIFFORD, CLIFFORD_CONJ, CLIFFORD_HSZ_DECOMPOSITION, CLIFFORD_MEASURE, CLIFFORD_MUL, + CLIFFORD_PAULI_DECOMPOSITION, ) from graphix.clifford import Clifford +from graphix.opengraph import OpenGraphError from graphix.ops import Ops +from graphix.random_objects import rand_state_vector + +if TYPE_CHECKING: + from typing import TypeVar + + from numpy.random import Generator + + T = TypeVar("T") class TestCliffordDB: @@ -58,3 +71,63 @@ def test_safety(self, i: int) -> None: with pytest.raises(ValueError): # Cannot create writeable view v.flags.writeable = True + + +def generate_clifford_pauli_decomposition(rng: Generator) -> tuple[tuple[PauliMeasurement, ...], ...]: + """Compute the value of CLIFFORD_PAULI_DECOMPOSITION. + + This function ensures that the length of the decomposition is + optimal by search exhaustively by increasing length. + """ + pauli_measurements = tuple(PauliMeasurement) + input_states = tuple(rand_state_vector(nqubits=1, rng=rng) for _ in range(10)) + clifford_output_states_ref = tuple( + ( + clifford, + tuple( + Pattern(input_nodes=[0], cmds=[Command.C(0, clifford)]).simulate_pattern(input_state=input_state) + for input_state in input_states + ), + ) + for clifford in Clifford + ) + + patterns: list[tuple[PauliMeasurement, ...] | None] = [None] * len(Clifford) + + def explore(n: int) -> None: + graph = nx.path_graph(n + 1) + for measurement_list in itertools.product(pauli_measurements, repeat=n): + measurements = dict(zip(range(n), measurement_list, strict=True)) + og = OpenGraph(graph=graph, input_nodes=[0], output_nodes=[n], measurements=measurements) + try: + pattern = og.to_pattern() + except OpenGraphError: + continue + for clifford, output_states_ref in clifford_output_states_ref: + if patterns[clifford.value] is not None: + continue + if all( + pattern.simulate_pattern(input_state=input_state, rng=rng).isclose(output_state_ref) + for input_state, output_state_ref in zip(input_states, output_states_ref, strict=True) + ): + patterns[clifford.value] = measurement_list + + for n in range(4): + explore(n) + + if any(pattern is None for pattern in patterns): + raise RuntimeError("Local Cliffords are guaranteed to have a decomposition in 3 or less measured nodes.") + + return tuple(map(unwrap, patterns)) + + +def unwrap(v: T | None) -> T: + """Return ``v`` if it is not ``None``, or raise an exception.""" + if v is None: + raise ValueError("Unexpected `None`.") + return v + + +def test_generate_clifford_pauli_decomposition(fx_rng: Generator) -> None: + clifford_pauli_decomposition = generate_clifford_pauli_decomposition(fx_rng) + assert clifford_pauli_decomposition == CLIFFORD_PAULI_DECOMPOSITION diff --git a/tests/test_optimization.py b/tests/test_optimization.py index 42db33200..eeb611ae6 100644 --- a/tests/test_optimization.py +++ b/tests/test_optimization.py @@ -6,12 +6,12 @@ from numpy.random import PCG64, Generator from graphix.clifford import Clifford -from graphix.command import C, CommandKind, E, M, N, X, Z +from graphix.command import C, Command, CommandKind, E, M, N, X, Z from graphix.fundamentals import ANGLE_PI, Plane from graphix.measurements import Measurement from graphix.optimization import StandardizedPattern, remove_useless_domains from graphix.pattern import Pattern -from graphix.random_objects import rand_circuit +from graphix.random_objects import rand_circuit, rand_state_vector from graphix.states import PlanarState if TYPE_CHECKING: @@ -131,3 +131,32 @@ def test_bug_482() -> None: ) output_pattern = StandardizedPattern.from_pattern(input_pattern).to_space_optimal_pattern() assert input_pattern.output_nodes == output_pattern.output_nodes + + +@pytest.mark.parametrize("jumps", range(1, 11)) +def test_remove_local_clifford_commands(fx_bg: PCG64, jumps: int) -> None: + rng = Generator(fx_bg.jumped(jumps)) + nqubits = 4 + depth = 4 + circuit = rand_circuit(nqubits, depth, rng) + pattern = circuit.transpile().pattern + pattern.remove_pauli_measurements() + assert any(cmd.kind == CommandKind.C for cmd in pattern) + new_pattern = pattern.remove_local_clifford_commands(copy=True) + assert not any(cmd.kind == CommandKind.C for cmd in new_pattern) + input_state = rand_state_vector(nqubits, rng=rng) + state_ref = pattern.simulate_pattern(input_state=input_state, rng=rng) + state = new_pattern.simulate_pattern(input_state=input_state, rng=rng) + assert state.isclose(state_ref) + + +def test_remove_local_clifford_commands_edge_cases() -> None: + pattern = Pattern() + pattern.remove_local_clifford_commands(copy=False) + assert list(pattern) == [] + pattern = Pattern(input_nodes=[0, 1], cmds=[Command.C(0, Clifford.H), Command.E((0, 1))]) + pattern.remove_local_clifford_commands(copy=False) + pattern.check_runnability() + pattern = Pattern(input_nodes=[0, 1], cmds=[Command.C(0, Clifford.H), Command.C(0, Clifford.S)]) + pattern.remove_local_clifford_commands(copy=False) + pattern.check_runnability() diff --git a/tests/test_pattern.py b/tests/test_pattern.py index 6a87c4016..2a040794d 100644 --- a/tests/test_pattern.py +++ b/tests/test_pattern.py @@ -1077,6 +1077,31 @@ def test_perform_pauli_pushing(self) -> None: original_pattern.to_bloch().perform_pauli_pushing() assert original_pattern.perform_pauli_pushing(copy=True, standardize=True).is_standard() + def test_node_mapping(self) -> None: + pattern = Pattern(input_nodes=[3], cmds=[N(1), E((1, 3)), M(3), C(1, Clifford.H), X(1, {3}), Z(1, {3})]) + assert pattern.node_mapping() == {1: 0, 3: 1} + assert pattern.node_mapping({3: 1}, start=1) == {1: 2, 3: 1} + assert pattern.node_mapping(start=1, preserves={1}) == {3: 2} + with pytest.raises(ValueError, match="Initial mapping is not injective"): + pattern.node_mapping({1: 0, 3: 0}) + with pytest.raises(ValueError, match="Some keys in the initial mapping are not nodes"): + pattern.node_mapping({2: 0}) + + def test_reindex(self) -> None: + pattern = Pattern(input_nodes=[3], cmds=[N(1), E((1, 3)), M(3), C(1, Clifford.H), X(1, {3}), Z(1, {3})]) + pattern_copy = pattern.copy() + pattern_reindexed = pattern.reindex(copy=True) + assert pattern_reindexed.input_nodes == [1] + assert list(pattern_reindexed) == [N(0), E((0, 1)), M(1), C(0, Clifford.H), X(0, {1}), Z(0, {1})] + assert pattern_reindexed.output_nodes == [0] + assert pattern.input_nodes == pattern_copy.input_nodes + assert list(pattern) == list(pattern_copy) + assert pattern.output_nodes == pattern_copy.output_nodes + pattern.reindex() + assert pattern.input_nodes == pattern_reindexed.input_nodes + assert list(pattern) == list(pattern_reindexed) + assert pattern.output_nodes == pattern_reindexed.output_nodes + def test_extract_opengraph_standardization(self) -> None: p = Pattern(cmds=[N(0), C(0, Clifford.H), M(0, Measurement.XY(0.3))]) og = p.extract_opengraph()