Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ dependencies = [
"zarr-metadata>=0.3",
]

[project.optional-dependencies]
xarray = ["xarray"]

[tool.maturin]
features = ["pyo3/extension-module", "abi3-py311"]
module-name = "zarrista._zarrista"
Expand All @@ -40,6 +43,7 @@ dev = [
"pytest>=8.0",
"pytest-asyncio>=1.4",
"ruff>=0.6",
"xarray>=2025.1.1",
"zarr>=3",
]
docs = [
Expand Down
117 changes: 117 additions & 0 deletions python/zarrista/xarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""A read-only, lazily-indexed xarray adapter for zarrista arrays.

This module wraps a zarrista `Array` so it can be used as an xarray backend
array. It is pure Python and builds only on zarrista's public API. Importing
this module requires the optional `xarray` dependency (`zarrista[xarray]`).

Only fixed-width data types are supported. Variable-length and masked decoded
layouts, slices with `step != 1`, and async stores are out of scope.
"""

from __future__ import annotations

import numpy as np
import xarray as xr
from xarray.backends.common import BackendArray
from xarray.core import indexing

from zarrista import Array, Tensor


class ZarristaBackendArray(BackendArray):
"""A lazily-indexed xarray backend array wrapping a zarrista `Array`.

The wrapped `Array` is held without reading any chunk data; reads happen
only when the array is indexed. Only fixed-width data types are supported.
"""

def __init__(self, array: Array) -> None:
"""Wrap `array`, deriving the numpy `shape` and `dtype` from its metadata.

Raises `NotImplementedError` for variable-length data types (those whose
`DataType.size` is `None`), which have no fixed-width numpy layout.
"""
if array.dtype.size is None:
raise NotImplementedError(
f"variable-length data type {array.dtype.name!r} is not supported",
)
self._array = array
self.shape = tuple(array.shape)
self.dtype = np.dtype(array.dtype.name)

def __getitem__(self, key: indexing.ExplicitIndexer) -> np.ndarray:
"""Read the region selected by `key`, returning a numpy array.

Declares `BASIC` indexing support; xarray decomposes outer/vectorized
indexing into a basic backend read plus numpy post-indexing.
"""
return indexing.explicit_indexing_adapter(
key,
self.shape,
indexing.IndexingSupport.BASIC,
self._raw_indexing,
)

def _raw_indexing(self, key: tuple[int | slice, ...]) -> np.ndarray:
"""Read `key` (one int or slice per axis) and squeeze integer axes.

Integer indexers are passed through to `retrieve_array_subset`, which is
ndim-preserving (an integer keeps a length-1 axis); those axes are then
squeezed so the result matches xarray's `BASIC` indexing contract.
Slices with `step != 1` are not supported.
"""
selection: list[int | slice] = []
squeeze_axes: list[int] = []
for axis, indexer in enumerate(key):
if isinstance(indexer, slice):
if indexer.step not in (None, 1):
raise NotImplementedError(
"slicing with step != 1 is not supported",
)
selection.append(indexer)
else:
selection.append(int(indexer))
squeeze_axes.append(axis)

decoded = self._array.retrieve_array_subset(tuple(selection))
if not isinstance(decoded, Tensor):
raise NotImplementedError(
f"data type {self._array.dtype.name!r} is not supported",
)
result = decoded.to_numpy()
if squeeze_axes:
result = np.squeeze(result, axis=tuple(squeeze_axes))
return result


def to_dataarray(array: Array, *, name: str | None = None) -> xr.DataArray:
"""Wrap a zarrista `Array` as a lazily-indexed `xarray.DataArray`.

Dimension names are taken from the array's `dimension_names`; entries that
are `None` (or absent entirely) are synthesised as `dim_{i}`. The array's
user attributes are copied onto the variable. The returned `DataArray` reads
lazily: chunk data is fetched only when the array is indexed or computed.
"""
backend_array = ZarristaBackendArray(array)
names = array.dimension_names or [None] * array.ndim
dims = [
name_i if name_i is not None else f"dim_{i}" for i, name_i in enumerate(names)
]
data = indexing.LazilyIndexedArray(backend_array)
variable = xr.Variable(dims, data, attrs=dict(array.attrs))
return xr.DataArray(variable, name=name)


def _array_xr(array: Array) -> xr.DataArray:
"""Return this `Array` as a lazily-indexed `xarray.DataArray`.

Convenience accessor equivalent to `to_dataarray(array)`.
"""
return to_dataarray(array)


# Attach a read-only `.xr` accessor to the core `Array` class. This lives in
# Python (rather than Rust) so the xarray integration stays optional and fully
# decoupled from the extension module; it becomes available once this module is
# imported.
Array.xr = property(_array_xr) # type: ignore[attr-defined]
126 changes: 126 additions & 0 deletions tests/test_xarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Tests for the `zarrista.xarray` lazy BackendArray adapter.

Fixtures are written with zarr-python and read back through zarrista, then
wrapped for xarray. Reads are compared against the equivalent numpy slice.
"""

from pathlib import Path

import numpy as np
import pytest
import zarr
from numpy.typing import NDArray
from xarray.core import indexing

from zarrista import Array, FilesystemStore
from zarrista.xarray import ZarristaBackendArray, to_dataarray


@pytest.fixture
def int32_array(tmp_path: Path) -> tuple[Path, NDArray[np.int32]]:
"""A (9, 64, 100) int32 array with dim names and attrs; returns (path, data)."""
path = tmp_path / "a.zarr"
data = np.arange(9 * 64 * 100, dtype="int32").reshape(9, 64, 100)
z = zarr.create_array(
store=str(path),
shape=data.shape,
chunks=(3, 16, 50),
dtype=data.dtype,
dimension_names=("t", "y", "x"),
)
z[:] = data
z.attrs["units"] = "m"
return path, data


def test_backend_array_shape_and_dtype(int32_array):
path, _data = int32_array
arr = Array.open(FilesystemStore(path))
backend = ZarristaBackendArray(arr)
assert backend.shape == (9, 64, 100)
assert backend.dtype == np.dtype("int32")


def test_raw_indexing_slice_matches_numpy(int32_array):
path, data = int32_array
backend = ZarristaBackendArray(Array.open(FilesystemStore(path)))
result = backend._raw_indexing((slice(0, 2), slice(None), slice(5, 7)))
np.testing.assert_array_equal(result, data[0:2, :, 5:7])


def test_raw_indexing_int_squeezes_axis(int32_array):
path, data = int32_array
backend = ZarristaBackendArray(Array.open(FilesystemStore(path)))
result = backend._raw_indexing((5, slice(None), slice(None)))
assert result.shape == (64, 100)
np.testing.assert_array_equal(result, data[5])


def test_raw_indexing_negative_int(int32_array):
path, data = int32_array
backend = ZarristaBackendArray(Array.open(FilesystemStore(path)))
result = backend._raw_indexing((-1, slice(None), slice(None)))
assert result.shape == (64, 100)
np.testing.assert_array_equal(result, data[-1])


def test_raw_indexing_step_not_one_raises(int32_array):
path, _ = int32_array
backend = ZarristaBackendArray(Array.open(FilesystemStore(path)))
with pytest.raises(NotImplementedError):
backend._raw_indexing((slice(0, 9, 2), slice(None), slice(None)))


def test_variable_length_dtype_raises(tmp_path: Path):
path = tmp_path / "s.zarr"
z = zarr.create_array(store=str(path), shape=(4,), chunks=(4,), dtype=str)
z[:] = np.array(["a", "bb", "ccc", "dddd"], dtype=object)
arr = Array.open(FilesystemStore(path))
assert arr.dtype.size is None # precondition: variable-length
with pytest.raises(NotImplementedError):
ZarristaBackendArray(arr)


def test_to_dataarray_dims_attrs_and_lazy(int32_array):
path, _data = int32_array
arr = Array.open(FilesystemStore(path))
da = to_dataarray(arr, name="temp")

assert da.name == "temp"
assert da.dims == ("t", "y", "x")
assert da.shape == (9, 64, 100)
assert da.dtype == np.dtype("int32")
assert da.attrs["units"] == "m"
# The data is wrapped lazily and not yet loaded into memory.
assert isinstance(da.variable._data, indexing.LazilyIndexedArray)


def test_to_dataarray_indexing_matches_numpy(int32_array):
path, data = int32_array
da = to_dataarray(Array.open(FilesystemStore(path)))
np.testing.assert_array_equal(da[0:2, :, 5:7].to_numpy(), data[0:2, :, 5:7])
np.testing.assert_array_equal(da[5].to_numpy(), data[5])
np.testing.assert_array_equal(da.to_numpy(), data)


def test_array_xr_accessor_matches_to_dataarray(int32_array):
path, data = int32_array
arr = Array.open(FilesystemStore(path))
da = arr.xr
assert da.dims == ("t", "y", "x")
assert da.shape == (9, 64, 100)
np.testing.assert_array_equal(da.to_numpy(), data)


def test_to_dataarray_synthesizes_dim_names(tmp_path: Path):
path = tmp_path / "nodims.zarr"
data = np.arange(2 * 3, dtype="int16").reshape(2, 3)
z = zarr.create_array(
store=str(path),
shape=data.shape,
chunks=(2, 3),
dtype=data.dtype,
)
z[:] = data
da = to_dataarray(Array.open(FilesystemStore(path)))
assert da.dims == ("dim_0", "dim_1")
Loading
Loading