Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 79 additions & 29 deletions src/quantem/core/datastructures/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Literal, Optional, Self, Union, overload

import numpy as np
import torch
from numpy.typing import DTypeLike, NDArray

from quantem.core.io.serialize import AutoSerialize
Expand Down Expand Up @@ -38,24 +39,37 @@ class Dataset(AutoSerialize):

def __init__(
self,
array: Any, # Input can be array-like
name: str,
origin: NDArray | tuple | list | float | int,
sampling: NDArray | tuple | list | float | int,
units: list[str] | tuple | list,
array: NDArray | None = None,
tensor: torch.Tensor | None = None,
name: str = "",
origin: NDArray | tuple | list | float | int | None = None,
sampling: NDArray | tuple | list | float | int | None = None,
units: list[str] | tuple | list | None = None,
signal_units: str = "arb. units",
metadata: Optional[dict] = None,
_token: object | None = None,
):
if _token is not self._token:
raise RuntimeError("Use Dataset.from_array() to instantiate this class.")
super().__init__()
arr = ensure_valid_array(array)
if not isinstance(arr, np.ndarray):
raise TypeError(
"Dataset requires a NumPy array (CuPy is not supported on this branch)."
raise RuntimeError(
"Use Dataset.from_array() or Dataset.from_tensor() to instantiate this class."
)
self._array = arr
super().__init__()
# Dual-slot storage: exactly one of (_array, _tensor) is set.
if array is None and tensor is None:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some conditional checks for now - user can either have nupmy-backed OR torch-backed. Not both at this stage

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way that an array or tensor is never initialized for a Dataset? Otherwise, I feel like this first conditional is kind of redundant since everything is instantiated with from_data or from_tensor.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed that some of these protections are probably unnecessary, but it's okay to leave them assuming that they will be removed once the transition is complete (maybe with a comment stating as much)

raise ValueError("Provide either `array` (numpy) or `tensor` (torch).")
if array is not None and tensor is not None:
raise ValueError("Provide only one of `array` or `tensor`, not both.")
if array is not None:
arr = ensure_valid_array(array)
if not isinstance(arr, np.ndarray):
raise TypeError(f"Dataset.array must be numpy.ndarray, got {type(arr).__name__}.")
self._array = arr
self._tensor = None
else:
if not isinstance(tensor, torch.Tensor):
raise TypeError(f"Dataset.tensor must be torch.Tensor, got {type(tensor).__name__}.")
self._array = None
self._tensor = tensor
self.name = name
self.origin = origin
self.sampling = sampling
Expand Down Expand Up @@ -122,19 +136,31 @@ def from_array(

# --- Properties ---
@property
def array(self) -> NDArray:
"""The underlying n-dimensional NumPy array data."""
return self._array
def array(self) -> NDArray | None:
"""The underlying n-dimensional NumPy array data.

Returns ``None`` for tensor-backed datasets. Use ``.tensor`` for the
torch tensor, or ``.numpy()`` to materialize a numpy copy explicitly.
"""
return getattr(self, "_array", None)

@array.setter
def array(self, value: NDArray) -> None:
arr = ensure_valid_array(value, ndim=self.ndim) # want to allow changing dtype
arr = ensure_valid_array(value, ndim=self.ndim)
if not isinstance(arr, np.ndarray):
raise TypeError(
"Dataset requires a NumPy array (CuPy is not supported on this branch)."
)
raise TypeError(f"Dataset.array must be numpy.ndarray, got {type(arr).__name__}.")
self._array = arr
# self._array = ensure_valid_array(value, dtype=self.dtype, ndim=self.ndim)

@property
Comment thread
arthurmccray marked this conversation as resolved.
def tensor(self) -> torch.Tensor:
"""Torch tensor backing the data. AttributeError if numpy-backed."""
# getattr handles AutoSerialize-restored instances (no __init__ run).
tensor = getattr(self, "_tensor", None)
if tensor is None:
raise AttributeError(
f"Dataset '{self.name}' is numpy-backed; use Dataset.from_tensor() at construction."
)
return tensor

@property
def metadata(self) -> dict:
Expand Down Expand Up @@ -191,26 +217,50 @@ def file_path(self, value: os.PathLike | str | None) -> None:
# --- Derived Properties ---
@property
def shape(self) -> tuple[int, ...]:
return self.array.shape
# Direct slot access (never triggers .array derive, which would force
# a full GPU->CPU copy on tensor-backed datasets). getattr handles
# AutoSerialize-restored instances (no __init__ run).
array = getattr(self, "_array", None)
return tuple((array if array is not None else self._tensor).shape)

@property
def ndim(self) -> int:
return self.array.ndim
array = getattr(self, "_array", None)
return (array if array is not None else self._tensor).ndim

@property
def dtype(self) -> DTypeLike:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype - based on the given numpy or torch

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are torch.dtype included in numpys DTypeLike?

return self.array.dtype
array = getattr(self, "_array", None)
return (array if array is not None else self._tensor).dtype

@property
def device(self) -> str:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

device - cpy for numpy, for torch, depends on the tensor

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can actually do .device on numpy arrays, np.arange(10).device -> "cpu". it's included to be compatible with other array packages :)

"""
Outputting a string is likely temporary -- once we have our use cases we can
figure out a more permanent device solution that enables easier translation between
numpy <-> torch <-> numpy, etc.
"""``"cpu"`` for numpy-backed; torch device string for tensor-backed."""
tensor = getattr(self, "_tensor", None)
if tensor is not None:
return str(tensor.device)
return "cpu"

For NumPy-only datasets, this is always "cpu".
def numpy(self) -> NDArray:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arthurmccray comment on - getting User used to this for explicit array type.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks good to me! Only thing i would add is the flags.writable thing that Cedric found to the torch tensor output, making it clear that it cannot be writable. I haven't tested this but it seems like what we want: #222 (comment)

"""Return the data as a numpy array (mirrors ``torch.Tensor.numpy()``).

For numpy-backed datasets, returns ``self.array`` directly. For
tensor-backed datasets, materializes a CPU copy via ``.detach().cpu().numpy()``.
"""
return "cpu"
array = getattr(self, "_array", None)
if array is not None:
return array
return self._tensor.detach().cpu().numpy()

def to(self, device) -> Self:
"""Move the underlying tensor to ``device``. Raises if numpy-backed."""
tensor = getattr(self, "_tensor", None)
if tensor is None:
raise AttributeError(
f"Cannot .to({device!r}) on numpy-backed Dataset '{self.name}'."
)
self._tensor = tensor.to(device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the config module we have a method for validating and getting canonical names for devices, which might be useful here.

from quantem.core import config

dev, _id = config.validate_device(device)
self._tensor = tensor.to(dev)

return self

# --- Summaries ---
def __repr__(self) -> str:
Expand Down
19 changes: 12 additions & 7 deletions src/quantem/core/datastructures/dataset4d.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Self, Union

import numpy as np
import torch
from numpy.typing import NDArray

from quantem.core.datastructures.dataset import Dataset
Expand All @@ -21,11 +22,12 @@ class Dataset4d(Dataset):

def __init__(
self,
array: NDArray | Any,
name: str,
origin: NDArray | tuple | list | float | int,
sampling: NDArray | tuple | list | float | int,
units: list[str] | tuple | list,
array: NDArray | None = None,
tensor: torch.Tensor | None = None,
name: str = "",
origin: NDArray | tuple | list | float | int | None = None,
sampling: NDArray | tuple | list | float | int | None = None,
units: list[str] | tuple | list | None = None,
signal_units: str = "arb. units",
metadata: dict = {},
_token: object | None = None,
Expand All @@ -34,8 +36,10 @@ def __init__(

Parameters
----------
array : NDArray | Any
The underlying 3D array data
array : NDArray | None
The underlying 4D numpy array. Provide exactly one of ``array`` or ``tensor``.
tensor : torch.Tensor | None
The underlying 4D torch tensor (any device). Provide exactly one of ``array`` or ``tensor``.
name : str
A descriptive name for the dataset
origin : NDArray | tuple | list | float | int
Expand All @@ -51,6 +55,7 @@ def __init__(
"""
super().__init__(
array=array,
tensor=tensor,
name=name,
origin=origin,
sampling=sampling,
Expand Down
55 changes: 50 additions & 5 deletions src/quantem/core/datastructures/dataset4dstem.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.patches import Circle, Wedge
from numpy.typing import NDArray

Expand Down Expand Up @@ -41,11 +42,12 @@ class Dataset4dstem(Dataset4d):

def __init__(
self,
array: NDArray | Any,
name: str,
origin: NDArray | tuple | list | float | int,
sampling: NDArray | tuple | list | float | int,
units: list[str] | tuple | list,
array: NDArray | None = None,
tensor: torch.Tensor | None = None,
name: str = "",
origin: NDArray | tuple | list | float | int | None = None,
sampling: NDArray | tuple | list | float | int | None = None,
units: list[str] | tuple | list | None = None,
signal_units: str = "arb. units",
metadata: dict = {},
_token: object | None = None,
Expand All @@ -56,6 +58,9 @@ def __init__(
----------
array : NDArray | Any
The underlying 4D array data
tensor : torch.Tensor | None, optional
Alternative to ``array``: the underlying 4D torch tensor (any device).
Provide exactly one of ``array`` or ``tensor``.
name : str
A descriptive name for the dataset
origin : NDArray | tuple | list | float | int
Expand All @@ -79,6 +84,7 @@ def __init__(

super().__init__(
array=array,
tensor=tensor,
name=name,
origin=origin,
sampling=sampling,
Expand Down Expand Up @@ -157,6 +163,45 @@ def from_array(
_token=cls._token,
)

@classmethod
def from_tensor(
Comment thread
arthurmccray marked this conversation as resolved.
cls,
tensor: torch.Tensor,
name: str | None = None,
origin: NDArray | tuple | list | float | int | None = None,
sampling: NDArray | tuple | list | float | int | None = None,
units: list[str] | tuple | list | None = None,
signal_units: str = "arb. units",
metadata: dict | None = None,
) -> Self:
"""Create a Dataset4dstem from a torch tensor (any device).

Use this when raw data is GPU-resident (CUDA pipelines, live detector
frames, GPU file readers) to skip the VRAM<->RAM round-trip.

For cupy / jax arrays, wrap with ``torch.from_dlpack(arr)`` first.
"""
if not isinstance(tensor, torch.Tensor):
Comment thread
arthurmccray marked this conversation as resolved.
raise TypeError(
f"from_tensor requires torch.Tensor, got {type(tensor).__name__}. "
f"For cupy / jax, wrap with `torch.from_dlpack(arr)` first."
)
if tensor.ndim != 4:
raise ValueError(
f"Dataset4dstem.from_tensor requires a 4D tensor "
f"(scan_row, scan_col, dp_row, dp_col), got shape {tuple(tensor.shape)}."
)
Comment on lines +189 to +193
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine for now, but I think we should update the validators (or maybe make a new ensure_valid_tensor to match ensure_valid_array). I generally like having validators as it significantly cuts down on bloat.

return cls(
tensor=tensor,
name=name if name is not None else "4D-STEM dataset (torch)",
origin=origin if origin is not None else np.zeros(4),
sampling=sampling if sampling is not None else np.ones(4),
units=units if units is not None else ["pixels"] * 4,
signal_units=signal_units,
metadata=metadata if metadata is not None else {},
_token=cls._token,
)

@property
def virtual_images(self) -> dict[str, Dataset2d]:
"""
Expand Down
13 changes: 8 additions & 5 deletions widget/src/quantem/widget/show4dstem.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,15 +333,18 @@ def __init__(

_io_labels = None

# Auto-extract sampling + units from Dataset4dstem if available.
if hasattr(data, "sampling") and hasattr(data, "array"):
if not title and hasattr(data, "name") and data.name:
# Extract underlying array / tensor + auto-calibrate from Dataset input
Comment thread
arthurmccray marked this conversation as resolved.
# (duck-typed via the dual-slot private attributes _tensor / _array).
tensor = getattr(data, "_tensor", None)
array = getattr(data, "_array", None)
if tensor is not None or array is not None:
if not title and getattr(data, "name", ""):
title = str(data.name)
if sampling is None:
sampling = tuple(float(s) for s in data.sampling)
if units is None and hasattr(data, "units"):
if units is None:
units = list(data.units)
data = data.array
data = tensor if tensor is not None else array

# Resolve sampling + units (4 axes for 4D-STEM):
# [scan_row, scan_col, k_row, k_col]. Scalar/None broadcast to (1, 1, 1, 1).
Expand Down