diff --git a/python/lsst/pipe/base/__init__.py b/python/lsst/pipe/base/__init__.py
index 74339da90..51652a714 100644
--- a/python/lsst/pipe/base/__init__.py
+++ b/python/lsst/pipe/base/__init__.py
@@ -1,4 +1,4 @@
-from . import automatic_connection_constants, connectionTypes, pipelineIR
+from . import automatic_connection_constants, connectionTypes, pipeline_graph, pipelineIR
from ._dataset_handle import *
from ._instrument import *
from ._observation_dimension_packer import *
@@ -11,6 +11,10 @@
from .graph import *
from .graphBuilder import *
from .pipeline import *
+
+# We import the main PipelineGraph type and the module (above), but we don't
+# lift all symbols to package scope.
+from .pipeline_graph import PipelineGraph
from .pipelineTask import *
from .struct import *
from .task import *
diff --git a/python/lsst/pipe/base/pipeTools.py b/python/lsst/pipe/base/pipeTools.py
index 0a6f249f1..3c6bbe74e 100644
--- a/python/lsst/pipe/base/pipeTools.py
+++ b/python/lsst/pipe/base/pipeTools.py
@@ -2,7 +2,7 @@
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
-# (http://www.lsst.org).
+# (http://www.lsst.org).XS
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
@@ -27,30 +27,17 @@
# No one should do import * from this module
__all__ = ["isPipelineOrdered", "orderPipeline"]
-# -------------------------------
-# Imports of standard modules --
-# -------------------------------
-import itertools
from collections.abc import Iterable
from typing import TYPE_CHECKING
-# -----------------------------
-# Imports for other modules --
-# -----------------------------
-from .connections import iterConnections
+from .pipeline import Pipeline, TaskDef
+
+# Exceptions re-exported here for backwards compatibility.
+from .pipeline_graph import DuplicateOutputError, PipelineDataCycleError, PipelineGraph # noqa: F401
if TYPE_CHECKING:
- from .pipeline import Pipeline, TaskDef
from .taskFactory import TaskFactory
-# ----------------------------------
-# Local non-exported definitions --
-# ----------------------------------
-
-# ------------------------
-# Exported definitions --
-# ------------------------
-
class MissingTaskFactoryError(Exception):
"""Exception raised when client fails to provide TaskFactory instance."""
@@ -58,20 +45,6 @@ class MissingTaskFactoryError(Exception):
pass
-class DuplicateOutputError(Exception):
- """Exception raised when Pipeline has more than one task for the same
- output.
- """
-
- pass
-
-
-class PipelineDataCycleError(Exception):
- """Exception raised when Pipeline has data dependency cycle."""
-
- pass
-
-
def isPipelineOrdered(pipeline: Pipeline | Iterable[TaskDef], taskFactory: TaskFactory | None = None) -> bool:
"""Check whether tasks in pipeline are correctly ordered.
@@ -80,15 +53,15 @@ def isPipelineOrdered(pipeline: Pipeline | Iterable[TaskDef], taskFactory: TaskF
Parameters
----------
- pipeline : `pipe.base.Pipeline`
+ pipeline : `Pipeline` or `collections.abc.Iterable` [ `TaskDef` ]
Pipeline description.
- taskFactory: `pipe.base.TaskFactory`, optional
- Instance of an object which knows how to import task classes. It is
- only used if pipeline task definitions do not define task classes.
+ taskFactory: `TaskFactory`, optional
+ Ignored; present only for backwards compatibility.
Returns
-------
- True for correctly ordered pipeline, False otherwise.
+ is_ordered : `bool`
+ True for correctly ordered pipeline, False otherwise.
Raises
------
@@ -96,118 +69,50 @@ def isPipelineOrdered(pipeline: Pipeline | Iterable[TaskDef], taskFactory: TaskF
Raised when task class cannot be imported.
DuplicateOutputError
Raised when there is more than one producer for a dataset type.
- MissingTaskFactoryError
- Raised when TaskFactory is needed but not provided.
"""
- # Build a map of DatasetType name to producer's index in a pipeline
- producerIndex = {}
- for idx, taskDef in enumerate(pipeline):
- for attr in iterConnections(taskDef.connections, "outputs"):
- if attr.name in producerIndex:
- raise DuplicateOutputError(
- "DatasetType `{}' appears more than once as output".format(attr.name)
- )
- producerIndex[attr.name] = idx
-
- # check all inputs that are also someone's outputs
- for idx, taskDef in enumerate(pipeline):
- # get task input DatasetTypes, this can only be done via class method
- inputs = {name: getattr(taskDef.connections, name) for name in taskDef.connections.inputs}
- for dsTypeDescr in inputs.values():
- # all pre-existing datasets have effective index -1
- prodIdx = producerIndex.get(dsTypeDescr.name, -1)
- if prodIdx >= idx:
- # not good, producer is downstream
- return False
-
+ if isinstance(pipeline, Pipeline):
+ graph = pipeline.to_graph()
+ else:
+ graph = PipelineGraph()
+ for task_def in pipeline:
+ graph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections)
+ # Can't use graph.is_sorted because that requires sorted dataset type names
+ # as well as sorted tasks.
+ tasks_xgraph = graph.make_task_xgraph()
+ seen: set[str] = set()
+ for task_label in tasks_xgraph:
+ successors = set(tasks_xgraph.successors(task_label))
+ if not successors.isdisjoint(seen):
+ return False
+ seen.add(task_label)
return True
-def orderPipeline(pipeline: list[TaskDef]) -> list[TaskDef]:
+def orderPipeline(pipeline: Pipeline | Iterable[TaskDef]) -> list[TaskDef]:
"""Re-order tasks in pipeline to satisfy data dependencies.
- When possible new ordering keeps original relative order of the tasks.
-
Parameters
----------
- pipeline : `list` of `pipe.base.TaskDef`
+ pipeline : `Pipeline` or `collections.abc.Iterable` [ `TaskDef` ]
Pipeline description.
Returns
-------
- Correctly ordered pipeline (`list` of `pipe.base.TaskDef` objects).
+ ordered : `list` [ `TaskDef` ]
+ Correctly ordered pipeline.
Raises
------
- `DuplicateOutputError` is raised when there is more than one producer for a
- dataset type.
- `PipelineDataCycleError` is also raised when pipeline has dependency
- cycles. `MissingTaskFactoryError` is raised when `TaskFactory` is needed
- but not provided.
+ DuplicateOutputError
+ Raised when there is more than one producer for a dataset type.
+ PipelineDataCycleError
+ Raised when the pipeline has dependency cycles.
"""
- # This is a modified version of Kahn's algorithm that preserves order
-
- # build mapping of the tasks to their inputs and outputs
- inputs = {} # maps task index to its input DatasetType names
- outputs = {} # maps task index to its output DatasetType names
- allInputs = set() # all inputs of all tasks
- allOutputs = set() # all outputs of all tasks
- dsTypeTaskLabels: dict[str, str] = {} # maps DatasetType name to the label of its parent task
- for idx, taskDef in enumerate(pipeline):
- # task outputs
- dsMap = {name: getattr(taskDef.connections, name) for name in taskDef.connections.outputs}
- for dsTypeDescr in dsMap.values():
- if dsTypeDescr.name in allOutputs:
- raise DuplicateOutputError(
- f"DatasetType `{dsTypeDescr.name}' in task `{taskDef.label}' already appears as an "
- f"output in task `{dsTypeTaskLabels[dsTypeDescr.name]}'."
- )
- dsTypeTaskLabels[dsTypeDescr.name] = taskDef.label
- outputs[idx] = set(dsTypeDescr.name for dsTypeDescr in dsMap.values())
- allOutputs.update(outputs[idx])
-
- # task inputs
- connectionInputs = itertools.chain(taskDef.connections.inputs, taskDef.connections.prerequisiteInputs)
- inputs[idx] = set(getattr(taskDef.connections, name).name for name in connectionInputs)
- allInputs.update(inputs[idx])
-
- # for simplicity add pseudo-node which is a producer for all pre-existing
- # inputs, its index is -1
- preExisting = allInputs - allOutputs
- outputs[-1] = preExisting
-
- # Set of nodes with no incoming edges, initially set to pseudo-node
- queue = [-1]
- result = []
- while queue:
- # move to final list, drop -1
- idx = queue.pop(0)
- if idx >= 0:
- result.append(idx)
-
- # remove task outputs from other tasks inputs
- thisTaskOutputs = outputs.get(idx, set())
- for taskInputs in inputs.values():
- taskInputs -= thisTaskOutputs
-
- # find all nodes with no incoming edges and move them to the queue
- topNodes = [key for key, value in inputs.items() if not value]
- queue += topNodes
- for key in topNodes:
- del inputs[key]
-
- # keep queue ordered
- queue.sort()
-
- # if there is something left it means cycles
- if inputs:
- # format it in usable way
- loops = []
- for idx, inputNames in inputs.items():
- taskName = pipeline[idx].label
- outputNames = outputs[idx]
- edge = " {} -> {} -> {}".format(inputNames, taskName, outputNames)
- loops.append(edge)
- raise PipelineDataCycleError("Pipeline has data cycles:\n" + "\n".join(loops))
-
- return [pipeline[idx] for idx in result]
+ if isinstance(pipeline, Pipeline):
+ graph = pipeline.to_graph()
+ else:
+ graph = PipelineGraph()
+ for task_def in pipeline:
+ graph.add_task(task_def.label, task_def.taskClass, task_def.config, task_def.connections)
+ graph.sort()
+ return list(graph._iter_task_defs())
diff --git a/python/lsst/pipe/base/pipeline.py b/python/lsst/pipe/base/pipeline.py
index 781defcca..a23fbd95f 100644
--- a/python/lsst/pipe/base/pipeline.py
+++ b/python/lsst/pipe/base/pipeline.py
@@ -53,14 +53,12 @@
from lsst.utils.introspection import get_full_type_name
from . import automatic_connection_constants as acc
-from . import pipelineIR, pipeTools
+from . import pipeline_graph, pipelineIR
from ._instrument import Instrument as PipeBaseInstrument
-from ._task_metadata import TaskMetadata
from .config import PipelineTaskConfig
-from .connections import iterConnections
+from .connections import PipelineTaskConnections, iterConnections
from .connectionTypes import Input
from .pipelineTask import PipelineTask
-from .task import _TASK_METADATA_TYPE
if TYPE_CHECKING: # Imports needed only for type annotations; may be circular.
from lsst.obs.base import Instrument
@@ -126,6 +124,11 @@ class TaskDef:
Task label, usually a short string unique in a pipeline. If not
provided, ``taskClass`` must be, and ``taskClass._DefaultName`` will
be used.
+ connections : `PipelineTaskConnections`, optional
+ Object that describes the dataset types used by the task. If not
+ provided, one will be constructed from the given configuration. If
+ provided, it is assumed that ``config`` has already been validated
+ and frozen.
"""
def __init__(
@@ -134,6 +137,7 @@ def __init__(
config: PipelineTaskConfig | None = None,
taskClass: type[PipelineTask] | None = None,
label: str | None = None,
+ connections: PipelineTaskConnections | None = None,
):
if taskName is None:
if taskClass is None:
@@ -150,16 +154,20 @@ def __init__(
raise ValueError("`taskClass` must be provided if `label` is not.")
label = taskClass._DefaultName
self.taskName = taskName
- try:
- config.validate()
- except Exception:
- _LOG.error("Configuration validation failed for task %s (%s)", label, taskName)
- raise
- config.freeze()
+ if connections is None:
+ # If we don't have connections yet, assume the config hasn't been
+ # validated yet.
+ try:
+ config.validate()
+ except Exception:
+ _LOG.error("Configuration validation failed for task %s (%s)", label, taskName)
+ raise
+ config.freeze()
+ connections = config.connections.ConnectionsClass(config=config)
self.config = config
self.taskClass = taskClass
self.label = label
- self.connections = config.connections.ConnectionsClass(config=config)
+ self.connections = connections
@property
def configDatasetName(self) -> str:
@@ -739,6 +747,47 @@ def write_to_uri(self, uri: ResourcePathExpression) -> None:
"""
self._pipelineIR.write_to_uri(uri)
+ def to_graph(self) -> pipeline_graph.PipelineGraph:
+ """Construct a pipeline graph from this pipeline.
+
+ Constructing a graph applies all configuration overrides, freezes all
+ configuration, checks all contracts, and checks for dataset type
+ consistency between tasks (as much as possible without access to a data
+ repository). It cannot be reversed.
+
+ Returns
+ -------
+ graph : `pipeline_graph.PipelineGraph`
+ Representation of the pipeline as a graph.
+ """
+ instrument_class_name = self._pipelineIR.instrument
+ data_id = {}
+ if instrument_class_name is not None:
+ instrument_class = doImportType(instrument_class_name)
+ if instrument_class is not None:
+ data_id["instrument"] = instrument_class.getName()
+ graph = pipeline_graph.PipelineGraph(data_id=data_id)
+ graph.description = self._pipelineIR.description
+ for label in self._pipelineIR.tasks:
+ self._add_task_to_graph(label, graph)
+ if self._pipelineIR.contracts is not None:
+ label_to_config = {x.label: x.config for x in graph.tasks.values()}
+ for contract in self._pipelineIR.contracts:
+ # execute this in its own line so it can raise a good error
+ # message if there was problems with the eval
+ success = eval(contract.contract, None, label_to_config)
+ if not success:
+ extra_info = f": {contract.msg}" if contract.msg is not None else ""
+ raise pipelineIR.ContractError(
+ f"Contract(s) '{contract.contract}' were not satisfied{extra_info}"
+ )
+ for label, subset in self._pipelineIR.labeled_subsets.items():
+ graph.add_task_subset(
+ label, subset.subset, subset.description if subset.description is not None else ""
+ )
+ graph.sort()
+ return graph
+
def toExpandedPipeline(self) -> Generator[TaskDef, None, None]:
r"""Return a generator of `TaskDef`\s which can be used to create
quantum graphs.
@@ -755,31 +804,22 @@ def toExpandedPipeline(self) -> Generator[TaskDef, None, None]:
If a dataId is supplied in a config block. This is in place for
future use
"""
- taskDefs = []
- for label in self._pipelineIR.tasks:
- taskDefs.append(self._buildTaskDef(label))
-
- # lets evaluate the contracts
- if self._pipelineIR.contracts is not None:
- label_to_config = {x.label: x.config for x in taskDefs}
- for contract in self._pipelineIR.contracts:
- # execute this in its own line so it can raise a good error
- # message if there was problems with the eval
- success = eval(contract.contract, None, label_to_config)
- if not success:
- extra_info = f": {contract.msg}" if contract.msg is not None else ""
- raise pipelineIR.ContractError(
- f"Contract(s) '{contract.contract}' were not satisfied{extra_info}"
- )
+ yield from self.to_graph()._iter_task_defs()
- taskDefs = sorted(taskDefs, key=lambda x: x.label)
- yield from pipeTools.orderPipeline(taskDefs)
+ def _add_task_to_graph(self, label: str, graph: pipeline_graph.PipelineGraph) -> None:
+ """Add a single task from this pipeline to a pipeline graph that is
+ under construction.
- def _buildTaskDef(self, label: str) -> TaskDef:
+ Parameters
+ ----------
+ label : `str`
+ Label for the task to be added.
+ graph : `pipeline_graph.PipelineGraph`
+ Graph to add the task to.
+ """
if (taskIR := self._pipelineIR.tasks.get(label)) is None:
raise NameError(f"Label {label} does not appear in this pipeline")
taskClass: type[PipelineTask] = doImportType(taskIR.klass)
- taskName = get_full_type_name(taskClass)
config = taskClass.ConfigClass()
instrument: PipeBaseInstrument | None = None
if (instrumentName := self._pipelineIR.instrument) is not None:
@@ -792,13 +832,19 @@ def _buildTaskDef(self, label: str) -> TaskDef:
self._pipelineIR.parameters,
label,
)
- return TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label)
+ graph.add_task(label, taskClass, config)
def __iter__(self) -> Generator[TaskDef, None, None]:
return self.toExpandedPipeline()
def __getitem__(self, item: str) -> TaskDef:
- return self._buildTaskDef(item)
+ # Making a whole graph and then making a TaskDef from that is pretty
+ # backwards, but I'm hoping to deprecate this method shortly in favor
+ # of making the graph explicitly and working with its node objects.
+ graph = pipeline_graph.PipelineGraph()
+ self._add_task_to_graph(item, graph)
+ (result,) = graph._iter_task_defs()
+ return result
def __len__(self) -> int:
return len(self._pipelineIR.tasks)
@@ -1072,7 +1118,7 @@ def makeDatasetTypesSet(
DatasetType(
taskDef.configDatasetName,
registry.dimensions.empty,
- storageClass="Config",
+ storageClass=acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
)
)
initOutputs.freeze()
@@ -1090,7 +1136,7 @@ def makeDatasetTypesSet(
current = registry.getDatasetType(taskDef.metadataDatasetName)
except KeyError:
# No previous definition so use the default.
- storageClass = "TaskMetadata" if _TASK_METADATA_TYPE is TaskMetadata else "PropertySet"
+ storageClass = acc.METADATA_OUTPUT_STORAGE_CLASS
else:
storageClass = current.storageClass.name
outputs.update({DatasetType(taskDef.metadataDatasetName, dimensions, storageClass)})
@@ -1098,7 +1144,15 @@ def makeDatasetTypesSet(
if taskDef.logOutputDatasetName is not None:
# Log output dimensions correspond to a task quantum.
dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
- outputs.update({DatasetType(taskDef.logOutputDatasetName, dimensions, "ButlerLogRecords")})
+ outputs.update(
+ {
+ DatasetType(
+ taskDef.logOutputDatasetName,
+ dimensions,
+ acc.LOG_OUTPUT_STORAGE_CLASS,
+ )
+ }
+ )
outputs.freeze()
diff --git a/python/lsst/pipe/base/pipeline_graph/__init__.py b/python/lsst/pipe/base/pipeline_graph/__init__.py
new file mode 100644
index 000000000..3cf7a8101
--- /dev/null
+++ b/python/lsst/pipe/base/pipeline_graph/__init__.py
@@ -0,0 +1,29 @@
+# This file is part of pipe_base.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+from __future__ import annotations
+
+from ._dataset_types import *
+from ._edges import *
+from ._exceptions import *
+from ._nodes import *
+from ._pipeline_graph import *
+from ._task_subsets import *
+from ._tasks import *
diff --git a/python/lsst/pipe/base/pipeline_graph/_dataset_types.py b/python/lsst/pipe/base/pipeline_graph/_dataset_types.py
new file mode 100644
index 000000000..7cc3dd56c
--- /dev/null
+++ b/python/lsst/pipe/base/pipeline_graph/_dataset_types.py
@@ -0,0 +1,214 @@
+# This file is part of pipe_base.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+from __future__ import annotations
+
+__all__ = ("DatasetTypeNode",)
+
+import dataclasses
+from typing import TYPE_CHECKING, Any
+
+import networkx
+from lsst.daf.butler import DatasetRef, DatasetType, DimensionGraph, Registry, StorageClass
+from lsst.daf.butler.registry import MissingDatasetTypeError
+
+from ._exceptions import DuplicateOutputError
+from ._nodes import NodeKey, NodeType
+
+if TYPE_CHECKING:
+ from ._edges import ReadEdge, WriteEdge
+
+
+@dataclasses.dataclass(frozen=True, eq=False)
+class DatasetTypeNode:
+ """A node in a pipeline graph that represents a resolved dataset type.
+
+ Notes
+ -----
+ A dataset type node represents a common definition of the dataset type
+ across the entire graph - it is never a component, and the storage class is
+ the registry dataset type's storage class or (if there isn't one) the one
+ defined by the producing task.
+
+ Dataset type nodes are intentionally not equality comparable, since there
+ are many different (and useful) ways to compare these objects with no clear
+ winner as the most obvious behavior.
+ """
+
+ dataset_type: DatasetType
+ """Common definition of this dataset type for the graph.
+ """
+
+ is_initial_query_constraint: bool
+ """Whether this dataset should be included as a constraint in the initial
+ query for data IDs in QuantumGraph generation.
+
+ This is only `True` for dataset types that are overall regular inputs, and
+ only if none of those input connections had ``deferQueryConstraint=True``.
+ """
+
+ is_prerequisite: bool
+ """Whether this dataset type is a prerequisite input that must exist in
+ the Registry before graph creation.
+ """
+
+ @classmethod
+ def _from_edges(
+ cls, key: NodeKey, xgraph: networkx.MultiDiGraph, registry: Registry, previous: DatasetTypeNode | None
+ ) -> DatasetTypeNode:
+ """Construct a dataset type node from its edges.
+
+ Parameters
+ ----------
+ key : `NodeKey`
+ Named tuple that holds the dataset type and serves as the node
+ object in the internal networkx graph.
+ xgraph : `networkx.MultiDiGraph`
+ The internal networkx graph.
+ registry : `lsst.daf.butler.Registry`
+ Registry client for the data repository. Only used to get
+ dataset type definitions and the dimension universe.
+ previous : `DatasetTypeNode` or `None`
+ Previous node for this dataset type.
+
+ Returns
+ -------
+ node : `DatasetTypeNode`
+ Node consistent with all edges pointing to it and the data
+ repository.
+ """
+ try:
+ dataset_type = registry.getDatasetType(key.name)
+ is_registered = True
+ except MissingDatasetTypeError:
+ dataset_type = None
+ is_registered = False
+ if previous is not None and previous.dataset_type == dataset_type:
+ # This node was already resolved (with exactly the same edges
+ # contributing, since we clear resolutions when edges are added or
+ # removed). The only thing that might have changed was the
+ # definition in the registry, and it didn't.
+ return previous
+ is_initial_query_constraint = True
+ is_prerequisite: bool | None = None
+ producer: str | None = None
+ write_edge: WriteEdge
+ for _, _, write_edge in xgraph.in_edges(key, data="instance"): # will iterate zero or one time
+ if producer is not None:
+ raise DuplicateOutputError(
+ f"Dataset type {key.name!r} is produced by both {write_edge.task_label!r} "
+ f"and {producer!r}."
+ )
+ producer = write_edge.task_label
+ dataset_type = write_edge._resolve_dataset_type(dataset_type, universe=registry.dimensions)
+ is_prerequisite = False
+ is_initial_query_constraint = False
+ read_edge: ReadEdge
+ consumers: list[str] = []
+ read_edges = list(read_edge for _, _, read_edge in xgraph.out_edges(key, data="instance"))
+ # Put edges that are not component datasets before any edges that are.
+ read_edges.sort(key=lambda read_edge: read_edge.component is not None)
+ for read_edge in read_edges:
+ dataset_type, is_initial_query_constraint, is_prerequisite = read_edge._resolve_dataset_type(
+ current=dataset_type,
+ universe=registry.dimensions,
+ is_initial_query_constraint=is_initial_query_constraint,
+ is_prerequisite=is_prerequisite,
+ is_registered=is_registered,
+ producer=producer,
+ consumers=consumers,
+ )
+ consumers.append(read_edge.task_label)
+ assert dataset_type is not None, "Graph structure guarantees at least one edge."
+ assert is_prerequisite is not None, "Having at least one edge guarantees is_prerequisite is known."
+ return DatasetTypeNode(
+ dataset_type=dataset_type,
+ is_initial_query_constraint=is_initial_query_constraint,
+ is_prerequisite=is_prerequisite,
+ )
+
+ @property
+ def name(self) -> str:
+ """Name of the dataset type.
+
+ This is always the parent dataset type, never that of a component.
+ """
+ return self.dataset_type.name
+
+ @property
+ def key(self) -> NodeKey:
+ """Key that identifies this dataset type in internal and exported
+ networkx graphs.
+ """
+ return NodeKey(NodeType.DATASET_TYPE, self.dataset_type.name)
+
+ @property
+ def dimensions(self) -> DimensionGraph:
+ """Dimensions of the dataset type."""
+ return self.dataset_type.dimensions
+
+ @property
+ def storage_class_name(self) -> str:
+ """String name of the storage class for this dataset type."""
+ return self.dataset_type.storageClass_name
+
+ @property
+ def storage_class(self) -> StorageClass:
+ """Storage class for this dataset type."""
+ return self.dataset_type.storageClass
+
+ def __repr__(self) -> str:
+ return f"{self.name} ({self.storage_class_name}, {self.dimensions})"
+
+ def generalize_ref(self, ref: DatasetRef) -> DatasetRef:
+ """Convert a `~lsst.daf.butler.DatasetRef` with the dataset type
+ associated with some task to one with the common dataset type defined
+ by this node.
+
+ Parameters
+ ----------
+ ref : `lsst.daf.butler.DatasetRef`
+ Reference whose dataset type is convertible to this node's, either
+ because it is a component with the node's dataset type as its
+ parent, or because it has a compatible storage class.
+
+ Returns
+ -------
+ ref : `lsst.daf.butler.DatasetRef`
+ Reference with exactly this node's dataset type.
+ """
+ if ref.isComponent():
+ ref = ref.makeCompositeRef()
+ if ref.datasetType.storageClass_name != self.dataset_type.storageClass_name:
+ return ref.overrideStorageClass(self.dataset_type.storageClass_name)
+ return ref
+
+ def _to_xgraph_state(self) -> dict[str, Any]:
+ """Convert this node's attributes into a dictionary suitable for use
+ in exported networkx graphs.
+ """
+ return {
+ "dataset_type": self.dataset_type,
+ "is_initial_query_constraint": self.is_initial_query_constraint,
+ "is_prerequisite": self.is_prerequisite,
+ "dimensions": self.dataset_type.dimensions,
+ "storage_class_name": self.dataset_type.storageClass_name,
+ "bipartite": NodeType.DATASET_TYPE.bipartite,
+ }
diff --git a/python/lsst/pipe/base/pipeline_graph/_edges.py b/python/lsst/pipe/base/pipeline_graph/_edges.py
new file mode 100644
index 000000000..10ea6b11b
--- /dev/null
+++ b/python/lsst/pipe/base/pipeline_graph/_edges.py
@@ -0,0 +1,714 @@
+# This file is part of pipe_base.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+from __future__ import annotations
+
+__all__ = ("Edge", "ReadEdge", "WriteEdge")
+
+from abc import ABC, abstractmethod
+from collections.abc import Mapping, Sequence
+from typing import Any, ClassVar, TypeVar
+
+from lsst.daf.butler import DatasetRef, DatasetType, DimensionUniverse, SkyPixDimension
+from lsst.daf.butler.registry import MissingDatasetTypeError
+from lsst.utils.classes import immutable
+
+from ..connectionTypes import BaseConnection
+from ._exceptions import ConnectionTypeConsistencyError, IncompatibleDatasetTypeError
+from ._nodes import NodeKey, NodeType
+
+_S = TypeVar("_S", bound="Edge")
+
+
+@immutable
+class Edge(ABC):
+ """Base class for edges in a pipeline graph.
+
+ This represents the link between a task node and an input or output dataset
+ type.
+
+ Parameters
+ ----------
+ task_key : `NodeKey`
+ Key for the task node this edge is connected to.
+ dataset_type_key : `NodeKey`
+ Key for the dataset type node this edge is connected to.
+ storage_class_name : `str`
+ Name of the dataset type's storage class as seen by the task.
+ connection_name : `str`
+ Internal name for the connection as seen by the task.
+ is_calibration : `bool`
+ Whether this dataset type can be included in
+ `~lsst.daf.butler.CollectionType.CALIBRATION` collections.
+ raw_dimensions : `frozenset` [ `str` ]
+ Raw dimensions from the connection definition.
+ """
+
+ def __init__(
+ self,
+ *,
+ task_key: NodeKey,
+ dataset_type_key: NodeKey,
+ storage_class_name: str,
+ connection_name: str,
+ is_calibration: bool,
+ raw_dimensions: frozenset[str],
+ ):
+ self.task_key = task_key
+ self.dataset_type_key = dataset_type_key
+ self.connection_name = connection_name
+ self.storage_class_name = storage_class_name
+ self.is_calibration = is_calibration
+ self.raw_dimensions = raw_dimensions
+
+ INIT_TO_TASK_NAME: ClassVar[str] = "INIT"
+ """Edge key for the special edge that connects a task init node to the
+ task node itself (for regular edges, this would be the connection name).
+ """
+
+ task_key: NodeKey
+ """Task part of the key for this edge in networkx graphs."""
+
+ dataset_type_key: NodeKey
+ """Task part of the key for this edge in networkx graphs."""
+
+ connection_name: str
+ """Name used by the task to refer to this dataset type."""
+
+ storage_class_name: str
+ """Storage class expected by this task.
+
+ If `ReadEdge.component` is not `None`, this is the component storage class,
+ not the parent storage class.
+ """
+
+ is_calibration: bool
+ """Whether this dataset type can be included in
+ `~lsst.daf.butler.CollectionType.CALIBRATION` collections.
+ """
+
+ raw_dimensions: frozenset[str]
+ """Raw dimensions in the task declaration.
+
+ This can only be used safely for partial comparisons: two edges with the
+ same ``raw_dimensions`` (and the same parent dataset type name) always have
+ the same resolved dimensions, but edges with different ``raw_dimensions``
+ may also have the same resolvd dimensions.
+ """
+
+ @property
+ def is_init(self) -> bool:
+ """Whether this dataset is read or written when the task is
+ constructed, not when it is run.
+ """
+ return self.task_key.node_type is NodeType.TASK_INIT
+
+ @property
+ def task_label(self) -> str:
+ """Label of the task."""
+ return str(self.task_key)
+
+ @property
+ def parent_dataset_type_name(self) -> str:
+ """Name of the parent dataset type.
+
+ All dataset type nodes in a pipeline graph are for parent dataset
+ types; components are represented by additional `ReadEdge` state.
+ """
+ return str(self.dataset_type_key)
+
+ @property
+ @abstractmethod
+ def nodes(self) -> tuple[NodeKey, NodeKey]:
+ """The directed pair of `NodeKey` instances this edge connects.
+
+ This tuple is ordered in the same direction as the pipeline flow:
+ `task_key` precedes `dataset_type_key` for writes, and the
+ reverse is true for reads.
+ """
+ raise NotImplementedError()
+
+ @property
+ def key(self) -> tuple[NodeKey, NodeKey, str]:
+ """Ordered tuple of node keys and connection name that uniquely
+ identifies this edge in a pipeline graph.
+ """
+ return self.nodes + (self.connection_name,)
+
+ def __repr__(self) -> str:
+ return f"{self.nodes[0]} -> {self.nodes[1]} ({self.connection_name})"
+
+ @property
+ def dataset_type_name(self) -> str:
+ """Dataset type name seen by the task.
+
+ This defaults to the parent dataset type name, which is appropriate
+ for all writes and most reads.
+ """
+ return self.parent_dataset_type_name
+
+ def diff(self: _S, other: _S, connection_type: str = "connection") -> list[str]:
+ """Compare this edge to another one from a possibly-different
+ configuration of the same task label.
+
+ Parameters
+ ----------
+ other : `Edge`
+ Another edge of the same type to compare to.
+ connection_type : `str`
+ Human-readable name of the connection type of this edge (e.g.
+ "init input", "output") for use in returned messages.
+
+ Returns
+ -------
+ differences : `list` [ `str` ]
+ List of string messages describing differences between ``self`` and
+ ``other``. Will be empty if ``self == other`` or if the only
+ difference is in the task label or connection name (which are not
+ checked). Messages will use 'A' to refer to ``self`` and 'B' to
+ refer to ``other``.
+ """
+ result = []
+ if self.dataset_type_name != other.dataset_type_name:
+ result.append(
+ f"{connection_type.capitalize()} {self.connection_name!r} has dataset type "
+ f"{self.dataset_type_name!r} in A, but {other.dataset_type_name!r} in B."
+ )
+ if self.storage_class_name != other.storage_class_name:
+ result.append(
+ f"{connection_type.capitalize()} {self.connection_name!r} has storage class "
+ f"{self.storage_class_name!r} in A, but {other.storage_class_name!r} in B."
+ )
+ if self.raw_dimensions != other.raw_dimensions:
+ result.append(
+ f"{connection_type.capitalize()} {self.connection_name!r} has raw dimensions "
+ f"{set(self.raw_dimensions)} in A, but {set(other.raw_dimensions)} in B "
+ "(differences in raw dimensions may not lead to differences in resolved dimensions, "
+ "but this cannot be checked without re-resolving the dataset type)."
+ )
+ if self.is_calibration != other.is_calibration:
+ result.append(
+ f"{connection_type.capitalize()} {self.connection_name!r} is marked as a calibration "
+ f"{'in A but not in B' if self.is_calibration else 'in B but not in A'}."
+ )
+ return result
+
+ @abstractmethod
+ def adapt_dataset_type(self, dataset_type: DatasetType) -> DatasetType:
+ """Transform the graph's definition of a dataset type (parent, with the
+ registry or producer's storage class) to the one seen by this task.
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def adapt_dataset_ref(self, ref: DatasetRef) -> DatasetRef:
+ """Transform the graph's definition of a dataset reference (parent
+ dataset type, with the registry or producer's storage class) to the one
+ seen by this task.
+ """
+ raise NotImplementedError()
+
+ def _to_xgraph_state(self) -> dict[str, Any]:
+ """Convert this edges's attributes into a dictionary suitable for use
+ in exported networkx graphs.
+ """
+ return {
+ "parent_dataset_type_name": self.parent_dataset_type_name,
+ "storage_class_name": self.storage_class_name,
+ "is_init": bool,
+ }
+
+
+class ReadEdge(Edge):
+ """Representation of an input connection (including init-inputs and
+ prerequisites) in a pipeline graph.
+
+ Parameters
+ ----------
+ dataset_type_key : `NodeKey`
+ Key for the dataset type node this edge is connected to. This should
+ hold the parent dataset type name for component dataset types.
+ task_key : `NodeKey`
+ Key for the task node this edge is connected to.
+ storage_class_name : `str`
+ Name of the dataset type's storage class as seen by the task.
+ connection_name : `str`
+ Internal name for the connection as seen by the task.
+ is_calibration : `bool`
+ Whether this dataset type can be included in
+ `~lsst.daf.butler.CollectionType.CALIBRATION` collections.
+ raw_dimensions : `frozenset` [ `str` ]
+ Raw dimensions from the connection definition.
+ is_prerequisite : `bool`
+ Whether this dataset must be present in the data repository prior to
+ `QuantumGraph` generation.
+ component : `str` or `None`
+ Component of the dataset type requested by the task.
+ defer_query_constraint : `bool`
+ If `True`, by default do not include this dataset type's existence as a
+ constraint on the initial data ID query in QuantumGraph generation.
+
+ Notes
+ -----
+ When included in an exported `networkx` graph (e.g.
+ `PipelineGraph.make_xgraph`), read edges set the following edge attributes:
+
+ - ``parent_dataset_type_name``
+ - ``storage_class_name``
+ - ``is_init``
+ - ``component``
+ - ``is_prerequisite``
+
+ As with `ReadEdge` instance attributes, these descriptions of dataset types
+ are those specific to a task, and may differ from the graph's resolved
+ dataset type or (if `PipelineGraph.resolve` has not been called) there may
+ not even be a consistent definition of the dataset type.
+ """
+
+ def __init__(
+ self,
+ dataset_type_key: NodeKey,
+ task_key: NodeKey,
+ *,
+ storage_class_name: str,
+ connection_name: str,
+ is_calibration: bool,
+ raw_dimensions: frozenset[str],
+ is_prerequisite: bool,
+ component: str | None,
+ defer_query_constraint: bool,
+ ):
+ super().__init__(
+ task_key=task_key,
+ dataset_type_key=dataset_type_key,
+ storage_class_name=storage_class_name,
+ connection_name=connection_name,
+ raw_dimensions=raw_dimensions,
+ is_calibration=is_calibration,
+ )
+ self.is_prerequisite = is_prerequisite
+ self.component = component
+ self.defer_query_constraint = defer_query_constraint
+
+ component: str | None
+ """Component to add to `parent_dataset_type_name` to form the dataset type
+ name seen by this task.
+ """
+
+ is_prerequisite: bool
+ """Whether this dataset must be present in the data repository prior to
+ `QuantumGraph` generation.
+ """
+
+ defer_query_constraint: bool
+ """If `True`, by default do not include this dataset type's existence as a
+ constraint on the initial data ID query in QuantumGraph generation.
+ """
+
+ @property
+ def nodes(self) -> tuple[NodeKey, NodeKey]:
+ # Docstring inherited.
+ return (self.dataset_type_key, self.task_key)
+
+ @property
+ def dataset_type_name(self) -> str:
+ """Complete dataset type name, as seen by the task."""
+ if self.component is not None:
+ return f"{self.parent_dataset_type_name}.{self.component}"
+ return self.parent_dataset_type_name
+
+ def diff(self: ReadEdge, other: ReadEdge, connection_type: str = "connection") -> list[str]:
+ # Docstring inherited.
+ result = super().diff(other, connection_type)
+ if self.defer_query_constraint != other.defer_query_constraint:
+ result.append(
+ f"{connection_type.capitalize()} {self.connection_name!r} is marked as a deferred query "
+ f"constraint {'in A but not in B' if self.defer_query_constraint else 'in B but not in A'}."
+ )
+ return result
+
+ def adapt_dataset_type(self, dataset_type: DatasetType) -> DatasetType:
+ # Docstring inherited.
+ if self.component is not None:
+ assert (
+ self.storage_class_name == dataset_type.storageClass.allComponents()[self.component].name
+ ), "components with storage class overrides are not supported"
+ return dataset_type.makeComponentDatasetType(self.component)
+ if self.storage_class_name != dataset_type.storageClass_name:
+ return dataset_type.overrideStorageClass(self.storage_class_name)
+ return dataset_type
+
+ def adapt_dataset_ref(self, ref: DatasetRef) -> DatasetRef:
+ # Docstring inherited.
+ if self.component is not None:
+ assert (
+ self.storage_class_name == ref.datasetType.storageClass.allComponents()[self.component].name
+ ), "components with storage class overrides are not supported"
+ return ref.makeComponentRef(self.component)
+ if self.storage_class_name != ref.datasetType.storageClass_name:
+ return ref.overrideStorageClass(self.storage_class_name)
+ return ref
+
+ @classmethod
+ def _from_connection_map(
+ cls,
+ task_key: NodeKey,
+ connection_name: str,
+ connection_map: Mapping[str, BaseConnection],
+ is_prerequisite: bool = False,
+ ) -> ReadEdge:
+ """Construct a `ReadEdge` instance from a `.BaseConnection` object.
+
+ Parameters
+ ----------
+ task_key : `NodeKey`
+ Key for the associated task node or task init node.
+ connection_name : `str`
+ Internal name for the connection as seen by the task,.
+ connection_map : Mapping [ `str`, `.BaseConnection` ]
+ Mapping of post-configuration object to draw dataset type
+ information from, keyed by connection name.
+ is_prerequisite : `bool`, optional
+ Whether this dataset must be present in the data repository prior
+ to `QuantumGraph` generation.
+
+ Returns
+ -------
+ edge : `ReadEdge`
+ New edge instance.
+ """
+ connection = connection_map[connection_name]
+ parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(connection.name)
+ return cls(
+ dataset_type_key=NodeKey(NodeType.DATASET_TYPE, parent_dataset_type_name),
+ task_key=task_key,
+ component=component,
+ storage_class_name=connection.storageClass,
+ # InitInput connections don't have .isCalibration.
+ is_calibration=getattr(connection, "isCalibration", False),
+ is_prerequisite=is_prerequisite,
+ connection_name=connection_name,
+ # InitInput connections don't have a .dimensions because they
+ # always have empty dimensions.
+ raw_dimensions=frozenset(getattr(connection, "dimensions", frozenset())),
+ # PrerequisiteInput and InitInput connections don't have a
+ # .eferQueryConstraints, because they never constrain the initial
+ # data ID query.
+ defer_query_constraint=getattr(connection, "deferQueryConstraint", False),
+ )
+
+ def _resolve_dataset_type(
+ self,
+ *,
+ current: DatasetType | None,
+ is_initial_query_constraint: bool,
+ is_prerequisite: bool | None,
+ universe: DimensionUniverse,
+ producer: str | None,
+ consumers: Sequence[str],
+ is_registered: bool,
+ ) -> tuple[DatasetType, bool, bool]:
+ """Participate in the construction of the `DatasetTypeNode` object
+ associated with this edge.
+
+ Parameters
+ ----------
+ current : `lsst.daf.butler.DatasetType` or `None`
+ The current graph-wide `DatasetType`, or `None`. This will always
+ be the registry's definition of the parent dataset type, if one
+ exists. If not, it will be the dataset type definition from the
+ task in the graph that writes it, if there is one. If there is no
+ such task, this will be `None`.
+ is_initial_query_constraint : `bool`
+ Whether this dataset type is currently marked as a constraint on
+ the initial data ID query in QuantumGraph generation.
+ is_prerequisite : `bool` | None`
+ Whether this dataset type is marked as a prerequisite input in all
+ edges processed so far. `None` if this is the first edge.
+ universe : `lsst.daf.butler.DimensionUniverse`
+ Object that holds all dimension definitions.
+ producer : `str` or `None`
+ The label of the task that produces this dataset type in the
+ pipeline, or `None` if it is an overall input.
+ consumers : `Sequence` [ `str` ]
+ Labels for other consuming tasks that have already participated in
+ this dataset type's resolution.
+ is_registered : `bool`
+ Whether a registration for this dataset type was found in the
+ data repository.
+
+ Returns
+ -------
+ dataset_type : `DatasetType`
+ The updated graph-wide dataset type. If ``current`` was provided,
+ this must be equal to it.
+ is_initial_query_constraint : `bool`
+ If `True`, this dataset type should be included as a constraint in
+ the initial data ID query during QuantumGraph generation; this
+ requires that ``is_initial_query_constraint`` also be `True` on
+ input.
+ is_prerequisite : `bool`
+ Whether this dataset type is marked as a prerequisite input in this
+ task and all other edges processed so far.
+
+ Raises
+ ------
+ MissingDatasetTypeError
+ Raised if ``current is None`` and this edge cannot define one on
+ its own.
+ IncompatibleDatasetTypeError
+ Raised if ``current is not None`` and this edge's definition is not
+ compatible with it.
+ ConnectionTypeConsistencyError
+ Raised if a prerequisite input for one task appears as a different
+ kind of connection in any other task.
+ """
+ if "skypix" in self.raw_dimensions:
+ if current is None:
+ raise MissingDatasetTypeError(
+ f"DatasetType '{self.dataset_type_name}' referenced by "
+ f"{self.task_label!r} uses 'skypix' as a dimension "
+ f"placeholder, but has not been registered with the data repository. "
+ f"Note that reference catalog names are now used as the dataset "
+ f"type name instead of 'ref_cat'."
+ )
+ rest1 = set(universe.extract(self.raw_dimensions - set(["skypix"])).names)
+ rest2 = set(dim.name for dim in current.dimensions if not isinstance(dim, SkyPixDimension))
+ if rest1 != rest2:
+ raise IncompatibleDatasetTypeError(
+ f"Non-skypix dimensions for dataset type {self.dataset_type_name} declared in "
+ f"connections ({rest1}) are inconsistent with those in "
+ f"registry's version of this dataset ({rest2})."
+ )
+ dimensions = current.dimensions
+ else:
+ dimensions = universe.extract(self.raw_dimensions)
+ is_initial_query_constraint = is_initial_query_constraint and not self.defer_query_constraint
+ if is_prerequisite is None:
+ is_prerequisite = self.is_prerequisite
+ elif is_prerequisite and not self.is_prerequisite:
+ raise ConnectionTypeConsistencyError(
+ f"Dataset type {self.parent_dataset_type_name!r} is a prerequisite input to {consumers}, "
+ f"but it is not a prerequisite to {self.task_label!r}."
+ )
+ elif not is_prerequisite and self.is_prerequisite:
+ if producer is not None:
+ raise ConnectionTypeConsistencyError(
+ f"Dataset type {self.parent_dataset_type_name!r} is a prerequisite input to "
+ f"{self.task_label}, but it is produced by {producer!r}."
+ )
+ else:
+ raise ConnectionTypeConsistencyError(
+ f"Dataset type {self.parent_dataset_type_name!r} is a prerequisite input to "
+ f"{self.task_label}, but it is a regular input to {consumers!r}."
+ )
+
+ def report_current_origin() -> str:
+ if is_registered:
+ return "data repository"
+ elif producer is not None:
+ return f"producing task {producer!r}"
+ else:
+ return f"consuming task(s) {consumers!r}"
+
+ if self.component is not None:
+ if current is None:
+ raise MissingDatasetTypeError(
+ f"Dataset type {self.parent_dataset_type_name!r} is not registered and not produced by "
+ f"this pipeline, but it used by task {self.task_label!r}, via component "
+ f"{self.component!r}. This pipeline cannot be resolved until the parent dataset type is "
+ "registered."
+ )
+ all_current_components = current.storageClass.allComponents()
+ if self.component not in all_current_components:
+ raise IncompatibleDatasetTypeError(
+ f"Dataset type {self.parent_dataset_type_name!r} has storage class "
+ f"{current.storageClass_name!r} (from {report_current_origin()}), "
+ f"which does not include component {self.component!r} "
+ f"as requested by task {self.task_label!r}."
+ )
+ if all_current_components[self.component].name != self.storage_class_name:
+ raise IncompatibleDatasetTypeError(
+ f"Dataset type '{self.parent_dataset_type_name}.{self.component}' has storage class "
+ f"{all_current_components[self.component].name!r} "
+ f"(from {report_current_origin()}), which does not match "
+ f"{self.storage_class_name!r}, as requested by task {self.task_label!r}. "
+ "Note that storage class conversions of components are not supported."
+ )
+ return current, is_initial_query_constraint, is_prerequisite
+ else:
+ dataset_type = DatasetType(
+ self.parent_dataset_type_name,
+ dimensions,
+ storageClass=self.storage_class_name,
+ isCalibration=self.is_calibration,
+ )
+ if current is not None:
+ if not is_registered and producer is None:
+ # Current definition comes from another consumer; we
+ # require the dataset types to be exactly equal (not just
+ # compatible), since neither connection should take
+ # precedence.
+ if dataset_type != current:
+ raise MissingDatasetTypeError(
+ f"Definitions differ for input dataset type {self.parent_dataset_type_name!r}; "
+ f"task {self.task_label!r} has {dataset_type}, but the definition "
+ f"from {report_current_origin()} is {current}. If the storage classes are "
+ "compatible but different, registering the dataset type in the data repository "
+ "in advance will avoid this error."
+ )
+ elif not dataset_type.is_compatible_with(current):
+ raise IncompatibleDatasetTypeError(
+ f"Incompatible definition for input dataset type {self.parent_dataset_type_name!r}; "
+ f"task {self.task_label!r} has {dataset_type}, but the definition "
+ f"from {report_current_origin()} is {current}."
+ )
+ return current, is_initial_query_constraint, is_prerequisite
+ else:
+ return dataset_type, is_initial_query_constraint, is_prerequisite
+
+ def _to_xgraph_state(self) -> dict[str, Any]:
+ # Docstring inherited.
+ result = super()._to_xgraph_state()
+ result["component"] = self.component
+ result["is_prerequisite"] = self.is_prerequisite
+ return result
+
+
+class WriteEdge(Edge):
+ """Representation of an output connection (including init-outputs) in a
+ pipeline graph.
+
+ Notes
+ -----
+ When included in an exported `networkx` graph (e.g.
+ `PipelineGraph.make_xgraph`), write edges set the following edge
+ attributes:
+
+ - ``parent_dataset_type_name``
+ - ``storage_class_name``
+ - ``is_init``
+
+ As with `WRiteEdge` instance attributes, these descriptions of dataset
+ types are those specific to a task, and may differ from the graph's
+ resolved dataset type or (if `PipelineGraph.resolve` has not been called)
+ there may not even be a consistent definition of the dataset type.
+ """
+
+ @property
+ def nodes(self) -> tuple[NodeKey, NodeKey]:
+ # Docstring inherited.
+ return (self.task_key, self.dataset_type_key)
+
+ def adapt_dataset_type(self, dataset_type: DatasetType) -> DatasetType:
+ # Docstring inherited.
+ if self.storage_class_name != dataset_type.storageClass_name:
+ return dataset_type.overrideStorageClass(self.storage_class_name)
+ return dataset_type
+
+ def adapt_dataset_ref(self, ref: DatasetRef) -> DatasetRef:
+ # Docstring inherited.
+ if self.storage_class_name != ref.datasetType.storageClass_name:
+ return ref.overrideStorageClass(self.storage_class_name)
+ return ref
+
+ @classmethod
+ def _from_connection_map(
+ cls,
+ task_key: NodeKey,
+ connection_name: str,
+ connection_map: Mapping[str, BaseConnection],
+ ) -> WriteEdge:
+ """Construct a `WriteEdge` instance from a `.BaseConnection` object.
+
+ Parameters
+ ----------
+ task_key : `NodeKey`
+ Key for the associated task node or task init node.
+ connection_name : `str`
+ Internal name for the connection as seen by the task,.
+ connection_map : Mapping [ `str`, `.BaseConnection` ]
+ Mapping of post-configuration object to draw dataset type
+ information from, keyed by connection name.
+
+ Returns
+ -------
+ edge : `WriteEdge`
+ New edge instance.
+ """
+ connection = connection_map[connection_name]
+ parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(connection.name)
+ if component is not None:
+ raise ValueError(
+ f"Illegal output component dataset {connection.name!r} in task {task_key.name!r}."
+ )
+ return cls(
+ task_key=task_key,
+ dataset_type_key=NodeKey(NodeType.DATASET_TYPE, parent_dataset_type_name),
+ storage_class_name=connection.storageClass,
+ connection_name=connection_name,
+ # InitOutput connections don't have .isCalibration.
+ is_calibration=getattr(connection, "isCalibration", False),
+ # InitOutput connections don't have a .dimensions because they
+ # always have empty dimensions.
+ raw_dimensions=frozenset(getattr(connection, "dimensions", frozenset())),
+ )
+
+ def _resolve_dataset_type(self, current: DatasetType | None, universe: DimensionUniverse) -> DatasetType:
+ """Participate in the construction of the `DatasetTypeNode` object
+ associated with this edge.
+
+ Parameters
+ ----------
+ current : `lsst.daf.butler.DatasetType` or `None`
+ The current graph-wide `DatasetType`, or `None`. This will always
+ be the registry's definition of the parent dataset type, if one
+ exists.
+ universe : `lsst.daf.butler.DimensionUniverse`
+ Object that holds all dimension definitions.
+
+ Returns
+ -------
+ dataset_type : `DatasetType`
+ A dataset type compatible with this edge. If ``current`` was
+ provided, this must be equal to it.
+
+ Raises
+ ------
+ IncompatibleDatasetTypeError
+ Raised if ``current is not None`` and this edge's definition is not
+ compatible with it.
+ """
+ dimensions = universe.extract(self.raw_dimensions)
+ dataset_type = DatasetType(
+ self.parent_dataset_type_name,
+ dimensions,
+ storageClass=self.storage_class_name,
+ isCalibration=self.is_calibration,
+ )
+ if current is not None:
+ if not current.is_compatible_with(dataset_type):
+ raise IncompatibleDatasetTypeError(
+ f"Incompatible definition for output dataset type {self.parent_dataset_type_name!r}: "
+ f"task {self.task_label!r} has {current}, but data repository has {dataset_type}."
+ )
+ return current
+ else:
+ return dataset_type
diff --git a/python/lsst/pipe/base/pipeline_graph/_exceptions.py b/python/lsst/pipe/base/pipeline_graph/_exceptions.py
new file mode 100644
index 000000000..8ed6cd164
--- /dev/null
+++ b/python/lsst/pipe/base/pipeline_graph/_exceptions.py
@@ -0,0 +1,95 @@
+# This file is part of pipe_base.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+from __future__ import annotations
+
+__all__ = (
+ "ConnectionTypeConsistencyError",
+ "DuplicateOutputError",
+ "IncompatibleDatasetTypeError",
+ "PipelineGraphExceptionSafetyError",
+ "PipelineDataCycleError",
+ "PipelineGraphError",
+ "PipelineGraphReadError",
+ "EdgesChangedError",
+ "UnresolvedGraphError",
+ "TaskNotImportedError",
+)
+
+
+class PipelineGraphError(RuntimeError):
+ """Base exception raised when there is a problem constructing or resolving
+ a pipeline graph.
+ """
+
+
+class DuplicateOutputError(PipelineGraphError):
+ """Exception raised when multiple tasks in one pipeline produce the same
+ output dataset type.
+ """
+
+
+class PipelineDataCycleError(PipelineGraphError):
+ """Exception raised when a pipeline graph contains a cycle."""
+
+
+class ConnectionTypeConsistencyError(PipelineGraphError):
+ """Exception raised when the tasks in a pipeline graph use different (and
+ incompatible) connection types for the same dataset type.
+ """
+
+
+class IncompatibleDatasetTypeError(PipelineGraphError):
+ """Exception raised when the tasks in a pipeline graph define dataset types
+ with the same name in incompatible ways, or when these are incompatible
+ with the data repository definition.
+ """
+
+
+class UnresolvedGraphError(PipelineGraphError):
+ """Exception raised when an operation requires dimensions or dataset types
+ to have been resolved, but they have not been.
+ """
+
+
+class PipelineGraphReadError(PipelineGraphError, IOError):
+ """Exception raised when a serialized PipelineGraph cannot be read."""
+
+
+class TaskNotImportedError(PipelineGraphError):
+ """Exception raised when accessing an attribute of a graph or graph node
+ that is not available unless the task class has been imported and
+ configured.
+ """
+
+
+class EdgesChangedError(PipelineGraphError):
+ """Exception raised when the edges in one version of a pipeline graph
+ are not consistent with those in another, but they were expected to be.
+ """
+
+
+class PipelineGraphExceptionSafetyError(PipelineGraphError):
+ """Exception raised when a PipelineGraph method could not provide strong
+ exception safety, and the graph may have been left in an inconsistent
+ state.
+
+ The originating exception is always chained when this exception is raised.
+ """
diff --git a/python/lsst/pipe/base/pipeline_graph/_mapping_views.py b/python/lsst/pipe/base/pipeline_graph/_mapping_views.py
new file mode 100644
index 000000000..6f12d42c3
--- /dev/null
+++ b/python/lsst/pipe/base/pipeline_graph/_mapping_views.py
@@ -0,0 +1,197 @@
+# This file is part of pipe_base.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+from __future__ import annotations
+
+from collections.abc import Iterable, Iterator, Mapping
+from typing import Any, ClassVar, Sequence, TypeVar, cast, overload
+
+import networkx
+
+from ._dataset_types import DatasetTypeNode
+from ._exceptions import UnresolvedGraphError
+from ._nodes import NodeKey, NodeType
+from ._tasks import TaskInitNode, TaskNode
+
+_N = TypeVar("_N", covariant=True)
+_T = TypeVar("_T")
+
+
+class MappingView(Mapping[str, _N]):
+ """Base class for mapping views into nodes of certain types in a
+ `PipelineGraph`.
+
+
+ Parameters
+ ----------
+ parent_xgraph : `networkx.MultiDiGraph`
+ Backing networkx graph for the `PipelineGraph` instance.
+
+ Notes
+ -----
+ Instances should only be constructed by `PipelineGraph` and its helper
+ classes.
+
+ Iteration order is topologically sorted if and only if the backing
+ `PipelineGraph` has been sorted since its last modification.
+ """
+
+ def __init__(self, parent_xgraph: networkx.MultiDiGraph) -> None:
+ self._parent_xgraph = parent_xgraph
+ self._keys: list[str] | None = None
+
+ _NODE_TYPE: ClassVar[NodeType] # defined by derived classes
+
+ def __contains__(self, key: object) -> bool:
+ # The given key may not be a str, but if it isn't it'll just fail the
+ # check, which is what we want anyway.
+ return NodeKey(self._NODE_TYPE, cast(str, key)) in self._parent_xgraph
+
+ def __iter__(self) -> Iterator[str]:
+ if self._keys is None:
+ self._keys = self._make_keys(self._parent_xgraph)
+ return iter(self._keys)
+
+ def __getitem__(self, key: str) -> _N:
+ return self._parent_xgraph.nodes[NodeKey(self._NODE_TYPE, key)]["instance"]
+
+ def __len__(self) -> int:
+ if self._keys is None:
+ self._keys = self._make_keys(self._parent_xgraph)
+ return len(self._keys)
+
+ def __repr__(self) -> str:
+ return f"{type(self).__name__}({self!s})"
+
+ def __str__(self) -> str:
+ return f"{{{', '.join(iter(self))}}}"
+
+ def _reorder(self, parent_keys: Sequence[NodeKey]) -> None:
+ """Set this view's iteration order according to the given iterable of
+ parent keys.
+
+ Parameters
+ ----------
+ parent_keys : `~collections.abc.Sequence` [ `NodeKey` ]
+ Superset of the keys in this view, in the new order.
+ """
+ self._keys = self._make_keys(parent_keys)
+
+ def _reset(self) -> None:
+ """Reset all cached content.
+
+ This should be called by the parent graph after any changes that could
+ invalidate the view, causing it to be reconstructed when next
+ requested.
+ """
+ self._keys = None
+
+ def _make_keys(self, parent_keys: Iterable[NodeKey]) -> list[str]:
+ """Make a sequence of keys for this view from an iterable of parent
+ keys.
+
+ Parameters
+ ----------
+ parent_keys : `~collections.abc.Iterable` [ `NodeKey` ]
+ Superset of the keys in this view.
+ """
+ return [str(k) for k in parent_keys if k.node_type is self._NODE_TYPE]
+
+
+class TaskMappingView(MappingView[TaskNode]):
+ """A mapping view of the tasks in a `PipelineGraph`.
+
+ Notes
+ -----
+ Mapping keys are task labels and values are `TaskNode` instances.
+ Iteration order is topological if and only if the `PipelineGraph` has been
+ sorted since its last modification.
+ """
+
+ _NODE_TYPE = NodeType.TASK
+
+
+class TaskInitMappingView(MappingView[TaskInitNode]):
+ """A mapping view of the nodes representing task initialization in a
+ `PipelineGraph`.
+
+ Notes
+ -----
+ Mapping keys are task labels and values are `TaskInitNode` instances.
+ Iteration order is topological if and only if the `PipelineGraph` has been
+ sorted since its last modification.
+ """
+
+ _NODE_TYPE = NodeType.TASK_INIT
+
+
+class DatasetTypeMappingView(MappingView[DatasetTypeNode]):
+ """A mapping view of the nodes representing task initialization in a
+ `PipelineGraph`.
+
+ Notes
+ -----
+ Mapping keys are parent dataset type names and values are `DatasetTypeNode`
+ instances, but values are only available for nodes that have been resolved
+ (see `PipelineGraph.resolve`). Attempting to access an unresolved value
+ will result in `UnresolvedGraphError` being raised. Keys for unresolved
+ nodes are always present and iterable.
+
+ Iteration order is topological if and only if the `PipelineGraph` has been
+ sorted since its last modification.
+ """
+
+ _NODE_TYPE = NodeType.DATASET_TYPE
+
+ def __getitem__(self, key: str) -> DatasetTypeNode:
+ if (result := super().__getitem__(key)) is None:
+ raise UnresolvedGraphError(f"Node for dataset type {key!r} has not been resolved.")
+ return result
+
+ def is_resolved(self, key: str) -> bool:
+ """Test whether a node has been resolved."""
+ return super().__getitem__(key) is not None
+
+ @overload
+ def get_if_resolved(self, key: str) -> DatasetTypeNode | None:
+ ... # pragma: nocover
+
+ @overload
+ def get_if_resolved(self, key: str, default: _T) -> DatasetTypeNode | _T:
+ ... # pragma: nocover
+
+ def get_if_resolved(self, key: str, default: Any = None) -> DatasetTypeNode | Any:
+ """Get a node or return a default if it has not been resolved.
+
+ Parameters
+ ----------
+ key : `str`
+ Parent dataset type name.
+ default
+ Value to return if this dataset type has not been resolved.
+
+ Raises
+ ------
+ KeyError
+ Raised if the node is not present in the graph at all.
+ """
+ if (result := super().__getitem__(key)) is None:
+ return default # type: ignore
+ return result
diff --git a/python/lsst/pipe/base/pipeline_graph/_nodes.py b/python/lsst/pipe/base/pipeline_graph/_nodes.py
new file mode 100644
index 000000000..b9ec00fca
--- /dev/null
+++ b/python/lsst/pipe/base/pipeline_graph/_nodes.py
@@ -0,0 +1,85 @@
+# This file is part of pipe_base.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+from __future__ import annotations
+
+__all__ = (
+ "NodeKey",
+ "NodeType",
+)
+
+import enum
+from typing import NamedTuple
+
+
+class NodeType(enum.Enum):
+ """Enumeration of the types of nodes in a PipelineGraph."""
+
+ DATASET_TYPE = 0
+ TASK_INIT = 1
+ TASK = 2
+
+ @property
+ def bipartite(self) -> int:
+ """The integer used as the "bipartite" key in networkx exports of a
+ `PipelineGraph`.
+
+ This key is used by the `networkx.algorithms.bipartite` module.
+ """
+ return int(self is not NodeType.DATASET_TYPE)
+
+ def __lt__(self, other: NodeType) -> bool:
+ # We define __lt__ only to be able to provide deterministic tiebreaking
+ # on top of topological ordering of `PipelineGraph`` and views thereof.
+ return self.value < other.value
+
+
+class NodeKey(NamedTuple):
+ """A special key type for nodes in networkx graphs.
+
+ Notes
+ -----
+ Using a tuple for the key allows tasks labels and dataset type names with
+ the same string value to coexist in the graph. These only rarely appear in
+ `PipelineGraph` public interfaces; when the node type is implicit, bare
+ `str` task labels or dataset type names are used instead.
+
+ NodeKey objects stringify to just their name, which is used both as a way
+ to convert to the `str` objects used in the main public interface and as an
+ easy way to usefully stringify containers returned directly by networkx
+ algorithms (especially in error messages). Note that this requires `repr`,
+ not just `str`, because Python builtin containers always use `repr` on
+ their items, even in their implementations for `str`.
+ """
+
+ node_type: NodeType
+ """Node type enum for this key."""
+
+ name: str
+ """Task label or dataset type name.
+
+ This is always the parent dataset type name for component dataset types.
+ """
+
+ def __repr__(self) -> str:
+ return self.name
+
+ def __str__(self) -> str:
+ return self.name
diff --git a/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py
new file mode 100644
index 000000000..5e8ea5c18
--- /dev/null
+++ b/python/lsst/pipe/base/pipeline_graph/_pipeline_graph.py
@@ -0,0 +1,1389 @@
+# This file is part of pipe_base.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+from __future__ import annotations
+
+__all__ = ("PipelineGraph",)
+
+import gzip
+import itertools
+import json
+from collections.abc import Iterable, Iterator, Mapping, Sequence
+from typing import TYPE_CHECKING, Any, BinaryIO, Literal, TypeVar, cast
+
+import networkx
+import networkx.algorithms.bipartite
+import networkx.algorithms.dag
+from lsst.daf.butler import DataCoordinate, DataId, DimensionGraph, DimensionUniverse, Registry
+from lsst.resources import ResourcePath, ResourcePathExpression
+
+from ._dataset_types import DatasetTypeNode
+from ._edges import Edge, ReadEdge, WriteEdge
+from ._exceptions import (
+ EdgesChangedError,
+ PipelineDataCycleError,
+ PipelineGraphError,
+ PipelineGraphExceptionSafetyError,
+ UnresolvedGraphError,
+)
+from ._mapping_views import DatasetTypeMappingView, TaskMappingView
+from ._nodes import NodeKey, NodeType
+from ._task_subsets import TaskSubset
+from ._tasks import TaskInitNode, TaskNode, _TaskNodeImportedData
+
+if TYPE_CHECKING:
+ from ..config import PipelineTaskConfig
+ from ..connections import PipelineTaskConnections
+ from ..pipeline import TaskDef
+ from ..pipelineTask import PipelineTask
+
+
+_G = TypeVar("_G", bound=networkx.DiGraph | networkx.MultiDiGraph)
+
+
+class PipelineGraph:
+ """A graph representation of fully-configured pipeline.
+
+ `PipelineGraph` instances are typically constructed by calling
+ `.Pipeline.to_graph`, but in rare cases constructing and then populating
+ an empty one may be preferable.
+
+ Parameters
+ ----------
+ description : `str`, optional
+ String description for this pipeline.
+ universe : `lsst.daf.butler.DimensionUniverse`, optional
+ Definitions for all butler dimensions. If not provided, some
+ attributes will not be available until `resolve` is called.
+ data_id : `lsst.daf.butler.DataCoordinate` or other data ID, optional
+ Data ID that represents a constraint on all quanta generated by this
+ pipeline. This typically just holds the instrument constraint included
+ in the pipeline definition, if there was one.
+ """
+
+ def __init__(
+ self,
+ *,
+ description: str = "",
+ universe: DimensionUniverse | None = None,
+ data_id: DataId | None = None,
+ ) -> None:
+ self._init_from_args(
+ xgraph=None,
+ sorted_keys=None,
+ task_subsets=None,
+ description=description,
+ universe=universe,
+ data_id=data_id,
+ )
+
+ def _init_from_args(
+ self,
+ xgraph: networkx.MultiDiGraph | None,
+ sorted_keys: Sequence[NodeKey] | None,
+ task_subsets: dict[str, TaskSubset] | None,
+ description: str,
+ universe: DimensionUniverse | None,
+ data_id: DataId | None,
+ ) -> None:
+ """Initialize the graph with possibly-nontrivial arguments.
+
+ Parameters
+ ----------
+ xgraph : `networkx.MultiDiGraph` or `None`
+ The backing networkx graph, or `None` to create an empty one.
+ This graph has `NodeKey` instances for nodes and the same structure
+ as the graph exported by `make_xgraph`, but its nodes and edges
+ have a single ``instance`` attribute that holds a `TaskNode`,
+ `TaskInitNode`, `DatasetTypeNode` (or `None`), `ReadEdge`, or
+ `WriteEdge` instance.
+ sorted_keys : `Sequence` [ `NodeKey` ] or `None`
+ Topologically sorted sequence of node keys, or `None` if the graph
+ is not sorted.
+ task_subsets : `dict` [ `str`, `TaskSubset` ]
+ Labeled subsets of tasks. Values must be constructed with
+ ``xgraph`` as their parent graph.
+ description : `str`
+ String description for this pipeline.
+ universe : `lsst.daf.butler.DimensionUniverse` or `None`
+ Definitions of all dimensions.
+ data_id : `lsst.daf.butler.DataCoordinate` or other data ID mapping.
+ Data ID that represents a constraint on all quanta generated from
+ this pipeline.
+
+ Notes
+ -----
+ Only empty `PipelineGraph` instances should be constructed directly by
+ users, which sets the signature of ``__init__`` itself, but methods on
+ `PipelineGraph` and its helper classes need to be able to create them
+ with state. Those methods can call this after calling ``__new__``
+ manually, skipping ``__init__``.
+
+ `PipelineGraph` mutator methods provide strong exception safety (the
+ graph is left unchanged when an exception is raised and caught) unless
+ the exception raised is `PipelineGraphExceptionSafetyError`.
+ """
+ self._xgraph = xgraph if xgraph is not None else networkx.MultiDiGraph()
+ self._sorted_keys: Sequence[NodeKey] | None = None
+ self._task_subsets = task_subsets if task_subsets is not None else {}
+ self._description = description
+ self._tasks = TaskMappingView(self._xgraph)
+ self._dataset_types = DatasetTypeMappingView(self._xgraph)
+ self._raw_data_id: dict[str, Any]
+ if isinstance(data_id, DataCoordinate):
+ universe = data_id.universe
+ self._raw_data_id = data_id.byName()
+ elif data_id is None:
+ self._raw_data_id = {}
+ else:
+ self._raw_data_id = dict(data_id)
+ self._universe = universe
+ if sorted_keys is not None:
+ self._reorder(sorted_keys)
+
+ def __repr__(self) -> str:
+ return f"{type(self).__name__}({self.description!r}, tasks={self.tasks!s})"
+
+ @property
+ def description(self) -> str:
+ """String description for this pipeline."""
+ return self._description
+
+ @description.setter
+ def description(self, value: str) -> None:
+ # Docstring in setter.
+ self._description = value
+
+ @property
+ def universe(self) -> DimensionUniverse | None:
+ """Definitions for all butler dimensions."""
+ return self._universe
+
+ @property
+ def data_id(self) -> DataCoordinate:
+ """Data ID that represents a constraint on all quanta generated from
+ this pipeline.
+
+ This is may not be available unless `universe` is not `None`.
+ """
+ return DataCoordinate.standardize(self._raw_data_id, universe=self.universe)
+
+ @property
+ def tasks(self) -> TaskMappingView:
+ """A mapping view of the tasks in the graph.
+
+ This mapping has `str` task label keys and `TaskNode` values. Iteration
+ is topologically and deterministically ordered if and only if `sort`
+ has been called since the last modification to the graph.
+ """
+ return self._tasks
+
+ @property
+ def dataset_types(self) -> DatasetTypeMappingView:
+ """A mapping view of the dataset types in the graph.
+
+ This mapping has `str` parent dataset type name keys, but only provides
+ access to its `DatasetTypeNode` values if `resolve` has been called
+ since the last modification involving a task that uses a dataset type.
+ See `DatasetTypeMappingView` for details.
+ """
+ return self._dataset_types
+
+ @property
+ def task_subsets(self) -> Mapping[str, TaskSubset]:
+ """A mapping of all labeled subsets of tasks.
+
+ Keys are subset labels, values are sets of task labels. See
+ `TaskSubset` for more information.
+
+ Use `add_task_subset` to add a new subset. The subsets themselves may
+ be modified in-place.
+ """
+ return self._task_subsets
+
+ def iter_edges(self, init: bool = False) -> Iterator[Edge]:
+ """Iterate over edges in the graph.
+
+ Parameters
+ ----------
+ init : `bool`, optional
+ If `True` (`False` is default) iterate over the edges between task
+ initialization node and init input/output dataset types, instead of
+ the runtime task nodes and regular input/output/prerequisite
+ dataset types.
+
+ Returns
+ -------
+ edges : `~collections.abc.Iterator` [ `Edge` ]
+ A lazy iterator over `Edge` (`WriteEdge` or `ReadEdge`) instances.
+
+ Notes
+ -----
+ This method always returns _either_ init edges or runtime edges, never
+ both. The full (internal) graph that contains both also includes a
+ special edge that connects each task init node to its runtime node;
+ that is also never returned by this method, since it is never a part of
+ the init-only or runtime-only subgraphs.
+ """
+ edge: Edge
+ for _, _, edge in self._xgraph.edges(data="instance"):
+ if edge is not None and edge.is_init == init:
+ yield edge
+
+ def iter_nodes(
+ self,
+ ) -> Iterator[
+ tuple[Literal[NodeType.TASK_INIT], str, TaskInitNode]
+ | tuple[Literal[NodeType.TASK], str, TaskInitNode]
+ | tuple[Literal[NodeType.DATASET_TYPE], str, DatasetTypeNode | None]
+ ]:
+ """Iterate over nodes in the graph.
+
+ Returns
+ -------
+ nodes : `~collections.abc.Iterator` [ `tuple` ]
+ A lazy iterator over all of the nodes in the graph. Each yielded
+ element is a tuple of:
+
+ - the node type enum value (`NodeType`);
+ - the string name for the node (task label or parent dataset type
+ name);
+ - the node value (`TaskNode`, `TaskInitNode`, `DatasetTypeNode`,
+ or `None` for dataset type nodes that have not been resolved).
+ """
+ key: NodeKey
+ if self._sorted_keys is not None:
+ for key in self._sorted_keys:
+ yield key.node_type, key.name, self._xgraph.nodes[key]["instance"] # type: ignore
+ else:
+ for key, node in self._xgraph.nodes(data="instance"):
+ yield key.node_type, key.name, node # type: ignore
+
+ def iter_overall_inputs(self) -> Iterator[tuple[str, DatasetTypeNode | None]]:
+ """Iterate over all of the dataset types that are consumed but not
+ produced by the graph.
+
+ Returns
+ -------
+ dataset_types : `~collections.abc.Iterator` [ `tuple` ]
+ A lazy iterator over the overall-input dataset types (including
+ overall init inputs and prerequisites). Each yielded element is a
+ tuple of:
+
+ - the parent dataset type name;
+ - the resolved `DatasetTypeNode`, or `None` if the dataset type has
+ - not been resolved.
+ """
+ for generation in networkx.algorithms.dag.topological_generations(self._xgraph):
+ key: NodeKey
+ for key in generation:
+ # While we expect all tasks to have at least one input and
+ # hence never appear in the first topological generation, that
+ # is not true of task init nodes.
+ if key.node_type is NodeType.DATASET_TYPE:
+ yield key.name, self._xgraph.nodes[key]["instance"]
+ return
+
+ def make_xgraph(self) -> networkx.MultiDiGraph:
+ """Export a networkx representation of the full pipeline graph,
+ including both init and runtime edges.
+
+ Returns
+ -------
+ xgraph : `networkx.MultiDiGraph`
+ Directed acyclic graph with parallel edges.
+
+ Notes
+ -----
+ The returned graph uses `NodeKey` instances for nodes. Parallel edges
+ represent the same dataset type appearing in multiple connections for
+ the same task, and are hence rare. The connection name is used as the
+ edge key to disambiguate those parallel edges.
+
+ Almost all edges connect dataset type nodes to task or task init nodes
+ or vice versa, but there is also a special edge that connects each task
+ init node to its runtime node. The existence of these nodes makes the
+ graph not quite bipartite, unless its init-only and runtime-only
+ subgraphs.
+
+ See `TaskNode`, `TaskInitNode`, `DatasetTypeNode`, `ReadEdge`, and
+ `WriteEdge` for the descriptive node and edge attributes added.
+ """
+ return self._transform_xgraph_state(self._xgraph.copy(), skip_edges=False)
+
+ def make_bipartite_xgraph(self, init: bool = False) -> networkx.MultiDiGraph:
+ """Return a bipartite networkx representation of just the runtime or
+ init-time pipeline graph.
+
+ Parameters
+ ----------
+ init : `bool`, optional
+ If `True` (`False` is default) return the graph of task
+ initialization nodes and init input/output dataset types, instead
+ of the graph of runtime task nodes and regular
+ input/output/prerequisite dataset types.
+
+ Returns
+ -------
+ xgraph : `networkx.MultiDiGraph`
+ Directed acyclic graph with parallel edges.
+
+ Notes
+ -----
+ The returned graph uses `NodeKey` instances for nodes. Parallel edges
+ represent the same dataset type appearing in multiple connections for
+ the same task, and are hence rare. The connection name is used as the
+ edge key to disambiguate those parallel edges.
+
+ This graph is bipartite because each dataset type node only has edges
+ that connect it to a task [init] node, and vice versa.
+
+ See `TaskNode`, `TaskInitNode`, `DatasetTypeNode`, `ReadEdge`, and
+ `WriteEdge` for the descriptive node and edge attributes added.
+ """
+ return self._transform_xgraph_state(
+ self._make_bipartite_xgraph_internal(init).copy(), skip_edges=False
+ )
+
+ def make_task_xgraph(self, init: bool = False) -> networkx.DiGraph:
+ """Return a networkx representation of just the tasks in the pipeline.
+
+ Parameters
+ ----------
+ init : `bool`, optional
+ If `True` (`False` is default) return the graph of task
+ initialization nodes, instead of the graph of runtime task nodes.
+
+ Returns
+ -------
+ xgraph : `networkx.DiGraph`
+ Directed acyclic graph with no parallel edges.
+
+ Notes
+ -----
+ The returned graph uses `NodeKey` instances for nodes. The dataset
+ types that link these tasks are not represented at all; edges have no
+ attributes, and there are no parallel edges.
+
+ See `TaskNode` and `TaskInitNode` for the descriptive node and
+ attributes added.
+ """
+ bipartite_xgraph = self._make_bipartite_xgraph_internal(init)
+ task_keys = [
+ key
+ for key, bipartite in bipartite_xgraph.nodes(data="bipartite")
+ if bipartite == NodeType.TASK.bipartite
+ ]
+ return self._transform_xgraph_state(
+ networkx.algorithms.bipartite.projected_graph(networkx.DiGraph(bipartite_xgraph), task_keys),
+ skip_edges=True,
+ )
+
+ def make_dataset_type_xgraph(self, init: bool = False) -> networkx.DiGraph:
+ """Return a networkx representation of just the dataset types in the
+ pipeline.
+
+ Parameters
+ ----------
+ init : `bool`, optional
+ If `True` (`False` is default) return the graph of init input and
+ output dataset types, instead of the graph of runtime (input,
+ output, prerequisite input) dataset types.
+
+ Returns
+ -------
+ xgraph : `networkx.DiGraph`
+ Directed acyclic graph with no parallel edges.
+
+ Notes
+ -----
+ The returned graph uses `NodeKey` instances for nodes. The tasks that
+ link these tasks are not represented at all; edges have no attributes,
+ and there are no parallel edges.
+
+ See `DatasetTypeNode` for the descriptive node and attributes added.
+ """
+ bipartite_xgraph = self._make_bipartite_xgraph_internal(init)
+ dataset_type_keys = [
+ key
+ for key, bipartite in bipartite_xgraph.nodes(data="bipartite")
+ if bipartite == NodeType.DATASET_TYPE.bipartite
+ ]
+ return self._transform_xgraph_state(
+ networkx.algorithms.bipartite.projected_graph(
+ networkx.DiGraph(bipartite_xgraph), dataset_type_keys
+ ),
+ skip_edges=True,
+ )
+
+ def _make_bipartite_xgraph_internal(self, init: bool) -> networkx.MultiDiGraph:
+ """Make a bipartite init-only or runtime-only internal subgraph.
+
+ See `make_bipartite_xgraph` for parameters and return values.
+
+ Notes
+ -----
+ This method returns a view of the `PipelineGraph` object's internal
+ backing graph, and hence should only be called in methods that copy the
+ result either explicitly or by running a copying algorithm before
+ returning it to the user.
+ """
+ return self._xgraph.edge_subgraph([edge.key for edge in self.iter_edges(init)])
+
+ def _transform_xgraph_state(self, xgraph: _G, skip_edges: bool) -> _G:
+ """Transform networkx graph attributes in-place from the internal
+ "instance" attributes to the documented exported attributes.
+
+ Parameters
+ ----------
+ xgraph : `networkx.DiGraph` or `networkx.MultiDiGraph`
+ Graph whose state should be transformed.
+ skip_edges : `bool`
+ If `True`, do not transform edge state.
+
+ Returns
+ -------
+ xgraph : `networkx.DiGraph` or `networkx.MultiDiGraph`
+ The same object passed in, after modification.
+
+ Notes
+ -----
+ This should be called after making a copy of the internal graph but
+ before any projection down to just task or dataset type nodes, since
+ it assumes stateful edges.
+ """
+ state: dict[str, Any]
+ for state in xgraph.nodes.values():
+ node_value: TaskInitNode | TaskNode | DatasetTypeNode | None = state.pop("instance")
+ if node_value is not None:
+ state.update(node_value._to_xgraph_state())
+ if not skip_edges:
+ for _, _, state in xgraph.edges(data=True):
+ edge: Edge | None = state.pop("instance", None)
+ if edge is not None:
+ state.update(edge._to_xgraph_state())
+ return xgraph
+
+ def group_by_dimensions(
+ self, prerequisites: bool = False
+ ) -> dict[DimensionGraph, tuple[dict[str, TaskNode], dict[str, DatasetTypeNode]]]:
+ """Group this graph's tasks and dataset types by their dimensions.
+
+ Parameters
+ ----------
+ prerequisites : `bool`, optional
+ If `True`, include prerequisite dataset types as well as regular
+ input and output datasets (including intermediates).
+
+ Returns
+ -------
+ groups : `dict` [ `DimensionGraph`, `tuple` ]
+ A dictionary of groups keyed by `DimensionGraph`, in which each
+ value is a tuple of:
+
+ - a `dict` of `TaskNode` instances, keyed by task label
+ - a `dict` of `DatasetTypeNode` instances, keyed by
+ dataset type name.
+
+ that have those dimensions.
+
+ Notes
+ -----
+ Init inputs and outputs are always included, but always have empty
+ dimensions and are hence are all grouped together.
+ """
+ result: dict[DimensionGraph, tuple[dict[str, TaskNode], dict[str, DatasetTypeNode]]] = {}
+ next_new_value: tuple[dict[str, TaskNode], dict[str, DatasetTypeNode]] = ({}, {})
+ for task_label, task_node in self.tasks.items():
+ if task_node.dimensions is None:
+ raise UnresolvedGraphError(f"Task with label {task_label!r} has not been resolved.")
+ if (group := result.setdefault(task_node.dimensions, next_new_value)) is next_new_value:
+ next_new_value = ({}, {}) # make new lists for next time
+ group[0][task_node.label] = task_node
+ for dataset_type_name, dataset_type_node in self.dataset_types.items():
+ if dataset_type_node is None:
+ raise UnresolvedGraphError(f"Dataset type {dataset_type_name!r} has not been resolved.")
+ if not dataset_type_node.is_prerequisite or prerequisites:
+ if (
+ group := result.setdefault(dataset_type_node.dataset_type.dimensions, next_new_value)
+ ) is next_new_value:
+ next_new_value = ({}, {}) # make new lists for next time
+ group[1][dataset_type_node.name] = dataset_type_node
+ return result
+
+ @property
+ def is_sorted(self) -> bool:
+ """Whether this graph's tasks and dataset types are topologically
+ sorted with the exact same deterministic tiebreakers that `sort` would
+ apply.
+
+ This may perform (and then discard) a full sort if `has_been_sorted` is
+ `False`. If the goal is to obtain a sorted graph, it is better to just
+ call `sort` without guarding that with an ``if not graph.is_sorted``
+ check.
+ """
+ if self._sorted_keys is not None:
+ return True
+ return all(
+ sorted == unsorted
+ for sorted, unsorted in zip(networkx.lexicographical_topological_sort(self._xgraph), self._xgraph)
+ )
+
+ @property
+ def has_been_sorted(self) -> bool:
+ """Whether this graph's tasks and dataset types have been
+ topologically sorted (with unspecified but deterministic tiebreakers)
+ since the last modification to the graph.
+
+ This may return `False` if the graph *happens* to be sorted but `sort`
+ was never called, but it is potentially much faster than `is_sorted`,
+ which may attempt (and then discard) a full sort if `has_been_sorted`
+ is `False`.
+ """
+ return self._sorted_keys is not None
+
+ def sort(self) -> None:
+ """Sort this graph's nodes topologically with deterministic (but
+ unspecified) tiebreakers.
+
+ This does nothing if the graph is already known to be sorted.
+ """
+ if self._sorted_keys is None:
+ try:
+ sorted_keys: Sequence[NodeKey] = list(networkx.lexicographical_topological_sort(self._xgraph))
+ except networkx.NetworkXUnfeasible as err: # pragma: no cover
+ # Should't be possible to get here, because we check for cycles
+ # when adding tasks, but we guard against it anyway.
+ cycle = networkx.find_cycle(self._xgraph)
+ raise PipelineDataCycleError(
+ f"Cycle detected while attempting to sort graph: {cycle}."
+ ) from err
+ self._reorder(sorted_keys)
+
+ def producer_of(self, dataset_type_name: str) -> WriteEdge | None:
+ """Return the `WriteEdge` that links the producing task to the named
+ dataset type.
+
+ Parameters
+ ----------
+ dataset_type_name : `str`
+ Dataset type name. Must not be a component.
+
+ Returns
+ -------
+ edge : `WriteEdge` or `None`
+ Producing edge or `None` if there isn't one in this graph.
+ """
+ for _, _, edge in self._xgraph.in_edges(
+ NodeKey(NodeType.DATASET_TYPE, dataset_type_name), data="instance"
+ ):
+ return edge
+ return None
+
+ def consumers_of(self, dataset_type_name: str) -> list[ReadEdge]:
+ """Return the `ReadEdge` objects that link the named dataset type to
+ the tasks that consume it.
+
+ Parameters
+ ----------
+ dataset_type_name : `str`
+ Dataset type name. Must not be a component.
+
+ Returns
+ -------
+ edges : `list` [ `ReadEdge` ]
+ Edges that connect this dataset type to the tasks that consume it.
+ """
+ return [
+ edge
+ for _, _, edge in self._xgraph.out_edges(
+ NodeKey(NodeType.DATASET_TYPE, dataset_type_name), data="instance"
+ )
+ ]
+
+ def add_task(
+ self,
+ label: str,
+ task_class: type[PipelineTask],
+ config: PipelineTaskConfig,
+ connections: PipelineTaskConnections | None = None,
+ ) -> TaskNode:
+ """Add a new task to the graph.
+
+ Parameters
+ ----------
+ label : `str`
+ Label for the task in the pipeline.
+ task_class : `type` [ `PipelineTask` ]
+ Class object for the task.
+ config : `PipelineTaskConfig`
+ Configuration for the task.
+ connections : `PipelineTaskConnections`, optional
+ Object that describes the dataset types used by the task. If not
+ provided, one will be constructed from the given configuration. If
+ provided, it is assumed that ``config`` has already been validated
+ and frozen.
+
+ Returns
+ -------
+ node : `TaskNode`
+ The new task node added to the graph.
+
+ Raises
+ ------
+ ValueError
+ Raised if configuration validation failed when constructing
+ ``connections``.
+ PipelineDataCycleError
+ Raised if the graph is cyclic after this addition.
+ RuntimeError
+ Raised if an unexpected exception (which will be chained) occurred
+ at a stage that may have left the graph in an inconsistent state.
+ Other exceptions should leave the graph unchanged.
+
+ Notes
+ -----
+ Checks for dataset type consistency and multiple producers do not occur
+ until `resolve` is called, since the resolution depends on both the
+ state of the data repository and all contributing tasks.
+
+ Adding new tasks removes any existing resolutions of all dataset types
+ it references and marks the graph as unsorted. It is most effiecient
+ to add all tasks up front and only then resolve and/or sort the graph.
+ """
+ key = NodeKey(NodeType.TASK, label)
+ init_key = NodeKey(NodeType.TASK_INIT, label)
+ task_node = TaskNode._from_imported_data(
+ key,
+ init_key,
+ _TaskNodeImportedData.configure(label, task_class, config, connections),
+ universe=self.universe,
+ )
+ self.add_task_nodes([task_node])
+ return task_node
+
+ def add_task_nodes(self, nodes: Iterable[TaskNode]) -> None:
+ """Add one or more existing task nodes to the graph.
+
+ Parameters
+ ----------
+ nodes : `~collections.abc.Iterable` [ `TaskNode` ]
+ Iterable of task nodes to add. If any tasks have resolved
+ dimensions, they must have the same dimension universe as the rest
+ of the graph.
+
+ Raises
+ ------
+ PipelineDataCycleError
+ Raised if the graph is cyclic after this addition.
+
+ Notes
+ -----
+ Checks for dataset type consistency and multiple producers do not occur
+ until `resolve` is called, since the resolution depends on both the
+ state of the data repository and all contributing tasks.
+
+ Adding new tasks removes any existing resolutions of all dataset types
+ it references and marks the graph as unsorted. It is most effiecient
+ to add all tasks up front and only then resolve and/or sort the graph.
+ """
+ node_data: list[tuple[NodeKey, dict[str, Any]]] = []
+ edge_data: list[tuple[NodeKey, NodeKey, str, dict[str, Any]]] = []
+ for task_node in nodes:
+ task_node = task_node._resolved(self._universe)
+ node_data.append(
+ (task_node.key, {"instance": task_node, "bipartite": task_node.key.node_type.bipartite})
+ )
+ node_data.append(
+ (
+ task_node.init.key,
+ {"instance": task_node.init, "bipartite": task_node.init.key.node_type.bipartite},
+ )
+ )
+ # Convert the edge objects attached to the task node to networkx.
+ for read_edge in task_node.init.iter_all_inputs():
+ self._append_graph_data_from_edge(node_data, edge_data, read_edge)
+ for write_edge in task_node.init.iter_all_outputs():
+ self._append_graph_data_from_edge(node_data, edge_data, write_edge)
+ for read_edge in task_node.iter_all_inputs():
+ self._append_graph_data_from_edge(node_data, edge_data, read_edge)
+ for write_edge in task_node.iter_all_outputs():
+ self._append_graph_data_from_edge(node_data, edge_data, write_edge)
+ # Add a special edge (with no Edge instance) that connects the
+ # TaskInitNode to the runtime TaskNode.
+ edge_data.append((task_node.init.key, task_node.key, Edge.INIT_TO_TASK_NAME, {"instance": None}))
+ if not node_data and not edge_data:
+ return
+ # Checks and preparation complete; time to start the actual
+ # modification, during which it's hard to provide strong exception
+ # safety. Start by resetting the sort ordering, if there is one.
+ self._reset()
+ try:
+ self._xgraph.add_nodes_from(node_data)
+ self._xgraph.add_edges_from(edge_data)
+ if not networkx.algorithms.dag.is_directed_acyclic_graph(self._xgraph):
+ cycle = networkx.find_cycle(self._xgraph)
+ raise PipelineDataCycleError(f"Cycle detected while adding tasks: {cycle}.")
+ except Exception:
+ # First try to roll back our changes.
+ try:
+ self._xgraph.remove_edges_from(edge_data)
+ self._xgraph.remove_nodes_from(key for key, _ in node_data)
+ except Exception as err: # pragma: no cover
+ # There's no known way to get here, but we want to make it
+ # clear it's a big problem if we do.
+ raise PipelineGraphExceptionSafetyError(
+ "Error while attempting to revert PipelineGraph modification has left the graph in "
+ "an inconsistent state."
+ ) from err
+ # Successfully rolled back; raise the original exception.
+ raise
+
+ def reconfigure_tasks(
+ self,
+ *args: tuple[str, PipelineTaskConfig],
+ check_edges_unchanged: bool = False,
+ assume_edges_unchanged: bool = False,
+ **kwargs: PipelineTaskConfig,
+ ) -> None:
+ """Update the configuration for one or more tasks.
+
+ Parameters
+ ----------
+ *args : `tuple` [ `str`, `.PipelineTaskConfig` ]
+ Positional arguments are each a 2-tuple of task label and new
+ config object. Note that the same arguments may also be passed as
+ ``**kwargs``, which is usually more readable, but task labels in
+ ``*args`` are not required to be valid Python identifiers.
+ check_edges_unchanged : `bool`, optional
+ If `True`, require the edges (connections) of the modified tasks to
+ remain unchanged after the configuration updates, and verify that
+ this is the case.
+ assume_edges_unchanged : `bool`, optional
+ If `True`, the caller declares that the edges (connections) of the
+ modified tasks will remain unchanged after the configuration
+ updates, and that it is unnecessary to check this.
+ **kwargs : `.PipelineTaskConfig`
+ New config objects or overrides to apply to copies of the current
+ config objects, with task labels as the keywords.
+
+ Raises
+ ------
+ ValueError
+ Raised if ``assume_edges_unchanged`` and ``check_edges_unchanged``
+ are both `True`, or if the same task appears twice.
+ EdgesChangedError
+ Raised if ``check_edges_unchanged=True`` and the edges of a task do
+ change.
+
+ Notes
+ -----
+ If reconfiguring a task causes its edges to change, any dataset type
+ nodes connected to that task (not just those whose edges have changed!)
+ will be unresolved.
+ """
+ new_configs: dict[str, PipelineTaskConfig] = {}
+ for task_label, config_update in itertools.chain(args, kwargs.items()):
+ if new_configs.setdefault(task_label, config_update) is not config_update:
+ raise ValueError(f"Config for {task_label!r} provided more than once.")
+ updates = {
+ task_label: self.tasks[task_label]._reconfigured(config, rebuild=not assume_edges_unchanged)
+ for task_label, config in new_configs.items()
+ }
+ self._replace_task_nodes(
+ updates,
+ check_edges_unchanged=check_edges_unchanged,
+ assume_edges_unchanged=assume_edges_unchanged,
+ message_header=(
+ "Unexpected change in edges for task {task_label!r} from original config (A) to "
+ "new configs (B):"
+ ),
+ )
+
+ def remove_tasks(
+ self, labels: Iterable[str], drop_from_subsets: bool = True
+ ) -> list[tuple[TaskNode, set[str]]]:
+ """Remove one or more tasks from the graph.
+
+ Parameters
+ ----------
+ labels : `~collections.abc.Iterable` [ `str` ]
+ Iterable of the labels of the tasks to remove.
+ drop_from_subsets : `bool`, optional
+ If `True`, drop each removed task from any subset in which it
+ currently appears. If `False`, raise `PipelineGraphError` if any
+ such subsets exist.
+
+ Returns
+ -------
+ nodes_and_subsets : `list` [ `tuple` [ `TaskNode`, `set` [ `str` ] ] ]
+ List of nodes removed and the labels of task subsets that
+ referenced them.
+
+ Raises
+ ------
+ PipelineGraphError
+ Raised if ``drop_from_subsets`` is `False` and the task is still
+ part of one or more subsets.
+
+ Notes
+ -----
+ Removing a task will cause dataset nodes with no other referencing
+ tasks to be removed. Any other dataset type nodes referenced by a
+ removed task will be reset to an "unresolved" state.
+ """
+ task_nodes_and_subsets = []
+ dataset_types: set[NodeKey] = set()
+ nodes_to_remove = set()
+ for label in labels:
+ task_node: TaskNode = self._xgraph.nodes[NodeKey(NodeType.TASK, label)]["instance"]
+ # Find task subsets that reference this task.
+ referencing_subsets = {
+ subset_label
+ for subset_label, task_subset in self.task_subsets.items()
+ if label in task_subset
+ }
+ if not drop_from_subsets and referencing_subsets:
+ raise PipelineGraphError(
+ f"Task {label!r} is still referenced by subset(s) {referencing_subsets}."
+ )
+ task_nodes_and_subsets.append((task_node, referencing_subsets))
+ # Find dataset types referenced by this task.
+ dataset_types.update(self._xgraph.predecessors(task_node.key))
+ dataset_types.update(self._xgraph.successors(task_node.key))
+ dataset_types.update(self._xgraph.predecessors(task_node.init.key))
+ dataset_types.update(self._xgraph.successors(task_node.init.key))
+ # Since there's an edge between the task and its init node, we'll
+ # have added those two nodes here, too, and we don't want that.
+ dataset_types.remove(task_node.init.key)
+ dataset_types.remove(task_node.key)
+ # Mark the task node and its init node for removal from the graph.
+ nodes_to_remove.add(task_node.key)
+ nodes_to_remove.add(task_node.init.key)
+ # Process the referenced datasets to see which ones are orphaned and
+ # need to be removed vs. just unresolved.
+ nodes_to_unresolve = []
+ for dataset_type_key in dataset_types:
+ related_tasks = set()
+ related_tasks.update(self._xgraph.predecessors(dataset_type_key))
+ related_tasks.update(self._xgraph.successors(dataset_type_key))
+ related_tasks.difference_update(nodes_to_remove)
+ if not related_tasks:
+ nodes_to_remove.add(dataset_type_key)
+ else:
+ nodes_to_unresolve.append(dataset_type_key)
+ # Checks and preparation complete; time to start the actual
+ # modification, during which it's hard to provide strong exception
+ # safety. Start by resetting the sort ordering.
+ self._reset()
+ try:
+ for dataset_type_key in nodes_to_unresolve:
+ self._xgraph.nodes[dataset_type_key]["instance"] = None
+ for task_node, referencing_subsets in task_nodes_and_subsets:
+ for subset_label in referencing_subsets:
+ self._task_subsets[subset_label].remove(task_node.label)
+ self._xgraph.remove_nodes_from(nodes_to_remove)
+ except Exception as err: # pragma: no cover
+ # There's no known way to get here, but we want to make it
+ # clear it's a big problem if we do.
+ raise PipelineGraphExceptionSafetyError(
+ "Error during task removal has left the graph in an inconsistent state."
+ ) from err
+ return task_nodes_and_subsets
+
+ def add_task_subset(self, subset_label: str, task_labels: Iterable[str], description: str = "") -> None:
+ """Add a label for a set of tasks that are already in the pipeline.
+
+ Parameters
+ ----------
+ subset_label : `str`
+ Label for this set of tasks.
+ task_labels : `~collections.abc.Iterable` [ `str` ]
+ Labels of the tasks to include in the set. All must already be
+ included in the graph.
+ description : `str`, optional
+ String description to associate with this label.
+ """
+ subset = TaskSubset(self._xgraph, subset_label, set(task_labels), description)
+ self._task_subsets[subset_label] = subset
+
+ def remove_task_subset(self, subset_label: str) -> None:
+ """Remove a labeled set of tasks."""
+ del self._task_subsets[subset_label]
+
+ def copy(self) -> PipelineGraph:
+ """Return a copy of this graph that copies all mutable state."""
+ xgraph = self._xgraph.copy()
+ result = PipelineGraph.__new__(PipelineGraph)
+ result._init_from_args(
+ xgraph,
+ self._sorted_keys,
+ task_subsets={
+ k: TaskSubset(xgraph, v.label, set(v._members), v.description)
+ for k, v in self._task_subsets.items()
+ },
+ description=self._description,
+ universe=self.universe,
+ data_id=self._raw_data_id,
+ )
+ return result
+
+ def __copy__(self) -> PipelineGraph:
+ # Fully shallow copies are dangerous; we don't want shared mutable
+ # state to lead to broken class invariants.
+ return self.copy()
+
+ def __deepcopy__(self, memo: dict) -> PipelineGraph:
+ # Genuine deep copies are unnecessary, since we should only ever care
+ # that mutable state is copied.
+ return self.copy()
+
+ def import_and_configure(
+ self, check_edges_unchanged: bool = False, assume_edges_unchanged: bool = False
+ ) -> None:
+ """Import the `PipelineTask` classes referenced by all task nodes and
+ update those nodes accordingly.
+
+ Parameters
+ ----------
+ check_edges_unchanged : `bool`, optional
+ If `True`, require the edges (connections) of the modified tasks to
+ remain unchanged after importing and configuring each task, and
+ verify that this is the case.
+ assume_edges_unchanged : `bool`, optional
+ If `True`, the caller declares that the edges (connections) of the
+ modified tasks will remain unchanged importing and configuring each
+ task, and that it is unnecessary to check this.
+
+ Raises
+ ------
+ ValueError
+ Raised if ``assume_edges_unchanged`` and ``check_edges_unchanged``
+ are both `True`, or if a full config is provided for a task after
+ another full config or an override has already been provided.
+ EdgesChangedError
+ Raised if ``check_edges_unchanged=True`` and the edges of a task do
+ change.
+
+ Notes
+ -----
+ This method shouldn't need to be called unless the graph was
+ deserialized without importing and configuring immediately, which is
+ not the default behavior (but it can greatly speed up deserialization).
+ If all tasks have already been imported this does nothing.
+
+ Importing and configuring a task can change its
+ `~TaskNode.task_class_name` or `~TaskClass.get_config_str` output,
+ usually because the software used to read a serialized graph is newer
+ than the software used to write it (e.g. a new config option has been
+ added, or the task was moved to a new module with a forwarding alias
+ left behind). These changes are allowed by ``check=True``.
+
+ If importing and configuring a task causes its edges to change, any
+ dataset type nodes linked to those edges will be reset to the
+ unresolved state.
+ """
+ rebuild = check_edges_unchanged or not assume_edges_unchanged
+ updates: dict[str, TaskNode] = {}
+ node_key: NodeKey
+ for node_key, node_state in self._xgraph.nodes.items():
+ if node_key.node_type is NodeType.TASK:
+ task_node: TaskNode = node_state["instance"]
+ new_task_node = task_node._imported_and_configured(rebuild)
+ if new_task_node is not task_node:
+ updates[task_node.label] = new_task_node
+ self._replace_task_nodes(
+ updates,
+ check_edges_unchanged=check_edges_unchanged,
+ assume_edges_unchanged=assume_edges_unchanged,
+ message_header=(
+ "In task with label {task_label!r}, persisted edges (A)"
+ "differ from imported and configured edges (B):"
+ ),
+ )
+
+ def resolve(self, registry: Registry) -> None:
+ """Resolve all dimensions and dataset types and check them for
+ consistency.
+
+ Resolving a graph also causes it to be sorted.
+
+ Parameters
+ ----------
+ registry : `lsst.daf.butler.Registry`
+ Client for the data repository to resolve against.
+
+ Notes
+ -----
+ The `universe` attribute are set to ``registry.dimensions`` and used to
+ set all `TaskNode.dimensions` attributes. Dataset type nodes are
+ resolved by first looking for a registry definition, then using the
+ producing task's definition, then looking for consistency between all
+ consuming task definitions.
+
+ Raises
+ ------
+ ConnectionTypeConsistencyError
+ Raised if a prerequisite input for one task appears as a different
+ kind of connection in any other task.
+ DuplicateOutputError
+ Raised if multiple tasks have the same dataset type as an output.
+ IncompatibleDatasetTypeError
+ Raised if different tasks have different definitions of a dataset
+ type. Different but compatible storage classes are permitted.
+ MissingDatasetTypeError
+ Raised if a dataset type definition is required to exist in the
+ data repository but none was found. This should only occur for
+ dataset types that are not produced by a task in the pipeline and
+ are consumed with different storage classes or as components by
+ tasks in the pipeline.
+ EdgesChangedError
+ Raised if ``check_edges_unchanged=True`` and the edges of a task do
+ change after import and reconfiguration.
+ """
+ node_key: NodeKey
+ updates: dict[NodeKey, TaskNode | DatasetTypeNode] = {}
+ for node_key, node_state in self._xgraph.nodes.items():
+ match node_key.node_type:
+ case NodeType.TASK:
+ task_node: TaskNode = node_state["instance"]
+ new_task_node = task_node._resolved(registry.dimensions)
+ if new_task_node is not task_node:
+ updates[node_key] = new_task_node
+ case NodeType.DATASET_TYPE:
+ dataset_type_node: DatasetTypeNode | None = node_state["instance"]
+ new_dataset_type_node = DatasetTypeNode._from_edges(
+ node_key, self._xgraph, registry, previous=dataset_type_node
+ )
+ if new_dataset_type_node is not dataset_type_node:
+ updates[node_key] = new_dataset_type_node
+ try:
+ for node_key, node_value in updates.items():
+ self._xgraph.nodes[node_key]["instance"] = node_value
+ except Exception as err: # pragma: no cover
+ # There's no known way to get here, but we want to make it
+ # clear it's a big problem if we do.
+ raise PipelineGraphExceptionSafetyError(
+ "Error during dataset type resolution has left the graph in an inconsistent state."
+ ) from err
+ self.sort()
+ self._universe = registry.dimensions
+
+ @classmethod
+ def read_stream(
+ cls,
+ stream: BinaryIO,
+ import_and_configure: bool = True,
+ check_edges_unchanged: bool = False,
+ assume_edges_unchanged: bool = False,
+ ) -> PipelineGraph:
+ """Read a serialized `PipelineGraph` from a file-like object.
+
+ Parameters
+ ----------
+ stream : `BinaryIO`
+ File-like object opened for binary reading, containing
+ gzip-compressed JSON.
+ import_and_configure : `bool`, optional
+ If `True`, import and configure all tasks immediately (see the
+ `import_and_configure` method). If `False`, some `TaskNode` and
+ `TaskInitNode` attributes will not be available, but reading may be
+ much faster.
+ check_edges_unchanged : `bool`, optional
+ Forwarded to `import_and_configure` after reading.
+ assume_edges_unchanged : `bool`, optional
+ Forwarded to `import_and_configure` after reading.
+
+ Returns
+ -------
+ graph : `PipelineGraph`
+ Deserialized pipeline graph.
+
+ Raises
+ ------
+ PipelineGraphReadError
+ Raised if the serialized `PipelineGraph` is not self-consistent.
+ EdgesChangedError
+ Raised if ``check_edges_unchanged=True`` and the edges of a task do
+ change after import and reconfiguration.
+ """
+ from .io import SerializedPipelineGraph
+
+ with gzip.open(stream, "rb") as uncompressed_stream:
+ data = json.load(uncompressed_stream)
+ serialized_graph = SerializedPipelineGraph.parse_obj(data)
+ return serialized_graph.deserialize(
+ import_and_configure=import_and_configure,
+ check_edges_unchanged=check_edges_unchanged,
+ assume_edges_unchanged=assume_edges_unchanged,
+ )
+
+ @classmethod
+ def read_uri(
+ cls,
+ uri: ResourcePathExpression,
+ import_and_configure: bool = True,
+ check_edges_unchanged: bool = False,
+ assume_edges_unchanged: bool = False,
+ ) -> PipelineGraph:
+ """Read a serialized `PipelineGraph` from a file at a URI.
+
+ Parameters
+ ----------
+ uri : convertible to `lsst.resources.ResourcePath`
+ URI to a gzip-compressed JSON file containing a serialized pipeline
+ graph.
+ import_and_configure : `bool`, optional
+ If `True`, import and configure all tasks immediately (see
+ the `import_and_configure` method). If `False`, some `TaskNode`
+ and `TaskInitNode` attributes will not be available, but reading
+ may be much faster.
+ check_edges_unchanged : `bool`, optional
+ Forwarded to `import_and_configure` after reading.
+ assume_edges_unchanged : `bool`, optional
+ Forwarded to `import_and_configure` after reading.
+
+ Returns
+ -------
+ graph : `PipelineGraph`
+ Deserialized pipeline graph.
+
+ Raises
+ ------
+ PipelineGraphReadError
+ Raised if the serialized `PipelineGraph` is not self-consistent.
+ EdgesChangedError
+ Raised if ``check_edges_unchanged=True`` and the edges of a task do
+ change after import and reconfiguration.
+ """
+ uri = ResourcePath(uri)
+ with uri.open("rb") as stream:
+ return cls.read_stream(
+ cast(BinaryIO, stream),
+ import_and_configure=import_and_configure,
+ check_edges_unchanged=check_edges_unchanged,
+ assume_edges_unchanged=assume_edges_unchanged,
+ )
+
+ def write_stream(self, stream: BinaryIO) -> None:
+ """Write the pipeline to a file-like object.
+
+ Parameters
+ ----------
+ stream
+ File-like object opened for binary writing.
+
+ Notes
+ -----
+ The file format is gzipped JSON, and is intended to be human-readable,
+ but it should not be considered a stable public interface for outside
+ code, which should always use `PipelineGraph` methods (or at least the
+ `io.SerializedPipelineGraph` class) to read these files.
+ """
+ from .io import SerializedPipelineGraph
+
+ with gzip.open(stream, mode="wb") as compressed_stream:
+ compressed_stream.write(
+ SerializedPipelineGraph.serialize(self).json(exclude_defaults=True, indent=2).encode("utf-8")
+ )
+
+ def write_uri(self, uri: ResourcePathExpression) -> None:
+ """Write the pipeline to a file given a URI.
+
+ Parameters
+ ----------
+ uri : convertible to `lsst.resources.ResourcePath`
+ URI to write to . May have ``.json.gz`` or no extension (which
+ will cause a ``.json.gz`` extension to be added).
+
+ Notes
+ -----
+ The file format is gzipped JSON, and is intended to be human-readable,
+ but it should not be considered a stable public interface for outside
+ code, which should always use `PipelineGraph` methods (or at least the
+ `io.SerializedPipelineGraph` class) to read these files.
+ """
+ uri = ResourcePath(uri)
+ extension = uri.getExtension()
+ if not extension:
+ uri = uri.updatedExtension(".json.gz")
+ elif extension != ".json.gz":
+ raise ValueError("Expanded pipeline files should always have a .json.gz extension.")
+ with uri.open(mode="wb") as stream:
+ self.write_stream(cast(BinaryIO, stream))
+
+ def _iter_task_defs(self) -> Iterator[TaskDef]:
+ """Iterate over this pipeline as a sequence of `TaskDef` instances.
+
+ Notes
+ -----
+ This is a package-private method intended to aid in the transition to a
+ codebase more fully integrated with the `PipelineGraph` class, in which
+ both `TaskDef` and `PipelineDatasetTypes` are expected to go away, and
+ much of the functionality on the `Pipeline` class will be moved to
+ `PipelineGraph` as well.
+
+ Raises
+ ------
+ TaskNotImportedError
+ Raised if `TaskNode.is_imported` is `False` for any task.
+ """
+ from ..pipeline import TaskDef
+
+ for node in self._tasks.values():
+ yield TaskDef(
+ config=node.config,
+ taskClass=node.task_class,
+ label=node.label,
+ connections=node._get_imported_data().connections,
+ )
+
+ def _replace_task_nodes(
+ self,
+ updates: Mapping[str, TaskNode],
+ check_edges_unchanged: bool,
+ assume_edges_unchanged: bool,
+ message_header: str,
+ ) -> None:
+ """Replace task nodes and update edges and dataset type nodes
+ accordingly.
+
+ Parameters
+ ----------
+ updates : `Mapping` [ `str`, `TaskNode` ]
+ New task nodes with task label keys. All keys must be task labels
+ that are already present in the graph.
+ check_edges_unchanged : `bool`, optional
+ If `True`, require the edges (connections) of the modified tasks to
+ remain unchanged after importing and configuring each task, and
+ verify that this is the case.
+ assume_edges_unchanged : `bool`, optional
+ If `True`, the caller declares that the edges (connections) of the
+ modified tasks will remain unchanged importing and configuring each
+ task, and that it is unnecessary to check this.
+ message_header : `str`
+ Template for `str.format` with a single ``task_label`` placeholder
+ to use as the first line in `EdgesChangedError` messages that show
+ the differences between new task edges and old task edges. Should
+ include the fact that the rest of the message will refer to the old
+ task as "A" and the new task as "B", and end with a colon.
+
+ Raises
+ ------
+ ValueError
+ Raised if ``assume_edges_unchanged`` and ``check_edges_unchanged``
+ are both `True`, or if a full config is provided for a task after
+ another full config or an override has already been provided.
+ EdgesChangedError
+ Raised if ``check_edges_unchanged=True`` and the edges of a task do
+ change.
+ """
+ deep: dict[str, TaskNode] = {}
+ shallow: dict[str, TaskNode] = {}
+ if assume_edges_unchanged:
+ if check_edges_unchanged:
+ raise ValueError("Cannot simultaneously assume and check that edges have not changed.")
+ shallow.update(updates)
+ else:
+ for task_label, new_task_node in updates.items():
+ old_task_node = self.tasks[task_label]
+ messages = old_task_node.diff_edges(new_task_node)
+ if messages:
+ if check_edges_unchanged:
+ messages.insert(0, message_header.format(task_label=task_label))
+ raise EdgesChangedError("\n".join(messages))
+ else:
+ deep[task_label] = new_task_node
+ else:
+ shallow[task_label] = new_task_node
+ try:
+ if deep:
+ removed = self.remove_tasks(deep.keys(), drop_from_subsets=True)
+ self.add_task_nodes(deep.values())
+ for replaced_task_node, referencing_subsets in removed:
+ for subset_label in referencing_subsets:
+ self._task_subsets[subset_label].add(replaced_task_node.label)
+ for task_node in shallow.values():
+ self._xgraph.nodes[task_node.key]["instance"] = task_node
+ self._xgraph.nodes[task_node.init.key]["instance"] = task_node.init
+ except PipelineGraphExceptionSafetyError: # pragma: no cover
+ raise
+ except Exception as err: # pragma: no cover
+ # There's no known way to get here, but we want to make it clear
+ # it's a big problem if we do.
+ raise PipelineGraphExceptionSafetyError(
+ "Error while replacing tasks has left the graph in an inconsistent state."
+ ) from err
+
+ def _append_graph_data_from_edge(
+ self,
+ node_data: list[tuple[NodeKey, dict[str, Any]]],
+ edge_data: list[tuple[NodeKey, NodeKey, str, dict[str, Any]]],
+ edge: Edge,
+ ) -> None:
+ """Append networkx state dictionaries for an edge and the corresponding
+ dataset type node.
+
+ Parameters
+ ----------
+ node_data : `list`
+ List of node keys and state dictionaries. A node is appended if
+ one does not already exist for this dataset type.
+ edge_data : `list`
+ List of node key pairs, connection names, and state dictionaries
+ for edges.
+ edge : `Edge`
+ New edge being processed.
+ """
+ if (existing_dataset_type_state := self._xgraph.nodes.get(edge.dataset_type_key)) is not None:
+ existing_dataset_type_state["instance"] = None
+ else:
+ node_data.append(
+ (
+ edge.dataset_type_key,
+ {
+ "instance": None,
+ "bipartite": NodeType.DATASET_TYPE.bipartite,
+ },
+ )
+ )
+ edge_data.append(
+ edge.nodes
+ + (
+ edge.connection_name,
+ {"instance": edge},
+ )
+ )
+
+ def _reorder(self, sorted_keys: Sequence[NodeKey]) -> None:
+ """Set the order of all views of this graph from the given sorted
+ sequence of task labels and dataset type names.
+ """
+ self._sorted_keys = sorted_keys
+ self._tasks._reorder(sorted_keys)
+ self._dataset_types._reorder(sorted_keys)
+
+ def _reset(self) -> None:
+ """Reset the all views of this graph following a modification that
+ might invalidate them.
+ """
+ self._sorted_keys = None
+ self._tasks._reset()
+ self._dataset_types._reset()
diff --git a/python/lsst/pipe/base/pipeline_graph/_task_subsets.py b/python/lsst/pipe/base/pipeline_graph/_task_subsets.py
new file mode 100644
index 000000000..1c48ecab9
--- /dev/null
+++ b/python/lsst/pipe/base/pipeline_graph/_task_subsets.py
@@ -0,0 +1,122 @@
+# This file is part of pipe_base.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+from __future__ import annotations
+
+__all__ = ("TaskSubset",)
+
+from collections.abc import Iterator, MutableSet
+
+import networkx
+import networkx.algorithms.boundary
+
+from ._exceptions import PipelineGraphError
+from ._nodes import NodeKey, NodeType
+
+
+class TaskSubset(MutableSet[str]):
+ """A specialized set that represents a labeles subset of the tasks in a
+ pipeline graph.
+
+ Instances of this class should never be constructed directly; they should
+ only be accessed via the `PipelineGraph.task_subsets` attribute and created
+ by the `PipelineGraph.add_task_subset` method.
+
+ Parameters
+ ----------
+ parent_xgraph : `networkx.DiGraph`
+ Parent networkx graph that this subgraph is part of.
+ label : `str`
+ Label associated with this subset of the pipeline.
+ members : `set` [ `str` ]
+ Labels of the tasks that are members of this subset.
+ description : `str`, optional
+ Description string associated with this labeled subset.
+
+ Notes
+ -----
+ Iteration order is arbitrary, even when the parent pipeline graph is
+ ordered (there is no guarantee that an ordering of the tasks in the graph
+ implies a consistent ordering of subsets).
+ """
+
+ def __init__(
+ self,
+ parent_xgraph: networkx.DiGraph,
+ label: str,
+ members: set[str],
+ description: str,
+ ):
+ self._parent_xgraph = parent_xgraph
+ self._label = label
+ self._members = members
+ self._description = description
+
+ @property
+ def label(self) -> str:
+ """Label associated with this subset of the pipeline."""
+ return self._label
+
+ @property
+ def description(self) -> str:
+ """Description string associated with this labeled subset."""
+ return self._description
+
+ @description.setter
+ def description(self, value: str) -> None:
+ # Docstring in getter.
+ self._description = value
+
+ def __repr__(self) -> str:
+ return f"{self.label}: {self.description!r}, tasks={{{', '.join(iter(self))}}}"
+
+ def __contains__(self, key: object) -> bool:
+ return key in self._members
+
+ def __len__(self) -> int:
+ return len(self._members)
+
+ def __iter__(self) -> Iterator[str]:
+ return iter(self._members)
+
+ def add(self, task_label: str) -> None:
+ """Add a new task to this subset.
+
+ Parameters
+ ----------
+ task_label : `str`
+ Label for the task. Must already be present in the parent pipeline
+ graph.
+ """
+ key = NodeKey(NodeType.TASK, task_label)
+ if key not in self._parent_xgraph:
+ raise PipelineGraphError(f"{task_label!r} is not a task in the parent pipeline.")
+ self._members.add(key.name)
+
+ def discard(self, task_label: str) -> None:
+ """Remove a task from the subset if it is present.
+
+ Parameters
+ ----------
+ task_label : `str`
+ Label for the task. Must already be present in the parent pipeline
+ graph.
+ """
+ self._members.discard(task_label)
diff --git a/python/lsst/pipe/base/pipeline_graph/_tasks.py b/python/lsst/pipe/base/pipeline_graph/_tasks.py
new file mode 100644
index 000000000..4cc841e2a
--- /dev/null
+++ b/python/lsst/pipe/base/pipeline_graph/_tasks.py
@@ -0,0 +1,854 @@
+# This file is part of pipe_base.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+from __future__ import annotations
+
+__all__ = ("TaskNode", "TaskInitNode")
+
+import dataclasses
+from collections.abc import Iterator, Mapping
+from typing import TYPE_CHECKING, Any, cast
+
+from lsst.daf.butler import DimensionGraph, DimensionUniverse
+from lsst.utils.classes import immutable
+from lsst.utils.doImport import doImportType
+from lsst.utils.introspection import get_full_type_name
+
+from .. import automatic_connection_constants as acc
+from ..connections import PipelineTaskConnections
+from ..connectionTypes import BaseConnection, InitOutput, Output
+from ._edges import Edge, ReadEdge, WriteEdge
+from ._exceptions import TaskNotImportedError, UnresolvedGraphError
+from ._nodes import NodeKey, NodeType
+
+if TYPE_CHECKING:
+ from ..config import PipelineTaskConfig
+ from ..pipelineTask import PipelineTask
+
+
+@dataclasses.dataclass(frozen=True)
+class _TaskNodeImportedData:
+ """An internal struct that holds `TaskNode` and `TaskInitNode` state that
+ requires task classes to be imported.
+ """
+
+ task_class: type[PipelineTask]
+ """Type object for the task."""
+
+ config: PipelineTaskConfig
+ """Configuration object for the task."""
+
+ connection_map: dict[str, BaseConnection]
+ """Mapping from connection name to connection.
+
+ In addition to ``connections.allConnections``, this also holds the
+ "automatic" config, log, and metadata connections using the names defined
+ in the `.automatic_connection_constants` module.
+ """
+
+ connections: PipelineTaskConnections
+ """Configured connections object for the task."""
+
+ @classmethod
+ def configure(
+ cls,
+ label: str,
+ task_class: type[PipelineTask],
+ config: PipelineTaskConfig,
+ connections: PipelineTaskConnections | None = None,
+ ) -> _TaskNodeImportedData:
+ """Construct while creating a `PipelineTaskConnections` instance if
+ necessary.
+
+ Parameters
+ ----------
+ label : `str`
+ Label for the task in the pipeline. Only used in error messages.
+ task_class : `type` [ `.PipelineTask` ]
+ Pipeline task `type` object.
+ config : `.PipelineTaskConfig`
+ Configuration for the task.
+ connections : `.PipelineTaskConnections`, optional
+ Object that describes the dataset types used by the task. If not
+ provided, one will be constructed from the given configuration. If
+ provided, it is assumed that ``config`` has already been validated
+ and frozen.
+
+ Returns
+ -------
+ data : `_TaskNodeImportedData`
+ Instance of this struct.
+ """
+ if connections is None:
+ # If we don't have connections yet, assume the config hasn't been
+ # validated yet.
+ try:
+ config.validate()
+ except Exception as err:
+ raise ValueError(
+ f"Configuration validation failed for task {label!r} (see chained exception)."
+ ) from err
+ config.freeze()
+ connections = task_class.ConfigClass.ConnectionsClass(config=config)
+ connection_map = dict(connections.allConnections)
+ connection_map[acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME] = InitOutput(
+ acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=label),
+ acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
+ )
+ if not config.saveMetadata:
+ raise ValueError(f"Metadata for task {label} cannot be disabled.")
+ connection_map[acc.METADATA_OUTPUT_CONNECTION_NAME] = Output(
+ acc.METADATA_OUTPUT_TEMPLATE.format(label=label),
+ acc.METADATA_OUTPUT_STORAGE_CLASS,
+ dimensions=set(connections.dimensions),
+ )
+ if config.saveLogOutput:
+ connection_map[acc.LOG_OUTPUT_CONNECTION_NAME] = Output(
+ acc.LOG_OUTPUT_TEMPLATE.format(label=label),
+ acc.LOG_OUTPUT_STORAGE_CLASS,
+ dimensions=set(connections.dimensions),
+ )
+ return cls(task_class, config, connection_map, connections)
+
+
+@immutable
+class TaskInitNode:
+ """A node in a pipeline graph that represents the construction of a
+ `PipelineTask`.
+
+ Parameters
+ ----------
+ inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ]
+ Graph edges that represent inputs required just to construct an
+ instance of this task, keyed by connection name.
+ outputs : ~collections.abc.Mapping` [ `str`, `WriteEdge` ]
+ Graph edges that represent outputs of this task that are available
+ after just constructing it, keyed by connection name.
+
+ This does not include the special `config_init_output` edge; use
+ `iter_all_outputs` to include that, too.
+ config_output : `WriteEdge`
+ The special init output edge that persists the task's configuration.
+ imported_data : `_TaskNodeImportedData`, optional
+ Internal struct that holds information that requires the task class to
+ have been be imported.
+ task_class_name : `str`, optional
+ Fully-qualified name of the task class. Must be provided if
+ ``imported_data`` is not.
+ config_str : `str`, optional
+ Configuration for the task as a string of override statements. Must be
+ provided if ``imported_data`` is not.
+
+ Notes
+ -----
+ When included in an exported `networkx` graph (e.g.
+ `PipelineGraph.make_xgraph`), task initialization nodes set the following
+ node attributes:
+
+ - ``task_class_name``
+ - ``bipartite`` (see `NodeType.bipartite`)
+ - ``task_class`` (only if `is_imported` is `True`)
+ - ``config`` (only if `is_importd` is `True`)
+ """
+
+ def __init__(
+ self,
+ key: NodeKey,
+ *,
+ inputs: Mapping[str, ReadEdge],
+ outputs: Mapping[str, WriteEdge],
+ config_output: WriteEdge,
+ imported_data: _TaskNodeImportedData | None = None,
+ task_class_name: str | None = None,
+ config_str: str | None = None,
+ ):
+ self.key = key
+ self.inputs = inputs
+ self.outputs = outputs
+ self.config_output = config_output
+ # Instead of setting attributes to None, we do not set them at all;
+ # this works better with the @immutable decorator, which supports
+ # deferred initialization but not reassignment.
+ if task_class_name is not None:
+ self._task_class_name = task_class_name
+ if config_str is not None:
+ self._config_str = config_str
+ if imported_data is not None:
+ self._imported_data = imported_data
+ else:
+ assert (
+ self._task_class_name is not None and self._config_str is not None
+ ), "If imported_data is not present, task_class_name and config_str must be."
+
+ key: NodeKey
+ """Key that identifies this node in internal and exported networkx graphs.
+ """
+
+ inputs: Mapping[str, ReadEdge]
+ """Graph edges that represent inputs required just to construct an instance
+ of this task, keyed by connection name.
+ """
+
+ outputs: Mapping[str, WriteEdge]
+ """Graph edges that represent outputs of this task that are available after
+ just constructing it, keyed by connection name.
+
+ This does not include the special `config_output` edge; use
+ `iter_all_outputs` to include that, too.
+ """
+
+ config_output: WriteEdge
+ """The special output edge that persists the task's configuration.
+ """
+
+ @property
+ def label(self) -> str:
+ """Label of this configuration of a task in the pipeline."""
+ return str(self.key)
+
+ @property
+ def is_imported(self) -> bool:
+ """Whether this the task type for this node has been imported and
+ its configuration overrides applied.
+
+ If this is `False`, the `task_class` and `config` attributes may not
+ be accessed.
+ """
+ return hasattr(self, "_imported_data")
+
+ @property
+ def task_class(self) -> type[PipelineTask]:
+ """Type object for the task.
+
+ Accessing this attribute when `is_imported` is `False` will raise
+ `TaskNotImportedError`, but accessing `task_class_name` will not.
+ """
+ return self._get_imported_data().task_class
+
+ @property
+ def task_class_name(self) -> str:
+ """The fully-qualified string name of the task class."""
+ try:
+ return self._task_class_name
+ except AttributeError:
+ pass
+ self._task_class_name = get_full_type_name(self.task_class)
+ return self._task_class_name
+
+ @property
+ def config(self) -> PipelineTaskConfig:
+ """Configuration for the task.
+
+ This is always frozen.
+
+ Accessing this attribute when `is_imported` is `False` will raise
+ `TaskNotImportedError`, but calling `get_config_str` will not.
+ """
+ return self._get_imported_data().config
+
+ def get_config_str(self) -> str:
+ """Return the configuration for this task as a string of override
+ statements.
+
+ Returns
+ -------
+ config_str : `str`
+ String containing configuration-overload statements.
+ """
+ try:
+ return self._config_str
+ except AttributeError:
+ pass
+ self._config_str = self.config.saveToString()
+ return self._config_str
+
+ def iter_all_inputs(self) -> Iterator[ReadEdge]:
+ """Iterate over all inputs required for construction.
+
+ This is the same as iteration over ``inputs.values()``, but it will be
+ updated to include any automatic init-input connections added in the
+ future, while `inputs` will continue to hold only task-defined init
+ inputs.
+ """
+ return iter(self.inputs.values())
+
+ def iter_all_outputs(self) -> Iterator[WriteEdge]:
+ """Iterate over all outputs available after construction, including
+ special ones.
+ """
+ yield from self.outputs.values()
+ yield self.config_output
+
+ def diff_edges(self, other: TaskInitNode) -> list[str]:
+ """Compare the edges of this task initialization node to those from the
+ same task label in a different pipeline.
+
+ Parameters
+ ----------
+ other : `TaskInitNode`
+ Other node to compare to. Must have the same task label, but need
+ not have the same configuration or even the same task class.
+
+ Returns
+ -------
+ differences : `list` [ `str` ]
+ List of string messages describing differences between ``self`` and
+ ``other``. Will be empty if the two nodes have the same edges.
+ Messages will use 'A' to refer to ``self`` and 'B' to refer to
+ ``other``.
+ """
+ result = []
+ result += _diff_edge_mapping(self.inputs, self.inputs, self.label, "init input")
+ result += _diff_edge_mapping(self.outputs, other.outputs, self.label, "init output")
+ result += self.config_output.diff(other.config_output, "config init output")
+ return result
+
+ def _to_xgraph_state(self) -> dict[str, Any]:
+ """Convert this nodes's attributes into a dictionary suitable for use
+ in exported networkx graphs.
+ """
+ result = {"task_class_name": self.task_class_name, "bipartite": NodeType.TASK_INIT.bipartite}
+ if hasattr(self, "_imported_data"):
+ result["task_class"] = self.task_class
+ result["config"] = self.config
+ return result
+
+ def _get_imported_data(self) -> _TaskNodeImportedData:
+ """Return the imported data struct.
+
+ Returns
+ -------
+ imported_data : `_TaskNodeImportedData`
+ Internal structure holding state that requires the task class to
+ have been imported.
+
+ Raises
+ ------
+ TaskNotImportedError
+ Raised if `is_imported` is `False`.
+ """
+ try:
+ return self._imported_data
+ except AttributeError:
+ raise TaskNotImportedError(
+ f"Task class {self.task_class_name!r} for label {self.label!r} has not been imported "
+ "(see PipelineGraph.import_and_configure)."
+ ) from None
+
+
+@immutable
+class TaskNode:
+ """A node in a pipeline graph that represents a labeled configuration of a
+ `PipelineTask`.
+
+ Parameters
+ ----------
+ key : `NodeKey`
+ Identifier for this node in networkx graphs.
+ init : `TaskInitNode`
+ Node representing the initialization of this task.
+ prerequisite_inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ]
+ Graph edges that represent prerequisite inputs to this task, keyed by
+ connection name.
+
+ Prerequisite inputs must already exist in the data repository when a
+ `QuantumGraph` is built, but have more flexibility in how they are
+ looked up than regular inputs.
+ inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ]
+ Graph edges that represent regular runtime inputs to this task, keyed
+ by connection name.
+ outputs : ~collections.abc.Mapping` [ `str`, `WriteEdge` ]
+ Graph edges that represent regular runtime outputs of this task, keyed
+ by connection name.
+
+ This does not include the special `log_output` and `metadata_output`
+ edges; use `iter_all_outputs` to include that, too.
+ log_output : `WriteEdge` or `None`
+ The special runtime output that persists the task's logs.
+ metadata_output : `WriteEdge`
+ The special runtime output that persists the task's metadata.
+ dimensions : `lsst.daf.butler.DimensionGraph` or `frozenset`
+ Dimensions of the task. If a `frozenset`, the dimensions have not been
+ resolved by a `~lsst.daf.butler.DimensionUniverse` and cannot be safely
+ compared to other sets of dimensions.
+
+ Notes
+ -----
+ Task nodes are intentionally not equality comparable, since there are many
+ different (and useful) ways to compare these objects with no clear winner
+ as the most obvious behavior.
+
+ When included in an exported `networkx` graph (e.g.
+ `PipelineGraph.make_xgraph`), task nodes set the following node attributes:
+
+ - ``task_class_name``
+ - ``bipartite`` (see `NodeType.bipartite`)
+ - ``task_class`` (only if `is_imported` is `True`)
+ - ``config`` (only if `is_importd` is `True`)
+ """
+
+ def __init__(
+ self,
+ key: NodeKey,
+ init: TaskInitNode,
+ *,
+ prerequisite_inputs: Mapping[str, ReadEdge],
+ inputs: Mapping[str, ReadEdge],
+ outputs: Mapping[str, WriteEdge],
+ log_output: WriteEdge | None,
+ metadata_output: WriteEdge,
+ dimensions: DimensionGraph | frozenset,
+ ):
+ self.key = key
+ self.init = init
+ self.prerequisite_inputs = prerequisite_inputs
+ self.inputs = inputs
+ self.outputs = outputs
+ self.log_output = log_output
+ self.metadata_output = metadata_output
+ self._dimensions = dimensions
+
+ @staticmethod
+ def _from_imported_data(
+ key: NodeKey,
+ init_key: NodeKey,
+ data: _TaskNodeImportedData,
+ universe: DimensionUniverse | None,
+ ) -> TaskNode:
+ """Construct from a `PipelineTask` type and its configuration.
+
+ Parameters
+ ----------
+ key : `NodeKey`
+ Identifier for this node in networkx graphs.
+ init : `TaskInitNode`
+ Node representing the initialization of this task.
+ data : `_TaskNodeImportedData`
+ Internal struct that holds information that requires the task class
+ to have been be imported.
+ universe : `lsst.daf.butler.DimensionUniverse` or `None`
+ Definitions of all dimensions.
+
+ Returns
+ -------
+ node : `TaskNode`
+ New task node.
+
+ Raises
+ ------
+ ValueError
+ Raised if configuration validation failed when constructing
+ ``connections``.
+ """
+ init_inputs = {
+ name: ReadEdge._from_connection_map(init_key, name, data.connection_map)
+ for name in data.connections.initInputs
+ }
+ prerequisite_inputs = {
+ name: ReadEdge._from_connection_map(key, name, data.connection_map, is_prerequisite=True)
+ for name in data.connections.prerequisiteInputs
+ }
+ inputs = {
+ name: ReadEdge._from_connection_map(key, name, data.connection_map)
+ for name in data.connections.inputs
+ }
+ init_outputs = {
+ name: WriteEdge._from_connection_map(init_key, name, data.connection_map)
+ for name in data.connections.initOutputs
+ }
+ outputs = {
+ name: WriteEdge._from_connection_map(key, name, data.connection_map)
+ for name in data.connections.outputs
+ }
+ init = TaskInitNode(
+ key=init_key,
+ inputs=init_inputs,
+ outputs=init_outputs,
+ config_output=WriteEdge._from_connection_map(
+ init_key, acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME, data.connection_map
+ ),
+ imported_data=data,
+ )
+ instance = TaskNode(
+ key=key,
+ init=init,
+ prerequisite_inputs=prerequisite_inputs,
+ inputs=inputs,
+ outputs=outputs,
+ log_output=(
+ WriteEdge._from_connection_map(key, acc.LOG_OUTPUT_CONNECTION_NAME, data.connection_map)
+ if data.config.saveLogOutput
+ else None
+ ),
+ metadata_output=WriteEdge._from_connection_map(
+ key, acc.METADATA_OUTPUT_CONNECTION_NAME, data.connection_map
+ ),
+ dimensions=(
+ frozenset(data.connections.dimensions)
+ if universe is None
+ else universe.extract(data.connections.dimensions)
+ ),
+ )
+ return instance
+
+ key: NodeKey
+ """Key that identifies this node in internal and exported networkx graphs.
+ """
+
+ prerequisite_inputs: Mapping[str, ReadEdge]
+ """Graph edges that represent prerequisite inputs to this task.
+
+ Prerequisite inputs must already exist in the data repository when a
+ `QuantumGraph` is built, but have more flexibility in how they are looked
+ up than regular inputs.
+ """
+
+ inputs: Mapping[str, ReadEdge]
+ """Graph edges that represent regular runtime inputs to this task.
+ """
+
+ outputs: Mapping[str, WriteEdge]
+ """Graph edges that represent regular runtime outputs of this task.
+
+ This does not include the special `log_output` and `metadata_output` edges;
+ use `iter_all_outputs` to include that, too.
+ """
+
+ log_output: WriteEdge | None
+ """The special runtime output that persists the task's logs.
+ """
+
+ metadata_output: WriteEdge
+ """The special runtime output that persists the task's metadata.
+ """
+
+ @property
+ def label(self) -> str:
+ """Label of this configuration of a task in the pipeline."""
+ return self.key.name
+
+ @property
+ def is_imported(self) -> bool:
+ """Whether this the task type for this node has been imported and
+ its configuration overrides applied.
+
+ If this is `False`, the `task_class` and `config` attributes may not
+ be accessed.
+ """
+ return self.init.is_imported
+
+ @property
+ def task_class(self) -> type[PipelineTask]:
+ """Type object for the task.
+
+ Accessing this attribute when `is_imported` is `False` will raise
+ `TaskNotImportedError`, but accessing `task_class_name` will not.
+ """
+ return self.init.task_class
+
+ @property
+ def task_class_name(self) -> str:
+ """The fully-qualified string name of the task class."""
+ return self.init.task_class_name
+
+ @property
+ def config(self) -> PipelineTaskConfig:
+ """Configuration for the task.
+
+ This is always frozen.
+
+ Accessing this attribute when `is_imported` is `False` will raise
+ `TaskNotImportedError`, but calling `get_config_str` will not.
+ """
+ return self.init.config
+
+ @property
+ def has_resolved_dimensions(self) -> bool:
+ """Whether the `dimensions` attribute may be accessed.
+
+ If `False`, the `raw_dimensions` attribute may be used to obtain a
+ set of dimension names that has not been resolved by a
+ `~lsst.daf.butler.DimensionsUniverse`.
+ """
+ return type(self._dimensions) is DimensionGraph
+
+ @property
+ def dimensions(self) -> DimensionGraph:
+ """Standardized dimensions of the task."""
+ if not self.has_resolved_dimensions:
+ raise UnresolvedGraphError(f"Dimensions for task {self.label!r} have not been resolved.")
+ return cast(DimensionGraph, self._dimensions)
+
+ @property
+ def raw_dimensions(self) -> frozenset[str]:
+ """Raw dimensions of the task, with standardization by a
+ `~lsst.daf.butler.DimensionUniverse` not guaranteed.
+ """
+ if self.has_resolved_dimensions:
+ return frozenset(cast(DimensionGraph, self._dimensions).names)
+ else:
+ return cast(frozenset[str], self._dimensions)
+
+ def __repr__(self) -> str:
+ if self.has_resolved_dimensions:
+ return f"{self.label} ({self.task_class_name}, {self.dimensions})"
+ else:
+ return f"{self.label} ({self.task_class_name})"
+
+ def get_config_str(self) -> str:
+ """Return the configuration for this task as a string of override
+ statements.
+
+ Returns
+ -------
+ config_str : `str`
+ String containing configuration-overload statements.
+ """
+ return self.init.get_config_str()
+
+ def iter_all_inputs(self) -> Iterator[ReadEdge]:
+ """Iterate over all runtime inputs, including both regular inputs and
+ prerequisites.
+ """
+ yield from self.prerequisite_inputs.values()
+ yield from self.inputs.values()
+
+ def iter_all_outputs(self) -> Iterator[WriteEdge]:
+ """Iterate over all runtime outputs, including special ones."""
+ yield from self.outputs.values()
+ yield self.metadata_output
+ if self.log_output is not None:
+ yield self.log_output
+
+ def diff_edges(self, other: TaskNode) -> list[str]:
+ """Compare the edges of this task node to those from the same task
+ label in a different pipeline.
+
+ This also calls `TaskInitNode.diff_edges`.
+
+ Parameters
+ ----------
+ other : `TaskInitNode`
+ Other node to compare to. Must have the same task label, but need
+ not have the same configuration or even the same task class.
+
+ Returns
+ -------
+ differences : `list` [ `str` ]
+ List of string messages describing differences between ``self`` and
+ ``other``. Will be empty if the two nodes have the same edges.
+ Messages will use 'A' to refer to ``self`` and 'B' to refer to
+ ``other``.
+ """
+ result = self.init.diff_edges(other.init)
+ result += _diff_edge_mapping(
+ self.prerequisite_inputs, other.prerequisite_inputs, self.label, "prerequisite input"
+ )
+ result += _diff_edge_mapping(self.inputs, other.inputs, self.label, "input")
+ result += _diff_edge_mapping(self.outputs, other.outputs, self.label, "output")
+ if self.log_output is not None:
+ if other.log_output is not None:
+ result += self.log_output.diff(other.log_output, "log output")
+ else:
+ result.append("Log output is present in A, but not in B.")
+ elif other.log_output is not None:
+ result.append("Log output is present in B, but not in A.")
+ result += self.metadata_output.diff(other.metadata_output, "metadata output")
+ return result
+
+ def _imported_and_configured(self, rebuild: bool) -> TaskNode:
+ """Import the task class and use it to construct a new instance.
+
+ Parameters
+ ----------
+ rebuild : `bool`
+ If `True`, import the task class and configure its connections to
+ generate new edges that may differ from the current ones. If
+ `False`, import the task class but just update the `task_class` and
+ `config` attributes, and assume the edges have not changed.
+
+ Returns
+ -------
+ node : `TaskNode`
+ Task node instance for which `is_imported` is `True`. Will be
+ ``self`` if this is the case already.
+ """
+ from ..pipelineTask import PipelineTask
+
+ if self.is_imported:
+ return self
+ task_class = doImportType(self.task_class_name)
+ if not issubclass(task_class, PipelineTask):
+ raise TypeError(f"{self.task_class_name!r} is not a PipelineTask subclass.")
+ config = task_class.ConfigClass()
+ config.loadFromString(self.get_config_str())
+ return self._reconfigured(config, rebuild=rebuild, task_class=task_class)
+
+ def _reconfigured(
+ self,
+ config: PipelineTaskConfig,
+ rebuild: bool,
+ task_class: type[PipelineTask] | None = None,
+ ) -> TaskNode:
+ """Return a version of this node with new configuration.
+
+ Parameters
+ ----------
+ config : `.PipelineTaskConfig`
+ New configuration for the task.
+ rebuild : `bool`
+ If `True`, use the configured connections to generate new edges
+ that may differ from the current ones. If `False`, just update the
+ `task_class` and `config` attributes, and assume the edges have not
+ changed.
+ task_class : `type` [ `PipelineTask` ], optional
+ Subclass of `PipelineTask`. This defaults to ``self.task_class`,
+ but may be passed as an argument if that is not available because
+ the task class was not imported when ``self`` was constructed.
+
+ Returns
+ -------
+ node : `TaskNode`
+ Task node instance with the new config.
+ """
+ if task_class is None:
+ task_class = self.task_class
+ imported_data = _TaskNodeImportedData.configure(self.key.name, task_class, config)
+ if rebuild:
+ return self._from_imported_data(
+ self.key,
+ self.init.key,
+ imported_data,
+ universe=self._dimensions.universe if type(self._dimensions) is DimensionGraph else None,
+ )
+ else:
+ return TaskNode(
+ self.key,
+ TaskInitNode(
+ self.init.key,
+ inputs=self.init.inputs,
+ outputs=self.init.outputs,
+ config_output=self.init.config_output,
+ imported_data=imported_data,
+ ),
+ prerequisite_inputs=self.prerequisite_inputs,
+ inputs=self.inputs,
+ outputs=self.outputs,
+ log_output=self.log_output,
+ metadata_output=self.metadata_output,
+ dimensions=self._dimensions,
+ )
+
+ def _resolved(self, universe: DimensionUniverse | None) -> TaskNode:
+ """Return an otherwise-equivalent task node with resolved dimensions.
+
+ Parameters
+ ----------
+ universe : `lsst.daf.butler.DimensionUniverse` or `None`
+ Definitions for all dimensions.
+
+ Returns
+ -------
+ node : `TaskNode`
+ Task node instance with `dimensions` resolved by the given
+ universe. Will be ``self`` if this is the case already.
+ """
+ if self.has_resolved_dimensions:
+ if cast(DimensionGraph, self._dimensions).universe is universe:
+ return self
+ elif universe is None:
+ return self
+ return TaskNode(
+ key=self.key,
+ init=self.init,
+ prerequisite_inputs=self.prerequisite_inputs,
+ inputs=self.inputs,
+ outputs=self.outputs,
+ log_output=self.log_output,
+ metadata_output=self.metadata_output,
+ dimensions=(
+ universe.extract(self.raw_dimensions) if universe is not None else self.raw_dimensions
+ ),
+ )
+
+ def _to_xgraph_state(self) -> dict[str, Any]:
+ """Convert this nodes's attributes into a dictionary suitable for use
+ in exported networkx graphs.
+ """
+ result = self.init._to_xgraph_state()
+ if self.has_resolved_dimensions:
+ result["dimensions"] = self._dimensions
+ result["raw_dimensions"] = self.raw_dimensions
+ return result
+
+ def _get_imported_data(self) -> _TaskNodeImportedData:
+ """Return the imported data struct.
+
+ Returns
+ -------
+ imported_data : `_TaskNodeImportedData`
+ Internal structure holding state that requires the task class to
+ have been imported.
+
+ Raises
+ ------
+ TaskNotImportedError
+ Raised if `is_imported` is `False`.
+ """
+ return self.init._get_imported_data()
+
+
+def _diff_edge_mapping(
+ a_mapping: Mapping[str, Edge], b_mapping: Mapping[str, Edge], task_label: str, connection_type: str
+) -> list[str]:
+ """Compare a pair of mappings of edges.
+
+ Parameters
+ ----------
+ a_mapping : `~collections.abc.Mapping` [ `str`, `Edge` ]
+ First mapping to compare. Expected to have connection names as keys.
+ b_mapping : `~collections.abc.Mapping` [ `str`, `Edge` ]
+ First mapping to compare. If keys differ from those of ``a_mapping``,
+ this will be reported as a difference (in addition to element-wise
+ comparisons).
+ task_label : `str`
+ Task label associated with both mappings.
+ connection_type : `str`
+ Type of connection (e.g. "input" or "init output") associated with both
+ connections. This is a human-readable string to include in difference
+ messages.
+ """
+ results = []
+ b_to_do = set(b_mapping.keys())
+ for connection_name, a_edge in a_mapping.items():
+ if (b_edge := b_mapping.get(connection_name)) is None:
+ results.append(
+ f"{connection_type.capitalize()} {connection_name!r} of task "
+ f"{task_label!r} exists in A, but not in B (or it may have a different connection type)."
+ )
+ else:
+ results.extend(a_edge.diff(b_edge, connection_type))
+ b_to_do.discard(connection_name)
+ for connection_name in b_to_do:
+ results.append(
+ f"{connection_type.capitalize()} {connection_name!r} of task "
+ f"{task_label!r} exists in A, but not in B (or it may have a different connection type)."
+ )
+ return results
diff --git a/python/lsst/pipe/base/pipeline_graph/io.py b/python/lsst/pipe/base/pipeline_graph/io.py
new file mode 100644
index 000000000..09e52df55
--- /dev/null
+++ b/python/lsst/pipe/base/pipeline_graph/io.py
@@ -0,0 +1,578 @@
+# This file is part of pipe_base.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+from __future__ import annotations
+
+__all__ = (
+ "expect_not_none",
+ "SerializedEdge",
+ "SerializedTaskInitNode",
+ "SerializedTaskNode",
+ "SerializedDatasetTypeNode",
+ "SerializedTaskSubset",
+ "SerializedPipelineGraph",
+)
+
+from typing import Any, TypeVar
+
+import networkx
+import pydantic
+from lsst.daf.butler import DatasetType, DimensionConfig, DimensionGraph, DimensionUniverse
+
+from .. import automatic_connection_constants as acc
+from ._dataset_types import DatasetTypeNode
+from ._edges import Edge, ReadEdge, WriteEdge
+from ._exceptions import PipelineGraphReadError
+from ._nodes import NodeKey, NodeType
+from ._pipeline_graph import PipelineGraph
+from ._task_subsets import TaskSubset
+from ._tasks import TaskInitNode, TaskNode
+
+_U = TypeVar("_U")
+
+_IO_VERSION_INFO = (0, 0, 1)
+"""Version tuple embedded in saved PipelineGraphs.
+"""
+
+
+def expect_not_none(value: _U | None, msg: str) -> _U:
+ """Check that a value is not `None` and return it.
+
+ Parameters
+ ----------
+ value
+ Value to check
+ msg
+ Error message for the case where ``value is None``.
+
+ Returns
+ -------
+ value
+ Value, guaranteed not to be `None`.
+
+ Raises
+ ------
+ PipelineGraphReadError
+ Raised with ``msg`` if ``value is None``.
+ """
+ if value is None:
+ raise PipelineGraphReadError(msg)
+ return value
+
+
+class SerializedEdge(pydantic.BaseModel):
+ """Struct used to represent a serialized `Edge` in a `PipelineGraph`.
+
+ All `ReadEdge` and `WriteEdge` state not included here is instead
+ effectively serialized by the context in which a `SerializedEdge` appears
+ (e.g. the keys of the nested dictionaries in which it serves as the value
+ type).
+ """
+
+ dataset_type_name: str
+ """Full dataset type name (including component)."""
+
+ storage_class: str
+ """Name of the storage class."""
+
+ raw_dimensions: list[str]
+ """Raw dimensions of the dataset type from the task connections."""
+
+ is_calibration: bool = False
+ """Whether this dataset type can be included in
+ `~lsst.daf.butler.CollectionType.CALIBRATION` collections."""
+
+ defer_query_constraint: bool = False
+ """If `True`, by default do not include this dataset type's existence as a
+ constraint on the initial data ID query in QuantumGraph generation."""
+
+ @classmethod
+ def serialize(cls, target: Edge) -> SerializedEdge:
+ """Transform an `Edge` to a `SerializedEdge`."""
+ return SerializedEdge.construct(
+ storage_class=target.storage_class_name,
+ dataset_type_name=target.dataset_type_name,
+ raw_dimensions=sorted(target.raw_dimensions),
+ is_calibration=target.is_calibration,
+ defer_query_constraint=getattr(target, "defer_query_constraint", False),
+ )
+
+ def deserialize_read_edge(
+ self,
+ task_key: NodeKey,
+ connection_name: str,
+ is_prerequisite: bool = False,
+ ) -> ReadEdge:
+ """Transform a `SerializedEdge` to a `ReadEdge`."""
+ parent_dataset_type_name, component = DatasetType.splitDatasetTypeName(self.dataset_type_name)
+ dataset_type_key = NodeKey(NodeType.DATASET_TYPE, parent_dataset_type_name)
+ return ReadEdge(
+ dataset_type_key,
+ task_key,
+ storage_class_name=self.storage_class,
+ is_prerequisite=is_prerequisite,
+ component=component,
+ connection_name=connection_name,
+ is_calibration=self.is_calibration,
+ defer_query_constraint=self.defer_query_constraint,
+ raw_dimensions=frozenset(self.raw_dimensions),
+ )
+
+ def deserialize_write_edge(
+ self,
+ task_key: NodeKey,
+ connection_name: str,
+ ) -> WriteEdge:
+ """Transform a `SerializedEdge` to a `WriteEdge`."""
+ dataset_type_key = NodeKey(NodeType.DATASET_TYPE, self.dataset_type_name)
+ return WriteEdge(
+ task_key=task_key,
+ dataset_type_key=dataset_type_key,
+ storage_class_name=self.storage_class,
+ connection_name=connection_name,
+ is_calibration=self.is_calibration,
+ raw_dimensions=frozenset(self.raw_dimensions),
+ )
+
+
+class SerializedTaskInitNode(pydantic.BaseModel):
+ """Struct used to represent a serialized `TaskInitNode` in a
+ `PipelineGraph`.
+
+ The task label is serialized by the context in which a
+ `SerializedTaskInitNode` appears (e.g. the keys of the nested dictionary
+ in which it serves as the value type), and the task class name and config
+ string are save with the corresponding `SerializedTaskNode`.
+ """
+
+ inputs: dict[str, SerializedEdge]
+ """Mapping of serialized init-input edges, keyed by connection name."""
+
+ outputs: dict[str, SerializedEdge]
+ """Mapping of serialized init-output edges, keyed by connection name."""
+
+ config_output: SerializedEdge
+ """The serialized config init-output edge."""
+
+ index: int | None = None
+ """The index of this node in the sorted sequence of `PipelineGraph`.
+
+ This is `None` if the `PipelineGraph` was not sorted when it was
+ serialized.
+ """
+
+ @classmethod
+ def serialize(cls, target: TaskInitNode) -> SerializedTaskInitNode:
+ """Transform a `TaskInitNode` to a `SerializedTaskInitNode`."""
+ return cls.construct(
+ inputs={
+ connection_name: SerializedEdge.serialize(edge)
+ for connection_name, edge in sorted(target.inputs.items())
+ },
+ outputs={
+ connection_name: SerializedEdge.serialize(edge)
+ for connection_name, edge in sorted(target.outputs.items())
+ },
+ config_output=SerializedEdge.serialize(target.config_output),
+ )
+
+ def deserialize(
+ self,
+ key: NodeKey,
+ task_class_name: str,
+ config_str: str,
+ ) -> TaskInitNode:
+ """Transform a `SerializedTaskInitNode` to a `TaskInitNode`."""
+ return TaskInitNode(
+ key,
+ inputs={
+ connection_name: serialized_edge.deserialize_read_edge(key, connection_name)
+ for connection_name, serialized_edge in self.inputs.items()
+ },
+ outputs={
+ connection_name: serialized_edge.deserialize_write_edge(key, connection_name)
+ for connection_name, serialized_edge in self.outputs.items()
+ },
+ config_output=self.config_output.deserialize_write_edge(
+ key,
+ acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME,
+ ),
+ task_class_name=task_class_name,
+ config_str=config_str,
+ )
+
+
+class SerializedTaskNode(pydantic.BaseModel):
+ """Struct used to represent a serialized `TaskNode` in a `PipelineGraph`.
+
+ The task label is serialized by the context in which a
+ `SerializedTaskNode` appears (e.g. the keys of the nested dictionary in
+ which it serves as the value type).
+ """
+
+ task_class: str
+ """Fully-qualified name of the task class."""
+
+ init: SerializedTaskInitNode
+ """Serialized task initialization node."""
+
+ config_str: str
+ """Configuration for the task as a string of override statements."""
+
+ prerequisite_inputs: dict[str, SerializedEdge]
+ """Mapping of serialized prerequisiste input edges, keyed by connection
+ name.
+ """
+
+ inputs: dict[str, SerializedEdge]
+ """Mapping of serialized input edges, keyed by connection name."""
+
+ outputs: dict[str, SerializedEdge]
+ """Mapping of serialized output edges, keyed by connection name."""
+
+ metadata_output: SerializedEdge
+ """The serialized metadata output edge."""
+
+ dimensions: list[str]
+ """The task's dimensions, if they were resolved."""
+
+ log_output: SerializedEdge | None = None
+ """The serialized log output edge."""
+
+ index: int | None = None
+ """The index of this node in the sorted sequence of `PipelineGraph`.
+
+ This is `None` if the `PipelineGraph` was not sorted when it was
+ serialized.
+ """
+
+ @classmethod
+ def serialize(cls, target: TaskNode) -> SerializedTaskNode:
+ """Transform a `TaskNode` to a `SerializedTaskNode`."""
+ return cls.construct(
+ task_class=target.task_class_name,
+ init=SerializedTaskInitNode.serialize(target.init),
+ config_str=target.get_config_str(),
+ dimensions=list(target.raw_dimensions),
+ prerequisite_inputs={
+ connection_name: SerializedEdge.serialize(edge)
+ for connection_name, edge in sorted(target.prerequisite_inputs.items())
+ },
+ inputs={
+ connection_name: SerializedEdge.serialize(edge)
+ for connection_name, edge in sorted(target.inputs.items())
+ },
+ outputs={
+ connection_name: SerializedEdge.serialize(edge)
+ for connection_name, edge in sorted(target.outputs.items())
+ },
+ metadata_output=SerializedEdge.serialize(target.metadata_output),
+ log_output=(
+ SerializedEdge.serialize(target.log_output) if target.log_output is not None else None
+ ),
+ )
+
+ def deserialize(self, key: NodeKey, init_key: NodeKey, universe: DimensionUniverse | None) -> TaskNode:
+ """Transform a `SerializedTaskNode` to a `TaskNode`."""
+ init = self.init.deserialize(
+ init_key,
+ task_class_name=self.task_class,
+ config_str=expect_not_none(
+ self.config_str, f"No serialized config file for task with label {key.name!r}."
+ ),
+ )
+ inputs = {
+ connection_name: serialized_edge.deserialize_read_edge(key, connection_name)
+ for connection_name, serialized_edge in self.inputs.items()
+ }
+ prerequisite_inputs = {
+ connection_name: serialized_edge.deserialize_read_edge(key, connection_name, is_prerequisite=True)
+ for connection_name, serialized_edge in self.prerequisite_inputs.items()
+ }
+ outputs = {
+ connection_name: serialized_edge.deserialize_write_edge(key, connection_name)
+ for connection_name, serialized_edge in self.outputs.items()
+ }
+ if (serialized_log_output := self.log_output) is not None:
+ log_output = serialized_log_output.deserialize_write_edge(key, acc.LOG_OUTPUT_CONNECTION_NAME)
+ else:
+ log_output = None
+ metadata_output = self.metadata_output.deserialize_write_edge(
+ key, acc.METADATA_OUTPUT_CONNECTION_NAME
+ )
+ dimensions: frozenset[str] | DimensionGraph
+ if universe is not None:
+ dimensions = universe.extract(self.dimensions)
+ else:
+ dimensions = frozenset(self.dimensions)
+ return TaskNode(
+ key=key,
+ init=init,
+ inputs=inputs,
+ prerequisite_inputs=prerequisite_inputs,
+ outputs=outputs,
+ log_output=log_output,
+ metadata_output=metadata_output,
+ dimensions=dimensions,
+ )
+
+
+class SerializedDatasetTypeNode(pydantic.BaseModel):
+ """Struct used to represent a serialized `DatasetTypeNode` in a
+ `PipelineGraph`.
+
+ Unresolved dataset types are serialized as instances with at most the
+ `index` attribute set, and are typically converted to JSON with pydantic's
+ ``exclude_defaults=True`` option to keep this compact.
+
+ The dataset typename is serialized by the context in which a
+ `SerializedDatasetTypeNode` appears (e.g. the keys of the nested dictionary
+ in which it serves as the value type).
+ """
+
+ dimensions: list[str] | None = None
+ """Dimensions of the dataset type."""
+
+ storage_class: str | None = None
+ """Name of the storage class."""
+
+ is_calibration: bool = False
+ """Whether this dataset type is a calibration."""
+
+ is_initial_query_constraint: bool = False
+ """Whether this dataset type should be a query constraint during
+ `QuantumGraph` generation."""
+
+ is_prerequisite: bool = False
+ """Whether datasets of this dataset type must exist in the input collection
+ before `QuantumGraph` generation."""
+
+ index: int | None = None
+ """The index of this node in the sorted sequence of `PipelineGraph`.
+
+ This is `None` if the `PipelineGraph` was not sorted when it was
+ serialized.
+ """
+
+ @classmethod
+ def serialize(cls, target: DatasetTypeNode | None) -> SerializedDatasetTypeNode:
+ """Transform a `DatasetTypeNode` to a `SerializedDatasetTypeNode`."""
+ if target is None:
+ return cls.construct()
+ return cls.construct(
+ dimensions=list(target.dataset_type.dimensions.names),
+ storage_class=target.dataset_type.storageClass_name,
+ is_calibration=target.dataset_type.isCalibration(),
+ is_initial_query_constraint=target.is_initial_query_constraint,
+ is_prerequisite=target.is_prerequisite,
+ )
+
+ def deserialize(self, key: NodeKey, universe: DimensionUniverse | None) -> DatasetTypeNode | None:
+ """Transform a `SerializedDatasetTypeNode` to a `DatasetTypeNode`."""
+ if self.dimensions is not None:
+ dataset_type = DatasetType(
+ key.name,
+ expect_not_none(
+ self.dimensions,
+ f"Serialized dataset type {key.name!r} has no dimensions.",
+ ),
+ storageClass=expect_not_none(
+ self.storage_class,
+ f"Serialized dataset type {key.name!r} has no storage class.",
+ ),
+ isCalibration=self.is_calibration,
+ universe=expect_not_none(
+ universe,
+ f"Serialized dataset type {key.name!r} has dimensions, "
+ "but no dimension universe was stored.",
+ ),
+ )
+ return DatasetTypeNode(
+ dataset_type=dataset_type,
+ is_prerequisite=self.is_prerequisite,
+ is_initial_query_constraint=self.is_initial_query_constraint,
+ )
+ return None
+
+
+class SerializedTaskSubset(pydantic.BaseModel):
+ """Struct used to represent a serialized `TaskSubset` in a `PipelineGraph`.
+
+ The subsetlabel is serialized by the context in which a
+ `SerializedDatasetTypeNode` appears (e.g. the keys of the nested dictionary
+ in which it serves as the value type).
+ """
+
+ description: str
+ """Description of the subset."""
+
+ tasks: list[str]
+ """Labels of tasks in the subset, sorted lexicographically for
+ determinism.
+ """
+
+ @classmethod
+ def serialize(cls, target: TaskSubset) -> SerializedTaskSubset:
+ """Transform a `TaskSubset` into a `SerializedTaskSubset`."""
+ return cls.construct(description=target._description, tasks=list(sorted(target)))
+
+ def deserialize_task_subset(self, label: str, xgraph: networkx.MultiDiGraph) -> TaskSubset:
+ """Transform a `SerializedTaskSubset` into a `TaskSubset`."""
+ members = set(self.tasks)
+ return TaskSubset(xgraph, label, members, self.description)
+
+
+class SerializedPipelineGraph(pydantic.BaseModel):
+ """Struct used to represent a serialized `PipelineGraph`."""
+
+ version: str = ".".join(str(v) for v in _IO_VERSION_INFO)
+ """Serialization version."""
+
+ description: str
+ """Human-readable description of the pipeline."""
+
+ tasks: dict[str, SerializedTaskNode] = pydantic.Field(default_factory=dict)
+ """Mapping of serialized tasks, keyed by label."""
+
+ dataset_types: dict[str, SerializedDatasetTypeNode] = pydantic.Field(default_factory=dict)
+ """Mapping of serialized dataset types, keyed by parent dataset type name.
+ """
+
+ task_subsets: dict[str, SerializedTaskSubset] = pydantic.Field(default_factory=dict)
+ """Mapping of task subsets, keyed by subset label."""
+
+ dimensions: dict[str, Any] | None = None
+ """Dimension universe configuration."""
+
+ data_id: dict[str, Any] = pydantic.Field(default_factory=dict)
+ """Data ID that constrains all quanta generated from this pipeline."""
+
+ @classmethod
+ def serialize(cls, target: PipelineGraph) -> SerializedPipelineGraph:
+ """Transform a `PipelineGraph` into a `SerializedPipelineGraph`."""
+ result = SerializedPipelineGraph.construct(
+ description=target.description,
+ tasks={label: SerializedTaskNode.serialize(node) for label, node in target.tasks.items()},
+ dataset_types={
+ name: SerializedDatasetTypeNode().serialize(target.dataset_types.get_if_resolved(name))
+ for name in target.dataset_types.keys()
+ },
+ task_subsets={
+ label: SerializedTaskSubset.serialize(subset) for label, subset in target.task_subsets.items()
+ },
+ dimensions=target.universe.dimensionConfig.toDict() if target.universe is not None else None,
+ data_id=target._raw_data_id,
+ )
+ if target._sorted_keys:
+ for index, node_key in enumerate(target._sorted_keys):
+ match node_key.node_type:
+ case NodeType.TASK:
+ result.tasks[node_key.name].index = index
+ case NodeType.DATASET_TYPE:
+ result.dataset_types[node_key.name].index = index
+ case NodeType.TASK_INIT:
+ result.tasks[node_key.name].init.index = index
+ return result
+
+ def deserialize(
+ self,
+ import_and_configure: bool = True,
+ check_edges_unchanged: bool = False,
+ assume_edges_unchanged: bool = False,
+ ) -> PipelineGraph:
+ """Transform a `SerializedPipelineGraph` into a `PipelineGraph`."""
+ universe: DimensionUniverse | None = None
+ if self.dimensions is not None:
+ universe = DimensionUniverse(
+ config=DimensionConfig(
+ expect_not_none(
+ self.dimensions,
+ "Serialized pipeline graph has not been resolved; "
+ "load it is a MutablePipelineGraph instead.",
+ )
+ )
+ )
+ xgraph = networkx.MultiDiGraph()
+ sort_index_map: dict[int, NodeKey] = {}
+ for dataset_type_name, serialized_dataset_type in self.dataset_types.items():
+ dataset_type_key = NodeKey(NodeType.DATASET_TYPE, dataset_type_name)
+ dataset_type_node = serialized_dataset_type.deserialize(dataset_type_key, universe)
+ xgraph.add_node(
+ dataset_type_key, instance=dataset_type_node, bipartite=NodeType.DATASET_TYPE.value
+ )
+ if serialized_dataset_type.index is not None:
+ sort_index_map[serialized_dataset_type.index] = dataset_type_key
+ for task_label, serialized_task in self.tasks.items():
+ task_key = NodeKey(NodeType.TASK, task_label)
+ task_init_key = NodeKey(NodeType.TASK_INIT, task_label)
+ task_node = serialized_task.deserialize(task_key, task_init_key, universe)
+ if serialized_task.index is not None:
+ sort_index_map[serialized_task.index] = task_key
+ if serialized_task.init.index is not None:
+ sort_index_map[serialized_task.init.index] = task_init_key
+ xgraph.add_node(task_key, instance=task_node, bipartite=NodeType.TASK.bipartite)
+ xgraph.add_node(task_init_key, instance=task_node.init, bipartite=NodeType.TASK_INIT.bipartite)
+ xgraph.add_edge(task_init_key, task_key, Edge.INIT_TO_TASK_NAME, instance=None)
+ for read_edge in task_node.init.iter_all_inputs():
+ xgraph.add_edge(
+ read_edge.dataset_type_key,
+ read_edge.task_key,
+ read_edge.connection_name,
+ instance=read_edge,
+ )
+ for write_edge in task_node.init.iter_all_outputs():
+ xgraph.add_edge(
+ write_edge.task_key,
+ write_edge.dataset_type_key,
+ write_edge.connection_name,
+ instance=write_edge,
+ )
+ for read_edge in task_node.iter_all_inputs():
+ xgraph.add_edge(
+ read_edge.dataset_type_key,
+ read_edge.task_key,
+ read_edge.connection_name,
+ instance=read_edge,
+ )
+ for write_edge in task_node.iter_all_outputs():
+ xgraph.add_edge(
+ write_edge.task_key,
+ write_edge.dataset_type_key,
+ write_edge.connection_name,
+ instance=write_edge,
+ )
+ result = PipelineGraph.__new__(PipelineGraph)
+ result._init_from_args(
+ xgraph,
+ sorted_keys=[sort_index_map[i] for i in range(len(xgraph))] if sort_index_map else None,
+ task_subsets={
+ subset_label: serialized_subset.deserialize_task_subset(subset_label, xgraph)
+ for subset_label, serialized_subset in self.task_subsets.items()
+ },
+ description=self.description,
+ universe=universe,
+ data_id=self.data_id,
+ )
+ if import_and_configure:
+ result.import_and_configure(
+ check_edges_unchanged=check_edges_unchanged,
+ assume_edges_unchanged=assume_edges_unchanged,
+ )
+ return result
diff --git a/python/lsst/pipe/base/tests/mocks/_pipeline_task.py b/python/lsst/pipe/base/tests/mocks/_pipeline_task.py
index 61edcd24c..f7d785b37 100644
--- a/python/lsst/pipe/base/tests/mocks/_pipeline_task.py
+++ b/python/lsst/pipe/base/tests/mocks/_pipeline_task.py
@@ -20,19 +20,27 @@
# along with this program. If not, see .
from __future__ import annotations
-__all__ = ("MockPipelineTask", "MockPipelineTaskConfig", "mock_task_defs")
+__all__ = (
+ "DynamicConnectionConfig",
+ "DynamicTestPipelineTask",
+ "DynamicTestPipelineTaskConfig",
+ "MockPipelineTask",
+ "MockPipelineTaskConfig",
+ "mock_task_defs",
+)
import dataclasses
import logging
from collections.abc import Iterable, Mapping
-from typing import TYPE_CHECKING, Any, ClassVar
+from typing import TYPE_CHECKING, Any, ClassVar, TypeVar
-from lsst.daf.butler import DeferredDatasetHandle
-from lsst.pex.config import ConfigurableField, Field, ListField
+from lsst.daf.butler import DatasetRef, DeferredDatasetHandle
+from lsst.pex.config import Config, ConfigDictField, ConfigurableField, Field, ListField
from lsst.utils.doImport import doImportType
from lsst.utils.introspection import get_full_type_name
from lsst.utils.iteration import ensure_iterable
+from ... import connectionTypes as cT
from ...config import PipelineTaskConfig
from ...connections import InputQuantizedConnection, OutputQuantizedConnection, PipelineTaskConnections
from ...pipeline import TaskDef
@@ -46,6 +54,9 @@
from ..._quantumContext import QuantumContext
+_T = TypeVar("_T", bound=cT.BaseConnection)
+
+
def mock_task_defs(
originals: Iterable[TaskDef],
unmocked_dataset_types: Iterable[str] = (),
@@ -96,73 +107,11 @@ def mock_task_defs(
return results
-class MockPipelineDefaultTargetConnections(PipelineTaskConnections, dimensions=()):
+class BaseTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()):
pass
-class MockPipelineDefaultTargetConfig(
- PipelineTaskConfig, pipelineConnections=MockPipelineDefaultTargetConnections
-):
- pass
-
-
-class MockPipelineDefaultTargetTask(PipelineTask):
- """A `PipelineTask` class used as the default target for
- ``MockPipelineTaskConfig.original``.
-
- This is effectively a workaround for `lsst.pex.config.ConfigurableField`
- not supporting ``optional=True``, but that is generally a reasonable
- limitation for production code and it wouldn't make sense just to support
- test utilities.
- """
-
- ConfigClass = MockPipelineDefaultTargetConfig
-
-
-class MockPipelineTaskConnections(PipelineTaskConnections, dimensions=()):
- def __init__(self, *, config: MockPipelineTaskConfig):
- original: PipelineTaskConnections = config.original.connections.ConnectionsClass(
- config=config.original.value
- )
- self.dimensions.update(original.dimensions)
- unmocked_dataset_types = frozenset(config.unmocked_dataset_types)
- for name, connection in original.allConnections.items():
- if name in original.initInputs or name in original.initOutputs:
- # We just ignore initInputs and initOutputs, because the task
- # is never given DatasetRefs for those and hence can't create
- # mocks.
- continue
- if connection.name not in unmocked_dataset_types:
- # We register the mock storage class with the global singleton
- # here, but can only put its name in the connection. That means
- # the same global singleton (or one that also has these
- # registrations) has to be available whenever this dataset type
- # is used.
- storage_class = MockStorageClass.get_or_register_mock(connection.storageClass)
- kwargs = {}
- if hasattr(connection, "dimensions"):
- connection_dimensions = set(connection.dimensions)
- # Replace the generic "skypix" placeholder with htm7, since
- # that requires the dataset type to have already been
- # registered.
- if "skypix" in connection_dimensions:
- connection_dimensions.remove("skypix")
- connection_dimensions.add("htm7")
- kwargs["dimensions"] = connection_dimensions
- connection = dataclasses.replace(
- connection,
- name=get_mock_name(connection.name),
- storageClass=storage_class.name,
- **kwargs,
- )
- elif name in original.outputs:
- raise ValueError(f"Unmocked dataset type {connection.name!r} cannot be used as an output.")
- setattr(self, name, connection)
-
-
-class MockPipelineTaskConfig(PipelineTaskConfig, pipelineConnections=MockPipelineTaskConnections):
- """Configuration class for `MockPipelineTask`."""
-
+class BaseTestPipelineTaskConfig(PipelineTaskConfig, pipelineConnections=BaseTestPipelineTaskConnections):
fail_condition = Field[str](
dtype=str,
default="",
@@ -181,28 +130,15 @@ class MockPipelineTaskConfig(PipelineTaskConfig, pipelineConnections=MockPipelin
),
)
- original: ConfigurableField = ConfigurableField(
- doc="The original task being mocked by this one.", target=MockPipelineDefaultTargetTask
- )
-
- unmocked_dataset_types = ListField[str](
- doc=(
- "Names of input dataset types that should be used as-is instead "
- "of being mocked. May include dataset types not relevant for "
- "this task, which will be ignored."
- ),
- default=(),
- optional=False,
- )
-
def data_id_match(self) -> DataIdMatch | None:
if not self.fail_condition:
return None
return DataIdMatch(self.fail_condition)
-class MockPipelineTask(PipelineTask):
- """Implementation of `PipelineTask` used for running a mock pipeline.
+class BaseTestPipelineTask(PipelineTask):
+ """A base class for test-utility `PipelineTask` classes that read and write
+ mock datasets `runQuantum`.
Notes
-----
@@ -212,20 +148,21 @@ class MockPipelineTask(PipelineTask):
`MockDataset` inputs and simulates reading inputs of other types by
creating `MockDataset` inputs from their DatasetRefs.
- At present `MockPipelineTask` simply drops any ``initInput`` and
- ``initOutput`` connections present on the original, since `MockDataset`
- creation for those would have to happen in the code that executes the task,
- not in the task itself. Because `MockPipelineTask` never instantiates the
- mock task (just its connections class), this is a limitation on what the
- mocks can be used to test, not anything deeper.
+ Subclasses are responsible for defining connections, but init-input and
+ init-output connections are not supported at runtime (they may be present
+ as long as the task is never constructed). All output connections must
+ use mock storage classes. `..Input` and `..PrerequisiteInput` connections
+ that do not use mock storage classes will be handled by constructing a
+ `MockDataset` from the `~lsst.daf.butler.DatasetRef` rather than actually
+ reading them.
"""
- ConfigClass: ClassVar[type[PipelineTaskConfig]] = MockPipelineTaskConfig
+ ConfigClass: ClassVar[type[PipelineTaskConfig]] = BaseTestPipelineTaskConfig
def __init__(
self,
*,
- config: MockPipelineTaskConfig,
+ config: BaseTestPipelineTaskConfig,
**kwargs: Any,
):
super().__init__(config=config, **kwargs)
@@ -234,7 +171,7 @@ def __init__(
if self.data_id_match:
self.fail_exception = doImportType(self.config.fail_exception)
- config: MockPipelineTaskConfig
+ config: BaseTestPipelineTaskConfig
def runQuantum(
self,
@@ -263,6 +200,7 @@ def runQuantum(
)
for name, refs in inputRefs:
inputs_list = []
+ ref: DatasetRef
for ref in ensure_iterable(refs):
if isinstance(ref.datasetType.storageClass, MockStorageClass):
input_dataset = butlerQC.get(ref)
@@ -292,3 +230,218 @@ def runQuantum(
butlerQC.put(output, ref)
_LOG.info("Finished mocking task '%s' on quantum %s", self.getName(), quantum.dataId)
+
+
+class MockPipelineDefaultTargetConnections(PipelineTaskConnections, dimensions=()):
+ pass
+
+
+class MockPipelineDefaultTargetConfig(
+ PipelineTaskConfig, pipelineConnections=MockPipelineDefaultTargetConnections
+):
+ pass
+
+
+class MockPipelineDefaultTargetTask(PipelineTask):
+ """A `PipelineTask` class used as the default target for
+ ``MockPipelineTaskConfig.original``.
+
+ This is effectively a workaround for `lsst.pex.config.ConfigurableField`
+ not supporting ``optional=True``, but that is generally a reasonable
+ limitation for production code and it wouldn't make sense just to support
+ test utilities.
+ """
+
+ ConfigClass = MockPipelineDefaultTargetConfig
+
+
+class MockPipelineTaskConnections(BaseTestPipelineTaskConnections, dimensions=()):
+ """A connections class that creates mock connections from the connections
+ of a real PipelineTask.
+ """
+
+ def __init__(self, *, config: MockPipelineTaskConfig):
+ original: PipelineTaskConnections = config.original.connections.ConnectionsClass(
+ config=config.original.value
+ )
+ self.dimensions.update(original.dimensions)
+ unmocked_dataset_types = frozenset(config.unmocked_dataset_types)
+ for name, connection in original.allConnections.items():
+ if name in original.initInputs or name in original.initOutputs:
+ # We just ignore initInputs and initOutputs, because the task
+ # is never given DatasetRefs for those and hence can't create
+ # mocks.
+ continue
+ if connection.name not in unmocked_dataset_types:
+ # We register the mock storage class with the global singleton
+ # here, but can only put its name in the connection. That means
+ # the same global singleton (or one that also has these
+ # registrations) has to be available whenever this dataset type
+ # is used.
+ storage_class = MockStorageClass.get_or_register_mock(connection.storageClass)
+ kwargs = {}
+ if hasattr(connection, "dimensions"):
+ connection_dimensions = set(connection.dimensions)
+ # Replace the generic "skypix" placeholder with htm7, since
+ # that requires the dataset type to have already been
+ # registered.
+ if "skypix" in connection_dimensions:
+ connection_dimensions.remove("skypix")
+ connection_dimensions.add("htm7")
+ kwargs["dimensions"] = connection_dimensions
+ connection = dataclasses.replace(
+ connection,
+ name=get_mock_name(connection.name),
+ storageClass=storage_class.name,
+ **kwargs,
+ )
+ elif name in original.outputs:
+ raise ValueError(f"Unmocked dataset type {connection.name!r} cannot be used as an output.")
+ setattr(self, name, connection)
+
+
+class MockPipelineTaskConfig(BaseTestPipelineTaskConfig, pipelineConnections=MockPipelineTaskConnections):
+ """Configuration class for `MockPipelineTask`."""
+
+ original: ConfigurableField = ConfigurableField(
+ doc="The original task being mocked by this one.", target=MockPipelineDefaultTargetTask
+ )
+
+ unmocked_dataset_types = ListField[str](
+ doc=(
+ "Names of input dataset types that should be used as-is instead "
+ "of being mocked. May include dataset types not relevant for "
+ "this task, which will be ignored."
+ ),
+ default=(),
+ optional=False,
+ )
+
+
+class MockPipelineTask(BaseTestPipelineTask):
+ """A test-utility implementation of `PipelineTask` with connections
+ generated by mocking those of a real task.
+
+ Notes
+ -----
+ At present `MockPipelineTask` simply drops any ``initInput`` and
+ ``initOutput`` connections present on the original, since `MockDataset`
+ creation for those would have to happen in the code that executes the task,
+ not in the task itself. Because `MockPipelineTask` never instantiates the
+ mock task (just its connections class), this is a limitation on what the
+ mocks can be used to test, not anything deeper.
+ """
+
+ ConfigClass: ClassVar[type[PipelineTaskConfig]] = MockPipelineTaskConfig
+
+
+class DynamicConnectionConfig(Config):
+ """A config class that defines a completely dynamic connection."""
+
+ dataset_type_name = Field[str](doc="Name for the dataset type as seen by the butler.", dtype=str)
+ dimensions = ListField[str](doc="Dimensions for the dataset type.", dtype=str, default=[])
+ storage_class = Field[str](
+ doc="Name of the butler storage class for the dataset type.", dtype=str, default="StructuredDataDict"
+ )
+ is_calibration = Field[bool](doc="Whether this dataset type is a calibration.", dtype=bool, default=False)
+ multiple = Field[bool](
+ doc="Whether this connection gets or puts multiple datasets for each quantum.",
+ dtype=bool,
+ default=False,
+ )
+ mock_storage_class = Field[bool](
+ doc="Whether the storage class should actually be a mock of the storage class given.",
+ dtype=bool,
+ default=True,
+ )
+
+ def make_connection(self, cls: type[_T]) -> _T:
+ storage_class = self.storage_class
+ if self.mock_storage_class:
+ storage_class = MockStorageClass.get_or_register_mock(storage_class).name
+ if issubclass(cls, cT.DimensionedConnection):
+ return cls( # type: ignore
+ name=self.dataset_type_name,
+ storageClass=storage_class,
+ isCalibration=self.is_calibration,
+ multiple=self.multiple,
+ dimensions=frozenset(self.dimensions),
+ )
+ else:
+ return cls(
+ name=self.dataset_type_name,
+ storageClass=storage_class,
+ multiple=self.multiple,
+ )
+
+
+class DynamicTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()):
+ """A connections class whose dimensions and connections are wholly
+ determined via configuration.
+ """
+
+ def __init__(self, *, config: DynamicTestPipelineTaskConfig):
+ self.dimensions.update(config.dimensions)
+ connection_config: DynamicConnectionConfig
+ for connection_name, connection_config in config.init_inputs.items():
+ setattr(self, connection_name, connection_config.make_connection(cT.InitInput))
+ for connection_name, connection_config in config.init_outputs.items():
+ setattr(self, connection_name, connection_config.make_connection(cT.InitOutput))
+ for connection_name, connection_config in config.prerequisite_inputs.items():
+ setattr(self, connection_name, connection_config.make_connection(cT.PrerequisiteInput))
+ for connection_name, connection_config in config.inputs.items():
+ setattr(self, connection_name, connection_config.make_connection(cT.Input))
+ for connection_name, connection_config in config.outputs.items():
+ setattr(self, connection_name, connection_config.make_connection(cT.Output))
+
+
+class DynamicTestPipelineTaskConfig(
+ PipelineTaskConfig, pipelineConnections=DynamicTestPipelineTaskConnections
+):
+ """Configuration for DynamicTestPipelineTask."""
+
+ dimensions = ListField[str](doc="Dimensions for the task's quanta.", dtype=str, default=[])
+ init_inputs = ConfigDictField(
+ doc=(
+ "Init-input connections, keyed by the connection name as seen by the task. "
+ "Must be empty if the task will be constructed."
+ ),
+ keytype=str,
+ itemtype=DynamicConnectionConfig,
+ default={},
+ )
+ init_outputs = ConfigDictField(
+ doc=(
+ "Init-output connections, keyed by the connection name as seen by the task. "
+ "Must be empty if the task will be constructed."
+ ),
+ keytype=str,
+ itemtype=DynamicConnectionConfig,
+ default={},
+ )
+ prerequisite_inputs = ConfigDictField(
+ doc="Prerequisite input connections, keyed by the connection name as seen by the task.",
+ keytype=str,
+ itemtype=DynamicConnectionConfig,
+ default={},
+ )
+ inputs = ConfigDictField(
+ doc="Regular input connections, keyed by the connection name as seen by the task.",
+ keytype=str,
+ itemtype=DynamicConnectionConfig,
+ default={},
+ )
+ outputs = ConfigDictField(
+ doc="Regular output connections, keyed by the connection name as seen by the task.",
+ keytype=str,
+ itemtype=DynamicConnectionConfig,
+ default={},
+ )
+
+
+class DynamicTestPipelineTask(BaseTestPipelineTask):
+ """A test-utility implementation of `PipelineTask` with dimensions and
+ connections determined wholly from configuration.
+ """
+
+ ConfigClass: ClassVar[type[PipelineTaskConfig]] = DynamicTestPipelineTaskConfig
diff --git a/python/lsst/pipe/base/tests/pipelineStepTester.py b/python/lsst/pipe/base/tests/pipelineStepTester.py
index 22e08e7de..ddbc680f8 100644
--- a/python/lsst/pipe/base/tests/pipelineStepTester.py
+++ b/python/lsst/pipe/base/tests/pipelineStepTester.py
@@ -28,7 +28,7 @@
import unittest
from lsst.daf.butler import Butler, DatasetType
-from lsst.pipe.base import Pipeline, PipelineDatasetTypes
+from lsst.pipe.base import Pipeline
@dataclasses.dataclass
@@ -88,32 +88,22 @@ def run(self, butler: Butler, test_case: unittest.TestCase) -> None:
pure_inputs: dict[str, str] = dict()
for suffix in self.step_suffixes:
- pipeline = Pipeline.from_uri(self.filename + suffix)
- dataset_types = PipelineDatasetTypes.fromPipeline(
- pipeline,
- registry=butler.registry,
- include_configs=False,
- include_packages=False,
- )
+ step_graph = Pipeline.from_uri(self.filename + suffix).to_graph()
+ step_graph.resolve(butler.registry)
- pure_inputs.update({k: suffix for k in dataset_types.prerequisites.names})
- parent_inputs = {t.nameAndComponent()[0] for t in dataset_types.inputs}
- pure_inputs.update({k: suffix for k in parent_inputs - all_outputs.keys()})
- all_outputs.update(dataset_types.outputs.asMapping())
- all_outputs.update(dataset_types.intermediates.asMapping())
-
- for name in dataset_types.inputs.names & all_outputs.keys():
- test_case.assertTrue(
- all_outputs[name].is_compatible_with(dataset_types.inputs[name]),
- msg=(
- f"dataset type {name} is defined as {dataset_types.inputs[name]} as an "
- f"input, but {all_outputs[name]} as an output, and these are not compatible."
- ),
- )
+ pure_inputs.update(
+ {name: suffix for name, _ in step_graph.iter_overall_inputs() if name not in all_outputs}
+ )
+ all_outputs.update(
+ {
+ name: node.dataset_type
+ for name, node in step_graph.dataset_types.items()
+ if step_graph.producer_of(name) is not None
+ }
+ )
- for dataset_type in dataset_types.outputs | dataset_types.intermediates:
- if not dataset_type.isComponent():
- butler.registry.registerDatasetType(dataset_type)
+ for node in step_graph.dataset_types.values():
+ butler.registry.registerDatasetType(node.dataset_type)
if not pure_inputs.keys() <= self.expected_inputs:
missing = [f"{k} ({pure_inputs[k]})" for k in pure_inputs.keys() - self.expected_inputs]
diff --git a/tests/test_pipeTools.py b/tests/test_pipeTools.py
index 5cf247f8c..f9708bdbd 100644
--- a/tests/test_pipeTools.py
+++ b/tests/test_pipeTools.py
@@ -130,18 +130,6 @@ def testIsOrdered(self):
)
self.assertTrue(pipeTools.isPipelineOrdered(pipeline))
- def testIsOrderedExceptions(self):
- """Tests for pipeTools.isPipelineOrdered method exceptions"""
- # two producers should throw ValueError
- with self.assertRaises(pipeTools.DuplicateOutputError):
- _makePipeline(
- [
- ("A", "B", "task1"),
- ("B", "C", "task2"),
- ("A", "C", "task3"),
- ]
- )
-
def testOrderPipeline(self):
"""Tests for pipeTools.orderPipeline method"""
pipeline = _makePipeline([("A", "B", "task1"), ("B", "C", "task2")])
@@ -198,14 +186,6 @@ def testOrderPipeline(self):
def testOrderPipelineExceptions(self):
"""Tests for pipeTools.orderPipeline method exceptions"""
- with self.assertRaises(pipeTools.DuplicateOutputError):
- _makePipeline(
- [
- ("A", "B", "task1"),
- ("B", "C", "task2"),
- ("A", "C", "task3"),
- ]
- )
# cycle in a graph should throw ValueError
with self.assertRaises(pipeTools.PipelineDataCycleError):
diff --git a/tests/test_pipeline_graph.py b/tests/test_pipeline_graph.py
new file mode 100644
index 000000000..60566a203
--- /dev/null
+++ b/tests/test_pipeline_graph.py
@@ -0,0 +1,1299 @@
+# This file is part of pipe_base.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+"""Tests of things related to the GraphBuilder class."""
+
+import copy
+import io
+import logging
+import unittest
+from typing import Any
+
+import lsst.pipe.base.automatic_connection_constants as acc
+import lsst.utils.tests
+from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, StorageClassFactory
+from lsst.daf.butler.registry import MissingDatasetTypeError
+from lsst.pipe.base.pipeline_graph import (
+ ConnectionTypeConsistencyError,
+ DuplicateOutputError,
+ Edge,
+ EdgesChangedError,
+ IncompatibleDatasetTypeError,
+ NodeKey,
+ NodeType,
+ PipelineGraph,
+ PipelineGraphError,
+ UnresolvedGraphError,
+)
+from lsst.pipe.base.tests.mocks import (
+ DynamicConnectionConfig,
+ DynamicTestPipelineTask,
+ DynamicTestPipelineTaskConfig,
+ get_mock_name,
+)
+
+_LOG = logging.getLogger(__name__)
+
+
+class MockRegistry:
+ """A test-utility stand-in for lsst.daf.butler.Registry that just knows
+ how to get dataset types.
+ """
+
+ def __init__(self, dimensions: DimensionUniverse, dataset_types: dict[str, DatasetType]) -> None:
+ self.dimensions = dimensions
+ self._dataset_types = dataset_types
+
+ def getDatasetType(self, name: str) -> DatasetType:
+ try:
+ return self._dataset_types[name]
+ except KeyError:
+ raise MissingDatasetTypeError(name)
+
+
+class PipelineGraphTestCase(unittest.TestCase):
+ """Tests for the `PipelineGraph` class.
+
+ Tests for `PipelineGraph.resolve` are mostly in
+ `PipelineGraphResolveTestCase` later in this file.
+ """
+
+ def setUp(self) -> None:
+ # Simple test pipeline has two tasks, 'a' and 'b', with dataset types
+ # 'input', 'intermediate', and 'output'. There are no dimensions on
+ # any of those. We add tasks in reverse order to better test sorting.
+ # There is one labeled task subset, 'only_b', with just 'b' in it.
+ # We copy the configs so the originals (the instance attributes) can
+ # be modified and reused after the ones passed in to the graph are
+ # frozen.
+ self.description = "A pipeline for PipelineGraph unit tests."
+ self.graph = PipelineGraph()
+ self.graph.description = self.description
+ self.b_config = DynamicTestPipelineTaskConfig()
+ self.b_config.init_inputs["in_schema"] = DynamicConnectionConfig(dataset_type_name="schema")
+ self.b_config.inputs["input1"] = DynamicConnectionConfig(dataset_type_name="intermediate_1")
+ self.b_config.outputs["output1"] = DynamicConnectionConfig(dataset_type_name="output_1")
+ self.graph.add_task("b", DynamicTestPipelineTask, copy.deepcopy(self.b_config))
+ self.a_config = DynamicTestPipelineTaskConfig()
+ self.a_config.init_outputs["out_schema"] = DynamicConnectionConfig(dataset_type_name="schema")
+ self.a_config.inputs["input1"] = DynamicConnectionConfig(dataset_type_name="input_1")
+ self.a_config.outputs["output1"] = DynamicConnectionConfig(dataset_type_name="intermediate_1")
+ self.graph.add_task("a", DynamicTestPipelineTask, copy.deepcopy(self.a_config))
+ self.graph.add_task_subset("only_b", ["b"])
+ self.subset_description = "A subset with only task B in it."
+ self.graph.task_subsets["only_b"].description = self.subset_description
+ self.dimensions = DimensionUniverse()
+ self.maxDiff = None
+
+ def test_unresolved_accessors(self) -> None:
+ """Test attribute accessors, iteration, and simple methods on a graph
+ that has not had `PipelineGraph.resolve` called on it."""
+ self.check_base_accessors(self.graph)
+ self.assertEqual(
+ repr(self.graph.tasks["a"]), "a (lsst.pipe.base.tests.mocks.DynamicTestPipelineTask)"
+ )
+
+ def test_sorting(self) -> None:
+ """Test sort methods on PipelineGraph."""
+ self.assertFalse(self.graph.has_been_sorted)
+ self.assertFalse(self.graph.is_sorted)
+ self.graph.sort()
+ self.check_sorted(self.graph)
+
+ def test_unresolved_xgraph_export(self) -> None:
+ """Test exporting an unresolved PipelineGraph to networkx in various
+ ways."""
+ self.check_make_xgraph(self.graph, resolved=False)
+ self.check_make_bipartite_xgraph(self.graph, resolved=False)
+ self.check_make_task_xgraph(self.graph, resolved=False)
+ self.check_make_dataset_type_xgraph(self.graph, resolved=False)
+
+ def test_unresolved_stream_io(self) -> None:
+ """Test round-tripping an unresolved PipelineGraph through in-memory
+ serialization.
+ """
+ stream = io.BytesIO()
+ self.graph.write_stream(stream)
+ stream.seek(0)
+ roundtripped = PipelineGraph.read_stream(stream)
+ self.check_make_xgraph(roundtripped, resolved=False)
+
+ def test_unresolved_file_io(self) -> None:
+ """Test round-tripping an unresolved PipelineGraph through file
+ serialization.
+ """
+ with lsst.utils.tests.getTempFilePath(".json.gz") as filename:
+ self.graph.write_uri(filename)
+ roundtripped = PipelineGraph.read_uri(filename)
+ self.check_make_xgraph(roundtripped, resolved=False)
+
+ def test_unresolved_deferred_import_io(self) -> None:
+ """Test round-tripping an unresolved PipelineGraph through
+ serialization, without immediately importing tasks on read.
+ """
+ stream = io.BytesIO()
+ self.graph.write_stream(stream)
+ stream.seek(0)
+ roundtripped = PipelineGraph.read_stream(stream, import_and_configure=False)
+ self.check_make_xgraph(roundtripped, resolved=False, imported_and_configured=False)
+ # Check that we can still resolve the graph without importing tasks.
+ roundtripped.resolve(MockRegistry(self.dimensions, {}))
+ self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False)
+ roundtripped.import_and_configure(assume_edges_unchanged=True)
+ self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True)
+
+ def test_resolved_accessors(self) -> None:
+ """Test attribute accessors, iteration, and simple methods on a graph
+ that has had `PipelineGraph.resolve` called on it.
+
+ This includes the accessors available on unresolved graphs as well as
+ new ones, and we expect the resolved graph to be sorted as well.
+ """
+ self.graph.resolve(MockRegistry(self.dimensions, {}))
+ self.check_base_accessors(self.graph)
+ self.check_sorted(self.graph)
+ self.assertEqual(
+ repr(self.graph.tasks["a"]), "a (lsst.pipe.base.tests.mocks.DynamicTestPipelineTask, {})"
+ )
+ self.assertEqual(self.graph.tasks["a"].dimensions, self.dimensions.empty)
+ self.assertEqual(repr(self.graph.dataset_types["input_1"]), "input_1 (_mock_StructuredDataDict, {})")
+ self.assertEqual(self.graph.dataset_types["input_1"].key, NodeKey(NodeType.DATASET_TYPE, "input_1"))
+ self.assertEqual(self.graph.dataset_types["input_1"].dimensions, self.dimensions.empty)
+ self.assertEqual(self.graph.dataset_types["input_1"].storage_class_name, "_mock_StructuredDataDict")
+ self.assertEqual(self.graph.dataset_types["input_1"].storage_class.name, "_mock_StructuredDataDict")
+
+ def test_resolved_xgraph_export(self) -> None:
+ """Test exporting a resolved PipelineGraph to networkx in various
+ ways."""
+ self.graph.resolve(MockRegistry(self.dimensions, {}))
+ self.check_make_xgraph(self.graph, resolved=True)
+ self.check_make_bipartite_xgraph(self.graph, resolved=True)
+ self.check_make_task_xgraph(self.graph, resolved=True)
+ self.check_make_dataset_type_xgraph(self.graph, resolved=True)
+
+ def test_resolved_stream_io(self) -> None:
+ """Test round-tripping a resolved PipelineGraph through in-memory
+ serialization.
+ """
+ self.graph.resolve(MockRegistry(self.dimensions, {}))
+ stream = io.BytesIO()
+ self.graph.write_stream(stream)
+ stream.seek(0)
+ roundtripped = PipelineGraph.read_stream(stream)
+ self.check_make_xgraph(roundtripped, resolved=True)
+
+ def test_resolved_file_io(self) -> None:
+ """Test round-tripping a resolved PipelineGraph through file
+ serialization.
+ """
+ self.graph.resolve(MockRegistry(self.dimensions, {}))
+ with lsst.utils.tests.getTempFilePath(".json.gz") as filename:
+ self.graph.write_uri(filename)
+ roundtripped = PipelineGraph.read_uri(filename)
+ self.check_make_xgraph(roundtripped, resolved=True)
+
+ def test_resolved_deferred_import_io(self) -> None:
+ """Test round-tripping a resolved PipelineGraph through serialization,
+ without immediately importing tasks on read.
+ """
+ self.graph.resolve(MockRegistry(self.dimensions, {}))
+ stream = io.BytesIO()
+ self.graph.write_stream(stream)
+ stream.seek(0)
+ roundtripped = PipelineGraph.read_stream(stream, import_and_configure=False)
+ self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=False)
+ roundtripped.import_and_configure(check_edges_unchanged=True)
+ self.check_make_xgraph(roundtripped, resolved=True, imported_and_configured=True)
+
+ def test_unresolved_copies(self) -> None:
+ """Test making copies of an unresolved PipelineGraph."""
+ copy1 = self.graph.copy()
+ self.assertIsNot(copy1, self.graph)
+ self.check_make_xgraph(copy1, resolved=False)
+ copy2 = copy.copy(self.graph)
+ self.assertIsNot(copy2, self.graph)
+ self.check_make_xgraph(copy2, resolved=False)
+ copy3 = copy.deepcopy(self.graph)
+ self.assertIsNot(copy3, self.graph)
+ self.check_make_xgraph(copy3, resolved=False)
+
+ def test_resolved_copies(self) -> None:
+ """Test making copies of a resolved PipelineGraph."""
+ self.graph.resolve(MockRegistry(self.dimensions, {}))
+ copy1 = self.graph.copy()
+ self.assertIsNot(copy1, self.graph)
+ self.check_make_xgraph(copy1, resolved=True)
+ copy2 = copy.copy(self.graph)
+ self.assertIsNot(copy2, self.graph)
+ self.check_make_xgraph(copy2, resolved=True)
+ copy3 = copy.deepcopy(self.graph)
+ self.assertIsNot(copy3, self.graph)
+ self.check_make_xgraph(copy3, resolved=True)
+
+ def check_base_accessors(self, graph: PipelineGraph) -> None:
+ """Implementation for test methods that check attribute access,
+ iteration, and simple methods.
+
+ The given graph must be unchanged from the one defined in `setUp`,
+ other than sorting.
+ """
+ self.assertEqual(graph.description, self.description)
+ self.assertEqual(graph.tasks.keys(), {"a", "b"})
+ self.assertEqual(
+ graph.dataset_types.keys(),
+ {
+ "schema",
+ "input_1",
+ "intermediate_1",
+ "output_1",
+ "a_config",
+ "a_log",
+ "a_metadata",
+ "b_config",
+ "b_log",
+ "b_metadata",
+ },
+ )
+ self.assertEqual(graph.task_subsets.keys(), {"only_b"})
+ self.assertEqual(
+ {edge.nodes + (repr(edge),) for edge in graph.iter_edges(init=False)},
+ {
+ (
+ NodeKey(NodeType.DATASET_TYPE, "input_1"),
+ NodeKey(NodeType.TASK, "a"),
+ "input_1 -> a (input1)",
+ ),
+ (
+ NodeKey(NodeType.TASK, "a"),
+ NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
+ "a -> intermediate_1 (output1)",
+ ),
+ (
+ NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
+ NodeKey(NodeType.TASK, "b"),
+ "intermediate_1 -> b (input1)",
+ ),
+ (
+ NodeKey(NodeType.TASK, "b"),
+ NodeKey(NodeType.DATASET_TYPE, "output_1"),
+ "b -> output_1 (output1)",
+ ),
+ (NodeKey(NodeType.TASK, "a"), NodeKey(NodeType.DATASET_TYPE, "a_log"), "a -> a_log (_log)"),
+ (
+ NodeKey(NodeType.TASK, "a"),
+ NodeKey(NodeType.DATASET_TYPE, "a_metadata"),
+ "a -> a_metadata (_metadata)",
+ ),
+ (NodeKey(NodeType.TASK, "b"), NodeKey(NodeType.DATASET_TYPE, "b_log"), "b -> b_log (_log)"),
+ (
+ NodeKey(NodeType.TASK, "b"),
+ NodeKey(NodeType.DATASET_TYPE, "b_metadata"),
+ "b -> b_metadata (_metadata)",
+ ),
+ },
+ )
+ self.assertEqual(
+ {edge.nodes + (repr(edge),) for edge in graph.iter_edges(init=True)},
+ {
+ (
+ NodeKey(NodeType.TASK_INIT, "a"),
+ NodeKey(NodeType.DATASET_TYPE, "schema"),
+ "a -> schema (out_schema)",
+ ),
+ (
+ NodeKey(NodeType.DATASET_TYPE, "schema"),
+ NodeKey(NodeType.TASK_INIT, "b"),
+ "schema -> b (in_schema)",
+ ),
+ (
+ NodeKey(NodeType.TASK_INIT, "a"),
+ NodeKey(NodeType.DATASET_TYPE, "a_config"),
+ "a -> a_config (_config)",
+ ),
+ (
+ NodeKey(NodeType.TASK_INIT, "b"),
+ NodeKey(NodeType.DATASET_TYPE, "b_config"),
+ "b -> b_config (_config)",
+ ),
+ },
+ )
+ self.assertEqual(
+ {(node_type, name) for node_type, name, _ in graph.iter_nodes()},
+ {
+ NodeKey(NodeType.TASK, "a"),
+ NodeKey(NodeType.TASK, "b"),
+ NodeKey(NodeType.TASK_INIT, "a"),
+ NodeKey(NodeType.TASK_INIT, "b"),
+ NodeKey(NodeType.DATASET_TYPE, "schema"),
+ NodeKey(NodeType.DATASET_TYPE, "input_1"),
+ NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
+ NodeKey(NodeType.DATASET_TYPE, "output_1"),
+ NodeKey(NodeType.DATASET_TYPE, "a_config"),
+ NodeKey(NodeType.DATASET_TYPE, "a_log"),
+ NodeKey(NodeType.DATASET_TYPE, "a_metadata"),
+ NodeKey(NodeType.DATASET_TYPE, "b_config"),
+ NodeKey(NodeType.DATASET_TYPE, "b_log"),
+ NodeKey(NodeType.DATASET_TYPE, "b_metadata"),
+ },
+ )
+ self.assertEqual({name for name, _ in graph.iter_overall_inputs()}, {"input_1"})
+ self.assertEqual({edge.task_label for edge in graph.consumers_of("input_1")}, {"a"})
+ self.assertEqual({edge.task_label for edge in graph.consumers_of("intermediate_1")}, {"b"})
+ self.assertEqual({edge.task_label for edge in graph.consumers_of("output_1")}, set())
+ self.assertIsNone(graph.producer_of("input_1"))
+ self.assertEqual(graph.producer_of("intermediate_1").task_label, "a")
+ self.assertEqual(graph.producer_of("output_1").task_label, "b")
+ self.assertTrue(repr(self.graph).startswith(f"PipelineGraph({self.description!r}, tasks="))
+ self.assertEqual(
+ repr(graph.task_subsets["only_b"]), f"only_b: {self.subset_description!r}, tasks={{b}}"
+ )
+
+ def check_sorted(self, graph: PipelineGraph) -> None:
+ """Run a battery of tests on a PipelineGraph that must be
+ deterministically sorted.
+
+ The given graph must be unchanged from the one defined in `setUp`,
+ other than sorting.
+ """
+ self.assertTrue(graph.has_been_sorted)
+ self.assertTrue(graph.is_sorted)
+ self.assertEqual(
+ [(node_type, name) for node_type, name, _ in graph.iter_nodes()],
+ [
+ # We only advertise that the order is topological and
+ # deterministic, so this test is slightly over-specified; there
+ # are other orders that are consistent with our guarantees.
+ NodeKey(NodeType.DATASET_TYPE, "input_1"),
+ NodeKey(NodeType.TASK_INIT, "a"),
+ NodeKey(NodeType.DATASET_TYPE, "a_config"),
+ NodeKey(NodeType.DATASET_TYPE, "schema"),
+ NodeKey(NodeType.TASK_INIT, "b"),
+ NodeKey(NodeType.DATASET_TYPE, "b_config"),
+ NodeKey(NodeType.TASK, "a"),
+ NodeKey(NodeType.DATASET_TYPE, "a_log"),
+ NodeKey(NodeType.DATASET_TYPE, "a_metadata"),
+ NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
+ NodeKey(NodeType.TASK, "b"),
+ NodeKey(NodeType.DATASET_TYPE, "b_log"),
+ NodeKey(NodeType.DATASET_TYPE, "b_metadata"),
+ NodeKey(NodeType.DATASET_TYPE, "output_1"),
+ ],
+ )
+ # Most users should only care that the tasks and dataset types are
+ # topologically sorted.
+ self.assertEqual(list(graph.tasks), ["a", "b"])
+ self.assertEqual(
+ list(graph.dataset_types),
+ [
+ "input_1",
+ "a_config",
+ "schema",
+ "b_config",
+ "a_log",
+ "a_metadata",
+ "intermediate_1",
+ "b_log",
+ "b_metadata",
+ "output_1",
+ ],
+ )
+ # __str__ and __repr__ of course work on unsorted mapping views, too,
+ # but the order of elements is then nondeterministic and hard to test.
+ self.assertEqual(repr(self.graph.tasks), "TaskMappingView({a, b})")
+ self.assertEqual(
+ repr(self.graph.dataset_types),
+ (
+ "DatasetTypeMappingView({input_1, a_config, schema, b_config, a_log, a_metadata, "
+ "intermediate_1, b_log, b_metadata, output_1})"
+ ),
+ )
+
+ def check_make_xgraph(
+ self, graph: PipelineGraph, resolved: bool, imported_and_configured: bool = True
+ ) -> None:
+ """Check that the given graph exports as expected to networkx.
+
+ The given graph must be unchanged from the one defined in `setUp`,
+ other than being resolved (if ``resolved=True``) or round-tripped
+ through serialization without tasks being imported (if
+ ``imported_and_configured=False``).
+ """
+ xgraph = graph.make_xgraph()
+ expected_edges = (
+ {edge.key for edge in graph.iter_edges()}
+ | {edge.key for edge in graph.iter_edges(init=True)}
+ | {
+ (NodeKey(NodeType.TASK_INIT, "a"), NodeKey(NodeType.TASK, "a"), Edge.INIT_TO_TASK_NAME),
+ (NodeKey(NodeType.TASK_INIT, "b"), NodeKey(NodeType.TASK, "b"), Edge.INIT_TO_TASK_NAME),
+ }
+ )
+ test_edges = set(xgraph.edges)
+ self.assertEqual(test_edges, expected_edges)
+ expected_nodes = {
+ NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node(
+ "a", resolved, imported_and_configured=imported_and_configured
+ ),
+ NodeKey(NodeType.TASK, "a"): self.get_expected_task_node(
+ "a", resolved, imported_and_configured=imported_and_configured
+ ),
+ NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node(
+ "b", resolved, imported_and_configured=imported_and_configured
+ ),
+ NodeKey(NodeType.TASK, "b"): self.get_expected_task_node(
+ "b", resolved, imported_and_configured=imported_and_configured
+ ),
+ NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node(
+ "schema", resolved, is_initial_query_constraint=False
+ ),
+ NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node(
+ "input_1", resolved, is_initial_query_constraint=True
+ ),
+ NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node(
+ "intermediate_1", resolved, is_initial_query_constraint=False
+ ),
+ NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node(
+ "output_1", resolved, is_initial_query_constraint=False
+ ),
+ }
+ test_nodes = dict(xgraph.nodes.items())
+ self.assertEqual(set(test_nodes.keys()), set(expected_nodes.keys()))
+ for key, expected_node in expected_nodes.items():
+ test_node = test_nodes[key]
+ self.assertEqual(expected_node, test_node, key)
+
+ def check_make_bipartite_xgraph(self, graph: PipelineGraph, resolved: bool) -> None:
+ """Check that the given graph's init-only or runtime subset exports as
+ expected to networkx.
+
+ The given graph must be unchanged from the one defined in `setUp`,
+ other than being resolved (if ``resolved=True``).
+ """
+ run_xgraph = graph.make_bipartite_xgraph()
+ self.assertEqual(set(run_xgraph.edges), {edge.key for edge in graph.iter_edges()})
+ self.assertEqual(
+ dict(run_xgraph.nodes.items()),
+ {
+ NodeKey(NodeType.TASK, "a"): self.get_expected_task_node("a", resolved),
+ NodeKey(NodeType.TASK, "b"): self.get_expected_task_node("b", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node(
+ "input_1", resolved, is_initial_query_constraint=True
+ ),
+ NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node(
+ "intermediate_1", resolved, is_initial_query_constraint=False
+ ),
+ NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node(
+ "output_1", resolved, is_initial_query_constraint=False
+ ),
+ },
+ )
+ init_xgraph = graph.make_bipartite_xgraph(
+ init=True,
+ )
+ self.assertEqual(set(init_xgraph.edges), {edge.key for edge in graph.iter_edges(init=True)})
+ self.assertEqual(
+ dict(init_xgraph.nodes.items()),
+ {
+ NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node("a", resolved),
+ NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node("b", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node(
+ "schema", resolved, is_initial_query_constraint=False
+ ),
+ NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved),
+ },
+ )
+
+ def check_make_task_xgraph(self, graph: PipelineGraph, resolved: bool) -> None:
+ """Check that the given graph's task-only projection exports as
+ expected to networkx.
+
+ The given graph must be unchanged from the one defined in `setUp`,
+ other than being resolved (if ``resolved=True``).
+ """
+ run_xgraph = graph.make_task_xgraph()
+ self.assertEqual(set(run_xgraph.edges), {(NodeKey(NodeType.TASK, "a"), NodeKey(NodeType.TASK, "b"))})
+ self.assertEqual(
+ dict(run_xgraph.nodes.items()),
+ {
+ NodeKey(NodeType.TASK, "a"): self.get_expected_task_node("a", resolved),
+ NodeKey(NodeType.TASK, "b"): self.get_expected_task_node("b", resolved),
+ },
+ )
+ init_xgraph = graph.make_task_xgraph(
+ init=True,
+ )
+ self.assertEqual(
+ set(init_xgraph.edges),
+ {(NodeKey(NodeType.TASK_INIT, "a"), NodeKey(NodeType.TASK_INIT, "b"))},
+ )
+ self.assertEqual(
+ dict(init_xgraph.nodes.items()),
+ {
+ NodeKey(NodeType.TASK_INIT, "a"): self.get_expected_task_init_node("a", resolved),
+ NodeKey(NodeType.TASK_INIT, "b"): self.get_expected_task_init_node("b", resolved),
+ },
+ )
+
+ def check_make_dataset_type_xgraph(self, graph: PipelineGraph, resolved: bool) -> None:
+ """Check that the given graph's dataset-type-only projection exports as
+ expected to networkx.
+
+ The given graph must be unchanged from the one defined in `setUp`,
+ other than being resolved (if ``resolved=True``).
+ """
+ run_xgraph = graph.make_dataset_type_xgraph()
+ self.assertEqual(
+ set(run_xgraph.edges),
+ {
+ (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "intermediate_1")),
+ (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "a_log")),
+ (NodeKey(NodeType.DATASET_TYPE, "input_1"), NodeKey(NodeType.DATASET_TYPE, "a_metadata")),
+ (
+ NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
+ NodeKey(NodeType.DATASET_TYPE, "output_1"),
+ ),
+ (NodeKey(NodeType.DATASET_TYPE, "intermediate_1"), NodeKey(NodeType.DATASET_TYPE, "b_log")),
+ (
+ NodeKey(NodeType.DATASET_TYPE, "intermediate_1"),
+ NodeKey(NodeType.DATASET_TYPE, "b_metadata"),
+ ),
+ },
+ )
+ self.assertEqual(
+ dict(run_xgraph.nodes.items()),
+ {
+ NodeKey(NodeType.DATASET_TYPE, "a_log"): self.get_expected_log_node("a", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "b_log"): self.get_expected_log_node("b", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "a_metadata"): self.get_expected_metadata_node("a", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "b_metadata"): self.get_expected_metadata_node("b", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "input_1"): self.get_expected_connection_node(
+ "input_1", resolved, is_initial_query_constraint=True
+ ),
+ NodeKey(NodeType.DATASET_TYPE, "intermediate_1"): self.get_expected_connection_node(
+ "intermediate_1", resolved, is_initial_query_constraint=False
+ ),
+ NodeKey(NodeType.DATASET_TYPE, "output_1"): self.get_expected_connection_node(
+ "output_1", resolved, is_initial_query_constraint=False
+ ),
+ },
+ )
+ init_xgraph = graph.make_dataset_type_xgraph(init=True)
+ self.assertEqual(
+ set(init_xgraph.edges),
+ {(NodeKey(NodeType.DATASET_TYPE, "schema"), NodeKey(NodeType.DATASET_TYPE, "b_config"))},
+ )
+ self.assertEqual(
+ dict(init_xgraph.nodes.items()),
+ {
+ NodeKey(NodeType.DATASET_TYPE, "schema"): self.get_expected_connection_node(
+ "schema", resolved, is_initial_query_constraint=False
+ ),
+ NodeKey(NodeType.DATASET_TYPE, "a_config"): self.get_expected_config_node("a", resolved),
+ NodeKey(NodeType.DATASET_TYPE, "b_config"): self.get_expected_config_node("b", resolved),
+ },
+ )
+
+ def get_expected_task_node(
+ self, label: str, resolved: bool, imported_and_configured: bool = True
+ ) -> dict[str, Any]:
+ """Construct a networkx-export task node for comparison."""
+ result = self.get_expected_task_init_node(
+ label, resolved, imported_and_configured=imported_and_configured
+ )
+ if resolved:
+ result["dimensions"] = self.dimensions.empty
+ result["raw_dimensions"] = frozenset()
+ return result
+
+ def get_expected_task_init_node(
+ self, label: str, resolved: bool, imported_and_configured: bool = True
+ ) -> dict[str, Any]:
+ """Construct a networkx-export task init for comparison."""
+ result = {
+ "task_class_name": "lsst.pipe.base.tests.mocks.DynamicTestPipelineTask",
+ "bipartite": 1,
+ }
+ if imported_and_configured:
+ result["task_class"] = DynamicTestPipelineTask
+ result["config"] = getattr(self, f"{label}_config")
+ return result
+
+ def get_expected_config_node(self, label: str, resolved: bool) -> dict[str, Any]:
+ """Construct a networkx-export init-output config dataset type node for
+ comparison.
+ """
+ if not resolved:
+ return {"bipartite": 0}
+ else:
+ return {
+ "dataset_type": DatasetType(
+ acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=label),
+ self.dimensions.empty,
+ acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
+ ),
+ "is_initial_query_constraint": False,
+ "is_prerequisite": False,
+ "dimensions": self.dimensions.empty,
+ "storage_class_name": acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
+ "bipartite": 0,
+ }
+
+ def get_expected_log_node(self, label: str, resolved: bool) -> dict[str, Any]:
+ """Construct a networkx-export output log dataset type node for
+ comparison.
+ """
+ if not resolved:
+ return {"bipartite": 0}
+ else:
+ return {
+ "dataset_type": DatasetType(
+ acc.LOG_OUTPUT_TEMPLATE.format(label=label),
+ self.dimensions.empty,
+ acc.LOG_OUTPUT_STORAGE_CLASS,
+ ),
+ "is_initial_query_constraint": False,
+ "is_prerequisite": False,
+ "dimensions": self.dimensions.empty,
+ "storage_class_name": acc.LOG_OUTPUT_STORAGE_CLASS,
+ "bipartite": 0,
+ }
+
+ def get_expected_metadata_node(self, label: str, resolved: bool) -> dict[str, Any]:
+ """Construct a networkx-export output metadata dataset type node for
+ comparison.
+ """
+ if not resolved:
+ return {"bipartite": 0}
+ else:
+ return {
+ "dataset_type": DatasetType(
+ acc.METADATA_OUTPUT_TEMPLATE.format(label=label),
+ self.dimensions.empty,
+ acc.METADATA_OUTPUT_STORAGE_CLASS,
+ ),
+ "is_initial_query_constraint": False,
+ "is_prerequisite": False,
+ "dimensions": self.dimensions.empty,
+ "storage_class_name": acc.METADATA_OUTPUT_STORAGE_CLASS,
+ "bipartite": 0,
+ }
+
+ def get_expected_connection_node(
+ self, name: str, resolved: bool, *, is_initial_query_constraint: bool
+ ) -> dict[str, Any]:
+ """Construct a networkx-export dataset type node for comparison."""
+ if not resolved:
+ return {"bipartite": 0}
+ else:
+ return {
+ "dataset_type": DatasetType(
+ name,
+ self.dimensions.empty,
+ get_mock_name("StructuredDataDict"),
+ ),
+ "is_initial_query_constraint": is_initial_query_constraint,
+ "is_prerequisite": False,
+ "dimensions": self.dimensions.empty,
+ "storage_class_name": get_mock_name("StructuredDataDict"),
+ "bipartite": 0,
+ }
+
+ def test_construct_with_data_coordinate(self) -> None:
+ """Test constructing a graph with a DataCoordinate.
+
+ Since this creates a graph with DimensionUniverse, all tasks added to
+ it should have resolved dimensions, but not (yet) resolved dataset
+ types. We use that to test a few other operations in that state.
+ """
+ data_id = DataCoordinate.standardize(instrument="I", universe=self.dimensions)
+ graph = PipelineGraph(data_id=data_id)
+ self.assertEqual(graph.universe, self.dimensions)
+ self.assertEqual(graph.data_id, data_id)
+ graph.add_task("b1", DynamicTestPipelineTask, self.b_config)
+ self.assertEqual(graph.tasks["b1"].dimensions, self.dimensions.empty)
+ # Still can't group by dimensions, because the dataset types aren't
+ # resolved.
+ with self.assertRaises(UnresolvedGraphError):
+ graph.group_by_dimensions()
+ # Transferring a node from this graph to ``self.graph`` should
+ # unresolve the dimensions.
+ self.graph.add_task_nodes([graph.tasks["b1"]])
+ self.assertIsNot(self.graph.tasks["b1"], graph.tasks["b1"])
+ self.assertFalse(self.graph.tasks["b1"].has_resolved_dimensions)
+ # Do the opposite transfer, which should resolve dimensions.
+ graph.add_task_nodes([self.graph.tasks["a"]])
+ self.assertIsNot(self.graph.tasks["a"], graph.tasks["a"])
+ self.assertTrue(graph.tasks["a"].has_resolved_dimensions)
+
+ def test_group_by_dimensions(self) -> None:
+ """Test PipelineGraph.group_by_dimensions."""
+ with self.assertRaises(UnresolvedGraphError):
+ self.graph.group_by_dimensions()
+ self.a_config.dimensions = ["visit"]
+ self.a_config.outputs["output1"].dimensions = ["visit"]
+ self.a_config.prerequisite_inputs["prereq1"] = DynamicConnectionConfig(
+ dataset_type_name="prereq_1",
+ multiple=True,
+ dimensions=["htm7"],
+ is_calibration=True,
+ )
+ self.b_config.dimensions = ["htm7"]
+ self.b_config.inputs["input1"].dimensions = ["visit"]
+ self.b_config.inputs["input1"].multiple = True
+ self.b_config.outputs["output1"].dimensions = ["htm7"]
+ self.graph.reconfigure_tasks(a=self.a_config, b=self.b_config)
+ self.graph.resolve(MockRegistry(self.dimensions, {}))
+ visit_dims = self.dimensions.extract(["visit"])
+ htm7_dims = self.dimensions.extract(["htm7"])
+ expected = {
+ self.dimensions.empty: (
+ {},
+ {
+ "schema": self.graph.dataset_types["schema"],
+ "input_1": self.graph.dataset_types["input_1"],
+ "a_config": self.graph.dataset_types["a_config"],
+ "b_config": self.graph.dataset_types["b_config"],
+ },
+ ),
+ visit_dims: (
+ {"a": self.graph.tasks["a"]},
+ {
+ "a_log": self.graph.dataset_types["a_log"],
+ "a_metadata": self.graph.dataset_types["a_metadata"],
+ "intermediate_1": self.graph.dataset_types["intermediate_1"],
+ },
+ ),
+ htm7_dims: (
+ {"b": self.graph.tasks["b"]},
+ {
+ "b_log": self.graph.dataset_types["b_log"],
+ "b_metadata": self.graph.dataset_types["b_metadata"],
+ "output_1": self.graph.dataset_types["output_1"],
+ },
+ ),
+ }
+ self.assertEqual(self.graph.group_by_dimensions(), expected)
+ expected[htm7_dims][1]["prereq_1"] = self.graph.dataset_types["prereq_1"]
+ self.assertEqual(self.graph.group_by_dimensions(prerequisites=True), expected)
+
+ def test_add_and_remove(self) -> None:
+ """Tests for adding and removing tasks and task subsets from a
+ PipelineGraph.
+ """
+ # Can't remove a task while it's still in a subset.
+ with self.assertRaises(PipelineGraphError):
+ self.graph.remove_tasks(["b"], drop_from_subsets=False)
+ # ...unless you remove the subset.
+ self.graph.remove_task_subset("only_b")
+ self.assertFalse(self.graph.task_subsets)
+ ((b, referencing_subsets),) = self.graph.remove_tasks(["b"], drop_from_subsets=False)
+ self.assertFalse(referencing_subsets)
+ self.assertEqual(self.graph.tasks.keys(), {"a"})
+ # Add that task back in.
+ self.graph.add_task_nodes([b])
+ self.assertEqual(self.graph.tasks.keys(), {"a", "b"})
+ # Add the subset back in.
+ self.graph.add_task_subset("only_b", {"b"})
+ self.assertEqual(self.graph.task_subsets.keys(), {"only_b"})
+ # Resolve the graph's dataset types and task dimensions.
+ self.graph.resolve(MockRegistry(self.dimensions, {}))
+ self.assertTrue(self.graph.dataset_types.is_resolved("input_1"))
+ self.assertTrue(self.graph.dataset_types.is_resolved("output_1"))
+ self.assertTrue(self.graph.dataset_types.is_resolved("schema"))
+ self.assertTrue(self.graph.dataset_types.is_resolved("intermediate_1"))
+ # Remove the task while removing it from the subset automatically. This
+ # should also unresolve (only) the referenced dataset types and drop
+ # any datasets no longer attached to any task.
+ self.assertEqual(self.graph.tasks.keys(), {"a", "b"})
+ ((b, referencing_subsets),) = self.graph.remove_tasks(["b"], drop_from_subsets=True)
+ self.assertEqual(referencing_subsets, {"only_b"})
+ self.assertEqual(self.graph.tasks.keys(), {"a"})
+ self.assertTrue(self.graph.dataset_types.is_resolved("input_1"))
+ self.assertNotIn("output1", self.graph.dataset_types)
+ self.assertFalse(self.graph.dataset_types.is_resolved("schema"))
+ self.assertFalse(self.graph.dataset_types.is_resolved("intermediate_1"))
+
+ def test_reconfigure(self) -> None:
+ """Tests for PipelineGraph.reconfigure."""
+ self.graph.resolve(MockRegistry(self.dimensions, {}))
+ self.b_config.outputs["output1"].storage_class = "TaskMetadata"
+ with self.assertRaises(ValueError):
+ # Can't check and assume together.
+ self.graph.reconfigure_tasks(
+ b=self.b_config, assume_edges_unchanged=True, check_edges_unchanged=True
+ )
+ # Check that graph is unchanged after error.
+ self.check_base_accessors(self.graph)
+ with self.assertRaises(EdgesChangedError):
+ self.graph.reconfigure_tasks(b=self.b_config, check_edges_unchanged=True)
+ self.check_base_accessors(self.graph)
+ # Make a change that does affect edges; this will unresolve most
+ # dataset types.
+ self.graph.reconfigure_tasks(b=self.b_config)
+ self.assertTrue(self.graph.dataset_types.is_resolved("input_1"))
+ self.assertFalse(self.graph.dataset_types.is_resolved("output_1"))
+ self.assertFalse(self.graph.dataset_types.is_resolved("schema"))
+ self.assertFalse(self.graph.dataset_types.is_resolved("intermediate_1"))
+ # Resolving again will pick up the new storage class
+ self.graph.resolve(MockRegistry(self.dimensions, {}))
+ self.assertEqual(
+ self.graph.dataset_types["output_1"].storage_class_name, get_mock_name("TaskMetadata")
+ )
+
+
+def _have_example_storage_classes() -> bool:
+ """Check whether some storage classes work as expected.
+
+ Given that these have registered converters, it shouldn't actually be
+ necessary to import be able to those types in order to determine that
+ they're convertible, but the storage class machinery is implemented such
+ that types that can't be imported can't be converted, and while that's
+ inconvenient here it's totally fine in non-testing scenarios where you only
+ care about a storage class if you can actually use it.
+ """
+ getter = StorageClassFactory().getStorageClass
+ return (
+ getter("ArrowTable").can_convert(getter("ArrowAstropy"))
+ and getter("ArrowAstropy").can_convert(getter("ArrowTable"))
+ and getter("ArrowTable").can_convert(getter("DataFrame"))
+ and getter("DataFrame").can_convert(getter("ArrowTable"))
+ )
+
+
+class PipelineGraphResolveTestCase(unittest.TestCase):
+ """More extensive tests for PipelineGraph.resolve and its primate helper
+ methods.
+
+ These are in a separate TestCase because they utilize a different `setUp`
+ from the rest of the `PipelineGraph` tests.
+ """
+
+ def setUp(self) -> None:
+ self.a_config = DynamicTestPipelineTaskConfig()
+ self.b_config = DynamicTestPipelineTaskConfig()
+ self.dimensions = DimensionUniverse()
+ self.maxDiff = None
+
+ def make_graph(self) -> PipelineGraph:
+ graph = PipelineGraph()
+ graph.add_task("a", DynamicTestPipelineTask, self.a_config)
+ graph.add_task("b", DynamicTestPipelineTask, self.b_config)
+ return graph
+
+ def test_prerequisite_inconsistency(self) -> None:
+ """Test that we raise an exception when one edge defines a dataset type
+ as a prerequisite and another does not.
+
+ This test will hopefully someday go away (along with
+ `DatasetTypeNode.is_prerequisite`) when the QuantumGraph generation
+ algorithm becomes more flexible.
+ """
+ self.a_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d")
+ self.b_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d")
+ graph = self.make_graph()
+ with self.assertRaises(ConnectionTypeConsistencyError):
+ graph.resolve(MockRegistry(self.dimensions, {}))
+
+ def test_prerequisite_inconsistency_reversed(self) -> None:
+ """Same as `test_prerequisite_inconsistency`, with the order the edges
+ are added to the graph reversed.
+ """
+ self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d")
+ self.b_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d")
+ graph = self.make_graph()
+ with self.assertRaises(ConnectionTypeConsistencyError):
+ graph.resolve(MockRegistry(self.dimensions, {}))
+
+ def test_prerequisite_output(self) -> None:
+ """Test that we raise an exception when one edge defines a dataset type
+ as a prerequisite but another defines it as an output.
+ """
+ self.a_config.prerequisite_inputs["p"] = DynamicConnectionConfig(dataset_type_name="d")
+ self.b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d")
+ graph = self.make_graph()
+ with self.assertRaises(ConnectionTypeConsistencyError):
+ graph.resolve(MockRegistry(self.dimensions, {}))
+
+ def test_skypix_missing(self) -> None:
+ """Test that we raise an exception when one edge uses the "skypix"
+ dimension as a placeholder but the dataset type is not registered.
+ """
+ self.a_config.prerequisite_inputs["i"] = DynamicConnectionConfig(
+ dataset_type_name="d", dimensions={"skypix"}
+ )
+ graph = self.make_graph()
+ with self.assertRaises(MissingDatasetTypeError):
+ graph.resolve(MockRegistry(self.dimensions, {}))
+
+ def test_skypix_inconsistent(self) -> None:
+ """Test that we raise an exception when one edge uses the "skypix"
+ dimension as a placeholder but the rest of the dimensions are
+ inconsistent with the registered dataset type.
+ """
+ self.a_config.prerequisite_inputs["i"] = DynamicConnectionConfig(
+ dataset_type_name="d", dimensions={"skypix", "visit"}
+ )
+ graph = self.make_graph()
+ with self.assertRaises(IncompatibleDatasetTypeError):
+ graph.resolve(
+ MockRegistry(
+ self.dimensions,
+ {
+ "d": DatasetType(
+ "d",
+ dimensions=self.dimensions.extract(["htm7"]),
+ storageClass="StructuredDataDict",
+ )
+ },
+ )
+ )
+ with self.assertRaises(IncompatibleDatasetTypeError):
+ graph.resolve(
+ MockRegistry(
+ self.dimensions,
+ {
+ "d": DatasetType(
+ "d",
+ dimensions=self.dimensions.extract(["htm7", "visit", "skymap"]),
+ storageClass="StructuredDataDict",
+ )
+ },
+ )
+ )
+
+ def test_duplicate_outputs(self) -> None:
+ """Test that we raise an exception when a dataset type node would have
+ two write edges.
+ """
+ self.a_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d")
+ self.b_config.outputs["o"] = DynamicConnectionConfig(dataset_type_name="d")
+ graph = self.make_graph()
+ with self.assertRaises(DuplicateOutputError):
+ graph.resolve(MockRegistry(self.dimensions, {}))
+
+ def test_component_of_unregistered_parent(self) -> None:
+ """Test that we raise an exception when a component dataset type's
+ parent is not registered.
+ """
+ self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d.c")
+ graph = self.make_graph()
+ with self.assertRaises(MissingDatasetTypeError):
+ graph.resolve(MockRegistry(self.dimensions, {}))
+
+ def test_undefined_component(self) -> None:
+ """Test that we raise an exception when a component dataset type's
+ parent is registered, but its storage class does not have that
+ component.
+ """
+ self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d.c")
+ graph = self.make_graph()
+ with self.assertRaises(IncompatibleDatasetTypeError):
+ graph.resolve(
+ MockRegistry(
+ self.dimensions,
+ {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))},
+ )
+ )
+
+ @unittest.skipUnless(
+ _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
+ )
+ def test_bad_component_storage_class(self) -> None:
+ """Test that we raise an exception when a component dataset type's
+ parent is registered, but does not have that component.
+ """
+ self.a_config.inputs["i"] = DynamicConnectionConfig(
+ dataset_type_name="d.schema", storage_class="StructuredDataDict"
+ )
+ graph = self.make_graph()
+ with self.assertRaises(IncompatibleDatasetTypeError):
+ graph.resolve(
+ MockRegistry(
+ self.dimensions,
+ {"d": DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))},
+ )
+ )
+
+ def test_input_storage_class_incompatible_with_registry(self) -> None:
+ """Test that we raise an exception when an input connection's storage
+ class is incompatible with the registry definition.
+ """
+ self.a_config.inputs["i"] = DynamicConnectionConfig(
+ dataset_type_name="d", storage_class="StructuredDataList"
+ )
+ graph = self.make_graph()
+ with self.assertRaises(IncompatibleDatasetTypeError):
+ graph.resolve(
+ MockRegistry(
+ self.dimensions,
+ {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))},
+ )
+ )
+
+ def test_output_storage_class_incompatible_with_registry(self) -> None:
+ """Test that we raise an exception when an output connection's storage
+ class is incompatible with the registry definition.
+ """
+ self.a_config.outputs["o"] = DynamicConnectionConfig(
+ dataset_type_name="d", storage_class="StructuredDataList"
+ )
+ graph = self.make_graph()
+ with self.assertRaises(IncompatibleDatasetTypeError):
+ graph.resolve(
+ MockRegistry(
+ self.dimensions,
+ {"d": DatasetType("d", self.dimensions.empty, get_mock_name("StructuredDataDict"))},
+ )
+ )
+
+ def test_input_storage_class_incompatible_with_output(self) -> None:
+ """Test that we raise an exception when an input connection's storage
+ class is incompatible with the storage class of the output connection.
+ """
+ self.a_config.outputs["o"] = DynamicConnectionConfig(
+ dataset_type_name="d", storage_class="StructuredDataDict"
+ )
+ self.b_config.inputs["i"] = DynamicConnectionConfig(
+ dataset_type_name="d", storage_class="StructuredDataList"
+ )
+ graph = self.make_graph()
+ with self.assertRaises(IncompatibleDatasetTypeError):
+ graph.resolve(MockRegistry(self.dimensions, {}))
+
+ def test_ambiguous_storage_class(self) -> None:
+ """Test that we raise an exception when two input connections define
+ the same dataset with different storage classes (even compatible ones)
+ and there is no output connection or registry definition to take
+ precedence.
+ """
+ self.a_config.inputs["i"] = DynamicConnectionConfig(
+ dataset_type_name="d", storage_class="StructuredDataDict"
+ )
+ self.b_config.inputs["i"] = DynamicConnectionConfig(
+ dataset_type_name="d", storage_class="StructuredDataList"
+ )
+ graph = self.make_graph()
+ with self.assertRaises(MissingDatasetTypeError):
+ graph.resolve(MockRegistry(self.dimensions, {}))
+
+ @unittest.skipUnless(
+ _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
+ )
+ def test_inputs_compatible_with_registry(self) -> None:
+ """Test successful resolution of a dataset type where input edges have
+ different but compatible storage classes and the dataset type is
+ already registered.
+ """
+ self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="ArrowTable")
+ self.b_config.inputs["i"] = DynamicConnectionConfig(
+ dataset_type_name="d", storage_class="ArrowAstropy"
+ )
+ graph = self.make_graph()
+ dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("DataFrame"))
+ graph.resolve(MockRegistry(self.dimensions, {"d": dataset_type}))
+ self.assertEqual(graph.dataset_types["d"].dataset_type, dataset_type)
+ a_i = graph.tasks["a"].inputs["i"]
+ b_i = graph.tasks["b"].inputs["i"]
+ self.assertEqual(
+ a_i.adapt_dataset_type(dataset_type),
+ dataset_type.overrideStorageClass(get_mock_name("ArrowTable")),
+ )
+ self.assertEqual(
+ b_i.adapt_dataset_type(dataset_type),
+ dataset_type.overrideStorageClass(get_mock_name("ArrowAstropy")),
+ )
+ data_id = DataCoordinate.makeEmpty(self.dimensions)
+ ref = DatasetRef(dataset_type, data_id, run="r")
+ a_ref = a_i.adapt_dataset_ref(ref)
+ b_ref = b_i.adapt_dataset_ref(ref)
+ self.assertEqual(a_ref, ref.overrideStorageClass(get_mock_name("ArrowTable")))
+ self.assertEqual(b_ref, ref.overrideStorageClass(get_mock_name("ArrowAstropy")))
+ self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
+ self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
+
+ @unittest.skipUnless(
+ _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
+ )
+ def test_output_compatible_with_registry(self) -> None:
+ """Test successful resolution of a dataset type where an output edge
+ has a different but compatible storage class from the dataset type
+ already registered.
+ """
+ self.a_config.outputs["o"] = DynamicConnectionConfig(
+ dataset_type_name="d", storage_class="ArrowTable"
+ )
+ graph = self.make_graph()
+ dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("DataFrame"))
+ graph.resolve(MockRegistry(self.dimensions, {"d": dataset_type}))
+ self.assertEqual(graph.dataset_types["d"].dataset_type, dataset_type)
+ a_o = graph.tasks["a"].outputs["o"]
+ self.assertEqual(
+ a_o.adapt_dataset_type(dataset_type),
+ dataset_type.overrideStorageClass(get_mock_name("ArrowTable")),
+ )
+ data_id = DataCoordinate.makeEmpty(self.dimensions)
+ ref = DatasetRef(dataset_type, data_id, run="r")
+ a_ref = a_o.adapt_dataset_ref(ref)
+ self.assertEqual(a_ref, ref.overrideStorageClass(get_mock_name("ArrowTable")))
+ self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
+
+ @unittest.skipUnless(
+ _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
+ )
+ def test_inputs_compatible_with_output(self) -> None:
+ """Test successful resolution of a dataset type where an input edge has
+ a different but compatible storage class from the output edge, and
+ the dataset type is not registered.
+ """
+ self.a_config.outputs["o"] = DynamicConnectionConfig(
+ dataset_type_name="d", storage_class="ArrowTable"
+ )
+ self.b_config.inputs["i"] = DynamicConnectionConfig(
+ dataset_type_name="d", storage_class="ArrowAstropy"
+ )
+ graph = self.make_graph()
+ a_o = graph.tasks["a"].outputs["o"]
+ b_i = graph.tasks["b"].inputs["i"]
+ graph.resolve(MockRegistry(self.dimensions, {}))
+ self.assertEqual(graph.dataset_types["d"].storage_class_name, get_mock_name("ArrowTable"))
+ self.assertEqual(
+ a_o.adapt_dataset_type(graph.dataset_types["d"].dataset_type),
+ graph.dataset_types["d"].dataset_type,
+ )
+ self.assertEqual(
+ b_i.adapt_dataset_type(graph.dataset_types["d"].dataset_type),
+ graph.dataset_types["d"].dataset_type.overrideStorageClass(get_mock_name("ArrowAstropy")),
+ )
+ data_id = DataCoordinate.makeEmpty(self.dimensions)
+ ref = DatasetRef(graph.dataset_types["d"].dataset_type, data_id, run="r")
+ a_ref = a_o.adapt_dataset_ref(ref)
+ b_ref = b_i.adapt_dataset_ref(ref)
+ self.assertEqual(a_ref, ref)
+ self.assertEqual(b_ref, ref.overrideStorageClass(get_mock_name("ArrowAstropy")))
+ self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
+ self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
+
+ @unittest.skipUnless(
+ _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
+ )
+ def test_component_resolved_by_input(self) -> None:
+ """Test successful resolution of a component dataset type due to
+ another input referencing the parent dataset type.
+ """
+ self.a_config.inputs["i"] = DynamicConnectionConfig(dataset_type_name="d", storage_class="ArrowTable")
+ self.b_config.inputs["i"] = DynamicConnectionConfig(
+ dataset_type_name="d.schema", storage_class="ArrowSchema"
+ )
+ graph = self.make_graph()
+ parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))
+ graph.resolve(MockRegistry(self.dimensions, {}))
+ self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type)
+ a_i = graph.tasks["a"].inputs["i"]
+ b_i = graph.tasks["b"].inputs["i"]
+ self.assertEqual(b_i.dataset_type_name, "d.schema")
+ self.assertEqual(a_i.adapt_dataset_type(parent_dataset_type), parent_dataset_type)
+ self.assertEqual(
+ b_i.adapt_dataset_type(parent_dataset_type),
+ parent_dataset_type.makeComponentDatasetType("schema"),
+ )
+ data_id = DataCoordinate.makeEmpty(self.dimensions)
+ ref = DatasetRef(parent_dataset_type, data_id, run="r")
+ a_ref = a_i.adapt_dataset_ref(ref)
+ b_ref = b_i.adapt_dataset_ref(ref)
+ self.assertEqual(a_ref, ref)
+ self.assertEqual(b_ref, ref.makeComponentRef("schema"))
+ self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
+ self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
+
+ @unittest.skipUnless(
+ _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
+ )
+ def test_component_resolved_by_output(self) -> None:
+ """Test successful resolution of a component dataset type due to
+ an output connection referencing the parent dataset type.
+ """
+ self.a_config.outputs["o"] = DynamicConnectionConfig(
+ dataset_type_name="d", storage_class="ArrowTable"
+ )
+ self.b_config.inputs["i"] = DynamicConnectionConfig(
+ dataset_type_name="d.schema", storage_class="ArrowSchema"
+ )
+ graph = self.make_graph()
+ parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))
+ graph.resolve(MockRegistry(self.dimensions, {}))
+ self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type)
+ a_o = graph.tasks["a"].outputs["o"]
+ b_i = graph.tasks["b"].inputs["i"]
+ self.assertEqual(b_i.dataset_type_name, "d.schema")
+ self.assertEqual(a_o.adapt_dataset_type(parent_dataset_type), parent_dataset_type)
+ self.assertEqual(
+ b_i.adapt_dataset_type(parent_dataset_type),
+ parent_dataset_type.makeComponentDatasetType("schema"),
+ )
+ data_id = DataCoordinate.makeEmpty(self.dimensions)
+ ref = DatasetRef(parent_dataset_type, data_id, run="r")
+ a_ref = a_o.adapt_dataset_ref(ref)
+ b_ref = b_i.adapt_dataset_ref(ref)
+ self.assertEqual(a_ref, ref)
+ self.assertEqual(b_ref, ref.makeComponentRef("schema"))
+ self.assertEqual(graph.dataset_types["d"].generalize_ref(a_ref), ref)
+ self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
+
+ @unittest.skipUnless(
+ _have_example_storage_classes(), "Arrow/Astropy/Pandas storage classes are not available."
+ )
+ def test_component_resolved_by_registry(self) -> None:
+ """Test successful resolution of a component dataset type due to
+ the parent dataset type already being registered.
+ """
+ self.b_config.inputs["i"] = DynamicConnectionConfig(
+ dataset_type_name="d.schema", storage_class="ArrowSchema"
+ )
+ graph = self.make_graph()
+ parent_dataset_type = DatasetType("d", self.dimensions.empty, get_mock_name("ArrowTable"))
+ graph.resolve(MockRegistry(self.dimensions, {"d": parent_dataset_type}))
+ self.assertEqual(graph.dataset_types["d"].dataset_type, parent_dataset_type)
+ b_i = graph.tasks["b"].inputs["i"]
+ self.assertEqual(b_i.dataset_type_name, "d.schema")
+ self.assertEqual(
+ b_i.adapt_dataset_type(parent_dataset_type),
+ parent_dataset_type.makeComponentDatasetType("schema"),
+ )
+ data_id = DataCoordinate.makeEmpty(self.dimensions)
+ ref = DatasetRef(parent_dataset_type, data_id, run="r")
+ b_ref = b_i.adapt_dataset_ref(ref)
+ self.assertEqual(b_ref, ref.makeComponentRef("schema"))
+ self.assertEqual(graph.dataset_types["d"].generalize_ref(b_ref), ref)
+
+
+if __name__ == "__main__":
+ lsst.utils.tests.init()
+ unittest.main()