Skip to content

Support query optimization with Dask expression arrays#11382

Open
mrocklin wants to merge 9 commits into
pydata:mainfrom
mrocklin:codex/composite-expr-protocol
Open

Support query optimization with Dask expression arrays#11382
mrocklin wants to merge 9 commits into
pydata:mainfrom
mrocklin:codex/composite-expr-protocol

Conversation

@mrocklin

Copy link
Copy Markdown
Contributor

A while ago I finished Dask expression arrays which support query optimization. This PR supports them in Xarray. This required a few things:

  • Creating a new __dask_exprs__ protocol in Dask (see Support composite expressions with __dask_exprs__ protocol dask/dask#12457)
  • Implementing that protocol in Xarray (this does most of the lifting)
  • Building a array chunk manager (this was done in the dask-array project)
  • Some silliness around xarray's map_blocks function
  • Changing a few explicit uses of dask.array to instead use the chunk manager (these should probably be changed regardless)

Example

import dask
import dask_array
import xarray as xr
from xarray.namedarray.parallelcompat import get_chunked_array_type


ds = xr.tutorial.scatter_example_dataset(seed=42).chunk({"x": 1, "y": 1, "z": 2, "w": 2})

# The slice and rechunk start above the elementwise operation.  dask-array's
# optimizer can push them down so it only builds the small requested window.
window = (ds.A + ds.B).chunk({"y": 3}).isel(x=slice(0, 1), y=slice(0, 3))

tasks_before = len(window.__dask_graph__())
(optimized_window,) = dask.optimize(window)
optimized_data = window.data.optimize()
tasks_after = len(optimized_data.__dask_graph__())

manager = get_chunked_array_type(ds.A.data)

print(f"xarray chunk manager: {type(manager).__name__}")
print(f"dask.optimize result: {type(optimized_window).__name__}")
print(f"array type: {type(window.data).__module__}.{type(window.data).__name__}")
print(f"graph tasks before optimize: {tasks_before}")
print(f"graph tasks after optimize:  {tasks_after}")
print()
print("Before optimize:")
window.data.pprint()
print()
print("After optimize:")
optimized_data.pprint()

Output

xarray chunk manager: DaskArrayExprManager
dask.optimize result: DataArray
array type: dask_array._collection.Array
graph tasks before optimize: 448
graph tasks after optimize:  12

Before optimize:
  Operation                Shape    Bytes   Chunks
  Getitem           (1, 3, 4, 4)    384 B  1×3×2×2
  └ Rechunk        (3, 11, 4, 4)  4.1 kiB  1×3×2×2
    └ Add          (3, 11, 4, 4)  4.1 kiB  1×1×2×2
      ├ FromArray  (3, 11, 4, 4)  4.1 kiB  1×1×2×2
      └ FromArray  (3, 11, 4, 4)  4.1 kiB  1×1×2×2

After optimize:
  Operation           Shape  Bytes   Chunks
  Add          (1, 3, 4, 4)  384 B  1×3×2×2
  ├ FromArray  (1, 3, 4, 4)  384 B  1×3×2×2
  └ FromArray  (1, 3, 4, 4)  384 B  1×3×2×2

@norlandrhagen

Copy link
Copy Markdown

Excited to check this out! Fingers crossed it helps the xarray/dask 'large task graph' serialization warning.

@mrocklin mrocklin left a comment

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To aid review I've added some comments on what I think is essential, and what could be dropped with only slight degradation in functionality (but lots of simplicity in review)

elif module_available("dask", "2024.08.2"):
from dask.array import reshape_blockwise as dask_reshape_blockwise

return dask_reshape_blockwise(x, shape=shape, chunks=chunks)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this change is useful regardless if xarray wants to do the chunk manager thing rather than be tied to dask.array

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM. FWIW it's technically not array api (and will never be since it's only valuable for chunked arary kinds) . presumably dask_array implements it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point. I guess the question then becomes "how should xarray handle dispatch along these operations between different chunked APIs". What I've done here doesn't feel quite right. Any suggestions?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

**we have the "chunk manager" for this.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I understand more clearly now. Thank you. I guess there are two options here then:

  1. Leave it as is. A little unclean because we're kind of abusing a protocol, but in a harmless way
  2. Extend the Xarray's chunk manager

From your earlier SGTM comment my sense is that it'd be reasonable to extend the chunk manager, but not a big deal. Planning to leave this as is for now, but let me know if you'd prefer otherwise.

Comment thread xarray/compat/dask_array_ops.py Outdated
# once https://github.com/pydata/xarray/issues/9229 being implemented

