From 40326e956fce950e0ce7ee4b52db470d44c7e1cc Mon Sep 17 00:00:00 2001 From: giovp Date: Wed, 21 Aug 2024 20:39:35 +0200 Subject: [PATCH 01/22] implement selection --- src/spatialdata/dataloader/datasets.py | 47 ++++++++++++++++---------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index c65bc4f5a..7370b2a6a 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -128,7 +128,8 @@ def __init__( from spatialdata import bounding_box_query from spatialdata._core.operations.rasterize import rasterize as rasterize_fn - self._validate(sdata, regions_to_images, regions_to_coordinate_systems, return_annotations, table_name) + self.sdata = sdata + self._validate(regions_to_images, regions_to_coordinate_systems, return_annotations, table_name) self._preprocess(tile_scale, tile_dim_in_units, rasterize, table_name) if rasterize_kwargs is not None and len(rasterize_kwargs) > 0 and rasterize is False: @@ -145,21 +146,19 @@ def __init__( **dict(rasterize_kwargs), ) if rasterize - else bounding_box_query # type: ignore[assignment] + else partial(bounding_box_query, return_request_only=True) # type: ignore[assignment] ) self._return = self._get_return(return_annotations, table_name) self.transform = transform def _validate( self, - sdata: SpatialData, regions_to_images: dict[str, str], regions_to_coordinate_systems: dict[str, str], return_annotations: str | list[str] | None, table_name: str | None, ) -> None: """Validate input parameters.""" - self.sdata = sdata if return_annotations is not None and table_name is None: raise ValueError("`table_name` must be provided if `return_annotations` is not `None`.") @@ -174,8 +173,8 @@ def _validate( image_name = regions_to_images[region_name] # get elements - region_elem = sdata[region_name] - image_elem = sdata[image_name] + region_elem = self.sdata[region_name] + image_elem = self.sdata[image_name] # check that the elements are supported if get_model(region_elem) == PointsModel: @@ -200,13 +199,13 @@ def _validate( ) if table_name is not None: - _, region_key, instance_key = get_table_keys(sdata.tables[table_name]) + _, region_key, instance_key = get_table_keys(self.sdata.tables[table_name]) if get_model(region_elem) in [Labels2DModel, Labels3DModel]: indices = get_element_instances(region_elem).tolist() else: indices = region_elem.index.tolist() - table = sdata.tables[table_name] - if not isinstance(sdata.tables[table_name].obs[region_key].dtype, CategoricalDtype): + table = self.sdata.tables[table_name] + if not isinstance(self.sdata.tables[table_name].obs[region_key].dtype, CategoricalDtype): raise TypeError( f"The `regions_element` column `{region_key}` in the table must be a categorical dtype. " f"Please convert it." @@ -229,8 +228,10 @@ def _preprocess( table_name: str | None, ) -> None: """Preprocess the dataset.""" + from spatialdata import bounding_box_query + if table_name is not None: - _, region_key, instance_key = get_table_keys(self.sdata.tables[table_name]) + _, region_key, _ = get_table_keys(self.sdata.tables[table_name]) filtered_table = self.sdata.tables[table_name][ self.sdata.tables[table_name].obs[region_key].isin(self.regions) ] # filtered table for the data loader @@ -250,6 +251,17 @@ def _preprocess( tile_scale=tile_scale, tile_dim_in_units=tile_dim_in_units, ) + tile_coords["selection"] = tile_coords.apply( + lambda row: bounding_box_query( + self.sdata[image_name], + ("x", "y"), + min_coordinate=row[["minx", "miny"]].values, + max_coordinate=row[["maxx", "maxy"]].values, + target_coordinate_system=cs, + return_request_only=True, + ), + axis=1, + ) tile_coords_df.append(tile_coords) inst = circles.index.values @@ -359,13 +371,14 @@ def __getitem__(self, idx: int) -> Any | SpatialData: t_coords = self.tiles_coords.iloc[idx] image = self.sdata[row["image"]] - tile = self._crop_image( - image, - axes=tuple(self.dims), - min_coordinate=t_coords[[f"min{i}" for i in self.dims]].values, - max_coordinate=t_coords[[f"max{i}" for i in self.dims]].values, - target_coordinate_system=row["cs"], - ) + # tile = self._crop_image( + # image, + # axes=tuple(self.dims), + # min_coordinate=t_coords[[f"min{i}" for i in self.dims]].values, + # max_coordinate=t_coords[[f"max{i}" for i in self.dims]].values, + # target_coordinate_system=row["cs"], + # ) + tile = image.sel(t_coords["selection"]) if self.transform is not None: out = self._return(idx, tile) return self.transform(out) From aa339aa719c808afdce662bd8a6e1305e06f3654 Mon Sep 17 00:00:00 2001 From: giovp Date: Wed, 21 Aug 2024 20:47:50 +0200 Subject: [PATCH 02/22] update --- src/spatialdata/dataloader/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 7370b2a6a..6b7ec7ae9 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -252,7 +252,7 @@ def _preprocess( tile_dim_in_units=tile_dim_in_units, ) tile_coords["selection"] = tile_coords.apply( - lambda row: bounding_box_query( + lambda row, cs=cs, image_name=image_name: bounding_box_query( self.sdata[image_name], ("x", "y"), min_coordinate=row[["minx", "miny"]].values, From 92d578fea77735b2279ee624aa9a316132690bca Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 2 Sep 2024 11:58:25 -0700 Subject: [PATCH 03/22] vectorize adjust_bounding_box_to_real_axes --- src/spatialdata/_core/query/spatial_query.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index dea2280a5..5a312149f 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -230,6 +230,7 @@ def _adjust_bounding_box_to_real_axes( The bounding box is defined by the user and its axes may not coincide with the axes of the transformation. """ + axis = min_coordinate.ndim - 1 if set(axes_bb) != set(axes_out_without_c): axes_only_in_bb = set(axes_bb) - set(axes_out_without_c) axes_only_in_output = set(axes_out_without_c) - set(axes_bb) @@ -246,8 +247,8 @@ def _adjust_bounding_box_to_real_axes( for ax in axes_only_in_output: axes_bb = axes_bb + (ax,) M = np.finfo(np.float32).max - 1 - min_coordinate = np.append(min_coordinate, -M) - max_coordinate = np.append(max_coordinate, M) + min_coordinate = np.append(min_coordinate, -M, axis=axis) + max_coordinate = np.append(max_coordinate, M, axis=axis) else: indices = [axes_bb.index(ax) for ax in axes_out_without_c] min_coordinate = min_coordinate[np.array(indices)] From 2bb5c35e34cee512bad60011a0c2b7ae3bcfb5a8 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 2 Sep 2024 12:00:58 -0700 Subject: [PATCH 04/22] update --- src/spatialdata/_core/query/spatial_query.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 5a312149f..fbfbe9c89 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -239,8 +239,8 @@ def _adjust_bounding_box_to_real_axes( # 3D bounding box) indices_to_remove_from_bb = [axes_bb.index(ax) for ax in axes_only_in_bb] axes_bb = tuple(ax for ax in axes_bb if ax not in axes_only_in_bb) - min_coordinate = np.delete(min_coordinate, indices_to_remove_from_bb) - max_coordinate = np.delete(max_coordinate, indices_to_remove_from_bb) + min_coordinate = np.delete(min_coordinate, indices_to_remove_from_bb, axis=axis) + max_coordinate = np.delete(max_coordinate, indices_to_remove_from_bb, axis=axis) # if there are axes in the output axes that are not in the bounding box, we need to add them to the bounding box # with a range that includes everything (e.g. querying 3D points with a 2D bounding box) @@ -251,8 +251,8 @@ def _adjust_bounding_box_to_real_axes( max_coordinate = np.append(max_coordinate, M, axis=axis) else: indices = [axes_bb.index(ax) for ax in axes_out_without_c] - min_coordinate = min_coordinate[np.array(indices)] - max_coordinate = max_coordinate[np.array(indices)] + min_coordinate = np.take(min_coordinate, indices, axis=axis) + max_coordinate = np.take(max_coordinate, indices, axis=axis) axes_bb = axes_out_without_c return axes_bb, min_coordinate, max_coordinate From c89dcdf62c2a704df00874e81b5a3ee076aacf4d Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 2 Sep 2024 13:59:35 -0700 Subject: [PATCH 05/22] replace append with insert --- src/spatialdata/_core/query/spatial_query.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index fbfbe9c89..6c12b792c 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -244,11 +244,11 @@ def _adjust_bounding_box_to_real_axes( # if there are axes in the output axes that are not in the bounding box, we need to add them to the bounding box # with a range that includes everything (e.g. querying 3D points with a 2D bounding box) + M = np.finfo(np.float32).max - 1 for ax in axes_only_in_output: axes_bb = axes_bb + (ax,) - M = np.finfo(np.float32).max - 1 - min_coordinate = np.append(min_coordinate, -M, axis=axis) - max_coordinate = np.append(max_coordinate, M, axis=axis) + min_coordinate = np.insert(min_coordinate, min_coordinate.shape[axis], -M, axis=axis) + max_coordinate = np.insert(max_coordinate, max_coordinate.shape[axis], M, axis=axis) else: indices = [axes_bb.index(ax) for ax in axes_out_without_c] min_coordinate = np.take(min_coordinate, indices, axis=axis) From 5bf0b43e1756e8db4f273ae2ef2c7e5917853d49 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 2 Sep 2024 14:10:07 -0700 Subject: [PATCH 06/22] add comment --- src/spatialdata/_core/query/spatial_query.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 6c12b792c..00abbba3f 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -230,6 +230,7 @@ def _adjust_bounding_box_to_real_axes( The bounding box is defined by the user and its axes may not coincide with the axes of the transformation. """ + # axis for slicing, if axis > 0, then the min_/max_coordinate multiple bounding boxes along axis 0 axis = min_coordinate.ndim - 1 if set(axes_bb) != set(axes_out_without_c): axes_only_in_bb = set(axes_bb) - set(axes_out_without_c) From a60bf6f3ee1f645fb3bef81d4f4ff4751633b0f0 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 2 Sep 2024 15:23:26 -0700 Subject: [PATCH 07/22] vectorize --- src/spatialdata/_core/query/_utils.py | 67 +++++++++++++++++---------- 1 file changed, 43 insertions(+), 24 deletions(-) diff --git a/src/spatialdata/_core/query/_utils.py b/src/spatialdata/_core/query/_utils.py index 3b63470ed..c79e31d73 100644 --- a/src/spatialdata/_core/query/_utils.py +++ b/src/spatialdata/_core/query/_utils.py @@ -2,6 +2,7 @@ from typing import Any +import numpy as np from anndata import AnnData from xarray import DataArray @@ -36,37 +37,55 @@ def get_bounding_box_corners( min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) - if len(min_coordinate) not in (2, 3): + if min_coordinate.ndim == 1: + min_coordinate = min_coordinate[np.newaxis, :] + max_coordinate = max_coordinate[np.newaxis, :] + + if min_coordinate.shape[1] not in (2, 3): raise ValueError("bounding box must be 2D or 3D") - if len(min_coordinate) == 2: + num_boxes = min_coordinate.shape[0] + num_dims = min_coordinate.shape[1] + + if num_dims == 2: # 2D bounding box assert len(axes) == 2 - return DataArray( + corners = np.array( [ - [min_coordinate[0], min_coordinate[1]], - [min_coordinate[0], max_coordinate[1]], - [max_coordinate[0], max_coordinate[1]], - [max_coordinate[0], min_coordinate[1]], - ], - coords={"corner": range(4), "axis": list(axes)}, + [min_coordinate[:, 0], min_coordinate[:, 1]], + [min_coordinate[:, 0], max_coordinate[:, 1]], + [max_coordinate[:, 0], max_coordinate[:, 1]], + [max_coordinate[:, 0], min_coordinate[:, 1]], + ] ) - - # 3D bounding cube - assert len(axes) == 3 - return DataArray( - [ - [min_coordinate[0], min_coordinate[1], min_coordinate[2]], - [min_coordinate[0], min_coordinate[1], max_coordinate[2]], - [min_coordinate[0], max_coordinate[1], max_coordinate[2]], - [min_coordinate[0], max_coordinate[1], min_coordinate[2]], - [max_coordinate[0], min_coordinate[1], min_coordinate[2]], - [max_coordinate[0], min_coordinate[1], max_coordinate[2]], - [max_coordinate[0], max_coordinate[1], max_coordinate[2]], - [max_coordinate[0], max_coordinate[1], min_coordinate[2]], - ], - coords={"corner": range(8), "axis": list(axes)}, + corners = np.transpose(corners, (2, 0, 1)) + else: + # 3D bounding cube + assert len(axes) == 3 + corners = np.array( + [ + [min_coordinate[:, 0], min_coordinate[:, 1], min_coordinate[:, 2]], + [min_coordinate[:, 0], min_coordinate[:, 1], max_coordinate[:, 2]], + [min_coordinate[:, 0], max_coordinate[:, 1], max_coordinate[:, 2]], + [min_coordinate[:, 0], max_coordinate[:, 1], min_coordinate[:, 2]], + [max_coordinate[:, 0], min_coordinate[:, 1], min_coordinate[:, 2]], + [max_coordinate[:, 0], min_coordinate[:, 1], max_coordinate[:, 2]], + [max_coordinate[:, 0], max_coordinate[:, 1], max_coordinate[:, 2]], + [max_coordinate[:, 0], max_coordinate[:, 1], min_coordinate[:, 2]], + ] + ) + corners = np.transpose(corners, (2, 0, 1)) + output = DataArray( + corners, + coords={ + "box": range(num_boxes), + "corner": range(corners.shape[1]), + "axis": list(axes), + }, ) + if num_boxes > 1: + return output + return output.squeeze().drop_vars("box") def _get_filtered_or_unfiltered_tables( From 017967b57e384974f1b3a5e0ff549f8f81039ddf Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 2 Sep 2024 16:19:32 -0700 Subject: [PATCH 08/22] update to handle multiple boxes --- src/spatialdata/_core/query/spatial_query.py | 67 +++++++++++++++----- 1 file changed, 50 insertions(+), 17 deletions(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 00abbba3f..b575d67a8 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -120,10 +120,18 @@ def _get_bounding_box_corners_in_intrinsic_coordinates( intrinsic_bounding_box_corners = bounding_box_corners.data @ rotation_matrix.T + translation + if bounding_box_corners.ndim > 2: # multiple boxes + coords = { + "box": range(len(bounding_box_corners)), + "corner": range(len(bounding_box_corners)), + "axis": list(inverse.output_axes), + } + else: + coords = {"corner": range(len(bounding_box_corners)), "axis": list(inverse.output_axes)} return ( DataArray( intrinsic_bounding_box_corners, - coords={"corner": range(len(bounding_box_corners)), "axis": list(inverse.output_axes)}, + coords=coords, ), input_axes_without_c, ) @@ -534,22 +542,47 @@ def _( # build the request: now that we have the bounding box corners in the intrinsic coordinate system, we can use them # to build the request to query the raster data using the xarray APIs - selection = {} - translation_vector = [] - for axis_name in axes: - # get the min value along the axis - min_value = intrinsic_bounding_box_corners.sel(axis=axis_name).min().item() - - # get max value, slices are open half interval - max_value = intrinsic_bounding_box_corners.sel(axis=axis_name).max().item() - - # add the - selection[axis_name] = slice(min_value, max_value) - - if min_value > 0: - translation_vector.append(np.ceil(min_value).item()) - else: - translation_vector.append(0) + # selection = {} + # translation_vector = [] + # for axis_name in axes: + # # get the min value along the axis + # min_value = intrinsic_bounding_box_corners.sel(axis=axis_name).min().item() + + # # get max value, slices are open half interval + # max_value = intrinsic_bounding_box_corners.sel(axis=axis_name).max().item() + + # # add the + # selection[axis_name] = slice(min_value, max_value) + + # if min_value > 0: + # translation_vector.append(np.ceil(min_value).item()) + # else: + # translation_vector.append(0) + + min_values = intrinsic_bounding_box_corners.min(dim="corner") + max_values = intrinsic_bounding_box_corners.max(dim="corner") + + # Convert to numpy arrays for faster operations + min_values_np = min_values.values + max_values_np = max_values.values + + if min_values.ndim == 2: # Multiple boxes + slices = np.array( + [ + [slice(min_val, max_val) for min_val, max_val in zip(box_min, box_max)] + for box_min, box_max in zip(min_values_np, max_values_np) + ] + ) + translation_vectors = np.ceil(np.maximum(min_values_np, 0)) + selection: list[dict[str, Any]] | dict[str, Any] = [ + {axis: slices[box_idx, axis_idx] for axis_idx, axis in enumerate(axes)} + for box_idx in range(len(min_values_np)) + ] + translation_vector = translation_vectors.tolist() + else: # Single box + slices = np.array([slice(min_val, max_val) for min_val, max_val in zip(min_values_np, max_values_np)]) + translation_vector = np.ceil(np.maximum(min_values_np, 0)).tolist() + selection = {axis: slices[axis_idx] for axis_idx, axis in enumerate(axes)} if return_request_only: return selection From ab774b7912cb3eb7f5cfb93de41df8fe36c90485 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 2 Sep 2024 16:35:46 -0700 Subject: [PATCH 09/22] vectorize with numba --- src/spatialdata/_core/query/spatial_query.py | 67 ++++++++++---------- 1 file changed, 33 insertions(+), 34 deletions(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index b575d67a8..dd0e8ffa7 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -9,6 +9,7 @@ import dask.array as da import dask.dataframe as dd +import numba as nb import numpy as np from dask.dataframe import DataFrame as DaskDataFrame from datatree import DataTree @@ -44,6 +45,24 @@ ) +@nb.njit(parallel=False, nopython=True) +def create_slices_and_translation( + min_values: nb.types.Array[nb.float64, nb.float64], + max_values: nb.types.Array[nb.float64, nb.float64], +) -> tuple[nb.types.Array[nb.float64, nb.float64], nb.types.Array[nb.float64, nb.float64]]: + n_boxes, n_dims = min_values.shape + slices = np.empty((n_boxes, n_dims, 2), dtype=np.float64) # (n_boxes, n_dims, [min, max]) + translation_vectors = np.empty((n_boxes, n_dims), dtype=np.float64) # (n_boxes, n_dims) + + for i in range(n_boxes): + for j in range(n_dims): + slices[i, j, 0] = min_values[i, j] + slices[i, j, 1] = max_values[i, j] + translation_vectors[i, j] = np.ceil(max(min_values[i, j], 0)) + + return slices, translation_vectors + + def _get_bounding_box_corners_in_intrinsic_coordinates( element: SpatialElement, axes: tuple[str, ...], @@ -540,49 +559,30 @@ def _( if TYPE_CHECKING: assert isinstance(intrinsic_bounding_box_corners, DataArray) - # build the request: now that we have the bounding box corners in the intrinsic coordinate system, we can use them - # to build the request to query the raster data using the xarray APIs - # selection = {} - # translation_vector = [] - # for axis_name in axes: - # # get the min value along the axis - # min_value = intrinsic_bounding_box_corners.sel(axis=axis_name).min().item() - - # # get max value, slices are open half interval - # max_value = intrinsic_bounding_box_corners.sel(axis=axis_name).max().item() - - # # add the - # selection[axis_name] = slice(min_value, max_value) - - # if min_value > 0: - # translation_vector.append(np.ceil(min_value).item()) - # else: - # translation_vector.append(0) - min_values = intrinsic_bounding_box_corners.min(dim="corner") max_values = intrinsic_bounding_box_corners.max(dim="corner") - # Convert to numpy arrays for faster operations - min_values_np = min_values.values - max_values_np = max_values.values + min_values_np = min_values.data + max_values_np = max_values.data + + if min_values_np.ndim == 1: + min_values_np = min_values_np[np.newaxis, :] + max_values_np = max_values_np[np.newaxis, :] + + slices, translation_vectors = create_slices_and_translation(min_values_np, max_values_np) if min_values.ndim == 2: # Multiple boxes - slices = np.array( - [ - [slice(min_val, max_val) for min_val, max_val in zip(box_min, box_max)] - for box_min, box_max in zip(min_values_np, max_values_np) - ] - ) - translation_vectors = np.ceil(np.maximum(min_values_np, 0)) selection: list[dict[str, Any]] | dict[str, Any] = [ - {axis: slices[box_idx, axis_idx] for axis_idx, axis in enumerate(axes)} + { + axis: slice(slices[box_idx, axis_idx, 0], slices[box_idx, axis_idx, 1]) + for axis_idx, axis in enumerate(axes) + } for box_idx in range(len(min_values_np)) ] translation_vector = translation_vectors.tolist() else: # Single box - slices = np.array([slice(min_val, max_val) for min_val, max_val in zip(min_values_np, max_values_np)]) - translation_vector = np.ceil(np.maximum(min_values_np, 0)).tolist() - selection = {axis: slices[axis_idx] for axis_idx, axis in enumerate(axes)} + selection = {axis: slice(slices[0, axis_idx, 0], slices[0, axis_idx, 1]) for axis_idx, axis in enumerate(axes)} + translation_vector = translation_vectors[0].tolist() if return_request_only: return selection @@ -858,7 +858,6 @@ def _( images: bool = True, labels: bool = True, ) -> SpatialData: - _check_deprecated_kwargs({"shapes": shapes, "points": points, "images": images, "labels": labels}) new_elements = {} for element_type in ["points", "images", "labels", "shapes"]: From 38dba2528a1f591bc9f52836b2c9cec49664d3df Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 2 Sep 2024 16:56:59 -0700 Subject: [PATCH 10/22] fix corner len --- src/spatialdata/_core/query/spatial_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index dd0e8ffa7..792248f32 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -142,7 +142,7 @@ def _get_bounding_box_corners_in_intrinsic_coordinates( if bounding_box_corners.ndim > 2: # multiple boxes coords = { "box": range(len(bounding_box_corners)), - "corner": range(len(bounding_box_corners)), + "corner": range(bounding_box_corners.shape[1]), "axis": list(inverse.output_axes), } else: From b27607e91962123c01b1d91df7014fad05010721 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 2 Sep 2024 17:02:37 -0700 Subject: [PATCH 11/22] update --- src/spatialdata/_core/query/spatial_query.py | 12 ++++----- src/spatialdata/dataloader/datasets.py | 28 +++++++++++++------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 792248f32..0107b49b4 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -546,12 +546,12 @@ def _( max_coordinate = _parse_list_into_array(max_coordinate) # for triggering validation - _ = BoundingBoxRequest( - target_coordinate_system=target_coordinate_system, - axes=axes, - min_coordinate=min_coordinate, - max_coordinate=max_coordinate, - ) + # _ = BoundingBoxRequest( + # target_coordinate_system=target_coordinate_system, + # axes=axes, + # min_coordinate=min_coordinate, + # max_coordinate=max_coordinate, + # ) intrinsic_bounding_box_corners, axes = _get_bounding_box_corners_in_intrinsic_coordinates( image, axes, min_coordinate, max_coordinate, target_coordinate_system diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 6b7ec7ae9..f6687d6b9 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -251,17 +251,25 @@ def _preprocess( tile_scale=tile_scale, tile_dim_in_units=tile_dim_in_units, ) - tile_coords["selection"] = tile_coords.apply( - lambda row, cs=cs, image_name=image_name: bounding_box_query( - self.sdata[image_name], - ("x", "y"), - min_coordinate=row[["minx", "miny"]].values, - max_coordinate=row[["maxx", "maxy"]].values, - target_coordinate_system=cs, - return_request_only=True, - ), - axis=1, + tile_coords["selection"] = bounding_box_query( + self.sdata[image_name], + ("x", "y"), + min_coordinate=tile_coords[["minx", "miny"]].values, + max_coordinate=tile_coords[["maxx", "maxy"]].values, + target_coordinate_system=cs, + return_request_only=True, ) + # tile_coords["selection"] = tile_coords.apply( + # lambda row, cs=cs, image_name=image_name: bounding_box_query( + # self.sdata[image_name], + # ("x", "y"), + # min_coordinate=row[["minx", "miny"]].values, + # max_coordinate=row[["maxx", "maxy"]].values, + # target_coordinate_system=cs, + # return_request_only=True, + # ), + # axis=1, + # ) tile_coords_df.append(tile_coords) inst = circles.index.values From a934e21a0421c925580a4319148e49f45f2e8bfd Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 2 Sep 2024 17:16:30 -0700 Subject: [PATCH 12/22] fix validation --- src/spatialdata/_core/query/spatial_query.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 792248f32..fb363121d 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -380,7 +380,12 @@ def __post_init__(self) -> None: raise ValueError(f"Non-spatial axes specified: {non_spatial_axes}") # validate the axes - if len(self.axes) != len(self.min_coordinate) or len(self.axes) != len(self.max_coordinate): + if self.min_coordinate.shape != self.max_coordinate.shape: + raise ValueError("The `min_coordinate` and `max_coordinate` must have the same shape.") + + n_axes_coordinate = len(self.min_coordinate) if self.min_coordinate.ndim == 1 else self.min_coordinate.shape[1] + + if len(self.axes) != n_axes_coordinate: raise ValueError("The number of axes must match the number of coordinates.") # validate the coordinates From 77f73f471fbd61eac0c00cf96652d9ac76f0af95 Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 3 Sep 2024 14:26:32 -0700 Subject: [PATCH 13/22] refactor --- src/spatialdata/_core/query/spatial_query.py | 68 +++++++++++--------- 1 file changed, 38 insertions(+), 30 deletions(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index fb363121d..260b22896 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -528,6 +528,40 @@ def _( return SpatialData(**new_elements, tables=tables) +def _process_data_tree_query_result(query_result: DataTree) -> DataTree | None: + d = {} + for k, data_tree in query_result.items(): + v = data_tree.values() + assert len(v) == 1 + xdata = v.__iter__().__next__() + if 0 in xdata.shape: + if k == "scale0": + return None + else: + d[k] = xdata + + # Remove scales after finding a missing scale + scales_to_keep = [] + for i, scale_name in enumerate(d.keys()): + if scale_name == f"scale{i}": + scales_to_keep.append(scale_name) + else: + break + + # Case in which scale0 is not present but other scales are + if len(scales_to_keep) == 0: + return None + + d = {k: d[k] for k in scales_to_keep} + result = DataTree.from_dict(d) + + # Rechunk the data to avoid irregular chunks + for scale in result: + result[scale]["image"] = result[scale]["image"].chunk("auto") + + return result + + @bounding_box_query.register(DataArray) @bounding_box_query.register(DataTree) def _( @@ -593,46 +627,20 @@ def _( return selection # query the data - query_result = image.sel(selection) + query_result = image.sel(selection) if isinstance(selection, dict) else [image.sel(sel) for sel in selection] if isinstance(image, DataArray): if 0 in query_result.shape: return None assert isinstance(query_result, DataArray) # rechunk the data to avoid irregular chunks - image = image.chunk("auto") + query_result = query_result.chunk("auto") else: assert isinstance(image, DataTree) assert isinstance(query_result, DataTree) - - d = {} - for k, data_tree in query_result.items(): - v = data_tree.values() - assert len(v) == 1 - xdata = v.__iter__().__next__() - if 0 in xdata.shape: - if k == "scale0": - return None - else: - d[k] = xdata - # the list of scales may not be contiguous when the data has small shape (for instance with yx = 22 and - # rotations we may end up having scale0 and scale2 but not scale1. Practically this may occur in torch tiler if - # the tiles are request to be too small). - # Here we remove scales after we found a scale missing - scales_to_keep = [] - for i, scale_name in enumerate(d.keys()): - if scale_name == f"scale{i}": - scales_to_keep.append(scale_name) - else: - break - # case in which scale0 is not present but other scales are - if len(scales_to_keep) == 0: + query_result = _process_data_tree_query_result(query_result) + if query_result is None: return None - d = {k: d[k] for k in scales_to_keep} - query_result = DataTree.from_dict(d) - # rechunk the data to avoid irregular chunks - for scale in query_result: - query_result[scale]["image"] = query_result[scale]["image"].chunk("auto") query_result = compute_coordinates(query_result) # the bounding box, mapped back to the intrinsic coordinate system is a set of points. The bounding box of these From 3adfea8655d5a80f55d34b092db40dc3cddcfd3b Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 3 Sep 2024 14:42:53 -0700 Subject: [PATCH 14/22] refactor --- src/spatialdata/_core/query/_utils.py | 99 +++++++++++++++++ src/spatialdata/_core/query/spatial_query.py | 108 +++---------------- 2 files changed, 113 insertions(+), 94 deletions(-) diff --git a/src/spatialdata/_core/query/_utils.py b/src/spatialdata/_core/query/_utils.py index c79e31d73..e45f19b44 100644 --- a/src/spatialdata/_core/query/_utils.py +++ b/src/spatialdata/_core/query/_utils.py @@ -2,14 +2,22 @@ from typing import Any +import numba as nb import numpy as np from anndata import AnnData +from datatree import DataTree from xarray import DataArray from spatialdata._core._elements import Tables from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike from spatialdata._utils import Number, _parse_list_into_array +from spatialdata.transformations._utils import compute_coordinates +from spatialdata.transformations.transformations import ( + BaseTransformation, + Sequence, + Translation, +) def get_bounding_box_corners( @@ -88,6 +96,97 @@ def get_bounding_box_corners( return output.squeeze().drop_vars("box") +@nb.njit(parallel=False, nopython=True) +def _create_slices_and_translation( + min_values: nb.types.Array[nb.float64, nb.float64], + max_values: nb.types.Array[nb.float64, nb.float64], +) -> tuple[nb.types.Array[nb.float64, nb.float64], nb.types.Array[nb.float64, nb.float64]]: + n_boxes, n_dims = min_values.shape + slices = np.empty((n_boxes, n_dims, 2), dtype=np.float64) # (n_boxes, n_dims, [min, max]) + translation_vectors = np.empty((n_boxes, n_dims), dtype=np.float64) # (n_boxes, n_dims) + + for i in range(n_boxes): + for j in range(n_dims): + slices[i, j, 0] = min_values[i, j] + slices[i, j, 1] = max_values[i, j] + translation_vectors[i, j] = np.ceil(max(min_values[i, j], 0)) + + return slices, translation_vectors + + +def _process_data_tree_query_result(query_result: DataTree) -> DataTree | None: + d = {} + for k, data_tree in query_result.items(): + v = data_tree.values() + assert len(v) == 1 + xdata = v.__iter__().__next__() + if 0 in xdata.shape: + if k == "scale0": + return None + else: + d[k] = xdata + + # Remove scales after finding a missing scale + scales_to_keep = [] + for i, scale_name in enumerate(d.keys()): + if scale_name == f"scale{i}": + scales_to_keep.append(scale_name) + else: + break + + # Case in which scale0 is not present but other scales are + if len(scales_to_keep) == 0: + return None + + d = {k: d[k] for k in scales_to_keep} + result = DataTree.from_dict(d) + + # Rechunk the data to avoid irregular chunks + for scale in result: + result[scale]["image"] = result[scale]["image"].chunk("auto") + + return result + + +def _process_query_result( + result: DataArray | DataTree, translation_vector: ArrayLike, axes: tuple[str, ...] +) -> DataArray | DataTree | None: + from spatialdata.transformations import get_transformation, set_transformation + + if isinstance(result, DataArray): + if 0 in result.shape: + return None + # rechunk the data to avoid irregular chunks + result = result.chunk("auto") + elif isinstance(result, DataTree): + result = _process_data_tree_query_result(result) + if result is None: + return None + + result = compute_coordinates(result) + + if not np.allclose(np.array(translation_vector), 0): + translation_transform = Translation(translation=translation_vector, axes=axes) + + transformations = get_transformation(result, get_all=True) + assert isinstance(transformations, dict) + + new_transformations = {} + for coordinate_system, initial_transform in transformations.items(): + new_transformation: BaseTransformation = Sequence( + [translation_transform, initial_transform], + ) + new_transformations[coordinate_system] = new_transformation + set_transformation(result, new_transformations, set_all=True) + + # let's make a copy of the transformations so that we don't modify the original object + t = get_transformation(result, get_all=True) + assert isinstance(t, dict) + set_transformation(result, t.copy(), set_all=True) + + return result + + def _get_filtered_or_unfiltered_tables( filter_table: bool, elements: dict[str, Any], sdata: SpatialData ) -> dict[str, AnnData] | Tables: diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 260b22896..77c4a934b 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -9,7 +9,6 @@ import dask.array as da import dask.dataframe as dd -import numba as nb import numpy as np from dask.dataframe import DataFrame as DaskDataFrame from datatree import DataTree @@ -34,35 +33,14 @@ points_geopandas_to_dask_dataframe, ) from spatialdata.models._utils import ValidAxis_t, get_spatial_axes -from spatialdata.transformations._utils import compute_coordinates from spatialdata.transformations.operations import set_transformation from spatialdata.transformations.transformations import ( Affine, BaseTransformation, - Sequence, - Translation, _get_affine_for_element, ) -@nb.njit(parallel=False, nopython=True) -def create_slices_and_translation( - min_values: nb.types.Array[nb.float64, nb.float64], - max_values: nb.types.Array[nb.float64, nb.float64], -) -> tuple[nb.types.Array[nb.float64, nb.float64], nb.types.Array[nb.float64, nb.float64]]: - n_boxes, n_dims = min_values.shape - slices = np.empty((n_boxes, n_dims, 2), dtype=np.float64) # (n_boxes, n_dims, [min, max]) - translation_vectors = np.empty((n_boxes, n_dims), dtype=np.float64) # (n_boxes, n_dims) - - for i in range(n_boxes): - for j in range(n_dims): - slices[i, j, 0] = min_values[i, j] - slices[i, j, 1] = max_values[i, j] - translation_vectors[i, j] = np.ceil(max(min_values[i, j], 0)) - - return slices, translation_vectors - - def _get_bounding_box_corners_in_intrinsic_coordinates( element: SpatialElement, axes: tuple[str, ...], @@ -528,40 +506,6 @@ def _( return SpatialData(**new_elements, tables=tables) -def _process_data_tree_query_result(query_result: DataTree) -> DataTree | None: - d = {} - for k, data_tree in query_result.items(): - v = data_tree.values() - assert len(v) == 1 - xdata = v.__iter__().__next__() - if 0 in xdata.shape: - if k == "scale0": - return None - else: - d[k] = xdata - - # Remove scales after finding a missing scale - scales_to_keep = [] - for i, scale_name in enumerate(d.keys()): - if scale_name == f"scale{i}": - scales_to_keep.append(scale_name) - else: - break - - # Case in which scale0 is not present but other scales are - if len(scales_to_keep) == 0: - return None - - d = {k: d[k] for k in scales_to_keep} - result = DataTree.from_dict(d) - - # Rechunk the data to avoid irregular chunks - for scale in result: - result[scale]["image"] = result[scale]["image"].chunk("auto") - - return result - - @bounding_box_query.register(DataArray) @bounding_box_query.register(DataTree) def _( @@ -579,7 +523,7 @@ def _( See https://github.com/scverse/spatialdata/pull/151 for a detailed overview of the logic of this code, and for the cases the comments refer to. """ - from spatialdata.transformations import get_transformation, set_transformation + from spatialdata._core.query._utils import _create_slices_and_translation, _process_query_result min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) @@ -608,7 +552,7 @@ def _( min_values_np = min_values_np[np.newaxis, :] max_values_np = max_values_np[np.newaxis, :] - slices, translation_vectors = create_slices_and_translation(min_values_np, max_values_np) + slices, translation_vectors = _create_slices_and_translation(min_values_np, max_values_np) if min_values.ndim == 2: # Multiple boxes selection: list[dict[str, Any]] | dict[str, Any] = [ @@ -627,43 +571,19 @@ def _( return selection # query the data - query_result = image.sel(selection) if isinstance(selection, dict) else [image.sel(sel) for sel in selection] - if isinstance(image, DataArray): - if 0 in query_result.shape: - return None - assert isinstance(query_result, DataArray) - # rechunk the data to avoid irregular chunks - query_result = query_result.chunk("auto") + query_result: DataArray | DataTree | list[DataArray | DataTree] = ( + image.sel(selection) if isinstance(selection, dict) else [image.sel(sel) for sel in selection] + ) + + if isinstance(query_result, list): + processed_results = [] + for result in query_result: + processed_result = _process_query_result(result, translation_vector, axes) + if processed_result is not None: + processed_results.append(processed_result) + query_result = processed_results if processed_results else None else: - assert isinstance(image, DataTree) - assert isinstance(query_result, DataTree) - query_result = _process_data_tree_query_result(query_result) - if query_result is None: - return None - - query_result = compute_coordinates(query_result) - - # the bounding box, mapped back to the intrinsic coordinate system is a set of points. The bounding box of these - # points is likely starting away from the origin (this is described by translation_vector), so we need to prepend - # this translation to every transformation in the new queries elements (unless the translation_vector is zero, - # in that case the translation is not needed) - if not np.allclose(np.array(translation_vector), 0): - translation_transform = Translation(translation=translation_vector, axes=axes) - - transformations = get_transformation(query_result, get_all=True) - assert isinstance(transformations, dict) - - new_transformations = {} - for coordinate_system, initial_transform in transformations.items(): - new_transformation: BaseTransformation = Sequence( - [translation_transform, initial_transform], - ) - new_transformations[coordinate_system] = new_transformation - set_transformation(query_result, new_transformations, set_all=True) - # let's make a copy of the transformations so that we don't modify the original object - t = get_transformation(query_result, get_all=True) - assert isinstance(t, dict) - set_transformation(query_result, t.copy(), set_all=True) + query_result = _process_query_result(query_result, translation_vector, axes) return query_result From dfdfdbfef7e7ae47a507d9c87ca49e2907c0af42 Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 3 Sep 2024 15:03:26 -0700 Subject: [PATCH 15/22] add test for query with multiple bounding boxes --- src/spatialdata/_core/query/spatial_query.py | 8 +- tests/core/query/test_spatial_query.py | 82 ++++++++++++++------ 2 files changed, 63 insertions(+), 27 deletions(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 77c4a934b..36728adfb 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -562,10 +562,10 @@ def _( } for box_idx in range(len(min_values_np)) ] - translation_vector = translation_vectors.tolist() + translation_vectors = translation_vectors.tolist() else: # Single box selection = {axis: slice(slices[0, axis_idx, 0], slices[0, axis_idx, 1]) for axis_idx, axis in enumerate(axes)} - translation_vector = translation_vectors[0].tolist() + translation_vectors = translation_vectors[0].tolist() if return_request_only: return selection @@ -577,13 +577,13 @@ def _( if isinstance(query_result, list): processed_results = [] - for result in query_result: + for result, translation_vector in zip(query_result, translation_vectors): processed_result = _process_query_result(result, translation_vector, axes) if processed_result is not None: processed_results.append(processed_result) query_result = processed_results if processed_results else None else: - query_result = _process_query_result(query_result, translation_vector, axes) + query_result = _process_query_result(query_result, translation_vectors, axes) return query_result diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index 9444d8e9e..5e4482a81 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -192,8 +192,15 @@ def test_query_points_no_points(): @pytest.mark.parametrize("is_bb_3d", [True, False]) @pytest.mark.parametrize("with_polygon_query", [True, False]) @pytest.mark.parametrize("return_request_only", [True, False]) +@pytest.mark.parametrize("multiple_boxes", [True, False]) def test_query_raster( - n_channels: int, is_labels: bool, is_3d: bool, is_bb_3d: bool, with_polygon_query: bool, return_request_only: bool + n_channels: int, + is_labels: bool, + is_3d: bool, + is_bb_3d: bool, + with_polygon_query: bool, + return_request_only: bool, + multiple_boxes: bool, ): """Apply a bounding box to a raster element.""" if is_labels and n_channels > 1: @@ -232,16 +239,16 @@ def test_query_raster( for image in images: if is_bb_3d: - _min_coordinate = np.array([2, 5, 0]) - _max_coordinate = np.array([7, 10, 5]) + _min_coordinate = np.array([[2, 5, 0], [1, 4, 0]]) if multiple_boxes else np.array([2, 5, 0]) + _max_coordinate = np.array([[7, 10, 5], [6, 9, 4]]) if multiple_boxes else np.array([7, 10, 5]) _axes = ("z", "y", "x") else: - _min_coordinate = np.array([5, 0]) - _max_coordinate = np.array([10, 5]) + _min_coordinate = np.array([[5, 0], [4, 0]]) if multiple_boxes else np.array([5, 0]) + _max_coordinate = np.array([[10, 5], [9, 4]]) if multiple_boxes else np.array([10, 5]) _axes = ("y", "x") if with_polygon_query: - if is_bb_3d: + if is_bb_3d or multiple_boxes: return # make a triangle whose bounding box is the same as the bounding box specified with the query polygon = Polygon([(0, 5), (5, 5), (5, 10)]) @@ -258,29 +265,58 @@ def test_query_raster( return_request_only=return_request_only, ) - slices = {"y": slice(5, 10), "x": slice(0, 5)} - if is_bb_3d and is_3d: - slices["z"] = slice(2, 7) + if multiple_boxes: + slices = [{"y": slice(5, 10), "x": slice(0, 5)}, {"y": slice(4, 9), "x": slice(0, 4)}] + if is_bb_3d and is_3d: + slices[0]["z"] = slice(2, 7) + slices[1]["z"] = slice(1, 6) + else: + slices = {"y": slice(5, 10), "x": slice(0, 5)} + if is_bb_3d and is_3d: + slices["z"] = slice(2, 7) + if return_request_only: - assert isinstance(image_result, dict) - if not (is_bb_3d and is_3d) and ("z" in image_result): - image_result.pop("z") # remove z from slices if `polygon_query` - for k, v in image_result.items(): - assert isinstance(v, slice) - assert image_result[k] == slices[k] + assert isinstance(image_result, (dict, list)) + if multiple_boxes: + for i, result in enumerate(image_result): + if not (is_bb_3d and is_3d) and ("z" in result): + result.pop("z") # remove z from slices if `polygon_query` + for k, v in result.items(): + assert isinstance(v, slice) + assert result[k] == slices[i][k] + else: + if not (is_bb_3d and is_3d) and ("z" in image_result): + image_result.pop("z") # remove z from slices if `polygon_query` + for k, v in image_result.items(): + assert isinstance(v, slice) + assert image_result[k] == slices[k] return - expected_image = ximage.sel(**slices) + if multiple_boxes: + expected_images = [ximage.sel(**s) for s in slices] + else: + expected_image = ximage.sel(**slices) if isinstance(image, DataArray): - assert isinstance(image, DataArray) - np.testing.assert_allclose(image_result, expected_image) + assert isinstance(image_result, (DataArray, list)) + if multiple_boxes: + for result, expected in zip(image_result, expected_images): + np.testing.assert_allclose(result, expected) + else: + np.testing.assert_allclose(image_result, expected_image) elif isinstance(image, DataTree): - assert isinstance(image_result, DataTree) - v = image_result["scale0"].values() - assert len(v) == 1 - xdata = v.__iter__().__next__() - np.testing.assert_allclose(xdata, expected_image) + assert isinstance(image_result, (DataTree, list)) + if multiple_boxes: + for result, expected in zip(image_result, expected_images): + v = result["scale0"].values() + assert len(v) == 1 + xdata = v.__iter__().__next__() + np.testing.assert_allclose(xdata, expected) + else: + v = image_result["scale0"].values() + assert len(v) == 1 + xdata = v.__iter__().__next__() + np.testing.assert_allclose(xdata, expected_image) else: raise ValueError("Unexpected type") From 5c5560d517471ff6f67eb9b01dda2d58f02aeaa4 Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 3 Sep 2024 15:08:50 -0700 Subject: [PATCH 16/22] fix typing --- src/spatialdata/_core/query/spatial_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 36728adfb..dc85556a2 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -515,7 +515,7 @@ def _( max_coordinate: list[Number] | ArrayLike, target_coordinate_system: str, return_request_only: bool = False, -) -> DataArray | DataTree | Mapping[str, slice] | None: +) -> DataArray | DataTree | Mapping[str, slice] | list[DataArray | DataTree] | None: """Implement bounding box query for Spatialdata supported DataArray. Notes From dd2c573d61932b06110dec07bc447f1496310c6a Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 3 Sep 2024 16:09:09 -0700 Subject: [PATCH 17/22] vectorize bounding box query on polygons --- src/spatialdata/_core/query/spatial_query.py | 34 ++++++++++++++------ tests/core/query/test_spatial_query.py | 21 ++++++++---- 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index dc85556a2..9cb78b64d 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -673,7 +673,7 @@ def _( min_coordinate: list[Number] | ArrayLike, max_coordinate: list[Number] | ArrayLike, target_coordinate_system: str, -) -> GeoDataFrame | None: +) -> GeoDataFrame | list[GeoDataFrame] | None: from spatialdata.transformations import get_transformation min_coordinate = _parse_list_into_array(min_coordinate) @@ -695,16 +695,32 @@ def _( max_coordinate=max_coordinate, target_coordinate_system=target_coordinate_system, ) - intrinsic_bounding_box_corners = intrinsic_bounding_box_corners.data - bounding_box_non_axes_aligned = Polygon(intrinsic_bounding_box_corners) - indices = polygons.geometry.intersects(bounding_box_non_axes_aligned) - queried = polygons[indices] - if len(queried) == 0: - return None + + # Create a list of Polygons for each bounding box old_transformations = get_transformation(polygons, get_all=True) assert isinstance(old_transformations, dict) - del queried.attrs[ShapesModel.TRANSFORM_KEY] - return ShapesModel.parse(queried, transformations=old_transformations.copy()) + + queried_polygons = [] + intrinsic_bounding_box_corners = ( + intrinsic_bounding_box_corners.expand_dims(dim="box") + if "box" not in intrinsic_bounding_box_corners.dims + else intrinsic_bounding_box_corners + ) + for box_corners in intrinsic_bounding_box_corners: + bounding_box_non_axes_aligned = Polygon(box_corners.data) + indices = polygons.geometry.intersects(bounding_box_non_axes_aligned) + queried = polygons[indices] + if len(queried) == 0: + queried_polygon = None + else: + del queried.attrs[ShapesModel.TRANSFORM_KEY] + queried_polygon = ShapesModel.parse(queried, transformations=old_transformations.copy()) + queried_polygons.append(queried_polygon) + if len(queried_polygons) == 0: + return None + if len(queried_polygons) == 1: + return queried_polygons[0] + return queried_polygons # TODO: we can replace the manually triggered deprecation warning heres with the decorator from Wouter diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index 5e4482a81..c80270171 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -323,7 +323,8 @@ def test_query_raster( @pytest.mark.parametrize("is_bb_3d", [True, False]) @pytest.mark.parametrize("with_polygon_query", [True, False]) -def test_query_polygons(is_bb_3d: bool, with_polygon_query: bool): +@pytest.mark.parametrize("multiple_boxes", [True, False]) +def test_query_polygons(is_bb_3d: bool, with_polygon_query: bool, multiple_boxes: bool): centroids = np.array([[10, 10], [10, 80], [80, 20], [70, 60]]) half_widths = [6] * 4 sd_polygons = _make_squares(centroid_coordinates=centroids, half_widths=half_widths) @@ -339,12 +340,12 @@ def test_query_polygons(is_bb_3d: bool, with_polygon_query: bool): ) else: if is_bb_3d: - _min_coordinate = np.array([2, 40, 40]) - _max_coordinate = np.array([7, 100, 100]) + _min_coordinate = np.array([[2, 40, 40], [2, 50, 50]]) if multiple_boxes else np.array([2, 40, 40]) + _max_coordinate = np.array([[7, 100, 100], [7, 110, 110]]) if multiple_boxes else np.array([7, 100, 100]) _axes = ("z", "y", "x") else: - _min_coordinate = np.array([40, 40]) - _max_coordinate = np.array([100, 100]) + _min_coordinate = np.array([[40, 40], [50, 50]]) if multiple_boxes else np.array([40, 40]) + _max_coordinate = np.array([[100, 100], [110, 110]]) if multiple_boxes else np.array([100, 100]) _axes = ("y", "x") polygons_result = bounding_box_query( @@ -355,8 +356,14 @@ def test_query_polygons(is_bb_3d: bool, with_polygon_query: bool): max_coordinate=_max_coordinate, ) - assert len(polygons_result) == 1 - assert polygons_result.index[0] == 3 + if multiple_boxes and not with_polygon_query: + assert isinstance(polygons_result, list) + assert len(polygons_result) == 2 + assert polygons_result[0].index[0] == 3 + assert len(polygons_result[1]) == 1 + else: + assert len(polygons_result) == 1 + assert polygons_result.index[0] == 3 @pytest.mark.parametrize("is_bb_3d", [True, False]) From be9535834ea0ae1d89597fef0acefb669b6f7b5e Mon Sep 17 00:00:00 2001 From: giovp Date: Tue, 3 Sep 2024 16:31:31 -0700 Subject: [PATCH 18/22] add test to cover no polygon overlap (None) --- tests/core/query/test_spatial_query.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index c80270171..e58e3424e 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -324,7 +324,8 @@ def test_query_raster( @pytest.mark.parametrize("is_bb_3d", [True, False]) @pytest.mark.parametrize("with_polygon_query", [True, False]) @pytest.mark.parametrize("multiple_boxes", [True, False]) -def test_query_polygons(is_bb_3d: bool, with_polygon_query: bool, multiple_boxes: bool): +@pytest.mark.parametrize("box_outside_polygon", [True, False]) +def test_query_polygons(is_bb_3d: bool, with_polygon_query: bool, multiple_boxes: bool, box_outside_polygon: bool): centroids = np.array([[10, 10], [10, 80], [80, 20], [70, 60]]) half_widths = [6] * 4 sd_polygons = _make_squares(centroid_coordinates=centroids, half_widths=half_widths) @@ -342,10 +343,18 @@ def test_query_polygons(is_bb_3d: bool, with_polygon_query: bool, multiple_boxes if is_bb_3d: _min_coordinate = np.array([[2, 40, 40], [2, 50, 50]]) if multiple_boxes else np.array([2, 40, 40]) _max_coordinate = np.array([[7, 100, 100], [7, 110, 110]]) if multiple_boxes else np.array([7, 100, 100]) + if box_outside_polygon: + _min_coordinate = np.array([[2, 100, 100], [2, 50, 50]]) if multiple_boxes else np.array([2, 40, 40]) + _max_coordinate = ( + np.array([[7, 110, 110], [7, 110, 110]]) if multiple_boxes else np.array([7, 100, 100]) + ) _axes = ("z", "y", "x") else: _min_coordinate = np.array([[40, 40], [50, 50]]) if multiple_boxes else np.array([40, 40]) _max_coordinate = np.array([[100, 100], [110, 110]]) if multiple_boxes else np.array([100, 100]) + if box_outside_polygon: + _min_coordinate = np.array([[100, 100], [50, 50]]) if multiple_boxes else np.array([40, 40]) + _max_coordinate = np.array([[110, 110], [110, 110]]) if multiple_boxes else np.array([100, 100]) _axes = ("y", "x") polygons_result = bounding_box_query( @@ -359,8 +368,13 @@ def test_query_polygons(is_bb_3d: bool, with_polygon_query: bool, multiple_boxes if multiple_boxes and not with_polygon_query: assert isinstance(polygons_result, list) assert len(polygons_result) == 2 - assert polygons_result[0].index[0] == 3 - assert len(polygons_result[1]) == 1 + if box_outside_polygon: + + assert polygons_result[0] is None + assert polygons_result[1].index[0] == 3 + else: + assert polygons_result[0].index[0] == 3 + assert len(polygons_result[1]) == 1 else: assert len(polygons_result) == 1 assert polygons_result.index[0] == 3 From fad9b1aa2dc72e1c16985efea8ee51f846cfef0a Mon Sep 17 00:00:00 2001 From: giovp Date: Wed, 4 Sep 2024 12:52:35 -0700 Subject: [PATCH 19/22] vectorize bounding box query on points and tests --- src/spatialdata/_core/query/spatial_query.py | 123 +++++++++++++------ tests/core/query/test_spatial_query.py | 72 ++++++++--- 2 files changed, 138 insertions(+), 57 deletions(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 9cb78b64d..31e75864c 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -385,7 +385,7 @@ def _bounding_box_mask_points( min_coordinate: list[Number] | ArrayLike, max_coordinate: list[Number] | ArrayLike, ) -> da.Array: - """Compute a mask that is true for the points inside an axis-aligned bounding box. + """Compute a mask that is true for the points inside axis-aligned bounding boxes. Parameters ---------- @@ -394,30 +394,42 @@ def _bounding_box_mask_points( axes The axes that min_coordinate and max_coordinate refer to. min_coordinate - The upper left hand corner of the bounding box (i.e., minimum coordinates along all dimensions). + The upper left hand corners of the bounding boxes (i.e., minimum coordinates along all dimensions). + Shape: (n_boxes, n_axes) or (n_axes,) for a single box. max_coordinate - The lower right hand corner of the bounding box (i.e., the maximum coordinates along all dimensions). + The lower right hand corners of the bounding boxes (i.e., the maximum coordinates along all dimensions). + Shape: (n_boxes, n_axes) or (n_axes,) for a single box. Returns ------- - The mask for the points inside the bounding box. + The masks for the points inside the bounding boxes. """ element_axes = get_axes_names(points) + min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) + + # Ensure min_coordinate and max_coordinate are 2D arrays + min_coordinate = min_coordinate[np.newaxis, :] if min_coordinate.ndim == 1 else min_coordinate + max_coordinate = max_coordinate[np.newaxis, :] if max_coordinate.ndim == 1 else max_coordinate + + n_boxes = min_coordinate.shape[0] in_bounding_box_masks = [] - for axis_index, axis_name in enumerate(axes): - if axis_name not in element_axes: - continue - min_value = min_coordinate[axis_index] - in_bounding_box_masks.append(points[axis_name].gt(min_value).to_dask_array(lengths=True)) - for axis_index, axis_name in enumerate(axes): - if axis_name not in element_axes: - continue - max_value = max_coordinate[axis_index] - in_bounding_box_masks.append(points[axis_name].lt(max_value).to_dask_array(lengths=True)) - in_bounding_box_masks = da.stack(in_bounding_box_masks, axis=-1) - return da.all(in_bounding_box_masks, axis=1) + + for box in range(n_boxes): + box_masks = [] + for axis_index, axis_name in enumerate(axes): + if axis_name not in element_axes: + continue + min_value = min_coordinate[box, axis_index] + max_value = max_coordinate[box, axis_index] + box_masks.append( + points[axis_name].gt(min_value).to_dask_array(lengths=True) + & points[axis_name].lt(max_value).to_dask_array(lengths=True) + ) + bounding_box_mask = da.stack(box_masks, axis=-1) + in_bounding_box_masks.append(da.all(bounding_box_mask, axis=1)) + return in_bounding_box_masks def _dict_query_dispatcher( @@ -601,6 +613,10 @@ def _( min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) + # Ensure min_coordinate and max_coordinate are 2D arrays + min_coordinate = min_coordinate[np.newaxis, :] if min_coordinate.ndim == 1 else min_coordinate + max_coordinate = max_coordinate[np.newaxis, :] if max_coordinate.ndim == 1 else max_coordinate + # for triggering validation _ = BoundingBoxRequest( target_coordinate_system=target_coordinate_system, @@ -617,9 +633,11 @@ def _( max_coordinate=max_coordinate, target_coordinate_system=target_coordinate_system, ) - intrinsic_bounding_box_corners = intrinsic_bounding_box_corners.data - min_coordinate_intrinsic = intrinsic_bounding_box_corners.min(axis=0) - max_coordinate_intrinsic = intrinsic_bounding_box_corners.max(axis=0) + min_coordinate_intrinsic = intrinsic_bounding_box_corners.min(dim="corner") + max_coordinate_intrinsic = intrinsic_bounding_box_corners.max(dim="corner") + + min_coordinate_intrinsic = min_coordinate_intrinsic.data + max_coordinate_intrinsic = max_coordinate_intrinsic.data # get the points in the intrinsic coordinate bounding box in_intrinsic_bounding_box = _bounding_box_mask_points( @@ -628,10 +646,20 @@ def _( min_coordinate=min_coordinate_intrinsic, max_coordinate=max_coordinate_intrinsic, ) - # if there aren't any points, just return - if in_intrinsic_bounding_box.sum() == 0: + + # assert that the number of bounding boxes is correct + assert len(in_intrinsic_bounding_box) == len(min_coordinate) + points_in_intrinsic_bounding_box: list[DaskDataFrame | None] = [] + for mask in in_intrinsic_bounding_box: + if mask.sum() == 0: + points_in_intrinsic_bounding_box.append(None) + else: + points_in_intrinsic_bounding_box.append(points.loc[mask]) + if len(points_in_intrinsic_bounding_box) == 0: return None - points_in_intrinsic_bounding_box = points.loc[in_intrinsic_bounding_box] + + # assert that the number of queried points is correct + assert len(points_in_intrinsic_bounding_box) == len(min_coordinate) # # we have to reset the index since we have subset # # https://stackoverflow.com/questions/61395351/how-to-reset-index-on-concatenated-dataframe-in-dask @@ -645,25 +673,42 @@ def _( # points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.drop(columns=["idx"]) # transform the element to the query coordinate system - points_query_coordinate_system = transform( - points_in_intrinsic_bounding_box, to_coordinate_system=target_coordinate_system, maintain_positioning=False - ) # type: ignore[union-attr] + output: list[DaskDataFrame | None] = [] + for p, min_c, max_c in zip(points_in_intrinsic_bounding_box, min_coordinate, max_coordinate): + if p is None: + output.append(None) + else: + points_query_coordinate_system = transform( + p, to_coordinate_system=target_coordinate_system, maintain_positioning=False + ) - # get a mask for the points in the bounding box - bounding_box_mask = _bounding_box_mask_points( - points=points_query_coordinate_system, - axes=axes, - min_coordinate=min_coordinate, - max_coordinate=max_coordinate, - ) - bounding_box_indices = np.where(bounding_box_mask.compute())[0] - if len(bounding_box_indices) == 0: + # get a mask for the points in the bounding box + bounding_box_mask = _bounding_box_mask_points( + points=points_query_coordinate_system, + axes=axes, + min_coordinate=min_c, + max_coordinate=max_c, + ) + if len(bounding_box_mask) == 1: + bounding_box_mask = bounding_box_mask[0] + bounding_box_indices = np.where(bounding_box_mask.compute())[0] + + if len(bounding_box_indices) == 0: + output.append(None) + else: + points_df = p.compute().iloc[bounding_box_indices] + old_transformations = get_transformation(p, get_all=True) + assert isinstance(old_transformations, dict) + output.append( + PointsModel.parse( + dd.from_pandas(points_df, npartitions=1), transformations=old_transformations.copy() + ) + ) + if len(output) == 0: return None - points_df = points_in_intrinsic_bounding_box.compute().iloc[bounding_box_indices] - old_transformations = get_transformation(points, get_all=True) - assert isinstance(old_transformations, dict) - # an alternative approach is to query for each partition in parallel - return PointsModel.parse(dd.from_pandas(points_df, npartitions=1), transformations=old_transformations.copy()) + if len(output) == 1: + return output[0] + return output @bounding_box_query.register(GeoDataFrame) diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index e58e3424e..496bd3e6e 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -108,11 +108,12 @@ def test_bounding_box_request_wrong_coordinate_order(): @pytest.mark.parametrize("is_3d", [True, False]) @pytest.mark.parametrize("is_bb_3d", [True, False]) @pytest.mark.parametrize("with_polygon_query", [True, False]) -def test_query_points(is_3d: bool, is_bb_3d: bool, with_polygon_query: bool): +@pytest.mark.parametrize("multiple_boxes", [True, False]) +def test_query_points(is_3d: bool, is_bb_3d: bool, with_polygon_query: bool, multiple_boxes: bool): """test the points bounding box_query""" - data_x = np.array([10, 20, 20, 20]) - data_y = np.array([10, 20, 30, 30]) - data_z = np.array([100, 200, 200, 300]) + data_x = np.array([10, 20, 20, 20, 40]) + data_y = np.array([10, 20, 30, 30, 50]) + data_z = np.array([100, 200, 200, 300, 500]) data = np.stack((data_x, data_y), axis=1) if is_3d: @@ -125,16 +126,24 @@ def test_query_points(is_3d: bool, is_bb_3d: bool, with_polygon_query: bool): original_z = points_element["z"] if is_bb_3d: - _min_coordinate = np.array([18, 25, 250]) - _max_coordinate = np.array([22, 35, 350]) + if multiple_boxes: + _min_coordinate = np.array([[18, 25, 250], [35, 45, 450], [100, 110, 1100]]) + _max_coordinate = np.array([[22, 35, 350], [45, 55, 550], [110, 120, 1200]]) + else: + _min_coordinate = np.array([18, 25, 250]) + _max_coordinate = np.array([22, 35, 350]) _axes = ("x", "y", "z") else: - _min_coordinate = np.array([18, 25]) - _max_coordinate = np.array([22, 35]) + if multiple_boxes: + _min_coordinate = np.array([[18, 25], [35, 45], [100, 110]]) + _max_coordinate = np.array([[22, 35], [45, 55], [110, 120]]) + else: + _min_coordinate = np.array([18, 25]) + _max_coordinate = np.array([22, 35]) _axes = ("x", "y") if with_polygon_query: - if is_bb_3d: + if is_bb_3d or multiple_boxes: return polygon = Polygon([(18, 25), (18, 35), (22, 35), (22, 25)]) points_result = polygon_query(points_element, polygon=polygon, target_coordinate_system="global") @@ -147,22 +156,49 @@ def test_query_points(is_3d: bool, is_bb_3d: bool, with_polygon_query: bool): target_coordinate_system="global", ) - # Check that the correct point was selected + # Check that the correct points were selected if is_3d: if is_bb_3d: - np.testing.assert_allclose(points_result["x"].compute(), [20]) - np.testing.assert_allclose(points_result["y"].compute(), [30]) - np.testing.assert_allclose(points_result["z"].compute(), [300]) + if multiple_boxes: + np.testing.assert_allclose(points_result[0]["x"].compute(), [20]) + np.testing.assert_allclose(points_result[0]["y"].compute(), [30]) + np.testing.assert_allclose(points_result[0]["z"].compute(), [300]) + np.testing.assert_allclose(points_result[1]["x"].compute(), [40]) + np.testing.assert_allclose(points_result[1]["y"].compute(), [50]) + np.testing.assert_allclose(points_result[1]["z"].compute(), [500]) + else: + np.testing.assert_allclose(points_result["x"].compute(), [20]) + np.testing.assert_allclose(points_result["y"].compute(), [30]) + np.testing.assert_allclose(points_result["z"].compute(), [300]) + else: + if multiple_boxes: + np.testing.assert_allclose(points_result[0]["x"].compute(), [20, 20]) + np.testing.assert_allclose(points_result[0]["y"].compute(), [30, 30]) + np.testing.assert_allclose(points_result[0]["z"].compute(), [200, 300]) + np.testing.assert_allclose(points_result[1]["x"].compute(), [40]) + np.testing.assert_allclose(points_result[1]["y"].compute(), [50]) + np.testing.assert_allclose(points_result[1]["z"].compute(), [500]) + else: + np.testing.assert_allclose(points_result["x"].compute(), [20, 20]) + np.testing.assert_allclose(points_result["y"].compute(), [30, 30]) + np.testing.assert_allclose(points_result["z"].compute(), [200, 300]) + else: + if multiple_boxes: + np.testing.assert_allclose(points_result[0]["x"].compute(), [20, 20]) + np.testing.assert_allclose(points_result[0]["y"].compute(), [30, 30]) + np.testing.assert_allclose(points_result[1]["x"].compute(), [40]) + np.testing.assert_allclose(points_result[1]["y"].compute(), [50]) + assert points_result[2] is None else: np.testing.assert_allclose(points_result["x"].compute(), [20, 20]) np.testing.assert_allclose(points_result["y"].compute(), [30, 30]) - np.testing.assert_allclose(points_result["z"].compute(), [200, 300]) - else: - np.testing.assert_allclose(points_result["x"].compute(), [20, 20]) - np.testing.assert_allclose(points_result["y"].compute(), [30, 30]) # result should be valid points element - PointsModel.validate(points_result) + if multiple_boxes: + for result in points_result: + if result is None: + continue + PointsModel.validate(result) # original element should be unchanged np.testing.assert_allclose(points_element["x"].compute(), original_x) From 9b977d64b43704426406c230663ea7d8294a3ef2 Mon Sep 17 00:00:00 2001 From: giovp Date: Wed, 4 Sep 2024 12:53:52 -0700 Subject: [PATCH 20/22] fix type --- src/spatialdata/_core/query/spatial_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 31e75864c..b08e56be1 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -606,7 +606,7 @@ def _( min_coordinate: list[Number] | ArrayLike, max_coordinate: list[Number] | ArrayLike, target_coordinate_system: str, -) -> DaskDataFrame | None: +) -> DaskDataFrame | list[DaskDataFrame] | None: from spatialdata import transform from spatialdata.transformations import get_transformation From e87d318371cfa31e68b466b238109bc47ea7a15e Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Fri, 15 May 2026 23:30:47 +0200 Subject: [PATCH 21/22] Fix rasterize path and bugs in PR #687 dataloader; add benchmark **Bugs fixed in datasets.py:** - rasterize=True path was broken: __getitem__ always called image.sel() regardless of rasterize flag, bypassing rasterize_fn entirely. Fixed by storing self._rasterize and branching in __getitem__. - ad.concat(*tables_l) unpacked the list as positional args, failing with >1 region. Fixed to ad.concat(tables_l). - Vectorized selection pre-computation was always run even for rasterize=True where it is unused. Fixed by guarding with `if not rasterize`. - Removed stale commented-out pandas.apply fallback code. **Fixes in _utils.py:** - Removed redundant nopython=True from @nb.njit (njit implies nopython=True, and the argument caused a RuntimeWarning). - Replaced invalid nb.types.Array[nb.float64, nb.float64] annotations with np.ndarray. **Fixes in spatial_query.py:** - Restored BoundingBoxRequest validation that was commented out. The validator's __post_init__ already handles both 1-D (single box) and 2-D (multi-box) arrays. **Benchmark (benchmark_dataloader.py):** Synthetic 2048x2048 image, 500 circle regions (32 px radius), 3-channel. Phase main PR (fixed) speedup init ~162 ms ~20 ms ~8x fetch 500 ~618 ms ~118 ms ~5x per-tile ~1237 us ~235 us ~5x Co-Authored-By: Claude Sonnet 4.6 --- benchmark_dataloader.py | 151 +++++++++++++++++++ src/spatialdata/_core/query/_utils.py | 8 +- src/spatialdata/_core/query/spatial_query.py | 14 +- src/spatialdata/dataloader/datasets.py | 55 ++++--- 4 files changed, 188 insertions(+), 40 deletions(-) create mode 100644 benchmark_dataloader.py diff --git a/benchmark_dataloader.py b/benchmark_dataloader.py new file mode 100644 index 000000000..115c58aba --- /dev/null +++ b/benchmark_dataloader.py @@ -0,0 +1,151 @@ +""" +Benchmark for ImageTilesDataset: init time + iteration time. + +Usage: + python benchmark_dataloader.py + +Measures two phases: + 1. Init — constructing ImageTilesDataset (includes bounding-box pre-computation). + 2. Fetch — iterating over every tile once (pure __getitem__ calls, no DataLoader overhead). + +Designed to run identically on `main` and the `giovp/dataloader3` branch so the two +numbers can be compared directly. +""" + +# ruff: noqa: T201 + +from __future__ import annotations + +import time + +import anndata as ad +import geopandas as gpd +import numpy as np +import pandas as pd +from shapely.geometry import Point + +import spatialdata as sd +from spatialdata.dataloader import ImageTilesDataset +from spatialdata.models import Image2DModel, ShapesModel, TableModel +from spatialdata.transformations import Identity + +RNG = np.random.default_rng(42) + +# --------------------------------------------------------------------------- +# Synthetic dataset +# --------------------------------------------------------------------------- + +IMAGE_SIZE = 2048 # pixels (square) +N_CIRCLES = 500 # number of region instances / tiles +N_CHANNELS = 3 + + +def make_sdata(n: int = N_CIRCLES, img_size: int = IMAGE_SIZE) -> sd.SpatialData: + """Build an in-memory SpatialData with a large image and N circle regions.""" + # Image: random uint8, shape (C, H, W) + img_data = RNG.integers(0, 256, size=(N_CHANNELS, img_size, img_size), dtype=np.uint8).astype(np.float32) + image = Image2DModel.parse( + img_data, + dims=["c", "y", "x"], + transformations={"global": Identity()}, + ) + + # Circles: random centres, fixed radius so each tile is ~64 pixels wide + radius = 32.0 + cx = RNG.uniform(radius, img_size - radius, size=n) + cy = RNG.uniform(radius, img_size - radius, size=n) + geom = gpd.GeoDataFrame({"geometry": [Point(x, y) for x, y in zip(cx, cy, strict=True)]}) + geom["radius"] = radius + circles = ShapesModel.parse(geom, transformations={"global": Identity()}) + + # Table: one row per circle + table = ad.AnnData( + RNG.random((n, 10)).astype(np.float32), + obs=pd.DataFrame( + { + "region": pd.Categorical(["circles"] * n), + "instance_id": np.arange(n, dtype=np.int64), + }, + index=[str(i) for i in range(n)], + ), + ) + table = TableModel.parse(table, region="circles", region_key="region", instance_key="instance_id") + + return sd.SpatialData( + images={"image": image}, + shapes={"circles": circles}, + tables={"table": table}, + ) + + +# --------------------------------------------------------------------------- +# Benchmark +# --------------------------------------------------------------------------- + + +def bench_init(sdata: sd.SpatialData, n_reps: int = 5) -> float: + """Time ImageTilesDataset construction.""" + times = [] + for _ in range(n_reps): + t0 = time.perf_counter() + _ = ImageTilesDataset( + sdata=sdata, + regions_to_images={"circles": "image"}, + regions_to_coordinate_systems={"circles": "global"}, + table_name="table", + return_annotations="instance_id", + ) + times.append(time.perf_counter() - t0) + return float(np.median(times)) + + +def bench_fetch(ds: ImageTilesDataset, n_reps: int = 3) -> float: + """Time iterating over every item in the dataset.""" + times = [] + for _ in range(n_reps): + t0 = time.perf_counter() + for i in range(len(ds)): + _ = ds[i] + times.append(time.perf_counter() - t0) + return float(np.median(times)) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + """Run all benchmark phases and print a timing summary.""" + import spatialdata + + print(f"spatialdata version : {spatialdata.__version__}") + print(f"Image size : {IMAGE_SIZE}×{IMAGE_SIZE} ({N_CHANNELS} channels)") + print(f"Circles (tiles) : {N_CIRCLES}") + print() + + print("Building synthetic SpatialData …", flush=True) + t0 = time.perf_counter() + sdata = make_sdata() + print(f" done in {time.perf_counter() - t0:.2f} s\n") + + print("Benchmarking init (5 reps) …", flush=True) + t_init = bench_init(sdata, n_reps=5) + print(f" median init time : {t_init * 1000:.1f} ms\n") + + # Build one dataset for the fetch benchmark + ds = ImageTilesDataset( + sdata=sdata, + regions_to_images={"circles": "image"}, + regions_to_coordinate_systems={"circles": "global"}, + table_name="table", + return_annotations="instance_id", + ) + print(f"Benchmarking fetch of {len(ds)} tiles (3 reps) …", flush=True) + t_fetch = bench_fetch(ds, n_reps=3) + per_tile_us = t_fetch / len(ds) * 1e6 + print(f" median fetch time : {t_fetch * 1000:.1f} ms total ({per_tile_us:.0f} µs/tile)") + + +if __name__ == "__main__": + main() diff --git a/src/spatialdata/_core/query/_utils.py b/src/spatialdata/_core/query/_utils.py index 86dcfc83c..d21ee63e2 100644 --- a/src/spatialdata/_core/query/_utils.py +++ b/src/spatialdata/_core/query/_utils.py @@ -91,11 +91,11 @@ def get_bounding_box_corners( return output.squeeze().drop_vars("box") -@nb.njit(parallel=False, nopython=True) +@nb.njit(parallel=False) def _create_slices_and_translation( - min_values: nb.types.Array[nb.float64, nb.float64], - max_values: nb.types.Array[nb.float64, nb.float64], -) -> tuple[nb.types.Array[nb.float64, nb.float64], nb.types.Array[nb.float64, nb.float64]]: + min_values: np.ndarray, + max_values: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: n_boxes, n_dims = min_values.shape slices = np.empty((n_boxes, n_dims, 2), dtype=np.float64) # (n_boxes, n_dims, [min, max]) translation_vectors = np.empty((n_boxes, n_dims), dtype=np.float64) # (n_boxes, n_dims) diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 7713945b6..475c36f4f 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -556,13 +556,13 @@ def _( min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) - # for triggering validation - # _ = BoundingBoxRequest( - # target_coordinate_system=target_coordinate_system, - # axes=axes, - # min_coordinate=min_coordinate, - # max_coordinate=max_coordinate, - # ) + # for triggering validation (handles both 1-D single-box and 2-D multi-box arrays) + _ = BoundingBoxRequest( + target_coordinate_system=target_coordinate_system, + axes=axes, + min_coordinate=min_coordinate, + max_coordinate=max_coordinate, + ) intrinsic_bounding_box_corners, axes = _get_bounding_box_corners_in_intrinsic_coordinates( image, axes, min_coordinate, max_coordinate, target_coordinate_system diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 78ca24d74..03879abc8 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -128,6 +128,7 @@ def __init__( from spatialdata._core.operations.rasterize import rasterize as rasterize_fn self.sdata = sdata + self._rasterize = rasterize self._validate(regions_to_images, regions_to_coordinate_systems, return_annotations, table_name) self._preprocess(tile_scale, tile_dim_in_units, rasterize, table_name) @@ -145,7 +146,7 @@ def __init__( **dict(rasterize_kwargs), ) if rasterize - else partial(bounding_box_query, return_request_only=True) # type: ignore[assignment] + else bounding_box_query ) self._return = self._get_return(return_annotations, table_name) self.transform = transform @@ -250,25 +251,18 @@ def _preprocess( tile_scale=tile_scale, tile_dim_in_units=tile_dim_in_units, ) - tile_coords["selection"] = bounding_box_query( - self.sdata[image_name], - ("x", "y"), - min_coordinate=tile_coords[["minx", "miny"]].values, - max_coordinate=tile_coords[["maxx", "maxy"]].values, - target_coordinate_system=cs, - return_request_only=True, - ) - # tile_coords["selection"] = tile_coords.apply( - # lambda row, cs=cs, image_name=image_name: bounding_box_query( - # self.sdata[image_name], - # ("x", "y"), - # min_coordinate=row[["minx", "miny"]].values, - # max_coordinate=row[["maxx", "maxy"]].values, - # target_coordinate_system=cs, - # return_request_only=True, - # ), - # axis=1, - # ) + if not rasterize: + # Pre-compute all per-tile slice selections in a single vectorized call. + # Passing 2-D min/max arrays triggers the multi-box path in bounding_box_query, + # which returns a list of {axis: slice} dicts — one per tile. + tile_coords["selection"] = bounding_box_query( + self.sdata[image_name], + ("x", "y"), + min_coordinate=tile_coords[["minx", "miny"]].values, + max_coordinate=tile_coords[["maxx", "maxy"]].values, + target_coordinate_system=cs, + return_request_only=True, + ) tile_coords_df.append(tile_coords) inst = circles.index.values @@ -296,7 +290,7 @@ def _preprocess( self.dataset_index = pd.concat(index_df).reset_index(drop=True) assert len(self.tiles_coords) == len(self.dataset_index) if table_name: - self.dataset_table = ad.concat(*tables_l) + self.dataset_table = ad.concat(tables_l) assert len(self.tiles_coords) == len(self.dataset_table) dims_ = set(chain(*dims_l)) @@ -376,14 +370,17 @@ def __getitem__(self, idx: int) -> Any | SpatialData: t_coords = self.tiles_coords.iloc[idx] image = self.sdata[row["image"]] - # tile = self._crop_image( - # image, - # axes=tuple(self.dims), - # min_coordinate=t_coords[[f"min{i}" for i in self.dims]].values, - # max_coordinate=t_coords[[f"max{i}" for i in self.dims]].values, - # target_coordinate_system=row["cs"], - # ) - tile = image.sel(t_coords["selection"]) + if self._rasterize: + tile = self._crop_image( + image, + axes=tuple(self.dims), + min_coordinate=t_coords[[f"min{i}" for i in self.dims]].values, + max_coordinate=t_coords[[f"max{i}" for i in self.dims]].values, + target_coordinate_system=row["cs"], + ) + else: + # Use pre-computed slice selection (vectorized at init time). + tile = image.sel(t_coords["selection"]) if self.transform is not None: out = self._return(idx, tile) return self.transform(out) From a51cb959e07c1a600b6a0474cd376d59dd2e198d Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Thu, 21 May 2026 13:39:04 +0200 Subject: [PATCH 22/22] add asv benchmark for dataloader performance --- benchmark_dataloader.py | 151 ----------------------------- benchmarks/README.md | 26 +++++ benchmarks/benchmark_dataloader.py | 75 ++++++++++++++ 3 files changed, 101 insertions(+), 151 deletions(-) delete mode 100644 benchmark_dataloader.py create mode 100644 benchmarks/benchmark_dataloader.py diff --git a/benchmark_dataloader.py b/benchmark_dataloader.py deleted file mode 100644 index 115c58aba..000000000 --- a/benchmark_dataloader.py +++ /dev/null @@ -1,151 +0,0 @@ -""" -Benchmark for ImageTilesDataset: init time + iteration time. - -Usage: - python benchmark_dataloader.py - -Measures two phases: - 1. Init — constructing ImageTilesDataset (includes bounding-box pre-computation). - 2. Fetch — iterating over every tile once (pure __getitem__ calls, no DataLoader overhead). - -Designed to run identically on `main` and the `giovp/dataloader3` branch so the two -numbers can be compared directly. -""" - -# ruff: noqa: T201 - -from __future__ import annotations - -import time - -import anndata as ad -import geopandas as gpd -import numpy as np -import pandas as pd -from shapely.geometry import Point - -import spatialdata as sd -from spatialdata.dataloader import ImageTilesDataset -from spatialdata.models import Image2DModel, ShapesModel, TableModel -from spatialdata.transformations import Identity - -RNG = np.random.default_rng(42) - -# --------------------------------------------------------------------------- -# Synthetic dataset -# --------------------------------------------------------------------------- - -IMAGE_SIZE = 2048 # pixels (square) -N_CIRCLES = 500 # number of region instances / tiles -N_CHANNELS = 3 - - -def make_sdata(n: int = N_CIRCLES, img_size: int = IMAGE_SIZE) -> sd.SpatialData: - """Build an in-memory SpatialData with a large image and N circle regions.""" - # Image: random uint8, shape (C, H, W) - img_data = RNG.integers(0, 256, size=(N_CHANNELS, img_size, img_size), dtype=np.uint8).astype(np.float32) - image = Image2DModel.parse( - img_data, - dims=["c", "y", "x"], - transformations={"global": Identity()}, - ) - - # Circles: random centres, fixed radius so each tile is ~64 pixels wide - radius = 32.0 - cx = RNG.uniform(radius, img_size - radius, size=n) - cy = RNG.uniform(radius, img_size - radius, size=n) - geom = gpd.GeoDataFrame({"geometry": [Point(x, y) for x, y in zip(cx, cy, strict=True)]}) - geom["radius"] = radius - circles = ShapesModel.parse(geom, transformations={"global": Identity()}) - - # Table: one row per circle - table = ad.AnnData( - RNG.random((n, 10)).astype(np.float32), - obs=pd.DataFrame( - { - "region": pd.Categorical(["circles"] * n), - "instance_id": np.arange(n, dtype=np.int64), - }, - index=[str(i) for i in range(n)], - ), - ) - table = TableModel.parse(table, region="circles", region_key="region", instance_key="instance_id") - - return sd.SpatialData( - images={"image": image}, - shapes={"circles": circles}, - tables={"table": table}, - ) - - -# --------------------------------------------------------------------------- -# Benchmark -# --------------------------------------------------------------------------- - - -def bench_init(sdata: sd.SpatialData, n_reps: int = 5) -> float: - """Time ImageTilesDataset construction.""" - times = [] - for _ in range(n_reps): - t0 = time.perf_counter() - _ = ImageTilesDataset( - sdata=sdata, - regions_to_images={"circles": "image"}, - regions_to_coordinate_systems={"circles": "global"}, - table_name="table", - return_annotations="instance_id", - ) - times.append(time.perf_counter() - t0) - return float(np.median(times)) - - -def bench_fetch(ds: ImageTilesDataset, n_reps: int = 3) -> float: - """Time iterating over every item in the dataset.""" - times = [] - for _ in range(n_reps): - t0 = time.perf_counter() - for i in range(len(ds)): - _ = ds[i] - times.append(time.perf_counter() - t0) - return float(np.median(times)) - - -# --------------------------------------------------------------------------- -# Main -# --------------------------------------------------------------------------- - - -def main() -> None: - """Run all benchmark phases and print a timing summary.""" - import spatialdata - - print(f"spatialdata version : {spatialdata.__version__}") - print(f"Image size : {IMAGE_SIZE}×{IMAGE_SIZE} ({N_CHANNELS} channels)") - print(f"Circles (tiles) : {N_CIRCLES}") - print() - - print("Building synthetic SpatialData …", flush=True) - t0 = time.perf_counter() - sdata = make_sdata() - print(f" done in {time.perf_counter() - t0:.2f} s\n") - - print("Benchmarking init (5 reps) …", flush=True) - t_init = bench_init(sdata, n_reps=5) - print(f" median init time : {t_init * 1000:.1f} ms\n") - - # Build one dataset for the fetch benchmark - ds = ImageTilesDataset( - sdata=sdata, - regions_to_images={"circles": "image"}, - regions_to_coordinate_systems={"circles": "global"}, - table_name="table", - return_annotations="instance_id", - ) - print(f"Benchmarking fetch of {len(ds)} tiles (3 reps) …", flush=True) - t_fetch = bench_fetch(ds, n_reps=3) - per_tile_us = t_fetch / len(ds) * 1e6 - print(f" median fetch time : {t_fetch * 1000:.1f} ms total ({per_tile_us:.0f} µs/tile)") - - -if __name__ == "__main__": - main() diff --git a/benchmarks/README.md b/benchmarks/README.md index 9f8903620..6ae1d7d03 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -65,6 +65,32 @@ git checkout - && git stash pop asv compare main HEAD ``` +### Dataloader benchmarks + +Dataloader benchmarks live in `benchmarks/benchmark_dataloader.py`. They use a synthetic in-memory `SpatialData` (2048×2048 image, 500 circle regions) and compute two metrics: + +- `time_init` — constructing `ImageTilesDataset` (includes bounding-box pre-computation). +- `time_fetch` — iterating over all 500 tiles once (pure `__getitem__` calls, no `DataLoader` overhead). + +Run both in your current environment: + +```bash +asv run --python=same --show-stderr -b TimeDataloader +``` + +Or a single method: + +```bash +asv run --python=same --show-stderr -b TimeDataloader.time_init +asv run --python=same --show-stderr -b TimeDataloader.time_fetch +``` + +Compare against `main` in one shot: + +```bash +asv continuous --show-stderr -v -b TimeDataloader main HEAD +``` + ### Querying benchmarks Querying using a bounding box without a spatial index is highly impacted by large amounts of points (transcripts), more than table rows (cells). diff --git a/benchmarks/benchmark_dataloader.py b/benchmarks/benchmark_dataloader.py new file mode 100644 index 000000000..474b658ba --- /dev/null +++ b/benchmarks/benchmark_dataloader.py @@ -0,0 +1,75 @@ +# type: ignore +"""Benchmarks for ImageTilesDataset: init time and iteration time.""" + +from __future__ import annotations + +import anndata as ad +import geopandas as gpd +import numpy as np +import pandas as pd +from shapely.geometry import Point + +import spatialdata as sd +from spatialdata.dataloader import ImageTilesDataset +from spatialdata.models import Image2DModel, ShapesModel, TableModel +from spatialdata.transformations import Identity + +_RNG = np.random.default_rng(42) + +_IMAGE_SIZE = 2048 +_N_CIRCLES = 500 +_N_CHANNELS = 3 + +_DATASET_KWARGS = { + "regions_to_images": {"circles": "image"}, + "regions_to_coordinate_systems": {"circles": "global"}, + "table_name": "table", + "return_annotations": "instance_id", +} + + +def _make_sdata() -> sd.SpatialData: + img_data = _RNG.integers(0, 256, size=(_N_CHANNELS, _IMAGE_SIZE, _IMAGE_SIZE), dtype=np.uint8).astype(np.float32) + image = Image2DModel.parse(img_data, dims=["c", "y", "x"], transformations={"global": Identity()}) + + radius = 32.0 + cx = _RNG.uniform(radius, _IMAGE_SIZE - radius, size=_N_CIRCLES) + cy = _RNG.uniform(radius, _IMAGE_SIZE - radius, size=_N_CIRCLES) + geom = gpd.GeoDataFrame({"geometry": [Point(x, y) for x, y in zip(cx, cy, strict=True)]}) + geom["radius"] = radius + circles = ShapesModel.parse(geom, transformations={"global": Identity()}) + + table = ad.AnnData( + _RNG.random((_N_CIRCLES, 10)).astype(np.float32), + obs=pd.DataFrame( + { + "region": pd.Categorical(["circles"] * _N_CIRCLES), + "instance_id": np.arange(_N_CIRCLES, dtype=np.int64), + }, + index=[str(i) for i in range(_N_CIRCLES)], + ), + ) + table = TableModel.parse(table, region="circles", region_key="region", instance_key="instance_id") + + return sd.SpatialData(images={"image": image}, shapes={"circles": circles}, tables={"table": table}) + + +class TimeDataloader: + """Time ImageTilesDataset construction and tile iteration.""" + + def setup(self): + self.sdata = _make_sdata() + self.ds = ImageTilesDataset(sdata=self.sdata, **_DATASET_KWARGS) + + def teardown(self): + del self.ds + del self.sdata + + def time_init(self): + """Time constructing ImageTilesDataset (bounding-box pre-computation).""" + ImageTilesDataset(sdata=self.sdata, **_DATASET_KWARGS) + + def time_fetch(self): + """Time iterating over every tile once.""" + for i in range(len(self.ds)): + _ = self.ds[i]