From 98ddd0d34e9ed51f63003131807f42f6843d7a16 Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Mon, 18 May 2026 22:36:44 -0700 Subject: [PATCH 1/8] dataset4d, dataset4dstem hold torch array --- src/quantem/core/datastructures/dataset.py | 104 +++++++++++++----- src/quantem/core/datastructures/dataset4d.py | 33 ++---- .../core/datastructures/dataset4dstem.py | 52 ++++++++- widget/src/quantem/widget/show4dstem.py | 11 +- 4 files changed, 141 insertions(+), 59 deletions(-) diff --git a/src/quantem/core/datastructures/dataset.py b/src/quantem/core/datastructures/dataset.py index 4d2ab9e1..786b3d5d 100644 --- a/src/quantem/core/datastructures/dataset.py +++ b/src/quantem/core/datastructures/dataset.py @@ -4,6 +4,7 @@ from typing import Any, Literal, Optional, Self, Union, overload import numpy as np +import torch from numpy.typing import DTypeLike, NDArray from quantem.core.io.serialize import AutoSerialize @@ -38,24 +39,39 @@ class Dataset(AutoSerialize): def __init__( self, - array: Any, # Input can be array-like - name: str, - origin: NDArray | tuple | list | float | int, - sampling: NDArray | tuple | list | float | int, - units: list[str] | tuple | list, + array: NDArray | None = None, + tensor: torch.Tensor | None = None, + name: str = "", + origin: NDArray | tuple | list | float | int | None = None, + sampling: NDArray | tuple | list | float | int | None = None, + units: list[str] | tuple | list | None = None, signal_units: str = "arb. units", metadata: Optional[dict] = None, _token: object | None = None, ): if _token is not self._token: - raise RuntimeError("Use Dataset.from_array() to instantiate this class.") - super().__init__() - arr = ensure_valid_array(array) - if not isinstance(arr, np.ndarray): - raise TypeError( - "Dataset requires a NumPy array (CuPy is not supported on this branch)." + raise RuntimeError( + "Use Dataset.from_array() or Dataset.from_tensor() to instantiate this class." ) - self._array = arr + super().__init__() + # Dual-slot storage: exactly one of (_array, _tensor) is set. + if array is None and tensor is None: + raise ValueError("Provide either `array` (numpy) or `tensor` (torch).") + if array is not None and tensor is not None: + raise ValueError("Provide only one of `array` or `tensor`, not both.") + if array is not None: + arr = ensure_valid_array(array) + if not isinstance(arr, np.ndarray): + raise TypeError(f"Dataset.array must be numpy.ndarray, got {type(arr).__name__}.") + self._array = arr + self._tensor = None + else: + if not isinstance(tensor, torch.Tensor): + raise TypeError(f"Dataset.tensor must be torch.Tensor, got {type(tensor).__name__}.") + self._array = None + self._tensor = tensor + # Lazy cache: derived numpy from tensor, materialized only on first .array access. + self._cached_numpy: np.ndarray | None = None self.name = name self.origin = origin self.sampling = sampling @@ -123,18 +139,34 @@ def from_array( # --- Properties --- @property def array(self) -> NDArray: - """The underlying n-dimensional NumPy array data.""" - return self._array + """The data as a numpy array. + + For tensor-backed datasets, returns a CACHED read-only CPU copy derived + from ``self.tensor`` (first access pays GPU->CPU transfer, subsequent + accesses are free). Torch-aware consumers should prefer ``.tensor``. + """ + if self._array is not None: + return self._array + if self._cached_numpy is None: + self._cached_numpy = self._tensor.detach().cpu().numpy() + self._cached_numpy.flags.writeable = False + return self._cached_numpy @array.setter def array(self, value: NDArray) -> None: - arr = ensure_valid_array(value, ndim=self.ndim) # want to allow changing dtype + arr = ensure_valid_array(value, ndim=self.ndim) if not isinstance(arr, np.ndarray): - raise TypeError( - "Dataset requires a NumPy array (CuPy is not supported on this branch)." - ) + raise TypeError(f"Dataset.array must be numpy.ndarray, got {type(arr).__name__}.") self._array = arr - # self._array = ensure_valid_array(value, dtype=self.dtype, ndim=self.ndim) + + @property + def tensor(self) -> torch.Tensor: + """Torch tensor backing the data. AttributeError if numpy-backed.""" + if self._tensor is None: + raise AttributeError( + f"Dataset '{self.name}' is numpy-backed; use Dataset.from_tensor() at construction." + ) + return self._tensor @property def metadata(self) -> dict: @@ -191,26 +223,42 @@ def file_path(self, value: os.PathLike | str | None) -> None: # --- Derived Properties --- @property def shape(self) -> tuple[int, ...]: - return self.array.shape + # Direct slot access — never triggers .array derive (which would force + # a full GPU->CPU copy on tensor-backed datasets). + return tuple((self._array if self._array is not None else self._tensor).shape) @property def ndim(self) -> int: - return self.array.ndim + return (self._array if self._array is not None else self._tensor).ndim @property def dtype(self) -> DTypeLike: - return self.array.dtype + return (self._array if self._array is not None else self._tensor).dtype @property def device(self) -> str: - """ - Outputting a string is likely temporary -- once we have our use cases we can - figure out a more permanent device solution that enables easier translation between - numpy <-> torch <-> numpy, etc. + """``"cpu"`` for numpy-backed; torch device string for tensor-backed.""" + if self._tensor is not None: + return str(self._tensor.device) + return "cpu" + + def numpy(self) -> NDArray: + """Return the data as a numpy array (mirrors ``torch.Tensor.numpy()``). - For NumPy-only datasets, this is always "cpu". + Equivalent to ``self.array`` — both return numpy. For tensor-backed + datasets, first call materializes a cached read-only CPU copy. """ - return "cpu" + return self.array + + def to(self, device) -> Self: + """Move the underlying tensor to ``device``. Raises if numpy-backed.""" + if self._tensor is None: + raise AttributeError( + f"Cannot .to({device!r}) on numpy-backed Dataset '{self.name}'." + ) + self._tensor = self._tensor.to(device) + self._cached_numpy = None # invalidate stale derived numpy + return self # --- Summaries --- def __repr__(self) -> str: diff --git a/src/quantem/core/datastructures/dataset4d.py b/src/quantem/core/datastructures/dataset4d.py index 8e5bdfa0..80d219ea 100644 --- a/src/quantem/core/datastructures/dataset4d.py +++ b/src/quantem/core/datastructures/dataset4d.py @@ -1,6 +1,7 @@ from typing import Any, Self, Union import numpy as np +import torch from numpy.typing import NDArray from quantem.core.datastructures.dataset import Dataset @@ -21,36 +22,20 @@ class Dataset4d(Dataset): def __init__( self, - array: NDArray | Any, - name: str, - origin: NDArray | tuple | list | float | int, - sampling: NDArray | tuple | list | float | int, - units: list[str] | tuple | list, + array: NDArray | None = None, + tensor: torch.Tensor | None = None, + name: str = "", + origin: NDArray | tuple | list | float | int | None = None, + sampling: NDArray | tuple | list | float | int | None = None, + units: list[str] | tuple | list | None = None, signal_units: str = "arb. units", metadata: dict = {}, _token: object | None = None, ): - """Initialize a 4D dataset. - - Parameters - ---------- - array : NDArray | Any - The underlying 3D array data - name : str - A descriptive name for the dataset - origin : NDArray | tuple | list | float | int - The origin coordinates for each dimension in calibrated units - sampling : NDArray | tuple | list | float | int - The sampling rate/spacing for each dimension - units : list[str] | tuple | list - Units for each dimension - signal_units : str, optional - Units for the array values, by default "arb. units" - _token : object | None, optional - Token to prevent direct instantiation, by default None - """ + """Initialize a 4D dataset. Pass exactly one of ``array`` (numpy) or ``tensor`` (torch).""" super().__init__( array=array, + tensor=tensor, name=name, origin=origin, sampling=sampling, diff --git a/src/quantem/core/datastructures/dataset4dstem.py b/src/quantem/core/datastructures/dataset4dstem.py index 79cbc479..60890475 100644 --- a/src/quantem/core/datastructures/dataset4dstem.py +++ b/src/quantem/core/datastructures/dataset4dstem.py @@ -2,6 +2,7 @@ import matplotlib.pyplot as plt import numpy as np +import torch from matplotlib.patches import Circle, Wedge from numpy.typing import NDArray @@ -41,11 +42,12 @@ class Dataset4dstem(Dataset4d): def __init__( self, - array: NDArray | Any, - name: str, - origin: NDArray | tuple | list | float | int, - sampling: NDArray | tuple | list | float | int, - units: list[str] | tuple | list, + array: NDArray | None = None, + tensor: torch.Tensor | None = None, + name: str = "", + origin: NDArray | tuple | list | float | int | None = None, + sampling: NDArray | tuple | list | float | int | None = None, + units: list[str] | tuple | list | None = None, signal_units: str = "arb. units", metadata: dict = {}, _token: object | None = None, @@ -79,6 +81,7 @@ def __init__( super().__init__( array=array, + tensor=tensor, name=name, origin=origin, sampling=sampling, @@ -157,6 +160,45 @@ def from_array( _token=cls._token, ) + @classmethod + def from_tensor( + cls, + tensor: torch.Tensor, + name: str | None = None, + origin: NDArray | tuple | list | float | int | None = None, + sampling: NDArray | tuple | list | float | int | None = None, + units: list[str] | tuple | list | None = None, + signal_units: str = "arb. units", + metadata: dict | None = None, + ) -> Self: + """Create a Dataset4dstem from a torch tensor (any device). + + Use this when raw data is GPU-resident (CUDA pipelines, live detector + frames, GPU file readers) to skip the VRAM<->RAM round-trip. + + For cupy / jax arrays, wrap with ``torch.from_dlpack(arr)`` first. + """ + if not isinstance(tensor, torch.Tensor): + raise TypeError( + f"from_tensor requires torch.Tensor, got {type(tensor).__name__}. " + f"For cupy / jax, wrap with `torch.from_dlpack(arr)` first." + ) + if tensor.ndim != 4: + raise ValueError( + f"Dataset4dstem.from_tensor requires a 4D tensor " + f"(scan_y, scan_x, dp_y, dp_x), got shape {tuple(tensor.shape)}." + ) + return cls( + tensor=tensor, + name=name if name is not None else "4D-STEM dataset (torch)", + origin=origin if origin is not None else np.zeros(4), + sampling=sampling if sampling is not None else np.ones(4), + units=units if units is not None else ["pixels"] * 4, + signal_units=signal_units, + metadata=metadata if metadata is not None else {}, + _token=cls._token, + ) + @property def virtual_images(self) -> dict[str, Dataset2d]: """ diff --git a/widget/src/quantem/widget/show4dstem.py b/widget/src/quantem/widget/show4dstem.py index 0a6b256b..06f8b629 100644 --- a/widget/src/quantem/widget/show4dstem.py +++ b/widget/src/quantem/widget/show4dstem.py @@ -334,14 +334,21 @@ def __init__( _io_labels = None # Auto-extract sampling + units from Dataset4dstem if available. - if hasattr(data, "sampling") and hasattr(data, "array"): + # NOTE: avoid `hasattr(data, "array")` — for tensor-backed Datasets the + # `.array` getter is an expensive derive (full GPU->CPU copy). Use cheap + # `hasattr(data, "sampling")` to identify a Dataset. + if hasattr(data, "sampling"): if not title and hasattr(data, "name") and data.name: title = str(data.name) if sampling is None: sampling = tuple(float(s) for s in data.sampling) if units is None and hasattr(data, "units"): units = list(data.units) - data = data.array + # If tensor-backed (zero-copy GPU path), take .tensor. Else .array (numpy). + if getattr(data, "_tensor", None) is not None: + data = data.tensor + else: + data = data.array # Resolve sampling + units (4 axes for 4D-STEM): # [scan_row, scan_col, k_row, k_col]. Scalar/None broadcast to (1, 1, 1, 1). From d2dc32b390307f409ac8a8ba45b403a0aab58b01 Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Mon, 18 May 2026 22:39:56 -0700 Subject: [PATCH 2/8] bring original docstring back --- src/quantem/core/datastructures/dataset.py | 2 +- src/quantem/core/datastructures/dataset4d.py | 22 ++++++++++++++++++- .../core/datastructures/dataset4dstem.py | 3 +++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/quantem/core/datastructures/dataset.py b/src/quantem/core/datastructures/dataset.py index 786b3d5d..5a89b8ad 100644 --- a/src/quantem/core/datastructures/dataset.py +++ b/src/quantem/core/datastructures/dataset.py @@ -139,7 +139,7 @@ def from_array( # --- Properties --- @property def array(self) -> NDArray: - """The data as a numpy array. + """The underlying n-dimensional NumPy array data. For tensor-backed datasets, returns a CACHED read-only CPU copy derived from ``self.tensor`` (first access pays GPU->CPU transfer, subsequent diff --git a/src/quantem/core/datastructures/dataset4d.py b/src/quantem/core/datastructures/dataset4d.py index 80d219ea..f5681730 100644 --- a/src/quantem/core/datastructures/dataset4d.py +++ b/src/quantem/core/datastructures/dataset4d.py @@ -32,7 +32,27 @@ def __init__( metadata: dict = {}, _token: object | None = None, ): - """Initialize a 4D dataset. Pass exactly one of ``array`` (numpy) or ``tensor`` (torch).""" + """Initialize a 4D dataset. + + Parameters + ---------- + array : NDArray | None + The underlying 4D numpy array. Provide exactly one of ``array`` or ``tensor``. + tensor : torch.Tensor | None + The underlying 4D torch tensor (any device). Provide exactly one of ``array`` or ``tensor``. + name : str + A descriptive name for the dataset + origin : NDArray | tuple | list | float | int + The origin coordinates for each dimension in calibrated units + sampling : NDArray | tuple | list | float | int + The sampling rate/spacing for each dimension + units : list[str] | tuple | list + Units for each dimension + signal_units : str, optional + Units for the array values, by default "arb. units" + _token : object | None, optional + Token to prevent direct instantiation, by default None + """ super().__init__( array=array, tensor=tensor, diff --git a/src/quantem/core/datastructures/dataset4dstem.py b/src/quantem/core/datastructures/dataset4dstem.py index 60890475..c0c331b5 100644 --- a/src/quantem/core/datastructures/dataset4dstem.py +++ b/src/quantem/core/datastructures/dataset4dstem.py @@ -58,6 +58,9 @@ def __init__( ---------- array : NDArray | Any The underlying 4D array data + tensor : torch.Tensor | None, optional + Alternative to ``array``: the underlying 4D torch tensor (any device). + Provide exactly one of ``array`` or ``tensor``. name : str A descriptive name for the dataset origin : NDArray | tuple | list | float | int From 6b7319d7996553912b7b8ce485edb16c17257183 Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Mon, 18 May 2026 22:42:43 -0700 Subject: [PATCH 3/8] remove need for cached numpy array --- src/quantem/core/datastructures/dataset.py | 25 ++++++++-------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/src/quantem/core/datastructures/dataset.py b/src/quantem/core/datastructures/dataset.py index 5a89b8ad..3263720a 100644 --- a/src/quantem/core/datastructures/dataset.py +++ b/src/quantem/core/datastructures/dataset.py @@ -70,8 +70,6 @@ def __init__( raise TypeError(f"Dataset.tensor must be torch.Tensor, got {type(tensor).__name__}.") self._array = None self._tensor = tensor - # Lazy cache: derived numpy from tensor, materialized only on first .array access. - self._cached_numpy: np.ndarray | None = None self.name = name self.origin = origin self.sampling = sampling @@ -138,19 +136,13 @@ def from_array( # --- Properties --- @property - def array(self) -> NDArray: + def array(self) -> NDArray | None: """The underlying n-dimensional NumPy array data. - For tensor-backed datasets, returns a CACHED read-only CPU copy derived - from ``self.tensor`` (first access pays GPU->CPU transfer, subsequent - accesses are free). Torch-aware consumers should prefer ``.tensor``. + Returns ``None`` for tensor-backed datasets — use ``.tensor`` for the + torch tensor, or ``.numpy()`` to materialize a numpy copy explicitly. """ - if self._array is not None: - return self._array - if self._cached_numpy is None: - self._cached_numpy = self._tensor.detach().cpu().numpy() - self._cached_numpy.flags.writeable = False - return self._cached_numpy + return self._array @array.setter def array(self, value: NDArray) -> None: @@ -245,10 +237,12 @@ def device(self) -> str: def numpy(self) -> NDArray: """Return the data as a numpy array (mirrors ``torch.Tensor.numpy()``). - Equivalent to ``self.array`` — both return numpy. For tensor-backed - datasets, first call materializes a cached read-only CPU copy. + For numpy-backed datasets, returns ``self.array`` directly. For + tensor-backed datasets, materializes a CPU copy via ``.detach().cpu().numpy()``. """ - return self.array + if self._array is not None: + return self._array + return self._tensor.detach().cpu().numpy() def to(self, device) -> Self: """Move the underlying tensor to ``device``. Raises if numpy-backed.""" @@ -257,7 +251,6 @@ def to(self, device) -> Self: f"Cannot .to({device!r}) on numpy-backed Dataset '{self.name}'." ) self._tensor = self._tensor.to(device) - self._cached_numpy = None # invalidate stale derived numpy return self # --- Summaries --- From 7f9913f61fcc9f0bafee2f79a5366b1d1358d6f4 Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Mon, 18 May 2026 22:44:38 -0700 Subject: [PATCH 4/8] further cleaup api docstring --- src/quantem/core/datastructures/dataset.py | 4 ++-- widget/src/quantem/widget/show4dstem.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/quantem/core/datastructures/dataset.py b/src/quantem/core/datastructures/dataset.py index 3263720a..a89257a3 100644 --- a/src/quantem/core/datastructures/dataset.py +++ b/src/quantem/core/datastructures/dataset.py @@ -139,7 +139,7 @@ def from_array( def array(self) -> NDArray | None: """The underlying n-dimensional NumPy array data. - Returns ``None`` for tensor-backed datasets — use ``.tensor`` for the + Returns ``None`` for tensor-backed datasets. Use ``.tensor`` for the torch tensor, or ``.numpy()`` to materialize a numpy copy explicitly. """ return self._array @@ -215,7 +215,7 @@ def file_path(self, value: os.PathLike | str | None) -> None: # --- Derived Properties --- @property def shape(self) -> tuple[int, ...]: - # Direct slot access — never triggers .array derive (which would force + # Direct slot access (never triggers .array derive, which would force # a full GPU->CPU copy on tensor-backed datasets). return tuple((self._array if self._array is not None else self._tensor).shape) diff --git a/widget/src/quantem/widget/show4dstem.py b/widget/src/quantem/widget/show4dstem.py index 06f8b629..99a5eae8 100644 --- a/widget/src/quantem/widget/show4dstem.py +++ b/widget/src/quantem/widget/show4dstem.py @@ -334,9 +334,9 @@ def __init__( _io_labels = None # Auto-extract sampling + units from Dataset4dstem if available. - # NOTE: avoid `hasattr(data, "array")` — for tensor-backed Datasets the - # `.array` getter is an expensive derive (full GPU->CPU copy). Use cheap - # `hasattr(data, "sampling")` to identify a Dataset. + # NOTE: avoid `hasattr(data, "array")` because for tensor-backed Datasets + # the `.array` getter is an expensive derive (full GPU->CPU copy). Use + # cheap `hasattr(data, "sampling")` to identify a Dataset. if hasattr(data, "sampling"): if not title and hasattr(data, "name") and data.name: title = str(data.name) From 9c93ae99d7ad483c773528c20e5c358ee9a49f55 Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Mon, 18 May 2026 22:49:44 -0700 Subject: [PATCH 5/8] use _array _tensor duck typing for show4dstem --- widget/src/quantem/widget/show4dstem.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/widget/src/quantem/widget/show4dstem.py b/widget/src/quantem/widget/show4dstem.py index 99a5eae8..c3ee61a7 100644 --- a/widget/src/quantem/widget/show4dstem.py +++ b/widget/src/quantem/widget/show4dstem.py @@ -333,22 +333,18 @@ def __init__( _io_labels = None - # Auto-extract sampling + units from Dataset4dstem if available. - # NOTE: avoid `hasattr(data, "array")` because for tensor-backed Datasets - # the `.array` getter is an expensive derive (full GPU->CPU copy). Use - # cheap `hasattr(data, "sampling")` to identify a Dataset. - if hasattr(data, "sampling"): - if not title and hasattr(data, "name") and data.name: + # Extract underlying array / tensor + auto-calibrate from Dataset input + # (duck-typed via the dual-slot private attributes _tensor / _array). + tensor = getattr(data, "_tensor", None) + array = getattr(data, "_array", None) + if tensor is not None or array is not None: + if not title and getattr(data, "name", ""): title = str(data.name) if sampling is None: sampling = tuple(float(s) for s in data.sampling) - if units is None and hasattr(data, "units"): + if units is None: units = list(data.units) - # If tensor-backed (zero-copy GPU path), take .tensor. Else .array (numpy). - if getattr(data, "_tensor", None) is not None: - data = data.tensor - else: - data = data.array + data = tensor if tensor is not None else array # Resolve sampling + units (4 axes for 4D-STEM): # [scan_row, scan_col, k_row, k_col]. Scalar/None broadcast to (1, 1, 1, 1). From 40878d3d5f13f90ddabe4a4947b1b74550337085 Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Mon, 18 May 2026 23:00:28 -0700 Subject: [PATCH 6/8] use row, col convention in docstring --- src/quantem/core/datastructures/dataset4dstem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/quantem/core/datastructures/dataset4dstem.py b/src/quantem/core/datastructures/dataset4dstem.py index c0c331b5..48d3f737 100644 --- a/src/quantem/core/datastructures/dataset4dstem.py +++ b/src/quantem/core/datastructures/dataset4dstem.py @@ -189,7 +189,7 @@ def from_tensor( if tensor.ndim != 4: raise ValueError( f"Dataset4dstem.from_tensor requires a 4D tensor " - f"(scan_y, scan_x, dp_y, dp_x), got shape {tuple(tensor.shape)}." + f"(scan_row, scan_col, dp_row, dp_col), got shape {tuple(tensor.shape)}." ) return cls( tensor=tensor, From 21b684fa7e71f843626a27ad1602adb193494cfc Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Fri, 22 May 2026 23:29:30 -0700 Subject: [PATCH 7/8] fix: tolerate missing _tensor slot on autoserialize-loaded datasets --- src/quantem/core/datastructures/dataset.py | 35 ++++++++++++++-------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/src/quantem/core/datastructures/dataset.py b/src/quantem/core/datastructures/dataset.py index a89257a3..fa1fcef2 100644 --- a/src/quantem/core/datastructures/dataset.py +++ b/src/quantem/core/datastructures/dataset.py @@ -142,7 +142,7 @@ def array(self) -> NDArray | None: Returns ``None`` for tensor-backed datasets. Use ``.tensor`` for the torch tensor, or ``.numpy()`` to materialize a numpy copy explicitly. """ - return self._array + return getattr(self, "_array", None) @array.setter def array(self, value: NDArray) -> None: @@ -154,11 +154,13 @@ def array(self, value: NDArray) -> None: @property def tensor(self) -> torch.Tensor: """Torch tensor backing the data. AttributeError if numpy-backed.""" - if self._tensor is None: + # getattr handles AutoSerialize-restored instances (no __init__ run). + tensor = getattr(self, "_tensor", None) + if tensor is None: raise AttributeError( f"Dataset '{self.name}' is numpy-backed; use Dataset.from_tensor() at construction." ) - return self._tensor + return tensor @property def metadata(self) -> dict: @@ -216,22 +218,27 @@ def file_path(self, value: os.PathLike | str | None) -> None: @property def shape(self) -> tuple[int, ...]: # Direct slot access (never triggers .array derive, which would force - # a full GPU->CPU copy on tensor-backed datasets). - return tuple((self._array if self._array is not None else self._tensor).shape) + # a full GPU->CPU copy on tensor-backed datasets). getattr handles + # AutoSerialize-restored instances (no __init__ run). + array = getattr(self, "_array", None) + return tuple((array if array is not None else self._tensor).shape) @property def ndim(self) -> int: - return (self._array if self._array is not None else self._tensor).ndim + array = getattr(self, "_array", None) + return (array if array is not None else self._tensor).ndim @property def dtype(self) -> DTypeLike: - return (self._array if self._array is not None else self._tensor).dtype + array = getattr(self, "_array", None) + return (array if array is not None else self._tensor).dtype @property def device(self) -> str: """``"cpu"`` for numpy-backed; torch device string for tensor-backed.""" - if self._tensor is not None: - return str(self._tensor.device) + tensor = getattr(self, "_tensor", None) + if tensor is not None: + return str(tensor.device) return "cpu" def numpy(self) -> NDArray: @@ -240,17 +247,19 @@ def numpy(self) -> NDArray: For numpy-backed datasets, returns ``self.array`` directly. For tensor-backed datasets, materializes a CPU copy via ``.detach().cpu().numpy()``. """ - if self._array is not None: - return self._array + array = getattr(self, "_array", None) + if array is not None: + return array return self._tensor.detach().cpu().numpy() def to(self, device) -> Self: """Move the underlying tensor to ``device``. Raises if numpy-backed.""" - if self._tensor is None: + tensor = getattr(self, "_tensor", None) + if tensor is None: raise AttributeError( f"Cannot .to({device!r}) on numpy-backed Dataset '{self.name}'." ) - self._tensor = self._tensor.to(device) + self._tensor = tensor.to(device) return self # --- Summaries --- From d97d6a0e11b10328eb3f4d3063da9ea96382fbde Mon Sep 17 00:00:00 2001 From: Sangjoon Bob Lee Date: Wed, 27 May 2026 18:41:11 -0700 Subject: [PATCH 8/8] add TODO for numpy/torch guarding, prevent numpy copy --- src/quantem/core/datastructures/dataset.py | 33 +++++++++++++------ .../core/datastructures/dataset4dstem.py | 2 ++ tests/datastructures/test_dataset.py | 18 ++++++++++ 3 files changed, 43 insertions(+), 10 deletions(-) diff --git a/src/quantem/core/datastructures/dataset.py b/src/quantem/core/datastructures/dataset.py index fa1fcef2..94744978 100644 --- a/src/quantem/core/datastructures/dataset.py +++ b/src/quantem/core/datastructures/dataset.py @@ -55,6 +55,7 @@ def __init__( ) super().__init__() # Dual-slot storage: exactly one of (_array, _tensor) is set. + # TODO: remove dual-init guards once torch transition is complete. if array is None and tensor is None: raise ValueError("Provide either `array` (numpy) or `tensor` (torch).") if array is not None and tensor is not None: @@ -229,37 +230,49 @@ def ndim(self) -> int: return (array if array is not None else self._tensor).ndim @property - def dtype(self) -> DTypeLike: + def dtype(self) -> DTypeLike | torch.dtype: array = getattr(self, "_array", None) return (array if array is not None else self._tensor).dtype @property def device(self) -> str: - """``"cpu"`` for numpy-backed; torch device string for tensor-backed.""" - tensor = getattr(self, "_tensor", None) - if tensor is not None: - return str(tensor.device) - return "cpu" + """Device string for the underlying storage. numpy 2.x ndarray and torch.Tensor + both expose ``.device`` (array-API convention), so this is uniform. + """ + array = getattr(self, "_array", None) + return str((array if array is not None else self._tensor).device) def numpy(self) -> NDArray: """Return the data as a numpy array (mirrors ``torch.Tensor.numpy()``). For numpy-backed datasets, returns ``self.array`` directly. For - tensor-backed datasets, materializes a CPU copy via ``.detach().cpu().numpy()``. + tensor-backed datasets, materializes a read-only CPU copy via + ``.detach().cpu().numpy()``. ``flags.writeable=False`` so accidental + in-place writes raise instead of silently being lost (the copy is not + the tensor). """ array = getattr(self, "_array", None) if array is not None: return array - return self._tensor.detach().cpu().numpy() + arr = self._tensor.detach().cpu().numpy() + arr.flags.writeable = False + return arr def to(self, device) -> Self: - """Move the underlying tensor to ``device``. Raises if numpy-backed.""" + """Move the underlying tensor to ``device``. Raises if numpy-backed. + + ``device`` is normalized via :func:`quantem.core.config.validate_device` + so values like ``"cuda"``, ``0``, ``"cuda:0"``, ``torch.device("cuda:0")`` + all resolve to the same canonical device. + """ + from quantem.core import config tensor = getattr(self, "_tensor", None) if tensor is None: raise AttributeError( f"Cannot .to({device!r}) on numpy-backed Dataset '{self.name}'." ) - self._tensor = tensor.to(device) + dev, _ = config.validate_device(device) + self._tensor = tensor.to(dev) return self # --- Summaries --- diff --git a/src/quantem/core/datastructures/dataset4dstem.py b/src/quantem/core/datastructures/dataset4dstem.py index 48d3f737..004db427 100644 --- a/src/quantem/core/datastructures/dataset4dstem.py +++ b/src/quantem/core/datastructures/dataset4dstem.py @@ -181,6 +181,8 @@ def from_tensor( For cupy / jax arrays, wrap with ``torch.from_dlpack(arr)`` first. """ + # TODO: factor type + ndim checks into `ensure_valid_tensor(value, ndim=4)` + # in validators.py, matching `ensure_valid_array` pattern. Cuts bloat. if not isinstance(tensor, torch.Tensor): raise TypeError( f"from_tensor requires torch.Tensor, got {type(tensor).__name__}. " diff --git a/tests/datastructures/test_dataset.py b/tests/datastructures/test_dataset.py index 9c83262c..201fe92b 100644 --- a/tests/datastructures/test_dataset.py +++ b/tests/datastructures/test_dataset.py @@ -434,3 +434,21 @@ def test_api_errors(self, sample_dataset_2d): # Neither specified with pytest.raises(ValueError): sample_dataset_2d.fourier_resample() + + +class TestDatasetTorch: + """Tests for torch-backed Dataset (from_tensor path).""" + + def test_numpy_copy_is_readonly(self): + """``.numpy()`` on a tensor-backed dataset returns a read-only CPU copy + so writes raise instead of silently updating only the detached copy. + """ + import torch + from quantem.core.datastructures.dataset4dstem import Dataset4dstem + ds = Dataset4dstem.from_tensor(torch.zeros(2, 2, 2, 2)) + arr = ds.numpy() + assert arr.flags.writeable is False + with pytest.raises(ValueError, match="read-only"): + arr[0, 0, 0, 0] = 99.0 + +