pushed_array = da.reductions.cumreduction(
pushed_array = cumreduction(

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, this change is about keeping things generic.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have this in the "chunk manager" as scan. Can you have claude make that change please

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. Can do. Thanks for the pointer.

Comment thread xarray/core/dask_array_expr.py Outdated
@@ -0,0 +1,275 @@
from __future__ import annotations

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in this file are the largest, and also aren't strictly necessary. They're here to support xarray's map_blocks function, which is a little odd in the proposed architecture. I'd be happy to remvove these changes if it would accelerate review. In their defense though, they're also pretty isolated from the main codebase and so should have a low blast radius.

Comment thread xarray/core/dataarray.py
def __dask_rebuild_from_exprs__(self, exprs):
ds = self._to_temp_dataset().__dask_rebuild_from_exprs__(exprs)
return self._from_temp_dataset(ds)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the core of the change. There's a newly proposed protocol in Dask and this is Xarray supporting that protocol.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment thread xarray/core/dataset.py Outdated
return HighLevelGraph.merge(*graphs.values())
except ImportError:
from dask import sharedict
pass

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sharedict is pretty ancient

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove the try/except then?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Done.

Comment thread xarray/core/dataset.py
self._indexes,
self._encoding,
self._close,
)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is, again, the core of the change I'm looking for. I hope that it's both fairly straightforward and has a low blast radius.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confirmed this is v. similar to existing _dask_postcompute as expected.

Comment thread xarray/core/parallel.py Outdated
wrapper=_wrapper,
get_chunk_slicer=_get_chunk_slicer,
dataset_to_dataarray=dataset_to_dataarray,
) # type: ignore[return-value]

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is part of the larger change that's not really necessary. it's only here to support xarray's map_blocks function

Comment thread xarray/core/dataset.py Outdated
import dask
from dask._collections import new_collection

exprs_iter = iter(exprs)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hahahah can we just exprs = list(exprs) it and assert len(exprs) == 1? This is some epic Claude nonsense.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there needs to be exactly one expression per Dask collection. zip(..., strict=True) would also be a cleaner way to do this.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow. I don't think we want to assert that len(exprs) == 1. For context, in a Dataset there are likely to be several exprs, one for each dask array. We want to iterate through them while also iterating through the reconstructed dataset and replace the expressions into the Dataset

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry, I was responding here to @dcherian 's response, not @shoyer 's . My internet access today is a bit spotty and I was responding to outdated information. +1 on zip strict.

@dcherian dcherian left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looks fine to me. I left some minor requests.

I didn't look at the map_blocks stuff too closely. But I can't understand how it works conceptually. It doesn't seem to be an 'expression'. Have wee lost culling then (e.g. ds.pipe(xr.map_blocks(...)).sel(...) => s.sel(...).pipe(xr.map_blocks, ...) )?

Can you add some docs to the top of that dask_array_expr file explaining what it does?

@dcherian

Copy link
Copy Markdown
Contributor

Also we'll need a CI env to test it :) . Can we reuse the existing dask.array test suite?

@shoyer shoyer left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this wait until the upstream Dask changes go in, or is it safe to merge now?

Comment thread xarray/core/dataset.py Outdated
import dask
from dask._collections import new_collection

exprs_iter = iter(exprs)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there needs to be exactly one expression per Dask collection. zip(..., strict=True) would also be a cleaner way to do this.

Comment thread xarray/core/dataset.py
for v in self.variables.values():
if dask.is_dask_collection(v):
if not is_dask_array_expr_array(v._data):
return None

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you note why falling back to claiming there are no expressions in the mixed case is the right thing to do? Alternatively, I can imagine raising an error might be more user-friendly.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're both valid options. The mixed case does actually work (at least if you take the dask.compute(...) path). It's entirely possible though that this is indicative of a situation that users would still want to be made aware of and correct. Erring could make sense. So too could warning.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the common way this case might occur is when an external library constructs dask.array.Array and a user combines that with a dask_array .

I can see a warning being useful. Should the choice between silence/warning/error be an option on the dask_array side? An error-by-default policy could push the ecosystem towards using expressions by default. In general, I have developed a strong dislike for this kind of "accept-everything" behaviour, it makes things hard to reason about.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to put a warning in if that's what people want. I think that this isn't a decision that I make. I also think that it's the kind of decision that doesn't need to block this PR. It's low stakes and easy to change in the future.


