Support query optimization with Dask expression arrays#11382
Conversation
147a748 to
3501992
Compare
|
Excited to check this out! Fingers crossed it helps the xarray/dask 'large task graph' serialization warning. |
mrocklin
left a comment
There was a problem hiding this comment.
To aid review I've added some comments on what I think is essential, and what could be dropped with only slight degradation in functionality (but lots of simplicity in review)
| elif module_available("dask", "2024.08.2"): | ||
| from dask.array import reshape_blockwise as dask_reshape_blockwise | ||
|
|
||
| return dask_reshape_blockwise(x, shape=shape, chunks=chunks) |
There was a problem hiding this comment.
I think that this change is useful regardless if xarray wants to do the chunk manager thing rather than be tied to dask.array
There was a problem hiding this comment.
SGTM. FWIW it's technically not array api (and will never be since it's only valuable for chunked arary kinds) . presumably dask_array implements it?
There was a problem hiding this comment.
Fair point. I guess the question then becomes "how should xarray handle dispatch along these operations between different chunked APIs". What I've done here doesn't feel quite right. Any suggestions?
There was a problem hiding this comment.
**we have the "chunk manager" for this.
There was a problem hiding this comment.
OK. I understand more clearly now. Thank you. I guess there are two options here then:
- Leave it as is. A little unclean because we're kind of abusing a protocol, but in a harmless way
- Extend the Xarray's chunk manager
From your earlier SGTM comment my sense is that it'd be reasonable to extend the chunk manager, but not a big deal. Planning to leave this as is for now, but let me know if you'd prefer otherwise.
| # once https://github.com/pydata/xarray/issues/9229 being implemented | ||
|
|
||
| pushed_array = da.reductions.cumreduction( | ||
| pushed_array = cumreduction( |
There was a problem hiding this comment.
Same here, this change is about keeping things generic.
There was a problem hiding this comment.
we have this in the "chunk manager" as scan. Can you have claude make that change please
There was a problem hiding this comment.
Yup. Can do. Thanks for the pointer.
| @@ -0,0 +1,275 @@ | |||
| from __future__ import annotations | |||
There was a problem hiding this comment.
The changes in this file are the largest, and also aren't strictly necessary. They're here to support xarray's map_blocks function, which is a little odd in the proposed architecture. I'd be happy to remvove these changes if it would accelerate review. In their defense though, they're also pretty isolated from the main codebase and so should have a low blast radius.
| def __dask_rebuild_from_exprs__(self, exprs): | ||
| ds = self._to_temp_dataset().__dask_rebuild_from_exprs__(exprs) | ||
| return self._from_temp_dataset(ds) | ||
|
|
There was a problem hiding this comment.
This is the core of the change. There's a newly proposed protocol in Dask and this is Xarray supporting that protocol.
| return HighLevelGraph.merge(*graphs.values()) | ||
| except ImportError: | ||
| from dask import sharedict | ||
| pass |
There was a problem hiding this comment.
sharedict is pretty ancient
There was a problem hiding this comment.
can we remove the try/except then?
| self._indexes, | ||
| self._encoding, | ||
| self._close, | ||
| ) |
There was a problem hiding this comment.
This is, again, the core of the change I'm looking for. I hope that it's both fairly straightforward and has a low blast radius.
There was a problem hiding this comment.
confirmed this is v. similar to existing _dask_postcompute as expected.
| wrapper=_wrapper, | ||
| get_chunk_slicer=_get_chunk_slicer, | ||
| dataset_to_dataarray=dataset_to_dataarray, | ||
| ) # type: ignore[return-value] |
There was a problem hiding this comment.
This is part of the larger change that's not really necessary. it's only here to support xarray's map_blocks function
| import dask | ||
| from dask._collections import new_collection | ||
|
|
||
| exprs_iter = iter(exprs) |
There was a problem hiding this comment.
hahahah can we just exprs = list(exprs) it and assert len(exprs) == 1? This is some epic Claude nonsense.
There was a problem hiding this comment.
I think there needs to be exactly one expression per Dask collection. zip(..., strict=True) would also be a cleaner way to do this.
There was a problem hiding this comment.
I'm not sure I follow. I don't think we want to assert that len(exprs) == 1. For context, in a Dataset there are likely to be several exprs, one for each dask array. We want to iterate through them while also iterating through the reconstructed dataset and replace the expressions into the Dataset
dcherian
left a comment
There was a problem hiding this comment.
Generally looks fine to me. I left some minor requests.
I didn't look at the map_blocks stuff too closely. But I can't understand how it works conceptually. It doesn't seem to be an 'expression'. Have wee lost culling then (e.g. ds.pipe(xr.map_blocks(...)).sel(...) => s.sel(...).pipe(xr.map_blocks, ...) )?
Can you add some docs to the top of that dask_array_expr file explaining what it does?
|
Also we'll need a CI env to test it :) . Can we reuse the existing dask.array test suite? |
shoyer
left a comment
There was a problem hiding this comment.
Should this wait until the upstream Dask changes go in, or is it safe to merge now?
| import dask | ||
| from dask._collections import new_collection | ||
|
|
||
| exprs_iter = iter(exprs) |
There was a problem hiding this comment.
I think there needs to be exactly one expression per Dask collection. zip(..., strict=True) would also be a cleaner way to do this.
| for v in self.variables.values(): | ||
| if dask.is_dask_collection(v): | ||
| if not is_dask_array_expr_array(v._data): | ||
| return None |
There was a problem hiding this comment.
Could you note why falling back to claiming there are no expressions in the mixed case is the right thing to do? Alternatively, I can imagine raising an error might be more user-friendly.
There was a problem hiding this comment.
They're both valid options. The mixed case does actually work (at least if you take the dask.compute(...) path). It's entirely possible though that this is indicative of a situation that users would still want to be made aware of and correct. Erring could make sense. So too could warning.
There was a problem hiding this comment.
I guess the common way this case might occur is when an external library constructs dask.array.Array and a user combines that with a dask_array .
I can see a warning being useful. Should the choice between silence/warning/error be an option on the dask_array side? An error-by-default policy could push the ecosystem towards using expressions by default. In general, I have developed a strong dislike for this kind of "accept-everything" behaviour, it makes things hard to reason about.
There was a problem hiding this comment.
Happy to put a warning in if that's what people want. I think that this isn't a decision that I make. I also think that it's the kind of decision that doesn't need to block this PR. It's low stakes and easy to change in the future.
|
|
||
| dask = pytest.importorskip("dask") | ||
| da = pytest.importorskip("dask.array") | ||
| dask_array = pytest.importorskip("dask_array") |
There was a problem hiding this comment.
I don't think we install dask_array currently in our CI, which would probably be a good idea to ensure this doesn't break.
Could you try adding this into our pixi.toml config?
Line 400 in fb20c68
There was a problem hiding this comment.
I've added dask-array to the dask feature, which I think does the job. Not certain though.
There was a problem hiding this comment.
I've removed this for now. I could use guidance on CI.
It's fine to wait. I think it's good to coordinate merging both. I was waiting to push on merging the dask changes until this got some eyes on it. Happy to accelerate merging that PR as needed. |
|
And thanks for the feedback all. Working on things now. |
My plan is to remove map_blocks from the PR in order to get the more important changes in quickly. However, broadly how we're doing this is creating a composite expression that takes in each of the dask-array expressions in the dataset, and emitting lots of dask-array expressions. The actual task does what xarray.map_blocks has always done (or at least as was my historical understanding), take each set of numpy array chunks, assemble an xarray dataset on the fly call the user defined function, and then emit the numpy arrays again. We're just kind of doing that same thing but now at the expr level. In terms of dask array optimizations yes, you're correct that map_blocks is fairly opaque. But really, let's just kick that down the road I think. What's here is ok I think, but I'm not keen to push through a complex thing at the moment. |
Co-Authored-By: Codex <codex@openai.com>
Co-Authored-By: Codex <codex@openai.com>
Co-Authored-By: Codex <codex@openai.com>
9b67dd6 to
584e611
Compare
|
For testing I tried briefly to add dask_array to one of the main CI lanes. I was hoping to run into just a few explicit Nothing that an agent couldn't chew through given time, but it would definitely make this PR much larger, which I'm not sure is appropriate at the moment. |
Keep the dask-array chunk-manager fixes in xarray while dropping the dedicated dask-array CI environment. This leaves map_blocks out of scope, keeps optional dask_array discovery localized, and updates the groupby expectation now that it remains expression-backed. Co-Authored-By: Codex <codex@openai.com>
9f63cb1 to
e856404
Compare
|
I don't know how best to handle testing here. I could use help thinking through Xarray's CI system and testing matrix, as well as help thinking through how deeply we want to test this when. Do we want the entire xarray test suite to pass before we merge in the protocol support? That's ok, but it'll probably be a lot of review. If not, then what do we want to make sure is tested before merging? |
|
Also it's not clear to me that the CI failures here are due to these changes. I suspect it may be the recent pytest 9.1 release (perhaps pinning pytest would be wise if so) |
|
Thank you for the review @dcherian @shoyer . I think I've handled the comments except for CI. On CI I don't have strong conviction on any plan due to ignorance of the project, but what I would probably do is the following:
I think that this is a lot of busy work that agents can handle pretty well. I'm happy to kick it off and iterate on it. If it were me I wouldn't include it in this PR. Personally I would keep this PR light. Another option would be to add a CI entry for just the tests added in this PR, and maybe a few more scattered throughout the repo. This feels ephemeral to me though. Anyway, I've done what I can here. Passing off to you all if you're still interested. Thanks for the time spent so far. |
A while ago I finished Dask expression arrays which support query optimization. This PR supports them in Xarray. This required a few things:
__dask_exprs__protocol in Dask (see Support composite expressions with__dask_exprs__protocol dask/dask#12457)dask.arrayto instead use the chunk manager (these should probably be changed regardless)Example
Output