diff --git a/xarray/compat/array_api_compat.py b/xarray/compat/array_api_compat.py index e1e5d5c5bdc..575f8cdf07d 100644 --- a/xarray/compat/array_api_compat.py +++ b/xarray/compat/array_api_compat.py @@ -1,3 +1,5 @@ +from types import ModuleType + import numpy as np from xarray.namedarray.pycompat import array_type @@ -46,7 +48,7 @@ def result_type(*arrays_and_dtypes, xp) -> np.dtype: return _future_array_api_result_type(*arrays_and_dtypes, xp=xp) -def get_array_namespace(*values): +def get_array_namespace(*values) -> ModuleType: def _get_single_namespace(x): if hasattr(x, "__array_namespace__"): return x.__array_namespace__() diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 7884c7bd74a..fa6c94c7e10 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -9,6 +9,7 @@ from contextlib import suppress from dataclasses import dataclass, field from datetime import timedelta +from types import ModuleType from typing import TYPE_CHECKING, Any, cast, overload import numpy as np @@ -16,6 +17,7 @@ from numpy.typing import DTypeLike from packaging.version import Version +from xarray.compat.array_api_compat import get_array_namespace from xarray.compat.npcompat import HAS_STRING_DTYPE from xarray.core import duck_array_ops from xarray.core.coordinate_transform import CoordinateTransform @@ -693,7 +695,10 @@ def __array__( else: return np.asarray(to_numpy(self.get_duck_array()), dtype=dtype) - def get_duck_array(self): + def __array_namespace__(self: Any) -> ModuleType: + return get_array_namespace(self.array) + + def get_duck_array(self) -> duckarray: return self.array.get_duck_array() def __getitem__(self, key: Any): diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 23d55ee0a11..eab7503f262 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -843,7 +843,7 @@ def chunk( # Using OuterIndexer is a pragmatic choice: dask does not yet handle # different indexing types in an explicit way: # https://github.com/dask/dask/issues/2883 - ndata = ImplicitToExplicitIndexingAdapter(data_old, OuterIndexer) # type: ignore[assignment] + ndata = ImplicitToExplicitIndexingAdapter(data_old, OuterIndexer) if is_dict_like(chunks): chunks = tuple(starmap(chunks.get, enumerate(ndata.shape))) diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py index eb01a150c18..c03b9a4da13 100644 --- a/xarray/namedarray/daskmanager.py +++ b/xarray/namedarray/daskmanager.py @@ -68,8 +68,9 @@ def from_array( import dask.array as da if isinstance(data, ImplicitToExplicitIndexingAdapter): - # lazily loaded backend array classes should use NumPy array operations. - kwargs["meta"] = np.ndarray + # lazily loaded backend array classes should use NumPy or CuPy array operations. + xp = data.__array_namespace__() + kwargs["meta"] = xp.ndarray return da.from_array( data, diff --git a/xarray/namedarray/pycompat.py b/xarray/namedarray/pycompat.py index a192930cea7..52349b13830 100644 --- a/xarray/namedarray/pycompat.py +++ b/xarray/namedarray/pycompat.py @@ -140,7 +140,7 @@ def to_duck_array(data: Any, **kwargs: dict[str, Any]) -> duckarray[_ShapeType, return loaded_data if isinstance(data, ExplicitlyIndexed | ImplicitToExplicitIndexingAdapter): - return data.get_duck_array() # type: ignore[no-untyped-call, no-any-return] + return data.get_duck_array() elif is_duck_array(data): return data else: