-
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Support query optimization with Dask expression arrays #11382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
034774b
59039a0
87561b8
fd7a9c8
4b17760
2804b6a
584e611
566da38
e856404
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -648,14 +648,17 @@ def __dask_graph__(self): | |
| if not graphs: | ||
| return None | ||
| else: | ||
| try: | ||
| from dask.highlevelgraph import HighLevelGraph | ||
| from dask.highlevelgraph import HighLevelGraph | ||
|
|
||
| if all(isinstance(graph, HighLevelGraph) for graph in graphs.values()): | ||
| return HighLevelGraph.merge(*graphs.values()) | ||
| except ImportError: | ||
| from dask import sharedict | ||
|
|
||
| return sharedict.merge(*graphs.values()) | ||
| from dask.utils import ensure_dict | ||
|
|
||
| merged = {} | ||
| for graph in graphs.values(): | ||
| merged.update(ensure_dict(graph)) | ||
| return merged | ||
|
|
||
| def __dask_keys__(self): | ||
| import dask | ||
|
|
@@ -666,6 +669,56 @@ def __dask_keys__(self): | |
| if dask.is_dask_collection(v) | ||
| ] | ||
|
|
||
| def __dask_exprs__(self): | ||
| from importlib import import_module | ||
|
|
||
| import dask | ||
|
|
||
| try: | ||
| DaskArray = import_module("dask_array").Array | ||
| except ImportError: | ||
| return None | ||
|
|
||
| exprs = [] | ||
| for v in self.variables.values(): | ||
| if dask.is_dask_collection(v): | ||
| if not isinstance(v._data, DaskArray): | ||
| # Composite expressions must account for every Dask-backed | ||
| # variable. Returning None keeps Dask's collection APIs on | ||
| # the existing HighLevelGraph path for mixed | ||
| # legacy/expression datasets. | ||
| return None | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you note why falling back to claiming there are no expressions in the mixed case is the right thing to do? Alternatively, I can imagine raising an error might be more user-friendly.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They're both valid options. The mixed case does actually work (at least if you take the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess the common way this case might occur is when an external library constructs I can see a warning being useful. Should the choice between silence/warning/error be an option on the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Happy to put a warning in if that's what people want. I think that this isn't a decision that I make. I also think that it's the kind of decision that doesn't need to block this PR. It's low stakes and easy to change in the future. |
||
| exprs.append(v._data.expr) | ||
| return exprs or None | ||
|
|
||
| def __dask_rebuild_from_exprs__(self, exprs): | ||
| import dask | ||
| from dask._collections import new_collection | ||
|
|
||
| dask_variables = [ | ||
| (k, v) for k, v in self._variables.items() if dask.is_dask_collection(v) | ||
| ] | ||
| exprs = list(exprs) | ||
| if len(exprs) != len(dask_variables): | ||
| raise ValueError( | ||
| f"Expected {len(dask_variables)} expressions to rebuild Dataset, " | ||
| f"got {len(exprs)}" | ||
| ) | ||
|
|
||
| variables = dict(self._variables) | ||
| for (k, v), expr in zip(dask_variables, exprs, strict=True): | ||
| variables[k] = v._replace(data=new_collection(expr)) | ||
|
|
||
| return type(self)._construct_direct( | ||
| variables, | ||
| self._coord_names, | ||
| self._dims, | ||
| self._attrs, | ||
| self._indexes, | ||
| self._encoding, | ||
| self._close, | ||
| ) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is, again, the core of the change I'm looking for. I hope that it's both fairly straightforward and has a low blast radius.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. confirmed this is v. similar to existing |
||
|
|
||
| def __dask_layers__(self): | ||
| import dask | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the core of the change. There's a newly proposed protocol in Dask and this is Xarray supporting that protocol.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM