diff --git a/daisy/worker_pool.py b/daisy/worker_pool.py index ee65b8ba..d400dbbe 100644 --- a/daisy/worker_pool.py +++ b/daisy/worker_pool.py @@ -103,21 +103,32 @@ def check_for_errors(self): pass def reap_dead_workers(self): - """Detect worker processes that have exited and remove them from the - pool. Returns the number of workers reaped.""" + """Detect worker processes that have exited. Dead workers + (exitcode != 0) are removed from the pool and counted in the return + value. Workers that exited normally (exitcode == 0) are kept in the + pool with process=None so len(workers) stays at the target count, + preventing unwanted respawning.""" dead_worker_ids = [] with self.workers_lock: for worker_id, worker in self.workers.items(): if worker.process is not None and not worker.process.is_alive(): - logger.warning( - "Worker %s (pid %d) exited with code %d", - worker, - worker.process.pid, - worker.process.exitcode, - ) + exitcode = worker.process.exitcode + if exitcode == 0: + logger.info( + "Worker %s (pid %d) exited normally", + worker, + worker.process.pid, + ) + else: + logger.warning( + "Worker %s (pid %d) failed with exit code %d", + worker, + worker.process.pid, + exitcode, + ) + dead_worker_ids.append(worker_id) worker.process = None - dead_worker_ids.append(worker_id) for worker_id in dead_worker_ids: del self.workers[worker_id] diff --git a/tests/process_block_or_quit.py b/tests/process_block_or_quit.py new file mode 100644 index 00000000..cffeea1e --- /dev/null +++ b/tests/process_block_or_quit.py @@ -0,0 +1,43 @@ +"""Worker that exits early on first invocation, loops normally on subsequent. + +The first worker to start creates a marker file and processes only one block +before exiting normally (code 0). Subsequent workers see the marker and +process all blocks in a loop. This forces a normal-exit worker to leave +while blocks are still pending, triggering the reap-replace path. +""" + +import daisy + +import os +import sys +import time + +from filelock import FileLock + +tmp_path = sys.argv[1] +marker = os.path.join(tmp_path, "first_worker_done") +counter = os.path.join(tmp_path, "worker_count") +lock = os.path.join(tmp_path, "count.lock") + +# Atomically increment spawn counter and check if first worker +with FileLock(lock): + count = int(open(counter).read()) + open(counter, "w").write(str(count + 1)) + is_first = not os.path.exists(marker) + if is_first: + open(marker, "w").write("done") + +client = daisy.Client() + +if is_first: + # Process exactly one block then exit normally + with client.acquire_block() as block: + pass +else: + # Process all blocks, but slowly — give the reap cycle time to + # notice the first worker exited while blocks are still pending + while True: + with client.acquire_block() as block: + if block is None: + break + time.sleep(0.5) diff --git a/tests/test_dead_workers.py b/tests/test_dead_workers.py index 392cacff..e55cd44d 100644 --- a/tests/test_dead_workers.py +++ b/tests/test_dead_workers.py @@ -7,6 +7,8 @@ import daisy from daisy.logging import set_log_basedir +from daisy.worker_pool import WorkerPool +from unittest.mock import MagicMock import logging import os @@ -27,9 +29,14 @@ def test_dead_worker_replacement(tmp_path): set_log_basedir(tmp_path) def start_worker(): - subprocess.run( + result = subprocess.run( [sys.executable, "tests/process_block_or_die.py", str(tmp_path)] ) + # Propagate subprocess exit code so the daisy worker process also + # exits non-zero on crash (SystemExit bypasses _spawn_wrapper's + # except Exception, producing a non-zero exitcode for reaping). + if result.returncode != 0: + raise SystemExit(result.returncode) task = daisy.Task( "test_dead_worker_task", @@ -53,3 +60,35 @@ def start_worker(): assert os.path.exists(tmp_path / "worker_crashed"), ( "Expected first worker to crash and leave a marker file" ) + + +def test_reap_distinguishes_normal_from_crash(): + """reap_dead_workers only removes crashed workers from the pool. + + Workers that exit normally (exitcode 0) stay in the dict with + process=None so len(workers) stays at the target count. Only + crashed workers (exitcode != 0) are removed and counted. + """ + pool = WorkerPool(lambda: None) + + normal_worker = MagicMock() + normal_worker.process.is_alive.return_value = False + normal_worker.process.exitcode = 0 + normal_worker.process.pid = 1000 + + crashed_worker = MagicMock() + crashed_worker.process.is_alive.return_value = False + crashed_worker.process.exitcode = 1 + crashed_worker.process.pid = 1001 + + pool.workers = {0: normal_worker, 1: crashed_worker} + + reaped = pool.reap_dead_workers() + + # Only the crashed worker counts as reaped + assert reaped == 1 + # Normal worker stays in dict with process=None + assert 0 in pool.workers + assert pool.workers[0].process is None + # Crashed worker is removed + assert 1 not in pool.workers diff --git a/tests/test_worker_spawning.py b/tests/test_worker_spawning.py new file mode 100644 index 00000000..4625d0bb --- /dev/null +++ b/tests/test_worker_spawning.py @@ -0,0 +1,62 @@ +"""Test that workers exiting normally are not endlessly respawned. + +When workers finish processing their blocks and exit with code 0, the server +should not treat them as crashed and replace them. Without this fix, the +reap-replace cycle causes unbounded worker growth. +""" + +import daisy +from daisy.logging import set_log_basedir + +import logging +import subprocess +import sys + +logging.basicConfig(level=logging.DEBUG) + + +def test_normal_exit_no_respawn(tmp_path): + """Workers that exit normally are not replaced. + + Uses 4 blocks with 2 workers. The first worker to start processes + exactly one block then exits normally (code 0) via + process_block_or_quit.py, while blocks are still pending. The second + worker loops normally. + + Without the fix, the exited worker is reaped and replaced, spawning + a third worker. With the fix, the exited worker stays counted in the + pool and no replacement is spawned. + """ + set_log_basedir(tmp_path) + + counter = tmp_path / "worker_count" + counter.write_text("0") + + def start_worker(): + subprocess.run([ + sys.executable, "tests/process_block_or_quit.py", str(tmp_path) + ]) + + task = daisy.Task( + "test_no_respawn_task", + total_roi=daisy.Roi((0,), (40,)), + read_roi=daisy.Roi((0,), (10,)), + write_roi=daisy.Roi((0,), (10,)), + process_function=start_worker, + check_function=None, + read_write_conflict=False, + fit="valid", + num_workers=2, + max_retries=2, + timeout=None, + ) + + server = daisy.Server() + task_states = server.run_blockwise([task]) + assert task_states[task.task_id].is_done(), task_states[task.task_id] + + total_workers = int(counter.read_text()) + assert total_workers <= 2, ( + f"Expected at most 2 workers, but {total_workers} were spawned. " + "Normal worker exits are being treated as dead and replaced." + )