dask = pytest.importorskip("dask")
da = pytest.importorskip("dask.array")
dask_array = pytest.importorskip("dask_array")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we install dask_array currently in our CI, which would probably be a good idea to ensure this doesn't break.

Could you try adding this into our pixi.toml config?

test-py313 = { features = [

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added dask-array to the dask feature, which I think does the job. Not certain though.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed this for now. I could use guidance on CI.

@mrocklin

Copy link
Copy Markdown
Contributor Author

Should this wait until the upstream Dask changes go in, or is it safe to merge now?

It's fine to wait. I think it's good to coordinate merging both. I was waiting to push on merging the dask changes until this got some eyes on it. Happy to accelerate merging that PR as needed.

@mrocklin

Copy link
Copy Markdown
Contributor Author

And thanks for the feedback all. Working on things now.

@mrocklin

Copy link
Copy Markdown
Contributor Author

I didn't look at the map_blocks stuff too closely. But I can't understand how it works conceptually. It doesn't seem to be an 'expression'. Have wee lost culling then (e.g. ds.pipe(xr.map_blocks(...)).sel(...) => s.sel(...).pipe(xr.map_blocks, ...) )?

My plan is to remove map_blocks from the PR in order to get the more important changes in quickly. However, broadly how we're doing this is creating a composite expression that takes in each of the dask-array expressions in the dataset, and emitting lots of dask-array expressions. The actual task does what xarray.map_blocks has always done (or at least as was my historical understanding), take each set of numpy array chunks, assemble an xarray dataset on the fly call the user defined function, and then emit the numpy arrays again. We're just kind of doing that same thing but now at the expr level.

In terms of dask array optimizations yes, you're correct that map_blocks is fairly opaque.

But really, let's just kick that down the road I think. What's here is ok I think, but I'm not keen to push through a complex thing at the moment.

@mrocklin mrocklin force-pushed the codex/composite-expr-protocol branch from 9b67dd6 to 584e611 Compare June 17, 2026 18:21
@github-actions github-actions Bot added the Automation Github bots, testing workflows, release automation label Jun 17, 2026
@mrocklin

Copy link
Copy Markdown
Contributor Author

For testing I tried briefly to add dask_array to one of the main CI lanes. I was hoping to run into just a few explicit dask.array issues. Turns out that there are many. I don't think that they're issues with code actually, but instead issues with the tests. Many are deeply tied to things like the dask.array.Array constructor.

Nothing that an agent couldn't chew through given time, but it would definitely make this PR much larger, which I'm not sure is appropriate at the moment.

Keep the dask-array chunk-manager fixes in xarray while dropping the dedicated dask-array CI environment. This leaves map_blocks out of scope, keeps optional dask_array discovery localized, and updates the groupby expectation now that it remains expression-backed.

Co-Authored-By: Codex <codex@openai.com>
@mrocklin mrocklin force-pushed the codex/composite-expr-protocol branch from 9f63cb1 to e856404 Compare June 17, 2026 23:36
@mrocklin

Copy link
Copy Markdown
Contributor Author

I don't know how best to handle testing here. I could use help thinking through Xarray's CI system and testing matrix, as well as help thinking through how deeply we want to test this when. Do we want the entire xarray test suite to pass before we merge in the protocol support? That's ok, but it'll probably be a lot of review. If not, then what do we want to make sure is tested before merging?

@mrocklin

Copy link
Copy Markdown
Contributor Author

Also it's not clear to me that the CI failures here are due to these changes. I suspect it may be the recent pytest 9.1 release (perhaps pinning pytest would be wise if so)

@mrocklin

Copy link
Copy Markdown
Contributor Author

Thank you for the review @dcherian @shoyer . I think I've handled the comments except for CI.

On CI I don't have strong conviction on any plan due to ignorance of the project, but what I would probably do is the following:

  • Merge this without CI support
  • Follow up with work that adds dask_array to one of the lanes (or make a new lane) and runs all existing tests on it. It will fail hard.
  • Go through all tests and either make them generic, or mark them with a custom pytest.mark for failing with dask_array due to some permissible reason (like they were made a decade ago and deeply tie into dask.array internal details)

I think that this is a lot of busy work that agents can handle pretty well. I'm happy to kick it off and iterate on it. If it were me I wouldn't include it in this PR. Personally I would keep this PR light.

Another option would be to add a CI entry for just the tests added in this PR, and maybe a few more scattered throughout the repo. This feels ephemeral to me though.

Anyway, I've done what I can here. Passing off to you all if you're still interested. Thanks for the time spent so far.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Automation Github bots, testing workflows, release automation topic-dask

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants