Skip to content

Support CuPy-backed arrays in DaskManager#11383

Draft
weiji14 wants to merge 9 commits into
pydata:mainfrom
weiji14:dask_with_cupy
Draft

Support CuPy-backed arrays in DaskManager#11383
weiji14 wants to merge 9 commits into
pydata:mainfrom
weiji14:dask_with_cupy

Conversation

@weiji14

@weiji14 weiji14 commented Jun 13, 2026

Copy link
Copy Markdown
Contributor

Description

The default Dask ChunkManagerEntrypoint appears to be hardcoded to return NumPy arrays by default, even if the underlying arrays are CuPy arrays

TODO:

Probably needs #11381 to be merged first. Part of resolving xarray-contrib/cupy-xarray#81 (comment).

Checklist

  • Closes #xxxx
  • Tests added
  • User visible changes (including notable bug fixes) are documented in whats-new.rst
  • New functions/methods are listed in api.rst

AI Disclosure

The "meta" argument passed to dask.array.from_array should not be hardcoded to just `numpy.ndarray`, but allow for `cupy.ndarray` too.
@github-actions github-actions Bot added the topic-NamedArray Lightweight version of Variable label Jun 13, 2026
Not sure how to type-hint np | cp | ??, so just use Any for output of get_array_namespace.
Comment thread xarray/namedarray/daskmanager.py Outdated
# lazily loaded backend array classes should use NumPy array operations.
kwargs["meta"] = np.ndarray
# lazily loaded backend array classes should use NumPy or CuPy array operations.
xp = get_array_namespace(data.get_duck_array())

@keewis keewis Jun 13, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably need to add some API allow getting the underlying array type / library without actually fetching data. Something like a data.get_array_namespace() or data.get_meta()? Not sure how easy it would be to implement that, though.

(I think this is what causes the tests to fail)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we could just call xp = data.__array_namespace__() following the Array API spec - https://data-apis.org/array-api/2025.12/API_specification/generated/array_api.array.__array_namespace__.html, and it would propagate through all the subclassed layers to get the underlying array namespace (numpy or cupy). I thought this would work by putting it into the NDArrayMixin (b77cc57), but that breaks a lot of the lazy repls...

Might need to think this through a bit more. Wondering if there needs to be a cached .__cached_array_namespace__ attribute of some sort to work with the lazy objects...

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i went down this road a while ago. the problem is that we need our lazy arrays coerced to an in-memory type at some point. If they advertise __array_namespace__ they can treated as an in-memory type (deep within dask), and nothing works. Did you also run in to this problem?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I've went pretty deep, but not deep into dask internals (yet?). If it ends up needing changes in dask, I'm just gonna push on findind a way to remove dask entirely - #9038 (comment).

Right now I think I've solved the repl issues by changing some of the logic in formatting.py. Still need to work my way through some other logic that have been basing their logic around duck arrays / __array_namespace__ but shouldn't...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually:

[ ] Ensure calling .compute() on an xarray.Dataset backed by CuPy arrays don't get coerced to NumPy arrays.

I did realize that even with the cupy meta fixes here, the Dask-backed arrays still get loaded into NumPy instead. So maybe that's what you're referring to? That there's no way to have a fully CuPy-only array pipeline without changing some internals of Dask?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an example I can run without an nvidia gpu?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately no, I only know how to reproduce this on this cupy-xarray PR - https://github.com/xarray-contrib/cupy-xarray/pull/81/changes#diff-b866be1141ec4295c6a2ef9b8effab4a58d6dacd262918d0db4f56d6d7927818:

ds: xr.Dataset = xr.open_mfdataset(
    paths=[
        "https://github.com/developmentseed/titiler/raw/1.2.0/src/titiler/mosaic/tests/fixtures/B01.tif",
        "https://github.com/developmentseed/titiler/raw/1.2.0/src/titiler/mosaic/tests/fixtures/B09.tif",
    ],
    engine="cog3pio",
    concat_dim="band",
    combine="nested",
    device_id=None,
)
ds.raster.load()  # inplace load, requires https://github.com/pydata/xarray/pull/11381
# assert isinstance(
#     ds.raster, cp.ndarray  # TODO wait for https://github.com/pydata/xarray/pull/11383 ?
# )

But I should figure out a good duck array test somehow, and will let you know if it can be reproduced without a GPU.

weiji14 added 4 commits June 15, 2026 13:12
Centralize retrieving of the __array_namespace__ through several subclassed layers, to avoid having to go through `.get_duck_array()`. Need to put `from xarray.compat.array_api_compat import get_array_namespace` import within the method to avoid circular import.

Also type-hinted output of `get_array_namespace` as ModuleType following numpy/numpy#20719.
To fix repr AssertionError mismatches on:
- TestVariable::test_repr_lazy_data
- test_repr_pandas_multi_index
- test_repr_pandas_range_index
- test_display_nbytes
- test_repr_file_collapsed
- test_coordinate_transform_variable_repr

by preventing specific xarray internal array types from going through the is_duck_array repl path.
Comment thread xarray/core/formatting.py Outdated
return short_array_repr(array)
elif is_duck_array(internal_data):
elif not isinstance(
internal_data, (LazilyIndexedArray, MemoryCachedArray, IndexingAdapter)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
internal_data, (LazilyIndexedArray, MemoryCachedArray, IndexingAdapter)
internal_data, ExplicitlyIndexed

is usually the way to do it

@weiji14 weiji14 Jun 18, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent hint, but actually, not needed anymore. I talked to @keewis on the Pangeo call just now, and figured out the right-level of abstraction is to put the __array_namespace__ method in ImplicitToExplicitIndexingAdapter instead of NDArrayMixin (commit 7132f73). So we won't need to change this repr stuff anymore.

@weiji14 weiji14 Jun 18, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wait, commit 7132f73 might not be correct either 😅 I'll actually need to use your suggestion here, and really should write a proper test for this first...

weiji14 added 2 commits June 18, 2026 12:30
Move `__array_namespace__` method from NDArrayMixin to just ImplicitToExplicitIndexingAdapter.
@weiji14

weiji14 commented Jun 18, 2026

Copy link
Copy Markdown
Contributor Author

CI is failing with Error: ... duplicate parametrization of ... (pytest-dev/pytest#14591), but other tests seems to be passing from what I can tell. Will re-run CI once pytest 9.1.1 comes out.

Still need to make sure calling .compute() on CuPy-backed Dask arrays returns CuPy arrays, though I'm thinking if it should go into a separate PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

topic-indexing topic-NamedArray Lightweight version of Variable

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants