From e075e4b73d8121d8728643b34c0bd5de54a3e5d8 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Wed, 4 Mar 2026 16:21:37 -0500 Subject: [PATCH 1/3] Add numpy-style docstrings and reorganize API reference (#82) * Initial plan * Initial plan for docstring updates and autodocs reorganization Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> * Add numpy-style module-level docstring to caskade __init__.py Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Update docstrings in base.py to comprehensive numpy-style format Improve or add numpy-style docstrings for public API in Node, Memo, and helper functions. Fix incorrect parameter in graphviz docstring and malformed code example in link docstring. No code logic changes. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Add numpy-style docstrings to backend.py Add comprehensive numpy-style docstrings to the Backend class, its public methods and properties, the ArrayLike type alias, and the module-level backend instance. No code logic changes. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Add comprehensive numpy-style docstrings to Param class and valid_shape Add or improve docstrings for all public properties and methods in param.py including valid_shape(), Param.dynamic, Param.pointer, Param.static, Param.node_type, Param.to_dynamic, Param.to_static, Param.to_pointer, Param.shape, Param.batched, Param.batch_shape, Param.group, Param.dtype, Param.device, Param.value, Param.npvalue, Param.cyclic, Param.valid, and Param.is_valid. All docstrings follow numpy-style with Parameters, Returns, and Raises sections as appropriate. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Update docstrings in module.py to numpy-style Add or improve docstrings for Module.__init__, all_params, param_order, dynamic, and static properties. All other docstrings were already adequate. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Update decorators.py docstrings to numpy-style format - forward: Add long description about parameter passing and activation, reorder sections (Parameters, Returns, Examples) - active_cache: Convert **WARNING** to Warnings section, Note to Notes section, Example:: to proper Examples section, use double backticks for inline code references Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * docs: update context.py docstrings to comprehensive numpy-style Update class docstrings for ActiveContext, ValidContext, and OverrideParam to use numpy-style format with Parameters, Raises, and Examples sections. No code logic changes. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Add numpy-style docstrings to collection.py Add comprehensive docstrings to NodeCollection, NodeTuple, and NodeList classes including class-level docs, property docs, and method docs. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Update docstrings in mixins.py to comprehensive numpy-style Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Convert docstrings in utils.py from Google-style to numpy-style Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * docs: expand error and warning docstrings to numpy-style Update docstrings in errors.py and warnings.py to use numpy-style format with brief descriptions of when each error/warning is raised. Add Parameters sections to classes with __init__ methods. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Expand test() docstring to numpy-style with Examples section Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Add numpy-style docstrings to all caskade classes/methods and reorganize autodocs Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> * Remove test_graph artifact and add to gitignore Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> * Final docstring and autodocs update Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> * Fix gitignore: properly add test_graph and remove artifact Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> * Address PR review feedback on docstrings for ValidContext, OverrideParam, to_valid/from_valid, and group Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> * Change "Raised" to "Issued" for warning description Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * more explicit backend reference in autodoc Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * update topological ordering description to mention recursive nature Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Split autodocs into separate TOC pages; move Node/Backend to bottom; update ArrayLike type Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Connor Stone, PhD Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .gitignore | 5 +- docs/source/_toc.yml | 12 ++ docs/source/api/backend.rst | 7 + docs/source/api/collections.rst | 14 ++ docs/source/api/context.rst | 11 ++ docs/source/api/decorators.rst | 8 ++ docs/source/api/exceptions.rst | 38 ++++++ docs/source/api/memo.rst | 5 + docs/source/api/module.rst | 6 + docs/source/api/node.rst | 6 + docs/source/api/param.rst | 6 + docs/source/api/utilities.rst | 7 + docs/source/api/warnings.rst | 11 ++ docs/source/modules.rst | 12 +- src/caskade/__init__.py | 19 +++ src/caskade/backend.py | 133 +++++++++++++++++- src/caskade/base.py | 227 +++++++++++++++++++++++++++---- src/caskade/collection.py | 84 ++++++++++++ src/caskade/context.py | 81 ++++++++++- src/caskade/decorators.py | 35 +++-- src/caskade/errors.py | 114 ++++++++++++++-- src/caskade/mixins.py | 116 +++++++++++++++- src/caskade/module.py | 47 ++++++- src/caskade/param.py | 234 ++++++++++++++++++++++++++++++-- src/caskade/tests.py | 15 +- src/caskade/utils.py | 55 +++++--- src/caskade/warnings.py | 31 ++++- 27 files changed, 1233 insertions(+), 106 deletions(-) create mode 100644 docs/source/api/backend.rst create mode 100644 docs/source/api/collections.rst create mode 100644 docs/source/api/context.rst create mode 100644 docs/source/api/decorators.rst create mode 100644 docs/source/api/exceptions.rst create mode 100644 docs/source/api/memo.rst create mode 100644 docs/source/api/module.rst create mode 100644 docs/source/api/node.rst create mode 100644 docs/source/api/param.rst create mode 100644 docs/source/api/utilities.rst create mode 100644 docs/source/api/warnings.rst diff --git a/.gitignore b/.gitignore index a56f3b3..589cb21 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,7 @@ cython_debug/ # HDF5 files **.h5 -**settings.json \ No newline at end of file +**settings.json + +# Test artifacts +test_graph diff --git a/docs/source/_toc.yml b/docs/source/_toc.yml index b523cf9..0b584a4 100644 --- a/docs/source/_toc.yml +++ b/docs/source/_toc.yml @@ -12,6 +12,18 @@ chapters: - file: contributing - file: license - file: modules + sections: + - file: api/module + - file: api/param + - file: api/decorators + - file: api/context + - file: api/collections + - file: api/memo + - file: api/exceptions + - file: api/warnings + - file: api/utilities + - file: api/node + - file: api/backend # - file: frequently_asked_questions # - file: citation # - file: glossary diff --git a/docs/source/api/backend.rst b/docs/source/api/backend.rst new file mode 100644 index 0000000..8702978 --- /dev/null +++ b/docs/source/api/backend.rst @@ -0,0 +1,7 @@ +Backend +======= + +.. autoclass:: caskade.backend.Backend + :members: + +.. autodata:: caskade.backend.backend diff --git a/docs/source/api/collections.rst b/docs/source/api/collections.rst new file mode 100644 index 0000000..6495a5b --- /dev/null +++ b/docs/source/api/collections.rst @@ -0,0 +1,14 @@ +Collections +=========== + +.. autoclass:: caskade.NodeCollection + :members: + :show-inheritance: + +.. autoclass:: caskade.NodeList + :members: + :show-inheritance: + +.. autoclass:: caskade.NodeTuple + :members: + :show-inheritance: diff --git a/docs/source/api/context.rst b/docs/source/api/context.rst new file mode 100644 index 0000000..b64354c --- /dev/null +++ b/docs/source/api/context.rst @@ -0,0 +1,11 @@ +Context Managers +================ + +.. autoclass:: caskade.ActiveContext + :members: + +.. autoclass:: caskade.ValidContext + :members: + +.. autoclass:: caskade.OverrideParam + :members: diff --git a/docs/source/api/decorators.rst b/docs/source/api/decorators.rst new file mode 100644 index 0000000..66e1359 --- /dev/null +++ b/docs/source/api/decorators.rst @@ -0,0 +1,8 @@ +Decorators +========== + +.. autofunction:: caskade.forward + +.. autoclass:: caskade.active_cache + :members: + :show-inheritance: diff --git a/docs/source/api/exceptions.rst b/docs/source/api/exceptions.rst new file mode 100644 index 0000000..302447e --- /dev/null +++ b/docs/source/api/exceptions.rst @@ -0,0 +1,38 @@ +Exceptions +========== + +.. autoclass:: caskade.CaskadeException + :show-inheritance: + +.. autoclass:: caskade.GraphError + :show-inheritance: + +.. autoclass:: caskade.BackendError + :show-inheritance: + +.. autoclass:: caskade.LinkToAttributeError + :show-inheritance: + +.. autoclass:: caskade.NodeConfigurationError + :show-inheritance: + +.. autoclass:: caskade.ParamConfigurationError + :show-inheritance: + +.. autoclass:: caskade.ParamTypeError + :show-inheritance: + +.. autoclass:: caskade.ActiveStateError + :show-inheritance: + +.. autoclass:: caskade.FillParamsError + :show-inheritance: + +.. autoclass:: caskade.FillParamsArrayError + :show-inheritance: + +.. autoclass:: caskade.FillParamsSequenceError + :show-inheritance: + +.. autoclass:: caskade.FillParamsMappingError + :show-inheritance: diff --git a/docs/source/api/memo.rst b/docs/source/api/memo.rst new file mode 100644 index 0000000..e02ab8b --- /dev/null +++ b/docs/source/api/memo.rst @@ -0,0 +1,5 @@ +Graph Communication +=================== + +.. autoclass:: caskade.Memo + :members: diff --git a/docs/source/api/module.rst b/docs/source/api/module.rst new file mode 100644 index 0000000..6127798 --- /dev/null +++ b/docs/source/api/module.rst @@ -0,0 +1,6 @@ +Module +====== + +.. autoclass:: caskade.Module + :members: + :show-inheritance: diff --git a/docs/source/api/node.rst b/docs/source/api/node.rst new file mode 100644 index 0000000..e904226 --- /dev/null +++ b/docs/source/api/node.rst @@ -0,0 +1,6 @@ +Node +==== + +.. autoclass:: caskade.Node + :members: + :show-inheritance: diff --git a/docs/source/api/param.rst b/docs/source/api/param.rst new file mode 100644 index 0000000..cc131e2 --- /dev/null +++ b/docs/source/api/param.rst @@ -0,0 +1,6 @@ +Param +===== + +.. autoclass:: caskade.Param + :members: + :show-inheritance: diff --git a/docs/source/api/utilities.rst b/docs/source/api/utilities.rst new file mode 100644 index 0000000..63ea77b --- /dev/null +++ b/docs/source/api/utilities.rst @@ -0,0 +1,7 @@ +Utilities & Testing +=================== + +.. automodule:: caskade.utils + :members: + +.. autofunction:: caskade.test diff --git a/docs/source/api/warnings.rst b/docs/source/api/warnings.rst new file mode 100644 index 0000000..f6a13fe --- /dev/null +++ b/docs/source/api/warnings.rst @@ -0,0 +1,11 @@ +Warnings +======== + +.. autoclass:: caskade.CaskadeWarning + :show-inheritance: + +.. autoclass:: caskade.InvalidValueWarning + :show-inheritance: + +.. autoclass:: caskade.SaveStateWarning + :show-inheritance: diff --git a/docs/source/modules.rst b/docs/source/modules.rst index f86abc1..a2bad5f 100644 --- a/docs/source/modules.rst +++ b/docs/source/modules.rst @@ -1,9 +1,5 @@ -caskade docstrings -================== +API Reference +============= -Someday I'll make a nicely formatted interface to all this, but for now, here's -a list of all the modules and their functions. You can just search for what you -need here and get more detailed information. - -.. automodule:: caskade - :members: \ No newline at end of file +This section documents the public API of ``caskade``. Browse the pages +below for detailed class and function documentation. \ No newline at end of file diff --git a/src/caskade/__init__.py b/src/caskade/__init__.py index 1021295..58ed274 100644 --- a/src/caskade/__init__.py +++ b/src/caskade/__init__.py @@ -1,3 +1,22 @@ +""" +caskade - Build scientific simulators as directed acyclic graphs. + +Caskade provides a framework for constructing modular scientific simulators +by composing computational steps into a directed acyclic graph (DAG). It +handles parameter management, caching, and context-dependent evaluation. + +Main Public API +--------------- +Node : Base class for building computational graph nodes. +Module : High-level container for assembling simulator components. +Param : Declare and manage parameters within nodes. +forward : Decorator to define the forward computation of a node. +active_cache : Decorator for caching intermediate results. +NodeCollection, NodeList, NodeTuple : Collections of nodes. +ActiveContext, ValidContext, OverrideParam : Context managers for evaluation. +backend : Array backend abstraction (NumPy, PyTorch, etc.). +utils : Utility functions. +""" from ._version import version as VERSION # noqa from .base import Node, Memo diff --git a/src/caskade/backend.py b/src/caskade/backend.py index 7ac1220..2000b91 100644 --- a/src/caskade/backend.py +++ b/src/caskade/backend.py @@ -1,23 +1,63 @@ +"""Backend abstraction for array operations. + +Provides a unified :class:`Backend` class that delegates array creation and +manipulation to one of three libraries: **torch**, **jax**, or **numpy**. +A module-level :data:`backend` instance is created on import and serves as the +primary interface for users. +""" import os import importlib -from typing import Annotated +from typing import Union from torch import Tensor import numpy as np from . import utils -ArrayLike = Annotated[ - Tensor, - "One of: torch.Tensor, numpy.ndarray, jax.numpy.ndarray depending on the chosen backend.", -] +#: Type alias for array types across backends. +#: Resolves to ``torch.Tensor``, ``numpy.ndarray``, or ``jax.numpy.ndarray`` +#: depending on the active backend. +ArrayLike = Union[Tensor, np.ndarray, "jnp.Array"] class Backend: + """Unified interface for array operations across torch, jax, and numpy. + + Provides a single API for creating and manipulating arrays regardless of + the underlying library. Methods such as ``make_array``, ``concatenate``, + ``to``, ``sigmoid``, and ``logit`` are dynamically bound when the backend + is set, delegating to the appropriate library-specific implementation. + + Parameters + ---------- + backend : str, optional + Backend name: ``"torch"``, ``"jax"``, or ``"numpy"``. If ``None``, + reads from the ``CASKADE_BACKEND`` environment variable, defaulting + to ``"torch"``. + + Examples + -------- + Use the module-level ``backend`` instance to switch backends:: + + from caskade import backend + backend.backend = "numpy" + arr = backend.make_array([1.0, 2.0, 3.0]) + """ + def __init__(self, backend=None): + """Initialize the backend. + + Parameters + ---------- + backend : str, optional + Backend name: ``"torch"``, ``"jax"``, or ``"numpy"``. If ``None``, + reads from the ``CASKADE_BACKEND`` environment variable, defaulting + to ``"torch"``. + """ self.backend = backend @property def backend(self): + """str : Name of the active backend (``"torch"``, ``"jax"``, or ``"numpy"``).""" return self._backend @backend.setter @@ -85,6 +125,22 @@ def setup_numpy(self): @property def array_type(self): + """type : The array class for the active backend. + + Returns ``torch.Tensor``, ``jax.numpy.ndarray``, or ``numpy.ndarray`` + depending on the current backend. Useful for ``isinstance`` checks. + + Returns + ------- + type + The array class used by the active backend. + + Examples + -------- + :: + + isinstance(my_array, backend.array_type) + """ return self._array_type() def _make_array_torch(self, array, dtype=None, device=None): @@ -169,18 +225,80 @@ def _to_numpy_numpy(self, array): return array def any(self, array): + """Test whether any element evaluates to ``True``. + + Parameters + ---------- + array : ArrayLike + Input array. + + Returns + ------- + ArrayLike + Scalar result; ``True`` if any element is non-zero. + """ return self.module.any(array) def all(self, array): + """Test whether all elements evaluate to ``True``. + + Parameters + ---------- + array : ArrayLike + Input array. + + Returns + ------- + ArrayLike + Scalar result; ``True`` if every element is non-zero. + """ return self.module.all(array) def log(self, array): + """Compute the natural logarithm element-wise. + + Parameters + ---------- + array : ArrayLike + Input array. + + Returns + ------- + ArrayLike + Element-wise natural logarithm of the input. + """ return self.module.log(array) def exp(self, array): + """Compute the exponential element-wise. + + Parameters + ---------- + array : ArrayLike + Input array. + + Returns + ------- + ArrayLike + Element-wise exponential of the input. + """ return self.module.exp(array) def sum(self, array, axis=None): + """Sum array elements over a given axis. + + Parameters + ---------- + array : ArrayLike + Input array. + axis : int or None, optional + Axis along which to sum. If ``None``, sums all elements. + + Returns + ------- + ArrayLike + Sum of elements. + """ return self.module.sum(array, axis=axis) def _sigmoid_torch(self, array): @@ -202,4 +320,9 @@ def _logit_numpy(self, array): return np.log(array / (1 - array)) +#: Module-level :class:`Backend` instance used as the default entry point. +#: Import and configure this object to switch backends globally:: +#: +#: from caskade import backend +#: backend.backend = "numpy" backend = Backend() diff --git a/src/caskade/base.py b/src/caskade/base.py index 85e3b4c..c46550f 100644 --- a/src/caskade/base.py +++ b/src/caskade/base.py @@ -16,7 +16,22 @@ def attrsetter(obj, attr, value): - """Set an attribute on an object.""" + """ + Set an attribute on an object, supporting nested dot-separated paths. + + If the value is the string ``"NONE"``, it is converted to ``None``. + Dot-separated attribute paths (e.g. ``"a.b.c"``) are resolved + recursively so that the final attribute is set on the correct object. + + Parameters + ---------- + obj : object + The target object on which to set the attribute. + attr : str + The attribute name or dot-separated path (e.g. ``"sub.attr"``). + value : Any + The value to assign. The string ``"NONE"`` is treated as ``None``. + """ if isinstance(value, str) and value == "NONE": value = None if "." in attr: @@ -27,11 +42,31 @@ def attrsetter(obj, attr, value): def is_valid_name(name): + """ + Check whether a string is a valid Python identifier and not a keyword. + + Parameters + ---------- + name : str + The candidate name to validate. + + Returns + ------- + bool + ``True`` if *name* is a valid Python identifier and is not a + reserved keyword, ``False`` otherwise. + """ return name.isidentifier() and not keyword.iskeyword(name) class meta: - """Meta information for a ``Node`` object.""" + """ + Container for meta information attached to a ``Node`` object. + + Each ``Node`` instance carries a ``meta`` attribute that is an instance + of this class. Arbitrary attributes may be set on it to store + auxiliary metadata without polluting the node's own namespace. + """ pass @@ -67,6 +102,21 @@ def __init__( link: Optional[Union["Node", tuple["Node"]]] = None, description: str = "", ): + """ + Initialise a new ``Node``. + + Parameters + ---------- + name : str, optional + Human-readable name for this node. Must be a valid Python + identifier and not a reserved keyword. Defaults to the class + name. + link : Node or tuple of Node, optional + One or more child nodes to link immediately after construction. + Each child is linked using its ``name`` as the key. + description : str, optional + Free-form text describing the purpose of this node. + """ if name is None: name = self.__class__.__name__ if not isinstance(name, str): @@ -89,18 +139,22 @@ def __init__( @property def name(self) -> str: + """str : The name of this node.""" return self._name @property def children(self) -> dict[str, "Node"]: + """dict[str, Node] : Mapping of link keys to child nodes.""" return self._children @property def parents(self) -> set["Node"]: + """set[Node] : Set of parent nodes that link to this node.""" return self._parents @property def subgraphs(self) -> set["Node"]: + """set[Node] : Subset of children linked hierarchically.""" return self._subgraphs def _link(self, key: str, child: "Node"): @@ -149,15 +203,18 @@ def link( Examples -------- - Example making some ``Node`` objects and then linking/unlinking them. + Example making some ``Node`` objects and then linking/unlinking them, demonstrating multiple ways to link/unlink:: - n1 = Node() n2 = Node() + n1 = Node() + n2 = Node() - n1.link("subnode", n2) # may use any str as the key + n1.link("subnode", n2) # may use any str as the key n1.unlink("subnode") - # Alternately, link by object n1.link(n2) n1.unlink(n2) + # Alternatively, link by object + n1.link(n2) + n1.unlink(n2) """ if ( isinstance(key, (tuple, list)) @@ -189,10 +246,18 @@ def hierarchical_link(self, key: str, child: "Node"): Parameters ---------- - key: (str) + key : str The key to link the child node with. - child: (Node) + child : Node The child ``Node`` object to link to. + + Examples + -------- + :: + + parent = Node(name="parent") + child = Node(name="child") + parent.hierarchical_link("child", child) """ self._subgraphs.add(child) @@ -208,7 +273,21 @@ def _unlink(self, key: str): self.update_graph() def unlink(self, key: Union[str, "Node", list, tuple]): - """Unlink the current ``Node`` object from another ``Node`` object which is a child.""" + """ + Unlink one or more child nodes from this node. + + Parameters + ---------- + key : str, Node, list, or tuple + Identifier of the child(ren) to remove. May be a link key + string, the child ``Node`` object itself, or a list/tuple of + keys or nodes to unlink in bulk. + + Raises + ------ + GraphError + If the graph is currently active. + """ if isinstance(key, Node): for node in self.children: if self.children[node] is key: @@ -223,7 +302,16 @@ def unlink(self, key: Union[str, "Node", list, tuple]): def topological_ordering(self) -> tuple["Node"]: """ Return a topological ordering of the graph below the current node. - Uses Iterative Deepening DFS (Post-Order) to resolve dependencies. + + Performs a recursive depth-first search with post-order traversal to + resolve dependencies. The result starts with this node and proceeds + to its descendants in dependency order. + + Returns + ------- + tuple[Node] + All nodes reachable from (and including) this node, ordered so + that every parent appears before its children. """ visited = set() stack = [] @@ -254,17 +342,29 @@ def update_graph(self): @property def active(self) -> bool: + """bool : ``True`` if the node is currently in an active simulation run.""" return any(memo.startswith("active") for memo in self._memos) @property def online(self) -> bool: + """bool : ``True`` if the node is online within a hierarchical sub-graph.""" return any(memo.endswith("_active") for memo in self._memos) @property def memos(self) -> set[str]: + """set[str] : Current set of memo strings held by this node.""" return self._memos def add_memo(self, memo): + """ + Add a memo string and propagate it to all children. + + Parameters + ---------- + memo : str + The memo message to add. Children in ``subgraphs`` receive + the memo with the child name appended (``memo|child_name``). + """ self._memos.add(memo) # Propagate memo to children @@ -272,6 +372,15 @@ def add_memo(self, memo): child.add_memo(memo + (f"|{child.name}" if child in self.subgraphs else "")) def remove_memo(self, memo): + """ + Remove a memo string and propagate removal to all children. + + Parameters + ---------- + memo : str + The memo message to remove. The same propagation rules as + ``add_memo`` apply. + """ self._memos.discard(memo) # Propagate removal to children @@ -405,7 +514,26 @@ def _append_state_hdf5(self, h5group): child._append_state_hdf5(h5group[key]) def append_state(self, saveto: Union[str, "File"]): - """Append the state of the node and its children to an existing HDF5 file.""" + """ + Append the current state to an existing HDF5 file. + + The file must have been previously created by ``save_state`` with + ``appendable=True``. The graph structure in the file is verified + before appending. + + Parameters + ---------- + saveto : str or File + Path to an HDF5 file (``'.h5'`` or ``'.hdf5'``) or an open + HDF5 ``File`` object. + + Raises + ------ + GraphError + If the graph structure no longer matches the file. + NotImplementedError + If the file path does not end with a supported extension. + """ if isinstance(saveto, str): if saveto.endswith(".h5") or saveto.endswith(".hdf5"): with h5py.File(saveto, "a") as h5file: @@ -446,7 +574,28 @@ def _load_state_hdf5(self, h5group, index: int = -1, _done_load: set = None): child._load_state_hdf5(h5group[key], index=index, _done_load=_done_load) def load_state(self, loadfrom: Union[str, "File"], index: int = -1, **kwargs): - """Load the state of the node and its children.""" + """ + Load node state (and children) from an HDF5 file. + + Parameters + ---------- + loadfrom : str or File + Path to an HDF5 file (``'.h5'`` or ``'.hdf5'``) or an open + HDF5 ``File`` object. + index : int, optional + Sample index to load when the file was saved in appendable + mode. Defaults to ``-1`` (last sample). + **kwargs + Additional keyword arguments forwarded to ``h5py.File`` + (e.g. ``driver``). + + Raises + ------ + GraphError + If the graph structure no longer matches the file. + NotImplementedError + If the file path does not end with a supported extension. + """ if isinstance(loadfrom, str): if loadfrom.endswith(".h5") or loadfrom.endswith(".hdf5"): with h5py.File(loadfrom, "r", **{"driver": "core", **kwargs}) as h5file: @@ -465,17 +614,20 @@ def graphviz_style(self): return {"style": "solid", "color": "black", "shape": "circle"} def graphviz(self, saveto: Optional[str] = None) -> "graphviz.Digraph": - """Return a graphviz object representing the graph below the current - node in the DAG. + """ + Return a graphviz ``Digraph`` representing the DAG below this node. Parameters ---------- - top_down: (bool, optional) - Whether to draw the graph top-down (current node at top) or - bottom-up (current node at bottom). Defaults to True. - saveto: (Optional[str], optional) - If provided, save the graph to this file. The file extension - determines the format (e.g. '.pdf', '.png'). Defaults to None. + saveto : str, optional + If provided, save the rendered graph to this file path. The + file extension determines the output format (e.g. ``'.pdf'``, + ``'.png'``). Defaults to ``None``. + + Returns + ------- + graphviz.Digraph + The constructed directed-graph object. """ import graphviz # noqa @@ -514,8 +666,17 @@ def node_str(self): return f"{self.name}|{self.node_type}" def graph_dict(self) -> dict[str, dict]: - """Return a dictionary representation of the graph below the current - node.""" + """ + Return a nested dictionary representation of the graph. + + Each key is a string of the form ``"name|node_type"`` and the value + is a dict containing the same structure for that node's children. + + Returns + ------- + dict[str, dict] + Nested dictionary mirroring the DAG hierarchy. + """ rep = self.node_str graph = { rep: {}, @@ -525,7 +686,27 @@ def graph_dict(self) -> dict[str, dict]: return graph def graph_print(self, dag: dict, depth: int = 0, indent: int = 4, result: str = "") -> str: - """Print the graph dictionary in a human-readable format.""" + """ + Recursively render a graph dictionary as an indented string. + + Parameters + ---------- + dag : dict[str, dict] + A nested dictionary as returned by ``graph_dict``. + depth : int, optional + Current indentation depth (used during recursion). Defaults + to ``0``. + indent : int, optional + Number of spaces per indentation level. Defaults to ``4``. + result : str, optional + Accumulator string (used during recursion). Defaults to + ``""``. + + Returns + ------- + str + A human-readable, indented representation of the graph. + """ for key in dag: result = f"{result}{' ' * indent * depth}{key}\n" result = self.graph_print(dag[key], depth + 1, indent, result) + "\n" diff --git a/src/caskade/collection.py b/src/caskade/collection.py index 5e4eab3..782ceb9 100644 --- a/src/caskade/collection.py +++ b/src/caskade/collection.py @@ -4,6 +4,13 @@ class NodeCollection(Node, GetSetValues): + """Base mixin for collections of nodes that track parameters. + + Provides shared functionality for traversing, querying, and converting + parameters within a graph of nodes. Subclasses such as ``NodeTuple`` and + ``NodeList`` combine this mixin with a standard Python sequence type. + """ + def to_dynamic(self, children_only=True): """Change all parameters to dynamic parameters. @@ -34,20 +41,51 @@ def to_static(self, children_only=True): @property def dynamic_params(self) -> tuple[Param]: + """All dynamic parameters in the graph below this node. + + Returns + ------- + tuple of Param + Dynamic (non-static, non-pointer) parameters found via + topological ordering. + """ T = self.topological_ordering() return tuple(filter(lambda n: isinstance(n, Param) and n.dynamic, T)) @property def dynamic_param_groups(self) -> tuple[int]: + """Sorted unique group identifiers of all dynamic parameters. + + Returns + ------- + tuple of int + Sorted group indices present among the dynamic parameters. + """ return tuple(sorted(set(p.group for p in self.dynamic_params))) @property def static_params(self) -> tuple[Param]: + """All static parameters in the graph below this node. + + Returns + ------- + tuple of Param + Static (non-dynamic, non-pointer) parameters found via + topological ordering. + """ T = self.topological_ordering() return tuple(filter(lambda n: isinstance(n, Param) and n.static, T)) @property def pointer_params(self) -> tuple[Param]: + """All pointer parameters in the graph below this node. + + Returns + ------- + tuple of Param + Parameters that act as pointers to other parameters, found via + topological ordering. + """ T = self.topological_ordering() return tuple(filter(lambda n: isinstance(n, Param) and n.pointer, T)) @@ -59,10 +97,24 @@ def deepcopy(self): @property def dynamic(self): + """Whether any node in this collection has dynamic parameters. + + Returns + ------- + bool + ``True`` if at least one contained node is dynamic. + """ return any(node.dynamic for node in self) @property def static(self): + """Whether all nodes in this collection are static. + + Returns + ------- + bool + ``True`` if no contained node is dynamic. + """ return not self.dynamic def __mul__(self, other): @@ -79,6 +131,19 @@ def __hash__(self): class NodeTuple(NodeCollection, tuple): + """Immutable, ordered collection of nodes. + + Behaves like a standard ``tuple`` but also participates in the caskade + node graph. All elements must be ``Node`` instances and are automatically + linked as children upon construction. + + Parameters + ---------- + iterable : iterable of Node, optional + Nodes to include in the tuple. + name : str, optional + Human-readable name for this collection node. + """ def __init__(self, iterable=None, name=None): tuple.__init__(iterable) @@ -105,6 +170,19 @@ def __add__(self, other): class NodeList(NodeCollection, list): + """Mutable, ordered collection of nodes. + + Behaves like a standard ``list`` but also participates in the caskade + node graph. All elements must be ``Node`` instances. Graph links are + automatically updated whenever the list is modified. + + Parameters + ---------- + iterable : iterable of Node, optional + Nodes to include in the list. Defaults to an empty iterable. + name : str, optional + Human-readable name for this collection node. + """ def __init__(self, iterable=(), name=None): list.__init__(self, iterable) @@ -128,32 +206,38 @@ def _link_nodes(self): self.link(node) def append(self, node): + """Append a node to the list and update graph links.""" self._unlink_nodes() super().append(node) self._link_nodes() def insert(self, index, node): + """Insert a node at the given index and update graph links.""" self._unlink_nodes() super().insert(index, node) self._link_nodes() def extend(self, iterable): + """Extend the list with nodes from an iterable and update graph links.""" self._unlink_nodes() super().extend(iterable) self._link_nodes() def clear(self): + """Remove all nodes from the list and update graph links.""" self._unlink_nodes() super().clear() self._link_nodes() def pop(self, index=-1): + """Remove and return a node at the given index, updating graph links.""" self._unlink_nodes() node = super().pop(index) self._link_nodes() return node def remove(self, value): + """Remove the first occurrence of a node and update graph links.""" self._unlink_nodes() super().remove(value) self._link_nodes() diff --git a/src/caskade/context.py b/src/caskade/context.py index 522fdf7..35ee7d0 100644 --- a/src/caskade/context.py +++ b/src/caskade/context.py @@ -5,8 +5,31 @@ class ActiveContext: """ - Context manager to activate a module for a simulation. Only inside an - ActiveContext is it possible to fill/clear the dynamic and live parameters. + Context manager to activate a module for a simulation. + + Only inside an ``ActiveContext`` is it possible to fill or clear the + dynamic and live parameters. On entry, the module is marked as active + (or its current parameter state is saved if already active). On exit, + the state is restored. + + Parameters + ---------- + module : Module + The module to activate for the duration of the context. + + Raises + ------ + ActiveStateError + If the module is already running a simulation (``module.online`` + is ``True``). + + Examples + -------- + Activate a module, fill parameters, and run a forward pass:: + + with ActiveContext(my_module): + my_module.fill_params(params) + result = my_module.my_forward(x) """ def __init__(self, module: Module): @@ -34,8 +57,27 @@ def __exit__(self, exc_type, exc_value, traceback): class ValidContext: """ - Context manager to set valid values for parameters. Only inside a - ValidContext will parameters automatically be assumed valid. + Context manager that transforms parameter values to an unconstrained space. + + Inside a ``ValidContext``, all parameter values are automatically + mapped into the range ``(-inf, inf)`` via each parameter's + ``to_valid`` / ``from_valid`` transformations. This is useful when + interfacing with samplers or optimizers that expect unconstrained + parameters—any value they propose will be mapped back into the + parameter's original valid range on exit. + + Parameters + ---------- + module : Module + The module whose parameters should be transformed. + + Examples + -------- + Get unconstrained parameter values for use with an optimizer:: + + with ValidContext(my_module): + unconstrained_params = my_module.get_values() + # unconstrained_params live in (-inf, inf) """ def __init__(self, module: Module): @@ -51,8 +93,35 @@ def __exit__(self, exc_type, exc_value, traceback): class OverrideParam: """ - Context manager to override a parameter value. Only inside an - OverrideParam will the parameter be set to the new value. + Context manager to override a parameter value. + + Only inside an ``OverrideParam`` will the parameter be set to the new + value. The original value (and the values of any parent pointer + parameters) are saved on entry and restored on exit. + + Parameters + ---------- + param : Param + The parameter whose value should be temporarily overridden. + value : object + The temporary value to assign to *param*. + + Examples + -------- + Override a parameter inside a ``@forward`` method so that it uses + ``new_value`` regardless of what was passed via ``params``:: + + class MySim(Module): + def __init__(self): + super().__init__() + self.a = Param("a", None) + self.b = Param("b", None) + + @forward + def __call__(self, x, a=None, b=None): + with OverrideParam(self.b, 5.0): + # b will always be 5.0 here, ignoring params + return x + a + self.b.value """ def __init__(self, param: Param, value): diff --git a/src/caskade/decorators.py b/src/caskade/decorators.py index 6a6d6de..f4120d3 100644 --- a/src/caskade/decorators.py +++ b/src/caskade/decorators.py @@ -12,11 +12,20 @@ def forward(method): """ Decorator to define a forward method for a module. + Manages parameter passing and activation for the decorated method. When + called, it automatically fills keyword arguments from the module's + ``Param`` children and handles parameter overrides and active context. + Parameters ---------- method: (Callable) The forward method to be decorated. + Returns + ------- + Callable + The decorated forward method. + Examples -------- Standard usage of the forward decorator:: @@ -35,11 +44,6 @@ def example_func(self, x, b=None): E = ExampleSim(a=1, b=None, c=3) print(E.example_func(4, params=[5])) # Output: 10 - - Returns - ------- - Callable - The decorated forward method. """ # Get arguments from function signature @@ -102,16 +106,21 @@ class active_cache: active simulation run. Once calculated, subsequent calls to the decorated method will return the stored value, ignoring any arguments passed to it. - **WARNING**: - If the method is called multiple times with different arguments in one - simulation, the cached result will still be returned, which may lead to - unexpected behavior. Use with caution! + Warnings + -------- + If the method is called multiple times with different arguments in one + simulation, the cached result will still be returned, which may lead to + unexpected behavior. Use with caution! - Note: - If you are stacking multiple decorators on a method (such as `@forward` - or `@jax.jit`), `@active_cache` MUST be the outermost (top) decorator. + Notes + ----- + If you are stacking multiple decorators on a method (such as ``@forward`` + or ``@jax.jit``), ``@active_cache`` MUST be the outermost (top) decorator. + + Examples + -------- + :: - Example:: class FluxModel(Module): def __init__(self, nodes, x, M): super().__init__() diff --git a/src/caskade/errors.py b/src/caskade/errors.py index cb0fe4f..bf8665f 100644 --- a/src/caskade/errors.py +++ b/src/caskade/errors.py @@ -5,43 +5,104 @@ class CaskadeException(Exception): - """Base class for all exceptions in ``caskade``.""" + """ + Base class for all exceptions in ``caskade``. + + All custom exceptions raised by ``caskade`` inherit from this class, + allowing users to catch any ``caskade``-specific error with a single + except clause. + """ class GraphError(CaskadeException): - """Class for graph exceptions in ``caskade``.""" + """ + Exception for graph-related errors in ``caskade``. + + Raised when an operation on the computational graph is invalid, such as + creating cycles or referencing nonexistent nodes. + """ class BackendError(CaskadeException): - """Class for exceptions related to the backend in ``caskade``.""" + """ + Exception for backend-related errors in ``caskade``. + + Raised when the selected numerical backend encounters an unsupported + operation or configuration issue. + """ class LinkToAttributeError(GraphError): - """Class for exceptions related to linking to an attribute in ``caskade``.""" + """ + Exception raised when linking to an attribute fails. + + Raised when an attempt is made to create a link to a node attribute + that does not exist or is not a valid link target. + """ class NodeConfigurationError(CaskadeException): - """Class for node configuration exceptions in ``caskade``.""" + """ + Exception for node configuration errors in ``caskade``. + + Raised when a node is configured with invalid or incompatible settings. + """ class ParamConfigurationError(NodeConfigurationError): - """Class for parameter configuration exceptions in ``caskade``.""" + """ + Exception for parameter configuration errors in ``caskade``. + + Raised when a parameter is defined with an invalid shape, type, or + constraint. + """ class ParamTypeError(CaskadeException): - """Class for exceptions related to the type of a parameter in ``caskade``.""" + """ + Exception for parameter type errors in ``caskade``. + + Raised when a value assigned to a parameter does not match its + expected type. + """ class ActiveStateError(CaskadeException): - """Class for exceptions related to the active state of a node in ``caskade``.""" + """ + Exception for active-state errors in ``caskade``. + + Raised when an operation requires a node to be in a particular active + state (enabled or disabled) and that condition is not met. + """ class FillParamsError(CaskadeException): - """Class for exceptions related to filling parameters in ``caskade``""" + """ + Base exception for errors when filling parameters in ``caskade``. + + Raised when the input data provided to fill node parameters is + invalid. Subclasses handle specific input types (array, sequence, + mapping). + """ class FillParamsArrayError(FillParamsError): - """Class for exceptions related to filling parameters with ArrayLike objects in ``caskade``.""" + """ + Exception raised when filling parameters with an array fails. + + Raised when the shape of the input array does not match the total + number of flattened parameters registered on a node. + + Parameters + ---------- + name : str + Name of the node whose parameters are being filled. + input_params : ArrayLike + The input array that was provided. + params : tuple of Param + Registered parameters whose shapes are compared against the + input. + """ def __init__(self, name, input_params, params): fullnumel = sum(max(1, prod(p.shape)) for p in params) @@ -59,7 +120,21 @@ def __init__(self, name, input_params, params): class FillParamsSequenceError(FillParamsError): - """Class for exceptions related to filling parameters with a sequence (list, tuple, etc.) in ``caskade``.""" + """ + Exception raised when filling parameters with a sequence fails. + + Raised when the length of the input sequence does not match the + number of dynamic parameters registered on a node. + + Parameters + ---------- + name : str + Name of the node whose parameters are being filled. + input_params : sequence + The input sequence (list, tuple, etc.) that was provided. + dynamic_params : tuple of Param + Registered dynamic parameters expected by the node. + """ def __init__(self, name, input_params, dynamic_params): message = dedent( @@ -74,7 +149,22 @@ def __init__(self, name, input_params, dynamic_params): class FillParamsMappingError(FillParamsError): - """Class for exceptions related to filling parameters with a mapping (dict) in ``caskade``.""" + """ + Exception raised when filling parameters with a mapping fails. + + Raised when a key in the input dictionary does not correspond to any + registered child node. + + Parameters + ---------- + name : str + Name of the node whose parameters are being filled. + children : dict + Dictionary of registered child nodes. + missing_key : str, optional + The key from the input mapping that was not found among the + node's children. + """ def __init__(self, name, children, missing_key=None): message = dedent( diff --git a/src/caskade/mixins.py b/src/caskade/mixins.py index 6b5723d..bddf1f9 100644 --- a/src/caskade/mixins.py +++ b/src/caskade/mixins.py @@ -15,6 +15,14 @@ class GetSetValues: + """Mixin providing methods for getting and setting parameter values. + + Provides array, list, and dict interfaces for reading and writing + dynamic (or static) parameter values on a ``Module`` or + ``NodeCollection``. Also includes helpers for locating parameters + by index and for converting between raw and valid (transformed) + parameter spaces. + """ @property def valid_context(self) -> bool: @@ -115,7 +123,32 @@ def _set_values( def set_values( self, params: Union[ArrayLike, Sequence, Mapping], dynamic=True, attribute="value" ): - """Fill the dynamic values of the module with the input values from params.""" + """Fill parameter values of the module from the provided data. + + Parameters + ---------- + params : Union[ArrayLike, Sequence, Mapping] + Values to assign to the parameters. Accepted formats: + + * **ArrayLike** – a flat (or batched) array whose last dimension + is concatenated parameter values in topological order. + * **Sequence** – one element per parameter, matched by position. + * **Mapping** – keys matching child names, values being the + parameter data (may be nested). + + When multiple dynamic parameter groups exist, ``params`` should + be a sequence of per-group containers. + dynamic : bool, optional + If ``True`` (default), sets dynamic parameters; otherwise sets + static parameters. + attribute : str, optional + The ``Param`` attribute to write to, by default ``"value"``. + + Raises + ------ + ActiveStateError + If the module is currently in an active (tracing) state. + """ if self.active: raise ActiveStateError(f"Cannot fill dynamic values when Module {self.name} is active") @@ -149,6 +182,35 @@ def _check_values(self, param_list, scheme): def get_values( self, scheme="array", dynamic=True, attribute="value", group: Optional[int] = None ) -> Union[ArrayLike, list[ArrayLike], dict[str, Union[dict, ArrayLike]]]: + """Retrieve parameter values from the module. + + Parameters + ---------- + scheme : str, optional + Output format, one of ``"array"`` (default), ``"list"``, or + ``"dict"``. + + * ``"array"`` / ``"tensor"`` – returns a single flat array with + all parameter values concatenated along the last axis. + * ``"list"`` – returns a list of raw parameter values. + * ``"dict"`` – returns a nested dictionary mirroring the graph + structure. + dynamic : bool, optional + If ``True`` (default), retrieves dynamic parameters; otherwise + retrieves static parameters. + attribute : str, optional + The ``Param`` attribute to read from, by default ``"value"``. + group : int or None, optional + Restrict to a specific parameter group. When ``None`` (default) + and multiple groups exist, returns a list of per-group results. + + Returns + ------- + Union[ArrayLike, list[ArrayLike], dict[str, Union[dict, ArrayLike]]] + Parameter values in the format specified by *scheme*. When + multiple groups exist and *group* is ``None``, a list of + per-group results is returned. + """ if len(self.dynamic_param_groups) > 1 and group is None: values = [] for g in self.dynamic_param_groups: @@ -374,7 +436,32 @@ def _transform_params(self, node, init_params, param_list, transform_attr): def to_valid( self, params: Union[ArrayLike, Sequence, Mapping], param_list=None, group=None ) -> Union[ArrayLike, Sequence, Mapping]: - """Convert input params to valid params.""" + """Map parameter values from their natural range to an unconstrained space. + + Takes parameter values that lie within each parameter's valid range + (e.g. 0–1 for an axis ratio) and maps them into the unconstrained + domain ``(-inf, inf)``. The inverse mapping :meth:`from_valid` will + map any value in ``(-inf, inf)`` back into the original valid range. + This is useful for interfacing with samplers and optimizers that + require unconstrained parameters. + + Parameters + ---------- + params : Union[ArrayLike, Sequence, Mapping] + Raw parameter values in any supported format (array, sequence, + or mapping). + param_list : tuple of Param or None, optional + Subset of parameters to transform. Defaults to all dynamic + parameters. + group : int or None, optional + Restrict to a specific parameter group. When ``None`` and + multiple groups exist, all groups are transformed. + + Returns + ------- + Union[ArrayLike, Sequence, Mapping] + Transformed values in the same format as *params*. + """ if param_list is None: param_list = self.dynamic_params if len(self.dynamic_param_groups) > 1: @@ -393,7 +480,30 @@ def to_valid( def from_valid( self, valid_params: Union[ArrayLike, Sequence, Mapping], param_list=None, group=None ) -> Union[ArrayLike, Sequence, Mapping]: - """Convert valid params to input params.""" + """Map parameter values from the unconstrained space back to their natural range. + + Takes values in the unconstrained domain ``(-inf, inf)`` (as + produced by :meth:`to_valid` or proposed by an optimizer/sampler) + and maps them back into each parameter's original valid range + (e.g. 0–1 for an axis ratio). + + Parameters + ---------- + valid_params : Union[ArrayLike, Sequence, Mapping] + Parameter values in the unconstrained ``(-inf, inf)`` domain, + in any supported format (array, sequence, or mapping). + param_list : tuple of Param or None, optional + Subset of parameters to transform. Defaults to all dynamic + parameters. + group : int or None, optional + Restrict to a specific parameter group. When ``None`` and + multiple groups exist, all groups are transformed. + + Returns + ------- + Union[ArrayLike, Sequence, Mapping] + Inverse-transformed values in the same format as *valid_params*. + """ if param_list is None: param_list = self.dynamic_params if len(self.dynamic_param_groups) > 1: diff --git a/src/caskade/module.py b/src/caskade/module.py index 774faa8..64d4fe0 100644 --- a/src/caskade/module.py +++ b/src/caskade/module.py @@ -60,6 +60,17 @@ def otherfun(self, x, c = None): ) # These tuples will not be converted to NodeTuple objects def __init__(self, name: Optional[str] = None, **kwargs): + """ + Initialize a Module node. + + Parameters + ---------- + name : str, optional + The name of this module node. If not provided, a name is + automatically assigned by the base ``Node`` class. + **kwargs + Additional keyword arguments passed to the ``Node`` base class. + """ self.dynamic_params = () self.pointer_params = () self.static_params = () @@ -74,6 +85,14 @@ def graphviz_style(self): @property def all_params(self): + """ + All parameters below this module in the DAG. + + Returns + ------- + tuple of Param + Concatenation of static, dynamic, and pointer parameters. + """ return self.static_params + self.dynamic_params + self.pointer_params def update_graph(self): @@ -94,6 +113,17 @@ def update_graph(self): super().update_graph() def param_order(self): + """ + Return a human-readable string of dynamic parameter ordering. + + Each line corresponds to a parameter group and lists the parameters + in the format ``parent_name: param_name``. + + Returns + ------- + str + Multi-line string describing the dynamic parameter order. + """ res = [] for g in self.dynamic_param_groups: res.append( @@ -109,11 +139,26 @@ def param_order(self): @property def dynamic(self) -> bool: - """Return True if the module has dynamic parameters as direct children.""" + """ + Return True if the module has dynamic parameters as direct children. + + Returns + ------- + bool + True if any direct children are dynamic parameters. + """ return any(isinstance(n, Param) and n.dynamic for n in self.children.values()) @property def static(self) -> bool: + """ + Return True if the module has no dynamic parameters as direct children. + + Returns + ------- + bool + True if none of the direct children are dynamic parameters. + """ return not self.dynamic def to_dynamic(self, children_only=True): diff --git a/src/caskade/param.py b/src/caskade/param.py index 7fe79a4..c9f8f41 100644 --- a/src/caskade/param.py +++ b/src/caskade/param.py @@ -12,6 +12,27 @@ def valid_shape(batch_shape, shape, value_shape): + """Check whether a value's shape is compatible with a parameter's shape. + + Validates that ``value_shape`` is consistent with the declared ``shape`` + and optional ``batch_shape``. Dimensions set to ``None`` in ``shape`` + act as wildcards and match any size. + + Parameters + ---------- + batch_shape : tuple of int or None + Leading batch dimensions, or ``None`` if the parameter is not batched. + shape : tuple of int or None, or None + Expected event dimensions. Individual entries may be ``None`` + (wildcard). If the entire argument is ``None``, any shape is accepted. + value_shape : tuple of int + The actual shape of the value to validate. + + Returns + ------- + bool + ``True`` if the shapes are compatible, ``False`` otherwise. + """ # No shape to compare if shape is None: return True @@ -127,14 +148,37 @@ def __init__( @property def dynamic(self) -> bool: + """Whether this parameter is dynamic. + + Returns + ------- + bool + ``True`` if the parameter's value is provided at runtime. + """ return "dynamic" in self.node_type @property def pointer(self) -> bool: + """Whether this parameter is a pointer. + + Returns + ------- + bool + ``True`` if the parameter points to another ``Param`` or a + callable that is evaluated at runtime. + """ return "pointer" in self.node_type @property def static(self) -> bool: + """Whether this parameter is static. + + Returns + ------- + bool + ``True`` if the parameter holds a fixed value that does not + change at runtime. + """ return "static" in self.node_type @property @@ -170,6 +214,13 @@ def graphviz_style(self): @property def node_type(self): + """The current type of this parameter node. + + Returns + ------- + str + One of ``"static"``, ``"dynamic"``, or ``"pointer"``. + """ return self._node_type @node_type.setter @@ -180,8 +231,26 @@ def node_type(self, value): self.update_graph() def to_dynamic(self, value=NULL): - """Change this parameter to a dynamic parameter. If a value is provided, - this will be set as the dynamic value.""" + """Change this parameter to a dynamic parameter. + + If a value is provided, it is stored as the default dynamic value. + When called without arguments the existing value (if any) is kept. + + Parameters + ---------- + value : ArrayLike, float, int, None, or sentinel, optional + The default value for the dynamic parameter. Must not be a + ``Param`` or callable. By default the current value is retained. + + Raises + ------ + ActiveStateError + If the parameter is currently active. + ParamTypeError + If *value* is a ``Param`` or callable. + ParamConfigurationError + If the value shape does not match the declared shape. + """ # While active no value can be set if self.active: raise ActiveStateError(f"Cannot set parameter {self.name} dynamic value while active.") @@ -214,8 +283,26 @@ def to_dynamic(self, value=NULL): self.is_valid() def to_static(self, value=NULL): - """Change this parameter to a static parameter. If a value is provided - this will be set as the static value.""" + """Change this parameter to a static parameter. + + If a value is provided, it is stored as the fixed static value. + When called without arguments the existing value (if any) is kept. + + Parameters + ---------- + value : ArrayLike, float, int, None, or sentinel, optional + The constant value for the static parameter. Must not be a + ``Param`` or callable. By default the current value is retained. + + Raises + ------ + ActiveStateError + If the parameter is currently active. + ParamTypeError + If *value* is a ``Param`` or callable. + ParamConfigurationError + If the value shape does not match the declared shape. + """ # While active no value can be set if self.active: raise ActiveStateError(f"Cannot set parameter {self.name} static value while active.") @@ -249,11 +336,27 @@ def to_static(self, value=NULL): self.node_type = "static" def to_pointer(self, value, link=()): - """Change this parameter to a pointer parameter. If a value is provided - this will be set as the pointer. Either provide a Param object to point - to its value, or provide a callable function to be called at runtime. It - is also possible to provide a tuple of nodes to link to while creating - the pointer.""" + """Change this parameter to a pointer parameter. + + The parameter's value will be computed at runtime by dereferencing + another ``Param`` or by calling a user-supplied function. + + Parameters + ---------- + value : Param or callable + A ``Param`` whose value will be mirrored, or a callable + ``f(param) -> ArrayLike`` evaluated at runtime. + link : Node or tuple of Node, optional + Additional nodes to link into the graph when creating the + pointer. Defaults to an empty tuple. + + Raises + ------ + ActiveStateError + If the parameter is currently active. + ParamTypeError + If *value* is not a ``Param`` or callable. + """ # While active no value can be set if self.active: raise ActiveStateError(f"Cannot set parameter {self.name} to pointer while active") @@ -273,6 +376,17 @@ def to_pointer(self, value, link=()): @property def shape(self) -> tuple[int, ...]: + """The event (non-batch) shape of the parameter value. + + Wildcard dimensions (``None``) in the declared shape are resolved + using the current value. If no shape was declared, the shape of the + current value is returned directly. + + Returns + ------- + tuple of int + The resolved shape of the parameter. + """ value = self.value # 1. Handle cases where no shape template is defined if self._shape is None: @@ -313,10 +427,28 @@ def shape(self, shape: Optional[Iterable]): @property def batched(self) -> bool: + """Whether this parameter carries batch dimensions. + + Returns + ------- + bool + ``True`` if ``batch_shape`` is non-empty. + """ return len(self.batch_shape) > 0 @property def batch_shape(self) -> tuple[int, ...]: + """The batch dimensions of the parameter value. + + Batch dimensions are the leading dimensions of the value that + precede the event ``shape``. If an explicit batch shape was set it + is returned directly; otherwise it is inferred from the value. + + Returns + ------- + tuple of int + The batch shape, or ``()`` if the parameter is not batched. + """ if self._batch_shape is not None: return self._batch_shape try: @@ -337,6 +469,18 @@ def batch_shape(self, batch_shape: tuple[int]): @property def group(self) -> int: + """The group index of this parameter. + + Parameters that share the same group index are collected together + into a single ``params`` object when calling a simulator's + ``@forward`` method, as well as when using ``get_values`` or + ``set_values``. + + Returns + ------- + int + The group index (default ``0``). + """ return self._group @group.setter @@ -349,6 +493,16 @@ def group(self, group: int): @property def dtype(self) -> Optional[str]: + """The data type of the parameter value. + + If no explicit dtype was set, the dtype is inferred from the + current value. + + Returns + ------- + dtype or None + The data type, or ``None`` if unknown. + """ if self._dtype is None: try: return self.value.dtype @@ -358,6 +512,16 @@ def dtype(self) -> Optional[str]: @property def device(self) -> Optional[str]: + """The device on which the parameter value resides. + + If no explicit device was set, the device is inferred from the + current value. + + Returns + ------- + device or None + The device, or ``None`` if unknown. + """ if self._device is None: try: return self.value.device @@ -367,6 +531,17 @@ def device(self) -> Optional[str]: @property def value(self) -> Union[ArrayLike, None]: + """The current value of the parameter. + + For static and dynamic parameters the stored value is returned. + For pointer parameters the linked callable is evaluated. During an + active simulation the result is cached. + + Returns + ------- + ArrayLike or None + The parameter value, or ``None`` if no value has been set. + """ if self._value is not None: return self._value if self.pointer: @@ -394,6 +569,13 @@ def value(self, value): @property def npvalue(self) -> ndarray: + """The current value converted to a NumPy array. + + Returns + ------- + numpy.ndarray + The value as a NumPy ``ndarray``. + """ return backend.to_numpy(self.value) def to(self, device=None, dtype=None) -> "Param": @@ -429,6 +611,16 @@ def to(self, device=None, dtype=None) -> "Param": @property def cyclic(self) -> bool: + """Whether the parameter has cyclic (periodic) boundary conditions. + + When ``True``, values wrap around the ``valid`` range (e.g. an + angle from 0 to 2π). + + Returns + ------- + bool + ``True`` if the parameter is cyclic. + """ return self._cyclic @cyclic.setter @@ -527,6 +719,14 @@ def _load_state_hdf5(self, h5group, index: int = -1, _done_load: set = None): @property def valid(self) -> tuple[Optional[ArrayLike], Optional[ArrayLike]]: + """The valid range of the parameter value. + + Returns + ------- + tuple of (ArrayLike or None, ArrayLike or None) + ``(lower_bound, upper_bound)``. Either bound may be ``None`` + indicating no constraint on that side. + """ return self._valid @valid.setter @@ -572,7 +772,21 @@ def valid(self, valid: tuple[Union[ArrayLike, float, int, None]]): self.is_valid() def is_valid(self, value=None) -> bool: - """Check if a given value is valid given this parameters allowed (valid) range.""" + """Check whether a value lies within the allowed range. + + Parameters + ---------- + value : ArrayLike or None, optional + The value to check. If ``None`` (default), the parameter's + current value is used. + + Returns + ------- + bool + ``True`` if the value is within the valid range or if no + constraints are set. ``False`` otherwise; a warning is also + emitted. + """ if self.cyclic or self.pointer: return True if value is None: diff --git a/src/caskade/tests.py b/src/caskade/tests.py index 6626e25..1c54646 100644 --- a/src/caskade/tests.py +++ b/src/caskade/tests.py @@ -41,6 +41,19 @@ def __call__(self, d=None, e=None, f=None): def test(): - """Basic integration test of caskade to ensure that the library is functioning correctly.""" + """ + Run a basic integration test to verify caskade is installed and working. + + Exercises core functionality including Module and Param creation, + parameter linking, and forward method execution. + + Examples + -------- + :: + + import caskade + caskade.test() + # Output: Success! + """ _test_full_integration() print("Success!") diff --git a/src/caskade/utils.py b/src/caskade/utils.py index b709650..73860cd 100644 --- a/src/caskade/utils.py +++ b/src/caskade/utils.py @@ -9,14 +9,19 @@ def broadcast_cat_torch(tensors, dim=-1): It behaves like torch.cat, but first broadcasts the tensors to match on all dimensions EXCEPT the concatenation dimension. - Args: - tensors (sequence of Tensors): Tensors to concatenate. - dim (int): The dimension along which to concatenate. - Must be a negative index to ensure consistency across - tensors of different ranks (e.g., -1 for the last dimension). - - Returns: - Tensor: The concatenated tensor. + Parameters + ---------- + tensors : sequence of Tensors + Tensors to concatenate. + dim : int + The dimension along which to concatenate. + Must be a negative index to ensure consistency across + tensors of different ranks (e.g., -1 for the last dimension). + + Returns + ------- + Tensor + The concatenated tensor. """ if not tensors: raise ValueError("tensors argument must be a non-empty sequence") @@ -94,12 +99,17 @@ def broadcast_cat_jax(arrays, dim=-1): Behaves like jnp.concatenate, but first broadcasts the arrays to match on all dimensions EXCEPT the concatenation dimension. - Args: - arrays (sequence of jnp.ndarray): Arrays to concatenate. - dim (int): The dimension along which to concatenate. - - Returns: - jnp.ndarray: The concatenated array. + Parameters + ---------- + arrays : sequence of jnp.ndarray + Arrays to concatenate. + dim : int + The dimension along which to concatenate. + + Returns + ------- + jnp.ndarray + The concatenated array. """ import jax.numpy as jnp @@ -161,12 +171,17 @@ def broadcast_cat_numpy(arrays, dim=-1): Behaves like np.concatenate, but first broadcasts the arrays to match on all dimensions EXCEPT the concatenation dimension. - Args: - arrays (sequence of np.ndarray): Arrays to concatenate. - dim (int): The dimension along which to concatenate. - - Returns: - np.ndarray: The concatenated array. + Parameters + ---------- + arrays : sequence of np.ndarray + Arrays to concatenate. + dim : int + The dimension along which to concatenate. + + Returns + ------- + np.ndarray + The concatenated array. """ if not arrays: raise ValueError("arrays argument must be a non-empty sequence") diff --git a/src/caskade/warnings.py b/src/caskade/warnings.py index 11829c0..03954e2 100644 --- a/src/caskade/warnings.py +++ b/src/caskade/warnings.py @@ -4,11 +4,31 @@ class CaskadeWarning(Warning): - """Base warning for ``caskade``.""" + """ + Base warning class for ``caskade``. + + All custom warnings issued by ``caskade`` inherit from this class, + allowing users to filter or catch any ``caskade``-specific warning. + """ class InvalidValueWarning(CaskadeWarning): - """Warning for values which fall outside the valid range.""" + """ + Warning issued when a parameter value is outside its valid range. + + Indicates that the assigned value may cause errors or unexpected + behavior during computation. + + Parameters + ---------- + name : str + Name of the parameter with the out-of-range value. + value : ArrayLike + The value that was assigned. + valid : tuple + A ``(lower, upper)`` tuple defining the valid range, where + ``None`` represents negative or positive infinity. + """ def __init__(self, name, value, valid): message = dedent( @@ -20,4 +40,9 @@ def __init__(self, name, value, valid): class SaveStateWarning(CaskadeWarning): - """Warning for when an issue occurs when a state is saved.""" + """ + Warning issued when saving state encounters a non-fatal problem. + + Issued when the state serialization completes but with potential data + loss or format issues that the user should be aware of. + """ From 55c32e9c1ed219f24961c657424820350614e68f Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Thu, 5 Mar 2026 09:41:59 -0500 Subject: [PATCH 2/3] feat: Add NodeDict collection for named node mappings (#80) * Initial plan * feat: add NodeDict collection class for dictionary of nodes Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> * test: add dynamic property and update-with-kwargs tests for NodeDict Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> * docs: add Node Collections section to AdvancedGuide tutorial Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> * clean up node collections in advanced notebook * make setdefault behaviour safer to improper inputs Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * improve code rendering in tutorial notebook Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * include name passed to super in example Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: wrap all NodeList/NodeDict mutation methods in try/finally to ensure re-linking Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> * add docstrings, handle setdefault error cases * add popitem test --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> Co-authored-by: Connor Stone Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .gitignore | 3 +- docs/source/notebooks/AdvancedGuide.ipynb | 84 ++++++++++++ src/caskade/__init__.py | 3 +- src/caskade/collection.py | 157 +++++++++++++++++++--- src/caskade/module.py | 5 +- tests/test_collection.py | 142 +++++++++++++++++++ 6 files changed, 372 insertions(+), 22 deletions(-) diff --git a/.gitignore b/.gitignore index 589cb21..ec2786b 100644 --- a/.gitignore +++ b/.gitignore @@ -166,6 +166,5 @@ cython_debug/ **.h5 **settings.json - -# Test artifacts test_graph +test_graph_dict diff --git a/docs/source/notebooks/AdvancedGuide.ipynb b/docs/source/notebooks/AdvancedGuide.ipynb index bd46ab3..9e4fe73 100644 --- a/docs/source/notebooks/AdvancedGuide.ipynb +++ b/docs/source/notebooks/AdvancedGuide.ipynb @@ -702,6 +702,90 @@ "Note that the groups will automatically order themselves by sorting the group integers, so you can easily pick where each group goes in the params. You can even place groups ahead of group 0 by using negative integers." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Node Collections: `NodeList`, `NodeTuple`, and `NodeDict`\n", + "\n", + "Sometimes you want to group a subset of nodes together without wrapping them inside a new `Module`. `caskade` provides three lightweight collection types for this purpose:\n", + "\n", + "- **`NodeList`** – a mutable, ordered list of nodes (supports `append`, `insert`, `pop`, etc.)\n", + "- **`NodeTuple`** – an immutable, ordered tuple of nodes\n", + "- **`NodeDict`** – a mutable mapping of `str → node`, with attribute-style access\n", + "\n", + "All three behave like their built-in Python counterparts while also being `Node` objects themselves, so they participate fully in the caskade graph.\n", + "\n", + "When you assign a plain Python `list`, `tuple`, or `dict` of nodes as an attribute of a `Module`, `caskade` automatically converts it to the appropriate collection type:\n", + "\n", + "```python\n", + "self.my_params = [p1, p2, p3] # becomes NodeList\n", + "self.my_params = (p1, p2, p3) # becomes NodeTuple\n", + "self.my_params = {'a': p1, 'b': p2} # becomes NodeDict\n", + "```\n", + "\n", + "This means you never have to construct `NodeList` / `NodeTuple` / `NodeDict` explicitly inside a `Module` — just assign the plain collection and caskade handles the rest.\n", + "\n", + "### Use cases\n", + "\n", + "- **Gibbs-style sampling** – keep one `NodeList` of the parameters you currently want to sample (dynamic) and another for the parameters you want to hold fixed (static). Calling `.to_dynamic()` / `.to_static()` on the collection flips the whole subset at once.\n", + "- **Selective saving / updating** – build a `NodeDict` of the parameters you care about and call `get_values()` / `set_values()` on just that subset, without touching the rest of the graph.\n", + "- **Quick inspection** – group modules or params of interest into a collection and iterate over them for printing, plotting, or diagnostics without restructuring the model.\n", + "\n", + "Note that `set_values` and `get_values` also work on Node collections." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Collections can be constructed directly\n", + "G = Gaussian(\"G\", x0=5, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0)\n", + "\n", + "# NodeList: mutable, ordered\n", + "position_params = ck.NodeList([G.x0, G.y0])\n", + "print(\"NodeList:\", position_params)\n", + "\n", + "# NodeTuple: immutable, ordered\n", + "shape_params = ck.NodeTuple((G.q, G.phi, G.sigma))\n", + "print(\"NodeTuple:\", shape_params)\n", + "\n", + "# NodeDict: mutable, keyed\n", + "named_params = ck.NodeDict({\"x0\": G.x0, \"y0\": G.y0, \"sigma\": G.sigma})\n", + "print(\"NodeDict:\", named_params)\n", + "print(\"Attribute access:\", named_params[\"x0\"]) # same as named_params.x0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Auto-conversion when assigned to a Module attribute\n", + "class GaussianWithCollections(ck.Module):\n", + " def __init__(self, name, submod):\n", + " super().__init__(name)\n", + " self.g = submod\n", + " # plain list/dict – caskade converts them automatically\n", + " self.position = [submod.x0, submod.y0] # -> NodeList\n", + " self.named = {\"q\": submod.q, \"phi\": submod.phi} # -> NodeDict\n", + "\n", + " @ck.forward\n", + " def __call__(self, x, y):\n", + " return self.g(x, y)\n", + "\n", + "\n", + "G2 = Gaussian(\"G2\", x0=5, y0=5, q=0.5, phi=0.0, sigma=1.0, I0=1.0)\n", + "gc = GaussianWithCollections(\"gc\", G2)\n", + "print(type(gc.position)) # NodeList\n", + "print(type(gc.named)) # NodeDict\n", + "gc.position.to_dynamic() # can call collection methods\n", + "display(gc.graphviz())" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/src/caskade/__init__.py b/src/caskade/__init__.py index 58ed274..cb22394 100644 --- a/src/caskade/__init__.py +++ b/src/caskade/__init__.py @@ -25,7 +25,7 @@ from .decorators import forward, active_cache from .module import Module from .param import Param -from .collection import NodeCollection, NodeList, NodeTuple +from .collection import NodeCollection, NodeList, NodeTuple, NodeDict from .tests import test from .errors import ( CaskadeException, @@ -58,6 +58,7 @@ "NodeCollection", "NodeList", "NodeTuple", + "NodeDict", "ActiveContext", "ValidContext", "OverrideParam", diff --git a/src/caskade/collection.py b/src/caskade/collection.py index 782ceb9..c8bdddb 100644 --- a/src/caskade/collection.py +++ b/src/caskade/collection.py @@ -208,39 +208,51 @@ def _link_nodes(self): def append(self, node): """Append a node to the list and update graph links.""" self._unlink_nodes() - super().append(node) - self._link_nodes() + try: + super().append(node) + finally: + self._link_nodes() def insert(self, index, node): """Insert a node at the given index and update graph links.""" self._unlink_nodes() - super().insert(index, node) - self._link_nodes() + try: + super().insert(index, node) + finally: + self._link_nodes() def extend(self, iterable): """Extend the list with nodes from an iterable and update graph links.""" self._unlink_nodes() - super().extend(iterable) - self._link_nodes() + try: + super().extend(iterable) + finally: + self._link_nodes() def clear(self): """Remove all nodes from the list and update graph links.""" self._unlink_nodes() - super().clear() - self._link_nodes() + try: + super().clear() + finally: + self._link_nodes() def pop(self, index=-1): """Remove and return a node at the given index, updating graph links.""" self._unlink_nodes() - node = super().pop(index) - self._link_nodes() + try: + node = super().pop(index) + finally: + self._link_nodes() return node def remove(self, value): """Remove the first occurrence of a node and update graph links.""" self._unlink_nodes() - super().remove(value) - self._link_nodes() + try: + super().remove(value) + finally: + self._link_nodes() def __getitem__(self, key): if isinstance(key, str): @@ -251,13 +263,17 @@ def __getitem__(self, key): def __setitem__(self, key, value): self._unlink_nodes() - super().__setitem__(key, value) - self._link_nodes() + try: + super().__setitem__(key, value) + finally: + self._link_nodes() def __delitem__(self, key): self._unlink_nodes() - super().__delitem__(key) - self._link_nodes() + try: + super().__delitem__(key) + finally: + self._link_nodes() def __add__(self, other): res = super().__add__(other) @@ -265,9 +281,114 @@ def __add__(self, other): def __iadd__(self, other): self._unlink_nodes() - ret = super().__iadd__(other) - self._link_nodes() + try: + ret = super().__iadd__(other) + finally: + self._link_nodes() return ret def __imul__(self, other): raise NotImplementedError + + +class NodeDict(NodeCollection, dict): + """Mutable, keyed collection of nodes. + + Behaves like a standard ``dict`` but also participates in the caskade + node graph. All elements must be ``Node`` instances. Graph links are + automatically updated whenever the dict is modified. + + Parameters + ---------- + mapping : mapping of str to Node, optional + Nodes to include in the dict. Defaults to an empty dict. + name : str, optional + Human-readable name for this collection of nodes. + """ + + def __init__(self, mapping=None, name=None): + if mapping is None: + mapping = {} + dict.__init__(self, mapping) + Node.__init__(self, name=name) + self.node_type = "ndict" + self._link_nodes() + + @property + def graphviz_style(self): + return {"style": "solid", "color": "black", "shape": "component"} + + @property + def dynamic(self): + return any(node.dynamic for node in dict.values(self)) + + def _unlink_nodes(self): + for node in dict.values(self): + self.unlink(node) + + def _link_nodes(self): + for key, node in dict.items(self): + if not isinstance(node, Node): + raise TypeError(f"NodeDict values must be Node objects, not {type(node)}") + self.link(key, node) + + def __getitem__(self, key): + return dict.__getitem__(self, key) + + def __setitem__(self, key, node): + self._unlink_nodes() + try: + dict.__setitem__(self, key, node) + finally: + self._link_nodes() + + def __delitem__(self, key): + self._unlink_nodes() + try: + dict.__delitem__(self, key) + finally: + self._link_nodes() + + def update(self, mapping=None, **kwargs): + """Update the dict with another mapping (i.e. dict) and update graph links.""" + self._unlink_nodes() + try: + if mapping is not None: + dict.update(self, mapping) + if kwargs: + dict.update(self, kwargs) + finally: + self._link_nodes() + + def pop(self, key, *args): + """Remove and return a node from the dict and update graph links.""" + self._unlink_nodes() + try: + node = dict.pop(self, key, *args) + finally: + self._link_nodes() + return node + + def popitem(self): + """Remove and return an arbitrary (key, node) pair from the dict (the last one inserted) and update graph links.""" + self._unlink_nodes() + try: + key, node = dict.popitem(self) + finally: + self._link_nodes() + return key, node + + def clear(self): + """Remove all nodes from the dict and update graph links.""" + self._unlink_nodes() + dict.clear(self) + + def setdefault(self, key, default): + """If key is in the dictionary, return its value. If not, insert key with a value of default and return default. Update graph links.""" + # Preserve dict.setdefault API shape but enforce NodeDict invariants + if key in self: + return self[key] + if not isinstance(default, Node): + raise TypeError(f"NodeDict values must be Node objects, not {type(default)}") + self[key] = default + return default diff --git a/src/caskade/module.py b/src/caskade/module.py index 64d4fe0..146f5fd 100644 --- a/src/caskade/module.py +++ b/src/caskade/module.py @@ -3,7 +3,7 @@ from .backend import ArrayLike from .base import Node from .param import Param -from .collection import NodeTuple, NodeList +from .collection import NodeTuple, NodeList, NodeDict from .mixins import GetSetValues from .errors import ActiveStateError, FillParamsError @@ -281,6 +281,9 @@ def __setattr__(self, key: str, value: Any): ): if len(value) > 0 and all(isinstance(v, Node) for v in value): value = NodeTuple(value, name=key) + elif isinstance(value, dict) and not isinstance(value, NodeDict): + if len(value) > 0 and all(isinstance(v, Node) for v in value.values()): + value = NodeDict(value, name=key) except AttributeError: pass super().__setattr__(key, value) diff --git a/tests/test_collection.py b/tests/test_collection.py index c02e259..f41a2b8 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -4,6 +4,7 @@ from caskade import ( NodeList, NodeTuple, + NodeDict, Param, Module, backend, @@ -172,15 +173,19 @@ def test_collection_in_module(): l1 = [Param("ptest1"), Param("ptest2"), Module("mtest1"), Module("mtest2")] t1 = (Param("ptest3"), Param("ptest4"), Module("mtest3"), Module("mtest4")) + d1 = {"ptest5": Param("ptest5"), "ptest6": Param("ptest6"), "mtest5": Module("mtest5")} m1 = Module("test") m1.l = l1 m1.t = t1 + m1.d = d1 assert m1["l"][2] == l1[2] assert m1["t"][2] == t1[2] assert m1.l[3] == l1[3] assert m1.t[3] == t1[3] + assert m1["d"]["ptest5"] == d1["ptest5"] + assert m1.d["mtest5"] == d1["mtest5"] @pytest.mark.parametrize("node_type", [NodeTuple, NodeList]) @@ -277,3 +282,140 @@ def test_valid_tuple(node_tuple, params_type, group): for i in range(len(node_tuple.dynamic_param_groups)): assert backend.module.allclose(init_params[i], round_trip_params[i]) assert backend.module.allclose(init_params[i], final_params[i]) + + +def test_node_dict_creation(): + + # Minimal creation + n1 = NodeDict() + assert n1.name.startswith("NodeDict") + assert len(n1) == 0 + + # Creation with dict of param nodes + params = {"p1": Param("p1"), "p2": Param("p2")} + n2 = NodeDict(params) + assert len(n2) == 2 + assert n2["p1"] is params["p1"] + assert n2.p1 is params["p1"] + assert n2["p2"] is params["p2"] + + # Creation with dict of module nodes + modules = {"m1": Module("m1"), "m2": Module("m2"), "m3": Module("m3")} + n3 = NodeDict(modules) + assert len(n3) == 3 + assert n3["m1"] is modules["m1"] + assert n3.m1 is modules["m1"] + assert n3["m2"] is modules["m2"] + + # Check repr + assert isinstance(repr(n3), str) + assert "[3]" in repr(n3) + + # Check to static/dynamic + n2.to_dynamic(False) + assert len(n2.static_params) == 0 + n2.to_static(False) + assert len(n2.static_params) == 2 + assert len(n2.pointer_params) == 0 + + # Graphviz + graph = n3.graphviz(saveto="test_graph_dict.pdf") + assert graph is not None, "should return a graphviz object" + assert os.path.exists("test_graph_dict.pdf") + os.remove("test_graph_dict.pdf") + + # Check copy + with pytest.raises(NotImplementedError): + n3.copy() + with pytest.raises(NotImplementedError): + n3.deepcopy() + + # Check bad init + with pytest.raises(TypeError): + NodeDict({"bad": 1}) + + +def test_node_dict_manipulation(): + + params = {"p1": Param("p1", 1), "p2": Param("p2", 2)} + modules = {"m1": Module("m1"), "m2": Module("m2"), "m3": Module("m3")} + nd = NodeDict(params) + + # Set item + p3 = Param("p3", 3) + nd["p3"] = p3 + assert len(nd) == 3 + assert nd["p3"] is p3 + assert nd.p3 is p3 + + # Update + nd.update(modules) + assert len(nd) == 6 + assert nd["m1"] is modules["m1"] + assert nd.m1 is modules["m1"] + + # Pop + popped = nd.pop("m3") + assert popped is modules["m3"] + assert len(nd) == 5 + assert "m3" not in nd + + # Del item + del nd["m2"] + assert len(nd) == 4 + assert "m2" not in nd + + # Popitem + key, _ = nd.popitem() + assert key not in nd + + # Clear + nd.clear() + assert len(nd) == 0 + + # Setdefault + nd2 = NodeDict({"p1": Param("p1")}) + p_new = Param("p_new") + nd2.setdefault("new_key", p_new) + assert nd2["new_key"] is p_new + assert len(nd2) == 2 + # setdefault should not overwrite existing + existing = nd2["p1"] + nd2.setdefault("p1", Param("p1_other")) + assert nd2["p1"] is existing + with pytest.raises(TypeError): + nd2.setdefault("bad_key", "not a node") + + # Check to static/dynamic + nd3 = NodeDict({"p1": Param("p1", 1), "p2": Param("p2", 2)}) + nd3.to_dynamic() + nd3.to_static() + + # dynamic property + assert nd3.static + assert not nd3.dynamic + nd3.to_dynamic() + assert nd3.dynamic + assert not nd3.static + + # Update with kwargs + nd4 = NodeDict({"p1": Param("p1")}) + m_kw = Module("mkw") + nd4.update(mkw=m_kw) + assert len(nd4) == 2 + assert nd4["mkw"] is m_kw + assert nd4.mkw is m_kw + + # mul raises NotImplementedError + with pytest.raises(NotImplementedError): + nd3 * 2 + + +def test_node_dict_param_values(): + nd = NodeDict({"p1": Param("p1"), "p2": Param("p2"), "p3": Param("p3")}) + + nd.set_values([1, 2, 3]) + + assert nd["p1"].value.item() == 1.0 + assert nd["p2"].value.item() == 2.0 + assert nd["p3"].value.item() == 3.0 From 2d996abf25647404512ab23303311867f74df772 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Thu, 5 Mar 2026 19:55:35 -0500 Subject: [PATCH 3/3] Allow `Node.unlink()` to clear all children when called with no arguments (#79) * Initial plan * Make unlink() with no arguments clear all children Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> * Address review: use recursive unlink for no-args, move notebook section Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> * add check for trying to unlink key that isnt a child * make test for unlink missing key --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: ConnorStoneAstro <78555321+ConnorStoneAstro@users.noreply.github.com> Co-authored-by: Connor Stone --- docs/source/notebooks/BeginnersGuide.ipynb | 38 +++++++++++++++++++++- pyproject.toml | 2 +- src/caskade/base.py | 24 +++++++++----- tests/test_base.py | 10 ++++++ 4 files changed, 64 insertions(+), 10 deletions(-) diff --git a/docs/source/notebooks/BeginnersGuide.ipynb b/docs/source/notebooks/BeginnersGuide.ipynb index bf39e9c..e1e05a4 100644 --- a/docs/source/notebooks/BeginnersGuide.ipynb +++ b/docs/source/notebooks/BeginnersGuide.ipynb @@ -312,6 +312,42 @@ "As you can see, a `pointer` parameter is represented in the graph as a shaded arrow. It will now return the same value as the `x0` parameter in `secondsim`." ] }, + { + "cell_type": "markdown", + "id": "link_unlink_md", + "metadata": {}, + "source": [ + "### Linking and unlinking params\n", + "\n", + "Pointer parameters can be linked to and unlinked from other nodes. ", + "Use `link(node)` to connect a child node, and `unlink(node)` (or a key string) to disconnect a specific child. ", + "Calling `unlink()` with no arguments removes **all** children at once, acting as a convenient clear." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "link_unlink_code", + "metadata": {}, + "outputs": [], + "source": [ + "time_param = ck.Param(\"mytime\") # a standalone param to link against\n", + "shared_x0 = ck.Param(\"shared_x0\", shape=(2,))\n", + "\n", + "# Link a child node using a key or by passing the node directly\n", + "shared_x0.link(\"mytime\", time_param)\n", + "print(\"Children after link:\", list(shared_x0.children))\n", + "\n", + "# Unlink a specific child by key or node reference\n", + "shared_x0.unlink(\"mytime\")\n", + "print(\"Children after unlink(key):\", list(shared_x0.children))\n", + "\n", + "# Re-link and then clear all children at once\n", + "shared_x0.link(\"mytime\", time_param)\n", + "shared_x0.unlink() # removes all children\n", + "print(\"Children after unlink():\", list(shared_x0.children))" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -568,4 +604,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 4ac2230..0f62a30 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ keywords = [ "pytorch" ] classifiers=[ - "Development Status :: 1 - Planning", + "Development Status :: 5 - Production/Stable", "Intended Audience :: Science/Research", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", diff --git a/src/caskade/base.py b/src/caskade/base.py index c46550f..7a80a95 100644 --- a/src/caskade/base.py +++ b/src/caskade/base.py @@ -272,31 +272,39 @@ def _unlink(self, key: str): del self.children[key] self.update_graph() - def unlink(self, key: Union[str, "Node", list, tuple]): - """ - Unlink one or more child nodes from this node. + def unlink(self, key: Union[str, "Node", list, tuple, None] = None): + """Unlink one or more ``Node`` objects from this ``Node``. Parameters ---------- - key : str, Node, list, or tuple - Identifier of the child(ren) to remove. May be a link key - string, the child ``Node`` object itself, or a list/tuple of - keys or nodes to unlink in bulk. - + key: (str, Node, list, tuple, or None, optional) + The key, ``Node`` object, or collection of keys/nodes to unlink. + If a string, the child with that key is unlinked. If a ``Node`` + object, the matching child is located and unlinked. If a list or + tuple, each element is unlinked in turn. If ``None`` (the + default), all children are unlinked. + Raises ------ GraphError If the graph is currently active. """ + if key is None: + self.unlink(list(self.children)) + return if isinstance(key, Node): for node in self.children: if self.children[node] is key: key = node break + else: + raise KeyError(f"Node {key.name} not found in parent {self.name}") elif isinstance(key, (tuple, list)): for k in key: self.unlink(k) return + if key not in self.children: + raise KeyError(f"Child key '{key}' not found in parent {self.name}") self.__delattr__(key) def topological_ordering(self) -> tuple["Node"]: diff --git a/tests/test_base.py b/tests/test_base.py index e566889..0929710 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -85,8 +85,18 @@ def test_linking(node_graph): assert e in a.topological_ordering() with pytest.raises(AttributeError): a.e + with pytest.raises(KeyError): + a.unlink(e) + with pytest.raises(KeyError): + a.unlink("e") a.unlink((b, c)) + # Check unlink with no arguments clears all children + a.link(e) + assert len(a.children) > 0 + a.unlink() + assert len(a.children) == 0 + def test_graphviz(node_graph): a, *_ = node_graph