Skip to content
4 changes: 3 additions & 1 deletion xarray/compat/array_api_compat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from types import ModuleType

import numpy as np

from xarray.namedarray.pycompat import array_type
Expand Down Expand Up @@ -46,7 +48,7 @@ def result_type(*arrays_and_dtypes, xp) -> np.dtype:
return _future_array_api_result_type(*arrays_and_dtypes, xp=xp)


def get_array_namespace(*values):
def get_array_namespace(*values) -> ModuleType:
def _get_single_namespace(x):
if hasattr(x, "__array_namespace__"):
return x.__array_namespace__()
Expand Down
7 changes: 6 additions & 1 deletion xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
from contextlib import suppress
from dataclasses import dataclass, field
from datetime import timedelta
from types import ModuleType
from typing import TYPE_CHECKING, Any, cast, overload

import numpy as np
import pandas as pd
from numpy.typing import DTypeLike
from packaging.version import Version

from xarray.compat.array_api_compat import get_array_namespace
from xarray.compat.npcompat import HAS_STRING_DTYPE
from xarray.core import duck_array_ops
from xarray.core.coordinate_transform import CoordinateTransform
Expand Down Expand Up @@ -693,7 +695,10 @@ def __array__(
else:
return np.asarray(to_numpy(self.get_duck_array()), dtype=dtype)

def get_duck_array(self):
def __array_namespace__(self: Any) -> ModuleType:
return get_array_namespace(self.array)

def get_duck_array(self) -> duckarray:
return self.array.get_duck_array()

def __getitem__(self, key: Any):
Expand Down
2 changes: 1 addition & 1 deletion xarray/namedarray/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ def chunk(
# Using OuterIndexer is a pragmatic choice: dask does not yet handle
# different indexing types in an explicit way:
# https://github.com/dask/dask/issues/2883
ndata = ImplicitToExplicitIndexingAdapter(data_old, OuterIndexer) # type: ignore[assignment]
ndata = ImplicitToExplicitIndexingAdapter(data_old, OuterIndexer)

if is_dict_like(chunks):
chunks = tuple(starmap(chunks.get, enumerate(ndata.shape)))
Expand Down
5 changes: 3 additions & 2 deletions xarray/namedarray/daskmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def from_array(
import dask.array as da

if isinstance(data, ImplicitToExplicitIndexingAdapter):
# lazily loaded backend array classes should use NumPy array operations.
kwargs["meta"] = np.ndarray
# lazily loaded backend array classes should use NumPy or CuPy array operations.
xp = data.__array_namespace__()
kwargs["meta"] = xp.ndarray

return da.from_array(
data,
Expand Down
2 changes: 1 addition & 1 deletion xarray/namedarray/pycompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def to_duck_array(data: Any, **kwargs: dict[str, Any]) -> duckarray[_ShapeType,
return loaded_data

if isinstance(data, ExplicitlyIndexed | ImplicitToExplicitIndexingAdapter):
return data.get_duck_array() # type: ignore[no-untyped-call, no-any-return]
return data.get_duck_array()
elif is_duck_array(data):
return data
else:
Expand Down
Loading