diff --git a/src/quantem/core/datastructures/__init__.py b/src/quantem/core/datastructures/__init__.py index dfb5b47a..0303706f 100644 --- a/src/quantem/core/datastructures/__init__.py +++ b/src/quantem/core/datastructures/__init__.py @@ -2,6 +2,7 @@ from quantem.core.datastructures.vector import Vector as Vector from quantem.core.datastructures.dataset4dstem import Dataset4dstem as Dataset4dstem +from quantem.core.datastructures.dataset5dstem import Dataset5dstem as Dataset5dstem from quantem.core.datastructures.dataset4d import Dataset4d as Dataset4d from quantem.core.datastructures.dataset3d import Dataset3d as Dataset3d from quantem.core.datastructures.dataset2d import Dataset2d as Dataset2d diff --git a/src/quantem/core/datastructures/dataset.py b/src/quantem/core/datastructures/dataset.py index 4d2ab9e1..fa1fcef2 100644 --- a/src/quantem/core/datastructures/dataset.py +++ b/src/quantem/core/datastructures/dataset.py @@ -4,6 +4,7 @@ from typing import Any, Literal, Optional, Self, Union, overload import numpy as np +import torch from numpy.typing import DTypeLike, NDArray from quantem.core.io.serialize import AutoSerialize @@ -38,24 +39,37 @@ class Dataset(AutoSerialize): def __init__( self, - array: Any, # Input can be array-like - name: str, - origin: NDArray | tuple | list | float | int, - sampling: NDArray | tuple | list | float | int, - units: list[str] | tuple | list, + array: NDArray | None = None, + tensor: torch.Tensor | None = None, + name: str = "", + origin: NDArray | tuple | list | float | int | None = None, + sampling: NDArray | tuple | list | float | int | None = None, + units: list[str] | tuple | list | None = None, signal_units: str = "arb. units", metadata: Optional[dict] = None, _token: object | None = None, ): if _token is not self._token: - raise RuntimeError("Use Dataset.from_array() to instantiate this class.") - super().__init__() - arr = ensure_valid_array(array) - if not isinstance(arr, np.ndarray): - raise TypeError( - "Dataset requires a NumPy array (CuPy is not supported on this branch)." + raise RuntimeError( + "Use Dataset.from_array() or Dataset.from_tensor() to instantiate this class." ) - self._array = arr + super().__init__() + # Dual-slot storage: exactly one of (_array, _tensor) is set. + if array is None and tensor is None: + raise ValueError("Provide either `array` (numpy) or `tensor` (torch).") + if array is not None and tensor is not None: + raise ValueError("Provide only one of `array` or `tensor`, not both.") + if array is not None: + arr = ensure_valid_array(array) + if not isinstance(arr, np.ndarray): + raise TypeError(f"Dataset.array must be numpy.ndarray, got {type(arr).__name__}.") + self._array = arr + self._tensor = None + else: + if not isinstance(tensor, torch.Tensor): + raise TypeError(f"Dataset.tensor must be torch.Tensor, got {type(tensor).__name__}.") + self._array = None + self._tensor = tensor self.name = name self.origin = origin self.sampling = sampling @@ -122,19 +136,31 @@ def from_array( # --- Properties --- @property - def array(self) -> NDArray: - """The underlying n-dimensional NumPy array data.""" - return self._array + def array(self) -> NDArray | None: + """The underlying n-dimensional NumPy array data. + + Returns ``None`` for tensor-backed datasets. Use ``.tensor`` for the + torch tensor, or ``.numpy()`` to materialize a numpy copy explicitly. + """ + return getattr(self, "_array", None) @array.setter def array(self, value: NDArray) -> None: - arr = ensure_valid_array(value, ndim=self.ndim) # want to allow changing dtype + arr = ensure_valid_array(value, ndim=self.ndim) if not isinstance(arr, np.ndarray): - raise TypeError( - "Dataset requires a NumPy array (CuPy is not supported on this branch)." - ) + raise TypeError(f"Dataset.array must be numpy.ndarray, got {type(arr).__name__}.") self._array = arr - # self._array = ensure_valid_array(value, dtype=self.dtype, ndim=self.ndim) + + @property + def tensor(self) -> torch.Tensor: + """Torch tensor backing the data. AttributeError if numpy-backed.""" + # getattr handles AutoSerialize-restored instances (no __init__ run). + tensor = getattr(self, "_tensor", None) + if tensor is None: + raise AttributeError( + f"Dataset '{self.name}' is numpy-backed; use Dataset.from_tensor() at construction." + ) + return tensor @property def metadata(self) -> dict: @@ -191,26 +217,50 @@ def file_path(self, value: os.PathLike | str | None) -> None: # --- Derived Properties --- @property def shape(self) -> tuple[int, ...]: - return self.array.shape + # Direct slot access (never triggers .array derive, which would force + # a full GPU->CPU copy on tensor-backed datasets). getattr handles + # AutoSerialize-restored instances (no __init__ run). + array = getattr(self, "_array", None) + return tuple((array if array is not None else self._tensor).shape) @property def ndim(self) -> int: - return self.array.ndim + array = getattr(self, "_array", None) + return (array if array is not None else self._tensor).ndim @property def dtype(self) -> DTypeLike: - return self.array.dtype + array = getattr(self, "_array", None) + return (array if array is not None else self._tensor).dtype @property def device(self) -> str: - """ - Outputting a string is likely temporary -- once we have our use cases we can - figure out a more permanent device solution that enables easier translation between - numpy <-> torch <-> numpy, etc. + """``"cpu"`` for numpy-backed; torch device string for tensor-backed.""" + tensor = getattr(self, "_tensor", None) + if tensor is not None: + return str(tensor.device) + return "cpu" - For NumPy-only datasets, this is always "cpu". + def numpy(self) -> NDArray: + """Return the data as a numpy array (mirrors ``torch.Tensor.numpy()``). + + For numpy-backed datasets, returns ``self.array`` directly. For + tensor-backed datasets, materializes a CPU copy via ``.detach().cpu().numpy()``. """ - return "cpu" + array = getattr(self, "_array", None) + if array is not None: + return array + return self._tensor.detach().cpu().numpy() + + def to(self, device) -> Self: + """Move the underlying tensor to ``device``. Raises if numpy-backed.""" + tensor = getattr(self, "_tensor", None) + if tensor is None: + raise AttributeError( + f"Cannot .to({device!r}) on numpy-backed Dataset '{self.name}'." + ) + self._tensor = tensor.to(device) + return self # --- Summaries --- def __repr__(self) -> str: diff --git a/src/quantem/core/datastructures/dataset4d.py b/src/quantem/core/datastructures/dataset4d.py index 8e5bdfa0..f5681730 100644 --- a/src/quantem/core/datastructures/dataset4d.py +++ b/src/quantem/core/datastructures/dataset4d.py @@ -1,6 +1,7 @@ from typing import Any, Self, Union import numpy as np +import torch from numpy.typing import NDArray from quantem.core.datastructures.dataset import Dataset @@ -21,11 +22,12 @@ class Dataset4d(Dataset): def __init__( self, - array: NDArray | Any, - name: str, - origin: NDArray | tuple | list | float | int, - sampling: NDArray | tuple | list | float | int, - units: list[str] | tuple | list, + array: NDArray | None = None, + tensor: torch.Tensor | None = None, + name: str = "", + origin: NDArray | tuple | list | float | int | None = None, + sampling: NDArray | tuple | list | float | int | None = None, + units: list[str] | tuple | list | None = None, signal_units: str = "arb. units", metadata: dict = {}, _token: object | None = None, @@ -34,8 +36,10 @@ def __init__( Parameters ---------- - array : NDArray | Any - The underlying 3D array data + array : NDArray | None + The underlying 4D numpy array. Provide exactly one of ``array`` or ``tensor``. + tensor : torch.Tensor | None + The underlying 4D torch tensor (any device). Provide exactly one of ``array`` or ``tensor``. name : str A descriptive name for the dataset origin : NDArray | tuple | list | float | int @@ -51,6 +55,7 @@ def __init__( """ super().__init__( array=array, + tensor=tensor, name=name, origin=origin, sampling=sampling, diff --git a/src/quantem/core/datastructures/dataset4dstem.py b/src/quantem/core/datastructures/dataset4dstem.py index 79cbc479..48d3f737 100644 --- a/src/quantem/core/datastructures/dataset4dstem.py +++ b/src/quantem/core/datastructures/dataset4dstem.py @@ -2,6 +2,7 @@ import matplotlib.pyplot as plt import numpy as np +import torch from matplotlib.patches import Circle, Wedge from numpy.typing import NDArray @@ -41,11 +42,12 @@ class Dataset4dstem(Dataset4d): def __init__( self, - array: NDArray | Any, - name: str, - origin: NDArray | tuple | list | float | int, - sampling: NDArray | tuple | list | float | int, - units: list[str] | tuple | list, + array: NDArray | None = None, + tensor: torch.Tensor | None = None, + name: str = "", + origin: NDArray | tuple | list | float | int | None = None, + sampling: NDArray | tuple | list | float | int | None = None, + units: list[str] | tuple | list | None = None, signal_units: str = "arb. units", metadata: dict = {}, _token: object | None = None, @@ -56,6 +58,9 @@ def __init__( ---------- array : NDArray | Any The underlying 4D array data + tensor : torch.Tensor | None, optional + Alternative to ``array``: the underlying 4D torch tensor (any device). + Provide exactly one of ``array`` or ``tensor``. name : str A descriptive name for the dataset origin : NDArray | tuple | list | float | int @@ -79,6 +84,7 @@ def __init__( super().__init__( array=array, + tensor=tensor, name=name, origin=origin, sampling=sampling, @@ -157,6 +163,45 @@ def from_array( _token=cls._token, ) + @classmethod + def from_tensor( + cls, + tensor: torch.Tensor, + name: str | None = None, + origin: NDArray | tuple | list | float | int | None = None, + sampling: NDArray | tuple | list | float | int | None = None, + units: list[str] | tuple | list | None = None, + signal_units: str = "arb. units", + metadata: dict | None = None, + ) -> Self: + """Create a Dataset4dstem from a torch tensor (any device). + + Use this when raw data is GPU-resident (CUDA pipelines, live detector + frames, GPU file readers) to skip the VRAM<->RAM round-trip. + + For cupy / jax arrays, wrap with ``torch.from_dlpack(arr)`` first. + """ + if not isinstance(tensor, torch.Tensor): + raise TypeError( + f"from_tensor requires torch.Tensor, got {type(tensor).__name__}. " + f"For cupy / jax, wrap with `torch.from_dlpack(arr)` first." + ) + if tensor.ndim != 4: + raise ValueError( + f"Dataset4dstem.from_tensor requires a 4D tensor " + f"(scan_row, scan_col, dp_row, dp_col), got shape {tuple(tensor.shape)}." + ) + return cls( + tensor=tensor, + name=name if name is not None else "4D-STEM dataset (torch)", + origin=origin if origin is not None else np.zeros(4), + sampling=sampling if sampling is not None else np.ones(4), + units=units if units is not None else ["pixels"] * 4, + signal_units=signal_units, + metadata=metadata if metadata is not None else {}, + _token=cls._token, + ) + @property def virtual_images(self) -> dict[str, Dataset2d]: """ diff --git a/src/quantem/core/datastructures/dataset5dstem.py b/src/quantem/core/datastructures/dataset5dstem.py new file mode 100644 index 00000000..339fbe14 --- /dev/null +++ b/src/quantem/core/datastructures/dataset5dstem.py @@ -0,0 +1,173 @@ +from typing import Iterator, Self + +import numpy as np +import torch +from numpy.typing import NDArray + +from quantem.core.datastructures.dataset import Dataset +from quantem.core.datastructures.dataset4dstem import Dataset4dstem +from quantem.core.utils.validators import validate_ndinfo, validate_units + + +_SERIES_TYPES = ("time", "tilt", "energy", "dose", "focus", "generic") + + +class Dataset5dstem(Dataset): + """**EXPERIMENTAL.** Torch-backed 5D-STEM series ``(N, scan_row, scan_col, k_row, k_col)``. + + Stack of 4D-STEM acquisitions sharing identical scan + k calibration. Axis 0 + represents ONE monotonically varying experimental parameter (time, tilt, + focus, dose, energy, generic). + + ``sampling`` / ``units`` / ``origin`` are 4-length (scan + k only) - the + series axis is described separately by ``series_type`` + ``series``. This + diverges from base Dataset's ``len(sampling) == ndim`` convention but keeps + the user-facing API clean (no axis-0 placeholders). + + Single-tensor / single-device only. Sharding deferred. API is experimental. + """ + + def __init__( + self, + tensor: torch.Tensor, + name: str = "", + sampling: NDArray | tuple | list | None = None, + units: list[str] | tuple | list | None = None, + origin: NDArray | tuple | list | None = None, + signal_units: str = "arb. units", + metadata: dict | None = None, + series_type: str = "generic", + series: NDArray | list | tuple | None = None, + _token: object | None = None, + ): + if _token is not self._token: + raise RuntimeError( + "Use Dataset5dstem.from_tensor() or Dataset5dstem.from_4dstem() to instantiate." + ) + if series_type not in _SERIES_TYPES: + raise ValueError(f"series_type must be one of {_SERIES_TYPES}, got {series_type!r}.") + super().__init__( + tensor=tensor, name=name, + sampling=sampling, units=units, origin=origin, + signal_units=signal_units, metadata=metadata, _token=_token, + ) + self.series_type = series_type + self.series = series + + @classmethod + def from_tensor( + cls, + tensor: torch.Tensor, + name: str | None = None, + sampling: NDArray | tuple | list | None = None, + units: list[str] | tuple | list | None = None, + origin: NDArray | tuple | list | None = None, + signal_units: str = "arb. units", + metadata: dict | None = None, + series_type: str = "generic", + series: NDArray | list | tuple | None = None, + ) -> Self: + """Wrap a 5D torch tensor. ``sampling`` / ``units`` / ``origin`` are 4-length + (scan_row, scan_col, k_row, k_col); axis 0 lives in ``series_type`` + ``series``. + """ + if tensor.ndim != 5: + raise ValueError( + f"from_tensor requires 5D tensor (N, scan_row, scan_col, k_row, k_col), " + f"got shape {tuple(tensor.shape)}." + ) + return cls( + tensor=tensor, + name=name if name is not None else "5D-STEM dataset (torch)", + sampling=sampling if sampling is not None else np.ones(4), + units=units if units is not None else ["pixels"] * 4, + origin=origin if origin is not None else np.zeros(4), + signal_units=signal_units, metadata=metadata, + series_type=series_type, series=series, + _token=cls._token, + ) + + @classmethod + def from_4dstem( + cls, + datasets: list[Dataset4dstem], + name: str | None = None, + series_type: str = "generic", + series: NDArray | list | tuple | None = None, + ) -> Self: + """Stack tensor-backed Dataset4dstem (same device). Spatial cal inherits from first.""" + member_tensors = [d.tensor for d in datasets] + devices = {str(t.device) for t in member_tensors} + if len(devices) > 1: + raise ValueError( + f"All Dataset4dstem must share device; got {sorted(devices)}. " + f"Sharding not yet supported - move to one device first via ds.to('cuda:N')." + ) + first = datasets[0] + return cls.from_tensor( + tensor=torch.stack(member_tensors, dim=0), + name=name if name is not None else f"{len(datasets)}x {first.name}", + sampling=first.sampling, units=first.units, origin=first.origin, + series_type=series_type, series=series, + ) + + # --- Override base sampling/units/origin: 4-length (scan + k), not ndim-length --- + @property + def sampling(self) -> NDArray: return self._sampling + + @sampling.setter + def sampling(self, value) -> None: + self._sampling = validate_ndinfo(value, 4, "sampling") + + @property + def origin(self) -> NDArray: return self._origin + + @origin.setter + def origin(self, value) -> None: + self._origin = validate_ndinfo(value, 4, "origin") + + @property + def units(self) -> list[str]: return self._units + + @units.setter + def units(self, value) -> None: + self._units = validate_units(value, 4) + + # --- Series metadata --- + @property + def series(self) -> NDArray | None: + return self._series + + @series.setter + def series(self, value) -> None: + if value is None: + self._series = None + return + arr = np.asarray(value, dtype=float) + n = len(self) + if arr.ndim != 1 or len(arr) != n: + raise ValueError(f"series must be 1D length {n}, got shape {arr.shape}.") + self._series = arr + + # --- Frame access --- + def __len__(self) -> int: + return int(self._tensor.shape[0]) + + def __getitem__(self, index: int | slice) -> Dataset4dstem | Self: + if isinstance(index, int): + return Dataset4dstem.from_tensor( + self._tensor[index], + name=f"{self.name}[{index}]", + sampling=self.sampling, units=self.units, + ) + sub_series = None if self._series is None else self._series[index] + return Dataset5dstem.from_tensor( + tensor=self._tensor[index], + name=self.name, + sampling=self.sampling, units=self.units, origin=self.origin, + signal_units=self.signal_units, metadata=self._metadata, + series_type=self.series_type, series=sub_series, + ) + + def __iter__(self) -> Iterator[Dataset4dstem]: + for i in range(len(self)): + yield self[i] diff --git a/tests/datastructures/test_dataset5dstem.py b/tests/datastructures/test_dataset5dstem.py new file mode 100644 index 00000000..347da7e4 --- /dev/null +++ b/tests/datastructures/test_dataset5dstem.py @@ -0,0 +1,62 @@ +"""Tests for Dataset5dstem (quantem.core.datastructures.dataset5dstem).""" + +import numpy as np +import torch + +from quantem.core.datastructures.dataset4dstem import Dataset4dstem +from quantem.core.datastructures.dataset5dstem import Dataset5dstem + + +def test_from_tensor(): + ds = Dataset5dstem.from_tensor( + tensor=torch.rand(3, 5, 5, 8, 8), + name="t", + sampling=(0.5, 0.5, 0.1, 0.1), + units=["nm", "nm", "1/nm", "1/nm"], + series_type="time", + series=[0.0, 2.0, 4.0], + ) + assert ds.shape == (3, 5, 5, 8, 8) + assert np.array_equal(ds.sampling, np.array([0.5, 0.5, 0.1, 0.1])) + assert ds.units == ["nm", "nm", "1/nm", "1/nm"] + assert ds.series_type == "time" + assert np.array_equal(ds.series, np.array([0.0, 2.0, 4.0])) + assert isinstance(ds[0], Dataset4dstem) + + +def test_slice(): + """A scientist slices a sub-stack - gets a smaller Dataset5dstem with series sliced.""" + ds = Dataset5dstem.from_tensor( + tensor=torch.rand(5, 5, 5, 8, 8), + series_type="time", series=[0.0, 1.0, 2.0, 3.0, 4.0], + ) + sub = ds[1:4] + assert isinstance(sub, Dataset5dstem) + assert sub.shape == (3, 5, 5, 8, 8) + assert np.array_equal(sub.series, np.array([1.0, 2.0, 3.0])) + + +def test_for_loop(): + """A scientist loops frame-by-frame - each yield is a Dataset4dstem.""" + ds = Dataset5dstem.from_tensor(tensor=torch.rand(3, 5, 5, 8, 8)) + seen = [f for f in ds] + assert len(seen) == 3 + assert all(isinstance(f, Dataset4dstem) and f.shape == (5, 5, 8, 8) for f in seen) + + +def test_from_4dstem(): + d4_list = [ + Dataset4dstem.from_tensor( + torch.rand(5, 5, 8, 8), + sampling=(0.5, 0.5, 0.1, 0.1), + units=("nm", "nm", "1/nm", "1/nm"), + name=f"f{i}", + ) + for i in range(3) + ] + ds = Dataset5dstem.from_4dstem(d4_list, series_type="tilt", series=[-30, 0, 30]) + assert ds.shape == (3, 5, 5, 8, 8) + assert np.array_equal(ds.sampling, np.array([0.5, 0.5, 0.1, 0.1])) + assert ds.units == ["nm", "nm", "1/nm", "1/nm"] + assert ds.series_type == "tilt" + assert np.array_equal(ds.series, np.array([-30.0, 0.0, 30.0])) diff --git a/widget/src/quantem/widget/show4dstem.py b/widget/src/quantem/widget/show4dstem.py index 0a6b256b..c3ee61a7 100644 --- a/widget/src/quantem/widget/show4dstem.py +++ b/widget/src/quantem/widget/show4dstem.py @@ -333,15 +333,18 @@ def __init__( _io_labels = None - # Auto-extract sampling + units from Dataset4dstem if available. - if hasattr(data, "sampling") and hasattr(data, "array"): - if not title and hasattr(data, "name") and data.name: + # Extract underlying array / tensor + auto-calibrate from Dataset input + # (duck-typed via the dual-slot private attributes _tensor / _array). + tensor = getattr(data, "_tensor", None) + array = getattr(data, "_array", None) + if tensor is not None or array is not None: + if not title and getattr(data, "name", ""): title = str(data.name) if sampling is None: sampling = tuple(float(s) for s in data.sampling) - if units is None and hasattr(data, "units"): + if units is None: units = list(data.units) - data = data.array + data = tensor if tensor is not None else array # Resolve sampling + units (4 axes for 4D-STEM): # [scan_row, scan_col, k_row, k_col]. Scalar/None broadcast to (1, 1, 1, 1).