Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/DM-54879.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added `retained_dataset_types` parameter to `QuantumGraphBuilder` and `SeparablePipelineExecutor`.
243 changes: 204 additions & 39 deletions python/lsst/pipe/base/quantum_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"QuantumGraphBuilderError",
)

import collections
import dataclasses
import operator
from abc import ABC, abstractmethod
Expand All @@ -61,6 +62,7 @@
from lsst.daf.butler._rubin import generate_uuidv7
from lsst.daf.butler.datastore.record_data import DatastoreRecordData
from lsst.daf.butler.registry import MissingCollectionError, MissingDatasetTypeError
from lsst.daf.butler.utils import globToRegex
from lsst.utils.logging import LsstLogAdapter, getLogger
from lsst.utils.timer import timeMethod

Expand Down Expand Up @@ -128,6 +130,14 @@ class QuantumGraphBuilder(ABC):
skip_existing_in : `~collections.abc.Sequence` [ `str` ], optional
Collections to search for outputs that already exist for the purpose of
skipping quanta that have already been run.
retained_dataset_types : `~collections.abc.Sequence` [ `str` ], optional
Dataset type names or glob-style wildcard patterns for dataset types
that should exist in ``skip_existing_in`` when the producing task ran
successfully. When a quantum should run, the builder propagates the
must-run signal backward through non-retained input datasets, forcing
the upstream quanta that need to regenerate those intermediates to also
run. Has no effect without ``skip_existing_in``. ``["*"]`` means
retaining all datasets, equivalent to not providing this option.
clobber : `bool`, optional
Whether to raise if predicted outputs already exist in ``output_run``
(not including those quanta that would be skipped because they've
Expand Down Expand Up @@ -171,6 +181,7 @@ def __init__(
input_collections: Sequence[str] | None = None,
output_run: str | None = None,
skip_existing_in: Sequence[str] = (),
retained_dataset_types: Sequence[str] | None = None,
clobber: bool = False,
):
self.log = getLogger(__name__)
Expand All @@ -188,6 +199,11 @@ def __init__(
self.butler = butler.clone(collections=input_collections)
self.output_run = output_run
self.skip_existing_in = skip_existing_in
self._retained_dataset_type_patterns: list[str] | None = (
list(retained_dataset_types) if retained_dataset_types is not None else None
)
if self._retained_dataset_type_patterns is not None and not skip_existing_in:
raise ValueError("retained_dataset_types has no effect without skip_existing_in.")
self.empty_data_id = DataCoordinate.make_empty(butler.dimensions)
self.clobber = clobber
# See whether the output run already exists.
Expand Down Expand Up @@ -407,8 +423,32 @@ def _build_skeleton(self, attach_datastore_records: bool = True) -> QuantumGraph
# so a quantum is only processed after any quantum that provides
# its inputs has been processed.
skipped_quanta: dict[str, list[QuantumKey]] = {}
for task_node in self._pipeline_graph.tasks.values():
skipped_quanta[task_node.label] = self._resolve_task_quanta(task_node, full_skeleton)
retained_types = self._expand_retained_patterns(self._retained_dataset_type_patterns)
# retained_types is None when all types are retained or option
# absent: no ancestor unskipping is needed.
if retained_types is None:
for task_node in self._pipeline_graph.tasks.values():
skipped_quanta[task_node.label] = self._resolve_task_quanta(task_node, full_skeleton)
else:
skip_decisions: dict[QuantumKey, bool] = {}
# Compute initial skip decisions without mutating the skeleton.
for task_node in self._pipeline_graph.tasks.values():
for quantum_key in full_skeleton.get_quanta(task_node.label):
skip_decisions[quantum_key] = self._compute_skip_decision(
task_node, quantum_key, full_skeleton
)
# Unskip ancestor quanta whose outputs are not retained.
n_unskipped = self._unskip_ancestors(full_skeleton, skip_decisions, retained_types)
if n_unskipped:
self.log.info(
"Forcing %s to rerun (output not retained).",
_quantum_or_quanta(n_unskipped),
)
# Apply decisions.
for task_node in self._pipeline_graph.tasks.values():
skipped_quanta[task_node.label] = self._resolve_task_quanta(
task_node, full_skeleton, skip_decisions=skip_decisions
)
# Add any dimension records not handled by the subclass, and
# aggregate any that were added directly to data IDs.
full_skeleton.attach_dimension_records(self.butler, self._pipeline_graph.get_all_dimensions())
Expand Down Expand Up @@ -490,7 +530,12 @@ class can add them later, albeit possibly less efficiently).

@final
@timeMethod
def _resolve_task_quanta(self, task_node: TaskNode, skeleton: QuantumGraphSkeleton) -> list[QuantumKey]:
def _resolve_task_quanta(
self,
task_node: TaskNode,
skeleton: QuantumGraphSkeleton,
skip_decisions: dict[QuantumKey, bool] | None = None,
) -> list[QuantumKey]:
"""Process the quanta for one task in a skeleton graph to skip those
that have already completed and add missing prerequisite inputs.

Expand All @@ -500,6 +545,9 @@ def _resolve_task_quanta(self, task_node: TaskNode, skeleton: QuantumGraphSkelet
Node for this task in the pipeline graph.
skeleton : `.quantum_graph_skeleton.QuantumGraphSkeleton`
Preliminary quantum graph, to be modified in-place.
skip_decisions : `dict` [ `QuantumKey`, `bool` ] or `None`, optional
Pre-computed per-quantum skip decisions. When provided, the
decisions are applied directly.

Returns
-------
Expand Down Expand Up @@ -531,7 +579,13 @@ def _resolve_task_quanta(self, task_node: TaskNode, skeleton: QuantumGraphSkelet
# gotten rid of.
skipped_quanta = []
for quantum_key in skeleton.get_quanta(task_node.label):
if self._skip_quantum_if_metadata_exists(task_node, quantum_key, skeleton):
if skip_decisions is not None:
if skip_decisions.get(quantum_key, False):
self._apply_skip_decision(task_node, quantum_key, skeleton)
skipped_quanta.append(quantum_key)
continue
elif self._compute_skip_decision(task_node, quantum_key, skeleton):
self._apply_skip_decision(task_node, quantum_key, skeleton)
skipped_quanta.append(quantum_key)
continue
quantum_data_id = skeleton[quantum_key]["data_id"]
Expand Down Expand Up @@ -662,7 +716,7 @@ def _adjust_task_quanta(
if no_work_quanta:
message_terms.append(f"{len(no_work_quanta)} had no work to do")
if skipped_quanta:
message_terms.append(f"{len(skipped_quanta)} previously succeeded")
message_terms.append(f"{len(skipped_quanta)} previously succeeded and skipped")
if adjuster.n_removed:
message_terms.append(f"{adjuster.n_removed} removed by adjust_all_quanta")
message_parenthetical = f" ({', '.join(message_terms)})" if message_terms else ""
Expand Down Expand Up @@ -699,11 +753,11 @@ def _get_task_inputs_if_overall_only(self, task_node: TaskNode) -> list[str] | N
return None
return result

def _skip_quantum_if_metadata_exists(
def _compute_skip_decision(
self, task_node: TaskNode, quantum_key: QuantumKey, skeleton: QuantumGraphSkeleton
) -> bool:
"""Identify and drop quanta that should be skipped because their
metadata datasets already exist.
"""Identify if a quantum should be skipped because its
metadata dataset already exists.

Parameters
----------
Expand All @@ -712,51 +766,162 @@ def _skip_quantum_if_metadata_exists(
quantum_key : `QuantumKey`
Identifier for this quantum in the graph.
skeleton : `.quantum_graph_skeleton.QuantumGraphSkeleton`
Preliminary quantum graph, to be modified in-place.
Preliminary quantum graph (not modified).

Returns
-------
skipped : `bool`
`True` if the quantum is being skipped and has been removed from
the graph, `False` otherwise.
skip : `bool`
`True` if the quantum's metadata exists in ``skip_existing_in`` and
should be skipped.
"""
metadata_dataset_key = DatasetKey(
task_node.metadata_output.parent_dataset_type_name, quantum_key.data_id_values
)
return bool(skeleton.get_output_for_skip(metadata_dataset_key))

def _apply_skip_decision(
self, task_node: TaskNode, quantum_key: QuantumKey, skeleton: QuantumGraphSkeleton
) -> None:
"""Update the skeleton for a quantum that has been decided to skip.

Parameters
----------
task_node : `pipeline_graph.TaskNode`
Node for this task in the pipeline graph.
quantum_key : `QuantumKey`
Identifier for this quantum in the graph.
skeleton : `.quantum_graph_skeleton.QuantumGraphSkeleton`
Preliminary quantum graph, to be modified in-place.

Notes
-----
If the metadata dataset for this quantum exists in the
`skip_existing_in` collections, the quantum will be skipped. This
The metadata dataset for this quantum exists in the
`skip_existing_in` collections and the quantum will be skipped. This
causes the quantum node to be removed from the graph. Dataset nodes
that were previously the outputs of this quantum will be associated
with `lsst.daf.butler.DatasetRef` objects that were found in
``skip_existing_in``, or will be removed if there is no such dataset
there. Any output dataset in `output_run` will be removed from the
"output in the way" category.
"""
metadata_dataset_key = DatasetKey(
task_node.metadata_output.parent_dataset_type_name, quantum_key.data_id_values
# This quantum's metadata is already present in the
# skip_existing_in collections; we'll skip it. But the presence of
# the metadata dataset doesn't guarantee that all of the other
# outputs we predicted are present; we have to check.
for output_dataset_key in list(skeleton.iter_outputs_of(quantum_key)):
# If this dataset was "in the way" (i.e. already in the
# output run), it isn't anymore.
skeleton.discard_output_in_the_way(output_dataset_key)
if (output_ref := skeleton.get_output_for_skip(output_dataset_key)) is not None:
# Populate the skeleton graph's node attributes
# with the existing DatasetRef, just like a
# predicted output of a non-skipped quantum.
skeleton.set_dataset_ref(output_ref, output_dataset_key)
else:
# Remove this dataset from the skeleton graph,
# because the quantum that would have produced it
# is being skipped and it doesn't already exist.
skeleton.remove_dataset_nodes([output_dataset_key])
# Removing the quantum node from the graph will happen outside this
# function.

def _expand_retained_patterns(self, patterns: list[str] | None) -> frozenset[str] | None:
"""Expand wildcard patterns into a concrete set of retained dataset
type names.

Parameters
----------
patterns : `list` [ `str` ] or `None`
Dataset type names or glob-style wildcard patterns, or `None` if
the option was not provided.

Returns
-------
retained_types : `frozenset` [ `str` ] or `None`
Concrete set of retained dataset type names, or `None` if no
ancestor unskipping is needed (option absent, empty list, or
patterns match everything).

"""
if patterns is None:
return None
regexes = globToRegex(patterns)
if regexes is ...:
# globToRegex returns Ellipsis when patterns match everything,
# e.g. the list is empty or "*". Treat as "retain everything".
return None
all_names = set(self._pipeline_graph.dataset_types)
result: set[str] = set()
for original, expression in zip(patterns, regexes):
if isinstance(expression, str):
if expression not in all_names:
self.log.warning("Retained dataset type %r not found in the pipeline.", expression)
result.add(expression)
else:
matches = {n for n in all_names if expression.search(n)}
if not matches:
self.log.warning(
"Retained dataset type pattern %r matches no dataset types.",
original,
)
result.update(matches)
return frozenset(result)

def _unskip_ancestors(
self,
skeleton: QuantumGraphSkeleton,
skip_decisions: dict[QuantumKey, bool],
retained: frozenset[str],
) -> int:
"""Unskip ancestor quanta whose outputs are not retained.

Parameters
----------
skeleton : `.quantum_graph_skeleton.QuantumGraphSkeleton`
Preliminary quantum graph (not modified).
skip_decisions : `dict` [ `QuantumKey`, `bool` ]
Per-quantum skip decisions, modified in-place.
retained : `frozenset` [ `str` ]
Dataset type names that should be present in
``skip_existing_in`` when their producing task has been skipped.
Types not in this set are treated as not retained.

Returns
-------
n_unskipped : `int`
Number of quanta unskipped by backward propagation.

Notes
-----
Seeds the breadth-first search with every initially must-run quantum.
For each must-run quantum, walks non-prerequisite input edges whose
dataset type is not retained. If the producer of such a dataset is
currently marked skip, it is unskipped and enqueued so that its own
inputs are examined in turn. Each quantum is visited at most once.
"""
queue: collections.deque[QuantumKey] = collections.deque(
qk for qk, skip in skip_decisions.items() if not skip
)
if skeleton.get_output_for_skip(metadata_dataset_key):
# This quantum's metadata is already present in the the
# skip_existing_in collections; we'll skip it. But the presence of
# the metadata dataset doesn't guarantee that all of the other
# outputs we predicted are present; we have to check.
for output_dataset_key in list(skeleton.iter_outputs_of(quantum_key)):
# If this dataset was "in the way" (i.e. already in the
# output run), it isn't anymore.
skeleton.discard_output_in_the_way(output_dataset_key)
if (output_ref := skeleton.get_output_for_skip(output_dataset_key)) is not None:
# Populate the skeleton graph's node attributes
# with the existing DatasetRef, just like a
# predicted output of a non-skipped quantum.
skeleton.set_dataset_ref(output_ref, output_dataset_key)
else:
# Remove this dataset from the skeleton graph,
# because the quantum that would have produced it
# is being skipped and it doesn't already exist.
skeleton.remove_dataset_nodes([output_dataset_key])
# Removing the quantum node from the graph will happen outside this
# function.
return True
return False
visited: set[QuantumKey] = set(queue)
n_unskipped = 0
while queue:
qk = queue.popleft()
for input_key in skeleton.iter_inputs_of(qk):
if input_key.is_prerequisite:
continue
if input_key.parent_dataset_type_name in retained:
continue
for producer_key in skeleton.iter_producers_of(input_key):
if not isinstance(producer_key, QuantumKey):
continue
if producer_key in visited:
continue
visited.add(producer_key)
if skip_decisions.get(producer_key, False):
skip_decisions[producer_key] = False
queue.append(producer_key)
n_unskipped += 1
return n_unskipped

@final
def _update_quantum_for_adjust(
Expand Down
15 changes: 15 additions & 0 deletions python/lsst/pipe/base/quantum_graph_skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,21 @@ def iter_inputs_of(
"""
return self._xgraph.predecessors(quantum_key)

def iter_producers_of(self, dataset_key: DatasetKey) -> Iterator[QuantumKey | TaskInitKey]:
"""Iterate over the quanta that produce the given dataset.

Parameters
----------
dataset_key : `DatasetKey`
Dataset to look up producers for.

Returns
-------
quanta : `~collections.abc.Iterator` of `QuantumKey` or `TaskInitKey`
Quanta that produce the given dataset.
"""
return self._xgraph.predecessors(dataset_key)

def update(self, other: QuantumGraphSkeleton) -> None:
"""Copy all nodes from ``other`` to ``self``.

Expand Down
Loading
Loading