diff --git a/bindings/distrdf/python/DistRDF/Backends/Dask/Backend.py b/bindings/distrdf/python/DistRDF/Backends/Dask/Backend.py index 84a9a0bbd4abc..5e3da6ad5a836 100644 --- a/bindings/distrdf/python/DistRDF/Backends/Dask/Backend.py +++ b/bindings/distrdf/python/DistRDF/Backends/Dask/Backend.py @@ -104,6 +104,20 @@ def __init__(self, daskclient: Optional[Client] = None): # N is the number of cores on the local machine. self.client = (daskclient if daskclient is not None else Client(LocalCluster(n_workers=os.cpu_count(), threads_per_worker=1, processes=True))) + + workers = self.client.scheduler_info().get("workers", None) + + if workers is None: + return + + for worker in workers.values(): + threads = worker.get("nthreads", 1) + + if threads > 1: + raise RuntimeError( + "DistRDF with Dask does not support threaded workers. " + "Please use processes=True and threads_per_worker=1." + ) def optimize_npartitions(self) -> int: """ diff --git a/roottest/python/distrdf/backends/check_backend.py b/roottest/python/distrdf/backends/check_backend.py index b7ae271ca26ea..34e3fc8b737fd 100644 --- a/roottest/python/distrdf/backends/check_backend.py +++ b/roottest/python/distrdf/backends/check_backend.py @@ -59,6 +59,27 @@ def test_optimize_npartitions(self, payload): backend = Backend.SparkBackend(sparkcontext=connection) assert backend.optimize_npartitions() == 2 + def test_dask_backend_handles_missing_workers(self, payload): + """ + Check that DaskBackend initialization succeeds when scheduler_info + does not provide worker information. + """ + connection, backend = payload + + if backend != "dask": + return + + from ROOT._distrdf.Backends.Dask import Backend + + original_scheduler_info = connection.scheduler_info + + try: + connection.scheduler_info = lambda: {} + backend = Backend.DaskBackend(daskclient=connection) + assert backend.client is connection + finally: + connection.scheduler_info = original_scheduler_info + class TestInitialization: """Check initialization method in the Dask backend"""