diff --git a/xarray/compat/dask_array_compat.py b/xarray/compat/dask_array_compat.py index b8c7da3e64f..e1648c4e34b 100644 --- a/xarray/compat/dask_array_compat.py +++ b/xarray/compat/dask_array_compat.py @@ -1,5 +1,6 @@ from typing import Any +from xarray.namedarray.parallelcompat import get_chunked_array_type from xarray.namedarray.utils import module_available @@ -8,10 +9,17 @@ def reshape_blockwise( shape: int | tuple[int, ...], chunks: tuple[tuple[int, ...], ...] | None = None, ): - if module_available("dask", "2024.08.2"): - from dask.array import reshape_blockwise + try: + array_api = get_chunked_array_type(x).array_api + except TypeError: + array_api = None - return reshape_blockwise(x, shape=shape, chunks=chunks) + if array_api is not None and hasattr(array_api, "reshape_blockwise"): + return array_api.reshape_blockwise(x, shape=shape, chunks=chunks) + elif module_available("dask", "2024.08.2"): + from dask.array import reshape_blockwise as dask_reshape_blockwise + + return dask_reshape_blockwise(x, shape=shape, chunks=chunks) else: return x.reshape(shape) diff --git a/xarray/compat/dask_array_ops.py b/xarray/compat/dask_array_ops.py index 9534351dbfd..e803b219a29 100644 --- a/xarray/compat/dask_array_ops.py +++ b/xarray/compat/dask_array_ops.py @@ -4,6 +4,7 @@ from xarray.compat.dask_array_compat import reshape_blockwise from xarray.core import dtypes, nputils +from xarray.namedarray.parallelcompat import get_chunked_array_type def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1): @@ -20,7 +21,7 @@ def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1): def least_squares(lhs, rhs, rcond=None, skipna=False): - import dask.array as da + da = get_chunked_array_type(rhs).array_api # The trick here is that the core dimension is axis 0. # All other dimensions need to be reshaped down to one axis for `lstsq` @@ -94,23 +95,25 @@ def push(array, n, axis, method="blelloch"): """ Dask-aware bottleneck.push """ - import dask.array as da import numpy as np from xarray.core.duck_array_ops import _push from xarray.core.nputils import nanlast + chunkmanager = get_chunked_array_type(array) + da = chunkmanager.array_api + if n is not None and all(n <= size for size in array.chunks[axis]): return array.map_overlap(_push, depth={axis: (n, 0)}, n=n, axis=axis) # TODO: Replace all this function # once https://github.com/pydata/xarray/issues/9229 being implemented - pushed_array = da.reductions.cumreduction( + pushed_array = chunkmanager.scan( func=_dtype_push, binop=_fill_with_last_one, ident=np.nan, - x=array, + arr=array, axis=axis, dtype=array.dtype, method=method, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index d0df9bc061b..bdba5cb556a 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1113,6 +1113,13 @@ def __dask_graph__(self): def __dask_keys__(self): return self._to_temp_dataset().__dask_keys__() + def __dask_exprs__(self): + return self._to_temp_dataset().__dask_exprs__() + + def __dask_rebuild_from_exprs__(self, exprs): + ds = self._to_temp_dataset().__dask_rebuild_from_exprs__(exprs) + return self._from_temp_dataset(ds) + def __dask_layers__(self): return self._to_temp_dataset().__dask_layers__() diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 03a0b594a13..1ce84904623 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -648,14 +648,17 @@ def __dask_graph__(self): if not graphs: return None else: - try: - from dask.highlevelgraph import HighLevelGraph + from dask.highlevelgraph import HighLevelGraph + if all(isinstance(graph, HighLevelGraph) for graph in graphs.values()): return HighLevelGraph.merge(*graphs.values()) - except ImportError: - from dask import sharedict - return sharedict.merge(*graphs.values()) + from dask.utils import ensure_dict + + merged = {} + for graph in graphs.values(): + merged.update(ensure_dict(graph)) + return merged def __dask_keys__(self): import dask @@ -666,6 +669,56 @@ def __dask_keys__(self): if dask.is_dask_collection(v) ] + def __dask_exprs__(self): + from importlib import import_module + + import dask + + try: + DaskArray = import_module("dask_array").Array + except ImportError: + return None + + exprs = [] + for v in self.variables.values(): + if dask.is_dask_collection(v): + if not isinstance(v._data, DaskArray): + # Composite expressions must account for every Dask-backed + # variable. Returning None keeps Dask's collection APIs on + # the existing HighLevelGraph path for mixed + # legacy/expression datasets. + return None + exprs.append(v._data.expr) + return exprs or None + + def __dask_rebuild_from_exprs__(self, exprs): + import dask + from dask._collections import new_collection + + dask_variables = [ + (k, v) for k, v in self._variables.items() if dask.is_dask_collection(v) + ] + exprs = list(exprs) + if len(exprs) != len(dask_variables): + raise ValueError( + f"Expected {len(dask_variables)} expressions to rebuild Dataset, " + f"got {len(exprs)}" + ) + + variables = dict(self._variables) + for (k, v), expr in zip(dask_variables, exprs, strict=True): + variables[k] = v._replace(data=new_collection(expr)) + + return type(self)._construct_direct( + variables, + self._coord_names, + self._dims, + self._attrs, + self._indexes, + self._encoding, + self._close, + ) + def __dask_layers__(self): import dask diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index a718ef5e911..fea3f44d8eb 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -440,6 +440,7 @@ def _wrapper( dataarray_to_dataset(arg) if isinstance(arg, DataArray) else arg for arg in aligned ) + # rechunk any numpy variables appropriately xarray_objs = tuple(arg.chunk(arg.chunksizes) for arg in xarray_objs) diff --git a/xarray/tests/test_dask_expr_protocol.py b/xarray/tests/test_dask_expr_protocol.py new file mode 100644 index 00000000000..b40dff53178 --- /dev/null +++ b/xarray/tests/test_dask_expr_protocol.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +from importlib import import_module +from typing import Any + +import numpy as np +import pytest + +import xarray as xr +from xarray import DataArray, Dataset +from xarray.testing import assert_identical +from xarray.tests import requires_scipy_or_netCDF4 + +dask = pytest.importorskip("dask") +da = pytest.importorskip("dask.array") +dask_array = pytest.importorskip("dask_array") + +dask_expr = pytest.importorskip("dask._expr") +CompositeExpr: Any = getattr(dask_expr, "CompositeExpr", None) +HLGExpr: Any = getattr(dask_expr, "HLGExpr", None) +if CompositeExpr is None or HLGExpr is None: + pytest.skip("requires Dask composite expressions", allow_module_level=True) + + +def test_dataset_composite_expr_protocol_simple(): + x = dask_array.arange(3, chunks=(1,)) + ds = Dataset({"x": ("i", x + 1)}) + + expr = dask.base.collections_to_expr(ds) + + assert isinstance(expr, CompositeExpr) + assert len(expr.exprs) == 1 + assert_identical( + dask.compute(ds, scheduler="single-threaded")[0], + Dataset({"x": ("i", np.arange(3) + 1)}), + ) + + +def test_dataarray_composite_expr_protocol_includes_chunked_coord(): + x = dask_array.arange(6, chunks=(3,)) + arr = DataArray(x + 1, dims=("i",), coords={"coord": ("i", x + 10)}, name="z") + + expr = dask.base.collections_to_expr(arr) + + assert isinstance(expr, CompositeExpr) + assert len(expr.exprs) == 2 + assert_identical( + dask.compute(arr, scheduler="single-threaded")[0], + DataArray( + np.arange(6) + 1, + dims=("i",), + coords={"coord": ("i", np.arange(6) + 10)}, + name="z", + ), + ) + + +def test_dataset_compute_persist_optimize_end_to_end(): + x = dask_array.arange(6, chunks=(3,)) + ds = Dataset( + {"foo": ("i", x + 1)}, + coords={"coord": ("i", x + 10)}, + attrs={"source": "test"}, + ) + ds["foo"].encoding["example"] = "kept" + + expected = Dataset( + {"foo": ("i", np.arange(6) + 1)}, + coords={"coord": ("i", np.arange(6) + 10)}, + attrs={"source": "test"}, + ) + + computed_ds, computed_x = dask.compute(ds, x, scheduler="single-threaded") + assert_identical(computed_ds, expected) + np.testing.assert_array_equal(computed_x, np.arange(6)) + assert computed_ds["foo"].encoding["example"] == "kept" + + persisted = dask.persist(ds, scheduler="single-threaded")[0] + assert isinstance(persisted["foo"].data, dask_array.Array) + assert_identical(persisted.compute(), expected) + assert persisted["foo"].encoding["example"] == "kept" + + optimized = dask.optimize(ds)[0] + assert isinstance(optimized["foo"].data, dask_array.Array) + assert_identical(optimized.compute(), expected) + + +@requires_scipy_or_netCDF4 +def test_open_mfdataset_end_to_end(tmp_path): + paths = [] + for i in range(2): + path = tmp_path / f"part-{i}.nc" + Dataset( + {"x": ("t", np.arange(i * 3, i * 3 + 3))}, + coords={"t": np.arange(i * 3, i * 3 + 3)}, + ).to_netcdf(path) + paths.append(path) + + with xr.open_mfdataset(paths, chunks={"t": 2}, combine="by_coords") as ds: + assert isinstance(ds["x"].data, dask_array.Array) + assert isinstance(dask.base.collections_to_expr(ds), CompositeExpr) + + expected = Dataset({"x": ("t", np.arange(6))}, coords={"t": np.arange(6)}) + assert_identical(ds.compute(scheduler="single-threaded"), expected) + assert_identical(dask.compute(ds, scheduler="single-threaded")[0], expected) + assert_identical( + ds.persist(scheduler="single-threaded").compute( + scheduler="single-threaded" + ), + expected, + ) + assert_identical( + dask.optimize(ds)[0].compute(scheduler="single-threaded"), expected + ) + + +@requires_scipy_or_netCDF4 +def test_open_dataset_rechunk_optimization_crosses_composite_expr(tmp_path): + Rechunk = import_module("dask_array._rechunk").Rechunk + + path = tmp_path / "data.nc" + Dataset({"x": ("t", np.arange(12))}, coords={"t": np.arange(12)}).to_netcdf(path) + + with xr.open_dataset(path, chunks={"t": 3}) as ds: + out = ds.chunk({"t": 4}) + expr = dask.base.collections_to_expr(out) + source_expr = ds["x"].data.expr + expected_operands = list(source_expr.operands) + expected_operands[source_expr._parameters.index("_chunks")] = ((4, 4, 4),) + expected_expr = type(source_expr)(*expected_operands).optimize() + + assert isinstance(expr, CompositeExpr) + assert len(expr.exprs) == 1 + assert list(expr.exprs[0].find_operations(Rechunk)) + + optimized_expr = expr.optimize() + + assert optimized_expr.exprs[0]._name == expected_expr._name + assert not list(optimized_expr.exprs[0].find_operations(Rechunk)) + + optimized = dask.optimize(out)[0] + assert isinstance(optimized["x"].data, dask_array.Array) + assert not list(optimized["x"].data.expr.find_operations(Rechunk)) + assert_identical( + optimized.compute(scheduler="single-threaded"), + Dataset({"x": ("t", np.arange(12))}, coords={"t": np.arange(12)}), + ) + + +def test_mixed_legacy_inputs_do_not_use_composite_path(): + ds = Dataset( + { + "x": ("i", dask_array.arange(3, chunks=(1,))), + "legacy": ("i", da.arange(3, chunks=(1,))), + } + ) + expected = Dataset( + {"x": ("i", np.arange(3)), "legacy": ("i", np.arange(3))}, + ) + + assert ds.__dask_exprs__() is None + assert isinstance(dask.base.collections_to_expr(ds), HLGExpr) + assert_identical(dask.compute(ds, scheduler="single-threaded")[0], expected) + + persisted = dask.persist(ds, scheduler="single-threaded")[0] + assert isinstance(persisted["x"].data, dask_array.Array) + assert isinstance(persisted["legacy"].data, da.Array) + assert_identical( + dask.compute(persisted, scheduler="single-threaded")[0], + expected, + ) + + optimized = dask.optimize(ds)[0] + assert isinstance(optimized["x"].data, dask_array.Array) + assert isinstance(optimized["legacy"].data, da.Array) + assert_identical( + dask.compute(optimized, scheduler="single-threaded")[0], + expected, + ) + + +def test_shared_subexpressions_optimize_without_cross_contamination(): + from dask.core import flatten + + x = dask_array.arange(6, chunks=(3,)) + ds = Dataset({"foo": ("i", x + 1), "bar": ("i", x + 2)}) + + optimized = dask.optimize(ds)[0] + + foo_graph_keys = set(optimized["foo"].data.__dask_graph__()) + bar_output_keys = set(flatten(optimized["bar"].data.__dask_keys__())) + assert not foo_graph_keys & bar_output_keys + assert_identical( + optimized.compute(), + Dataset({"foo": ("i", np.arange(6) + 1), "bar": ("i", np.arange(6) + 2)}), + ) + + +def test_rechunk_reduction_chain_uses_composite_expr(): + x = dask_array.arange(12, chunks=(4,)).reshape((3, 4)) + out = Dataset({"x": (("a", "b"), x)}).chunk({"a": 1, "b": 2}).sum("b") + 1 + + assert isinstance(dask.base.collections_to_expr(out), CompositeExpr) + expected = Dataset({"x": ("a", np.arange(12).reshape(3, 4).sum(axis=1) + 1)}) + assert_identical(out.compute(scheduler="single-threaded"), expected) + assert_identical( + dask.optimize(out)[0].compute(scheduler="single-threaded"), expected + ) + + +def test_apply_ufunc_parallelized_uses_composite_expr(): + x = dask_array.arange(6, chunks=(3,)) + arr = DataArray(x, dims="t", name="x") + out = xr.apply_ufunc(lambda z: z + 2, arr, dask="parallelized", output_dtypes=[int]) + + assert isinstance(dask.base.collections_to_expr(out), CompositeExpr) + assert_identical( + out.compute(scheduler="single-threaded"), + DataArray(np.arange(6) + 2, dims="t", name="x"), + ) + + +def test_groupby_sum_uses_composite_expr(): + x = dask_array.arange(6, chunks=(3,)) + arr = DataArray( + x, + dims="t", + coords={"label": ("t", np.array(["a", "a", "b", "b", "a", "b"]))}, + name="x", + ) + + out = arr.groupby("label").sum() + + assert isinstance(out.data, dask_array.Array) + assert isinstance(dask.base.collections_to_expr(out), CompositeExpr) + assert_identical( + out.compute(scheduler="single-threaded"), + DataArray([5, 10], dims="label", coords={"label": ["a", "b"]}, name="x"), + )