Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions xarray/compat/dask_array_compat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any

from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.utils import module_available


Expand All @@ -8,10 +9,17 @@ def reshape_blockwise(
shape: int | tuple[int, ...],
chunks: tuple[tuple[int, ...], ...] | None = None,
):
if module_available("dask", "2024.08.2"):
from dask.array import reshape_blockwise
try:
array_api = get_chunked_array_type(x).array_api
except TypeError:
array_api = None

return reshape_blockwise(x, shape=shape, chunks=chunks)
if array_api is not None and hasattr(array_api, "reshape_blockwise"):
return array_api.reshape_blockwise(x, shape=shape, chunks=chunks)
elif module_available("dask", "2024.08.2"):
from dask.array import reshape_blockwise as dask_reshape_blockwise

return dask_reshape_blockwise(x, shape=shape, chunks=chunks)
else:
return x.reshape(shape)

Expand Down
11 changes: 7 additions & 4 deletions xarray/compat/dask_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from xarray.compat.dask_array_compat import reshape_blockwise
from xarray.core import dtypes, nputils
from xarray.namedarray.parallelcompat import get_chunked_array_type


def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1):
Expand All @@ -20,7 +21,7 @@ def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1):


def least_squares(lhs, rhs, rcond=None, skipna=False):
import dask.array as da
da = get_chunked_array_type(rhs).array_api

# The trick here is that the core dimension is axis 0.
# All other dimensions need to be reshaped down to one axis for `lstsq`
Expand Down Expand Up @@ -94,23 +95,25 @@ def push(array, n, axis, method="blelloch"):
"""
Dask-aware bottleneck.push
"""
import dask.array as da
import numpy as np

from xarray.core.duck_array_ops import _push
from xarray.core.nputils import nanlast

chunkmanager = get_chunked_array_type(array)
da = chunkmanager.array_api

if n is not None and all(n <= size for size in array.chunks[axis]):
return array.map_overlap(_push, depth={axis: (n, 0)}, n=n, axis=axis)

# TODO: Replace all this function
# once https://github.com/pydata/xarray/issues/9229 being implemented

pushed_array = da.reductions.cumreduction(
pushed_array = chunkmanager.scan(
func=_dtype_push,
binop=_fill_with_last_one,
ident=np.nan,
x=array,
arr=array,
axis=axis,
dtype=array.dtype,
method=method,
Expand Down
7 changes: 7 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,6 +1113,13 @@ def __dask_graph__(self):
def __dask_keys__(self):
return self._to_temp_dataset().__dask_keys__()

def __dask_exprs__(self):
return self._to_temp_dataset().__dask_exprs__()

def __dask_rebuild_from_exprs__(self, exprs):
ds = self._to_temp_dataset().__dask_rebuild_from_exprs__(exprs)
return self._from_temp_dataset(ds)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the core of the change. There's a newly proposed protocol in Dask and this is Xarray supporting that protocol.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

def __dask_layers__(self):
return self._to_temp_dataset().__dask_layers__()

Expand Down
63 changes: 58 additions & 5 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,14 +648,17 @@ def __dask_graph__(self):
if not graphs:
return None
else:
try:
from dask.highlevelgraph import HighLevelGraph
from dask.highlevelgraph import HighLevelGraph

if all(isinstance(graph, HighLevelGraph) for graph in graphs.values()):
return HighLevelGraph.merge(*graphs.values())
except ImportError:
from dask import sharedict

return sharedict.merge(*graphs.values())
from dask.utils import ensure_dict

merged = {}
for graph in graphs.values():
merged.update(ensure_dict(graph))
return merged

def __dask_keys__(self):
import dask
Expand All @@ -666,6 +669,56 @@ def __dask_keys__(self):
if dask.is_dask_collection(v)
]

def __dask_exprs__(self):
from importlib import import_module

import dask

try:
DaskArray = import_module("dask_array").Array
except ImportError:
return None

exprs = []
for v in self.variables.values():
if dask.is_dask_collection(v):
if not isinstance(v._data, DaskArray):
# Composite expressions must account for every Dask-backed
# variable. Returning None keeps Dask's collection APIs on
# the existing HighLevelGraph path for mixed
# legacy/expression datasets.
return None

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you note why falling back to claiming there are no expressions in the mixed case is the right thing to do? Alternatively, I can imagine raising an error might be more user-friendly.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're both valid options. The mixed case does actually work (at least if you take the dask.compute(...) path). It's entirely possible though that this is indicative of a situation that users would still want to be made aware of and correct. Erring could make sense. So too could warning.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the common way this case might occur is when an external library constructs dask.array.Array and a user combines that with a dask_array .

I can see a warning being useful. Should the choice between silence/warning/error be an option on the dask_array side? An error-by-default policy could push the ecosystem towards using expressions by default. In general, I have developed a strong dislike for this kind of "accept-everything" behaviour, it makes things hard to reason about.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to put a warning in if that's what people want. I think that this isn't a decision that I make. I also think that it's the kind of decision that doesn't need to block this PR. It's low stakes and easy to change in the future.

exprs.append(v._data.expr)
return exprs or None

def __dask_rebuild_from_exprs__(self, exprs):
import dask
from dask._collections import new_collection

dask_variables = [
(k, v) for k, v in self._variables.items() if dask.is_dask_collection(v)
]
exprs = list(exprs)
if len(exprs) != len(dask_variables):
raise ValueError(
f"Expected {len(dask_variables)} expressions to rebuild Dataset, "
f"got {len(exprs)}"
)

variables = dict(self._variables)
for (k, v), expr in zip(dask_variables, exprs, strict=True):
variables[k] = v._replace(data=new_collection(expr))

return type(self)._construct_direct(
variables,
self._coord_names,
self._dims,
self._attrs,
self._indexes,
self._encoding,
self._close,
)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is, again, the core of the change I'm looking for. I hope that it's both fairly straightforward and has a low blast radius.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confirmed this is v. similar to existing _dask_postcompute as expected.


def __dask_layers__(self):
import dask

Expand Down
1 change: 1 addition & 0 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ def _wrapper(
dataarray_to_dataset(arg) if isinstance(arg, DataArray) else arg
for arg in aligned
)

# rechunk any numpy variables appropriately
xarray_objs = tuple(arg.chunk(arg.chunksizes) for arg in xarray_objs)

Expand Down
Loading
Loading