From 034774b3920002cf3cfb2b9044990730f4b923d6 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 4 Jun 2026 16:27:32 -0500 Subject: [PATCH 1/9] Add Dask expression protocol support --- xarray/core/dataarray.py | 7 + xarray/core/dataset.py | 59 ++++- xarray/tests/test_dask_expr_protocol.py | 282 ++++++++++++++++++++++++ 3 files changed, 345 insertions(+), 3 deletions(-) create mode 100644 xarray/tests/test_dask_expr_protocol.py diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index d0df9bc061b..bdba5cb556a 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1113,6 +1113,13 @@ def __dask_graph__(self): def __dask_keys__(self): return self._to_temp_dataset().__dask_keys__() + def __dask_exprs__(self): + return self._to_temp_dataset().__dask_exprs__() + + def __dask_rebuild_from_exprs__(self, exprs): + ds = self._to_temp_dataset().__dask_rebuild_from_exprs__(exprs) + return self._from_temp_dataset(ds) + def __dask_layers__(self): return self._to_temp_dataset().__dask_layers__() diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 03a0b594a13..d5c4f92f9ed 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -651,11 +651,17 @@ def __dask_graph__(self): try: from dask.highlevelgraph import HighLevelGraph - return HighLevelGraph.merge(*graphs.values()) + if all(isinstance(graph, HighLevelGraph) for graph in graphs.values()): + return HighLevelGraph.merge(*graphs.values()) except ImportError: - from dask import sharedict + pass - return sharedict.merge(*graphs.values()) + from dask.utils import ensure_dict + + merged = {} + for graph in graphs.values(): + merged.update(ensure_dict(graph)) + return merged def __dask_keys__(self): import dask @@ -666,6 +672,53 @@ def __dask_keys__(self): if dask.is_dask_collection(v) ] + def __dask_exprs__(self): + import dask + + exprs = [] + for v in self.variables.values(): + if dask.is_dask_collection(v): + if not hasattr(v._data, "expr"): + return None + exprs.append(v._data.expr) + return exprs or None + + def __dask_rebuild_from_exprs__(self, exprs): + import dask + from dask._collections import new_collection + + exprs_iter = iter(exprs) + variables = {} + + for k, v in self._variables.items(): + if dask.is_dask_collection(v): + try: + expr = next(exprs_iter) + except StopIteration as err: + raise ValueError( + "Not enough expressions to rebuild Dataset" + ) from err + variables[k] = v._replace(data=new_collection(expr)) + else: + variables[k] = v + + try: + next(exprs_iter) + except StopIteration: + pass + else: + raise ValueError("Too many expressions to rebuild Dataset") + + return type(self)._construct_direct( + variables, + self._coord_names, + self._dims, + self._attrs, + self._indexes, + self._encoding, + self._close, + ) + def __dask_layers__(self): import dask diff --git a/xarray/tests/test_dask_expr_protocol.py b/xarray/tests/test_dask_expr_protocol.py new file mode 100644 index 00000000000..d89434f2c0d --- /dev/null +++ b/xarray/tests/test_dask_expr_protocol.py @@ -0,0 +1,282 @@ +from __future__ import annotations + +import numpy as np +import pytest + +import xarray as xr +from xarray import DataArray, Dataset +from xarray.testing import assert_equal, assert_identical + +dask = pytest.importorskip("dask") +da = pytest.importorskip("dask.array") + + +def test_standalone_dask_array_dataset_composite_expr_protocol(): + dask_array = pytest.importorskip("dask_array") + from dask._expr import CompositeExpr + + x = dask_array.arange(6, chunks=(3,)) + ds = Dataset( + {"foo": ("i", x + 1)}, + coords={"coord": ("i", x + 10)}, + attrs={"source": "test"}, + ) + ds["foo"].encoding["example"] = "kept" + + expr = dask.base.collections_to_expr(ds) + + assert isinstance(expr, CompositeExpr) + assert len(expr.exprs) == 2 + + expected = Dataset( + {"foo": ("i", np.arange(6) + 1)}, + coords={"coord": ("i", np.arange(6) + 10)}, + attrs={"source": "test"}, + ) + computed = dask.compute(ds, scheduler="single-threaded")[0] + assert_identical(computed, expected) + assert computed["foo"].encoding["example"] == "kept" + + persisted = dask.persist(ds, scheduler="single-threaded")[0] + assert hasattr(persisted["foo"].data, "expr") + assert_identical(persisted.compute(), expected) + assert persisted["foo"].encoding["example"] == "kept" + + optimized = dask.optimize(ds)[0] + assert hasattr(optimized["foo"].data, "expr") + assert_identical(optimized.compute(), expected) + + +def test_standalone_dask_array_dataset_variable_named_expr(): + dask_array = pytest.importorskip("dask_array") + from dask._expr import CompositeExpr + + x = dask_array.arange(6, chunks=(3,)) + ds = Dataset({"expr": ("i", x + 1)}) + + expr = dask.base.collections_to_expr(ds) + + assert isinstance(expr, CompositeExpr) + assert_identical( + dask.compute(ds, scheduler="single-threaded")[0], + Dataset({"expr": ("i", np.arange(6) + 1)}), + ) + + +def test_standalone_dask_array_dataarray_composite_expr_protocol(): + dask_array = pytest.importorskip("dask_array") + from dask._expr import CompositeExpr + + x = dask_array.arange(6, chunks=(3,)) + arr = DataArray(x + 1, dims=("i",), coords={"coord": ("i", x)}, name="z") + + expr = dask.base.collections_to_expr(arr) + + assert isinstance(expr, CompositeExpr) + assert len(expr.exprs) == 2 + + expected = DataArray( + np.arange(6) + 1, + dims=("i",), + coords={"coord": ("i", np.arange(6))}, + name="z", + ) + computed = dask.compute(arr, scheduler="single-threaded")[0] + assert_identical(computed, expected) + + persisted = dask.persist(arr, scheduler="single-threaded")[0] + assert hasattr(persisted.data, "expr") + assert_identical(persisted.compute(), expected) + + +def test_standalone_dask_array_dataset_computes_with_raw_array(): + dask_array = pytest.importorskip("dask_array") + + x = dask_array.arange(6, chunks=(3,)) + ds = Dataset({"foo": ("i", x + 1)}) + + computed_ds, computed_x = dask.compute(ds, x, scheduler="single-threaded") + + assert_identical(computed_ds, Dataset({"foo": ("i", np.arange(6) + 1)})) + np.testing.assert_array_equal(computed_x, np.arange(6)) + + +def test_standalone_dask_array_optimize_culls_child_graphs(): + dask_array = pytest.importorskip("dask_array") + from dask.core import flatten + + x = dask_array.arange(6, chunks=(3,)) + ds = Dataset({"foo": ("i", x + 1), "bar": ("i", x + 2)}) + + optimized = dask.optimize(ds)[0] + + foo_graph_keys = set(optimized["foo"].data.__dask_graph__()) + bar_output_keys = set(flatten(optimized["bar"].data.__dask_keys__())) + assert not foo_graph_keys & bar_output_keys + assert_identical( + optimized.compute(), + Dataset({"foo": ("i", np.arange(6) + 1), "bar": ("i", np.arange(6) + 2)}), + ) + + +def test_standalone_dask_array_mixed_legacy_falls_back_from_composite_expr(): + dask_array = pytest.importorskip("dask_array") + + ds = Dataset( + { + "expr": ("i", dask_array.arange(3, chunks=(1,))), + "legacy": ("i", da.arange(3, chunks=(1,))), + } + ) + + assert ds.__dask_exprs__() is None + + +def test_standalone_dask_array_mixed_legacy_map_blocks_fallback_computes(): + dask_array = pytest.importorskip("dask_array") + + ds = Dataset( + { + "expr": ("i", dask_array.arange(6, chunks=(3,))), + "legacy": ("i", da.arange(6, chunks=(3,))), + } + ) + + out = xr.map_blocks(lambda block: block + 1, ds) + + expected = Dataset( + {"expr": ("i", np.arange(6) + 1), "legacy": ("i", np.arange(6) + 1)} + ) + assert_identical(out.compute(scheduler="single-threaded"), expected) + + +def test_standalone_dask_array_open_mfdataset_uses_expressions(tmp_path): + pytest.importorskip("dask_array") + from dask._expr import CompositeExpr + + paths = [] + for i in range(2): + path = tmp_path / f"part-{i}.nc" + Dataset( + {"x": ("t", np.arange(i * 3, i * 3 + 3))}, + coords={"t": np.arange(i * 3, i * 3 + 3)}, + ).to_netcdf(path) + paths.append(path) + + ds = xr.open_mfdataset(paths, chunks={"t": 2}, combine="by_coords") + try: + expr = dask.base.collections_to_expr(ds) + + assert type(ds["x"].data).__module__.startswith("dask_array") + assert isinstance(expr, CompositeExpr) + + expected = Dataset({"x": ("t", np.arange(6))}, coords={"t": np.arange(6)}) + assert_identical(ds.compute(scheduler="single-threaded"), expected) + assert_identical(dask.compute(ds, scheduler="single-threaded")[0], expected) + assert_identical( + ds.persist(scheduler="single-threaded").compute( + scheduler="single-threaded" + ), + expected, + ) + assert_identical( + dask.optimize(ds)[0].compute(scheduler="single-threaded"), expected + ) + finally: + ds.close() + + +def test_standalone_dask_array_shared_subexpressions_and_chunked_coords(): + dask_array = pytest.importorskip("dask_array") + + x = dask_array.arange(6, chunks=(3,)) + ds = Dataset( + {"a": ("t", x + 1), "b": ("t", x * 2)}, + coords={"qc": ("t", x + 10)}, + attrs={"case": "shared"}, + ) + + expected = Dataset( + {"a": ("t", np.arange(6) + 1), "b": ("t", np.arange(6) * 2)}, + coords={"qc": ("t", np.arange(6) + 10)}, + attrs={"case": "shared"}, + ) + assert len(dask.base.collections_to_expr(ds).exprs) == 3 + assert_identical(ds.compute(scheduler="single-threaded"), expected) + assert_identical( + dask.persist(ds, scheduler="single-threaded")[0].compute( + scheduler="single-threaded" + ), + expected, + ) + assert_identical( + dask.optimize(ds)[0].compute(scheduler="single-threaded"), expected + ) + + +def test_standalone_dask_array_rechunk_reduction_chain(): + dask_array = pytest.importorskip("dask_array") + from dask._expr import CompositeExpr + + x = dask_array.arange(12, chunks=(4,)).reshape((3, 4)) + out = Dataset({"x": (("a", "b"), x)}).chunk({"a": 1, "b": 2}).sum("b") + 1 + + assert isinstance(dask.base.collections_to_expr(out), CompositeExpr) + expected = Dataset({"x": ("a", np.arange(12).reshape(3, 4).sum(axis=1) + 1)}) + assert_identical(out.compute(scheduler="single-threaded"), expected) + assert_identical( + dask.optimize(out)[0].compute(scheduler="single-threaded"), expected + ) + + +def test_standalone_dask_array_groupby_sum(): + dask_array = pytest.importorskip("dask_array") + from dask._expr import CompositeExpr + + x = dask_array.arange(6, chunks=(3,)) + arr = DataArray( + x, + dims="t", + coords={"label": ("t", np.array(["a", "a", "b", "b", "a", "b"]))}, + name="x", + ) + out = arr.groupby("label").sum() + + assert isinstance(dask.base.collections_to_expr(out), CompositeExpr) + expected = DataArray( + [5, 10], + dims="label", + coords={"label": np.array(["a", "b"], dtype=object)}, + name="x", + ) + assert_equal(out.compute(scheduler="single-threaded"), expected) + assert_equal(dask.optimize(out)[0].compute(scheduler="single-threaded"), expected) + + +def test_standalone_dask_array_map_blocks_fallback_computes(): + dask_array = pytest.importorskip("dask_array") + + x = dask_array.arange(6, chunks=(3,)) + arr = DataArray(x, dims="t", name="x") + out = xr.map_blocks(lambda block: block + 1, arr) + + expected = DataArray(np.arange(6) + 1, dims="t", name="x") + assert_identical(out.compute(scheduler="single-threaded"), expected) + assert_identical( + dask.optimize(out)[0].compute(scheduler="single-threaded"), expected + ) + + +def test_standalone_dask_array_apply_ufunc_parallelized(): + dask_array = pytest.importorskip("dask_array") + from dask._expr import CompositeExpr + + x = dask_array.arange(6, chunks=(3,)) + arr = DataArray(x, dims="t", name="x") + out = xr.apply_ufunc(lambda z: z + 2, arr, dask="parallelized", output_dtypes=[int]) + + assert isinstance(dask.base.collections_to_expr(out), CompositeExpr) + assert_identical( + out.compute(scheduler="single-threaded"), + DataArray(np.arange(6) + 2, dims="t", name="x"), + ) From 59039a0716b94d7344ed5b7370e474db3035634d Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 4 Jun 2026 17:27:05 -0500 Subject: [PATCH 2/9] Use dask_array expressions for map_blocks --- xarray/core/parallel.py | 241 +++++++++++++++++++++++- xarray/namedarray/parallelcompat.py | 15 +- xarray/tests/test_dask_expr_protocol.py | 112 ++++++++++- 3 files changed, 357 insertions(+), 11 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index a718ef5e911..82e97e652de 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -149,6 +149,37 @@ def make_dict(x: DataArray | Dataset) -> dict[Hashable, Any]: return {k: v.data for k, v in x.variables.items()} +def _execute_map_blocks_multi_output(block_spec, *blocks): + args = [] + for arg_spec in block_spec["args"]: + if arg_spec[0] == "literal": + args.append(arg_spec[1]) + continue + + _, data_vars, coords, attrs = arg_spec + + def build_variables(variable_specs): + variables = [] + for name, dims, data, var_attrs in variable_specs: + if data[0] == "block": + data = blocks[data[1]] + else: + data = data[1] + variables.append((name, (dims, data, var_attrs))) + return dict(variables) + + args.append(Dataset(build_variables(data_vars), build_variables(coords), attrs)) + + return block_spec["wrapper"]( + block_spec["func"], + args, + block_spec["kwargs"], + block_spec["arg_is_array"], + block_spec["expected"], + block_spec["expected_indexes"], + ) + + def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping): if dim in chunk_index: which_chunk = chunk_index[dim] @@ -440,6 +471,31 @@ def _wrapper( dataarray_to_dataset(arg) if isinstance(arg, DataArray) else arg for arg in aligned ) + + try: + from dask_array import Array as DaskArrayExprArray + except ImportError: + DaskArrayExprArray = () + + chunked_data = [ + variable.data + for arg in xarray_objs + for variable in arg.variables.values() + if is_dask_collection(variable.data) + ] + has_dask_array_expr = any( + isinstance(data, DaskArrayExprArray) for data in chunked_data + ) + has_other_chunked = any( + not isinstance(data, DaskArrayExprArray) for data in chunked_data + ) + + if has_dask_array_expr and has_other_chunked: + raise TypeError( + "xarray.map_blocks cannot mix dask_array.Array with legacy or other " + "Dask-backed arrays. Convert inputs to one array backend first." + ) + # rechunk any numpy variables appropriately xarray_objs = tuple(arg.chunk(arg.chunksizes) for arg in xarray_objs) @@ -546,7 +602,190 @@ def _wrapper( dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in output_chunks.items() } - computed_variables = set(template.variables) - set(coordinates.indexes) + computed_variables = [ + name for name in template.variables if name not in coordinates.indexes + ] + + chunked_data = [ + variable.data + for arg in xarray_objs + for variable in arg.variables.values() + if is_dask_collection(variable.data) + ] + + if has_dask_array_expr: + missing_chunked_dims = { + dim + for dim, chunks in input_chunks.items() + if len(chunks) > 1 and dim not in output_chunks + } + if missing_chunked_dims: + raise NotImplementedError( + "dask_array-backed xarray.map_blocks does not yet support " + "dropping multi-chunk dimensions. Rechunk these dimensions to " + f"one chunk first: {sorted(missing_chunked_dims)!r}." + ) + + from xarray.namedarray.parallelcompat import get_chunked_array_type + + chunkmanager = get_chunked_array_type(*chunked_data) + + input_exprs = [] + input_indices = [] + arg_templates = [] + for isxr, arg in zip(is_xarray, npargs, strict=True): + if not isxr: + if is_dask_collection(arg): + raise TypeError( + "dask_array-backed xarray.map_blocks only supports Dask " + "collections inside xarray arguments." + ) + arg_templates.append(("literal", arg)) + continue + + variable_templates = [] + for name, variable in arg.variables.items(): + is_coord = name in arg._coord_names + if is_dask_collection(variable.data): + input_exprs.append(variable.data.expr) + input_indices.append(variable.dims) + variable_templates.append( + ( + "chunked", + name, + variable.dims, + variable.attrs, + len(input_exprs) - 1, + is_coord, + None, + ) + ) + else: + variable_templates.append( + ( + "static", + name, + variable.dims, + variable.attrs, + None, + is_coord, + variable, + ) + ) + arg_templates.append(("xarray", arg.attrs, variable_templates)) + + def build_block_specs(): + specs = {} + for chunk_tuple in itertools.product(*ichunk.values()): + chunk_index = dict(zip(ichunk.keys(), chunk_tuple, strict=True)) + arg_specs = [] + for arg_template in arg_templates: + if arg_template[0] == "literal": + arg_specs.append(arg_template) + continue + + _, attrs, variable_templates = arg_template + data_vars = [] + coords = [] + for ( + kind, + name, + dims, + attrs, + input_position, + is_coord, + variable, + ) in variable_templates: + if kind == "chunked": + data = ("block", input_position) + else: + assert variable is not None + subsetter = { + dim: _get_chunk_slicer( + dim, chunk_index, input_chunk_bounds + ) + for dim in variable.dims + } + data = ("static", variable.isel(subsetter)._data) + + target = coords if is_coord else data_vars + target.append((name, dims, data, attrs)) + + arg_specs.append(("xarray", data_vars, coords, attrs)) + + indexes = { + dim: coordinates.xindexes[dim][ + _get_chunk_slicer(dim, chunk_index, output_chunk_bounds) + ] + for dim in (new_indexes | modified_indexes) + } + expected: ExpectedDict = { + "shapes": { + k: output_chunks[k][v] + for k, v in chunk_index.items() + if k in output_chunks + }, + "data_vars": set(template.data_vars.keys()), + "coords": set(template.coords.keys()), + } + specs[chunk_tuple] = { + "wrapper": _wrapper, + "func": func, + "args": arg_specs, + "kwargs": kwargs, + "arg_is_array": is_array, + "expected": expected, + "expected_indexes": indexes, + } + return specs + + outputs = [] + for name in computed_variables: + variable = template.variables[name] + var_chunks = [] + for dim in variable.dims: + if dim in output_chunks: + var_chunks.append(output_chunks[dim]) + elif dim in template.dims: + var_chunks.append((template.sizes[dim],)) + + outputs.append( + { + "key": name, + "indices": variable.dims, + "chunks": tuple(var_chunks), + "dtype": variable.dtype, + "name": f"{name}-{gname}", + } + ) + + mapped_arrays = chunkmanager.map_blocks_multi_output( + _execute_map_blocks_multi_output, + input_exprs, + input_indices, + tuple(input_chunks), + build_block_specs(), + outputs, + token=gname, + ) + + result = Dataset(coords=coordinates, attrs=template.attrs) + for index in result._indexes: + result[index].attrs = template[index].attrs + result[index].encoding = template[index].encoding + + for name, data in zip(computed_variables, mapped_arrays, strict=True): + result[name] = (template[name].dims, data, template[name].attrs) + result[name].encoding = template[name].encoding + + result = result.set_coords(template._coord_names) + + if result_is_array: + da = dataset_to_dataarray(result) + da.name = template_name + return da # type: ignore[return-value] + return result # type: ignore[return-value] + # iterate over all possible chunk combinations for chunk_tuple in itertools.product(*ichunk.values()): # mapping from dimension name to chunk index diff --git a/xarray/namedarray/parallelcompat.py b/xarray/namedarray/parallelcompat.py index 8a68f5e9562..334e464a1fc 100644 --- a/xarray/namedarray/parallelcompat.py +++ b/xarray/namedarray/parallelcompat.py @@ -8,7 +8,7 @@ import functools from abc import ABC, abstractmethod -from collections.abc import Callable, Iterable, Sequence +from collections.abc import Callable, Iterable, Mapping, Sequence from importlib.metadata import EntryPoint, entry_points from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar @@ -655,6 +655,19 @@ def map_blocks( """ raise NotImplementedError() + def map_blocks_multi_output( + self, + func: Callable[..., Any], + input_exprs: Sequence[Any], + input_indices: Sequence[Iterable[Any]], + shared_indices: Iterable[Any], + block_specs: Mapping[tuple[int, ...], Any], + outputs: Sequence[Mapping[str, Any]], + *, + token: str, + ) -> list[Any]: + raise NotImplementedError() + def blockwise( self, func: Callable[..., Any], diff --git a/xarray/tests/test_dask_expr_protocol.py b/xarray/tests/test_dask_expr_protocol.py index d89434f2c0d..f28d98bd62c 100644 --- a/xarray/tests/test_dask_expr_protocol.py +++ b/xarray/tests/test_dask_expr_protocol.py @@ -132,7 +132,7 @@ def test_standalone_dask_array_mixed_legacy_falls_back_from_composite_expr(): assert ds.__dask_exprs__() is None -def test_standalone_dask_array_mixed_legacy_map_blocks_fallback_computes(): +def test_standalone_dask_array_mixed_legacy_map_blocks_raises(): dask_array = pytest.importorskip("dask_array") ds = Dataset( @@ -142,16 +142,22 @@ def test_standalone_dask_array_mixed_legacy_map_blocks_fallback_computes(): } ) - out = xr.map_blocks(lambda block: block + 1, ds) + with pytest.raises(TypeError, match=r"cannot mix dask_array\.Array"): + xr.map_blocks(lambda block: block + 1, ds) - expected = Dataset( - {"expr": ("i", np.arange(6) + 1), "legacy": ("i", np.arange(6) + 1)} - ) - assert_identical(out.compute(scheduler="single-threaded"), expected) + +def test_standalone_dask_array_mixed_legacy_map_blocks_arg_raises(): + dask_array = pytest.importorskip("dask_array") + + arr = DataArray(dask_array.arange(6, chunks=(3,)), dims="i") + other = DataArray(da.arange(6, chunks=(3,)), dims="i") + + with pytest.raises(TypeError, match=r"cannot mix dask_array\.Array"): + xr.map_blocks(lambda a, b: a + b, arr, args=[other]) def test_standalone_dask_array_open_mfdataset_uses_expressions(tmp_path): - pytest.importorskip("dask_array") + dask_array = pytest.importorskip("dask_array") from dask._expr import CompositeExpr paths = [] @@ -167,7 +173,7 @@ def test_standalone_dask_array_open_mfdataset_uses_expressions(tmp_path): try: expr = dask.base.collections_to_expr(ds) - assert type(ds["x"].data).__module__.startswith("dask_array") + assert isinstance(ds["x"].data, dask_array.Array) assert isinstance(expr, CompositeExpr) expected = Dataset({"x": ("t", np.arange(6))}, coords={"t": np.arange(6)}) @@ -182,6 +188,14 @@ def test_standalone_dask_array_open_mfdataset_uses_expressions(tmp_path): assert_identical( dask.optimize(ds)[0].compute(scheduler="single-threaded"), expected ) + + mapped = xr.map_blocks(lambda block: block + 1, ds) + assert isinstance(mapped["x"].data, dask_array.Array) + assert isinstance(dask.base.collections_to_expr(mapped), CompositeExpr) + assert_identical( + mapped.compute(scheduler="single-threaded"), + Dataset({"x": ("t", np.arange(6) + 1)}, coords={"t": np.arange(6)}), + ) finally: ds.close() @@ -253,20 +267,100 @@ def test_standalone_dask_array_groupby_sum(): assert_equal(dask.optimize(out)[0].compute(scheduler="single-threaded"), expected) -def test_standalone_dask_array_map_blocks_fallback_computes(): +def test_standalone_dask_array_map_blocks_uses_expressions(): dask_array = pytest.importorskip("dask_array") + from dask._expr import CompositeExpr x = dask_array.arange(6, chunks=(3,)) arr = DataArray(x, dims="t", name="x") out = xr.map_blocks(lambda block: block + 1, arr) + assert isinstance(out.data, dask_array.Array) + assert isinstance(dask.base.collections_to_expr(out), CompositeExpr) + expected = DataArray(np.arange(6) + 1, dims="t", name="x") assert_identical(out.compute(scheduler="single-threaded"), expected) + assert_identical( + out.persist(scheduler="single-threaded").compute(scheduler="single-threaded"), + expected, + ) assert_identical( dask.optimize(out)[0].compute(scheduler="single-threaded"), expected ) +def test_standalone_dask_array_map_blocks_preserves_scalar_coords(): + dask_array = pytest.importorskip("dask_array") + from dask._expr import CompositeExpr + + x = dask_array.arange(6, chunks=(3,)).reshape((3, 2)) + arr = DataArray( + x, + dims=("x", "y"), + coords={"label": ("x", ["a", "b", "c"]), "scale": 2}, + name="z", + ) + + out = xr.map_blocks(lambda block: block + block.scale, arr) + + assert isinstance(dask.base.collections_to_expr(out), CompositeExpr) + expected = DataArray( + np.arange(6).reshape((3, 2)) + 2, + dims=("x", "y"), + coords={"label": ("x", ["a", "b", "c"]), "scale": 2}, + ) + assert_identical(out.compute(scheduler="single-threaded"), expected) + + +def test_standalone_dask_array_map_blocks_dataset_outputs_share_block_calls(): + dask_array = pytest.importorskip("dask_array") + from dask._expr import CompositeExpr + + calls = [] + x = dask_array.arange(6, chunks=(3,)) + ds = Dataset({"x": ("t", x)}, coords={"qc": ("t", x + 10)}) + template = Dataset( + {"a": ("t", x), "b": ("t", x)}, + coords={"qc": ("t", x + 10)}, + attrs={"kind": "mapped"}, + ) + + def func(block): + calls.append(block.sizes["t"]) + return Dataset( + {"a": block["x"] + 1, "b": block["x"] + 2}, + coords={"qc": block["qc"]}, + attrs={"kind": "mapped"}, + ) + + out = xr.map_blocks(func, ds, template=template) + + assert isinstance(dask.base.collections_to_expr(out), CompositeExpr) + assert all(isinstance(out[name].data, dask_array.Array) for name in out) + + expected = Dataset( + {"a": ("t", np.arange(6) + 1), "b": ("t", np.arange(6) + 2)}, + coords={"qc": ("t", np.arange(6) + 10)}, + attrs={"kind": "mapped"}, + ) + assert_identical(out.compute(scheduler="single-threaded"), expected) + assert sorted(calls) == [3, 3] + + +def test_standalone_dask_array_map_blocks_reduces_single_chunk_dimension(): + dask_array = pytest.importorskip("dask_array") + from dask._expr import CompositeExpr + + x = dask_array.arange(12, chunks=(12,)).reshape((3, 4)).rechunk((3, 2)) + arr = DataArray(x, dims=("x", "y"), name="z") + + out = xr.map_blocks(lambda block: block.sum("x"), arr) + + assert isinstance(dask.base.collections_to_expr(out), CompositeExpr) + expected = DataArray(np.arange(12).reshape(3, 4).sum(axis=0), dims="y", name="z") + assert_identical(out.compute(scheduler="single-threaded"), expected) + + def test_standalone_dask_array_apply_ufunc_parallelized(): dask_array = pytest.importorskip("dask_array") from dask._expr import CompositeExpr From 87561b87c18cfb494b4069e726de984b4db08d50 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Fri, 5 Jun 2026 07:43:05 -0500 Subject: [PATCH 3/9] Isolate dask_array expression integration Co-Authored-By: Codex --- xarray/core/dask_array_expr.py | 272 +++++++++++++++++ xarray/core/dataset.py | 4 +- xarray/core/parallel.py | 262 ++-------------- xarray/namedarray/parallelcompat.py | 15 +- xarray/tests/test_dask_expr_protocol.py | 391 ++++++++++-------------- 5 files changed, 473 insertions(+), 471 deletions(-) create mode 100644 xarray/core/dask_array_expr.py diff --git a/xarray/core/dask_array_expr.py b/xarray/core/dask_array_expr.py new file mode 100644 index 00000000000..96f81095063 --- /dev/null +++ b/xarray/core/dask_array_expr.py @@ -0,0 +1,272 @@ +from __future__ import annotations + +import itertools +from collections.abc import Callable, Hashable, Mapping, Sequence +from typing import Any + +from xarray.core.coordinates import Coordinates +from xarray.core.dataarray import DataArray +from xarray.core.dataset import Dataset +from xarray.core.utils import is_dask_collection + + +def is_dask_array_expr_array(data: Any) -> bool: + try: + from dask_array import Array + except ImportError: + return False + + return isinstance(data, Array) + + +def collect_dask_array_expr_chunked_data( + xarray_objs: Sequence[Dataset], +) -> tuple[bool, list[Any]]: + chunked_data = [ + variable.data + for arg in xarray_objs + for variable in arg.variables.values() + if is_dask_collection(variable.data) + ] + has_dask_array_expr = any(is_dask_array_expr_array(data) for data in chunked_data) + has_other_chunked = any(not is_dask_array_expr_array(data) for data in chunked_data) + + if has_dask_array_expr and has_other_chunked: + raise TypeError( + "xarray.map_blocks cannot mix dask_array.Array with legacy or other " + "Dask-backed arrays. Convert inputs to one array backend first." + ) + + return has_dask_array_expr, chunked_data + + +def _execute_map_blocks_multi_output(block_spec: Mapping[str, Any], *blocks: Any): + args = [] + for arg_spec in block_spec["args"]: + if arg_spec[0] == "literal": + args.append(arg_spec[1]) + continue + + _, data_vars, coords, attrs = arg_spec + + def build_variables(variable_specs): + variables = [] + for name, dims, data, var_attrs in variable_specs: + if data[0] == "block": + data = blocks[data[1]] + else: + data = data[1] + variables.append((name, (dims, data, var_attrs))) + return dict(variables) + + args.append(Dataset(build_variables(data_vars), build_variables(coords), attrs)) + + return block_spec["wrapper"]( + block_spec["func"], + args, + block_spec["kwargs"], + block_spec["arg_is_array"], + block_spec["expected"], + block_spec["expected_indexes"], + ) + + +def map_blocks_with_dask_array_expr( + *, + func: Callable[..., Any], + npargs: Sequence[Any], + kwargs: Mapping[str, Any], + is_xarray: Sequence[bool], + is_array: Sequence[bool], + input_chunks: Mapping[Hashable, tuple[int, ...]], + output_chunks: Mapping[Hashable, tuple[int, ...]], + coordinates: Coordinates, + template: Dataset, + result_is_array: bool, + template_name: Hashable | None, + gname: str, + ichunk: Mapping[Hashable, range], + input_chunk_bounds: Mapping[Hashable, Any], + output_chunk_bounds: Mapping[Hashable, Any], + computed_variables: Sequence[Hashable], + new_indexes: set[Hashable], + modified_indexes: set[Hashable], + chunked_data: Sequence[Any], + wrapper: Callable[..., Any], + get_chunk_slicer: Callable[[Hashable, Mapping[Any, Any], Mapping[Any, Any]], slice], + dataset_to_dataarray: Callable[[Dataset], DataArray], +) -> DataArray | Dataset: + missing_chunked_dims = { + dim + for dim, chunks in input_chunks.items() + if len(chunks) > 1 and dim not in output_chunks + } + if missing_chunked_dims: + raise NotImplementedError( + "dask_array-backed xarray.map_blocks does not yet support " + "dropping multi-chunk dimensions. Rechunk these dimensions to " + f"one chunk first: {sorted(missing_chunked_dims)!r}." + ) + + from xarray.namedarray.parallelcompat import get_chunked_array_type + + chunkmanager = get_chunked_array_type(*chunked_data) + map_blocks_multi_output = getattr(chunkmanager, "map_blocks_multi_output", None) + if map_blocks_multi_output is None: + raise NotImplementedError( + "The dask_array chunk manager does not support map_blocks_multi_output." + ) + + input_exprs = [] + input_indices = [] + arg_templates = [] + for isxr, arg in zip(is_xarray, npargs, strict=True): + if not isxr: + if is_dask_collection(arg): + raise TypeError( + "dask_array-backed xarray.map_blocks only supports Dask " + "collections inside xarray arguments." + ) + arg_templates.append(("literal", arg)) + continue + + variable_templates = [] + for name, variable in arg.variables.items(): + is_coord = name in arg._coord_names + if is_dask_collection(variable.data): + input_exprs.append(variable.data.expr) + input_indices.append(variable.dims) + variable_templates.append( + ( + "chunked", + name, + variable.dims, + variable.attrs, + len(input_exprs) - 1, + is_coord, + None, + ) + ) + else: + variable_templates.append( + ( + "static", + name, + variable.dims, + variable.attrs, + None, + is_coord, + variable, + ) + ) + arg_templates.append(("xarray", arg.attrs, variable_templates)) + + def build_block_specs(): + specs = {} + for chunk_tuple in itertools.product(*ichunk.values()): + chunk_index = dict(zip(ichunk.keys(), chunk_tuple, strict=True)) + arg_specs = [] + for arg_template in arg_templates: + if arg_template[0] == "literal": + arg_specs.append(arg_template) + continue + + _, attrs, variable_templates = arg_template + data_vars = [] + coords = [] + for ( + kind, + name, + dims, + attrs, + input_position, + is_coord, + variable, + ) in variable_templates: + if kind == "chunked": + data = ("block", input_position) + else: + assert variable is not None + subsetter = { + dim: get_chunk_slicer(dim, chunk_index, input_chunk_bounds) + for dim in variable.dims + } + data = ("static", variable.isel(subsetter)._data) + + target = coords if is_coord else data_vars + target.append((name, dims, data, attrs)) + + arg_specs.append(("xarray", data_vars, coords, attrs)) + + indexes = { + dim: coordinates.xindexes[dim][ + get_chunk_slicer(dim, chunk_index, output_chunk_bounds) + ] + for dim in (new_indexes | modified_indexes) + } + expected = { + "shapes": { + k: output_chunks[k][v] + for k, v in chunk_index.items() + if k in output_chunks + }, + "data_vars": set(template.data_vars.keys()), + "coords": set(template.coords.keys()), + } + specs[chunk_tuple] = { + "wrapper": wrapper, + "func": func, + "args": arg_specs, + "kwargs": kwargs, + "arg_is_array": is_array, + "expected": expected, + "expected_indexes": indexes, + } + return specs + + outputs = [] + for name in computed_variables: + variable = template.variables[name] + var_chunks = [] + for dim in variable.dims: + if dim in output_chunks: + var_chunks.append(output_chunks[dim]) + elif dim in template.dims: + var_chunks.append((template.sizes[dim],)) + + outputs.append( + { + "key": name, + "indices": variable.dims, + "chunks": tuple(var_chunks), + "dtype": variable.dtype, + "name": f"{name}-{gname}", + } + ) + + mapped_arrays = map_blocks_multi_output( + _execute_map_blocks_multi_output, + input_exprs, + input_indices, + tuple(input_chunks), + build_block_specs(), + outputs, + token=gname, + ) + + result = Dataset(coords=coordinates, attrs=template.attrs) + for index in result._indexes: + result[index].attrs = template[index].attrs + result[index].encoding = template[index].encoding + + for name, data in zip(computed_variables, mapped_arrays, strict=True): + result[name] = (template[name].dims, data, template[name].attrs) + result[name].encoding = template[name].encoding + + result = result.set_coords(template._coord_names) + + if result_is_array: + da = dataset_to_dataarray(result) + da.name = template_name + return da + return result diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index d5c4f92f9ed..68681f1955e 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -675,10 +675,12 @@ def __dask_keys__(self): def __dask_exprs__(self): import dask + from xarray.core.dask_array_expr import is_dask_array_expr_array + exprs = [] for v in self.variables.values(): if dask.is_dask_collection(v): - if not hasattr(v._data, "expr"): + if not is_dask_array_expr_array(v._data): return None exprs.append(v._data.expr) return exprs or None diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 82e97e652de..bb500657eed 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -149,37 +149,6 @@ def make_dict(x: DataArray | Dataset) -> dict[Hashable, Any]: return {k: v.data for k, v in x.variables.items()} -def _execute_map_blocks_multi_output(block_spec, *blocks): - args = [] - for arg_spec in block_spec["args"]: - if arg_spec[0] == "literal": - args.append(arg_spec[1]) - continue - - _, data_vars, coords, attrs = arg_spec - - def build_variables(variable_specs): - variables = [] - for name, dims, data, var_attrs in variable_specs: - if data[0] == "block": - data = blocks[data[1]] - else: - data = data[1] - variables.append((name, (dims, data, var_attrs))) - return dict(variables) - - args.append(Dataset(build_variables(data_vars), build_variables(coords), attrs)) - - return block_spec["wrapper"]( - block_spec["func"], - args, - block_spec["kwargs"], - block_spec["arg_is_array"], - block_spec["expected"], - block_spec["expected_indexes"], - ) - - def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping): if dim in chunk_index: which_chunk = chunk_index[dim] @@ -472,32 +441,16 @@ def _wrapper( for arg in aligned ) - try: - from dask_array import Array as DaskArrayExprArray - except ImportError: - DaskArrayExprArray = () + from xarray.core.dask_array_expr import collect_dask_array_expr_chunked_data - chunked_data = [ - variable.data - for arg in xarray_objs - for variable in arg.variables.values() - if is_dask_collection(variable.data) - ] - has_dask_array_expr = any( - isinstance(data, DaskArrayExprArray) for data in chunked_data - ) - has_other_chunked = any( - not isinstance(data, DaskArrayExprArray) for data in chunked_data + has_dask_array_expr, chunked_data = collect_dask_array_expr_chunked_data( + xarray_objs ) - if has_dask_array_expr and has_other_chunked: - raise TypeError( - "xarray.map_blocks cannot mix dask_array.Array with legacy or other " - "Dask-backed arrays. Convert inputs to one array backend first." - ) - # rechunk any numpy variables appropriately xarray_objs = tuple(arg.chunk(arg.chunksizes) for arg in xarray_objs) + if has_dask_array_expr: + _, chunked_data = collect_dask_array_expr_chunked_data(xarray_objs) merged_coordinates = merge( [arg.coords for arg in aligned], @@ -577,6 +530,7 @@ def _wrapper( template = template._to_temp_dataset() elif isinstance(template, Dataset): result_is_array = False + template_name = None else: raise TypeError( f"func output must be DataArray or Dataset; got {type(template)}" @@ -606,185 +560,33 @@ def _wrapper( name for name in template.variables if name not in coordinates.indexes ] - chunked_data = [ - variable.data - for arg in xarray_objs - for variable in arg.variables.values() - if is_dask_collection(variable.data) - ] - if has_dask_array_expr: - missing_chunked_dims = { - dim - for dim, chunks in input_chunks.items() - if len(chunks) > 1 and dim not in output_chunks - } - if missing_chunked_dims: - raise NotImplementedError( - "dask_array-backed xarray.map_blocks does not yet support " - "dropping multi-chunk dimensions. Rechunk these dimensions to " - f"one chunk first: {sorted(missing_chunked_dims)!r}." - ) - - from xarray.namedarray.parallelcompat import get_chunked_array_type - - chunkmanager = get_chunked_array_type(*chunked_data) - - input_exprs = [] - input_indices = [] - arg_templates = [] - for isxr, arg in zip(is_xarray, npargs, strict=True): - if not isxr: - if is_dask_collection(arg): - raise TypeError( - "dask_array-backed xarray.map_blocks only supports Dask " - "collections inside xarray arguments." - ) - arg_templates.append(("literal", arg)) - continue - - variable_templates = [] - for name, variable in arg.variables.items(): - is_coord = name in arg._coord_names - if is_dask_collection(variable.data): - input_exprs.append(variable.data.expr) - input_indices.append(variable.dims) - variable_templates.append( - ( - "chunked", - name, - variable.dims, - variable.attrs, - len(input_exprs) - 1, - is_coord, - None, - ) - ) - else: - variable_templates.append( - ( - "static", - name, - variable.dims, - variable.attrs, - None, - is_coord, - variable, - ) - ) - arg_templates.append(("xarray", arg.attrs, variable_templates)) - - def build_block_specs(): - specs = {} - for chunk_tuple in itertools.product(*ichunk.values()): - chunk_index = dict(zip(ichunk.keys(), chunk_tuple, strict=True)) - arg_specs = [] - for arg_template in arg_templates: - if arg_template[0] == "literal": - arg_specs.append(arg_template) - continue - - _, attrs, variable_templates = arg_template - data_vars = [] - coords = [] - for ( - kind, - name, - dims, - attrs, - input_position, - is_coord, - variable, - ) in variable_templates: - if kind == "chunked": - data = ("block", input_position) - else: - assert variable is not None - subsetter = { - dim: _get_chunk_slicer( - dim, chunk_index, input_chunk_bounds - ) - for dim in variable.dims - } - data = ("static", variable.isel(subsetter)._data) - - target = coords if is_coord else data_vars - target.append((name, dims, data, attrs)) - - arg_specs.append(("xarray", data_vars, coords, attrs)) - - indexes = { - dim: coordinates.xindexes[dim][ - _get_chunk_slicer(dim, chunk_index, output_chunk_bounds) - ] - for dim in (new_indexes | modified_indexes) - } - expected: ExpectedDict = { - "shapes": { - k: output_chunks[k][v] - for k, v in chunk_index.items() - if k in output_chunks - }, - "data_vars": set(template.data_vars.keys()), - "coords": set(template.coords.keys()), - } - specs[chunk_tuple] = { - "wrapper": _wrapper, - "func": func, - "args": arg_specs, - "kwargs": kwargs, - "arg_is_array": is_array, - "expected": expected, - "expected_indexes": indexes, - } - return specs - - outputs = [] - for name in computed_variables: - variable = template.variables[name] - var_chunks = [] - for dim in variable.dims: - if dim in output_chunks: - var_chunks.append(output_chunks[dim]) - elif dim in template.dims: - var_chunks.append((template.sizes[dim],)) - - outputs.append( - { - "key": name, - "indices": variable.dims, - "chunks": tuple(var_chunks), - "dtype": variable.dtype, - "name": f"{name}-{gname}", - } - ) - - mapped_arrays = chunkmanager.map_blocks_multi_output( - _execute_map_blocks_multi_output, - input_exprs, - input_indices, - tuple(input_chunks), - build_block_specs(), - outputs, - token=gname, - ) - - result = Dataset(coords=coordinates, attrs=template.attrs) - for index in result._indexes: - result[index].attrs = template[index].attrs - result[index].encoding = template[index].encoding - - for name, data in zip(computed_variables, mapped_arrays, strict=True): - result[name] = (template[name].dims, data, template[name].attrs) - result[name].encoding = template[name].encoding - - result = result.set_coords(template._coord_names) - - if result_is_array: - da = dataset_to_dataarray(result) - da.name = template_name - return da # type: ignore[return-value] - return result # type: ignore[return-value] + from xarray.core.dask_array_expr import map_blocks_with_dask_array_expr + + return map_blocks_with_dask_array_expr( + func=func, + npargs=npargs, + kwargs=kwargs, + is_xarray=is_xarray, + is_array=is_array, + input_chunks=input_chunks, + output_chunks=output_chunks, + coordinates=coordinates, + template=template, + result_is_array=result_is_array, + template_name=template_name, + gname=gname, + ichunk=ichunk, + input_chunk_bounds=input_chunk_bounds, + output_chunk_bounds=output_chunk_bounds, + computed_variables=computed_variables, + new_indexes=new_indexes, + modified_indexes=modified_indexes, + chunked_data=chunked_data, + wrapper=_wrapper, + get_chunk_slicer=_get_chunk_slicer, + dataset_to_dataarray=dataset_to_dataarray, + ) # type: ignore[return-value] # iterate over all possible chunk combinations for chunk_tuple in itertools.product(*ichunk.values()): diff --git a/xarray/namedarray/parallelcompat.py b/xarray/namedarray/parallelcompat.py index 334e464a1fc..8a68f5e9562 100644 --- a/xarray/namedarray/parallelcompat.py +++ b/xarray/namedarray/parallelcompat.py @@ -8,7 +8,7 @@ import functools from abc import ABC, abstractmethod -from collections.abc import Callable, Iterable, Mapping, Sequence +from collections.abc import Callable, Iterable, Sequence from importlib.metadata import EntryPoint, entry_points from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar @@ -655,19 +655,6 @@ def map_blocks( """ raise NotImplementedError() - def map_blocks_multi_output( - self, - func: Callable[..., Any], - input_exprs: Sequence[Any], - input_indices: Sequence[Iterable[Any]], - shared_indices: Iterable[Any], - block_specs: Mapping[tuple[int, ...], Any], - outputs: Sequence[Mapping[str, Any]], - *, - token: str, - ) -> list[Any]: - raise NotImplementedError() - def blockwise( self, func: Callable[..., Any], diff --git a/xarray/tests/test_dask_expr_protocol.py b/xarray/tests/test_dask_expr_protocol.py index f28d98bd62c..c2a1bd904e7 100644 --- a/xarray/tests/test_dask_expr_protocol.py +++ b/xarray/tests/test_dask_expr_protocol.py @@ -5,161 +5,84 @@ import xarray as xr from xarray import DataArray, Dataset -from xarray.testing import assert_equal, assert_identical +from xarray.testing import assert_identical +from xarray.tests import requires_scipy_or_netCDF4 dask = pytest.importorskip("dask") da = pytest.importorskip("dask.array") +dask_array = pytest.importorskip("dask_array") +try: + from dask._expr import CompositeExpr, HLGExpr +except ImportError: + pytest.skip("requires Dask composite expressions", allow_module_level=True) -def test_standalone_dask_array_dataset_composite_expr_protocol(): - dask_array = pytest.importorskip("dask_array") - from dask._expr import CompositeExpr - x = dask_array.arange(6, chunks=(3,)) - ds = Dataset( - {"foo": ("i", x + 1)}, - coords={"coord": ("i", x + 10)}, - attrs={"source": "test"}, - ) - ds["foo"].encoding["example"] = "kept" - - expr = dask.base.collections_to_expr(ds) - - assert isinstance(expr, CompositeExpr) - assert len(expr.exprs) == 2 - - expected = Dataset( - {"foo": ("i", np.arange(6) + 1)}, - coords={"coord": ("i", np.arange(6) + 10)}, - attrs={"source": "test"}, - ) - computed = dask.compute(ds, scheduler="single-threaded")[0] - assert_identical(computed, expected) - assert computed["foo"].encoding["example"] == "kept" - - persisted = dask.persist(ds, scheduler="single-threaded")[0] - assert hasattr(persisted["foo"].data, "expr") - assert_identical(persisted.compute(), expected) - assert persisted["foo"].encoding["example"] == "kept" - - optimized = dask.optimize(ds)[0] - assert hasattr(optimized["foo"].data, "expr") - assert_identical(optimized.compute(), expected) - - -def test_standalone_dask_array_dataset_variable_named_expr(): - dask_array = pytest.importorskip("dask_array") - from dask._expr import CompositeExpr - - x = dask_array.arange(6, chunks=(3,)) - ds = Dataset({"expr": ("i", x + 1)}) +def test_dataset_composite_expr_protocol_simple(): + x = dask_array.arange(3, chunks=(1,)) + ds = Dataset({"x": ("i", x + 1)}) expr = dask.base.collections_to_expr(ds) assert isinstance(expr, CompositeExpr) + assert len(expr.exprs) == 1 assert_identical( dask.compute(ds, scheduler="single-threaded")[0], - Dataset({"expr": ("i", np.arange(6) + 1)}), + Dataset({"x": ("i", np.arange(3) + 1)}), ) -def test_standalone_dask_array_dataarray_composite_expr_protocol(): - dask_array = pytest.importorskip("dask_array") - from dask._expr import CompositeExpr - +def test_dataarray_composite_expr_protocol_includes_chunked_coord(): x = dask_array.arange(6, chunks=(3,)) - arr = DataArray(x + 1, dims=("i",), coords={"coord": ("i", x)}, name="z") + arr = DataArray(x + 1, dims=("i",), coords={"coord": ("i", x + 10)}, name="z") expr = dask.base.collections_to_expr(arr) assert isinstance(expr, CompositeExpr) assert len(expr.exprs) == 2 - - expected = DataArray( - np.arange(6) + 1, - dims=("i",), - coords={"coord": ("i", np.arange(6))}, - name="z", - ) - computed = dask.compute(arr, scheduler="single-threaded")[0] - assert_identical(computed, expected) - - persisted = dask.persist(arr, scheduler="single-threaded")[0] - assert hasattr(persisted.data, "expr") - assert_identical(persisted.compute(), expected) - - -def test_standalone_dask_array_dataset_computes_with_raw_array(): - dask_array = pytest.importorskip("dask_array") - - x = dask_array.arange(6, chunks=(3,)) - ds = Dataset({"foo": ("i", x + 1)}) - - computed_ds, computed_x = dask.compute(ds, x, scheduler="single-threaded") - - assert_identical(computed_ds, Dataset({"foo": ("i", np.arange(6) + 1)})) - np.testing.assert_array_equal(computed_x, np.arange(6)) - - -def test_standalone_dask_array_optimize_culls_child_graphs(): - dask_array = pytest.importorskip("dask_array") - from dask.core import flatten - - x = dask_array.arange(6, chunks=(3,)) - ds = Dataset({"foo": ("i", x + 1), "bar": ("i", x + 2)}) - - optimized = dask.optimize(ds)[0] - - foo_graph_keys = set(optimized["foo"].data.__dask_graph__()) - bar_output_keys = set(flatten(optimized["bar"].data.__dask_keys__())) - assert not foo_graph_keys & bar_output_keys assert_identical( - optimized.compute(), - Dataset({"foo": ("i", np.arange(6) + 1), "bar": ("i", np.arange(6) + 2)}), + dask.compute(arr, scheduler="single-threaded")[0], + DataArray( + np.arange(6) + 1, + dims=("i",), + coords={"coord": ("i", np.arange(6) + 10)}, + name="z", + ), ) -def test_standalone_dask_array_mixed_legacy_falls_back_from_composite_expr(): - dask_array = pytest.importorskip("dask_array") - +def test_dataset_compute_persist_optimize_end_to_end(): + x = dask_array.arange(6, chunks=(3,)) ds = Dataset( - { - "expr": ("i", dask_array.arange(3, chunks=(1,))), - "legacy": ("i", da.arange(3, chunks=(1,))), - } + {"foo": ("i", x + 1)}, + coords={"coord": ("i", x + 10)}, + attrs={"source": "test"}, ) + ds["foo"].encoding["example"] = "kept" - assert ds.__dask_exprs__() is None - - -def test_standalone_dask_array_mixed_legacy_map_blocks_raises(): - dask_array = pytest.importorskip("dask_array") - - ds = Dataset( - { - "expr": ("i", dask_array.arange(6, chunks=(3,))), - "legacy": ("i", da.arange(6, chunks=(3,))), - } + expected = Dataset( + {"foo": ("i", np.arange(6) + 1)}, + coords={"coord": ("i", np.arange(6) + 10)}, + attrs={"source": "test"}, ) - with pytest.raises(TypeError, match=r"cannot mix dask_array\.Array"): - xr.map_blocks(lambda block: block + 1, ds) - - -def test_standalone_dask_array_mixed_legacy_map_blocks_arg_raises(): - dask_array = pytest.importorskip("dask_array") - - arr = DataArray(dask_array.arange(6, chunks=(3,)), dims="i") - other = DataArray(da.arange(6, chunks=(3,)), dims="i") + computed_ds, computed_x = dask.compute(ds, x, scheduler="single-threaded") + assert_identical(computed_ds, expected) + np.testing.assert_array_equal(computed_x, np.arange(6)) + assert computed_ds["foo"].encoding["example"] == "kept" - with pytest.raises(TypeError, match=r"cannot mix dask_array\.Array"): - xr.map_blocks(lambda a, b: a + b, arr, args=[other]) + persisted = dask.persist(ds, scheduler="single-threaded")[0] + assert isinstance(persisted["foo"].data, dask_array.Array) + assert_identical(persisted.compute(), expected) + assert persisted["foo"].encoding["example"] == "kept" + optimized = dask.optimize(ds)[0] + assert isinstance(optimized["foo"].data, dask_array.Array) + assert_identical(optimized.compute(), expected) -def test_standalone_dask_array_open_mfdataset_uses_expressions(tmp_path): - dask_array = pytest.importorskip("dask_array") - from dask._expr import CompositeExpr +@requires_scipy_or_netCDF4 +def test_open_mfdataset_map_blocks_end_to_end(tmp_path): paths = [] for i in range(2): path = tmp_path / f"part-{i}.nc" @@ -169,12 +92,9 @@ def test_standalone_dask_array_open_mfdataset_uses_expressions(tmp_path): ).to_netcdf(path) paths.append(path) - ds = xr.open_mfdataset(paths, chunks={"t": 2}, combine="by_coords") - try: - expr = dask.base.collections_to_expr(ds) - + with xr.open_mfdataset(paths, chunks={"t": 2}, combine="by_coords") as ds: assert isinstance(ds["x"].data, dask_array.Array) - assert isinstance(expr, CompositeExpr) + assert isinstance(dask.base.collections_to_expr(ds), CompositeExpr) expected = Dataset({"x": ("t", np.arange(6))}, coords={"t": np.arange(6)}) assert_identical(ds.compute(scheduler="single-threaded"), expected) @@ -196,81 +116,42 @@ def test_standalone_dask_array_open_mfdataset_uses_expressions(tmp_path): mapped.compute(scheduler="single-threaded"), Dataset({"x": ("t", np.arange(6) + 1)}, coords={"t": np.arange(6)}), ) - finally: - ds.close() - - -def test_standalone_dask_array_shared_subexpressions_and_chunked_coords(): - dask_array = pytest.importorskip("dask_array") - - x = dask_array.arange(6, chunks=(3,)) - ds = Dataset( - {"a": ("t", x + 1), "b": ("t", x * 2)}, - coords={"qc": ("t", x + 10)}, - attrs={"case": "shared"}, - ) - - expected = Dataset( - {"a": ("t", np.arange(6) + 1), "b": ("t", np.arange(6) * 2)}, - coords={"qc": ("t", np.arange(6) + 10)}, - attrs={"case": "shared"}, - ) - assert len(dask.base.collections_to_expr(ds).exprs) == 3 - assert_identical(ds.compute(scheduler="single-threaded"), expected) - assert_identical( - dask.persist(ds, scheduler="single-threaded")[0].compute( - scheduler="single-threaded" - ), - expected, - ) - assert_identical( - dask.optimize(ds)[0].compute(scheduler="single-threaded"), expected - ) - -def test_standalone_dask_array_rechunk_reduction_chain(): - dask_array = pytest.importorskip("dask_array") - from dask._expr import CompositeExpr - x = dask_array.arange(12, chunks=(4,)).reshape((3, 4)) - out = Dataset({"x": (("a", "b"), x)}).chunk({"a": 1, "b": 2}).sum("b") + 1 +@requires_scipy_or_netCDF4 +def test_open_dataset_rechunk_optimization_crosses_composite_expr(tmp_path): + from dask_array._rechunk import Rechunk - assert isinstance(dask.base.collections_to_expr(out), CompositeExpr) - expected = Dataset({"x": ("a", np.arange(12).reshape(3, 4).sum(axis=1) + 1)}) - assert_identical(out.compute(scheduler="single-threaded"), expected) - assert_identical( - dask.optimize(out)[0].compute(scheduler="single-threaded"), expected - ) + path = tmp_path / "data.nc" + Dataset({"x": ("t", np.arange(12))}, coords={"t": np.arange(12)}).to_netcdf(path) + with xr.open_dataset(path, chunks={"t": 3}) as ds: + out = ds.chunk({"t": 4}) + expr = dask.base.collections_to_expr(out) + source_expr = ds["x"].data.expr + expected_operands = list(source_expr.operands) + expected_operands[source_expr._parameters.index("_chunks")] = ((4, 4, 4),) + expected_expr = type(source_expr)(*expected_operands).optimize() -def test_standalone_dask_array_groupby_sum(): - dask_array = pytest.importorskip("dask_array") - from dask._expr import CompositeExpr + assert isinstance(expr, CompositeExpr) + assert len(expr.exprs) == 1 + assert list(expr.exprs[0].find_operations(Rechunk)) - x = dask_array.arange(6, chunks=(3,)) - arr = DataArray( - x, - dims="t", - coords={"label": ("t", np.array(["a", "a", "b", "b", "a", "b"]))}, - name="x", - ) - out = arr.groupby("label").sum() + optimized_expr = expr.optimize() - assert isinstance(dask.base.collections_to_expr(out), CompositeExpr) - expected = DataArray( - [5, 10], - dims="label", - coords={"label": np.array(["a", "b"], dtype=object)}, - name="x", - ) - assert_equal(out.compute(scheduler="single-threaded"), expected) - assert_equal(dask.optimize(out)[0].compute(scheduler="single-threaded"), expected) + assert optimized_expr.exprs[0]._name == expected_expr._name + assert not list(optimized_expr.exprs[0].find_operations(Rechunk)) + optimized = dask.optimize(out)[0] + assert isinstance(optimized["x"].data, dask_array.Array) + assert not list(optimized["x"].data.expr.find_operations(Rechunk)) + assert_identical( + optimized.compute(scheduler="single-threaded"), + Dataset({"x": ("t", np.arange(12))}, coords={"t": np.arange(12)}), + ) -def test_standalone_dask_array_map_blocks_uses_expressions(): - dask_array = pytest.importorskip("dask_array") - from dask._expr import CompositeExpr +def test_map_blocks_dataarray_end_to_end(): x = dask_array.arange(6, chunks=(3,)) arr = DataArray(x, dims="t", name="x") out = xr.map_blocks(lambda block: block + 1, arr) @@ -289,33 +170,7 @@ def test_standalone_dask_array_map_blocks_uses_expressions(): ) -def test_standalone_dask_array_map_blocks_preserves_scalar_coords(): - dask_array = pytest.importorskip("dask_array") - from dask._expr import CompositeExpr - - x = dask_array.arange(6, chunks=(3,)).reshape((3, 2)) - arr = DataArray( - x, - dims=("x", "y"), - coords={"label": ("x", ["a", "b", "c"]), "scale": 2}, - name="z", - ) - - out = xr.map_blocks(lambda block: block + block.scale, arr) - - assert isinstance(dask.base.collections_to_expr(out), CompositeExpr) - expected = DataArray( - np.arange(6).reshape((3, 2)) + 2, - dims=("x", "y"), - coords={"label": ("x", ["a", "b", "c"]), "scale": 2}, - ) - assert_identical(out.compute(scheduler="single-threaded"), expected) - - -def test_standalone_dask_array_map_blocks_dataset_outputs_share_block_calls(): - dask_array = pytest.importorskip("dask_array") - from dask._expr import CompositeExpr - +def test_map_blocks_dataset_outputs_share_block_calls(): calls = [] x = dask_array.arange(6, chunks=(3,)) ds = Dataset({"x": ("t", x)}, coords={"qc": ("t", x + 10)}) @@ -337,6 +192,7 @@ def func(block): assert isinstance(dask.base.collections_to_expr(out), CompositeExpr) assert all(isinstance(out[name].data, dask_array.Array) for name in out) + assert sorted(calls) == [] expected = Dataset( {"a": ("t", np.arange(6) + 1), "b": ("t", np.arange(6) + 2)}, @@ -347,24 +203,91 @@ def func(block): assert sorted(calls) == [3, 3] -def test_standalone_dask_array_map_blocks_reduces_single_chunk_dimension(): - dask_array = pytest.importorskip("dask_array") - from dask._expr import CompositeExpr +def test_map_blocks_preserves_scalar_coords(): + x = dask_array.arange(6, chunks=(3,)).reshape((3, 2)) + arr = DataArray( + x, + dims=("x", "y"), + coords={"label": ("x", ["a", "b", "c"]), "scale": 2}, + name="z", + ) + + out = xr.map_blocks(lambda block: block + block.scale, arr) + + assert isinstance(dask.base.collections_to_expr(out), CompositeExpr) + assert_identical( + out.compute(scheduler="single-threaded"), + DataArray( + np.arange(6).reshape((3, 2)) + 2, + dims=("x", "y"), + coords={"label": ("x", ["a", "b", "c"]), "scale": 2}, + ), + ) + +def test_map_blocks_reduces_single_chunk_dimension(): x = dask_array.arange(12, chunks=(12,)).reshape((3, 4)).rechunk((3, 2)) arr = DataArray(x, dims=("x", "y"), name="z") out = xr.map_blocks(lambda block: block.sum("x"), arr) assert isinstance(dask.base.collections_to_expr(out), CompositeExpr) - expected = DataArray(np.arange(12).reshape(3, 4).sum(axis=0), dims="y", name="z") - assert_identical(out.compute(scheduler="single-threaded"), expected) + assert_identical( + out.compute(scheduler="single-threaded"), + DataArray(np.arange(12).reshape(3, 4).sum(axis=0), dims="y", name="z"), + ) -def test_standalone_dask_array_apply_ufunc_parallelized(): - dask_array = pytest.importorskip("dask_array") - from dask._expr import CompositeExpr +def test_mixed_legacy_inputs_do_not_use_composite_path(): + ds = Dataset( + { + "x": ("i", dask_array.arange(3, chunks=(1,))), + "legacy": ("i", da.arange(3, chunks=(1,))), + } + ) + + assert ds.__dask_exprs__() is None + assert isinstance(dask.base.collections_to_expr(ds), HLGExpr) + + with pytest.raises(TypeError, match=r"cannot mix dask_array\.Array"): + xr.map_blocks(lambda block: block + 1, ds) + + arr = DataArray(dask_array.arange(6, chunks=(3,)), dims="i") + other = DataArray(da.arange(6, chunks=(3,)), dims="i") + with pytest.raises(TypeError, match=r"cannot mix dask_array\.Array"): + xr.map_blocks(lambda a, b: a + b, arr, args=[other]) + + +def test_shared_subexpressions_optimize_without_cross_contamination(): + from dask.core import flatten + x = dask_array.arange(6, chunks=(3,)) + ds = Dataset({"foo": ("i", x + 1), "bar": ("i", x + 2)}) + + optimized = dask.optimize(ds)[0] + + foo_graph_keys = set(optimized["foo"].data.__dask_graph__()) + bar_output_keys = set(flatten(optimized["bar"].data.__dask_keys__())) + assert not foo_graph_keys & bar_output_keys + assert_identical( + optimized.compute(), + Dataset({"foo": ("i", np.arange(6) + 1), "bar": ("i", np.arange(6) + 2)}), + ) + + +def test_rechunk_reduction_chain_uses_composite_expr(): + x = dask_array.arange(12, chunks=(4,)).reshape((3, 4)) + out = Dataset({"x": (("a", "b"), x)}).chunk({"a": 1, "b": 2}).sum("b") + 1 + + assert isinstance(dask.base.collections_to_expr(out), CompositeExpr) + expected = Dataset({"x": ("a", np.arange(12).reshape(3, 4).sum(axis=1) + 1)}) + assert_identical(out.compute(scheduler="single-threaded"), expected) + assert_identical( + dask.optimize(out)[0].compute(scheduler="single-threaded"), expected + ) + + +def test_apply_ufunc_parallelized_uses_composite_expr(): x = dask_array.arange(6, chunks=(3,)) arr = DataArray(x, dims="t", name="x") out = xr.apply_ufunc(lambda z: z + 2, arr, dask="parallelized", output_dtypes=[int]) @@ -374,3 +297,19 @@ def test_standalone_dask_array_apply_ufunc_parallelized(): out.compute(scheduler="single-threaded"), DataArray(np.arange(6) + 2, dims="t", name="x"), ) + + +def test_groupby_sum_currently_falls_back_after_legacy_conversion(): + x = dask_array.arange(6, chunks=(3,)) + arr = DataArray( + x, + dims="t", + coords={"label": ("t", np.array(["a", "a", "b", "b", "a", "b"]))}, + name="x", + ) + + with pytest.warns(UserWarning, match="already a Dask collection"): + out = arr.groupby("label").sum() + + assert not isinstance(out.data, dask_array.Array) + assert isinstance(dask.base.collections_to_expr(out), HLGExpr) From fd7a9c81b6a2431cefe99fdc141f55c70c05fc46 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sun, 7 Jun 2026 00:47:00 +0000 Subject: [PATCH 4/9] Use active chunked array backend --- xarray/compat/dask_array_compat.py | 14 +++++++++++--- xarray/compat/dask_array_ops.py | 12 +++++++++--- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/xarray/compat/dask_array_compat.py b/xarray/compat/dask_array_compat.py index b8c7da3e64f..e1648c4e34b 100644 --- a/xarray/compat/dask_array_compat.py +++ b/xarray/compat/dask_array_compat.py @@ -1,5 +1,6 @@ from typing import Any +from xarray.namedarray.parallelcompat import get_chunked_array_type from xarray.namedarray.utils import module_available @@ -8,10 +9,17 @@ def reshape_blockwise( shape: int | tuple[int, ...], chunks: tuple[tuple[int, ...], ...] | None = None, ): - if module_available("dask", "2024.08.2"): - from dask.array import reshape_blockwise + try: + array_api = get_chunked_array_type(x).array_api + except TypeError: + array_api = None - return reshape_blockwise(x, shape=shape, chunks=chunks) + if array_api is not None and hasattr(array_api, "reshape_blockwise"): + return array_api.reshape_blockwise(x, shape=shape, chunks=chunks) + elif module_available("dask", "2024.08.2"): + from dask.array import reshape_blockwise as dask_reshape_blockwise + + return dask_reshape_blockwise(x, shape=shape, chunks=chunks) else: return x.reshape(shape) diff --git a/xarray/compat/dask_array_ops.py b/xarray/compat/dask_array_ops.py index 9534351dbfd..ad52a175bd3 100644 --- a/xarray/compat/dask_array_ops.py +++ b/xarray/compat/dask_array_ops.py @@ -4,6 +4,7 @@ from xarray.compat.dask_array_compat import reshape_blockwise from xarray.core import dtypes, nputils +from xarray.namedarray.parallelcompat import get_chunked_array_type def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1): @@ -20,7 +21,7 @@ def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1): def least_squares(lhs, rhs, rcond=None, skipna=False): - import dask.array as da + da = get_chunked_array_type(rhs).array_api # The trick here is that the core dimension is axis 0. # All other dimensions need to be reshaped down to one axis for `lstsq` @@ -94,11 +95,16 @@ def push(array, n, axis, method="blelloch"): """ Dask-aware bottleneck.push """ - import dask.array as da import numpy as np from xarray.core.duck_array_ops import _push from xarray.core.nputils import nanlast + from xarray.namedarray.parallelcompat import get_chunked_array_type + + da = get_chunked_array_type(array).array_api + cumreduction = getattr(da, "cumreduction", None) + if cumreduction is None: + cumreduction = da.reductions.cumreduction if n is not None and all(n <= size for size in array.chunks[axis]): return array.map_overlap(_push, depth={axis: (n, 0)}, n=n, axis=axis) @@ -106,7 +112,7 @@ def push(array, n, axis, method="blelloch"): # TODO: Replace all this function # once https://github.com/pydata/xarray/issues/9229 being implemented - pushed_array = da.reductions.cumreduction( + pushed_array = cumreduction( func=_dtype_push, binop=_fill_with_last_one, ident=np.nan, From 4b17760aee7d6ceafa51df3616d7d03db58295fa Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Thu, 11 Jun 2026 20:56:02 -0500 Subject: [PATCH 5/9] Fix dask_array expression typing checks [skip-rtd] --- xarray/core/dask_array_expr.py | 29 ++++++++++++++----------- xarray/tests/test_dask_expr_protocol.py | 12 ++++++---- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/xarray/core/dask_array_expr.py b/xarray/core/dask_array_expr.py index 96f81095063..1e0290449f0 100644 --- a/xarray/core/dask_array_expr.py +++ b/xarray/core/dask_array_expr.py @@ -2,6 +2,7 @@ import itertools from collections.abc import Callable, Hashable, Mapping, Sequence +from importlib import import_module from typing import Any from xarray.core.coordinates import Coordinates @@ -12,11 +13,12 @@ def is_dask_array_expr_array(data: Any) -> bool: try: - from dask_array import Array + dask_array = import_module("dask_array") except ImportError: return False - return isinstance(data, Array) + array_type = getattr(dask_array, "Array", None) + return array_type is not None and isinstance(data, array_type) def collect_dask_array_expr_chunked_data( @@ -105,7 +107,7 @@ def map_blocks_with_dask_array_expr( raise NotImplementedError( "dask_array-backed xarray.map_blocks does not yet support " "dropping multi-chunk dimensions. Rechunk these dimensions to " - f"one chunk first: {sorted(missing_chunked_dims)!r}." + f"one chunk first: {sorted(missing_chunked_dims, key=repr)!r}." ) from xarray.namedarray.parallelcompat import get_chunked_array_type @@ -117,9 +119,9 @@ def map_blocks_with_dask_array_expr( "The dask_array chunk manager does not support map_blocks_multi_output." ) - input_exprs = [] - input_indices = [] - arg_templates = [] + input_exprs: list[Any] = [] + input_indices: list[Any] = [] + arg_templates: list[Any] = [] for isxr, arg in zip(is_xarray, npargs, strict=True): if not isxr: if is_dask_collection(arg): @@ -130,11 +132,12 @@ def map_blocks_with_dask_array_expr( arg_templates.append(("literal", arg)) continue - variable_templates = [] + variable_templates: list[Any] = [] for name, variable in arg.variables.items(): is_coord = name in arg._coord_names if is_dask_collection(variable.data): - input_exprs.append(variable.data.expr) + data: Any = variable.data + input_exprs.append(data.expr) input_indices.append(variable.dims) variable_templates.append( ( @@ -161,19 +164,19 @@ def map_blocks_with_dask_array_expr( ) arg_templates.append(("xarray", arg.attrs, variable_templates)) - def build_block_specs(): - specs = {} + def build_block_specs() -> dict[tuple[Any, ...], dict[str, Any]]: + specs: dict[tuple[Any, ...], dict[str, Any]] = {} for chunk_tuple in itertools.product(*ichunk.values()): chunk_index = dict(zip(ichunk.keys(), chunk_tuple, strict=True)) - arg_specs = [] + arg_specs: list[Any] = [] for arg_template in arg_templates: if arg_template[0] == "literal": arg_specs.append(arg_template) continue _, attrs, variable_templates = arg_template - data_vars = [] - coords = [] + data_vars: list[Any] = [] + coords: list[Any] = [] for ( kind, name, diff --git a/xarray/tests/test_dask_expr_protocol.py b/xarray/tests/test_dask_expr_protocol.py index c2a1bd904e7..311bc23c615 100644 --- a/xarray/tests/test_dask_expr_protocol.py +++ b/xarray/tests/test_dask_expr_protocol.py @@ -1,5 +1,8 @@ from __future__ import annotations +from importlib import import_module +from typing import Any + import numpy as np import pytest @@ -12,9 +15,10 @@ da = pytest.importorskip("dask.array") dask_array = pytest.importorskip("dask_array") -try: - from dask._expr import CompositeExpr, HLGExpr -except ImportError: +dask_expr = pytest.importorskip("dask._expr") +CompositeExpr: Any = getattr(dask_expr, "CompositeExpr", None) +HLGExpr: Any = getattr(dask_expr, "HLGExpr", None) +if CompositeExpr is None or HLGExpr is None: pytest.skip("requires Dask composite expressions", allow_module_level=True) @@ -120,7 +124,7 @@ def test_open_mfdataset_map_blocks_end_to_end(tmp_path): @requires_scipy_or_netCDF4 def test_open_dataset_rechunk_optimization_crosses_composite_expr(tmp_path): - from dask_array._rechunk import Rechunk + Rechunk = import_module("dask_array._rechunk").Rechunk path = tmp_path / "data.nc" Dataset({"x": ("t", np.arange(12))}, coords={"t": np.arange(12)}).to_netcdf(path) From 2804b6ac6af674a4248d46498ff7c6e0e882dd70 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 17 Jun 2026 10:45:42 -0700 Subject: [PATCH 6/9] Address expression protocol review comments Co-Authored-By: Codex --- pixi.toml | 3 +++ xarray/compat/dask_array_ops.py | 10 +++----- xarray/core/dask_array_expr.py | 20 +++++++++++++++ xarray/core/dataset.py | 45 ++++++++++++++------------------- 4 files changed, 46 insertions(+), 32 deletions(-) diff --git a/pixi.toml b/pixi.toml index 3e429f6f6eb..b721d02e4cf 100644 --- a/pixi.toml +++ b/pixi.toml @@ -64,6 +64,9 @@ numbagg = "*" dask = "*" distributed = "*" +[feature.dask.pypi-dependencies] +dask-array = { git = "https://github.com/mrocklin/dask-array" } + [feature.accel.dependencies] flox = "*" bottleneck = "*" diff --git a/xarray/compat/dask_array_ops.py b/xarray/compat/dask_array_ops.py index ad52a175bd3..e0ffaf3bd7c 100644 --- a/xarray/compat/dask_array_ops.py +++ b/xarray/compat/dask_array_ops.py @@ -101,10 +101,8 @@ def push(array, n, axis, method="blelloch"): from xarray.core.nputils import nanlast from xarray.namedarray.parallelcompat import get_chunked_array_type - da = get_chunked_array_type(array).array_api - cumreduction = getattr(da, "cumreduction", None) - if cumreduction is None: - cumreduction = da.reductions.cumreduction + chunkmanager = get_chunked_array_type(array) + da = chunkmanager.array_api if n is not None and all(n <= size for size in array.chunks[axis]): return array.map_overlap(_push, depth={axis: (n, 0)}, n=n, axis=axis) @@ -112,11 +110,11 @@ def push(array, n, axis, method="blelloch"): # TODO: Replace all this function # once https://github.com/pydata/xarray/issues/9229 being implemented - pushed_array = cumreduction( + pushed_array = chunkmanager.scan( func=_dtype_push, binop=_fill_with_last_one, ident=np.nan, - x=array, + arr=array, axis=axis, dtype=array.dtype, method=method, diff --git a/xarray/core/dask_array_expr.py b/xarray/core/dask_array_expr.py index 1e0290449f0..dfa4ecfc044 100644 --- a/xarray/core/dask_array_expr.py +++ b/xarray/core/dask_array_expr.py @@ -1,3 +1,23 @@ +"""Optional xarray integration for ``dask_array`` expression-backed arrays. + +xarray stays expression-free. ``dask_array`` owns the lazy array expression +objects, while xarray owns Dataset/DataArray semantics such as coordinates, +indexes, attrs, template validation, and rebuilding final xarray objects. + +Most xarray operations only need the normal chunk-manager methods. The special +case here is ``xarray.map_blocks``: it can return multiple output variables from +one user function call per input block. The helper below converts xarray's block +metadata into a private ``dask_array`` multi-output map expression. Each output +variable is still a normal ``dask_array.Array`` child expression, so Dask can +group the children with the composite-collection protocol and ``dask_array`` can +optimize, cull, persist, and compute those arrays. + +If a Dataset mixes ``dask_array.Array`` with legacy ``dask.array.Array`` objects, +this path raises before constructing a graph. The generic Dataset expression +protocol instead returns ``None`` for mixed datasets so Dask can use xarray's +existing HighLevelGraph fallback. +""" + from __future__ import annotations import itertools diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 68681f1955e..e53fe6accb9 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -648,13 +648,10 @@ def __dask_graph__(self): if not graphs: return None else: - try: - from dask.highlevelgraph import HighLevelGraph + from dask.highlevelgraph import HighLevelGraph - if all(isinstance(graph, HighLevelGraph) for graph in graphs.values()): - return HighLevelGraph.merge(*graphs.values()) - except ImportError: - pass + if all(isinstance(graph, HighLevelGraph) for graph in graphs.values()): + return HighLevelGraph.merge(*graphs.values()) from dask.utils import ensure_dict @@ -681,6 +678,10 @@ def __dask_exprs__(self): for v in self.variables.values(): if dask.is_dask_collection(v): if not is_dask_array_expr_array(v._data): + # Composite expressions must account for every Dask-backed + # variable. Returning None lets Dask fall back to the + # existing HighLevelGraph path for mixed legacy/expression + # datasets rather than failing during protocol discovery. return None exprs.append(v._data.expr) return exprs or None @@ -689,27 +690,19 @@ def __dask_rebuild_from_exprs__(self, exprs): import dask from dask._collections import new_collection - exprs_iter = iter(exprs) - variables = {} - - for k, v in self._variables.items(): - if dask.is_dask_collection(v): - try: - expr = next(exprs_iter) - except StopIteration as err: - raise ValueError( - "Not enough expressions to rebuild Dataset" - ) from err - variables[k] = v._replace(data=new_collection(expr)) - else: - variables[k] = v + dask_variables = [ + (k, v) for k, v in self._variables.items() if dask.is_dask_collection(v) + ] + exprs = list(exprs) + if len(exprs) != len(dask_variables): + raise ValueError( + f"Expected {len(dask_variables)} expressions to rebuild Dataset, " + f"got {len(exprs)}" + ) - try: - next(exprs_iter) - except StopIteration: - pass - else: - raise ValueError("Too many expressions to rebuild Dataset") + variables = dict(self._variables) + for (k, v), expr in zip(dask_variables, exprs, strict=True): + variables[k] = v._replace(data=new_collection(expr)) return type(self)._construct_direct( variables, From 584e6117906ce549487235efa7fcaee11de72afd Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 17 Jun 2026 10:47:53 -0700 Subject: [PATCH 7/9] Defer dask_array map_blocks integration Co-Authored-By: Codex --- xarray/core/dask_array_expr.py | 281 ------------------------ xarray/core/parallel.py | 42 +--- xarray/tests/test_dask_expr_protocol.py | 105 +-------- 3 files changed, 2 insertions(+), 426 deletions(-) diff --git a/xarray/core/dask_array_expr.py b/xarray/core/dask_array_expr.py index dfa4ecfc044..db1e586624c 100644 --- a/xarray/core/dask_array_expr.py +++ b/xarray/core/dask_array_expr.py @@ -1,35 +1,8 @@ -"""Optional xarray integration for ``dask_array`` expression-backed arrays. - -xarray stays expression-free. ``dask_array`` owns the lazy array expression -objects, while xarray owns Dataset/DataArray semantics such as coordinates, -indexes, attrs, template validation, and rebuilding final xarray objects. - -Most xarray operations only need the normal chunk-manager methods. The special -case here is ``xarray.map_blocks``: it can return multiple output variables from -one user function call per input block. The helper below converts xarray's block -metadata into a private ``dask_array`` multi-output map expression. Each output -variable is still a normal ``dask_array.Array`` child expression, so Dask can -group the children with the composite-collection protocol and ``dask_array`` can -optimize, cull, persist, and compute those arrays. - -If a Dataset mixes ``dask_array.Array`` with legacy ``dask.array.Array`` objects, -this path raises before constructing a graph. The generic Dataset expression -protocol instead returns ``None`` for mixed datasets so Dask can use xarray's -existing HighLevelGraph fallback. -""" - from __future__ import annotations -import itertools -from collections.abc import Callable, Hashable, Mapping, Sequence from importlib import import_module from typing import Any -from xarray.core.coordinates import Coordinates -from xarray.core.dataarray import DataArray -from xarray.core.dataset import Dataset -from xarray.core.utils import is_dask_collection - def is_dask_array_expr_array(data: Any) -> bool: try: @@ -39,257 +12,3 @@ def is_dask_array_expr_array(data: Any) -> bool: array_type = getattr(dask_array, "Array", None) return array_type is not None and isinstance(data, array_type) - - -def collect_dask_array_expr_chunked_data( - xarray_objs: Sequence[Dataset], -) -> tuple[bool, list[Any]]: - chunked_data = [ - variable.data - for arg in xarray_objs - for variable in arg.variables.values() - if is_dask_collection(variable.data) - ] - has_dask_array_expr = any(is_dask_array_expr_array(data) for data in chunked_data) - has_other_chunked = any(not is_dask_array_expr_array(data) for data in chunked_data) - - if has_dask_array_expr and has_other_chunked: - raise TypeError( - "xarray.map_blocks cannot mix dask_array.Array with legacy or other " - "Dask-backed arrays. Convert inputs to one array backend first." - ) - - return has_dask_array_expr, chunked_data - - -def _execute_map_blocks_multi_output(block_spec: Mapping[str, Any], *blocks: Any): - args = [] - for arg_spec in block_spec["args"]: - if arg_spec[0] == "literal": - args.append(arg_spec[1]) - continue - - _, data_vars, coords, attrs = arg_spec - - def build_variables(variable_specs): - variables = [] - for name, dims, data, var_attrs in variable_specs: - if data[0] == "block": - data = blocks[data[1]] - else: - data = data[1] - variables.append((name, (dims, data, var_attrs))) - return dict(variables) - - args.append(Dataset(build_variables(data_vars), build_variables(coords), attrs)) - - return block_spec["wrapper"]( - block_spec["func"], - args, - block_spec["kwargs"], - block_spec["arg_is_array"], - block_spec["expected"], - block_spec["expected_indexes"], - ) - - -def map_blocks_with_dask_array_expr( - *, - func: Callable[..., Any], - npargs: Sequence[Any], - kwargs: Mapping[str, Any], - is_xarray: Sequence[bool], - is_array: Sequence[bool], - input_chunks: Mapping[Hashable, tuple[int, ...]], - output_chunks: Mapping[Hashable, tuple[int, ...]], - coordinates: Coordinates, - template: Dataset, - result_is_array: bool, - template_name: Hashable | None, - gname: str, - ichunk: Mapping[Hashable, range], - input_chunk_bounds: Mapping[Hashable, Any], - output_chunk_bounds: Mapping[Hashable, Any], - computed_variables: Sequence[Hashable], - new_indexes: set[Hashable], - modified_indexes: set[Hashable], - chunked_data: Sequence[Any], - wrapper: Callable[..., Any], - get_chunk_slicer: Callable[[Hashable, Mapping[Any, Any], Mapping[Any, Any]], slice], - dataset_to_dataarray: Callable[[Dataset], DataArray], -) -> DataArray | Dataset: - missing_chunked_dims = { - dim - for dim, chunks in input_chunks.items() - if len(chunks) > 1 and dim not in output_chunks - } - if missing_chunked_dims: - raise NotImplementedError( - "dask_array-backed xarray.map_blocks does not yet support " - "dropping multi-chunk dimensions. Rechunk these dimensions to " - f"one chunk first: {sorted(missing_chunked_dims, key=repr)!r}." - ) - - from xarray.namedarray.parallelcompat import get_chunked_array_type - - chunkmanager = get_chunked_array_type(*chunked_data) - map_blocks_multi_output = getattr(chunkmanager, "map_blocks_multi_output", None) - if map_blocks_multi_output is None: - raise NotImplementedError( - "The dask_array chunk manager does not support map_blocks_multi_output." - ) - - input_exprs: list[Any] = [] - input_indices: list[Any] = [] - arg_templates: list[Any] = [] - for isxr, arg in zip(is_xarray, npargs, strict=True): - if not isxr: - if is_dask_collection(arg): - raise TypeError( - "dask_array-backed xarray.map_blocks only supports Dask " - "collections inside xarray arguments." - ) - arg_templates.append(("literal", arg)) - continue - - variable_templates: list[Any] = [] - for name, variable in arg.variables.items(): - is_coord = name in arg._coord_names - if is_dask_collection(variable.data): - data: Any = variable.data - input_exprs.append(data.expr) - input_indices.append(variable.dims) - variable_templates.append( - ( - "chunked", - name, - variable.dims, - variable.attrs, - len(input_exprs) - 1, - is_coord, - None, - ) - ) - else: - variable_templates.append( - ( - "static", - name, - variable.dims, - variable.attrs, - None, - is_coord, - variable, - ) - ) - arg_templates.append(("xarray", arg.attrs, variable_templates)) - - def build_block_specs() -> dict[tuple[Any, ...], dict[str, Any]]: - specs: dict[tuple[Any, ...], dict[str, Any]] = {} - for chunk_tuple in itertools.product(*ichunk.values()): - chunk_index = dict(zip(ichunk.keys(), chunk_tuple, strict=True)) - arg_specs: list[Any] = [] - for arg_template in arg_templates: - if arg_template[0] == "literal": - arg_specs.append(arg_template) - continue - - _, attrs, variable_templates = arg_template - data_vars: list[Any] = [] - coords: list[Any] = [] - for ( - kind, - name, - dims, - attrs, - input_position, - is_coord, - variable, - ) in variable_templates: - if kind == "chunked": - data = ("block", input_position) - else: - assert variable is not None - subsetter = { - dim: get_chunk_slicer(dim, chunk_index, input_chunk_bounds) - for dim in variable.dims - } - data = ("static", variable.isel(subsetter)._data) - - target = coords if is_coord else data_vars - target.append((name, dims, data, attrs)) - - arg_specs.append(("xarray", data_vars, coords, attrs)) - - indexes = { - dim: coordinates.xindexes[dim][ - get_chunk_slicer(dim, chunk_index, output_chunk_bounds) - ] - for dim in (new_indexes | modified_indexes) - } - expected = { - "shapes": { - k: output_chunks[k][v] - for k, v in chunk_index.items() - if k in output_chunks - }, - "data_vars": set(template.data_vars.keys()), - "coords": set(template.coords.keys()), - } - specs[chunk_tuple] = { - "wrapper": wrapper, - "func": func, - "args": arg_specs, - "kwargs": kwargs, - "arg_is_array": is_array, - "expected": expected, - "expected_indexes": indexes, - } - return specs - - outputs = [] - for name in computed_variables: - variable = template.variables[name] - var_chunks = [] - for dim in variable.dims: - if dim in output_chunks: - var_chunks.append(output_chunks[dim]) - elif dim in template.dims: - var_chunks.append((template.sizes[dim],)) - - outputs.append( - { - "key": name, - "indices": variable.dims, - "chunks": tuple(var_chunks), - "dtype": variable.dtype, - "name": f"{name}-{gname}", - } - ) - - mapped_arrays = map_blocks_multi_output( - _execute_map_blocks_multi_output, - input_exprs, - input_indices, - tuple(input_chunks), - build_block_specs(), - outputs, - token=gname, - ) - - result = Dataset(coords=coordinates, attrs=template.attrs) - for index in result._indexes: - result[index].attrs = template[index].attrs - result[index].encoding = template[index].encoding - - for name, data in zip(computed_variables, mapped_arrays, strict=True): - result[name] = (template[name].dims, data, template[name].attrs) - result[name].encoding = template[name].encoding - - result = result.set_coords(template._coord_names) - - if result_is_array: - da = dataset_to_dataarray(result) - da.name = template_name - return da - return result diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index bb500657eed..fea3f44d8eb 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -441,16 +441,8 @@ def _wrapper( for arg in aligned ) - from xarray.core.dask_array_expr import collect_dask_array_expr_chunked_data - - has_dask_array_expr, chunked_data = collect_dask_array_expr_chunked_data( - xarray_objs - ) - # rechunk any numpy variables appropriately xarray_objs = tuple(arg.chunk(arg.chunksizes) for arg in xarray_objs) - if has_dask_array_expr: - _, chunked_data = collect_dask_array_expr_chunked_data(xarray_objs) merged_coordinates = merge( [arg.coords for arg in aligned], @@ -530,7 +522,6 @@ def _wrapper( template = template._to_temp_dataset() elif isinstance(template, Dataset): result_is_array = False - template_name = None else: raise TypeError( f"func output must be DataArray or Dataset; got {type(template)}" @@ -556,38 +547,7 @@ def _wrapper( dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in output_chunks.items() } - computed_variables = [ - name for name in template.variables if name not in coordinates.indexes - ] - - if has_dask_array_expr: - from xarray.core.dask_array_expr import map_blocks_with_dask_array_expr - - return map_blocks_with_dask_array_expr( - func=func, - npargs=npargs, - kwargs=kwargs, - is_xarray=is_xarray, - is_array=is_array, - input_chunks=input_chunks, - output_chunks=output_chunks, - coordinates=coordinates, - template=template, - result_is_array=result_is_array, - template_name=template_name, - gname=gname, - ichunk=ichunk, - input_chunk_bounds=input_chunk_bounds, - output_chunk_bounds=output_chunk_bounds, - computed_variables=computed_variables, - new_indexes=new_indexes, - modified_indexes=modified_indexes, - chunked_data=chunked_data, - wrapper=_wrapper, - get_chunk_slicer=_get_chunk_slicer, - dataset_to_dataarray=dataset_to_dataarray, - ) # type: ignore[return-value] - + computed_variables = set(template.variables) - set(coordinates.indexes) # iterate over all possible chunk combinations for chunk_tuple in itertools.product(*ichunk.values()): # mapping from dimension name to chunk index diff --git a/xarray/tests/test_dask_expr_protocol.py b/xarray/tests/test_dask_expr_protocol.py index 311bc23c615..f01f5811c81 100644 --- a/xarray/tests/test_dask_expr_protocol.py +++ b/xarray/tests/test_dask_expr_protocol.py @@ -86,7 +86,7 @@ def test_dataset_compute_persist_optimize_end_to_end(): @requires_scipy_or_netCDF4 -def test_open_mfdataset_map_blocks_end_to_end(tmp_path): +def test_open_mfdataset_end_to_end(tmp_path): paths = [] for i in range(2): path = tmp_path / f"part-{i}.nc" @@ -113,14 +113,6 @@ def test_open_mfdataset_map_blocks_end_to_end(tmp_path): dask.optimize(ds)[0].compute(scheduler="single-threaded"), expected ) - mapped = xr.map_blocks(lambda block: block + 1, ds) - assert isinstance(mapped["x"].data, dask_array.Array) - assert isinstance(dask.base.collections_to_expr(mapped), CompositeExpr) - assert_identical( - mapped.compute(scheduler="single-threaded"), - Dataset({"x": ("t", np.arange(6) + 1)}, coords={"t": np.arange(6)}), - ) - @requires_scipy_or_netCDF4 def test_open_dataset_rechunk_optimization_crosses_composite_expr(tmp_path): @@ -155,93 +147,6 @@ def test_open_dataset_rechunk_optimization_crosses_composite_expr(tmp_path): ) -def test_map_blocks_dataarray_end_to_end(): - x = dask_array.arange(6, chunks=(3,)) - arr = DataArray(x, dims="t", name="x") - out = xr.map_blocks(lambda block: block + 1, arr) - - assert isinstance(out.data, dask_array.Array) - assert isinstance(dask.base.collections_to_expr(out), CompositeExpr) - - expected = DataArray(np.arange(6) + 1, dims="t", name="x") - assert_identical(out.compute(scheduler="single-threaded"), expected) - assert_identical( - out.persist(scheduler="single-threaded").compute(scheduler="single-threaded"), - expected, - ) - assert_identical( - dask.optimize(out)[0].compute(scheduler="single-threaded"), expected - ) - - -def test_map_blocks_dataset_outputs_share_block_calls(): - calls = [] - x = dask_array.arange(6, chunks=(3,)) - ds = Dataset({"x": ("t", x)}, coords={"qc": ("t", x + 10)}) - template = Dataset( - {"a": ("t", x), "b": ("t", x)}, - coords={"qc": ("t", x + 10)}, - attrs={"kind": "mapped"}, - ) - - def func(block): - calls.append(block.sizes["t"]) - return Dataset( - {"a": block["x"] + 1, "b": block["x"] + 2}, - coords={"qc": block["qc"]}, - attrs={"kind": "mapped"}, - ) - - out = xr.map_blocks(func, ds, template=template) - - assert isinstance(dask.base.collections_to_expr(out), CompositeExpr) - assert all(isinstance(out[name].data, dask_array.Array) for name in out) - assert sorted(calls) == [] - - expected = Dataset( - {"a": ("t", np.arange(6) + 1), "b": ("t", np.arange(6) + 2)}, - coords={"qc": ("t", np.arange(6) + 10)}, - attrs={"kind": "mapped"}, - ) - assert_identical(out.compute(scheduler="single-threaded"), expected) - assert sorted(calls) == [3, 3] - - -def test_map_blocks_preserves_scalar_coords(): - x = dask_array.arange(6, chunks=(3,)).reshape((3, 2)) - arr = DataArray( - x, - dims=("x", "y"), - coords={"label": ("x", ["a", "b", "c"]), "scale": 2}, - name="z", - ) - - out = xr.map_blocks(lambda block: block + block.scale, arr) - - assert isinstance(dask.base.collections_to_expr(out), CompositeExpr) - assert_identical( - out.compute(scheduler="single-threaded"), - DataArray( - np.arange(6).reshape((3, 2)) + 2, - dims=("x", "y"), - coords={"label": ("x", ["a", "b", "c"]), "scale": 2}, - ), - ) - - -def test_map_blocks_reduces_single_chunk_dimension(): - x = dask_array.arange(12, chunks=(12,)).reshape((3, 4)).rechunk((3, 2)) - arr = DataArray(x, dims=("x", "y"), name="z") - - out = xr.map_blocks(lambda block: block.sum("x"), arr) - - assert isinstance(dask.base.collections_to_expr(out), CompositeExpr) - assert_identical( - out.compute(scheduler="single-threaded"), - DataArray(np.arange(12).reshape(3, 4).sum(axis=0), dims="y", name="z"), - ) - - def test_mixed_legacy_inputs_do_not_use_composite_path(): ds = Dataset( { @@ -253,14 +158,6 @@ def test_mixed_legacy_inputs_do_not_use_composite_path(): assert ds.__dask_exprs__() is None assert isinstance(dask.base.collections_to_expr(ds), HLGExpr) - with pytest.raises(TypeError, match=r"cannot mix dask_array\.Array"): - xr.map_blocks(lambda block: block + 1, ds) - - arr = DataArray(dask_array.arange(6, chunks=(3,)), dims="i") - other = DataArray(da.arange(6, chunks=(3,)), dims="i") - with pytest.raises(TypeError, match=r"cannot mix dask_array\.Array"): - xr.map_blocks(lambda a, b: a + b, arr, args=[other]) - def test_shared_subexpressions_optimize_without_cross_contamination(): from dask.core import flatten From 566da388d2c48707a68a9b8263fcc1f645927b99 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 17 Jun 2026 11:35:57 -0700 Subject: [PATCH 8/9] Inline dask-array expression check --- xarray/core/dask_array_expr.py | 14 -------------- xarray/core/dataset.py | 13 ++++++++----- xarray/tests/test_dask_expr_protocol.py | 20 ++++++++++++++++++++ 3 files changed, 28 insertions(+), 19 deletions(-) delete mode 100644 xarray/core/dask_array_expr.py diff --git a/xarray/core/dask_array_expr.py b/xarray/core/dask_array_expr.py deleted file mode 100644 index db1e586624c..00000000000 --- a/xarray/core/dask_array_expr.py +++ /dev/null @@ -1,14 +0,0 @@ -from __future__ import annotations - -from importlib import import_module -from typing import Any - - -def is_dask_array_expr_array(data: Any) -> bool: - try: - dask_array = import_module("dask_array") - except ImportError: - return False - - array_type = getattr(dask_array, "Array", None) - return array_type is not None and isinstance(data, array_type) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index e53fe6accb9..68a9a6349b8 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -672,16 +672,19 @@ def __dask_keys__(self): def __dask_exprs__(self): import dask - from xarray.core.dask_array_expr import is_dask_array_expr_array + try: + from dask_array import Array as DaskArray + except ImportError: + return None exprs = [] for v in self.variables.values(): if dask.is_dask_collection(v): - if not is_dask_array_expr_array(v._data): + if not isinstance(v._data, DaskArray): # Composite expressions must account for every Dask-backed - # variable. Returning None lets Dask fall back to the - # existing HighLevelGraph path for mixed legacy/expression - # datasets rather than failing during protocol discovery. + # variable. Returning None keeps Dask's collection APIs on + # the existing HighLevelGraph path for mixed + # legacy/expression datasets. return None exprs.append(v._data.expr) return exprs or None diff --git a/xarray/tests/test_dask_expr_protocol.py b/xarray/tests/test_dask_expr_protocol.py index f01f5811c81..924faa300ef 100644 --- a/xarray/tests/test_dask_expr_protocol.py +++ b/xarray/tests/test_dask_expr_protocol.py @@ -154,9 +154,29 @@ def test_mixed_legacy_inputs_do_not_use_composite_path(): "legacy": ("i", da.arange(3, chunks=(1,))), } ) + expected = Dataset( + {"x": ("i", np.arange(3)), "legacy": ("i", np.arange(3))}, + ) assert ds.__dask_exprs__() is None assert isinstance(dask.base.collections_to_expr(ds), HLGExpr) + assert_identical(dask.compute(ds, scheduler="single-threaded")[0], expected) + + persisted = dask.persist(ds, scheduler="single-threaded")[0] + assert isinstance(persisted["x"].data, dask_array.Array) + assert isinstance(persisted["legacy"].data, da.Array) + assert_identical( + dask.compute(persisted, scheduler="single-threaded")[0], + expected, + ) + + optimized = dask.optimize(ds)[0] + assert isinstance(optimized["x"].data, dask_array.Array) + assert isinstance(optimized["legacy"].data, da.Array) + assert_identical( + dask.compute(optimized, scheduler="single-threaded")[0], + expected, + ) def test_shared_subexpressions_optimize_without_cross_contamination(): From e85640440b2c02d55c5dfaf581be5b78d185b86c Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Wed, 17 Jun 2026 15:46:05 -0700 Subject: [PATCH 9/9] Refine dask-array expression follow-ups Keep the dask-array chunk-manager fixes in xarray while dropping the dedicated dask-array CI environment. This leaves map_blocks out of scope, keeps optional dask_array discovery localized, and updates the groupby expectation now that it remains expression-backed. Co-Authored-By: Codex --- pixi.toml | 3 --- xarray/compat/dask_array_ops.py | 1 - xarray/core/dataset.py | 4 +++- xarray/tests/test_dask_expr_protocol.py | 13 ++++++++----- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/pixi.toml b/pixi.toml index b721d02e4cf..3e429f6f6eb 100644 --- a/pixi.toml +++ b/pixi.toml @@ -64,9 +64,6 @@ numbagg = "*" dask = "*" distributed = "*" -[feature.dask.pypi-dependencies] -dask-array = { git = "https://github.com/mrocklin/dask-array" } - [feature.accel.dependencies] flox = "*" bottleneck = "*" diff --git a/xarray/compat/dask_array_ops.py b/xarray/compat/dask_array_ops.py index e0ffaf3bd7c..e803b219a29 100644 --- a/xarray/compat/dask_array_ops.py +++ b/xarray/compat/dask_array_ops.py @@ -99,7 +99,6 @@ def push(array, n, axis, method="blelloch"): from xarray.core.duck_array_ops import _push from xarray.core.nputils import nanlast - from xarray.namedarray.parallelcompat import get_chunked_array_type chunkmanager = get_chunked_array_type(array) da = chunkmanager.array_api diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 68a9a6349b8..1ce84904623 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -670,10 +670,12 @@ def __dask_keys__(self): ] def __dask_exprs__(self): + from importlib import import_module + import dask try: - from dask_array import Array as DaskArray + DaskArray = import_module("dask_array").Array except ImportError: return None diff --git a/xarray/tests/test_dask_expr_protocol.py b/xarray/tests/test_dask_expr_protocol.py index 924faa300ef..b40dff53178 100644 --- a/xarray/tests/test_dask_expr_protocol.py +++ b/xarray/tests/test_dask_expr_protocol.py @@ -220,7 +220,7 @@ def test_apply_ufunc_parallelized_uses_composite_expr(): ) -def test_groupby_sum_currently_falls_back_after_legacy_conversion(): +def test_groupby_sum_uses_composite_expr(): x = dask_array.arange(6, chunks=(3,)) arr = DataArray( x, @@ -229,8 +229,11 @@ def test_groupby_sum_currently_falls_back_after_legacy_conversion(): name="x", ) - with pytest.warns(UserWarning, match="already a Dask collection"): - out = arr.groupby("label").sum() + out = arr.groupby("label").sum() - assert not isinstance(out.data, dask_array.Array) - assert isinstance(dask.base.collections_to_expr(out), HLGExpr) + assert isinstance(out.data, dask_array.Array) + assert isinstance(dask.base.collections_to_expr(out), CompositeExpr) + assert_identical( + out.compute(scheduler="single-threaded"), + DataArray([5, 10], dims="label", coords={"label": ["a", "b"]}, name="x"), + )