diff --git a/src/quantem/core/datastructures/dataset.py b/src/quantem/core/datastructures/dataset.py index 4d2ab9e1..fa1fcef2 100644 --- a/src/quantem/core/datastructures/dataset.py +++ b/src/quantem/core/datastructures/dataset.py @@ -4,6 +4,7 @@ from typing import Any, Literal, Optional, Self, Union, overload import numpy as np +import torch from numpy.typing import DTypeLike, NDArray from quantem.core.io.serialize import AutoSerialize @@ -38,24 +39,37 @@ class Dataset(AutoSerialize): def __init__( self, - array: Any, # Input can be array-like - name: str, - origin: NDArray | tuple | list | float | int, - sampling: NDArray | tuple | list | float | int, - units: list[str] | tuple | list, + array: NDArray | None = None, + tensor: torch.Tensor | None = None, + name: str = "", + origin: NDArray | tuple | list | float | int | None = None, + sampling: NDArray | tuple | list | float | int | None = None, + units: list[str] | tuple | list | None = None, signal_units: str = "arb. units", metadata: Optional[dict] = None, _token: object | None = None, ): if _token is not self._token: - raise RuntimeError("Use Dataset.from_array() to instantiate this class.") - super().__init__() - arr = ensure_valid_array(array) - if not isinstance(arr, np.ndarray): - raise TypeError( - "Dataset requires a NumPy array (CuPy is not supported on this branch)." + raise RuntimeError( + "Use Dataset.from_array() or Dataset.from_tensor() to instantiate this class." ) - self._array = arr + super().__init__() + # Dual-slot storage: exactly one of (_array, _tensor) is set. + if array is None and tensor is None: + raise ValueError("Provide either `array` (numpy) or `tensor` (torch).") + if array is not None and tensor is not None: + raise ValueError("Provide only one of `array` or `tensor`, not both.") + if array is not None: + arr = ensure_valid_array(array) + if not isinstance(arr, np.ndarray): + raise TypeError(f"Dataset.array must be numpy.ndarray, got {type(arr).__name__}.") + self._array = arr + self._tensor = None + else: + if not isinstance(tensor, torch.Tensor): + raise TypeError(f"Dataset.tensor must be torch.Tensor, got {type(tensor).__name__}.") + self._array = None + self._tensor = tensor self.name = name self.origin = origin self.sampling = sampling @@ -122,19 +136,31 @@ def from_array( # --- Properties --- @property - def array(self) -> NDArray: - """The underlying n-dimensional NumPy array data.""" - return self._array + def array(self) -> NDArray | None: + """The underlying n-dimensional NumPy array data. + + Returns ``None`` for tensor-backed datasets. Use ``.tensor`` for the + torch tensor, or ``.numpy()`` to materialize a numpy copy explicitly. + """ + return getattr(self, "_array", None) @array.setter def array(self, value: NDArray) -> None: - arr = ensure_valid_array(value, ndim=self.ndim) # want to allow changing dtype + arr = ensure_valid_array(value, ndim=self.ndim) if not isinstance(arr, np.ndarray): - raise TypeError( - "Dataset requires a NumPy array (CuPy is not supported on this branch)." - ) + raise TypeError(f"Dataset.array must be numpy.ndarray, got {type(arr).__name__}.") self._array = arr - # self._array = ensure_valid_array(value, dtype=self.dtype, ndim=self.ndim) + + @property + def tensor(self) -> torch.Tensor: + """Torch tensor backing the data. AttributeError if numpy-backed.""" + # getattr handles AutoSerialize-restored instances (no __init__ run). + tensor = getattr(self, "_tensor", None) + if tensor is None: + raise AttributeError( + f"Dataset '{self.name}' is numpy-backed; use Dataset.from_tensor() at construction." + ) + return tensor @property def metadata(self) -> dict: @@ -191,26 +217,50 @@ def file_path(self, value: os.PathLike | str | None) -> None: # --- Derived Properties --- @property def shape(self) -> tuple[int, ...]: - return self.array.shape + # Direct slot access (never triggers .array derive, which would force + # a full GPU->CPU copy on tensor-backed datasets). getattr handles + # AutoSerialize-restored instances (no __init__ run). + array = getattr(self, "_array", None) + return tuple((array if array is not None else self._tensor).shape) @property def ndim(self) -> int: - return self.array.ndim + array = getattr(self, "_array", None) + return (array if array is not None else self._tensor).ndim @property def dtype(self) -> DTypeLike: - return self.array.dtype + array = getattr(self, "_array", None) + return (array if array is not None else self._tensor).dtype @property def device(self) -> str: - """ - Outputting a string is likely temporary -- once we have our use cases we can - figure out a more permanent device solution that enables easier translation between - numpy <-> torch <-> numpy, etc. + """``"cpu"`` for numpy-backed; torch device string for tensor-backed.""" + tensor = getattr(self, "_tensor", None) + if tensor is not None: + return str(tensor.device) + return "cpu" - For NumPy-only datasets, this is always "cpu". + def numpy(self) -> NDArray: + """Return the data as a numpy array (mirrors ``torch.Tensor.numpy()``). + + For numpy-backed datasets, returns ``self.array`` directly. For + tensor-backed datasets, materializes a CPU copy via ``.detach().cpu().numpy()``. """ - return "cpu" + array = getattr(self, "_array", None) + if array is not None: + return array + return self._tensor.detach().cpu().numpy() + + def to(self, device) -> Self: + """Move the underlying tensor to ``device``. Raises if numpy-backed.""" + tensor = getattr(self, "_tensor", None) + if tensor is None: + raise AttributeError( + f"Cannot .to({device!r}) on numpy-backed Dataset '{self.name}'." + ) + self._tensor = tensor.to(device) + return self # --- Summaries --- def __repr__(self) -> str: diff --git a/src/quantem/core/datastructures/dataset4d.py b/src/quantem/core/datastructures/dataset4d.py index 8e5bdfa0..f5681730 100644 --- a/src/quantem/core/datastructures/dataset4d.py +++ b/src/quantem/core/datastructures/dataset4d.py @@ -1,6 +1,7 @@ from typing import Any, Self, Union import numpy as np +import torch from numpy.typing import NDArray from quantem.core.datastructures.dataset import Dataset @@ -21,11 +22,12 @@ class Dataset4d(Dataset): def __init__( self, - array: NDArray | Any, - name: str, - origin: NDArray | tuple | list | float | int, - sampling: NDArray | tuple | list | float | int, - units: list[str] | tuple | list, + array: NDArray | None = None, + tensor: torch.Tensor | None = None, + name: str = "", + origin: NDArray | tuple | list | float | int | None = None, + sampling: NDArray | tuple | list | float | int | None = None, + units: list[str] | tuple | list | None = None, signal_units: str = "arb. units", metadata: dict = {}, _token: object | None = None, @@ -34,8 +36,10 @@ def __init__( Parameters ---------- - array : NDArray | Any - The underlying 3D array data + array : NDArray | None + The underlying 4D numpy array. Provide exactly one of ``array`` or ``tensor``. + tensor : torch.Tensor | None + The underlying 4D torch tensor (any device). Provide exactly one of ``array`` or ``tensor``. name : str A descriptive name for the dataset origin : NDArray | tuple | list | float | int @@ -51,6 +55,7 @@ def __init__( """ super().__init__( array=array, + tensor=tensor, name=name, origin=origin, sampling=sampling, diff --git a/src/quantem/core/datastructures/dataset4dstem.py b/src/quantem/core/datastructures/dataset4dstem.py index 79cbc479..48d3f737 100644 --- a/src/quantem/core/datastructures/dataset4dstem.py +++ b/src/quantem/core/datastructures/dataset4dstem.py @@ -2,6 +2,7 @@ import matplotlib.pyplot as plt import numpy as np +import torch from matplotlib.patches import Circle, Wedge from numpy.typing import NDArray @@ -41,11 +42,12 @@ class Dataset4dstem(Dataset4d): def __init__( self, - array: NDArray | Any, - name: str, - origin: NDArray | tuple | list | float | int, - sampling: NDArray | tuple | list | float | int, - units: list[str] | tuple | list, + array: NDArray | None = None, + tensor: torch.Tensor | None = None, + name: str = "", + origin: NDArray | tuple | list | float | int | None = None, + sampling: NDArray | tuple | list | float | int | None = None, + units: list[str] | tuple | list | None = None, signal_units: str = "arb. units", metadata: dict = {}, _token: object | None = None, @@ -56,6 +58,9 @@ def __init__( ---------- array : NDArray | Any The underlying 4D array data + tensor : torch.Tensor | None, optional + Alternative to ``array``: the underlying 4D torch tensor (any device). + Provide exactly one of ``array`` or ``tensor``. name : str A descriptive name for the dataset origin : NDArray | tuple | list | float | int @@ -79,6 +84,7 @@ def __init__( super().__init__( array=array, + tensor=tensor, name=name, origin=origin, sampling=sampling, @@ -157,6 +163,45 @@ def from_array( _token=cls._token, ) + @classmethod + def from_tensor( + cls, + tensor: torch.Tensor, + name: str | None = None, + origin: NDArray | tuple | list | float | int | None = None, + sampling: NDArray | tuple | list | float | int | None = None, + units: list[str] | tuple | list | None = None, + signal_units: str = "arb. units", + metadata: dict | None = None, + ) -> Self: + """Create a Dataset4dstem from a torch tensor (any device). + + Use this when raw data is GPU-resident (CUDA pipelines, live detector + frames, GPU file readers) to skip the VRAM<->RAM round-trip. + + For cupy / jax arrays, wrap with ``torch.from_dlpack(arr)`` first. + """ + if not isinstance(tensor, torch.Tensor): + raise TypeError( + f"from_tensor requires torch.Tensor, got {type(tensor).__name__}. " + f"For cupy / jax, wrap with `torch.from_dlpack(arr)` first." + ) + if tensor.ndim != 4: + raise ValueError( + f"Dataset4dstem.from_tensor requires a 4D tensor " + f"(scan_row, scan_col, dp_row, dp_col), got shape {tuple(tensor.shape)}." + ) + return cls( + tensor=tensor, + name=name if name is not None else "4D-STEM dataset (torch)", + origin=origin if origin is not None else np.zeros(4), + sampling=sampling if sampling is not None else np.ones(4), + units=units if units is not None else ["pixels"] * 4, + signal_units=signal_units, + metadata=metadata if metadata is not None else {}, + _token=cls._token, + ) + @property def virtual_images(self) -> dict[str, Dataset2d]: """ diff --git a/widget/src/quantem/widget/show4dstem.py b/widget/src/quantem/widget/show4dstem.py index 0a6b256b..c3ee61a7 100644 --- a/widget/src/quantem/widget/show4dstem.py +++ b/widget/src/quantem/widget/show4dstem.py @@ -333,15 +333,18 @@ def __init__( _io_labels = None - # Auto-extract sampling + units from Dataset4dstem if available. - if hasattr(data, "sampling") and hasattr(data, "array"): - if not title and hasattr(data, "name") and data.name: + # Extract underlying array / tensor + auto-calibrate from Dataset input + # (duck-typed via the dual-slot private attributes _tensor / _array). + tensor = getattr(data, "_tensor", None) + array = getattr(data, "_array", None) + if tensor is not None or array is not None: + if not title and getattr(data, "name", ""): title = str(data.name) if sampling is None: sampling = tuple(float(s) for s in data.sampling) - if units is None and hasattr(data, "units"): + if units is None: units = list(data.units) - data = data.array + data = tensor if tensor is not None else array # Resolve sampling + units (4 axes for 4D-STEM): # [scan_row, scan_col, k_row, k_col]. Scalar/None broadcast to (1, 1, 1, 1).