From 2d46d3faaec104e7a98c0ff4bb3bc3bf1cecb41c Mon Sep 17 00:00:00 2001 From: MohammadYusif Date: Thu, 28 May 2026 21:10:59 +0300 Subject: [PATCH] fix: use task-stream index instead of wall clock in get_task_stream context manager (#9253) The get_task_stream context manager bounded the collected tasks with a wall-clock timestamp (time() - 0.1). collect() then bisected the buffer by comparing that boundary against each task's recorded stop time. When there was latency or clock skew between the client and the workers, a task that finished inside the block could carry a stop time earlier than the client's start boundary and be silently dropped, so get_task_stream() returned no tasks. Record the scheduler's monotonic task-stream append index on entry and collect everything appended after it on exit. This removes the dependency on synchronized clocks entirely, as the maintainers' FIXME suggested. Adds a get_task_stream_index scheduler RPC and a start_index path through collect()/get_task_stream(), with tests covering the index semantics and the clock-skew regression. --- distributed/client.py | 35 ++++++++-------- distributed/diagnostics/task_stream.py | 12 +++++- .../diagnostics/tests/test_task_stream.py | 41 +++++++++++++++++++ distributed/scheduler.py | 27 +++++++++--- 4 files changed, 90 insertions(+), 25 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index a07025e1c3..49f56bf22c 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -5000,6 +5000,7 @@ def get_task_stream( plot=False, filename="task-stream.html", bokeh_resources=None, + start_index=None, ): """Get task stream data from scheduler @@ -5070,6 +5071,7 @@ def get_task_stream( plot=plot, filename=filename, bokeh_resources=bokeh_resources, + start_index=start_index, ) async def _get_task_stream( @@ -5080,8 +5082,11 @@ async def _get_task_stream( plot=False, filename="task-stream.html", bokeh_resources=None, + start_index=None, ): - msgs = await self.scheduler.get_task_stream(start=start, stop=stop, count=count) + msgs = await self.scheduler.get_task_stream( + start=start, stop=stop, count=count, start_index=start_index + ) if plot: from distributed.diagnostics.task_stream import rectangles @@ -6080,39 +6085,33 @@ def __init__(self, client=None, plot=False, filename="task-stream.html"): self._filename = filename self.figure = None self.client = client or default_client() - self._init = False + self._start_index = None def __enter__(self): - if not self._init: - self.client.get_task_stream(start=0, stop=0) # ensure plugin - self._init = True - - # Smooth over time differences of client vs. workers - # FIXME this is very crude. We should query TaskStreamPlugin.index instead. - self.start = time() - 0.1 + # Record the scheduler's task-stream cursor on entry and collect + # everything appended after it on exit. Using the monotonic index + # instead of a wall-clock boundary avoids dropping tasks when there is + # latency or clock skew between the client and the workers. + self._start_index = self.client.sync( + self.client.scheduler.get_task_stream_index + ) return self def __exit__(self, exc_type, exc_value, traceback): L = self.client.get_task_stream( - start=self.start, plot=self._plot, filename=self._filename + start_index=self._start_index, plot=self._plot, filename=self._filename ) if self._plot: L, self.figure = L self.data.extend(L) async def __aenter__(self): - if not self._init: - await self.client.get_task_stream(start=0, stop=0) # ensure plugin - self._init = True - - # Smooth over time differences of client vs. workers - # FIXME this is very crude. We should query TaskStreamPlugin.index instead. - self.start = time() - 0.1 + self._start_index = await self.client.scheduler.get_task_stream_index() return self async def __aexit__(self, exc_type, exc_value, traceback): L = await self.client.get_task_stream( - start=self.start, plot=self._plot, filename=self._filename + start_index=self._start_index, plot=self._plot, filename=self._filename ) if self._plot: L, self.figure = L diff --git a/distributed/diagnostics/task_stream.py b/distributed/diagnostics/task_stream.py index 063238fc08..a2c3b434f0 100644 --- a/distributed/diagnostics/task_stream.py +++ b/distributed/diagnostics/task_stream.py @@ -37,7 +37,17 @@ def transition(self, key, start, finish, *args, **kwargs): self.buffer.append(kwargs) self.index += 1 - def collect(self, start=None, stop=None, count=None): + def collect(self, start=None, stop=None, count=None, start_index=None): + # ``start_index`` selects records by their position in the monotonically + # increasing append counter (``self.index``) rather than by wall-clock + # time. This is immune to clock differences and latency between the + # client and the workers, which can otherwise cause time-based ``start`` + # boundaries to drop tasks that have already completed. + if start_index is not None: + buffer_start = start_index - (self.index - len(self.buffer)) + buffer_start = max(0, min(buffer_start, len(self.buffer))) + return [self.buffer[i] for i in range(buffer_start, len(self.buffer))] + def bisect(target, left, right): while left != right: mid = (left + right) // 2 diff --git a/distributed/diagnostics/tests/test_task_stream.py b/distributed/diagnostics/tests/test_task_stream.py index e7daab7d3f..5d368b4750 100644 --- a/distributed/diagnostics/tests/test_task_stream.py +++ b/distributed/diagnostics/tests/test_task_stream.py @@ -1,5 +1,7 @@ from __future__ import annotations +from collections import deque + import pytest from tlz import frequencies @@ -82,6 +84,45 @@ async def test_collect(c, s, a, b): assert tasks.collect(start=start, count=3) == list(tasks.buffer)[:3] +@gen_cluster(client=True) +async def test_collect_start_index(c, s, a, b): + tasks = TaskStreamPlugin(s) + s.add_plugin(tasks) + + futures = c.map(slowinc, range(5), delay=0.05) + await wait(futures) + midpoint = tasks.index + + futures = c.map(slowinc, range(5, 10), delay=0.05) + await wait(futures) + + # ``start_index`` selects by append position, not wall-clock time, so it + # returns exactly the records appended at or after the given index. + assert len(tasks.collect(start_index=0)) == 10 + assert len(tasks.collect(start_index=midpoint)) == 5 + assert len(tasks.collect(start_index=tasks.index)) == 0 + + +def test_collect_start_index_ignores_clock(): + # When the worker clock lags the client clock (or there is latency), a task + # can finish with a recorded stop time that is earlier than the client's + # ``start`` boundary. The time-based collection then drops the task, which + # is the latency/clock-skew failure from the original bug report. The + # index-based path must still return it. + plugin = TaskStreamPlugin.__new__(TaskStreamPlugin) + plugin.buffer = deque() + plugin.index = 0 + + now = time() + plugin.buffer.append({"key": "task", "startstops": [{"stop": now - 100}]}) + plugin.index += 1 + + # Time-based collection misses the task because its stop time is in the past. + assert plugin.collect(start=now) == [] + # Index-based collection captures it regardless of the clock. + assert len(plugin.collect(start_index=0)) == 1 + + @gen_cluster(client=True) async def test_client(c, s, a, b): await c.get_task_stream() diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 445ffd96e8..3c7f9ac2c4 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4179,6 +4179,7 @@ async def post(self) -> None: "heartbeat_worker": self.heartbeat_worker, "get_task_status": self.get_task_status, "get_task_stream": self.get_task_stream, + "get_task_stream_index": self.get_task_stream_index, "get_task_prefix_states": self.get_task_prefix_states, "register_scheduler_plugin": self.register_scheduler_plugin, "unregister_scheduler_plugin": self.unregister_scheduler_plugin, @@ -8075,20 +8076,34 @@ def get_task_status(self, keys: Iterable[Key]) -> dict[Key, TaskStateState | Non key: (self.tasks[key].state if key in self.tasks else None) for key in keys } + def _task_stream_plugin(self) -> TaskStreamPlugin: + from distributed.diagnostics.task_stream import TaskStreamPlugin + + if TaskStreamPlugin.name not in self.plugins: + self.add_plugin(TaskStreamPlugin(self)) + + return cast(TaskStreamPlugin, self.plugins[TaskStreamPlugin.name]) + def get_task_stream( self, start: str | float | None = None, stop: str | float | None = None, count: int | None = None, + start_index: int | None = None, ) -> list: - from distributed.diagnostics.task_stream import TaskStreamPlugin - - if TaskStreamPlugin.name not in self.plugins: - self.add_plugin(TaskStreamPlugin(self)) + plugin = self._task_stream_plugin() + return plugin.collect( + start=start, stop=stop, count=count, start_index=start_index + ) - plugin = cast(TaskStreamPlugin, self.plugins[TaskStreamPlugin.name]) + def get_task_stream_index(self) -> int: + """Return the number of tasks recorded by the task stream so far. - return plugin.collect(start=start, stop=stop, count=count) + Used as an opaque cursor by the ``get_task_stream`` context manager so + that it can collect exactly the tasks that ran during the block without + relying on (latency- and clock-sensitive) wall-clock boundaries. + """ + return self._task_stream_plugin().index def start_task_metadata(self, name: str) -> None: plugin = CollectTaskMetaDataPlugin(scheduler=self, name=name)