Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions bindings/distrdf/python/DistRDF/Backends/Dask/Backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,20 @@ def __init__(self, daskclient: Optional[Client] = None):
# N is the number of cores on the local machine.
self.client = (daskclient if daskclient is not None else
Client(LocalCluster(n_workers=os.cpu_count(), threads_per_worker=1, processes=True)))

workers = self.client.scheduler_info().get("workers", None)

if workers is None:
return

for worker in workers.values():
threads = worker.get("nthreads", 1)

if threads > 1:
raise RuntimeError(
"DistRDF with Dask does not support threaded workers. "
"Please use processes=True and threads_per_worker=1."
)

def optimize_npartitions(self) -> int:
"""
Expand Down
21 changes: 21 additions & 0 deletions roottest/python/distrdf/backends/check_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,27 @@ def test_optimize_npartitions(self, payload):
backend = Backend.SparkBackend(sparkcontext=connection)
assert backend.optimize_npartitions() == 2

def test_dask_backend_handles_missing_workers(self, payload):
"""
Check that DaskBackend initialization succeeds when scheduler_info
does not provide worker information.
"""
connection, backend = payload

if backend != "dask":
return

from ROOT._distrdf.Backends.Dask import Backend

original_scheduler_info = connection.scheduler_info

try:
connection.scheduler_info = lambda: {}
backend = Backend.DaskBackend(daskclient=connection)
assert backend.client is connection
finally:
connection.scheduler_info = original_scheduler_info


class TestInitialization:
"""Check initialization method in the Dask backend"""
Expand Down