diff --git a/daisy/dependency_graph.py b/daisy/dependency_graph.py index 9d38bdbf..6f23e1f0 100644 --- a/daisy/dependency_graph.py +++ b/daisy/dependency_graph.py @@ -4,6 +4,7 @@ from .roi import Roi import numpy as np +from tqdm import tqdm from itertools import product import logging @@ -93,9 +94,16 @@ def __init__( fit: str, total_read_roi: Optional[Roi] = None, total_write_roi: Optional[Roi] = None, + block_filter=None, ): self.block_read_roi = block_read_roi self.block_write_roi = block_write_roi + self.block_filter = block_filter + # When a block_filter is provided, the surviving blocks per level are + # materialized once during __init__ (see _apply_block_filter) so that + # subsequent num_blocks / level_blocks / num_roots calls see only the + # filtered set. Stored as a list[list[Block]] indexed by level. + self._filtered_blocks_by_level = None self.read_write_context = ( block_write_roi.begin - block_read_roi.begin, block_read_roi.end - block_write_roi.end, @@ -144,12 +152,63 @@ def __init__( self._level_offsets = self.compute_level_offsets() self._level_conflicts = self.compute_level_conflicts() + # Eagerly prune blocks if a filter was supplied. + if self.block_filter is not None: + self._apply_block_filter() + + def _apply_block_filter(self): + """Materialize every level's blocks once and keep only those passing + ``block_filter``. After this, ``num_blocks``, ``num_roots``, and + ``level_blocks`` all reflect the surviving set. + + Walks all blocks across all levels with a tqdm progress bar so callers + can see how the eager filter is advancing. + """ + num_levels = len(self._level_offsets) + # Analytical per-level counts let tqdm show a real total without + # exhausting the underlying generator. + per_level_totals = [self._num_level_blocks(level) for level in range(num_levels)] + total = int(sum(per_level_totals)) + + logger.info( + "Task %s: starting block_filter on %d candidate blocks across %d levels...", + self.task_id, + total, + num_levels, + ) + + kept = [] + with tqdm( + total=total, + desc=f"block_filter({self.task_id})", + unit="block", + ) as pbar: + for level in range(num_levels): + level_kept = [] + for b in self._unfiltered_level_blocks(level): + if self.block_filter(b): + level_kept.append(b) + pbar.update(1) + kept.append(level_kept) + + self._filtered_blocks_by_level = kept + surviving = sum(len(k) for k in kept) + logger.info( + "Task %s: block_filter kept %d / %d blocks (%.2f%%) — workers will only see these.", + self.task_id, + surviving, + total, + (100.0 * surviving / total) if total else 0.0, + ) + @property def num_levels(self): return len(self._level_offsets) @property def num_blocks(self): + if self._filtered_blocks_by_level is not None: + return sum(len(b) for b in self._filtered_blocks_by_level) num_blocks = 0 for level in range(self.num_levels): num_blocks += self._num_level_blocks(level) @@ -177,6 +236,8 @@ def fit_block(self): return fit_block def num_roots(self): + if self._filtered_blocks_by_level is not None: + return len(self._filtered_blocks_by_level[0]) return self._num_level_blocks(0) def _num_level_blocks(self, level): @@ -206,6 +267,12 @@ def _num_level_blocks(self, level): return num_blocks def level_blocks(self, level): + if self._filtered_blocks_by_level is not None: + yield from self._filtered_blocks_by_level[level] + return + yield from self._unfiltered_level_blocks(level) + + def _unfiltered_level_blocks(self, level): for block_offset in self._compute_level_block_offsets(level): block = Block( self.total_read_roi, @@ -641,6 +708,7 @@ def __add_task_dependency_graph(self, task): task.read_write_conflict, task.fit, total_read_roi=task.total_roi, + block_filter=getattr(task, "block_filter", None), ) def __enumerate_all_dependencies(self): diff --git a/daisy/task.py b/daisy/task.py index 0a668b51..be19d724 100644 --- a/daisy/task.py +++ b/daisy/task.py @@ -54,6 +54,25 @@ class Task: to check if the block needs to be run, and if so, the second one will be called after it was run to check if the run succeeded. + block_filter (function, optional): + + A function that will be called once per block at graph-construction + time, in the master process, before any worker is spawned:: + + block_filter(block) -> bool + + Return ``True`` to keep the block and ``False`` to drop it from the + graph entirely. Dropped blocks never count toward the task's block + total and are never offered to a worker. Use this to prune blocks + whose work can be decided up-front (e.g. blocks that fall outside + an inference mask), so that ``num_workers`` workers are only spawned + when there is actual work to do. + + This differs from ``check_function``: ``check_function`` runs + lazily, per block, each time a worker tries to acquire one, and + marks blocks as already-done; ``block_filter`` runs eagerly, once, + and removes blocks from the graph before scheduling begins. + init_callback_fn (function, optional): A function that Daisy will call once when the task is started. @@ -133,6 +152,7 @@ def __init__( write_roi, process_function, check_function=None, + block_filter=None, init_callback_fn=None, read_write_conflict=True, num_workers=1, @@ -152,6 +172,7 @@ def __init__( ) self.process_function = process_function self.check_function = check_function + self.block_filter = block_filter self.read_write_conflict = read_write_conflict self.fit = fit self.num_workers = num_workers diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index b22cb095..dd67ad2c 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -542,3 +542,74 @@ def process_block(block): task_state.failed_count + task_state.orphaned_count == task_state.total_block_count ), task_state + + +def test_block_filter_prunes_blocks_eagerly(): + """A block_filter passed to Task should remove blocks from the graph at + construction time, before scheduling begins. The scheduler should never + see filtered-out blocks and total_block_count should reflect the + surviving set. + """ + # No filter: a 1D task over 6 blocks (block_ids 1..6 with write_roi + # offsets at 1, 2, 3, 4, 5, 6). + unfiltered = Task( + task_id="filter_off", + total_roi=Roi((0,), (8,)), + read_roi=Roi((0,), (3,)), + write_roi=Roi((1,), (1,)), + process_function=process_block, + check_function=None, + read_write_conflict=False, + ) + unfiltered_sched = Scheduler([unfiltered]) + unfiltered_total = unfiltered_sched.task_states[unfiltered.task_id].total_block_count + assert unfiltered_total == 6 + + # With filter: keep only blocks whose write_roi begins at an even offset. + def even_offset_only(block): + return block.write_roi.begin[0] % 2 == 0 + + filtered = Task( + task_id="filter_on", + total_roi=Roi((0,), (8,)), + read_roi=Roi((0,), (3,)), + write_roi=Roi((1,), (1,)), + process_function=process_block, + check_function=None, + block_filter=even_offset_only, + read_write_conflict=False, + ) + sched = Scheduler([filtered]) + assert ( + sched.task_states[filtered.task_id].total_block_count == 3 + ), "filter should leave 3 blocks (write_roi begins at 2, 4, 6)" + + seen_offsets = [] + while True: + block = sched.acquire_block(filtered.task_id) + if block is None: + break + seen_offsets.append(block.write_roi.begin[0]) + block.status = BlockStatus.SUCCESS + sched.release_block(block) + + assert sorted(seen_offsets) == [2, 4, 6], seen_offsets + + +def test_block_filter_all_dropped_yields_no_work(): + """A block_filter that rejects every block leaves the task with zero + blocks; acquire_block returns None immediately. + """ + task = Task( + task_id="filter_all_out", + total_roi=Roi((0,), (8,)), + read_roi=Roi((0,), (3,)), + write_roi=Roi((1,), (1,)), + process_function=process_block, + check_function=None, + block_filter=lambda b: False, + read_write_conflict=False, + ) + sched = Scheduler([task]) + assert sched.task_states[task.task_id].total_block_count == 0 + assert sched.acquire_block(task.task_id) is None