Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/quantem/core/datastructures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from quantem.core.datastructures.vector import Vector as Vector

from quantem.core.datastructures.dataset4dstem import Dataset4dstem as Dataset4dstem
from quantem.core.datastructures.dataset5dstem import Dataset5dstem as Dataset5dstem
from quantem.core.datastructures.dataset4d import Dataset4d as Dataset4d
from quantem.core.datastructures.dataset3d import Dataset3d as Dataset3d
from quantem.core.datastructures.dataset2d import Dataset2d as Dataset2d
108 changes: 79 additions & 29 deletions src/quantem/core/datastructures/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Literal, Optional, Self, Union, overload

import numpy as np
import torch
from numpy.typing import DTypeLike, NDArray

from quantem.core.io.serialize import AutoSerialize
Expand Down Expand Up @@ -38,24 +39,37 @@ class Dataset(AutoSerialize):

def __init__(
self,
array: Any, # Input can be array-like
name: str,
origin: NDArray | tuple | list | float | int,
sampling: NDArray | tuple | list | float | int,
units: list[str] | tuple | list,
array: NDArray | None = None,
tensor: torch.Tensor | None = None,
name: str = "",
origin: NDArray | tuple | list | float | int | None = None,
sampling: NDArray | tuple | list | float | int | None = None,
units: list[str] | tuple | list | None = None,
signal_units: str = "arb. units",
metadata: Optional[dict] = None,
_token: object | None = None,
):
if _token is not self._token:
raise RuntimeError("Use Dataset.from_array() to instantiate this class.")
super().__init__()
arr = ensure_valid_array(array)
if not isinstance(arr, np.ndarray):
raise TypeError(
"Dataset requires a NumPy array (CuPy is not supported on this branch)."
raise RuntimeError(
"Use Dataset.from_array() or Dataset.from_tensor() to instantiate this class."
)
self._array = arr
super().__init__()
# Dual-slot storage: exactly one of (_array, _tensor) is set.
if array is None and tensor is None:
raise ValueError("Provide either `array` (numpy) or `tensor` (torch).")
if array is not None and tensor is not None:
raise ValueError("Provide only one of `array` or `tensor`, not both.")
if array is not None:
arr = ensure_valid_array(array)
if not isinstance(arr, np.ndarray):
raise TypeError(f"Dataset.array must be numpy.ndarray, got {type(arr).__name__}.")
self._array = arr
self._tensor = None
else:
if not isinstance(tensor, torch.Tensor):
raise TypeError(f"Dataset.tensor must be torch.Tensor, got {type(tensor).__name__}.")
self._array = None
self._tensor = tensor
self.name = name
self.origin = origin
self.sampling = sampling
Expand Down Expand Up @@ -122,19 +136,31 @@ def from_array(

# --- Properties ---
@property
def array(self) -> NDArray:
"""The underlying n-dimensional NumPy array data."""
return self._array
def array(self) -> NDArray | None:
"""The underlying n-dimensional NumPy array data.

Returns ``None`` for tensor-backed datasets. Use ``.tensor`` for the
torch tensor, or ``.numpy()`` to materialize a numpy copy explicitly.
"""
return getattr(self, "_array", None)

@array.setter
def array(self, value: NDArray) -> None:
arr = ensure_valid_array(value, ndim=self.ndim) # want to allow changing dtype
arr = ensure_valid_array(value, ndim=self.ndim)
if not isinstance(arr, np.ndarray):
raise TypeError(
"Dataset requires a NumPy array (CuPy is not supported on this branch)."
)
raise TypeError(f"Dataset.array must be numpy.ndarray, got {type(arr).__name__}.")
self._array = arr
# self._array = ensure_valid_array(value, dtype=self.dtype, ndim=self.ndim)

@property
def tensor(self) -> torch.Tensor:
"""Torch tensor backing the data. AttributeError if numpy-backed."""
# getattr handles AutoSerialize-restored instances (no __init__ run).
tensor = getattr(self, "_tensor", None)
if tensor is None:
raise AttributeError(
f"Dataset '{self.name}' is numpy-backed; use Dataset.from_tensor() at construction."
)
return tensor

@property
def metadata(self) -> dict:
Expand Down Expand Up @@ -191,26 +217,50 @@ def file_path(self, value: os.PathLike | str | None) -> None:
# --- Derived Properties ---
@property
def shape(self) -> tuple[int, ...]:
return self.array.shape
# Direct slot access (never triggers .array derive, which would force
# a full GPU->CPU copy on tensor-backed datasets). getattr handles
# AutoSerialize-restored instances (no __init__ run).
array = getattr(self, "_array", None)
return tuple((array if array is not None else self._tensor).shape)

@property
def ndim(self) -> int:
return self.array.ndim
array = getattr(self, "_array", None)
return (array if array is not None else self._tensor).ndim

@property
def dtype(self) -> DTypeLike:
return self.array.dtype
array = getattr(self, "_array", None)
return (array if array is not None else self._tensor).dtype

@property
def device(self) -> str:
"""
Outputting a string is likely temporary -- once we have our use cases we can
figure out a more permanent device solution that enables easier translation between
numpy <-> torch <-> numpy, etc.
"""``"cpu"`` for numpy-backed; torch device string for tensor-backed."""
tensor = getattr(self, "_tensor", None)
if tensor is not None:
return str(tensor.device)
return "cpu"

For NumPy-only datasets, this is always "cpu".
def numpy(self) -> NDArray:
"""Return the data as a numpy array (mirrors ``torch.Tensor.numpy()``).

For numpy-backed datasets, returns ``self.array`` directly. For
tensor-backed datasets, materializes a CPU copy via ``.detach().cpu().numpy()``.
"""
return "cpu"
array = getattr(self, "_array", None)
if array is not None:
return array
return self._tensor.detach().cpu().numpy()

def to(self, device) -> Self:
"""Move the underlying tensor to ``device``. Raises if numpy-backed."""
tensor = getattr(self, "_tensor", None)
if tensor is None:
raise AttributeError(
f"Cannot .to({device!r}) on numpy-backed Dataset '{self.name}'."
)
self._tensor = tensor.to(device)
return self

# --- Summaries ---
def __repr__(self) -> str:
Expand Down
19 changes: 12 additions & 7 deletions src/quantem/core/datastructures/dataset4d.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Self, Union

import numpy as np
import torch
from numpy.typing import NDArray

from quantem.core.datastructures.dataset import Dataset
Expand All @@ -21,11 +22,12 @@ class Dataset4d(Dataset):

def __init__(
self,
array: NDArray | Any,
name: str,
origin: NDArray | tuple | list | float | int,
sampling: NDArray | tuple | list | float | int,
units: list[str] | tuple | list,
array: NDArray | None = None,
tensor: torch.Tensor | None = None,
name: str = "",
origin: NDArray | tuple | list | float | int | None = None,
sampling: NDArray | tuple | list | float | int | None = None,
units: list[str] | tuple | list | None = None,
signal_units: str = "arb. units",
metadata: dict = {},
_token: object | None = None,
Expand All @@ -34,8 +36,10 @@ def __init__(

Parameters
----------
array : NDArray | Any
The underlying 3D array data
array : NDArray | None
The underlying 4D numpy array. Provide exactly one of ``array`` or ``tensor``.
tensor : torch.Tensor | None
The underlying 4D torch tensor (any device). Provide exactly one of ``array`` or ``tensor``.
name : str
A descriptive name for the dataset
origin : NDArray | tuple | list | float | int
Expand All @@ -51,6 +55,7 @@ def __init__(
"""
super().__init__(
array=array,
tensor=tensor,
name=name,
origin=origin,
sampling=sampling,
Expand Down
55 changes: 50 additions & 5 deletions src/quantem/core/datastructures/dataset4dstem.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.patches import Circle, Wedge
from numpy.typing import NDArray

Expand Down Expand Up @@ -41,11 +42,12 @@ class Dataset4dstem(Dataset4d):

def __init__(
self,
array: NDArray | Any,
name: str,
origin: NDArray | tuple | list | float | int,
sampling: NDArray | tuple | list | float | int,
units: list[str] | tuple | list,
array: NDArray | None = None,
tensor: torch.Tensor | None = None,
name: str = "",
origin: NDArray | tuple | list | float | int | None = None,
sampling: NDArray | tuple | list | float | int | None = None,
units: list[str] | tuple | list | None = None,
signal_units: str = "arb. units",
metadata: dict = {},
_token: object | None = None,
Expand All @@ -56,6 +58,9 @@ def __init__(
----------
array : NDArray | Any
The underlying 4D array data
tensor : torch.Tensor | None, optional
Alternative to ``array``: the underlying 4D torch tensor (any device).
Provide exactly one of ``array`` or ``tensor``.
name : str
A descriptive name for the dataset
origin : NDArray | tuple | list | float | int
Expand All @@ -79,6 +84,7 @@ def __init__(

super().__init__(
array=array,
tensor=tensor,
name=name,
origin=origin,
sampling=sampling,
Expand Down Expand Up @@ -157,6 +163,45 @@ def from_array(
_token=cls._token,
)

@classmethod
def from_tensor(
cls,
tensor: torch.Tensor,
name: str | None = None,
origin: NDArray | tuple | list | float | int | None = None,
sampling: NDArray | tuple | list | float | int | None = None,
units: list[str] | tuple | list | None = None,
signal_units: str = "arb. units",
metadata: dict | None = None,
) -> Self:
"""Create a Dataset4dstem from a torch tensor (any device).

Use this when raw data is GPU-resident (CUDA pipelines, live detector
frames, GPU file readers) to skip the VRAM<->RAM round-trip.

For cupy / jax arrays, wrap with ``torch.from_dlpack(arr)`` first.
"""
if not isinstance(tensor, torch.Tensor):
raise TypeError(
f"from_tensor requires torch.Tensor, got {type(tensor).__name__}. "
f"For cupy / jax, wrap with `torch.from_dlpack(arr)` first."
)
if tensor.ndim != 4:
raise ValueError(
f"Dataset4dstem.from_tensor requires a 4D tensor "
f"(scan_row, scan_col, dp_row, dp_col), got shape {tuple(tensor.shape)}."
)
return cls(
tensor=tensor,
name=name if name is not None else "4D-STEM dataset (torch)",
origin=origin if origin is not None else np.zeros(4),
sampling=sampling if sampling is not None else np.ones(4),
units=units if units is not None else ["pixels"] * 4,
signal_units=signal_units,
metadata=metadata if metadata is not None else {},
_token=cls._token,
)

@property
def virtual_images(self) -> dict[str, Dataset2d]:
"""
Expand Down
Loading