diff --git a/doc/changes/DM-54879.feature.md b/doc/changes/DM-54879.feature.md new file mode 100644 index 000000000..704eb13c2 --- /dev/null +++ b/doc/changes/DM-54879.feature.md @@ -0,0 +1 @@ +Added `retained_dataset_types` parameter to `QuantumGraphBuilder` and `SeparablePipelineExecutor`. diff --git a/python/lsst/pipe/base/quantum_graph_builder.py b/python/lsst/pipe/base/quantum_graph_builder.py index f230134f4..c986ceeb5 100644 --- a/python/lsst/pipe/base/quantum_graph_builder.py +++ b/python/lsst/pipe/base/quantum_graph_builder.py @@ -39,6 +39,7 @@ "QuantumGraphBuilderError", ) +import collections import dataclasses import operator from abc import ABC, abstractmethod @@ -61,6 +62,7 @@ from lsst.daf.butler._rubin import generate_uuidv7 from lsst.daf.butler.datastore.record_data import DatastoreRecordData from lsst.daf.butler.registry import MissingCollectionError, MissingDatasetTypeError +from lsst.daf.butler.utils import globToRegex from lsst.utils.logging import LsstLogAdapter, getLogger from lsst.utils.timer import timeMethod @@ -128,6 +130,14 @@ class QuantumGraphBuilder(ABC): skip_existing_in : `~collections.abc.Sequence` [ `str` ], optional Collections to search for outputs that already exist for the purpose of skipping quanta that have already been run. + retained_dataset_types : `~collections.abc.Sequence` [ `str` ], optional + Dataset type names or glob-style wildcard patterns for dataset types + that should exist in ``skip_existing_in`` when the producing task ran + successfully. When a quantum should run, the builder propagates the + must-run signal backward through non-retained input datasets, forcing + the upstream quanta that need to regenerate those intermediates to also + run. Has no effect without ``skip_existing_in``. ``["*"]`` means + retaining all datasets, equivalent to not providing this option. clobber : `bool`, optional Whether to raise if predicted outputs already exist in ``output_run`` (not including those quanta that would be skipped because they've @@ -171,6 +181,7 @@ def __init__( input_collections: Sequence[str] | None = None, output_run: str | None = None, skip_existing_in: Sequence[str] = (), + retained_dataset_types: Sequence[str] | None = None, clobber: bool = False, ): self.log = getLogger(__name__) @@ -188,6 +199,11 @@ def __init__( self.butler = butler.clone(collections=input_collections) self.output_run = output_run self.skip_existing_in = skip_existing_in + self._retained_dataset_type_patterns: list[str] | None = ( + list(retained_dataset_types) if retained_dataset_types is not None else None + ) + if self._retained_dataset_type_patterns is not None and not skip_existing_in: + raise ValueError("retained_dataset_types has no effect without skip_existing_in.") self.empty_data_id = DataCoordinate.make_empty(butler.dimensions) self.clobber = clobber # See whether the output run already exists. @@ -407,8 +423,32 @@ def _build_skeleton(self, attach_datastore_records: bool = True) -> QuantumGraph # so a quantum is only processed after any quantum that provides # its inputs has been processed. skipped_quanta: dict[str, list[QuantumKey]] = {} - for task_node in self._pipeline_graph.tasks.values(): - skipped_quanta[task_node.label] = self._resolve_task_quanta(task_node, full_skeleton) + retained_types = self._expand_retained_patterns(self._retained_dataset_type_patterns) + # retained_types is None when all types are retained or option + # absent: no ancestor unskipping is needed. + if retained_types is None: + for task_node in self._pipeline_graph.tasks.values(): + skipped_quanta[task_node.label] = self._resolve_task_quanta(task_node, full_skeleton) + else: + skip_decisions: dict[QuantumKey, bool] = {} + # Compute initial skip decisions without mutating the skeleton. + for task_node in self._pipeline_graph.tasks.values(): + for quantum_key in full_skeleton.get_quanta(task_node.label): + skip_decisions[quantum_key] = self._compute_skip_decision( + task_node, quantum_key, full_skeleton + ) + # Unskip ancestor quanta whose outputs are not retained. + n_unskipped = self._unskip_ancestors(full_skeleton, skip_decisions, retained_types) + if n_unskipped: + self.log.info( + "Forcing %s to rerun (output not retained).", + _quantum_or_quanta(n_unskipped), + ) + # Apply decisions. + for task_node in self._pipeline_graph.tasks.values(): + skipped_quanta[task_node.label] = self._resolve_task_quanta( + task_node, full_skeleton, skip_decisions=skip_decisions + ) # Add any dimension records not handled by the subclass, and # aggregate any that were added directly to data IDs. full_skeleton.attach_dimension_records(self.butler, self._pipeline_graph.get_all_dimensions()) @@ -490,7 +530,12 @@ class can add them later, albeit possibly less efficiently). @final @timeMethod - def _resolve_task_quanta(self, task_node: TaskNode, skeleton: QuantumGraphSkeleton) -> list[QuantumKey]: + def _resolve_task_quanta( + self, + task_node: TaskNode, + skeleton: QuantumGraphSkeleton, + skip_decisions: dict[QuantumKey, bool] | None = None, + ) -> list[QuantumKey]: """Process the quanta for one task in a skeleton graph to skip those that have already completed and add missing prerequisite inputs. @@ -500,6 +545,9 @@ def _resolve_task_quanta(self, task_node: TaskNode, skeleton: QuantumGraphSkelet Node for this task in the pipeline graph. skeleton : `.quantum_graph_skeleton.QuantumGraphSkeleton` Preliminary quantum graph, to be modified in-place. + skip_decisions : `dict` [ `QuantumKey`, `bool` ] or `None`, optional + Pre-computed per-quantum skip decisions. When provided, the + decisions are applied directly. Returns ------- @@ -531,7 +579,13 @@ def _resolve_task_quanta(self, task_node: TaskNode, skeleton: QuantumGraphSkelet # gotten rid of. skipped_quanta = [] for quantum_key in skeleton.get_quanta(task_node.label): - if self._skip_quantum_if_metadata_exists(task_node, quantum_key, skeleton): + if skip_decisions is not None: + if skip_decisions.get(quantum_key, False): + self._apply_skip_decision(task_node, quantum_key, skeleton) + skipped_quanta.append(quantum_key) + continue + elif self._compute_skip_decision(task_node, quantum_key, skeleton): + self._apply_skip_decision(task_node, quantum_key, skeleton) skipped_quanta.append(quantum_key) continue quantum_data_id = skeleton[quantum_key]["data_id"] @@ -662,7 +716,7 @@ def _adjust_task_quanta( if no_work_quanta: message_terms.append(f"{len(no_work_quanta)} had no work to do") if skipped_quanta: - message_terms.append(f"{len(skipped_quanta)} previously succeeded") + message_terms.append(f"{len(skipped_quanta)} previously succeeded and skipped") if adjuster.n_removed: message_terms.append(f"{adjuster.n_removed} removed by adjust_all_quanta") message_parenthetical = f" ({', '.join(message_terms)})" if message_terms else "" @@ -699,11 +753,11 @@ def _get_task_inputs_if_overall_only(self, task_node: TaskNode) -> list[str] | N return None return result - def _skip_quantum_if_metadata_exists( + def _compute_skip_decision( self, task_node: TaskNode, quantum_key: QuantumKey, skeleton: QuantumGraphSkeleton ) -> bool: - """Identify and drop quanta that should be skipped because their - metadata datasets already exist. + """Identify if a quantum should be skipped because its + metadata dataset already exists. Parameters ---------- @@ -712,18 +766,37 @@ def _skip_quantum_if_metadata_exists( quantum_key : `QuantumKey` Identifier for this quantum in the graph. skeleton : `.quantum_graph_skeleton.QuantumGraphSkeleton` - Preliminary quantum graph, to be modified in-place. + Preliminary quantum graph (not modified). Returns ------- - skipped : `bool` - `True` if the quantum is being skipped and has been removed from - the graph, `False` otherwise. + skip : `bool` + `True` if the quantum's metadata exists in ``skip_existing_in`` and + should be skipped. + """ + metadata_dataset_key = DatasetKey( + task_node.metadata_output.parent_dataset_type_name, quantum_key.data_id_values + ) + return bool(skeleton.get_output_for_skip(metadata_dataset_key)) + + def _apply_skip_decision( + self, task_node: TaskNode, quantum_key: QuantumKey, skeleton: QuantumGraphSkeleton + ) -> None: + """Update the skeleton for a quantum that has been decided to skip. + + Parameters + ---------- + task_node : `pipeline_graph.TaskNode` + Node for this task in the pipeline graph. + quantum_key : `QuantumKey` + Identifier for this quantum in the graph. + skeleton : `.quantum_graph_skeleton.QuantumGraphSkeleton` + Preliminary quantum graph, to be modified in-place. Notes ----- - If the metadata dataset for this quantum exists in the - `skip_existing_in` collections, the quantum will be skipped. This + The metadata dataset for this quantum exists in the + `skip_existing_in` collections and the quantum will be skipped. This causes the quantum node to be removed from the graph. Dataset nodes that were previously the outputs of this quantum will be associated with `lsst.daf.butler.DatasetRef` objects that were found in @@ -731,32 +804,124 @@ def _skip_quantum_if_metadata_exists( there. Any output dataset in `output_run` will be removed from the "output in the way" category. """ - metadata_dataset_key = DatasetKey( - task_node.metadata_output.parent_dataset_type_name, quantum_key.data_id_values + # This quantum's metadata is already present in the + # skip_existing_in collections; we'll skip it. But the presence of + # the metadata dataset doesn't guarantee that all of the other + # outputs we predicted are present; we have to check. + for output_dataset_key in list(skeleton.iter_outputs_of(quantum_key)): + # If this dataset was "in the way" (i.e. already in the + # output run), it isn't anymore. + skeleton.discard_output_in_the_way(output_dataset_key) + if (output_ref := skeleton.get_output_for_skip(output_dataset_key)) is not None: + # Populate the skeleton graph's node attributes + # with the existing DatasetRef, just like a + # predicted output of a non-skipped quantum. + skeleton.set_dataset_ref(output_ref, output_dataset_key) + else: + # Remove this dataset from the skeleton graph, + # because the quantum that would have produced it + # is being skipped and it doesn't already exist. + skeleton.remove_dataset_nodes([output_dataset_key]) + # Removing the quantum node from the graph will happen outside this + # function. + + def _expand_retained_patterns(self, patterns: list[str] | None) -> frozenset[str] | None: + """Expand wildcard patterns into a concrete set of retained dataset + type names. + + Parameters + ---------- + patterns : `list` [ `str` ] or `None` + Dataset type names or glob-style wildcard patterns, or `None` if + the option was not provided. + + Returns + ------- + retained_types : `frozenset` [ `str` ] or `None` + Concrete set of retained dataset type names, or `None` if no + ancestor unskipping is needed (option absent, empty list, or + patterns match everything). + + """ + if patterns is None: + return None + regexes = globToRegex(patterns) + if regexes is ...: + # globToRegex returns Ellipsis when patterns match everything, + # e.g. the list is empty or "*". Treat as "retain everything". + return None + all_names = set(self._pipeline_graph.dataset_types) + result: set[str] = set() + for original, expression in zip(patterns, regexes): + if isinstance(expression, str): + if expression not in all_names: + self.log.warning("Retained dataset type %r not found in the pipeline.", expression) + result.add(expression) + else: + matches = {n for n in all_names if expression.search(n)} + if not matches: + self.log.warning( + "Retained dataset type pattern %r matches no dataset types.", + original, + ) + result.update(matches) + return frozenset(result) + + def _unskip_ancestors( + self, + skeleton: QuantumGraphSkeleton, + skip_decisions: dict[QuantumKey, bool], + retained: frozenset[str], + ) -> int: + """Unskip ancestor quanta whose outputs are not retained. + + Parameters + ---------- + skeleton : `.quantum_graph_skeleton.QuantumGraphSkeleton` + Preliminary quantum graph (not modified). + skip_decisions : `dict` [ `QuantumKey`, `bool` ] + Per-quantum skip decisions, modified in-place. + retained : `frozenset` [ `str` ] + Dataset type names that should be present in + ``skip_existing_in`` when their producing task has been skipped. + Types not in this set are treated as not retained. + + Returns + ------- + n_unskipped : `int` + Number of quanta unskipped by backward propagation. + + Notes + ----- + Seeds the breadth-first search with every initially must-run quantum. + For each must-run quantum, walks non-prerequisite input edges whose + dataset type is not retained. If the producer of such a dataset is + currently marked skip, it is unskipped and enqueued so that its own + inputs are examined in turn. Each quantum is visited at most once. + """ + queue: collections.deque[QuantumKey] = collections.deque( + qk for qk, skip in skip_decisions.items() if not skip ) - if skeleton.get_output_for_skip(metadata_dataset_key): - # This quantum's metadata is already present in the the - # skip_existing_in collections; we'll skip it. But the presence of - # the metadata dataset doesn't guarantee that all of the other - # outputs we predicted are present; we have to check. - for output_dataset_key in list(skeleton.iter_outputs_of(quantum_key)): - # If this dataset was "in the way" (i.e. already in the - # output run), it isn't anymore. - skeleton.discard_output_in_the_way(output_dataset_key) - if (output_ref := skeleton.get_output_for_skip(output_dataset_key)) is not None: - # Populate the skeleton graph's node attributes - # with the existing DatasetRef, just like a - # predicted output of a non-skipped quantum. - skeleton.set_dataset_ref(output_ref, output_dataset_key) - else: - # Remove this dataset from the skeleton graph, - # because the quantum that would have produced it - # is being skipped and it doesn't already exist. - skeleton.remove_dataset_nodes([output_dataset_key]) - # Removing the quantum node from the graph will happen outside this - # function. - return True - return False + visited: set[QuantumKey] = set(queue) + n_unskipped = 0 + while queue: + qk = queue.popleft() + for input_key in skeleton.iter_inputs_of(qk): + if input_key.is_prerequisite: + continue + if input_key.parent_dataset_type_name in retained: + continue + for producer_key in skeleton.iter_producers_of(input_key): + if not isinstance(producer_key, QuantumKey): + continue + if producer_key in visited: + continue + visited.add(producer_key) + if skip_decisions.get(producer_key, False): + skip_decisions[producer_key] = False + queue.append(producer_key) + n_unskipped += 1 + return n_unskipped @final def _update_quantum_for_adjust( diff --git a/python/lsst/pipe/base/quantum_graph_skeleton.py b/python/lsst/pipe/base/quantum_graph_skeleton.py index 64c808bd4..a8bbd3576 100644 --- a/python/lsst/pipe/base/quantum_graph_skeleton.py +++ b/python/lsst/pipe/base/quantum_graph_skeleton.py @@ -308,6 +308,21 @@ def iter_inputs_of( """ return self._xgraph.predecessors(quantum_key) + def iter_producers_of(self, dataset_key: DatasetKey) -> Iterator[QuantumKey | TaskInitKey]: + """Iterate over the quanta that produce the given dataset. + + Parameters + ---------- + dataset_key : `DatasetKey` + Dataset to look up producers for. + + Returns + ------- + quanta : `~collections.abc.Iterator` of `QuantumKey` or `TaskInitKey` + Quanta that produce the given dataset. + """ + return self._xgraph.predecessors(dataset_key) + def update(self, other: QuantumGraphSkeleton) -> None: """Copy all nodes from ``other`` to ``self``. diff --git a/python/lsst/pipe/base/separable_pipeline_executor.py b/python/lsst/pipe/base/separable_pipeline_executor.py index 52d694ac9..ece66a0a4 100644 --- a/python/lsst/pipe/base/separable_pipeline_executor.py +++ b/python/lsst/pipe/base/separable_pipeline_executor.py @@ -84,6 +84,15 @@ class SeparablePipelineExecutor: for existing outputs, and skips any quanta that have run to completion (or have no work to do). Otherwise, all tasks are attempted (subject to ``clobber_output``). + retained_dataset_types : `~collections.abc.Iterable` [`str`], optional + Dataset type names or glob-style wildcard patterns for types that + should be present in ``skip_existing_in`` whenever the producing task + ran successfully. Dataset types not in this list are treated as not + retained: when a downstream quantum must run, the builder propagates + the must-run signal backward through non-retained input edges, forcing + the upstream quanta that need to regenerate those intermediates to also + run. Has no effect without ``skip_existing_in``. ``["*"]`` means + retaining all datasets, equivalent to not providing this option. task_factory : `.TaskFactory`, optional A custom task factory for use in pre-execution and execution. By default, a new instance of `.TaskFactory` is used. @@ -101,6 +110,7 @@ def __init__( butler: Butler, clobber_output: bool = False, skip_existing_in: Iterable[str] | None = None, + retained_dataset_types: Iterable[str] | None = None, task_factory: TaskFactory | None = None, resources: ExecutionResources | None = None, raise_on_partial_outputs: bool = True, @@ -115,6 +125,7 @@ def __init__( self._clobber_output = clobber_output self._skip_existing_in = list(skip_existing_in) if skip_existing_in else [] + self._retained_dataset_types = list(retained_dataset_types) if retained_dataset_types else None self._task_factory = task_factory if task_factory else TaskFactory() self.resources = resources @@ -216,6 +227,7 @@ class are provided automatically (from explicit arguments to this pipeline.to_graph(), self._butler, skip_existing_in=self._skip_existing_in, + retained_dataset_types=self._retained_dataset_types, clobber=self._clobber_output, **kwargs, ) @@ -276,6 +288,7 @@ class are provided automatically (from explicit arguments to this "output_run": self._butler.run, "skip_existing_in": self._skip_existing_in, "skip_existing": bool(self._skip_existing_in), + "retained_dataset_types": self._retained_dataset_types, "data_query": where, "user": getpass.getuser(), "time": str(datetime.datetime.now()), @@ -344,6 +357,7 @@ class are provided automatically (from explicit arguments to this metadata = { "skip_existing_in": self._skip_existing_in, "skip_existing": bool(self._skip_existing_in), + "retained_dataset_types": self._retained_dataset_types, "data_query": where, } qg_builder = self.make_quantum_graph_builder(pipeline, where, builder_class=builder_class, **kwargs) diff --git a/tests/test_graphBuilder.py b/tests/test_graphBuilder.py index b5587b8da..6bc764af4 100644 --- a/tests/test_graphBuilder.py +++ b/tests/test_graphBuilder.py @@ -32,7 +32,7 @@ import unittest import lsst.utils.tests -from lsst.daf.butler import Butler, DatasetType +from lsst.daf.butler import Butler, DataCoordinate, DatasetType from lsst.daf.butler.registry import UserExpressionError from lsst.pipe.base import PipelineGraph, QuantumGraph from lsst.pipe.base.all_dimensions_quantum_graph_builder import ( @@ -44,6 +44,7 @@ DynamicConnectionConfig, DynamicTestPipelineTask, DynamicTestPipelineTaskConfig, + InMemoryRepo, MockDataset, MockStorageClass, ) @@ -228,6 +229,242 @@ def test_datastore_records(self): self.assertEqual(quantum.datastore_records, {}) +class SkipExistingInTestCase(unittest.TestCase): + """Tests for the skip_existing_in behavior of QuantumGraphBuilder.""" + + def setUp(self): + self.helper = InMemoryRepo() + self.enterContext(self.helper) + self.helper.add_task() + self.helper.make_quantum_graph_builder(output_run="new_run") + self.helper.butler.collections.register("prior_run") + self._task_node = self.helper.pipeline_graph.tasks["task_auto1"] + self._empty_data_id = DataCoordinate.make_empty(self.helper.butler.dimensions) + + def _insert(self, *names, run="prior_run"): + """Register datasets with empty data IDs into a run collection.""" + for name in names: + dt = self.helper.pipeline_graph.dataset_types[name].dataset_type + self.helper.butler.registry.insertDatasets(dt, [self._empty_data_id], run=run) + + def _build(self, *, output_run="new_run", **kwargs): + return AllDimensionsQuantumGraphBuilder( + self.helper.pipeline_graph, + self.helper.butler, + input_collections=[self.helper.input_chain], + output_run=output_run, + **kwargs, + ).build(attach_datastore_records=False) + + def test_not_skipped_without_skip_existing_in(self): + """Without skip_existing_in, a quantum is never skipped even if + metadata exists in an input collection. + """ + self._insert(self._task_node.metadata_output.parent_dataset_type_name) + qgraph = self._build() + self.assertEqual(len(qgraph), 1) + + def test_skipped_when_metadata_exists(self): + """With skip_existing_in, a quantum is skipped when its metadata + dataset is present in the specified collections. + """ + self._insert(self._task_node.metadata_output.parent_dataset_type_name) + # Init-outputs required, otherwise InitInputMissingError. + for edge in self._task_node.init.iter_all_outputs(): + self._insert(edge.parent_dataset_type_name) + qgraph = self._build(skip_existing_in=["prior_run"]) + self.assertEqual(len(qgraph), 0) + + def test_not_skipped_when_metadata_absent(self): + """With skip_existing_in, a quantum is not skipped when its metadata + dataset is absent from the specified collections. + """ + qgraph = self._build(skip_existing_in=["prior_run"]) + self.assertEqual(len(qgraph), 1) + + +class RetainedDatasetTypesTestCase(unittest.TestCase): + """Tests for QuantumGraphBuilder.retained_dataset_types. + + dataset_auto0 -> task_auto1 -> dataset_auto1 -> task_auto2 + """ + + def setUp(self): + self.helper = InMemoryRepo() + self.enterContext(self.helper) + self.helper.add_task() + self.helper.add_task() + self.helper.make_quantum_graph_builder(output_run="new_run") + self.helper.butler.collections.register("prior_run") + self._task1 = self.helper.pipeline_graph.tasks["task_auto1"] + self._task2 = self.helper.pipeline_graph.tasks["task_auto2"] + self._empty_data_id = DataCoordinate.make_empty(self.helper.butler.dimensions) + + def _insert(self, *names, run="prior_run"): + """Register datasets with empty data IDs into a run collection.""" + for name in names: + dt = self.helper.pipeline_graph.dataset_types[name].dataset_type + self.helper.butler.registry.insertDatasets(dt, [self._empty_data_id], run=run) + + def _build(self, *, output_run="new_run", **kwargs): + return AllDimensionsQuantumGraphBuilder( + self.helper.pipeline_graph, + self.helper.butler, + input_collections=[self.helper.input_chain], + output_run=output_run, + **kwargs, + ).build(attach_datastore_records=False) + + def test_raises_without_skip_existing_in(self): + """retained_dataset_types invalid without skip_existing_in.""" + with self.assertRaises(ValueError): + self._build(retained_dataset_types=["dataset_auto1"]) + + def test_ancestor_unskipped_when_output_not_retained(self): + """task1 ran (metadata present) but did not retain its output; + task2 must run. Because dataset_auto1 is not retained, task1 + is unskipped to regenerate it. + """ + # task1 succeeded previously, but dataset_auto1 not retained. + self._insert(self._task1.metadata_output.parent_dataset_type_name) + qgraph = self._build( + skip_existing_in=["prior_run"], + retained_dataset_types=["*_metadata"], + ) + # Both tasks run: task1 regenerate dataset_auto1 for task2. + self.assertEqual(len(qgraph), 2) + + def test_ancestor_not_unskipped_when_output_retained(self): + """When the intermediate output is declared retained and is present in + skip_existing_in, unskipping stops there and task1 remains skipped. + """ + # task1 metadata and its output dataset_auto1 both present in + # prior_run. + self._insert(self._task1.metadata_output.parent_dataset_type_name) + self._insert("dataset_auto1") + for edge in self._task1.init.iter_all_outputs(): + self._insert(edge.parent_dataset_type_name) + qgraph = self._build( + skip_existing_in=["prior_run"], + retained_dataset_types=["dataset_auto1", "*_metadata"], + ) + # Only task2 runs. + self.assertEqual(len(qgraph), 1) + + def test_both_skipped_when_both_have_metadata(self): + """When both tasks have metadata, both remain skipped regardless of + which outputs are not retained. + """ + self._insert(self._task1.metadata_output.parent_dataset_type_name) + self._insert(self._task2.metadata_output.parent_dataset_type_name) + for edge in self._task1.init.iter_all_outputs(): + self._insert(edge.parent_dataset_type_name) + for edge in self._task2.init.iter_all_outputs(): + self._insert(edge.parent_dataset_type_name) + qgraph = self._build( + skip_existing_in=["prior_run"], + retained_dataset_types=["*_metadata"], + ) + self.assertEqual(len(qgraph), 0) + + def test_unrecognised_pattern_warns(self): + """Literal names and wildcard patterns that match nothing in the + pipeline emit a WARNING log message. + """ + with self.assertLogs("lsst.pipe.base.quantum_graph_builder", level="WARNING") as cm: + self._build( + skip_existing_in=["prior_run"], + retained_dataset_types=["no_such_dataset_type", "no_such_*"], + ) + self.assertTrue(any("no_such_dataset_type" in msg for msg in cm.output)) + self.assertTrue(any("no_such_*" in msg for msg in cm.output)) + + def test_no_unskipping_when_all_retained(self): + """'*' matches all dataset types; no ancestor unskipping occurs, + equivalent to not providing retained_dataset_types. + """ + # task1 ran; metadata and dataset_auto1 present. + self._insert(self._task1.metadata_output.parent_dataset_type_name) + self._insert("dataset_auto1") + for edge in self._task1.init.iter_all_outputs(): + self._insert(edge.parent_dataset_type_name) + qgraph = self._build( + skip_existing_in=["prior_run"], + retained_dataset_types=["*"], + ) + # All types retained -> no unskipping -> task1 stays skipped, + # only task2 runs. + self.assertEqual(len(qgraph), 1) + + +class RetainedDatasetTypesThreeTaskTestCase(unittest.TestCase): + """Tests for retained_dataset_types with a 3-task chain. + + Pipeline: dataset_auto0 -> task_auto1 -> dataset_auto1 + -> task_auto2 -> dataset_auto2 + -> task_auto3 -> dataset_auto3 + """ + + def setUp(self): + self.helper = InMemoryRepo() + self.enterContext(self.helper) + self.helper.add_task() + self.helper.add_task() + self.helper.add_task() + self.helper.make_quantum_graph_builder(output_run="new_run") + self.helper.butler.collections.register("prior_run") + self._task1 = self.helper.pipeline_graph.tasks["task_auto1"] + self._task2 = self.helper.pipeline_graph.tasks["task_auto2"] + self._task3 = self.helper.pipeline_graph.tasks["task_auto3"] + self._empty_data_id = DataCoordinate.make_empty(self.helper.butler.dimensions) + + def _insert(self, *names, run="prior_run"): + """Register datasets with empty data IDs into a run collection.""" + for name in names: + dt = self.helper.pipeline_graph.dataset_types[name].dataset_type + self.helper.butler.registry.insertDatasets(dt, [self._empty_data_id], run=run) + + def _build(self, *, output_run="new_run", **kwargs): + return AllDimensionsQuantumGraphBuilder( + self.helper.pipeline_graph, + self.helper.butler, + input_collections=[self.helper.input_chain], + output_run=output_run, + **kwargs, + ).build(attach_datastore_records=False) + + def test_unskipping_stops_at_retained_intermediate(self): + """task2's output is retained and present in skip_existing_in. + Only task3 runs, task1 and task2 remain skipped. + """ + self._insert(self._task1.metadata_output.parent_dataset_type_name) + self._insert(self._task2.metadata_output.parent_dataset_type_name) + self._insert("dataset_auto2") + for edge in self._task1.init.iter_all_outputs(): + self._insert(edge.parent_dataset_type_name) + for edge in self._task2.init.iter_all_outputs(): + self._insert(edge.parent_dataset_type_name) + qgraph = self._build( + skip_existing_in=["prior_run"], + retained_dataset_types=["dataset_auto2", "*_metadata"], + ) + self.assertEqual(len(qgraph), 1) + + def test_full_chain_unskipped_when_none_retained(self): + """task3 needs to run. Unskipping walks back through + dataset_auto2 (not retained) to unskip task2, then through + dataset_auto1 (not retained) to unskip task1. All three tasks + run to regenerate the non-retained datasets. + """ + self._insert(self._task1.metadata_output.parent_dataset_type_name) + self._insert(self._task2.metadata_output.parent_dataset_type_name) + qgraph = self._build( + skip_existing_in=["prior_run"], + retained_dataset_types=["*_metadata"], + ) + self.assertEqual(len(qgraph), 3) + + if __name__ == "__main__": lsst.utils.tests.init() unittest.main() diff --git a/tests/test_separable_pipeline_executor.py b/tests/test_separable_pipeline_executor.py index 5f71a18f6..704a01dce 100644 --- a/tests/test_separable_pipeline_executor.py +++ b/tests/test_separable_pipeline_executor.py @@ -715,6 +715,68 @@ def test_make_quantum_graph_nowhere_skippartial_clobber(self): self.assertEqual(len(graph), 2) self.assertEqual(graph.quanta_by_task.keys(), {"a", "b"}) + def test_make_quantum_graph_nowhere_retained_forces_upstream_rerun(self): + """When task b must run and 'intermediate' is not retained, + retained_dataset_types forces task a to rerun to regenerate the + missing intermediate. + """ + prior_run = "prior_run" + self.butler.registry.registerCollection(prior_run, lsst.daf.butler.CollectionType.RUN) + executor = SeparablePipelineExecutor( + self.butler, + skip_existing_in=[prior_run], + # Only metadata types are retained; 'intermediate' is not retained. + retained_dataset_types=["*_metadata"], + ) + pipeline = Pipeline.fromFile(self.pipeline_file) + self.butler.put({"zero": 0}, "input") + # Task a metadata present in prior run but 'intermediate' was not + # retained. + self.butler.put(TaskMetadata(), "a_metadata", run=prior_run) + graph = executor.build_quantum_graph(pipeline) + self.assertEqual(len(graph), 2) + self.assertEqual(graph.quanta_by_task.keys(), {"a", "b"}) + + def test_make_quantum_graph_nowhere_retained_metainway(self): + """When an ancestor is forced to rerun but its metadata is already in + the output run, OutputExistsError is raised without clobber_output. + """ + executor = SeparablePipelineExecutor( + self.butler, + skip_existing_in=[self.butler.run], + retained_dataset_types=["*_metadata"], + ) + pipeline = Pipeline.fromFile(self.pipeline_file) + self.butler.put({"zero": 0}, "input") + self.butler.put(TaskMetadata(), "a_metadata") + # a_metadata in output run, task b must run (no b_metadata), and + # 'intermediate' is not retained -> force task a to run + # no clobber -> error. + with self.assertRaises(OutputExistsError): + executor.build_quantum_graph(pipeline) + + def test_make_quantum_graph_nowhere_retained_both_skipped(self): + """Both tasks are skipped when both have metadata, regardless of which + outputs are not retained. + """ + executor = SeparablePipelineExecutor( + self.butler, + skip_existing_in=[self.butler.run], + retained_dataset_types=["*_metadata"], + ) + pipeline = Pipeline.fromFile(self.pipeline_file) + + butlerTests.addDatasetType(self.butler, "b_metadata", set(), "TaskMetadata") + butlerTests.addDatasetType(self.butler, "b_config", set(), "Config") + + self.butler.put(TaskMetadata(), "a_metadata") + self.butler.put(lsst.pex.config.Config(), "a_config") + self.butler.put(TaskMetadata(), "b_metadata") + self.butler.put(lsst.pex.config.Config(), "b_config") + + graph = executor.build_quantum_graph(pipeline) + self.assertEqual(len(graph), 0) + def test_make_quantum_graph_noinput(self): executor = SeparablePipelineExecutor(self.butler) pipeline = Pipeline.fromFile(self.pipeline_file)