Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5000,6 +5000,7 @@ def get_task_stream(
plot=False,
filename="task-stream.html",
bokeh_resources=None,
start_index=None,
):
"""Get task stream data from scheduler

Expand Down Expand Up @@ -5070,6 +5071,7 @@ def get_task_stream(
plot=plot,
filename=filename,
bokeh_resources=bokeh_resources,
start_index=start_index,
)

async def _get_task_stream(
Expand All @@ -5080,8 +5082,11 @@ async def _get_task_stream(
plot=False,
filename="task-stream.html",
bokeh_resources=None,
start_index=None,
):
msgs = await self.scheduler.get_task_stream(start=start, stop=stop, count=count)
msgs = await self.scheduler.get_task_stream(
start=start, stop=stop, count=count, start_index=start_index
)
if plot:
from distributed.diagnostics.task_stream import rectangles

Expand Down Expand Up @@ -6080,39 +6085,33 @@ def __init__(self, client=None, plot=False, filename="task-stream.html"):
self._filename = filename
self.figure = None
self.client = client or default_client()
self._init = False
self._start_index = None

def __enter__(self):
if not self._init:
self.client.get_task_stream(start=0, stop=0) # ensure plugin
self._init = True

# Smooth over time differences of client vs. workers
# FIXME this is very crude. We should query TaskStreamPlugin.index instead.
self.start = time() - 0.1
# Record the scheduler's task-stream cursor on entry and collect
# everything appended after it on exit. Using the monotonic index
# instead of a wall-clock boundary avoids dropping tasks when there is
# latency or clock skew between the client and the workers.
self._start_index = self.client.sync(
self.client.scheduler.get_task_stream_index
)
return self

def __exit__(self, exc_type, exc_value, traceback):
L = self.client.get_task_stream(
start=self.start, plot=self._plot, filename=self._filename
start_index=self._start_index, plot=self._plot, filename=self._filename
)
if self._plot:
L, self.figure = L
self.data.extend(L)

async def __aenter__(self):
if not self._init:
await self.client.get_task_stream(start=0, stop=0) # ensure plugin
self._init = True

# Smooth over time differences of client vs. workers
# FIXME this is very crude. We should query TaskStreamPlugin.index instead.
self.start = time() - 0.1
self._start_index = await self.client.scheduler.get_task_stream_index()
return self

async def __aexit__(self, exc_type, exc_value, traceback):
L = await self.client.get_task_stream(
start=self.start, plot=self._plot, filename=self._filename
start_index=self._start_index, plot=self._plot, filename=self._filename
)
if self._plot:
L, self.figure = L
Expand Down
12 changes: 11 additions & 1 deletion distributed/diagnostics/task_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,17 @@ def transition(self, key, start, finish, *args, **kwargs):
self.buffer.append(kwargs)
self.index += 1

def collect(self, start=None, stop=None, count=None):
def collect(self, start=None, stop=None, count=None, start_index=None):
# ``start_index`` selects records by their position in the monotonically
# increasing append counter (``self.index``) rather than by wall-clock
# time. This is immune to clock differences and latency between the
# client and the workers, which can otherwise cause time-based ``start``
# boundaries to drop tasks that have already completed.
if start_index is not None:
buffer_start = start_index - (self.index - len(self.buffer))
buffer_start = max(0, min(buffer_start, len(self.buffer)))
return [self.buffer[i] for i in range(buffer_start, len(self.buffer))]

def bisect(target, left, right):
while left != right:
mid = (left + right) // 2
Expand Down
41 changes: 41 additions & 0 deletions distributed/diagnostics/tests/test_task_stream.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from collections import deque

import pytest
from tlz import frequencies

Expand Down Expand Up @@ -82,6 +84,45 @@ async def test_collect(c, s, a, b):
assert tasks.collect(start=start, count=3) == list(tasks.buffer)[:3]


@gen_cluster(client=True)
async def test_collect_start_index(c, s, a, b):
tasks = TaskStreamPlugin(s)
s.add_plugin(tasks)

futures = c.map(slowinc, range(5), delay=0.05)
await wait(futures)
midpoint = tasks.index

futures = c.map(slowinc, range(5, 10), delay=0.05)
await wait(futures)

# ``start_index`` selects by append position, not wall-clock time, so it
# returns exactly the records appended at or after the given index.
assert len(tasks.collect(start_index=0)) == 10
assert len(tasks.collect(start_index=midpoint)) == 5
assert len(tasks.collect(start_index=tasks.index)) == 0


def test_collect_start_index_ignores_clock():
# When the worker clock lags the client clock (or there is latency), a task
# can finish with a recorded stop time that is earlier than the client's
# ``start`` boundary. The time-based collection then drops the task, which
# is the latency/clock-skew failure from the original bug report. The
# index-based path must still return it.
plugin = TaskStreamPlugin.__new__(TaskStreamPlugin)
plugin.buffer = deque()
plugin.index = 0

now = time()
plugin.buffer.append({"key": "task", "startstops": [{"stop": now - 100}]})
plugin.index += 1

# Time-based collection misses the task because its stop time is in the past.
assert plugin.collect(start=now) == []
# Index-based collection captures it regardless of the clock.
assert len(plugin.collect(start_index=0)) == 1


@gen_cluster(client=True)
async def test_client(c, s, a, b):
await c.get_task_stream()
Expand Down
27 changes: 21 additions & 6 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4179,6 +4179,7 @@ async def post(self) -> None:
"heartbeat_worker": self.heartbeat_worker,
"get_task_status": self.get_task_status,
"get_task_stream": self.get_task_stream,
"get_task_stream_index": self.get_task_stream_index,
"get_task_prefix_states": self.get_task_prefix_states,
"register_scheduler_plugin": self.register_scheduler_plugin,
"unregister_scheduler_plugin": self.unregister_scheduler_plugin,
Expand Down Expand Up @@ -8075,20 +8076,34 @@ def get_task_status(self, keys: Iterable[Key]) -> dict[Key, TaskStateState | Non
key: (self.tasks[key].state if key in self.tasks else None) for key in keys
}

def _task_stream_plugin(self) -> TaskStreamPlugin:
from distributed.diagnostics.task_stream import TaskStreamPlugin

if TaskStreamPlugin.name not in self.plugins:
self.add_plugin(TaskStreamPlugin(self))

return cast(TaskStreamPlugin, self.plugins[TaskStreamPlugin.name])

def get_task_stream(
self,
start: str | float | None = None,
stop: str | float | None = None,
count: int | None = None,
start_index: int | None = None,
) -> list:
from distributed.diagnostics.task_stream import TaskStreamPlugin

if TaskStreamPlugin.name not in self.plugins:
self.add_plugin(TaskStreamPlugin(self))
plugin = self._task_stream_plugin()
return plugin.collect(
start=start, stop=stop, count=count, start_index=start_index
)

plugin = cast(TaskStreamPlugin, self.plugins[TaskStreamPlugin.name])
def get_task_stream_index(self) -> int:
"""Return the number of tasks recorded by the task stream so far.

return plugin.collect(start=start, stop=stop, count=count)
Used as an opaque cursor by the ``get_task_stream`` context manager so
that it can collect exactly the tasks that ran during the block without
relying on (latency- and clock-sensitive) wall-clock boundaries.
"""
return self._task_stream_plugin().index

def start_task_metadata(self, name: str) -> None:
plugin = CollectTaskMetaDataPlugin(scheduler=self, name=name)
Expand Down
Loading