From 4f63fe01e337fbac29c4151aabd2d5a7b192b7de Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Sun, 7 Jun 2026 20:38:25 -0400 Subject: [PATCH 1/2] Add Task.block_filter for eager block pruning before workers spawn MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A new optional `block_filter` parameter on `daisy.Task` lets callers drop blocks from the dependency graph at construction time, before any worker process is spawned and before the scheduler begins handing blocks out. This is distinct from `check_function`, which runs lazily per block as a worker tries to acquire one and marks the block as already-completed. `block_filter` runs once, in the master, and the filtered blocks never count toward `total_block_count` and are never offered to a worker. Motivation: large blockwise inference jobs over sparse volumes (e.g. restricted to a coarse inference mask) often have tens of millions of candidate blocks but only a small fraction of real work. Today `num_workers` workers are bsub-launched up-front regardless, then sit idle while the master walks the block grid; with `block_filter` the graph collapses to just the live blocks before the worker pool is brought up. Wiring: - `Task.__init__` accepts `block_filter: Optional[Callable[[Block], bool]]` - `DependencyGraph.__add_task_dependency_graph` forwards it to the inner `BlockwiseDependencyGraph` - When set, `BlockwiseDependencyGraph` materializes the surviving blocks per level once in `__init__`. `num_blocks`, `num_roots`, and `level_blocks` then read from the cached filtered set. The original lazy enumeration path is preserved as `_unfiltered_level_blocks` and used when no filter is supplied — no behavior change for existing callers. Tests in `tests/test_scheduler.py` cover the typical case (filter half the blocks, scheduler only ever returns the kept ones) and the zero-blocks-after-filter degenerate case. --- daisy/dependency_graph.py | 32 ++++++++++++++++++ daisy/task.py | 21 ++++++++++++ tests/test_scheduler.py | 71 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 124 insertions(+) diff --git a/daisy/dependency_graph.py b/daisy/dependency_graph.py index 9d38bdbf..0fa8ee5c 100644 --- a/daisy/dependency_graph.py +++ b/daisy/dependency_graph.py @@ -93,9 +93,16 @@ def __init__( fit: str, total_read_roi: Optional[Roi] = None, total_write_roi: Optional[Roi] = None, + block_filter=None, ): self.block_read_roi = block_read_roi self.block_write_roi = block_write_roi + self.block_filter = block_filter + # When a block_filter is provided, the surviving blocks per level are + # materialized once during __init__ (see _apply_block_filter) so that + # subsequent num_blocks / level_blocks / num_roots calls see only the + # filtered set. Stored as a list[list[Block]] indexed by level. + self._filtered_blocks_by_level = None self.read_write_context = ( block_write_roi.begin - block_read_roi.begin, block_read_roi.end - block_write_roi.end, @@ -144,12 +151,28 @@ def __init__( self._level_offsets = self.compute_level_offsets() self._level_conflicts = self.compute_level_conflicts() + # Eagerly prune blocks if a filter was supplied. + if self.block_filter is not None: + self._apply_block_filter() + + def _apply_block_filter(self): + """Materialize every level's blocks once and keep only those passing + ``block_filter``. After this, ``num_blocks``, ``num_roots``, and + ``level_blocks`` all reflect the surviving set. + """ + self._filtered_blocks_by_level = [ + [b for b in self._unfiltered_level_blocks(level) if self.block_filter(b)] + for level in range(len(self._level_offsets)) + ] + @property def num_levels(self): return len(self._level_offsets) @property def num_blocks(self): + if self._filtered_blocks_by_level is not None: + return sum(len(b) for b in self._filtered_blocks_by_level) num_blocks = 0 for level in range(self.num_levels): num_blocks += self._num_level_blocks(level) @@ -177,6 +200,8 @@ def fit_block(self): return fit_block def num_roots(self): + if self._filtered_blocks_by_level is not None: + return len(self._filtered_blocks_by_level[0]) return self._num_level_blocks(0) def _num_level_blocks(self, level): @@ -206,6 +231,12 @@ def _num_level_blocks(self, level): return num_blocks def level_blocks(self, level): + if self._filtered_blocks_by_level is not None: + yield from self._filtered_blocks_by_level[level] + return + yield from self._unfiltered_level_blocks(level) + + def _unfiltered_level_blocks(self, level): for block_offset in self._compute_level_block_offsets(level): block = Block( self.total_read_roi, @@ -641,6 +672,7 @@ def __add_task_dependency_graph(self, task): task.read_write_conflict, task.fit, total_read_roi=task.total_roi, + block_filter=getattr(task, "block_filter", None), ) def __enumerate_all_dependencies(self): diff --git a/daisy/task.py b/daisy/task.py index 0a668b51..be19d724 100644 --- a/daisy/task.py +++ b/daisy/task.py @@ -54,6 +54,25 @@ class Task: to check if the block needs to be run, and if so, the second one will be called after it was run to check if the run succeeded. + block_filter (function, optional): + + A function that will be called once per block at graph-construction + time, in the master process, before any worker is spawned:: + + block_filter(block) -> bool + + Return ``True`` to keep the block and ``False`` to drop it from the + graph entirely. Dropped blocks never count toward the task's block + total and are never offered to a worker. Use this to prune blocks + whose work can be decided up-front (e.g. blocks that fall outside + an inference mask), so that ``num_workers`` workers are only spawned + when there is actual work to do. + + This differs from ``check_function``: ``check_function`` runs + lazily, per block, each time a worker tries to acquire one, and + marks blocks as already-done; ``block_filter`` runs eagerly, once, + and removes blocks from the graph before scheduling begins. + init_callback_fn (function, optional): A function that Daisy will call once when the task is started. @@ -133,6 +152,7 @@ def __init__( write_roi, process_function, check_function=None, + block_filter=None, init_callback_fn=None, read_write_conflict=True, num_workers=1, @@ -152,6 +172,7 @@ def __init__( ) self.process_function = process_function self.check_function = check_function + self.block_filter = block_filter self.read_write_conflict = read_write_conflict self.fit = fit self.num_workers = num_workers diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index b22cb095..dd67ad2c 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -542,3 +542,74 @@ def process_block(block): task_state.failed_count + task_state.orphaned_count == task_state.total_block_count ), task_state + + +def test_block_filter_prunes_blocks_eagerly(): + """A block_filter passed to Task should remove blocks from the graph at + construction time, before scheduling begins. The scheduler should never + see filtered-out blocks and total_block_count should reflect the + surviving set. + """ + # No filter: a 1D task over 6 blocks (block_ids 1..6 with write_roi + # offsets at 1, 2, 3, 4, 5, 6). + unfiltered = Task( + task_id="filter_off", + total_roi=Roi((0,), (8,)), + read_roi=Roi((0,), (3,)), + write_roi=Roi((1,), (1,)), + process_function=process_block, + check_function=None, + read_write_conflict=False, + ) + unfiltered_sched = Scheduler([unfiltered]) + unfiltered_total = unfiltered_sched.task_states[unfiltered.task_id].total_block_count + assert unfiltered_total == 6 + + # With filter: keep only blocks whose write_roi begins at an even offset. + def even_offset_only(block): + return block.write_roi.begin[0] % 2 == 0 + + filtered = Task( + task_id="filter_on", + total_roi=Roi((0,), (8,)), + read_roi=Roi((0,), (3,)), + write_roi=Roi((1,), (1,)), + process_function=process_block, + check_function=None, + block_filter=even_offset_only, + read_write_conflict=False, + ) + sched = Scheduler([filtered]) + assert ( + sched.task_states[filtered.task_id].total_block_count == 3 + ), "filter should leave 3 blocks (write_roi begins at 2, 4, 6)" + + seen_offsets = [] + while True: + block = sched.acquire_block(filtered.task_id) + if block is None: + break + seen_offsets.append(block.write_roi.begin[0]) + block.status = BlockStatus.SUCCESS + sched.release_block(block) + + assert sorted(seen_offsets) == [2, 4, 6], seen_offsets + + +def test_block_filter_all_dropped_yields_no_work(): + """A block_filter that rejects every block leaves the task with zero + blocks; acquire_block returns None immediately. + """ + task = Task( + task_id="filter_all_out", + total_roi=Roi((0,), (8,)), + read_roi=Roi((0,), (3,)), + write_roi=Roi((1,), (1,)), + process_function=process_block, + check_function=None, + block_filter=lambda b: False, + read_write_conflict=False, + ) + sched = Scheduler([task]) + assert sched.task_states[task.task_id].total_block_count == 0 + assert sched.acquire_block(task.task_id) is None From d1c7be84f6885e658f0e272f2532cc0350234054 Mon Sep 17 00:00:00 2001 From: Marwan Zouinkhi Date: Sun, 7 Jun 2026 20:51:09 -0400 Subject: [PATCH 2/2] block_filter: log start/end and show tqdm progress When `block_filter` is set, `_apply_block_filter` can take many seconds to minutes on large block grids (e.g. ~14M candidate blocks for a sparse-mask volumetric inference). Emit an INFO log at start, drive a tqdm progress bar across all levels, and log the surviving block count at the end so callers can tell whether the master is making progress or stuck. Per-level totals are computed analytically up-front so tqdm reports a real total without exhausting the underlying generator. --- daisy/dependency_graph.py | 44 +++++++++++++++++++++++++++++++++++---- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/daisy/dependency_graph.py b/daisy/dependency_graph.py index 0fa8ee5c..6f23e1f0 100644 --- a/daisy/dependency_graph.py +++ b/daisy/dependency_graph.py @@ -4,6 +4,7 @@ from .roi import Roi import numpy as np +from tqdm import tqdm from itertools import product import logging @@ -159,11 +160,46 @@ def _apply_block_filter(self): """Materialize every level's blocks once and keep only those passing ``block_filter``. After this, ``num_blocks``, ``num_roots``, and ``level_blocks`` all reflect the surviving set. + + Walks all blocks across all levels with a tqdm progress bar so callers + can see how the eager filter is advancing. """ - self._filtered_blocks_by_level = [ - [b for b in self._unfiltered_level_blocks(level) if self.block_filter(b)] - for level in range(len(self._level_offsets)) - ] + num_levels = len(self._level_offsets) + # Analytical per-level counts let tqdm show a real total without + # exhausting the underlying generator. + per_level_totals = [self._num_level_blocks(level) for level in range(num_levels)] + total = int(sum(per_level_totals)) + + logger.info( + "Task %s: starting block_filter on %d candidate blocks across %d levels...", + self.task_id, + total, + num_levels, + ) + + kept = [] + with tqdm( + total=total, + desc=f"block_filter({self.task_id})", + unit="block", + ) as pbar: + for level in range(num_levels): + level_kept = [] + for b in self._unfiltered_level_blocks(level): + if self.block_filter(b): + level_kept.append(b) + pbar.update(1) + kept.append(level_kept) + + self._filtered_blocks_by_level = kept + surviving = sum(len(k) for k in kept) + logger.info( + "Task %s: block_filter kept %d / %d blocks (%.2f%%) — workers will only see these.", + self.task_id, + surviving, + total, + (100.0 * surviving / total) if total else 0.0, + ) @property def num_levels(self):