Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions daisy/dependency_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .roi import Roi

import numpy as np
from tqdm import tqdm

from itertools import product
import logging
Expand Down Expand Up @@ -93,9 +94,16 @@ def __init__(
fit: str,
total_read_roi: Optional[Roi] = None,
total_write_roi: Optional[Roi] = None,
block_filter=None,
):
self.block_read_roi = block_read_roi
self.block_write_roi = block_write_roi
self.block_filter = block_filter
# When a block_filter is provided, the surviving blocks per level are
# materialized once during __init__ (see _apply_block_filter) so that
# subsequent num_blocks / level_blocks / num_roots calls see only the
# filtered set. Stored as a list[list[Block]] indexed by level.
self._filtered_blocks_by_level = None
self.read_write_context = (
block_write_roi.begin - block_read_roi.begin,
block_read_roi.end - block_write_roi.end,
Expand Down Expand Up @@ -144,12 +152,63 @@ def __init__(
self._level_offsets = self.compute_level_offsets()
self._level_conflicts = self.compute_level_conflicts()

# Eagerly prune blocks if a filter was supplied.
if self.block_filter is not None:
self._apply_block_filter()

def _apply_block_filter(self):
"""Materialize every level's blocks once and keep only those passing
``block_filter``. After this, ``num_blocks``, ``num_roots``, and
``level_blocks`` all reflect the surviving set.

Walks all blocks across all levels with a tqdm progress bar so callers
Comment on lines +160 to +164
can see how the eager filter is advancing.
"""
num_levels = len(self._level_offsets)
# Analytical per-level counts let tqdm show a real total without
# exhausting the underlying generator.
per_level_totals = [self._num_level_blocks(level) for level in range(num_levels)]
total = int(sum(per_level_totals))

logger.info(
"Task %s: starting block_filter on %d candidate blocks across %d levels...",
self.task_id,
total,
num_levels,
)

kept = []
with tqdm(
total=total,
desc=f"block_filter({self.task_id})",
unit="block",
) as pbar:
for level in range(num_levels):
level_kept = []
for b in self._unfiltered_level_blocks(level):
if self.block_filter(b):
level_kept.append(b)
pbar.update(1)
kept.append(level_kept)

self._filtered_blocks_by_level = kept
surviving = sum(len(k) for k in kept)
logger.info(
"Task %s: block_filter kept %d / %d blocks (%.2f%%) — workers will only see these.",
self.task_id,
surviving,
total,
(100.0 * surviving / total) if total else 0.0,
)

@property
def num_levels(self):
return len(self._level_offsets)

@property
def num_blocks(self):
if self._filtered_blocks_by_level is not None:
return sum(len(b) for b in self._filtered_blocks_by_level)
num_blocks = 0
for level in range(self.num_levels):
num_blocks += self._num_level_blocks(level)
Expand Down Expand Up @@ -177,6 +236,8 @@ def fit_block(self):
return fit_block

def num_roots(self):
if self._filtered_blocks_by_level is not None:
return len(self._filtered_blocks_by_level[0])
return self._num_level_blocks(0)
Comment on lines 238 to 241

def _num_level_blocks(self, level):
Expand Down Expand Up @@ -206,6 +267,12 @@ def _num_level_blocks(self, level):
return num_blocks

def level_blocks(self, level):
if self._filtered_blocks_by_level is not None:
yield from self._filtered_blocks_by_level[level]
return
yield from self._unfiltered_level_blocks(level)

def _unfiltered_level_blocks(self, level):
for block_offset in self._compute_level_block_offsets(level):
block = Block(
self.total_read_roi,
Expand Down Expand Up @@ -641,6 +708,7 @@ def __add_task_dependency_graph(self, task):
task.read_write_conflict,
task.fit,
total_read_roi=task.total_roi,
block_filter=getattr(task, "block_filter", None),
)

def __enumerate_all_dependencies(self):
Expand Down
21 changes: 21 additions & 0 deletions daisy/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,25 @@ class Task:
to check if the block needs to be run, and if so, the second one
will be called after it was run to check if the run succeeded.

block_filter (function, optional):

A function that will be called once per block at graph-construction
time, in the master process, before any worker is spawned::

block_filter(block) -> bool

Return ``True`` to keep the block and ``False`` to drop it from the
graph entirely. Dropped blocks never count toward the task's block
total and are never offered to a worker. Use this to prune blocks
whose work can be decided up-front (e.g. blocks that fall outside
an inference mask), so that ``num_workers`` workers are only spawned
when there is actual work to do.

This differs from ``check_function``: ``check_function`` runs
lazily, per block, each time a worker tries to acquire one, and
marks blocks as already-done; ``block_filter`` runs eagerly, once,
and removes blocks from the graph before scheduling begins.

init_callback_fn (function, optional):

A function that Daisy will call once when the task is started.
Expand Down Expand Up @@ -133,6 +152,7 @@ def __init__(
write_roi,
process_function,
check_function=None,
block_filter=None,
init_callback_fn=None,
Comment on lines 152 to 156
read_write_conflict=True,
num_workers=1,
Expand All @@ -152,6 +172,7 @@ def __init__(
)
self.process_function = process_function
self.check_function = check_function
self.block_filter = block_filter
self.read_write_conflict = read_write_conflict
self.fit = fit
self.num_workers = num_workers
Expand Down
71 changes: 71 additions & 0 deletions tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,3 +542,74 @@ def process_block(block):
task_state.failed_count + task_state.orphaned_count
== task_state.total_block_count
), task_state


def test_block_filter_prunes_blocks_eagerly():
"""A block_filter passed to Task should remove blocks from the graph at
construction time, before scheduling begins. The scheduler should never
see filtered-out blocks and total_block_count should reflect the
surviving set.
"""
# No filter: a 1D task over 6 blocks (block_ids 1..6 with write_roi
# offsets at 1, 2, 3, 4, 5, 6).
unfiltered = Task(
task_id="filter_off",
total_roi=Roi((0,), (8,)),
read_roi=Roi((0,), (3,)),
write_roi=Roi((1,), (1,)),
process_function=process_block,
check_function=None,
read_write_conflict=False,
)
unfiltered_sched = Scheduler([unfiltered])
unfiltered_total = unfiltered_sched.task_states[unfiltered.task_id].total_block_count
assert unfiltered_total == 6

# With filter: keep only blocks whose write_roi begins at an even offset.
def even_offset_only(block):
return block.write_roi.begin[0] % 2 == 0

filtered = Task(
task_id="filter_on",
total_roi=Roi((0,), (8,)),
read_roi=Roi((0,), (3,)),
write_roi=Roi((1,), (1,)),
process_function=process_block,
check_function=None,
block_filter=even_offset_only,
read_write_conflict=False,
)
sched = Scheduler([filtered])
assert (
sched.task_states[filtered.task_id].total_block_count == 3
), "filter should leave 3 blocks (write_roi begins at 2, 4, 6)"

seen_offsets = []
while True:
block = sched.acquire_block(filtered.task_id)
if block is None:
break
seen_offsets.append(block.write_roi.begin[0])
block.status = BlockStatus.SUCCESS
sched.release_block(block)

assert sorted(seen_offsets) == [2, 4, 6], seen_offsets


def test_block_filter_all_dropped_yields_no_work():
"""A block_filter that rejects every block leaves the task with zero
blocks; acquire_block returns None immediately.
"""
task = Task(
task_id="filter_all_out",
total_roi=Roi((0,), (8,)),
read_roi=Roi((0,), (3,)),
write_roi=Roi((1,), (1,)),
process_function=process_block,
check_function=None,
block_filter=lambda b: False,
read_write_conflict=False,
)
sched = Scheduler([task])
assert sched.task_states[task.task_id].total_block_count == 0
assert sched.acquire_block(task.task_id) is None
Loading