-
Notifications
You must be signed in to change notification settings - Fork 26
dataset4dstem support torch array while maintaining backward compatibility
#228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
98ddd0d
d2dc32b
6b7319d
7f9913f
9c93ae9
40878d3
21b684f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| from typing import Any, Literal, Optional, Self, Union, overload | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from numpy.typing import DTypeLike, NDArray | ||
|
|
||
| from quantem.core.io.serialize import AutoSerialize | ||
|
|
@@ -38,24 +39,37 @@ class Dataset(AutoSerialize): | |
|
|
||
| def __init__( | ||
| self, | ||
| array: Any, # Input can be array-like | ||
| name: str, | ||
| origin: NDArray | tuple | list | float | int, | ||
| sampling: NDArray | tuple | list | float | int, | ||
| units: list[str] | tuple | list, | ||
| array: NDArray | None = None, | ||
| tensor: torch.Tensor | None = None, | ||
| name: str = "", | ||
| origin: NDArray | tuple | list | float | int | None = None, | ||
| sampling: NDArray | tuple | list | float | int | None = None, | ||
| units: list[str] | tuple | list | None = None, | ||
| signal_units: str = "arb. units", | ||
| metadata: Optional[dict] = None, | ||
| _token: object | None = None, | ||
| ): | ||
| if _token is not self._token: | ||
| raise RuntimeError("Use Dataset.from_array() to instantiate this class.") | ||
| super().__init__() | ||
| arr = ensure_valid_array(array) | ||
| if not isinstance(arr, np.ndarray): | ||
| raise TypeError( | ||
| "Dataset requires a NumPy array (CuPy is not supported on this branch)." | ||
| raise RuntimeError( | ||
| "Use Dataset.from_array() or Dataset.from_tensor() to instantiate this class." | ||
| ) | ||
| self._array = arr | ||
| super().__init__() | ||
| # Dual-slot storage: exactly one of (_array, _tensor) is set. | ||
| if array is None and tensor is None: | ||
| raise ValueError("Provide either `array` (numpy) or `tensor` (torch).") | ||
| if array is not None and tensor is not None: | ||
| raise ValueError("Provide only one of `array` or `tensor`, not both.") | ||
| if array is not None: | ||
| arr = ensure_valid_array(array) | ||
| if not isinstance(arr, np.ndarray): | ||
| raise TypeError(f"Dataset.array must be numpy.ndarray, got {type(arr).__name__}.") | ||
| self._array = arr | ||
| self._tensor = None | ||
| else: | ||
| if not isinstance(tensor, torch.Tensor): | ||
| raise TypeError(f"Dataset.tensor must be torch.Tensor, got {type(tensor).__name__}.") | ||
| self._array = None | ||
| self._tensor = tensor | ||
| self.name = name | ||
| self.origin = origin | ||
| self.sampling = sampling | ||
|
|
@@ -122,19 +136,31 @@ def from_array( | |
|
|
||
| # --- Properties --- | ||
| @property | ||
| def array(self) -> NDArray: | ||
| """The underlying n-dimensional NumPy array data.""" | ||
| return self._array | ||
| def array(self) -> NDArray | None: | ||
| """The underlying n-dimensional NumPy array data. | ||
|
|
||
| Returns ``None`` for tensor-backed datasets. Use ``.tensor`` for the | ||
| torch tensor, or ``.numpy()`` to materialize a numpy copy explicitly. | ||
| """ | ||
| return getattr(self, "_array", None) | ||
|
|
||
| @array.setter | ||
| def array(self, value: NDArray) -> None: | ||
| arr = ensure_valid_array(value, ndim=self.ndim) # want to allow changing dtype | ||
| arr = ensure_valid_array(value, ndim=self.ndim) | ||
| if not isinstance(arr, np.ndarray): | ||
| raise TypeError( | ||
| "Dataset requires a NumPy array (CuPy is not supported on this branch)." | ||
| ) | ||
| raise TypeError(f"Dataset.array must be numpy.ndarray, got {type(arr).__name__}.") | ||
| self._array = arr | ||
| # self._array = ensure_valid_array(value, dtype=self.dtype, ndim=self.ndim) | ||
|
|
||
| @property | ||
|
arthurmccray marked this conversation as resolved.
|
||
| def tensor(self) -> torch.Tensor: | ||
| """Torch tensor backing the data. AttributeError if numpy-backed.""" | ||
| # getattr handles AutoSerialize-restored instances (no __init__ run). | ||
| tensor = getattr(self, "_tensor", None) | ||
| if tensor is None: | ||
| raise AttributeError( | ||
| f"Dataset '{self.name}' is numpy-backed; use Dataset.from_tensor() at construction." | ||
| ) | ||
| return tensor | ||
|
|
||
| @property | ||
| def metadata(self) -> dict: | ||
|
|
@@ -191,26 +217,50 @@ def file_path(self, value: os.PathLike | str | None) -> None: | |
| # --- Derived Properties --- | ||
| @property | ||
| def shape(self) -> tuple[int, ...]: | ||
| return self.array.shape | ||
| # Direct slot access (never triggers .array derive, which would force | ||
| # a full GPU->CPU copy on tensor-backed datasets). getattr handles | ||
| # AutoSerialize-restored instances (no __init__ run). | ||
| array = getattr(self, "_array", None) | ||
| return tuple((array if array is not None else self._tensor).shape) | ||
|
|
||
| @property | ||
| def ndim(self) -> int: | ||
| return self.array.ndim | ||
| array = getattr(self, "_array", None) | ||
| return (array if array is not None else self._tensor).ndim | ||
|
|
||
| @property | ||
| def dtype(self) -> DTypeLike: | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are |
||
| return self.array.dtype | ||
| array = getattr(self, "_array", None) | ||
| return (array if array is not None else self._tensor).dtype | ||
|
|
||
| @property | ||
| def device(self) -> str: | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can actually do |
||
| """ | ||
| Outputting a string is likely temporary -- once we have our use cases we can | ||
| figure out a more permanent device solution that enables easier translation between | ||
| numpy <-> torch <-> numpy, etc. | ||
| """``"cpu"`` for numpy-backed; torch device string for tensor-backed.""" | ||
| tensor = getattr(self, "_tensor", None) | ||
| if tensor is not None: | ||
| return str(tensor.device) | ||
| return "cpu" | ||
|
|
||
| For NumPy-only datasets, this is always "cpu". | ||
| def numpy(self) -> NDArray: | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @arthurmccray comment on - getting User used to this for explicit array type.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this looks good to me! Only thing i would add is the |
||
| """Return the data as a numpy array (mirrors ``torch.Tensor.numpy()``). | ||
|
|
||
| For numpy-backed datasets, returns ``self.array`` directly. For | ||
| tensor-backed datasets, materializes a CPU copy via ``.detach().cpu().numpy()``. | ||
| """ | ||
| return "cpu" | ||
| array = getattr(self, "_array", None) | ||
| if array is not None: | ||
| return array | ||
| return self._tensor.detach().cpu().numpy() | ||
|
|
||
| def to(self, device) -> Self: | ||
| """Move the underlying tensor to ``device``. Raises if numpy-backed.""" | ||
| tensor = getattr(self, "_tensor", None) | ||
| if tensor is None: | ||
| raise AttributeError( | ||
| f"Cannot .to({device!r}) on numpy-backed Dataset '{self.name}'." | ||
| ) | ||
| self._tensor = tensor.to(device) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From the |
||
| return self | ||
|
|
||
| # --- Summaries --- | ||
| def __repr__(self) -> str: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
|
|
||
| import matplotlib.pyplot as plt | ||
| import numpy as np | ||
| import torch | ||
| from matplotlib.patches import Circle, Wedge | ||
| from numpy.typing import NDArray | ||
|
|
||
|
|
@@ -41,11 +42,12 @@ class Dataset4dstem(Dataset4d): | |
|
|
||
| def __init__( | ||
| self, | ||
| array: NDArray | Any, | ||
| name: str, | ||
| origin: NDArray | tuple | list | float | int, | ||
| sampling: NDArray | tuple | list | float | int, | ||
| units: list[str] | tuple | list, | ||
| array: NDArray | None = None, | ||
| tensor: torch.Tensor | None = None, | ||
| name: str = "", | ||
| origin: NDArray | tuple | list | float | int | None = None, | ||
| sampling: NDArray | tuple | list | float | int | None = None, | ||
| units: list[str] | tuple | list | None = None, | ||
| signal_units: str = "arb. units", | ||
| metadata: dict = {}, | ||
| _token: object | None = None, | ||
|
|
@@ -56,6 +58,9 @@ def __init__( | |
| ---------- | ||
| array : NDArray | Any | ||
| The underlying 4D array data | ||
| tensor : torch.Tensor | None, optional | ||
| Alternative to ``array``: the underlying 4D torch tensor (any device). | ||
| Provide exactly one of ``array`` or ``tensor``. | ||
| name : str | ||
| A descriptive name for the dataset | ||
| origin : NDArray | tuple | list | float | int | ||
|
|
@@ -79,6 +84,7 @@ def __init__( | |
|
|
||
| super().__init__( | ||
| array=array, | ||
| tensor=tensor, | ||
| name=name, | ||
| origin=origin, | ||
| sampling=sampling, | ||
|
|
@@ -157,6 +163,45 @@ def from_array( | |
| _token=cls._token, | ||
| ) | ||
|
|
||
| @classmethod | ||
| def from_tensor( | ||
|
arthurmccray marked this conversation as resolved.
|
||
| cls, | ||
| tensor: torch.Tensor, | ||
| name: str | None = None, | ||
| origin: NDArray | tuple | list | float | int | None = None, | ||
| sampling: NDArray | tuple | list | float | int | None = None, | ||
| units: list[str] | tuple | list | None = None, | ||
| signal_units: str = "arb. units", | ||
| metadata: dict | None = None, | ||
| ) -> Self: | ||
| """Create a Dataset4dstem from a torch tensor (any device). | ||
|
|
||
| Use this when raw data is GPU-resident (CUDA pipelines, live detector | ||
| frames, GPU file readers) to skip the VRAM<->RAM round-trip. | ||
|
|
||
| For cupy / jax arrays, wrap with ``torch.from_dlpack(arr)`` first. | ||
| """ | ||
| if not isinstance(tensor, torch.Tensor): | ||
|
arthurmccray marked this conversation as resolved.
|
||
| raise TypeError( | ||
| f"from_tensor requires torch.Tensor, got {type(tensor).__name__}. " | ||
| f"For cupy / jax, wrap with `torch.from_dlpack(arr)` first." | ||
| ) | ||
| if tensor.ndim != 4: | ||
| raise ValueError( | ||
| f"Dataset4dstem.from_tensor requires a 4D tensor " | ||
| f"(scan_row, scan_col, dp_row, dp_col), got shape {tuple(tensor.shape)}." | ||
| ) | ||
|
Comment on lines
+189
to
+193
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is fine for now, but I think we should update the validators (or maybe make a new |
||
| return cls( | ||
| tensor=tensor, | ||
| name=name if name is not None else "4D-STEM dataset (torch)", | ||
| origin=origin if origin is not None else np.zeros(4), | ||
| sampling=sampling if sampling is not None else np.ones(4), | ||
| units=units if units is not None else ["pixels"] * 4, | ||
| signal_units=signal_units, | ||
| metadata=metadata if metadata is not None else {}, | ||
| _token=cls._token, | ||
| ) | ||
|
|
||
| @property | ||
| def virtual_images(self) -> dict[str, Dataset2d]: | ||
| """ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some conditional checks for now - user can either have nupmy-backed OR torch-backed. Not both at this stage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way that an array or tensor is never initialized for a Dataset? Otherwise, I feel like this first conditional is kind of redundant since everything is instantiated with
from_dataorfrom_tensor.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed that some of these protections are probably unnecessary, but it's okay to leave them assuming that they will be removed once the transition is complete (maybe with a comment stating as much)