From eb487014548cdcfd363d100cc2b8b5957998e2dc Mon Sep 17 00:00:00 2001 From: Boris Ivanovic Date: Tue, 14 Oct 2025 15:36:54 -0700 Subject: [PATCH 1/4] Updates that enable usage with alpasim (additional map components, protobuf4, removal of map rasterization, etc). --- DATASETS.md | 38 -- README.md | 21 +- examples/visualization_example.py | 25 +- pyproject.toml | 13 +- src/trajdata/caching/df_cache.py | 234 +--------- src/trajdata/data_structures/batch.py | 356 ++++++++++++--- src/trajdata/data_structures/batch_element.py | 333 +++++++++++--- src/trajdata/data_structures/collation.py | 366 +++++---------- src/trajdata/data_structures/data_index.py | 4 +- src/trajdata/data_structures/state.py | 41 +- src/trajdata/dataset.py | 157 +++---- .../dataset_specific/nuplan/nuplan_dataset.py | 80 ++-- .../dataset_specific/nuplan/nuplan_utils.py | 33 +- .../dataset_specific/nusc/nusc_dataset.py | 6 +- .../dataset_specific/nusc/nusc_utils.py | 12 +- .../dataset_specific/scene_records.py | 14 - .../dataset_specific/waymo/waymo_dataset.py | 46 +- .../dataset_specific/waymo/waymo_utils.py | 133 +++++- src/trajdata/maps/map_api.py | 45 +- src/trajdata/maps/map_kdtree.py | 80 +++- src/trajdata/maps/vec_map.py | 429 +++++++++++++----- src/trajdata/maps/vec_map_elements.py | 188 +++++++- src/trajdata/parallel/data_preprocessor.py | 8 +- src/trajdata/proto/vectorized_map.proto | 45 +- src/trajdata/proto/vectorized_map_pb2.py | 147 ++---- src/trajdata/proto/vectorized_map_pb2_grpc.py | 24 + src/trajdata/utils/arr_utils.py | 423 ++++++++++++++++- src/trajdata/utils/batch_utils.py | 104 ++++- src/trajdata/utils/comm_utils.py | 24 + src/trajdata/utils/env_utils.py | 10 - src/trajdata/utils/map_utils.py | 176 ++++++- src/trajdata/utils/scene_utils.py | 8 +- src/trajdata/utils/vis_utils.py | 68 ++- .../visualization/interactive_animation.py | 243 +++++----- .../visualization/interactive_figure.py | 4 +- src/trajdata/visualization/interactive_vis.py | 16 +- tests/test_raster_map.py | 100 ++++ tests/test_state.py | 211 +++++++++ 38 files changed, 2968 insertions(+), 1297 deletions(-) create mode 100644 src/trajdata/proto/vectorized_map_pb2_grpc.py create mode 100644 src/trajdata/utils/comm_utils.py create mode 100644 tests/test_raster_map.py diff --git a/DATASETS.md b/DATASETS.md index 528bae2..b2f9255 100644 --- a/DATASETS.md +++ b/DATASETS.md @@ -1,16 +1,5 @@ # Supported Datasets and Required Formats -## View-of-Delft -Nothing special needs to be done for the View-of-Delft Prediction dataset, simply download it as per [the instructions in the devkit README](https://github.com/tudelft-iv/view-of-delft-prediction-devkit?tab=readme-ov-file#vod-p-setup). - -It should look like this after downloading: -``` -/path/to/VoD/ - ├── maps/ - ├── v1.0-test/ - └── v1.0-trainval/ -``` - ## nuScenes Nothing special needs to be done for the nuScenes dataset, simply download it as per [the instructions in the devkit README](https://github.com/nutonomy/nuscenes-devkit#nuscenes-setup). @@ -199,30 +188,3 @@ It should look like this after downloading: ``` **Note**: Only the annotations need to be downloaded (not the videos). - - -## Argoverse 2 Motion Forecasting -The dataset can be downloaded from [here](https://www.argoverse.org/av2.html#download-link). - -It should look like this after downloading: -``` -/path/to/av2mf/ - ├── train/ - | ├── 0000b0f9-99f9-4a1f-a231-5be9e4c523f7/ - | | ├── log_map_archive_0000b0f9-99f9-4a1f-a231-5be9e4c523f7.json - | | └── scenario_0000b0f9-99f9-4a1f-a231-5be9e4c523f7.parquet - | ├── 0000b6ab-e100-4f6b-aee8-b520b57c0530/ - | | ├── log_map_archive_0000b6ab-e100-4f6b-aee8-b520b57c0530.json - | | └── scenario_0000b6ab-e100-4f6b-aee8-b520b57c0530.parquet - | └── ... - ├── val/ - | ├── 00010486-9a07-48ae-b493-cf4545855937/ - | | ├── log_map_archive_00010486-9a07-48ae-b493-cf4545855937.json - | | └── scenario_00010486-9a07-48ae-b493-cf4545855937.parquet - | └── ... - └── test/ - ├── 0000b329-f890-4c2b-93f2-7e2413d4ca5b/ - | ├── log_map_archive_0000b329-f890-4c2b-93f2-7e2413d4ca5b.json - | └── scenario_0000b329-f890-4c2b-93f2-7e2413d4ca5b.parquet - └── ... -``` \ No newline at end of file diff --git a/README.md b/README.md index db5678d..7858fe5 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ The easiest way to install trajdata is through PyPI with pip install trajdata ``` -In case you would also like to use datasets such as nuScenes, Lyft Level 5, View-of-Delft, or Waymo Open Motion Dataset (which require their own devkits to access raw data or additional package dependencies), the following will also install the respective devkits and/or package dependencies. +In case you would also like to use datasets such as nuScenes, Lyft Level 5, or Waymo Open Motion Dataset (which require their own devkits to access raw data or additional package dependencies), the following will also install the respective devkits and/or package dependencies. ```sh # For nuScenes pip install "trajdata[nusc]" @@ -31,13 +31,10 @@ pip install "trajdata[waymo]" # For INTERACTION pip install "trajdata[interaction]" -# For View-of-Delft -pip install "trajdata[vod]" - # All -pip install "trajdata[nusc,lyft,waymo,interaction,vod]" +pip install "trajdata[nusc,lyft,waymo,interaction]" ``` -Then, download the raw datasets (nuScenes, Lyft Level 5, View-of-Delft, ETH/UCY, etc.) in case you do not already have them. For more information about how to structure dataset folders/files, please see [`DATASETS.md`](./DATASETS.md). +Then, download the raw datasets (nuScenes, Lyft Level 5, ETH/UCY, etc.) in case you do not already have them. For more information about how to structure dataset folders/files, please see [`DATASETS.md`](./DATASETS.md). ### Package Developer Installation @@ -103,8 +100,6 @@ Currently, the dataloader supports interfacing with the following datasets: | nuPlan Validation | `nuplan_val` | N/A | `boston`, `singapore`, `pittsburgh`, `las_vegas` | nuPlan validation split (90.30 GB) | 0.05s (20Hz) | :white_check_mark: | | nuPlan Test | `nuplan_test` | N/A | `boston`, `singapore`, `pittsburgh`, `las_vegas` | nuPlan testing split (89.33 GB) | 0.05s (20Hz) | :white_check_mark: | | nuPlan Mini | `nuplan_mini` | `mini_train`, `mini_val`, `mini_test` | `boston`, `singapore`, `pittsburgh`, `las_vegas` | nuPlan mini training/validation/test splits (942/197/224 scenes, 7.96 GB) | 0.05s (20Hz) | :white_check_mark: | -| View-of-Delft Train/TrainVal/Val | `vod_trainval` | `train`, `train_val`, `val` | `delft` | View-of-Delft Prediction training and validation splits | 0.1s (10Hz) | :white_check_mark: | -| View-of-Delft Test | `vod_test` | `test` | `delft` | View-of-Delft Prediction test split | 0.1s (10Hz) | :white_check_mark: | | Waymo Open Motion Training | `waymo_train` | `train` | N/A | Waymo Open Motion Dataset `training` split | 0.1s (10Hz) | :white_check_mark: | | Waymo Open Motion Validation | `waymo_val` | `val` | N/A | Waymo Open Motion Dataset `validation` split | 0.1s (10Hz) | :white_check_mark: | | Waymo Open Motion Testing | `waymo_test` | `test` | N/A | Waymo Open Motion Dataset `testing` split | 0.1s (10Hz) | :white_check_mark: | @@ -112,7 +107,6 @@ Currently, the dataloader supports interfacing with the following datasets: | Lyft Level 5 Train Full | `lyft_train_full` | `train` | `palo_alto` | Lyft Level 5 training data - part 2/2 (70 GB) | 0.1s (10Hz) | :white_check_mark: | | Lyft Level 5 Validation | `lyft_val` | `val` | `palo_alto` | Lyft Level 5 validation data (8.2 GB) | 0.1s (10Hz) | :white_check_mark: | | Lyft Level 5 Sample | `lyft_sample` | `mini_train`, `mini_val` | `palo_alto` | Lyft Level 5 sample data (100 scenes, randomly split 80/20 for training/validation) | 0.1s (10Hz) | :white_check_mark: | -| Argoverse 2 Motion Forecasting | `av2_motion_forecasting` | `train`, `val`, `test` | N/A | 250,000 motion forecasting scenarios of 11s each | 0.1s (10Hz) | :white_check_mark: | | INTERACTION Dataset Single-Agent | `interaction_single` | `train`, `val`, `test`, `test_conditional` | `usa`, `china`, `germany`, `bulgaria` | Single-agent split of the INTERACTION Dataset (where the goal is to predict one target agents' future motion) | 0.1s (10Hz) | :white_check_mark: | | INTERACTION Dataset Multi-Agent | `interaction_multi` | `train`, `val`, `test`, `test_conditional` | `usa`, `china`, `germany`, `bulgaria` | Multi-agent split of the INTERACTION Dataset (where the goal is to jointly predict multiple agents' future motion) | 0.1s (10Hz) | :white_check_mark: | | ETH - Univ | `eupeds_eth` | `train`, `val`, `train_loo`, `val_loo`, `test_loo` | `zurich` | The ETH (University) scene from the ETH BIWI Walking Pedestrians dataset | 0.4s (2.5Hz) | | @@ -234,3 +228,12 @@ If you use this software, please cite it as follows: ## TODO - Create a method like finalize() which writes all the batch information to a TFRecord/WebDataset/some other format which is (very) fast to read from for higher epoch training. + +## Protobuf +This branch is intended to compile protobuf files for use with Proto V4. In order to do this, install dev depedencies with `pip install trajdata[dev]` and use + +``` +python -m grpc_tools.protoc --python_out=src/trajdata/proto --proto_path=src/trajdata/proto --experimental_allow_proto3_optional src/trajdata/proto/vectorized_map.proto +``` + +To compile the vector map protobuf. diff --git a/examples/visualization_example.py b/examples/visualization_example.py index ba65f53..7c45004 100644 --- a/examples/visualization_example.py +++ b/examples/visualization_example.py @@ -1,13 +1,10 @@ from collections import defaultdict +import panel as pn from torch.utils.data import DataLoader from tqdm import tqdm - from trajdata import AgentBatch, AgentType, UnifiedDataset -from trajdata.visualization.interactive_animation import ( - InteractiveAnimation, - animate_agent_batch_interactive, -) +from trajdata.visualization.interactive_animation import animate_agent_batch_interactive from trajdata.visualization.interactive_vis import plot_agent_batch_interactive from trajdata.visualization.vis import plot_agent_batch @@ -51,17 +48,17 @@ def main(): batch: AgentBatch for batch in tqdm(dataloader): - plot_agent_batch_interactive(batch, batch_idx=0, cache_path=dataset.cache_path) - plot_agent_batch(batch, batch_idx=0) + # plot_agent_batch_interactive(batch, batch_idx=0, cache_path=dataset.cache_path) + # plot_agent_batch(batch, batch_idx=0) - animation = InteractiveAnimation( - animate_agent_batch_interactive, - batch=batch, - batch_idx=0, - cache_path=dataset.cache_path, + server = pn.serve( + animate_agent_batch_interactive( + batch=batch, batch_idx=0, cache_path=dataset.cache_path + ), ) - animation.show() - # break + server.io_loop.start() + + break if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index 2413b5c..fb52a43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,8 +12,8 @@ classifiers = [ "Programming Language :: Python :: 3.8", "License :: OSI Approved :: Apache Software License", ] -name = "trajdata" -version = "1.4.0" +name = "trajdata-alpasim" +version = "1.4.2" authors = [{ name = "Boris Ivanovic", email = "bivanovic@nvidia.com" }] description = "A unified interface to many trajectory forecasting datasets." readme = "README.md" @@ -26,25 +26,20 @@ dependencies = [ "pandas>=1.4.1", "pyarrow>=7.0.0", "torch>=1.10.2", - "zarr>=2.11.0", - "kornia>=0.6.4", "seaborn>=0.12", "bokeh>=3.0.3", "geopandas>=0.13.2", - "protobuf==3.19.4", + "protobuf>=4.0.0,<5.0.0", "scipy>=1.9.0", - "opencv-python>=4.5.0", "shapely>=2.0.0", ] [project.optional-dependencies] -av2 = ["av2==0.2.1"] -dev = ["black", "isort", "pytest", "pytest-xdist", "twine", "build"] +dev = ["black", "isort", "pytest", "pytest-xdist", "twine", "build", "grpcio-tools==1.62.2"] interaction = ["lanelet2==1.2.1"] lyft = ["l5kit==1.5.0"] nusc = ["nuscenes-devkit==1.1.9"] waymo = ["tensorflow==2.11.0", "waymo-open-dataset-tf-2-11-0", "intervaltree"] -vod = ["vod-devkit==1.1.1"] [project.urls] "Homepage" = "https://github.com/nvr-avg/trajdata" diff --git a/src/trajdata/caching/df_cache.py b/src/trajdata/caching/df_cache.py index 329bb3d..b2cf312 100644 --- a/src/trajdata/caching/df_cache.py +++ b/src/trajdata/caching/df_cache.py @@ -6,8 +6,6 @@ if TYPE_CHECKING: from trajdata.maps import ( - RasterizedMap, - RasterizedMapMetadata, VectorMap, ) from trajdata.maps.map_kdtree import MapElementKDTree @@ -19,11 +17,8 @@ from typing import Any, Dict, Final, List, Optional, Tuple import dill -import kornia import numpy as np import pandas as pd -import torch -import zarr from trajdata.augmentation.augmentation import Augmentation, DatasetAugmentation from trajdata.caching.scene_cache import SceneCache @@ -31,7 +26,8 @@ from trajdata.data_structures.scene_metadata import Scene from trajdata.data_structures.state import NP_STATE_TYPES, StateArray from trajdata.maps.traffic_light_status import TrafficLightStatus -from trajdata.utils import arr_utils, df_utils, raster_utils, state_utils +from trajdata.utils import arr_utils, df_utils +from trajdata.utils.scene_utils import is_integer_robust STATE_COLS: Final[List[str]] = ["x", "y", "z", "vx", "vy", "ax", "ay"] EXTENT_COLS: Final[List[str]] = ["length", "width", "height"] @@ -322,7 +318,7 @@ def _transform_pair( def _upsample_data( self, new_index: pd.MultiIndex, upsample_dt_ratio: float, method: str ) -> pd.DataFrame: - upsample_dt_factor: int = int(upsample_dt_ratio) + upsample_dt_factor: int = int(round(upsample_dt_ratio)) interpolated_df: pd.DataFrame = pd.DataFrame( index=new_index, columns=self.scene_data_df.columns @@ -338,9 +334,9 @@ def _upsample_data( new_index.get_level_values("scene_ts") % upsample_dt_factor == 0 )[0] interpolated_df.iloc[scene_data_idxs] = scene_data - interpolated_df.iloc[ - scene_data_idxs, self.column_dict["heading"] - ] = unwrapped_heading + interpolated_df.iloc[scene_data_idxs, self.column_dict["heading"]] = ( + unwrapped_heading + ) # Interpolation. interpolated_df.interpolate(method=method, limit_area="inside", inplace=True) @@ -355,7 +351,7 @@ def _upsample_data( def _downsample_data( self, new_index: pd.MultiIndex, downsample_dt_ratio: float ) -> pd.DataFrame: - downsample_dt_factor: int = int(downsample_dt_ratio) + downsample_dt_factor: int = int(round(downsample_dt_ratio)) subsample_index: pd.MultiIndex = new_index.set_levels( new_index.levels[1] * downsample_dt_factor, level=1 @@ -370,7 +366,8 @@ def _downsample_data( def interpolate_data(self, desired_dt: float, method: str = "linear") -> None: upsample_dt_ratio: float = self.scene.env_metadata.dt / desired_dt downsample_dt_ratio: float = desired_dt / self.scene.env_metadata.dt - if not upsample_dt_ratio.is_integer() and not downsample_dt_ratio.is_integer(): + + if not is_integer_robust(upsample_dt_ratio) and not is_integer_robust(downsample_dt_ratio): raise ValueError( f"{str(self.scene)}'s dt of {self.scene.dt}s " f"is not integer divisible by the desired dt {desired_dt}s." @@ -705,23 +702,19 @@ def are_maps_cached(cache_path: Path, env_name: str) -> bool: @staticmethod def get_map_paths( - cache_path: Path, env_name: str, map_name: str, resolution: float - ) -> Tuple[Path, Path, Path, Path, Path, Path]: + cache_path: Path, env_name: str, map_name: str + ) -> Tuple[Path, Path, Path, Path]: maps_path: Path = DataFrameCache.get_maps_path(cache_path, env_name) vector_map_path: Path = maps_path / f"{map_name}.pb" kdtrees_path: Path = maps_path / f"{map_name}_kdtrees.dill" rtrees_path: Path = maps_path / f"{map_name}_rtrees.dill" - raster_map_path: Path = maps_path / f"{map_name}_{resolution:.2f}px_m.zarr" - raster_metadata_path: Path = maps_path / f"{map_name}_{resolution:.2f}px_m.dill" return ( maps_path, vector_map_path, kdtrees_path, rtrees_path, - raster_map_path, - raster_metadata_path, ) @staticmethod @@ -733,8 +726,6 @@ def is_map_cached( vector_map_path, kdtrees_path, rtrees_path, - raster_map_path, - raster_metadata_path, ) = DataFrameCache.get_map_paths(cache_path, env_name, map_name, resolution) # TODO(bivanovic): For now, rtrees are optional to have in the cache. @@ -745,8 +736,6 @@ def is_map_cached( and vector_map_path.exists() and kdtrees_path.exists() # and rtrees_path.exists() - and raster_metadata_path.exists() - and raster_map_path.exists() ) @staticmethod @@ -755,22 +744,13 @@ def finalize_and_cache_map( vector_map: VectorMap, map_params: Dict[str, Any], ) -> None: - raster_resolution: float = map_params["px_per_m"] - ( maps_path, vector_map_path, kdtrees_path, rtrees_path, - raster_map_path, - raster_metadata_path, ) = DataFrameCache.get_map_paths( - cache_path, vector_map.env_name, vector_map.map_name, raster_resolution - ) - - pbar_kwargs = {"position": 2, "leave": False, "disable": True} - rasterized_map: RasterizedMap = raster_utils.rasterize_map( - vector_map, raster_resolution, **pbar_kwargs + cache_path, vector_map.env_name, vector_map.map_name ) vector_map.compute_search_indices() @@ -790,13 +770,6 @@ def finalize_and_cache_map( with open(rtrees_path, "wb") as f: dill.dump(vector_map.search_rtrees, f) - # Saving the rasterized map data. - zarr.save(raster_map_path, rasterized_map.data) - - # Saving the rasterized map metadata. - with open(raster_metadata_path, "wb") as f: - dill.dump(rasterized_map.metadata, f) - def pad_map_patch( self, patch: np.ndarray, @@ -805,28 +778,7 @@ def pad_map_patch( patch_size: int, map_dims: Tuple[int, int, int], ) -> np.ndarray: - if patch.shape[-2:] == (patch_size, patch_size): - return patch - - top, bot, left, right = patch_sides - channels, height, width = map_dims - - # If we're off the map, just return zeros in the - # desired size of the patch. - if bot <= 0 or top >= height or right <= 0 or left >= width: - return np.zeros((channels, patch_size, patch_size)) - - pad_top, pad_bot, pad_left, pad_right = 0, 0, 0, 0 - if top < 0: - pad_top = 0 - top - if bot >= height: - pad_bot = bot - height - if left < 0: - pad_left = 0 - left - if right >= width: - pad_right = right - width - - return np.pad(patch, [(0, 0), (pad_top, pad_bot), (pad_left, pad_right)]) + raise NotImplementedError() def load_kdtrees(self) -> Dict[str, MapElementKDTree]: _, _, kdtrees_path, _, _, _ = DataFrameCache.get_map_paths( @@ -895,163 +847,3 @@ def get_rtrees(self, load_only_once: bool = True): else: return self._rtrees - - def load_map_patch( - self, - world_x: float, - world_y: float, - desired_patch_size: int, - resolution: float, - offset_xy: Tuple[float, float], - agent_heading: float, - return_rgb: bool, - rot_pad_factor: float = 1.0, - no_map_val: float = 0.0, - ) -> Tuple[np.ndarray, np.ndarray, bool]: - ( - maps_path, - _, - _, - _, - raster_map_path, - raster_metadata_path, - ) = DataFrameCache.get_map_paths( - self.path, self.scene.env_name, self.scene.location, resolution - ) - if not maps_path.exists(): - # This dataset (or location) does not have any maps, - # so we return an empty map. - patch_size: int = ceil((rot_pad_factor * desired_patch_size) / 2) * 2 - - return ( - np.full( - (1 if not return_rgb else 3, patch_size, patch_size), - fill_value=no_map_val, - ), - np.eye(3), - False, - ) - - with open(raster_metadata_path, "rb") as f: - map_info: RasterizedMapMetadata = dill.load(f) - - raster_from_world_tf: np.ndarray = map_info.map_from_world - map_coords: np.ndarray = map_info.map_from_world @ np.array( - [world_x, world_y, 1.0] - ) - map_x, map_y = map_coords[0].item(), map_coords[1].item() - - raster_from_world_tf = ( - np.array( - [ - [1.0, 0.0, -map_x], - [0.0, 1.0, -map_y], - [0.0, 0.0, 1.0], - ] - ) - @ raster_from_world_tf - ) - - # This first size is how much of the map we - # need to extract to match the requested metric size (meters x meters) of - # the patch. - data_patch_size: int = ceil( - desired_patch_size * map_info.resolution / resolution - ) - - # Incorporating offsets. - if offset_xy != (0.0, 0.0): - # x is negative here because I am moving the map - # center so that the agent ends up where the user wishes - # (the agent is pinned from the end user's perspective). - map_offset: Tuple[float, float] = ( - -offset_xy[0] * data_patch_size // 2, - offset_xy[1] * data_patch_size // 2, - ) - - rotated_offset: np.ndarray = ( - arr_utils.rotation_matrix(agent_heading) @ map_offset - ) - - off_x = rotated_offset[0] - off_y = rotated_offset[1] - - map_x += off_x - map_y += off_y - - raster_from_world_tf = ( - np.array( - [ - [1.0, 0.0, -off_x], - [0.0, 1.0, -off_y], - [0.0, 0.0, 1.0], - ] - ) - @ raster_from_world_tf - ) - - # This is the size of the patch taking into account expansion to allow for - # rotation to match the agent's heading. We also ensure the final size is - # divisible by two so that the // 2 below does not chop any information off. - data_with_rot_pad_size: int = ceil((rot_pad_factor * data_patch_size) / 2) * 2 - - disk_data = zarr.open_array(raster_map_path, mode="r") - - map_x = round(map_x) - map_y = round(map_y) - - # Half of the patch's side length. - half_extent: int = data_with_rot_pad_size // 2 - - top: int = map_y - half_extent - bot: int = map_y + half_extent - left: int = map_x - half_extent - right: int = map_x + half_extent - - data_patch: np.ndarray = self.pad_map_patch( - disk_data[ - ..., - max(top, 0) : min(bot, disk_data.shape[1]), - max(left, 0) : min(right, disk_data.shape[2]), - ], - (top, bot, left, right), - data_with_rot_pad_size, - disk_data.shape, - ) - - if return_rgb: - rgb_groups = map_info.layer_rgb_groups - data_patch = np.stack( - [ - np.amax(data_patch[rgb_groups[0]], axis=0), - np.amax(data_patch[rgb_groups[1]], axis=0), - np.amax(data_patch[rgb_groups[2]], axis=0), - ], - ) - - if desired_patch_size != data_patch_size: - scale_factor: float = desired_patch_size / data_patch_size - data_patch = ( - kornia.geometry.rescale( - torch.from_numpy(data_patch).unsqueeze(0), - scale_factor, - # Default align_corners value, just putting it to remove warnings - align_corners=False, - antialias=True, - ) - .squeeze(0) - .numpy() - ) - - raster_from_world_tf = ( - np.array( - [ - [1 / scale_factor, 0.0, 0.0], - [0.0, 1 / scale_factor, 0.0], - [0.0, 0.0, 1.0], - ] - ) - @ raster_from_world_tf - ) - - return data_patch, raster_from_world_tf, True diff --git a/src/trajdata/data_structures/batch.py b/src/trajdata/data_structures/batch.py index c93fb8b..c040b54 100644 --- a/src/trajdata/data_structures/batch.py +++ b/src/trajdata/data_structures/batch.py @@ -1,15 +1,19 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, replace from typing import Dict, List, Optional, Union import torch from torch import Tensor - from trajdata.data_structures.agent import AgentType from trajdata.data_structures.state import StateTensor from trajdata.maps import VectorMap -from trajdata.utils.arr_utils import PadDirection +from trajdata.utils.arr_utils import ( + PadDirection, + batch_nd_transform_xyvvaahh_pt, + roll_with_tensor, + transform_xyh_torch, +) @dataclass @@ -34,18 +38,25 @@ class AgentBatch: neigh_fut: StateTensor neigh_fut_extents: Tensor neigh_fut_len: Tensor + agents_from_world_tf: Tensor + history_pad_dir: PadDirection + extras: Dict[str, Tensor] + vector_maps: Optional[List[VectorMap]] + rasters_from_world_tf: Optional[Tensor] + scene_ids: Optional[List] robot_fut: Optional[StateTensor] robot_fut_len: Optional[Tensor] + track_ids: Optional[List[List[str]]] map_names: Optional[List[str]] maps: Optional[Tensor] maps_resolution: Optional[Tensor] - vector_maps: Optional[List[VectorMap]] - rasters_from_world_tf: Optional[Tensor] - agents_from_world_tf: Tensor - scene_ids: Optional[List] - history_pad_dir: PadDirection - extras: Dict[str, Tensor] - + lane_xyh: Optional[Tensor] + lane_adj: Optional[Tensor] + lane_ids: Optional[List[List[str]]] + lane_mask: Optional[Tensor] + road_edge_xyzh: Optional[Tensor] + road_edge_xyzh: Optional[Tensor] = None + def to(self, device) -> None: excl_vals = { "data_idx", @@ -63,6 +74,7 @@ def to(self, device) -> None: "scene_ids", "history_pad_dir", "extras", + "lane_ids", } for val in vars(self).keys(): tensor_val = getattr(self, val) @@ -131,22 +143,39 @@ def filter_batch(self, filter_mask: torch.Tensor) -> AgentBatch: neigh_fut_extents=_filter(self.neigh_fut_extents), neigh_fut_len=_filter(self.neigh_fut_len), robot_fut=_filter(self.robot_fut) if self.robot_fut is not None else None, - robot_fut_len=_filter(self.robot_fut_len) - if self.robot_fut_len is not None - else None, - map_names=_filter_tensor_or_list(self.map_names) - if self.map_names is not None - else None, + robot_fut_len=( + _filter(self.robot_fut_len) if self.robot_fut_len is not None else None + ), + map_names=( + _filter_tensor_or_list(self.map_names) + if self.map_names is not None + else None + ), maps=_filter(self.maps) if self.maps is not None else None, - maps_resolution=_filter(self.maps_resolution) - if self.maps_resolution is not None - else None, - vector_maps=_filter(self.vector_maps) - if self.vector_maps is not None - else None, - rasters_from_world_tf=_filter(self.rasters_from_world_tf) - if self.rasters_from_world_tf is not None - else None, + maps_resolution=( + _filter(self.maps_resolution) + if self.maps_resolution is not None + else None + ), + vector_maps=( + _filter_tensor_or_list(self.vector_maps) + if self.vector_maps is not None + else None + ), + rasters_from_world_tf=( + _filter(self.rasters_from_world_tf) + if self.rasters_from_world_tf is not None + else None + ), + lane_xyh=_filter(self.lane_xyh) if self.lane_xyh is not None else None, + lane_adj=_filter(self.lane_adj) if self.lane_adj is not None else None, + lane_ids=self.lane_ids, + lane_mask=_filter(self.lane_mask) if self.lane_mask is not None else None, + road_edge_xyzh=( + _filter(self.road_edge_xyzh) + if self.road_edge_xyzh is not None + else None + ), agents_from_world_tf=_filter(self.agents_from_world_tf), scene_ids=_filter_tensor_or_list(self.scene_ids), history_pad_dir=self.history_pad_dir, @@ -155,6 +184,51 @@ def filter_batch(self, filter_mask: torch.Tensor) -> AgentBatch: }, ) + def to_scene_batch(self, agent_ind: int) -> SceneBatch: + """ + Converts AgentBatch to SeceneBatch by combining neighbors and agent. + + The agent of AgentBatch will be treated as if it was the last neighbor. + self.extras will be simply copied over, any custom conversion must be + implemented externally. + """ + + batch_size = self.neigh_hist.shape[0] + num_neigh = self.neigh_hist.shape[1] + + combine = lambda neigh, agent: torch.cat((neigh, agent.unsqueeze(0)), dim=0) + combine_list = lambda neigh, agent: neigh + [agent] + + return SceneBatch( + data_idx=self.data_idx, + scene_ts=self.scene_ts, + dt=self.dt, + num_agents=self.num_neigh + 1, + agent_type=combine(self.neigh_types, self.agent_type), + centered_agent_state=self.curr_agent_state, # TODO this is not actually the agent but the `global` coordinate frame + agent_names=combine_list( + ["UNKNOWN" for _ in range(num_neigh)], self.agent_name + ), + agent_hist=combine(self.neigh_hist, self.agent_hist), + agent_hist_extent=combine(self.neigh_hist_extents, self.agent_hist_extent), + agent_hist_len=combine(self.neigh_hist_len, self.agent_hist_len), + agent_fut=combine(self.neigh_fut, self.agent_fut), + agent_fut_extent=combine(self.neigh_fut_extents, self.agent_fut_extent), + agent_fut_len=combine(self.neigh_fut_len, self.agent_fut_len), + robot_fut=self.robot_fut, + robot_fut_len=self.robot_fut_len, + map_names=self.map_names, # TODO + maps=self.maps, + maps_resolution=self.maps_resolution, + vector_maps=self.vector_maps, + rasters_from_world_tf=self.rasters_from_world_tf, + centered_agent_from_world_tf=self.agents_from_world_tf, + centered_world_from_agent_tf=torch.linalg.inv(self.agents_from_world_tf), + scene_ids=self.scene_ids, + history_pad_dir=self.history_pad_dir, + extras=self.extras, + ) + @dataclass class SceneBatch: @@ -164,34 +238,48 @@ class SceneBatch: num_agents: Tensor agent_type: Tensor centered_agent_state: StateTensor - agent_names: List[str] + agent_names: List[List[str]] + track_ids: Optional[List[List[str]]] agent_hist: StateTensor agent_hist_extent: Tensor agent_hist_len: Tensor agent_fut: StateTensor agent_fut_extent: Tensor agent_fut_len: Tensor + centered_agent_from_world_tf: Tensor + centered_world_from_agent_tf: Tensor + history_pad_dir: PadDirection + vector_maps: Optional[List[VectorMap]] + lane_ids: Optional(List[List[str]]) robot_fut: Optional[StateTensor] robot_fut_len: Optional[Tensor] map_names: Optional[Tensor] maps: Optional[Tensor] maps_resolution: Optional[Tensor] - vector_maps: Optional[List[VectorMap]] + lane_xyh: Optional[Tensor] + lane_adj: Optional[Tensor] + lane_mask: Optional[Tensor] + road_edge_xyzh: Optional[Tensor] rasters_from_world_tf: Optional[Tensor] - centered_agent_from_world_tf: Tensor - centered_world_from_agent_tf: Tensor scene_ids: Optional[List] - history_pad_dir: PadDirection + extras: Dict[str, Tensor] def to(self, device) -> None: excl_vals = { + "num_agents", "agent_names", + "track_ids", + "agent_type", + "agent_hist_len", + "agent_fut_len", + "robot_fut_len", "map_names", "vector_maps", "history_pad_dir", "scene_ids", "extras", + "lane_ids", } for val in vars(self).keys(): @@ -205,6 +293,31 @@ def to(self, device) -> None: self.extras[key] = val.__to__(device, non_blocking=True) else: self.extras[key] = val.to(device, non_blocking=True) + return self + + def astype(self, dtype) -> None: + new_obj = replace(self) + excl_vals = { + "num_agents", + "agent_names", + "track_ids", + "agent_type", + "agent_hist_len", + "agent_fut_len", + "robot_fut_len", + "map_names", + "vector_maps", + "history_pad_dir", + "scene_ids", + "extras", + "lane_ids", + } + + for val in vars(self).keys(): + tensor_val = getattr(self, val) + if val not in excl_vals and tensor_val is not None: + setattr(new_obj, val, tensor_val.type(dtype)) + return new_obj def agent_types(self) -> List[AgentType]: unique_types: Tensor = torch.unique(self.agent_type) @@ -214,13 +327,33 @@ def agent_types(self) -> List[AgentType]: if unique_type >= 0 ] - def for_agent_type(self, agent_type: AgentType) -> SceneBatch: - match_type = self.agent_type == agent_type - return self.filter_batch(match_type) + def copy(self): + # Shallow copy + return replace(self) - def filter_batch(self, filter_mask: torch.tensor) -> SceneBatch: + def convert_pad_direction(self, pad_dir: PadDirection) -> SceneBatch: + if self.history_pad_dir == pad_dir: + return self + batch: SceneBatch = self.copy() + if self.history_pad_dir == PadDirection.BEFORE: + # n, n, -2 , -1, 0 --> -2, -1, 0, n, n + shifts = batch.agent_hist_len + else: + # -2, -1, 0, n, n --> n, n, -2 , -1, 0 + shifts = -batch.agent_hist_len + batch.agent_hist = roll_with_tensor(batch.agent_hist, shifts, dim=-2) + batch.agent_hist_extent = roll_with_tensor( + batch.agent_hist_extent, shifts, dim=-2 + ) + batch.history_pad_dir = pad_dir + return batch + + def filter_batch(self, filter_mask: torch.Tensor) -> SceneBatch: """Build a new batch with elements for which filter_mask[i] == True.""" + if filter_mask.ndim != 1: + raise ValueError("Expected 1d filter mask.") + # Some of the tensors might be on different devices, so we define some convenience functions # to make sure the filter_mask is always on the same device as the tensor we are indexing. filter_mask_dict = {} @@ -229,7 +362,10 @@ def filter_batch(self, filter_mask: torch.tensor) -> SceneBatch: self.agent_hist.device ) - _filter = lambda tensor: tensor[filter_mask_dict[str(tensor.device)]] + # Use tensor.__class__ to keep TensorState. + _filter = lambda tensor: tensor.__class__( + tensor[filter_mask_dict[str(tensor.device)]] + ) _filter_tensor_or_list = lambda tensor_or_list: ( _filter(tensor_or_list) if isinstance(tensor_or_list, torch.Tensor) @@ -248,6 +384,8 @@ def filter_batch(self, filter_mask: torch.tensor) -> SceneBatch: dt=_filter(self.dt), num_agents=_filter(self.num_agents), agent_type=_filter(self.agent_type), + agent_names=_filter_tensor_or_list(self.agent_names), + track_ids=_filter_tensor_or_list(self.track_ids), centered_agent_state=_filter(self.centered_agent_state), agent_hist=_filter(self.agent_hist), agent_hist_extent=_filter(self.agent_hist_extent), @@ -256,29 +394,45 @@ def filter_batch(self, filter_mask: torch.tensor) -> SceneBatch: agent_fut_extent=_filter(self.agent_fut_extent), agent_fut_len=_filter(self.agent_fut_len), robot_fut=_filter(self.robot_fut) if self.robot_fut is not None else None, - robot_fut_len=_filter(self.robot_fut_len) - if self.robot_fut_len is not None - else None, - map_names=_filter_tensor_or_list(self.map_names) - if self.map_names is not None - else None, + robot_fut_len=( + _filter(self.robot_fut_len) if self.robot_fut_len is not None else None + ), + map_names=( + _filter_tensor_or_list(self.map_names) + if self.map_names is not None + else None + ), maps=_filter(self.maps) if self.maps is not None else None, - maps_resolution=_filter(self.maps_resolution) - if self.maps_resolution is not None - else None, - vector_maps=_filter(self.vector_maps) - if self.vector_maps is not None - else None, - rasters_from_world_tf=_filter(self.rasters_from_world_tf) - if self.rasters_from_world_tf is not None - else None, + maps_resolution=( + _filter(self.maps_resolution) + if self.maps_resolution is not None + else None + ), + vector_maps=( + _filter_tensor_or_list(self.vector_maps) + if self.vector_maps is not None + else None + ), + lane_xyh=_filter(self.lane_xyh) if self.lane_xyh is not None else None, + lane_adj=_filter(self.lane_adj) if self.lane_adj is not None else None, + lane_ids=self.lane_ids, + lane_mask=_filter(self.lane_mask) if self.lane_mask is not None else None, + road_edge_xyzh=( + _filter(self.road_edge_xyzh) + if self.road_edge_xyzh is not None + else None + ), + rasters_from_world_tf=( + _filter(self.rasters_from_world_tf) + if self.rasters_from_world_tf is not None + else None + ), centered_agent_from_world_tf=_filter(self.centered_agent_from_world_tf), centered_world_from_agent_tf=_filter(self.centered_world_from_agent_tf), scene_ids=_filter_tensor_or_list(self.scene_ids), history_pad_dir=self.history_pad_dir, extras={ - key: _filter_tensor_or_list(val, filter_mask) - for key, val in self.extras.items() + key: _filter_tensor_or_list(val) for key, val in self.extras.items() }, ) @@ -303,10 +457,8 @@ def to_agent_batch(self, agent_inds: torch.Tensor) -> AgentBatch: others_mask = torch.ones((batch_size, num_agents), dtype=torch.bool) others_mask[batch_inds, agent_inds] = False index_agent = lambda x: x[batch_inds, agent_inds] if x is not None else None - index_agent_list = ( - lambda xlist: [x[ind] for x, ind in zip(xlist, agent_inds)] - if xlist is not None - else None + index_agent_list = lambda xlist: ( + [x[ind] for x, ind in zip(xlist, agent_inds)] if xlist is not None else None ) index_neighbors = lambda x: x[others_mask].reshape( [ @@ -321,31 +473,105 @@ def to_agent_batch(self, agent_inds: torch.Tensor) -> AgentBatch: scene_ts=self.scene_ts, dt=self.dt, agent_name=index_agent_list(self.agent_names), + track_ids=index_agent_list(self.track_ids), agent_type=index_agent(self.agent_type), curr_agent_state=self.centered_agent_state, # TODO this is not actually the agent but the `global` coordinate frame - agent_hist=index_agent(self.agent_hist), + agent_hist=StateTensor.from_array( + index_agent(self.agent_hist), self.agent_hist._format + ), agent_hist_extent=index_agent(self.agent_hist_extent), agent_hist_len=index_agent(self.agent_hist_len), - agent_fut=index_agent(self.agent_fut), + agent_fut=StateTensor.from_array( + index_agent(self.agent_fut), self.agent_fut._format + ), agent_fut_extent=index_agent(self.agent_fut_extent), agent_fut_len=index_agent(self.agent_fut_len), num_neigh=self.num_agents - 1, neigh_types=index_neighbors(self.agent_type), - neigh_hist=index_neighbors(self.agent_hist), + neigh_hist=StateTensor.from_array( + index_neighbors(self.agent_hist), self.agent_hist._format + ), neigh_hist_extents=index_neighbors(self.agent_hist_extent), neigh_hist_len=index_neighbors(self.agent_hist_len), - neigh_fut=index_neighbors(self.agent_fut), + neigh_fut=StateTensor.from_array( + index_neighbors(self.agent_fut), self.agent_fut._format + ), neigh_fut_extents=index_neighbors(self.agent_fut_extent), neigh_fut_len=index_neighbors(self.agent_fut_len), robot_fut=self.robot_fut, robot_fut_len=self.robot_fut_len, - map_names=index_agent_list(self.map_names), - maps=index_agent(self.maps), - vector_maps=index_agent(self.vector_maps), - maps_resolution=index_agent(self.maps_resolution), - rasters_from_world_tf=index_agent(self.rasters_from_world_tf), + map_names=self.map_names, + maps=self.maps, + vector_maps=self.vector_maps, + lane_xyh=self.lane_xyh, + lane_adj=self.lane_adj, + lane_ids=self.lane_ids, + lane_mask=self.lane_mask, + road_edge_xyzh=self.road_edge_xyzh, + maps_resolution=self.maps_resolution, + rasters_from_world_tf=self.rasters_from_world_tf, agents_from_world_tf=self.centered_agent_from_world_tf, scene_ids=self.scene_ids, history_pad_dir=self.history_pad_dir, extras=self.extras, ) + + def apply_transform( + self, tf: torch.Tensor, dtype: Optional[torch.dtype] = None + ) -> SceneBatch: + """ + Applies a transformation matrix to all coordinates stored in the SceneBatch. + + Returns a shallow copy, only coordinate fields are replaced. + self.extras will be simply copied over (shallow copy), any custom conversion must be + implemented externally. + """ + assert tf.ndim == 3 # b, 3, 3 + assert tf.shape[-1] == 3 and tf.shape[-1] == 3 + assert ( + tf.dtype == torch.double + ) # tf should be double precision, otherwise we have large numerical errors + if dtype is None: + dtype = self.agent_hist.dtype + + # Shallow copy + batch: SceneBatch = replace(self) + + # TODO support generic format + assert batch.agent_hist._format == "x,y,xd,yd,xdd,ydd,s,c" + assert batch.agent_fut._format == "x,y,xd,yd,xdd,ydd,s,c" + state_class = batch.agent_hist.__class__ + + # Transforms + batch.agent_hist = state_class( + batch_nd_transform_xyvvaahh_pt(batch.agent_hist.double(), tf).type(dtype) + ) + batch.agent_fut = state_class( + batch_nd_transform_xyvvaahh_pt(batch.agent_fut.double(), tf).type(dtype) + ) + batch.rasters_from_world_tf = ( + tf.unsqueeze(1) @ batch.rasters_from_world_tf + if batch.rasters_from_world_tf is not None + else None + ) + batch.centered_agent_from_world_tf = tf @ batch.centered_agent_from_world_tf + centered_world_from_agent_tf = torch.linalg.inv( + batch.centered_agent_from_world_tf + ) + if batch.lane_xyh is not None: + batch.lane_xyh = transform_xyh_torch(batch.lane_xyh.double(), tf).type( + dtype + ) + if batch.road_edge_xyzh is not None: + batch.road_edge_xyzh = transform_xyh_torch( + batch.road_edge_xyzh.double(), tf + ).type(dtype) + # sanity check + assert torch.isclose( + batch.centered_world_from_agent_tf @ torch.linalg.inv(tf), + centered_world_from_agent_tf, + atol=1e-5, + ).all() + batch.centered_world_from_agent_tf = centered_world_from_agent_tf + + return batch diff --git a/src/trajdata/data_structures/batch_element.py b/src/trajdata/data_structures/batch_element.py index cf61764..b0f0381 100644 --- a/src/trajdata/data_structures/batch_element.py +++ b/src/trajdata/data_structures/batch_element.py @@ -3,13 +3,21 @@ from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np - from trajdata.caching import SceneCache from trajdata.data_structures.agent import AgentMetadata, AgentType from trajdata.data_structures.scene import SceneTime, SceneTimeAgent from trajdata.data_structures.state import StateArray from trajdata.maps import MapAPI, RasterizedMapPatch, VectorMap +from trajdata.utils.arr_utils import ( + get_close_lanes, + get_close_road_edges, + transform_xyh_np, +) +from trajdata.utils.map_utils import LaneSegRelation from trajdata.utils.state_utils import convert_to_frame_state, transform_from_frame +from trajdata.utils.arr_utils import transform_xyh_np, get_close_lanes + +from trajdata.utils.map_utils import LaneSegRelation class AgentBatchElement: @@ -35,6 +43,7 @@ def __init__( standardize_data: bool = False, standardize_derivatives: bool = False, max_neighbor_num: Optional[int] = None, + lane_graph_cache: Optional[dict] = None, ) -> None: self.cache: SceneCache = cache self.data_index: int = data_index @@ -54,6 +63,8 @@ def __init__( else: self.curr_agent_state_np = raw_state + incl_z = self.curr_agent_state_np.has_attr("position3d") + self.standardize_data = standardize_data if self.standardize_data: # Request cache to return observations relative to current agent @@ -68,16 +79,26 @@ def __init__( agent_pos = self.curr_agent_state_np.position agent_heading_vector = self.curr_agent_state_np.heading_vector cos_agent, sin_agent = agent_heading_vector[0], agent_heading_vector[1] - world_from_agent_tf: np.ndarray = np.array( - [ - [cos_agent, -sin_agent, agent_pos[0]], - [sin_agent, cos_agent, agent_pos[1]], - [0.0, 0.0, 1.0], - ] - ) + if incl_z: + world_from_agent_tf: np.ndarray = np.array( + [ + [cos_agent, -sin_agent, 0.0, agent_pos[0]], + [sin_agent, cos_agent, 0.0, agent_pos[1]], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + ) + else: + world_from_agent_tf: np.ndarray = np.array( + [ + [cos_agent, -sin_agent, agent_pos[0]], + [sin_agent, cos_agent, agent_pos[1]], + [0.0, 0.0, 1.0], + ] + ) self.agent_from_world_tf: np.ndarray = np.linalg.inv(world_from_agent_tf) else: - self.agent_from_world_tf: np.ndarray = np.eye(3) + self.agent_from_world_tf: np.ndarray = np.eye(4 if incl_z else 3) ### AGENT-SPECIFIC DATA ### self.agent_history_np, self.agent_history_extent_np = self.get_agent_history( @@ -149,15 +170,55 @@ def distance_limit(agent_types: np.ndarray, target_type: int) -> np.ndarray: if map_api is not None: self.vec_map = map_api.get_map( map_name, - self.cache - if self.cache.is_traffic_light_data_cached( - # Is the original dt cached? If so, we can continue by - # interpolating time to get whatever the user desires. - self.cache.scene.env_metadata.dt - ) - else None, + ( + self.cache + if self.cache.is_traffic_light_data_cached( + # Is the original dt cached? If so, we can continue by + # interpolating time to get whatever the user desires. + self.cache.scene.env_metadata.dt + ) + else None + ), **vector_map_params if vector_map_params is not None else None, ) + if vector_map_params.get("calc_lane_graph", False): + # not tested + ego_xyh = np.concatenate( + [ + self.curr_agent_state_np.position, + self.curr_agent_state_np.heading, + ] + ) + num_pts = vector_map_params.get("num_lane_pts", 30) + max_num_lanes = vector_map_params.get("max_num_lanes", 20) + remove_single_successor = vector_map_params.get( + "remove_single_successor", False + ) + radius = vector_map_params.get("radius", 100) + ( + self.num_lanes, + self.lane_xyh, + self.lane_adj, + self.lane_ids, + self.road_edge_xyzh, + ) = gen_lane_graph( + self.vec_map, + ego_xyh, + self.agent_from_world_tf, + num_pts, + max_num_lanes, + radius, + remove_single_successor=remove_single_successor, + get_road_edges=vector_map_params.get("incl_road_edges", False), + lane_graph_cache=lane_graph_cache, + ) + + else: + self.lane_xyh = None + self.lane_adj = None + self.lane_ids = list() + self.num_lanes = 0 + self.road_edge_xyzh = None self.scene_id = scene_time_agent.scene.name @@ -271,47 +332,7 @@ def get_robot_current_and_future( return robot_curr_and_fut_np def get_agent_map_patch(self, patch_params: Dict[str, int]) -> RasterizedMapPatch: - world_x, world_y = self.curr_agent_state_np.position - desired_patch_size: int = patch_params["map_size_px"] - resolution: float = patch_params["px_per_m"] - offset_xy: Tuple[float, float] = patch_params.get("offset_frac_xy", (0.0, 0.0)) - return_rgb: bool = patch_params.get("return_rgb", True) - no_map_fill_val: float = patch_params.get("no_map_fill_value", 0.0) - - if self.standardize_data: - heading = self.curr_agent_state_np.heading[0] - patch_data, raster_from_world_tf, has_data = self.cache.load_map_patch( - world_x, - world_y, - desired_patch_size, - resolution, - offset_xy, - heading, - return_rgb, - rot_pad_factor=sqrt(2), - no_map_val=no_map_fill_val, - ) - else: - heading = 0.0 - patch_data, raster_from_world_tf, has_data = self.cache.load_map_patch( - world_x, - world_y, - desired_patch_size, - resolution, - offset_xy, - heading, - return_rgb, - no_map_val=no_map_fill_val, - ) - - return RasterizedMapPatch( - data=patch_data, - rot_angle=heading, - crop_size=desired_patch_size, - resolution=resolution, - raster_from_world_tf=raster_from_world_tf, - has_data=has_data, - ) + raise NotImplementedError() class SceneBatchElement: @@ -336,6 +357,7 @@ def __init__( standardize_data: bool = False, standardize_derivatives: bool = False, max_agent_num: Optional[int] = None, + lane_graph_cache: Optional[dict] = None, ) -> None: self.cache: SceneCache = cache self.data_index = data_index @@ -362,6 +384,8 @@ def __init__( else: self.centered_agent_state_np = raw_state + incl_z = self.centered_agent_state_np.has_attr("position3d") + self.standardize_data = standardize_data if self.standardize_data: @@ -378,19 +402,30 @@ def __init__( agent_heading: float = self.centered_agent_state_np.heading[0] cos_agent, sin_agent = np.cos(agent_heading), np.sin(agent_heading) - self.centered_world_from_agent_tf: np.ndarray = np.array( - [ - [cos_agent, -sin_agent, agent_pos[0]], - [sin_agent, cos_agent, agent_pos[1]], - [0.0, 0.0, 1.0], - ] - ) + + if incl_z: + self.centered_world_from_agent_tf: np.ndarray = np.array( + [ + [cos_agent, -sin_agent, 0.0, agent_pos[0]], + [sin_agent, cos_agent, 0.0, agent_pos[1]], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + ) + else: + self.centered_world_from_agent_tf: np.ndarray = np.array( + [ + [cos_agent, -sin_agent, agent_pos[0]], + [sin_agent, cos_agent, agent_pos[1]], + [0.0, 0.0, 1.0], + ] + ) self.centered_agent_from_world_tf: np.ndarray = np.linalg.inv( self.centered_world_from_agent_tf ) else: - self.centered_agent_from_world_tf: np.ndarray = np.eye(3) - self.centered_world_from_agent_tf: np.ndarray = np.eye(3) + self.centered_agent_from_world_tf: np.ndarray = np.eye(4 if incl_z else 3) + self.centered_world_from_agent_tf: np.ndarray = np.eye(4 if incl_z else 3) ### NEIGHBOR-SPECIFIC DATA ### def distance_limit(agent_types: np.ndarray, target_type: int) -> np.ndarray: @@ -404,7 +439,7 @@ def distance_limit(agent_types: np.ndarray, target_type: int) -> np.ndarray: nearby_agents, self.agent_types_np = self.get_nearby_agents( scene_time, self.centered_agent, distance_limit ) - + self.agents = nearby_agents self.num_agents = len(nearby_agents) self.agent_names = [agent.name for agent in nearby_agents] ( @@ -440,6 +475,44 @@ def distance_limit(agent_types: np.ndarray, target_type: int) -> np.ndarray: self.cache if self.cache.is_traffic_light_data_cached() else None, **vector_map_params if vector_map_params is not None else None, ) + if vector_map_params.get("calc_lane_graph", False): + # not tested + ego_xyh = np.concatenate( + [ + self.centered_agent_state_np.position, + self.centered_agent_state_np.heading, + ] + ) + num_pts = vector_map_params.get("num_lane_pts", 30) + max_num_lanes = vector_map_params.get("max_num_lanes", 20) + remove_single_successor = vector_map_params.get( + "remove_single_successor", False + ) + ( + self.num_lanes, + self.lane_xyh, + self.lane_adj, + self.lane_ids, + self.road_edge_xyzh, + ) = gen_lane_graph( + self.vec_map, + ego_xyh, + self.centered_agent_from_world_tf, + num_pts, + max_num_lanes, + remove_single_successor=remove_single_successor, + get_road_edges=vector_map_params.get("incl_road_edges", False), + lane_graph_cache=lane_graph_cache, + ) + + else: + self.lane_xyh = None + self.lane_adj = None + self.lane_ids = list() + self.num_lanes = 0 + self.road_edge_xyzh = None + + self.scene_id = scene_time.scene.name ### ROBOT DATA ### self.robot_future_np: Optional[StateArray] = None @@ -583,6 +656,132 @@ def get_robot_current_and_future( return robot_curr_and_fut_np +def gen_lane_graph( + vec_map, + ego_xyh, + agent_from_world, + num_pts=20, + max_num_lanes=15, + radius=150, + get_road_edges=False, + remove_single_successor=False, + lane_graph_cache=None, +): + close_lanes, dis = get_close_lanes(radius, ego_xyh, vec_map, num_pts) + lanes_by_id = {lane.id: lane for lane in close_lanes} + dis_by_id = {lane.id: dis[i] for i, lane in enumerate(close_lanes)} + if lane_graph_cache is not None: + idx = np.argsort(dis)[:max_num_lanes] + lane_ids = [close_lanes[i].id for i in idx] + num_lanes = len(lane_ids) + cache_idx = [lane_graph_cache.lane_ids.index(id) for id in lane_ids] + lane_xyh = lane_graph_cache.lane_centerlines[cache_idx] + lane_adj = lane_graph_cache.lane_connectivity[cache_idx][:, cache_idx] + lane_xyh = transform_xyh_np( + lane_xyh.reshape(-1, 3), agent_from_world[None] + ).reshape(num_lanes, -1, 3) + road_edge_xyzh = None + return num_lanes, lane_xyh, lane_adj, lane_ids, road_edge_xyzh + + if remove_single_successor: + for lane in close_lanes: + while len(lane.next_lanes) == 1: + # if there are more than one succeeding lanes, then we abort the merging + next_id = list(lane.next_lanes)[0] + + if next_id in lanes_by_id: + next_lane = lanes_by_id[next_id] + shared_next = False + for id in next_lane.prev_lanes: + if id != lane.id and id in lanes_by_id: + shared_next = True + break + if shared_next: + # if the next lane shares two prev lanes in the close_lanes, then we abort the merging + break + lane.combine_next(lanes_by_id[next_id]) + dis_by_id[lane.id] = min(dis_by_id[lane.id], dis_by_id[next_id]) + lanes_by_id.pop(next_id) + else: + break + close_lanes = list(lanes_by_id.values()) + dis = np.array([dis_by_id[lane.id] for lane in close_lanes]) + num_lanes = len(close_lanes) + if num_lanes > max_num_lanes: + idx = dis.argsort()[:max_num_lanes] + close_lanes = [lane for i, lane in enumerate(close_lanes) if i in idx] + num_lanes = max_num_lanes + + if num_lanes > 0: + lane_xyh = list() + lane_adj = np.zeros([len(close_lanes), len(close_lanes)], dtype=np.int32) + lane_ids = [lane.id for lane in close_lanes] + + for i, lane in enumerate(close_lanes): + center = lane.center.interpolate(num_pts).points[:, [0, 1, 3]] + center_local = transform_xyh_np( + # Add pts dimension and select x, y and homogeneous dimension + center, + agent_from_world[None][..., [0, 1, -1]][..., [0, 1, -1], :], + ) + lane_xyh.append(center_local) + # construct lane adjacency matrix + for adj_lane_id in lane.next_lanes: + if adj_lane_id in lane_ids: + lane_adj[ + i, lane_ids.index(adj_lane_id) + ] = LaneSegRelation.NEXT.value + + for adj_lane_id in lane.prev_lanes: + if adj_lane_id in lane_ids: + lane_adj[ + i, lane_ids.index(adj_lane_id) + ] = LaneSegRelation.PREV.value + + for adj_lane_id in lane.adj_lanes_left: + if adj_lane_id in lane_ids: + lane_adj[ + i, lane_ids.index(adj_lane_id) + ] = LaneSegRelation.LEFT.value + + for adj_lane_id in lane.adj_lanes_right: + if adj_lane_id in lane_ids: + lane_adj[ + i, lane_ids.index(adj_lane_id) + ] = LaneSegRelation.RIGHT.value + lane_xyh = np.stack(lane_xyh, axis=0) + lane_xyh = lane_xyh + lane_adj = lane_adj + else: + lane_xyh = np.zeros([0, num_pts, 3]) + lane_adj = np.zeros([0, 0]) + lane_ids = list() + + road_edge_xyzh = None + if get_road_edges: + close_road_edges, re_dis = get_close_road_edges( + radius, ego_xyh, vec_map, num_pts + ) + num_road_edges = len(close_road_edges) + if num_road_edges > max_num_lanes: + idx = re_dis.argsort()[:max_num_lanes] + close_road_edges = [ + road_edge for i, road_edge in enumerate(close_road_edges) if i in idx + ] + num_road_edges = max_num_lanes + + if num_road_edges > 0: + road_edge_xyzh = list() + for i, road_edge in enumerate(close_road_edges): + polyline = road_edge.polyline.interpolate(num_pts).points + # TODO: What to do when `agent_from_world` doesn't have z coord? + polyline_local = transform_xyh_np(polyline, agent_from_world[None]) + road_edge_xyzh.append(polyline_local) + road_edge_xyzh = np.stack(road_edge_xyzh, axis=0) + + return num_lanes, lane_xyh, lane_adj, lane_ids, road_edge_xyzh + + def is_agent_stationary(cache: SceneCache, agent_info: AgentMetadata) -> bool: # Agent is considered stationary if it moves less than 1m between the first and last valid timestep. first_state: StateArray = cache.get_state( diff --git a/src/trajdata/data_structures/collation.py b/src/trajdata/data_structures/collation.py index f08b3b6..6729847 100644 --- a/src/trajdata/data_structures/collation.py +++ b/src/trajdata/data_structures/collation.py @@ -4,10 +4,8 @@ import numpy as np import torch import torch.nn.functional as F -from kornia.geometry.transform import rotate from torch import Tensor from torch.nn.utils.rnn import pad_sequence - from trajdata.augmentation import BatchAugmentation from trajdata.data_structures.batch import AgentBatch, SceneBatch from trajdata.data_structures.batch_element import AgentBatchElement, SceneBatchElement @@ -34,250 +32,53 @@ def _collate_data(elems): else: return torch.as_tensor(np.stack(elems)) - -def raster_map_collate_fn_agent( - batch_elems: List[AgentBatchElement], -): - if batch_elems[0].map_patch is None: - return None, None, None, None - - map_names = [batch_elem.map_name for batch_elem in batch_elems] - - # Ensuring that any empty map patches have the correct number of channels - # prior to collation. - has_data: np.ndarray = np.array( - [batch_elem.map_patch.has_data for batch_elem in batch_elems], - dtype=bool, - ) - no_data: np.ndarray = ~has_data - - patch_channels: np.ndarray = np.array( - [batch_elem.map_patch.data.shape[0] for batch_elem in batch_elems], - dtype=int, - ) - - desired_num_channels: int - if np.any(has_data): - # If any of the batch elements' maps have data, then use - # their number of channels as the reference. - unique_num_channels = np.unique(patch_channels[has_data]) - else: - # All map patches in this batch are from datasets with no maps. - unique_num_channels = np.unique(patch_channels) - - if unique_num_channels.size > 1: - raise ValueError( - "Maps must all have the same number of channels in a batch, " - f"but found maps with {unique_num_channels.tolist()} channels." - ) - - desired_num_channels = unique_num_channels[0].item() - - # Getting the map patch data and preparing it for batched rotation. - patch_size_y, patch_size_x = batch_elems[0].map_patch.data.shape[-2:] - patch_data: Tensor = torch.empty( - (len(batch_elems), desired_num_channels, patch_size_y, patch_size_x) - ) - - if np.any(has_data): - patch_data[has_data] = torch.as_tensor( - np.stack( - [ - batch_elem.map_patch.data - for idx, batch_elem in enumerate(batch_elems) - if has_data[idx] - ] - ), - dtype=torch.float, - ) - - if np.any(no_data): - patch_data[no_data] = torch.as_tensor( - np.stack( - [ - batch_elem.map_patch.data - for idx, batch_elem in enumerate(batch_elems) - if no_data[idx] - ] - ), - dtype=torch.float, - ).expand(-1, desired_num_channels, -1, -1) - - patch_size: int = batch_elems[0].map_patch.crop_size - assert all( - batch_elem.map_patch.crop_size == patch_size for batch_elem in batch_elems - ) - - rot_angles: Tensor = torch.as_tensor( - [batch_elem.map_patch.rot_angle for batch_elem in batch_elems], - dtype=torch.float, - ) - resolution: Tensor = torch.as_tensor( - [batch_elem.map_patch.resolution for batch_elem in batch_elems], - dtype=torch.float, - ) - rasters_from_world_tf: Tensor = torch.as_tensor( - np.stack( - [batch_elem.map_patch.raster_from_world_tf for batch_elem in batch_elems] - ), - dtype=torch.float, - ) - - center_y: int = patch_size_y // 2 - center_x: int = patch_size_x // 2 - half_extent: int = patch_size // 2 - - if ( - torch.count_nonzero(rot_angles) == 0 - and patch_size == patch_data.shape[-1] == patch_data.shape[-2] - ): - rasters_from_world_tf = torch.bmm( - torch.tensor( - [ - [ - [1.0, 0.0, half_extent], - [0.0, 1.0, half_extent], - [0.0, 0.0, 1.0], - ] - ], - dtype=rasters_from_world_tf.dtype, - device=rasters_from_world_tf.device, - ).expand((rasters_from_world_tf.shape[0], -1, -1)), - rasters_from_world_tf, - ) - - rot_crop_patches: Tensor = patch_data - +def _collate_lane_graph(elems): + num_lanes = [elem.num_lanes for elem in elems] + bs = len(elems) + M = max(num_lanes) + lane_xyh = np.zeros([bs,M,*elems[0].lane_xyh.shape[-2:]]) + lane_adj = np.zeros([bs,M,M],dtype=int) + lane_ids = list() + lane_mask = np.zeros([bs, M], dtype=int) + for i,elem in enumerate(elems): + lane_xyh[i,:num_lanes[i]] = elem.lane_xyh + lane_adj[i,:num_lanes[i],:num_lanes[i]] = elem.lane_adj + lane_ids.append(elem.lane_ids) + lane_mask[i,:num_lanes[i]] = 1 + + if elems[0].road_edge_xyzh is not None: + assert ( + elems[0].road_edge_xyzh.shape[-1] == 4 + ), "Road edge data must have 4 dimensions: x, y, z, heading. " + num_road_edges = [elem.road_edge_xyzh.shape[0] for elem in elems] + N = max(num_road_edges) + road_edge_xyzh = np.zeros([bs, N, *elems[0].road_edge_xyzh.shape[-2:]]) + for i, elem in enumerate(elems): + road_edge_xyzh[i, : num_road_edges[i]] = elem.road_edge_xyzh else: - # Batch rotating patches by rot_angles. - rot_patches: Tensor = rotate(patch_data, torch.rad2deg(rot_angles)) - - # Center cropping via slicing. - rot_crop_patches: Tensor = rot_patches[ - ..., - center_y - half_extent : center_y + half_extent, - center_x - half_extent : center_x + half_extent, - ] - - rasters_from_world_tf = torch.bmm( - arr_utils.transform_matrices( - -rot_angles, - torch.tensor([[half_extent, half_extent]]).expand( - (rot_angles.shape[0], -1) - ), - ), - rasters_from_world_tf, - ) + road_edge_xyzh = None return ( - map_names, - rot_crop_patches, - resolution, - rasters_from_world_tf, + torch.as_tensor(lane_xyh), + torch.as_tensor(lane_adj), + torch.as_tensor(lane_mask), + lane_ids, + torch.as_tensor(road_edge_xyzh) if road_edge_xyzh is not None else None, ) +def raster_map_collate_fn_agent( + batch_elems: List[AgentBatchElement], +): + raise NotImplementedError() + def raster_map_collate_fn_scene( batch_elems: List[SceneBatchElement], max_agent_num: Optional[int] = None, pad_value: Any = np.nan, -) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: - if batch_elems[0].map_patches is None: - return None, None, None, None - - patch_size: int = batch_elems[0].map_patches[0].crop_size - assert all( - batch_elem.map_patches[0].crop_size == patch_size for batch_elem in batch_elems - ) +) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]]: - map_names: List[str] = list() - num_agents: List[int] = list() - agents_rasters_from_world_tfs: List[np.ndarray] = list() - agents_patches: List[np.ndarray] = list() - agents_rot_angles_list: List[float] = list() - agents_res_list: List[float] = list() - - for elem in batch_elems: - map_names.append(elem.map_name) - num_agents.append(min(elem.num_agents, max_agent_num)) - agents_rasters_from_world_tfs += [ - x.raster_from_world_tf for x in elem.map_patches[:max_agent_num] - ] - agents_patches += [x.data for x in elem.map_patches[:max_agent_num]] - agents_rot_angles_list += [ - x.rot_angle for x in elem.map_patches[:max_agent_num] - ] - agents_res_list += [x.resolution for x in elem.map_patches[:max_agent_num]] - - patch_data: Tensor = torch.as_tensor(np.stack(agents_patches), dtype=torch.float) - agents_rot_angles: Tensor = torch.as_tensor( - np.stack(agents_rot_angles_list), dtype=torch.float - ) - agents_rasters_from_world_tf: Tensor = torch.as_tensor( - np.stack(agents_rasters_from_world_tfs), dtype=torch.float - ) - agents_resolution: Tensor = torch.as_tensor( - np.stack(agents_res_list), dtype=torch.float - ) - - patch_size_y, patch_size_x = patch_data.shape[-2:] - center_y: int = patch_size_y // 2 - center_x: int = patch_size_x // 2 - half_extent: int = patch_size // 2 - - if torch.count_nonzero(agents_rot_angles) == 0: - agents_rasters_from_world_tf = torch.bmm( - torch.tensor( - [ - [ - [1.0, 0.0, half_extent], - [0.0, 1.0, half_extent], - [0.0, 0.0, 1.0], - ] - ], - dtype=agents_rasters_from_world_tf.dtype, - device=agents_rasters_from_world_tf.device, - ).expand((agents_rasters_from_world_tf.shape[0], -1, -1)), - agents_rasters_from_world_tf, - ) - - rot_crop_patches = patch_data - else: - agents_rasters_from_world_tf = torch.bmm( - arr_utils.transform_matrices( - -agents_rot_angles, - torch.tensor([[half_extent, half_extent]]).expand( - (agents_rot_angles.shape[0], -1) - ), - ), - agents_rasters_from_world_tf, - ) - - # Batch rotating patches by rot_angles. - rot_patches: Tensor = rotate(patch_data, torch.rad2deg(agents_rot_angles)) - - # Center cropping via slicing. - rot_crop_patches = rot_patches[ - ..., - center_y - half_extent : center_y + half_extent, - center_x - half_extent : center_x + half_extent, - ] - - rot_crop_patches = split_pad_crop( - rot_crop_patches, num_agents, pad_value=pad_value, desired_size=max_agent_num - ) - - agents_rasters_from_world_tf = split_pad_crop( - agents_rasters_from_world_tf, - num_agents, - pad_value=pad_value, - desired_size=max_agent_num, - ) - agents_resolution = split_pad_crop( - agents_resolution, num_agents, pad_value=0, desired_size=max_agent_num - ) - - return map_names, rot_crop_patches, agents_resolution, agents_rasters_from_world_tf + raise NotImplementedError() def agent_collate_fn( @@ -419,17 +220,21 @@ def agent_collate_fn( to_add = max_neigh_history_len - padded_neighbor_histories.shape[-2] padded_neighbor_histories = F.pad( padded_neighbor_histories, - pad=(0, 0, to_add, 0) - if history_pad_dir == arr_utils.PadDirection.BEFORE - else (0, 0, 0, to_add), + pad=( + (0, 0, to_add, 0) + if history_pad_dir == arr_utils.PadDirection.BEFORE + else (0, 0, 0, to_add) + ), mode="constant", value=np.nan, ) padded_neighbor_history_extents = F.pad( padded_neighbor_history_extents, - pad=(0, 0, to_add, 0) - if history_pad_dir == arr_utils.PadDirection.BEFORE - else (0, 0, 0, to_add), + pad=( + (0, 0, to_add, 0) + if history_pad_dir == arr_utils.PadDirection.BEFORE + else (0, 0, 0, to_add) + ), mode="constant", value=np.nan, ) @@ -547,9 +352,11 @@ def agent_collate_fn( to_add: int = hist_len - agent_history_t.shape[-2] agent_history_t = F.pad( agent_history_t, - (0, 0, to_add, 0) - if history_pad_dir == arr_utils.PadDirection.BEFORE - else (0, 0, 0, to_add), + ( + (0, 0, to_add, 0) + if history_pad_dir == arr_utils.PadDirection.BEFORE + else (0, 0, 0, to_add) + ), value=np.nan, ).as_subclass(AgentObsTensor) @@ -557,9 +364,11 @@ def agent_collate_fn( to_add: int = hist_len - agent_history_extent_t.shape[-2] agent_history_extent_t = F.pad( agent_history_extent_t, - (0, 0, to_add, 0) - if history_pad_dir == arr_utils.PadDirection.BEFORE - else (0, 0, 0, to_add), + ( + (0, 0, to_add, 0) + if history_pad_dir == arr_utils.PadDirection.BEFORE + else (0, 0, 0, to_add) + ), value=np.nan, ) @@ -673,6 +482,18 @@ def agent_collate_fn( if batch_elems[0].vec_map is not None: vector_maps = [batch_elem.vec_map for batch_elem in batch_elems] + lane_xyh, lane_adj, lane_mask, lane_ids, road_edge_xyzh = ( + None, + None, + None, + None, + None, + ) + if hasattr(batch_elems[0],"lane_xyh") and batch_elems[0].lane_xyh is not None: + lane_xyh, lane_adj, lane_mask, lane_ids, road_edge_xyzh = _collate_lane_graph( + batch_elems + ) + agents_from_world_tf = torch.as_tensor( np.stack([batch_elem.agent_from_world_tf for batch_elem in batch_elems]), dtype=torch.float, @@ -685,7 +506,7 @@ def agent_collate_fn( extras[key] = _collate_data( [batch_elem.extras[key] for batch_elem in batch_elems] ) - + track_ids = _collate_data([batch_elem.track_id for batch_elem in batch_elems]) if hasattr(batch_elems[0],"track_id") else None batch = AgentBatch( data_idx=data_index_t, scene_ts=scene_ts_t, @@ -711,12 +532,18 @@ def agent_collate_fn( robot_fut_len=robot_future_len, map_names=map_names, maps=map_patches, + lane_xyh=lane_xyh, + lane_adj=lane_adj, + lane_mask=lane_mask, + lane_ids=lane_ids, + road_edge_xyzh=road_edge_xyzh, maps_resolution=maps_resolution, vector_maps=vector_maps, rasters_from_world_tf=rasters_from_world_tf, agents_from_world_tf=agents_from_world_tf, scene_ids=scene_ids, history_pad_dir=history_pad_dir, + track_ids=track_ids, extras=extras, ) @@ -780,6 +607,9 @@ def scene_collate_fn( return_dict: bool, pad_format: str, batch_augments: Optional[List[BatchAugmentation]] = None, + desired_num_agents = None, + desired_hist_len=None, + desired_fut_len=None, ) -> SceneBatch: batch_size: int = len(batch_elems) history_pad_dir: arr_utils.PadDirection = ( @@ -799,6 +629,8 @@ def scene_collate_fn( AgentObsTensor = TORCH_STATE_TYPES[obs_format] max_agent_num: int = max(elem.num_agents for elem in batch_elems) + if desired_num_agents is not None: + max_agent_num = max(max_agent_num,desired_num_agents) centered_agent_state: List[AgentStateTensor] = list() agents_types: List[Tensor] = list() @@ -819,6 +651,10 @@ def scene_collate_fn( max_history_len: int = max(elem.agent_history_lens_np.max() for elem in batch_elems) max_future_len: int = max(elem.agent_future_lens_np.max() for elem in batch_elems) + if desired_hist_len is not None: + max_history_len = max(max_history_len,desired_hist_len) + if desired_fut_len is not None: + max_future_len = max(max_future_len,desired_fut_len) robot_future: List[AgentObsTensor] = list() robot_future_len: Tensor = torch.zeros((batch_size,), dtype=torch.long) @@ -859,17 +695,21 @@ def scene_collate_fn( to_add = max_history_len - padded_agents_histories.shape[-2] padded_agents_histories = F.pad( padded_agents_histories, - pad=(0, 0, to_add, 0) - if history_pad_dir == arr_utils.PadDirection.BEFORE - else (0, 0, 0, to_add), + pad=( + (0, 0, to_add, 0) + if history_pad_dir == arr_utils.PadDirection.BEFORE + else (0, 0, 0, to_add) + ), mode="constant", value=np.nan, ) padded_agents_history_extents = F.pad( padded_agents_history_extents, - pad=(0, 0, to_add, 0) - if history_pad_dir == arr_utils.PadDirection.BEFORE - else (0, 0, 0, to_add), + pad=( + (0, 0, to_add, 0) + if history_pad_dir == arr_utils.PadDirection.BEFORE + else (0, 0, 0, to_add) + ), mode="constant", value=np.nan, ) @@ -950,6 +790,18 @@ def scene_collate_fn( if batch_elems[0].vec_map is not None: vector_maps = [batch_elem.vec_map for batch_elem in batch_elems] + lane_xyh, lane_adj, lane_mask, lane_ids, road_edge_xyzh = ( + None, + None, + None, + None, + None, + ) + if hasattr(batch_elems[0],"lane_xyh") and batch_elems[0].lane_xyh is not None: + lane_xyh, lane_adj, lane_mask, lane_ids, road_edge_xyzh = _collate_lane_graph( + batch_elems + ) + centered_agent_from_world_tf = torch.as_tensor( np.stack( [batch_elem.centered_agent_from_world_tf for batch_elem in batch_elems] @@ -989,6 +841,7 @@ def scene_collate_fn( agent_type=agents_types_t, centered_agent_state=centered_agent_state_t, agent_names=agent_names, + track_ids=None, agent_hist=agents_histories_t, agent_hist_extent=agents_history_extents_t, agent_hist_len=agents_history_len, @@ -999,6 +852,11 @@ def scene_collate_fn( robot_fut_len=robot_future_len, map_names=map_names, maps=map_patches, + lane_xyh=lane_xyh, + lane_adj=lane_adj, + lane_mask=lane_mask, + lane_ids=lane_ids, + road_edge_xyzh=road_edge_xyzh, maps_resolution=maps_resolution, vector_maps=vector_maps, rasters_from_world_tf=rasters_from_world_tf, diff --git a/src/trajdata/data_structures/data_index.py b/src/trajdata/data_structures/data_index.py index 54f70ca..8fe64d7 100644 --- a/src/trajdata/data_structures/data_index.py +++ b/src/trajdata/data_structures/data_index.py @@ -26,7 +26,7 @@ def __init__( ) self._len: int = self._cumulative_lengths[-1].item() - self._scene_paths: np.ndarray = np.array(scene_paths).astype(np.string_) + self._scene_paths: np.ndarray = np.array(scene_paths).astype(np.bytes_) def __len__(self) -> int: return self._len @@ -61,7 +61,7 @@ def __init__( ): agent_ids, agent_times = zip(*scene_data_index) - self._agent_ids.append(np.array(agent_ids).astype(np.string_)) + self._agent_ids.append(np.array(agent_ids).astype(np.bytes_)) agent_ts: np.ndarray = np.stack(agent_times) self._agent_times.append(agent_ts) diff --git a/src/trajdata/data_structures/state.py b/src/trajdata/data_structures/state.py index 9dcc09b..69ae95e 100644 --- a/src/trajdata/data_structures/state.py +++ b/src/trajdata/data_structures/state.py @@ -76,6 +76,15 @@ def y_component(long, lat, c, s): return long * s + lat * c +def _check_format_length_compatible_with_array_dim(array, format): + if array.shape[-1] != len(format.split(",")): + raise ValueError( + f"Array shape {array.shape} incompatible with format {format}: " + f"Array has {array.shape[-1]} entries but format has length " + f"{len(format.split(','))}." + ) + + class State: """ Base class implementing property access to state elements @@ -168,6 +177,23 @@ def as_format(self, new_format: str, create_type=True): return self.from_array(result, new_format) else: return result + + def has_attr(self, attr: str) -> bool: + try: + # Check whether we can compute attr without raising ValueError. + getattr(self, attr) + return True + except ValueError: + return False + + + def has_attr(self, attr: str) -> bool: + try: + # Check whether we can compute attr without raising ValueError. + getattr(self, attr) + return True + except ValueError: + return False def _compute_attr(self, attr: str, _depth: int = MAX_RECURSION_LEVELS): """ @@ -320,6 +346,7 @@ def as_ndarray(self) -> np.ndarray: @classmethod def from_array(cls, array: Array, format: str): + _check_format_length_compatible_with_array_dim(array, format) return array.view(NP_STATE_TYPES[format]) @classmethod @@ -393,11 +420,22 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): elif indices == slice(None): new_class = cls elif isinstance(indices, tuple): - if len(indices) < self.ndim: + if len(indices) < self.ndim and Ellipsis not in indices: new_class = cls elif len(indices) == self.ndim and indices[-1] == slice(None): new_class = cls + if func in (torch.reshape, Tensor.reshape): + original_state_size = args[0].shape[-1] + if len(args) > 1: + # shape passed in as positional argument + # It can be either a separate tuple or a series of ints + shape = args[1] if isinstance(args[1], tuple) else args[1:] + else: + shape = kwargs.get("shape") + if shape[-1] == original_state_size: + new_class = cls + if isinstance(result, Tensor) and new_class != cls: result = result.as_subclass(new_class) @@ -426,6 +464,7 @@ def from_numpy(cls, state: StateArray, **kwargs): @classmethod def from_array(cls, array: Array, format: str): + _check_format_length_compatible_with_array_dim(array, format) return array.as_subclass(TORCH_STATE_TYPES[format]) @classmethod diff --git a/src/trajdata/dataset.py b/src/trajdata/dataset.py index c3667b5..cd138b7 100644 --- a/src/trajdata/dataset.py +++ b/src/trajdata/dataset.py @@ -1,15 +1,24 @@ import gc import json import random -import re import time -import warnings from collections import defaultdict from functools import partial from itertools import chain from os.path import isfile from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Final, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) import dill import numpy as np @@ -45,12 +54,15 @@ agent_utils, env_utils, py_utils, - raster_utils, scene_utils, string_utils, + map_utils, ) from trajdata.utils.parallel_utils import parallel_iapply +# TODO(bivanovic): Move this to a better place in the codebase. +DEFAULT_PX_PER_M: Final[float] = 2.0 + class UnifiedDataset(Dataset): # @profile @@ -115,6 +127,7 @@ def __init__( Callable[..., Union[AgentBatchElement, SceneBatchElement]] ] = (), rank: int = 0, + cache_lane_graphs: bool = False, ) -> None: """Instantiates a PyTorch Dataset object which aggregates data from multiple trajectory forecasting datasets. @@ -157,9 +170,9 @@ def __init__( rank (int, optional): Proccess rank when using torch DistributedDataParallel for multi-GPU training. Only the rank 0 process will be used for caching. """ self.desired_data: List[str] = desired_data - self.scene_description_contains: Optional[ - List[str] - ] = scene_description_contains + self.scene_description_contains: Optional[List[str]] = ( + scene_description_contains + ) self.centric: str = centric self.desired_dt: float = desired_dt @@ -191,7 +204,7 @@ def __init__( raster_map_params if raster_map_params is not None # Allowing for parallel map processing in case the user specifies num_workers. - else {"px_per_m": raster_utils.DEFAULT_PX_PER_M, "num_workers": num_workers} + else {"px_per_m": DEFAULT_PX_PER_M, "num_workers": num_workers} ) self.incl_vector_map = incl_vector_map @@ -256,15 +269,19 @@ def __init__( flush=True, ) - self.check_args_combinations(matching_datasets) - self._map_api: Optional[MapAPI] = None if self.incl_vector_map: self._map_api = MapAPI( self.cache_path, - keep_in_memory=self.vector_map_params.get("keep_in_memory", True), + keep_in_memory=vector_map_params.get("keep_in_memory", True), ) + self.cache_lane_graphs = cache_lane_graphs + if cache_lane_graphs: + self.lane_graph_cache = dict() + else: + self.lane_graph_cache = dict() + all_scenes_list: Union[List[SceneMetadata], List[Scene]] = list() for env in self.envs: if any(env.name in dataset_tuple for dataset_tuple in matching_datasets): @@ -318,7 +335,7 @@ def __init__( env.cache_maps( self.cache_path, self.cache_class, - self.raster_map_params, + {**self.raster_map_params,**self.vector_map_params}, ) # Wait for rank 0 process to be done with caching. @@ -328,9 +345,9 @@ def __init__( ): distributed.barrier() - scenes_list: List[ - SceneMetadata - ] = self._get_desired_scenes_from_env(matching_datasets, env) + scenes_list: List[SceneMetadata] = ( + self._get_desired_scenes_from_env(matching_datasets, env) + ) if self.incl_vector_map and env.metadata.map_locations is not None: # env.metadata.map_locations can be none for map-containing @@ -387,63 +404,10 @@ def __init__( ) self._cache_data_index(data_index) - - # Wait for rank 0 process to be done with caching. if distributed.is_initialized() and distributed.get_world_size() > 1: distributed.barrier() - self._cached_batch_elements = None - def check_args_combinations(self, chosen_datasets: List[SceneTag]) -> None: - """Warn users about potential "gotcha" combinations of arguments, - usually involving fundamental limits in datasets. - """ - waymo_warning_given: bool = False - waymo_pattern: re.Pattern = re.compile("waymo") - - nuplan_warning_given: bool = False - nuplan_pattern: re.Pattern = re.compile("nuplan") - - dataset: SceneTag - for dataset in chosen_datasets: - if ( - not waymo_warning_given - and self.vector_map_params["incl_road_areas"] - and dataset.matches_any(waymo_pattern) - ): - warnings.warn( - ( - "\n\n############ WARNING! ############\n" - "Waymo has many gaps in the associations between " - "lane centerlines and boundaries,\nmaking it difficult " - "to construct lane edge polylines or road area polygons.\n" - "The ones currently provided by trajdata should be considered " - "low quality!" - "\n#################################\n" - ) - ) - waymo_warning_given = True - - elif ( - not nuplan_warning_given - and self.incl_vector_map - and dataset.matches_any(nuplan_pattern) - ): - warnings.warn( - ( - "\n\n############ WARNING! ############\n" - "nuPlan uses Shapely to represent its map. " - "Shapely only supports 2D coordinates.\nHowever, " - "nuPlan's agent trajectories are provided in 3D " - "(with a non-zero z-coordinate).\nThus, any " - "spatial queries (e.g., nearest lane or road area) " - "should be executed with the z-coordinate " - "set to 0.0 manually." - "\n#################################\n" - ) - ) - nuplan_warning_given = True - def _index_cache_path( self, ret_args: bool = False ) -> Union[Path, Tuple[Path, Dict[str, Any]]]: @@ -451,23 +415,31 @@ def _index_cache_path( # and hashed together here. impactful_args: Dict[str, Any] = { "desired_data": tuple(self.desired_data), - "scene_description_contains": tuple(self.scene_description_contains) - if self.scene_description_contains is not None - else None, + "scene_description_contains": ( + tuple(self.scene_description_contains) + if self.scene_description_contains is not None + else None + ), "centric": self.centric, "desired_dt": self.desired_dt, "history_sec": self.history_sec, "future_sec": self.future_sec, "incl_robot_future": self.incl_robot_future, - "only_types": tuple(t.name for t in self.only_types) - if self.only_types is not None - else None, - "only_predict": tuple(t.name for t in self.only_predict) - if self.only_predict is not None - else None, - "no_types": tuple(t.name for t in self.no_types) - if self.no_types is not None - else None, + "only_types": ( + tuple(t.name for t in self.only_types) + if self.only_types is not None + else None + ), + "only_predict": ( + tuple(t.name for t in self.only_predict) + if self.only_predict is not None + else None + ), + "no_types": ( + tuple(t.name for t in self.no_types) + if self.no_types is not None + else None + ), "ego_only": self.ego_only, } index_hash: str = py_utils.hash_dict(impactful_args) @@ -528,7 +500,7 @@ def load_or_create_cache( print(f"Loading cache from {cache_path} ...", end="") t = time.time() with open(cache_path, "rb") as f: - self._cached_batch_elements, keep_ids = dill.load(f, encoding="latin1") + self._cached_batch_elements, keep_mask = dill.load(f, encoding="latin1") print(f" done in {time.time() - t:.1f}s.") else: @@ -653,9 +625,7 @@ def remove_elements(self, keep_ids: Union[np.ndarray, List[int]]): f"Kept {self._data_len}/{old_len} elements, {self._data_len/old_len*100.0:.2f}%." ) - def _get_data_index( - self, num_workers: int, scene_paths: List[Path] - ) -> Union[ + def _get_data_index(self, num_workers: int, scene_paths: List[Path]) -> Union[ List[Tuple[str, int, np.ndarray]], List[Tuple[str, int, List[Tuple[str, np.ndarray]]]], ]: @@ -840,6 +810,9 @@ def get_collate_fn( return_dict=return_dict, pad_format=pad_format, batch_augments=batch_augments, + desired_num_agents = self.max_agent_num, + desired_hist_len = int(self.history_sec[1]/self.desired_dt), + desired_fut_len = int(self.future_sec[1]/self.desired_dt), ) else: raise ValueError(f"{self.centric}-centric data batches are not supported.") @@ -874,8 +847,7 @@ def _get_desired_scenes_from_env( ) -> Union[List[Scene], List[SceneMetadata]]: scenes_list: Union[List[Scene], List[SceneMetadata]] = list() for scene_tag in tqdm( - scene_tags, desc=f"Getting Scenes from {env.name} with scene tag {scene_tags}", - disable=not self.verbose + scene_tags, desc=f"Getting Scenes from {env.name}", disable=not self.verbose ): if env.name in scene_tag: scenes_list += env.get_matching_scenes( @@ -1078,6 +1050,16 @@ def __getitem__(self, idx: int) -> Union[SceneBatchElement, AgentBatchElement]: only_types=self.only_types, no_types=self.no_types, ) + if self.cache_lane_graphs: + map_name: str = ( + f"{scene_time.scene.env_name}:{scene_time.scene.location}" + ) + if map_name not in self.lane_graph_cache: + self.lane_graph_cache[map_name] = map_utils.obtain_lane_graph( + scene_time.scene, + self._map_api, + self.vector_map_params, + ) batch_element: SceneBatchElement = SceneBatchElement( scene_cache, @@ -1095,6 +1077,9 @@ def __getitem__(self, idx: int) -> Union[SceneBatchElement, AgentBatchElement]: self.standardize_data, self.standardize_derivatives, self.max_agent_num, + lane_graph_cache=( + self.lane_graph_cache[map_name] if self.cache_lane_graphs else None + ), ) elif self.centric == "agent": scene_time_agent: SceneTimeAgent = SceneTimeAgent.from_cache( diff --git a/src/trajdata/dataset_specific/nuplan/nuplan_dataset.py b/src/trajdata/dataset_specific/nuplan/nuplan_dataset.py index 1c4df3f..fe6f898 100644 --- a/src/trajdata/dataset_specific/nuplan/nuplan_dataset.py +++ b/src/trajdata/dataset_specific/nuplan/nuplan_dataset.py @@ -360,38 +360,6 @@ def get_agent_info( return agent_list, agent_presence - def cache_map( - self, - map_name: str, - cache_path: Path, - map_cache_class: Type[SceneCache], - map_params: Dict[str, Any], - ) -> None: - nuplan_map: NuPlanMap = map_factory.get_maps_api( - map_root=str(self.metadata.data_dir.parent / "maps"), - map_version=nuplan_utils.NUPLAN_MAP_VERSION, - map_name=nuplan_utils.NUPLAN_FULL_MAP_NAME_DICT[map_name], - ) - - # Loading all layer geometries. - nuplan_map.initialize_all_layers() - - # This df has the normal lane_connectors with additional boundary information, - # which we want to use, however the default index is not the lane_connector_fid, - # although it is a 1:1 mapping so we instead create another index with the - # lane_connector_fids as the key and the resulting integer indices as the value. - lane_connector_fids: pd.Series = nuplan_map._vector_map[ - "gen_lane_connectors_scaled_width_polygons" - ]["lane_connector_fid"] - lane_connector_idxs: pd.Series = pd.Series( - index=lane_connector_fids, data=range(len(lane_connector_fids)) - ) - - vector_map = VectorMap(map_id=f"{self.name}:{map_name}") - nuplan_utils.populate_vector_map(vector_map, nuplan_map, lane_connector_idxs) - - map_cache_class.finalize_and_cache_map(cache_path, vector_map, map_params) - def cache_maps( self, cache_path: Path, @@ -406,4 +374,50 @@ def cache_maps( desc=f"Caching {self.name} Maps at {map_params['px_per_m']:.2f} px/m", position=0, ): - self.cache_map(map_name, cache_path, map_cache_class, map_params) + cache_map( + map_root=str(self.metadata.data_dir.parent / "maps"), + env_name=self.name, + map_name=map_name, + cache_path=cache_path, + map_cache_class=map_cache_class, + map_params=map_params, + ) + + +def cache_map( + map_root: str, + env_name: str, + map_name: str, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], +) -> None: + nuplan_map: NuPlanMap = map_factory.get_maps_api( + map_root=map_root, + map_version=nuplan_utils.NUPLAN_MAP_VERSION, + map_name=nuplan_utils.NUPLAN_FULL_MAP_NAME_DICT[map_name], + ) + + # Loading all layer geometries. + nuplan_map.initialize_all_layers() + + # This df has the normal lane_connectors with additional boundary information, + # which we want to use, however the default index is not the lane_connector_fid, + # although it is a 1:1 mapping so we instead create another index with the + # lane_connector_fids as the key and the resulting integer indices as the value. + lane_connector_fids: pd.Series = nuplan_map._vector_map[ + "gen_lane_connectors_scaled_width_polygons" + ]["lane_connector_fid"] + lane_connector_idxs: pd.Series = pd.Series( + index=lane_connector_fids, data=range(len(lane_connector_fids)) + ) + + vector_map = VectorMap(map_id=f"{env_name}:{map_name}") + nuplan_utils.populate_vector_map( + vector_map, + nuplan_map, + lane_connector_idxs, + max_lane_length=map_params.get("max_lane_length", None), + ) + + map_cache_class.finalize_and_cache_map(cache_path, vector_map, map_params) diff --git a/src/trajdata/dataset_specific/nuplan/nuplan_utils.py b/src/trajdata/dataset_specific/nuplan/nuplan_utils.py index c98c748..9614929 100644 --- a/src/trajdata/dataset_specific/nuplan/nuplan_utils.py +++ b/src/trajdata/dataset_specific/nuplan/nuplan_utils.py @@ -8,6 +8,7 @@ import nuplan.planning.script.config.common as common_cfg import pandas as pd import yaml +import math from nuplan.common.maps.nuplan_map.nuplan_map import NuPlanMap from tqdm import tqdm @@ -23,6 +24,7 @@ RoadLane, ) from trajdata.utils import map_utils +from trajdata.maps.vec_map import split_lane_segments NUPLAN_DT: Final[float] = 0.05 NUPLAN_FULL_MAP_NAME_DICT: Final[Dict[str, str]] = { @@ -190,6 +192,7 @@ def close_db(self) -> None: def nuplan_type_to_unified_type(nuplan_type: str) -> AgentType: + # TODO map traffic cones, barriers to static; generic_object to pedestrian if nuplan_type == "pedestrian": return AgentType.PEDESTRIAN elif nuplan_type == "bicycle": @@ -259,7 +262,10 @@ def extract_area(nuplan_map: NuPlanMap, area_record) -> np.ndarray: def populate_vector_map( - vector_map: VectorMap, nuplan_map: NuPlanMap, lane_connector_idxs: pd.Series + vector_map: VectorMap, + nuplan_map: NuPlanMap, + lane_connector_idxs: pd.Series, + max_lane_length: Optional[float] = None, ) -> None: # Setting the map bounds. # NOTE: min_pt is especially important here since the world coordinates of nuPlan @@ -328,14 +334,37 @@ def populate_vector_map( # The right boundary of Lane A has Lane A to its left. boundary_connectivity_dict[right_boundary_id]["left"].append(lane_id) + # Find road areas that this lane intersects for faster lane-based lookup later. + intersect_filt = nuplan_map._vector_map["drivable_area"].intersects( + lane_info["geometry"] + ) + isnear_filt = ( + nuplan_map._vector_map["drivable_area"].distance(lane_info["geometry"]) + < 3.0 + ) + road_area_ids = set( + nuplan_map._vector_map["drivable_area"][intersect_filt | isnear_filt][ + "fid" + ].values + ) + if not road_area_ids: + print(f"Warning: no road lane associated with lane {lane_id}") + # "partial" because we aren't adding lane connectivity until later. partial_new_lane = RoadLane( id=lane_id, center=Polyline(center_pts), left_edge=Polyline(left_pts), right_edge=Polyline(right_pts), + road_area_ids=road_area_ids, ) - vector_map.add_map_element(partial_new_lane) + if max_lane_length is not None: + split_lanes = split_lane_segments(partial_new_lane, max_len=max_lane_length) + for lane in split_lanes: + vector_map.add_map_element(lane) + lane_boundary_dict[lane.id] = boundary_info + else: + vector_map.add_map_element(partial_new_lane) overall_pbar.update() for fid, polygon_info in nuplan_map._vector_map["drivable_area"].iterrows(): diff --git a/src/trajdata/dataset_specific/nusc/nusc_dataset.py b/src/trajdata/dataset_specific/nusc/nusc_dataset.py index 209824e..9fc8de3 100644 --- a/src/trajdata/dataset_specific/nusc/nusc_dataset.py +++ b/src/trajdata/dataset_specific/nusc/nusc_dataset.py @@ -279,7 +279,11 @@ def cache_map( ) vector_map = VectorMap(map_id=f"{self.name}:{map_name}") - nusc_utils.populate_vector_map(vector_map, nusc_map) + nusc_utils.populate_vector_map( + vector_map, + nusc_map, + max_lane_length=map_params.get("max_lane_length", None), + ) map_cache_class.finalize_and_cache_map(cache_path, vector_map, map_params) diff --git a/src/trajdata/dataset_specific/nusc/nusc_utils.py b/src/trajdata/dataset_specific/nusc/nusc_utils.py index 356debd..b343b3d 100644 --- a/src/trajdata/dataset_specific/nusc/nusc_utils.py +++ b/src/trajdata/dataset_specific/nusc/nusc_utils.py @@ -20,6 +20,7 @@ RoadLane, ) from trajdata.utils import arr_utils, map_utils +from trajdata.maps.vec_map import split_lane_segments NUSC_DT: Final[float] = 0.5 @@ -371,7 +372,9 @@ def extract_area(nusc_map: NuScenesMap, area_record) -> np.ndarray: return np.array([(node["x"], node["y"]) for node in polygon_nodes]) -def populate_vector_map(vector_map: VectorMap, nusc_map: NuScenesMap) -> None: +def populate_vector_map( + vector_map: VectorMap, nusc_map: NuScenesMap, max_lane_length=None +) -> None: # Setting the map bounds. vector_map.extent = np.array( [ @@ -427,7 +430,12 @@ def populate_vector_map(vector_map: VectorMap, nusc_map: NuScenesMap) -> None: # ) # Adding the element to the map. - vector_map.add_map_element(new_lane) + if max_lane_length is not None: + split_lanes = split_lane_segments(new_lane, max_len=max_lane_length) + for lane in split_lanes: + vector_map.add_map_element(lane) + else: + vector_map.add_map_element(new_lane) overall_pbar.update() for lane_record in nusc_map.lane_connector: diff --git a/src/trajdata/dataset_specific/scene_records.py b/src/trajdata/dataset_specific/scene_records.py index b785415..68bd5b9 100644 --- a/src/trajdata/dataset_specific/scene_records.py +++ b/src/trajdata/dataset_specific/scene_records.py @@ -1,11 +1,6 @@ from typing import NamedTuple -class Argoverse2Record(NamedTuple): - name: str - data_idx: int - - class EUPedsRecord(NamedTuple): name: str location: str @@ -34,15 +29,6 @@ class NuscSceneRecord(NamedTuple): data_idx: int -class VODSceneRecord(NamedTuple): - token: str - name: str - location: str - length: str - desc: str - data_idx: int - - class LyftSceneRecord(NamedTuple): name: str length: str diff --git a/src/trajdata/dataset_specific/waymo/waymo_dataset.py b/src/trajdata/dataset_specific/waymo/waymo_dataset.py index 4497a79..4621936 100644 --- a/src/trajdata/dataset_specific/waymo/waymo_dataset.py +++ b/src/trajdata/dataset_specific/waymo/waymo_dataset.py @@ -8,8 +8,6 @@ import pandas as pd import tensorflow as tf import tqdm -from waymo_open_dataset.protos.scenario_pb2 import Scenario - from trajdata.caching import EnvCache, SceneCache from trajdata.data_structures import ( AgentMetadata, @@ -43,6 +41,7 @@ ) from trajdata.utils import arr_utils from trajdata.utils.parallel_utils import parallel_apply +from waymo_open_dataset.protos.scenario_pb2 import Scenario def const_lambda(const_val: Any) -> Any: @@ -178,7 +177,7 @@ def get_agent_info( ) scenario: Scenario = Scenario() for data in dataset: - scenario.ParseFromString(bytearray(data.numpy())) + scenario.ParseFromString(bytes(data.numpy())) break agent_ids = [] @@ -186,6 +185,7 @@ def get_agent_info( all_agent_data = [] agents_to_remove = [] ego_id = None + agent_info_dict: dict[str, AgentMetadata] = {} for index, track in enumerate(scenario.tracks): agent_type: AgentType = translate_agent_type(track.object_type) if agent_type == -1: @@ -238,19 +238,16 @@ def get_agent_info( ego_id = agent_id agent_name = "ego" - agent_info = AgentMetadata( - name=agent_name, - agent_type=agent_type, - first_timestep=first_timestep, - last_timestep=last_timestep, - extent=VariableExtent(), - ) - if last_timestep - first_timestep > 0: - agent_list.append(agent_info) - for timestep in range(first_timestep, last_timestep + 1): - agent_presence[timestep].append(agent_info) - else: + if last_timestep - first_timestep <= 0: agents_to_remove.append(agent_id) + else: + agent_info_dict[agent_name] = AgentMetadata( + name=agent_name, + agent_type=agent_type, + first_timestep=first_timestep, + last_timestep=last_timestep, + extent=VariableExtent(), + ) # agent_ml_class = np.repeat(agent_ml_class, scene.length_timesteps) # all_agent_data = np.insert(all_agent_data, 6, agent_ml_class, axis=1) @@ -307,6 +304,22 @@ def get_agent_info( index={str(ego_id): "ego"}, inplace=True, level="agent_id" ) + all_agent_data_df = waymo_utils.sort_df_by_distance_to_ego_vehicle( + all_agent_data_df, + timestep=0, + ego_name="ego", + ) + + # Get sorted (by distance) list of agent_ids + agent_ids = all_agent_data_df.index.get_level_values("agent_id").unique() + for agent_id in agent_ids: + agent_info = agent_info_dict[str(agent_id)] + agent_list.append(agent_info) + for timestep in range( + agent_info.first_timestep, agent_info.last_timestep + 1 + ): + agent_presence[timestep].append(agent_info) + cache_class.save_agent_data( all_agent_data_df.loc[:, final_cols], cache_path, @@ -340,12 +353,13 @@ def cache_map( scenario: Scenario = Scenario() for data in dataset: - scenario.ParseFromString(bytearray(data.numpy())) + scenario.ParseFromString(bytes(data.numpy())) break vector_map: VectorMap = waymo_utils.extract_vectorized( map_features=scenario.map_features, map_name=f"{self.name}:{self.name}_{data_idx}", + max_lane_length=map_params.get("max_lane_length", 100), ) map_cache_class.finalize_and_cache_map(cache_path, vector_map, map_params) diff --git a/src/trajdata/dataset_specific/waymo/waymo_utils.py b/src/trajdata/dataset_specific/waymo/waymo_utils.py index 83cc210..eb544cf 100644 --- a/src/trajdata/dataset_specific/waymo/waymo_utils.py +++ b/src/trajdata/dataset_specific/waymo/waymo_utils.py @@ -8,12 +8,12 @@ import tensorflow as tf from intervaltree import Interval, IntervalTree from tqdm import tqdm +from trajdata.maps import TrafficLightStatus, VectorMap +from trajdata.maps.vec_map import split_lane_segments +from trajdata.maps.vec_map_elements import PedCrosswalk, Polyline, RoadEdge, RoadLane from waymo_open_dataset.protos import map_pb2 as waymo_map_pb2 from waymo_open_dataset.protos import scenario_pb2 -from trajdata.maps import TrafficLightStatus, VectorMap -from trajdata.maps.vec_map_elements import PedCrosswalk, Polyline, RoadLane - WAYMO_DT: Final[float] = 0.1 WAYMO_DATASET_NAMES = [ "testing", @@ -145,7 +145,10 @@ def get_filename(self, data_idx): def extract_vectorized( - map_features: List[waymo_map_pb2.MapFeature], map_name: str, verbose: bool = False + map_features: List[waymo_map_pb2.MapFeature], + map_name: str, + verbose: bool = False, + max_lane_length: Optional[float] = None, ) -> VectorMap: vec_map = VectorMap(map_id=map_name) @@ -175,7 +178,9 @@ def extract_vectorized( # aren't interpolating between others?? continue - road_lanes, modified_lane_ids = translate_lane(map_feature, boundaries) + road_lanes, modified_lane_ids = translate_lane( + map_feature, boundaries, max_lane_length=max_lane_length + ) if modified_lane_ids: lane_id_remap_dict.update(modified_lane_ids) @@ -207,6 +212,21 @@ def extract_vectorized( max_pt = np.fmax(max_pt, crosswalk.polygon.xyz.max(axis=0)) min_pt = np.fmin(min_pt, crosswalk.polygon.xyz.min(axis=0)) + elif map_feature.WhichOneof("feature_data") == "road_edge": + if len(map_feature.road_edge.polyline) == 1: + continue + road_edge = RoadEdge( + id=str(map_feature.id), + polyline=Polyline( + np.array( + [(pt.x, pt.y, pt.z) for pt in map_feature.road_edge.polyline] + ) + ), + ) + vec_map.add_map_element(road_edge) + max_pt = np.fmax(max_pt, road_edge.polyline.xyz.max(axis=0)) + min_pt = np.fmin(min_pt, road_edge.polyline.xyz.min(axis=0)) + else: continue @@ -473,6 +493,7 @@ def subselect_boundary( def translate_lane( map_feature: waymo_map_pb2.MapFeature, boundaries: Dict[int, Polyline], + max_lane_length: Optional[float] = None, ) -> Tuple[RoadLane, Optional[Dict[int, List[bytes]]]]: lane: waymo_map_pb2.LaneCenter = map_feature.lane @@ -485,37 +506,68 @@ def translate_lane( lane_chunks = split_lane_into_chunks(lane, boundaries) road_lanes: List[RoadLane] = [] new_ids: List[bytes] = [] + for idx, (lane_center, left_edge, right_edge) in enumerate(lane_chunks): + road_lane = RoadLane( - id=f"{map_feature.id}_{idx}" - if len(lane_chunks) > 1 - else str(map_feature.id), + id=( + f"{map_feature.id}_{idx}" + if len(lane_chunks) > 1 + else str(map_feature.id) + ), center=lane_center, left_edge=left_edge, right_edge=right_edge, ) - new_ids.append(road_lane.id) + if max_lane_length is not None: + split_lanes = split_lane_segments(road_lane, max_len=max_lane_length) + if idx == 0: + split_lanes[0].prev_lanes.update( + [str(eid) for eid in lane.entry_lanes] + ) + else: + split_lanes[0].prev_lanes.add(f"{map_feature.id}_{idx-1}") - if idx == 0: - road_lane.prev_lanes.update([str(eid) for eid in lane.entry_lanes]) - else: - road_lane.prev_lanes.add(f"{map_feature.id}_{idx-1}") + if idx == len(lane_chunks) - 1: + split_lanes[-1].next_lanes.update( + [str(eid) for eid in lane.exit_lanes] + ) + else: + split_lanes[-1].next_lanes.add(f"{map_feature.id}_{idx+1}") + + for split_lane in split_lanes: + # We'll take care of reassigning these IDs to the chunked versions later. + for neighbor in lane.left_neighbors: + split_lane.adj_lanes_left.add(str(neighbor.feature_id)) + + for neighbor in lane.right_neighbors: + split_lane.adj_lanes_right.add(str(neighbor.feature_id)) + new_ids.append(split_lane.id) + road_lanes.append(split_lane) - if idx == len(lane_chunks) - 1: - road_lane.next_lanes.update([str(eid) for eid in lane.exit_lanes]) else: - road_lane.next_lanes.add(f"{map_feature.id}_{idx+1}") + new_ids.append(road_lane.id) + + if idx == 0: + road_lane.prev_lanes.update([str(eid) for eid in lane.entry_lanes]) + else: + road_lane.prev_lanes.add(f"{map_feature.id}_{idx-1}") - # We'll take care of reassigning these IDs to the chunked versions later. - for neighbor in lane.left_neighbors: - road_lane.adj_lanes_left.add(str(neighbor.feature_id)) + if idx == len(lane_chunks) - 1: + road_lane.next_lanes.update([str(eid) for eid in lane.exit_lanes]) + else: + road_lane.next_lanes.add(f"{map_feature.id}_{idx+1}") - for neighbor in lane.right_neighbors: - road_lane.adj_lanes_right.add(str(neighbor.feature_id)) + # We'll take care of reassigning these IDs to the chunked versions later. + for neighbor in lane.left_neighbors: + road_lane.adj_lanes_left.add(str(neighbor.feature_id)) - road_lanes.append(road_lane) + for neighbor in lane.right_neighbors: + road_lane.adj_lanes_right.add(str(neighbor.feature_id)) - if len(lane_chunks) > 1: + road_lanes.append(road_lane) + + if len(road_lanes) > 1: return road_lanes, {str(map_feature.id): new_ids} else: return road_lanes, None @@ -564,3 +616,38 @@ def translate_traffic_state( def interpolate_array(data: List) -> np.array: return pd.DataFrame(data).interpolate(limit_area="inside").to_numpy() + + +def sort_df_by_distance_to_ego_vehicle( + df: pd.DataFrame, + timestep: int, + ego_name: str = "ego", + keep_distance_column: bool = False, +): + """Sort df by distance to ego vehicle at timestep.""" + ego_xy = tuple(df.loc["ego", timestep][["x", "y"]]) + + def distance_to_ego_xy(x, y): + return np.linalg.norm([x, y] - np.array(ego_xy)) + + all_xy_positions = np.stack([df["x"], df["y"]], axis=-1) + + df["dist_to_ego"] = np.linalg.norm(all_xy_positions - np.array(ego_xy), axis=-1) + only_t0 = ( + df.groupby(level="agent_id") + .apply( + lambda x: ( + x[x.index.get_level_values("scene_ts") == timestep] + if timestep in x.index.get_level_values("scene_ts") + else x.iloc[[0]] + ) + ) + .droplevel(0) + ) + only_t0_sorted = only_t0.sort_values(by="dist_to_ego") + sorted_ids = list(only_t0_sorted.index.get_level_values("agent_id")) + sorted_df = df.reindex(labels=sorted_ids, level=0) + df.drop(columns="dist_to_ego", inplace=True) + if not keep_distance_column: + sorted_df.drop(columns="dist_to_ego", inplace=True) + return sorted_df diff --git a/src/trajdata/maps/map_api.py b/src/trajdata/maps/map_api.py index 6e58aab..7b52617 100644 --- a/src/trajdata/maps/map_api.py +++ b/src/trajdata/maps/map_api.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union if TYPE_CHECKING: from trajdata.maps.map_kdtree import MapElementKDTree @@ -15,7 +15,7 @@ class MapAPI: - def __init__(self, unified_cache_path: Path, keep_in_memory: bool = False) -> None: + def __init__(self, unified_cache_path: Path, keep_in_memory: bool = False, data_dirs: Optional[Dict] = None) -> None: """A simple interface for loading trajdata's vector maps which does not require instantiation of a `UnifiedDataset` object. @@ -25,20 +25,53 @@ def __init__(self, unified_cache_path: Path, keep_in_memory: bool = False) -> No in memory (memoized) for later re-use. For most cases (e.g., batched dataloading), this is a good idea. However, this can cause rapid memory usage growth for some datasets (e.g., Waymo) and it can be better to disable this. Defaults to False. + data_dirs (Optional[Dict], optional): A dictionary of dataset names and their paths. """ self.unified_cache_path: Path = unified_cache_path self.maps: Dict[str, VectorMap] = dict() self._keep_in_memory = keep_in_memory + self.data_dirs = data_dirs + + def has_map(self, map_id: str) -> bool: + env_name, map_name = map_id.split(":") + vec_map_path: Path = ( + self.unified_cache_path / env_name / "maps" / f"{map_name}.pb" + ) + + if map_id in self.maps or Path.exists(vec_map_path): + return True + else: + return False + def get_map( - self, map_id: str, scene_cache: Optional[SceneCache] = None, **kwargs + self, map_id: str, scene_cache: Optional[SceneCache] = None, dataset_kwargs={},**kwargs ) -> VectorMap: if map_id not in self.maps: env_name, map_name = map_id.split(":") env_maps_path: Path = self.unified_cache_path / env_name / "maps" - stored_vec_map: VectorizedMap = map_utils.load_vector_map( - env_maps_path / f"{map_name}.pb" - ) + vec_map_path: Path = env_maps_path / f"{map_name}.pb" + + if not Path.exists(vec_map_path): + if self.data_dirs is None: + raise ValueError( + f"There is no cached map at {vec_map_path} and there was no " + + "`data_dirs` provided to rebuild cache.") + + # Rebuild maps by creating a dummy dataset object. + # TODO We need support for rebuilding map files only, without creating dataset and building agent data. + from trajdata.dataset import UnifiedDataset + dataset = UnifiedDataset( + desired_data=[env_name], + rebuild_cache=True, + rebuild_maps=True, + data_dirs=self.data_dirs, + cache_location=self.unified_cache_path, + verbose=True, + **dataset_kwargs, + ) + + stored_vec_map: VectorizedMap = map_utils.load_vector_map(vec_map_path) vec_map: VectorMap = VectorMap.from_proto(stored_vec_map, **kwargs) vec_map.search_kdtrees = map_utils.load_kdtrees( diff --git a/src/trajdata/maps/map_kdtree.py b/src/trajdata/maps/map_kdtree.py index c9f36d0..450425a 100644 --- a/src/trajdata/maps/map_kdtree.py +++ b/src/trajdata/maps/map_kdtree.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections import defaultdict -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, Dict, Iterator if TYPE_CHECKING: from trajdata.maps.vec_map import VectorMap @@ -11,7 +11,6 @@ import numpy as np from scipy.spatial import KDTree from tqdm import tqdm - from trajdata.maps.vec_map_elements import MapElement, MapElementType, Polyline from trajdata.utils.arr_utils import angle_wrap @@ -43,8 +42,7 @@ def _build_kdtree(self, vector_map: VectorMap, verbose: bool = False): total=len(vector_map), disable=not verbose, ): - result = self._extract_points_and_metadata(map_elem) - if result is not None: + for result in self._extract_points_and_metadata(map_elem): points, extras = result polyline_inds.extend([len(polylines)] * points.shape[0]) @@ -54,6 +52,7 @@ def _build_kdtree(self, vector_map: VectorMap, verbose: bool = False): for k, v in extras.items(): metadata[k].append(v) + metadata["map_elem_id"].append(np.array([map_elem.id])) points = np.concatenate(polylines, axis=0) polyline_inds = np.array(polyline_inds) @@ -64,7 +63,7 @@ def _build_kdtree(self, vector_map: VectorMap, verbose: bool = False): def _extract_points_and_metadata( self, map_element: MapElement - ) -> Optional[Tuple[np.ndarray, Dict[str, np.ndarray]]]: + ) -> Iterator[Tuple[np.ndarray, dict[str, np.ndarray]]]: """Defines the coordinates we want to store in the KDTree for a MapElement. Args: map_element (MapElement): the MapElement to store in the KDTree. @@ -101,6 +100,33 @@ def polyline_inds_in_range(self, point: np.ndarray, range: float) -> np.ndarray: return np.unique(self.polyline_inds[data_inds], axis=0) +class RoadEdgeKDTree(MapElementKDTree): + """KDTree for lane center polylines.""" + + def __init__( + self, vector_map: VectorMap, max_segment_len: Optional[float] = None + ) -> None: + """ + Args: + vec_map: the VectorizedMap object to build the KDTree for + max_segment_len (float, optional): if specified, we will insert extra points into the KDTree + such that all polyline segments are shorter then max_segment_len. + """ + self.max_segment_len = max_segment_len + super().__init__(vector_map) + + def _extract_points_and_metadata( + self, map_element: MapElement + ) -> Iterator[Tuple[np.ndarray, dict[str, np.ndarray]]]: + if map_element.elem_type == MapElementType.ROAD_EDGE: + pts: Polyline = map_element.polyline + if self.max_segment_len is not None: + pts = pts.interpolate(max_dist=self.max_segment_len) + + # We only want to store xyz in the kdtree, not heading. + yield pts.xyz, {"heading": pts.h} + + class LaneCenterKDTree(MapElementKDTree): """KDTree for lane center polylines.""" @@ -118,16 +144,14 @@ def __init__( def _extract_points_and_metadata( self, map_element: MapElement - ) -> Optional[Tuple[np.ndarray, Dict[str, np.ndarray]]]: + ) -> Iterator[Tuple[np.ndarray, dict[str, np.ndarray]]]: if map_element.elem_type == MapElementType.ROAD_LANE: pts: Polyline = map_element.center if self.max_segment_len is not None: pts = pts.interpolate(max_dist=self.max_segment_len) # We only want to store xyz in the kdtree, not heading. - return pts.xyz, {"heading": pts.h} - else: - return None + yield pts.xyz, {"heading": pts.h} def current_lane_inds( self, @@ -183,3 +207,41 @@ def current_lane_inds( min_costs = [np.min(costs[lane_inds == ind]) for ind in unique_lane_inds] return unique_lane_inds[np.argsort(min_costs)] + + +class RoadAreaKDTree(MapElementKDTree): + """KDTree for road area polygons. + The polygons may have holes. We will simply store points along both the + exterior_polygon and all interior_holes. Finding a nearest point in this KDTree will + correspond to finding any + """ + + def __init__( + self, vector_map: VectorMap, max_segment_len: Optional[float] = None + ) -> None: + """ + Args: + vec_map: the VectorizedMap object to build the KDTree for + max_segment_len (float, optional): if specified, we will insert extra points into the KDTree + such that all polyline segments are shorter then max_segment_len. + """ + self.max_segment_len = max_segment_len + super().__init__(vector_map) + + def _extract_points_and_metadata( + self, map_element: MapElement + ) -> Iterator[Tuple[np.ndarray, dict[str, np.ndarray]]]: + if map_element.elem_type == MapElementType.ROAD_AREA: + # Exterior polygon + pts: Polyline = map_element.exterior_polygon + if self.max_segment_len is not None: + pts = pts.interpolate(max_dist=self.max_segment_len) + # We only want to store xyz in the kdtree, not heading. + yield pts.xyz, {"exterior": np.array([True])} + + # Interior holes + for pts in map_element.interior_holes: + if self.max_segment_len is not None: + pts = pts.interpolate(max_dist=self.max_segment_len) + # We only want to store xyz in the kdtree, not heading. + yield pts.xyz, {"exterior": np.array([False])} diff --git a/src/trajdata/maps/vec_map.py b/src/trajdata/maps/vec_map.py index 6b720cc..a6c2ebb 100644 --- a/src/trajdata/maps/vec_map.py +++ b/src/trajdata/maps/vec_map.py @@ -3,7 +3,11 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from trajdata.maps.map_kdtree import MapElementKDTree, LaneCenterKDTree + from trajdata.maps.map_kdtree import ( + MapElementKDTree, + LaneCenterKDTree, + RoadEdgeKDTree, + ) from trajdata.maps.map_strtree import MapElementSTRTree from collections import defaultdict @@ -24,11 +28,11 @@ import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np +import trajdata.proto.vectorized_map_pb2 as map_proto from matplotlib.axes import Axes +from shapely.geometry import Polygon from tqdm import tqdm - -import trajdata.proto.vectorized_map_pb2 as map_proto -from trajdata.maps.map_kdtree import LaneCenterKDTree +from trajdata.maps.map_kdtree import LaneCenterKDTree, RoadAreaKDTree, RoadEdgeKDTree from trajdata.maps.map_strtree import MapElementSTRTree from trajdata.maps.traffic_light_status import TrafficLightStatus from trajdata.maps.vec_map_elements import ( @@ -38,23 +42,27 @@ PedWalkway, Polyline, RoadArea, + RoadEdge, RoadLane, + TrafficSign, + WaitLine, ) -from trajdata.utils import map_utils, raster_utils +from trajdata.utils import map_utils @dataclass(repr=False) class VectorMap: map_id: str - extent: Optional[ - np.ndarray - ] = None # extent is [min_x, min_y, min_z, max_x, max_y, max_z] + extent: Optional[np.ndarray] = ( + None # extent is [min_x, min_y, min_z, max_x, max_y, max_z] + ) elements: DefaultDict[MapElementType, Dict[str, MapElement]] = field( default_factory=lambda: defaultdict(dict) ) search_kdtrees: Optional[Dict[MapElementType, MapElementKDTree]] = None search_rtrees: Optional[Dict[MapElementType, MapElementSTRTree]] = None traffic_light_status: Optional[Dict[Tuple[str, int], TrafficLightStatus]] = None + online_metadict: Optional[Dict[Tuple[str, int], Dict]] = None def __post_init__(self) -> None: self.env_name, self.map_name = self.map_id.split(":") @@ -62,13 +70,22 @@ def __post_init__(self) -> None: self.lanes: Optional[List[RoadLane]] = None if MapElementType.ROAD_LANE in self.elements: self.lanes = list(self.elements[MapElementType.ROAD_LANE].values()) + self.road_edges: Optional[List[RoadEdge]] = None + if MapElementType.ROAD_EDGE in self.elements: + self.road_edges = list(self.elements[MapElementType.ROAD_EDGE].values()) def add_map_element(self, map_elem: MapElement) -> None: self.elements[map_elem.elem_type][map_elem.id] = map_elem def compute_search_indices(self) -> None: # TODO(bivanovic@nvidia.com): merge tree dicts? - self.search_kdtrees = {MapElementType.ROAD_LANE: LaneCenterKDTree(self)} + self.search_kdtrees = { + MapElementType.ROAD_LANE: LaneCenterKDTree(self), + } + if MapElementType.ROAD_EDGE in self.elements: + self.search_kdtrees[MapElementType.ROAD_EDGE] = RoadEdgeKDTree(self) + if MapElementType.ROAD_AREA in self.elements: + self.search_kdtrees[MapElementType.ROAD_AREA] = RoadAreaKDTree(self) self.search_rtrees = { elem_type: MapElementSTRTree(self, elem_type) for elem_type in [ @@ -86,6 +103,9 @@ def iter_elems(self) -> Iterator[MapElement]: def get_road_lane(self, lane_id: str) -> RoadLane: return self.elements[MapElementType.ROAD_LANE][lane_id] + def get_road_edge(self, lane_id: str) -> RoadEdge: + return self.elements[MapElementType.ROAD_EDGE][lane_id] + def __len__(self) -> int: return sum(len(elems_dict) for elems_dict in self.elements.values()) @@ -138,6 +158,45 @@ def _write_road_areas( shifted_origin, ) + def _write_road_edges( + self, vectorized_map: map_proto.VectorizedMap, shifted_origin: np.ndarray + ) -> None: + road_edge: RoadEdge + for elem_id, road_edge in self.elements[MapElementType.ROAD_EDGE].items(): + new_element: map_proto.MapElement = vectorized_map.elements.add() + new_element.id = elem_id.encode() + + new_edge: map_proto.RoadEdge = new_element.road_edge + map_utils.populate_road_edge_polylines(new_edge, road_edge, shifted_origin) + + def _write_traffic_signs( + self, vectorized_map: map_proto.VectorizedMap, shifted_origin: np.ndarray + ) -> None: + traffic_sign: TrafficSign + for elem_id, traffic_sign in self.elements[MapElementType.TRAFFIC_SIGN].items(): + new_element: map_proto.MapElement = vectorized_map.elements.add() + new_element.id = elem_id.encode() + + new_traffic_sign: map_proto.TrafficSign = new_element.traffic_sign + shifted_position: np.ndarray = traffic_sign.position - shifted_origin + new_traffic_sign.position.x = shifted_position[0] + new_traffic_sign.position.y = shifted_position[1] + new_traffic_sign.position.z = shifted_position[2] + new_traffic_sign.sign_type = traffic_sign.sign_type.encode() + + def _write_wait_lines( + self, vectorized_map: map_proto.VectorizedMap, shifted_origin: np.ndarray + ) -> None: + wait_line: WaitLine + for elem_id, wait_line in self.elements[MapElementType.WAIT_LINE].items(): + new_element: map_proto.MapElement = vectorized_map.elements.add() + new_element.id = elem_id.encode() + + new_wait_line: map_proto.WaitLine = new_element.wait_line + map_utils.populate_polygon(new_wait_line.polyline, wait_line.polyline.xyz, shifted_origin) + new_wait_line.wait_line_type = wait_line.wait_line_type.encode() + new_wait_line.is_implicit = wait_line.is_implicit + def _write_ped_crosswalks( self, vectorized_map: map_proto.VectorizedMap, shifted_origin: np.ndarray ) -> None: @@ -195,6 +254,9 @@ def to_proto(self) -> map_proto.VectorizedMap: self._write_road_areas(output_map, shifted_origin) self._write_ped_crosswalks(output_map, shifted_origin) self._write_ped_walkways(output_map, shifted_origin) + self._write_road_edges(output_map, shifted_origin) + self._write_traffic_signs(output_map, shifted_origin) + self._write_wait_lines(output_map, shifted_origin) return output_map @@ -205,6 +267,9 @@ def from_proto(cls, vec_map: map_proto.VectorizedMap, **kwargs): incl_road_areas: bool = kwargs.get("incl_road_areas", False) incl_ped_crosswalks: bool = kwargs.get("incl_ped_crosswalks", False) incl_ped_walkways: bool = kwargs.get("incl_ped_walkways", False) + incl_road_edges: bool = kwargs.get("incl_road_edges", False) + incl_traffic_signs: bool = kwargs.get("incl_traffic_signs", False) + incl_wait_lines: bool = kwargs.get("incl_wait_lines", False) # Add any map offset in case the map origin was shifted for storage efficiency. shifted_origin: np.ndarray = np.array( @@ -248,6 +313,19 @@ def from_proto(cls, vec_map: map_proto.VectorizedMap, **kwargs): ) + shifted_origin[:3] ) + + traffic_sign_ids: Optional[Set[str]] = None + if len(road_lane_obj.traffic_sign_ids) > 0: + traffic_sign_ids = set( + [iden.decode() for iden in road_lane_obj.traffic_sign_ids] + ) + + wait_line_ids: Optional[Set[str]] = None + if len(road_lane_obj.wait_line_ids) > 0: + wait_line_ids = set( + [iden.decode() for iden in road_lane_obj.wait_line_ids] + ) + adj_lanes_left: Set[str] = set( [iden.decode() for iden in road_lane_obj.adjacent_lanes_left] @@ -270,6 +348,8 @@ def from_proto(cls, vec_map: map_proto.VectorizedMap, **kwargs): center_pl, left_pl, right_pl, + traffic_sign_ids, + wait_line_ids, adj_lanes_left, adj_lanes_right, next_lanes, @@ -322,6 +402,42 @@ def from_proto(cls, vec_map: map_proto.VectorizedMap, **kwargs): curr_area = PedWalkway(elem_id, polygon_vertices) map_elem_dict[MapElementType.PED_WALKWAY][elem_id] = curr_area + elif incl_road_edges and map_elem.HasField("road_edge"): + road_edge_obj: map_proto.RoadEdge = map_elem.road_edge + + polyline: Polyline = Polyline( + map_utils.proto_to_np(road_edge_obj.polyline) + shifted_origin + ) + + curr_edge = RoadEdge(elem_id, polyline) + map_elem_dict[MapElementType.ROAD_EDGE][elem_id] = curr_edge + + elif incl_traffic_signs and map_elem.HasField("traffic_sign"): + traffic_sign_obj: map_proto.TrafficSign = map_elem.traffic_sign + + position: np.ndarray = np.array( + [traffic_sign_obj.position.x, + traffic_sign_obj.position.y, + traffic_sign_obj.position.z,] + ) + shifted_origin[:3] + + curr_traffic_sign = TrafficSign( + elem_id, position, traffic_sign_obj.sign_type + ) + map_elem_dict[MapElementType.TRAFFIC_SIGN][elem_id] = curr_traffic_sign + + elif incl_wait_lines and map_elem.HasField("wait_line"): + wait_line_obj: map_proto.WaitLine = map_elem.wait_line + + polyline: Polyline = Polyline( + map_utils.proto_to_np(wait_line_obj.polyline, incl_heading=False) + shifted_origin[:3] + ) + + curr_wait_line = WaitLine( + elem_id, polyline, wait_line_obj.wait_line_type, wait_line_obj.is_implicit + ) + map_elem_dict[MapElementType.WAIT_LINE][elem_id] = curr_wait_line + return cls( map_id=vec_map.name, extent=np.array( @@ -375,6 +491,13 @@ def get_closest_lane(self, xyz: np.ndarray) -> RoadLane: lane_kdtree: LaneCenterKDTree = self.search_kdtrees[MapElementType.ROAD_LANE] return self.lanes[lane_kdtree.closest_polyline_ind(xyz)] + def get_closest_road_edge(self, xyz: np.ndarray) -> RoadLane: + assert ( + self.search_kdtrees is not None + ), "Search KDTree not found, please rebuild cache." + road_edge_kdtree: RoadEdgeKDTree = self.search_kdtrees[MapElementType.ROAD_EDGE] + return self.road_edges[road_edge_kdtree.closest_polyline_ind(xyz)] + def get_closest_unique_lanes(self, xyz_vec: np.ndarray) -> List[RoadLane]: assert ( self.search_kdtrees is not None @@ -394,6 +517,16 @@ def get_lanes_within(self, xyz: np.ndarray, dist: float) -> List[RoadLane]: self.lanes[idx] for idx in lane_kdtree.polyline_inds_in_range(xyz, dist) ] + def get_road_edges_within(self, xyz: np.ndarray, dist: float) -> List[RoadEdge]: + assert ( + self.search_kdtrees is not None + ), "Search KDTree not found, please rebuild cache." + road_edge_kdtree: RoadEdgeKDTree = self.search_kdtrees[MapElementType.ROAD_EDGE] + return [ + self.road_edges[idx] + for idx in road_edge_kdtree.polyline_inds_in_range(xyz, dist) + ] + def get_closest_area( self, xy: np.ndarray, elem_type: MapElementType ) -> Union[RoadArea, PedCrosswalk, PedWalkway]: @@ -435,6 +568,31 @@ def get_areas_within( ) return [self.elements[elem_type][id] for id in ids] + def get_road_areas_within(self, xyz: np.ndarray, dist: float) -> List[RoadArea]: + road_area_kdtree: RoadAreaKDTree = self.search_kdtrees[MapElementType.ROAD_AREA] + polyline_inds = road_area_kdtree.polyline_inds_in_range(xyz, dist) + element_ids = set( + [road_area_kdtree.metadata["map_elem_id"][ind] for ind in polyline_inds] + ) + if MapElementType.ROAD_AREA not in self.elements: + raise ValueError( + "Road areas are not loaded. Use map_api.get_map(..., incl_road_areas=True)." + ) + return [self.elements[MapElementType.ROAD_AREA][id] for id in element_ids] + + def get_road_area_polygon_2d(self, id: str) -> Polygon: + if id not in self._road_area_polygons: + road_area: RoadArea = self.elements[MapElementType.ROAD_AREA][id] + road_area_polygon = Polygon( + shell=[(pt[0], pt[1]) for pt in road_area.exterior_polygon.points], + holes=[ + [(pt[0], pt[1]) for pt in polyline.points] + for polyline in road_area.interior_holes + ], + ) + self._road_area_polygons[id] = road_area_polygon + return self._road_area_polygons[id] + def get_traffic_light_status( self, lane_id: str, scene_ts: int ) -> TrafficLightStatus: @@ -445,6 +603,91 @@ def get_traffic_light_status( if self.traffic_light_status is not None else TrafficLightStatus.NO_DATA ) + + def has_stop_sign( + self, lane_id: str + ) -> bool: + traffic_sign_ids = self.get_road_lane(lane_id).traffic_sign_ids + if traffic_sign_ids is not None: + for sign_id in traffic_sign_ids: + if self.elements[MapElementType.TRAFFIC_SIGN][sign_id].sign_type == "TRAFFIC_SIGN_REGULATORY_R1_STOP": + return True + return False + + def get_wait_lines(self, lane_id: str) -> List[WaitLine]: + # Get waitlines for a lane + wait_line_ids = self.get_road_lane(lane_id).wait_line_ids + wait_lines: List[WaitLine] = [] + if wait_line_ids is not None: + wait_lines: List[WaitLine] = [ + self.elements[MapElementType.WAIT_LINE][wait_line_id] + for wait_line_id in wait_line_ids + ] + return wait_lines + + def get_traffic_signs(self, lane_id: str) -> List[TrafficSign]: + # Get traffic signs for a lane + traffic_sign_ids = self.get_road_lane(lane_id).traffic_sign_ids + traffic_signs: List[TrafficSign] = [] + if traffic_sign_ids is not None: + traffic_signs: List[TrafficSign] = [ + self.elements[MapElementType.TRAFFIC_SIGN][traffic_sign_id] + for traffic_sign_id in traffic_sign_ids + ] + + return traffic_signs + + def associate_traffic_sign_with_wait_line(self, lane_id: str) -> List[Tuple[TrafficSign, WaitLine]]: + """Associate traffic signs with wait lines for a lane using nearest distance. + + Args: + lane_id (str): lane id + + Returns: + List[Tuple[TrafficSign, WaitLine]]: List of (traffic_sign, wait_line) tuples that are associated. + """ + + # Get waitlines and traffic signs + wait_lines: List[WaitLine] = self.get_wait_lines(lane_id) + traffic_signs: List[TrafficSign] = self.get_traffic_signs(lane_id) + + # Associate traffic signs with wait lines + traffic_sign_wait_line_associations: List[Tuple[TrafficSign, WaitLine]] = [] + for traffic_sign in traffic_signs: + smallest_distance_to_wait_line = np.inf + nearest_wait_line = None + for wait_line in wait_lines: + distance_to_wait_line = wait_line.polyline.distance_to_point( + traffic_sign.position[None,:] + ) + if distance_to_wait_line < smallest_distance_to_wait_line: + smallest_distance_to_wait_line = distance_to_wait_line + nearest_wait_line = wait_line + if nearest_wait_line is not None: + traffic_sign_wait_line_associations.append( + (traffic_sign, nearest_wait_line) + ) + return traffic_sign_wait_line_associations + + def is_stop_sign(self, traffic_sign: TrafficSign) -> bool: + if traffic_sign.sign_type == "TRAFFIC_SIGN_REGULATORY_R1_STOP": + return True + return False + + def get_wait_line( + self, lane_id: str + ) -> Optional[WaitLine]: + wait_line_ids = self.get_road_lane(lane_id).wait_line_ids + if len(wait_line_ids) == 0: + return None + return self.elements[MapElementType.WAIT_LINE][wait_line_ids.pop()] + + def get_online_metadict(self, lane_id: str, scene_ts: int = 0) -> Dict: + return ( + self.online_metadict[(str(lane_id), scene_ts)] + if self.online_metadict is not None + else {} + ) def rasterize( self, resolution: float = 2, **kwargs @@ -457,110 +700,7 @@ def rasterize( Returns: np.ndarray: The rasterized RGB image. """ - return_tf_mat: bool = kwargs.get("return_tf_mat", False) - incl_centerlines: bool = kwargs.get("incl_centerlines", True) - incl_lane_edges: bool = kwargs.get("incl_lane_edges", True) - incl_lane_area: bool = kwargs.get("incl_lane_area", True) - - scene_ts: Optional[int] = kwargs.get("scene_ts", None) - - # (255, 102, 99) also looks nice. - center_color: Tuple[int, int, int] = kwargs.get("center_color", (129, 51, 255)) - # (86, 203, 249) also looks nice. - edge_color: Tuple[int, int, int] = kwargs.get("edge_color", (118, 185, 0)) - # (191, 215, 234) also looks nice. - area_color: Tuple[int, int, int] = kwargs.get("area_color", (214, 232, 181)) - - min_x, min_y, _, max_x, max_y, _ = self.extent - - world_center_m: Tuple[float, float] = ( - (max_x + min_x) / 2, - (max_y + min_y) / 2, - ) - - raster_size_x: int = ceil((max_x - min_x) * resolution) - raster_size_y: int = ceil((max_y - min_y) * resolution) - - raster_from_local: np.ndarray = np.array( - [ - [resolution, 0, raster_size_x / 2], - [0, resolution, raster_size_y / 2], - [0, 0, 1], - ] - ) - - # Compute pose from its position and rotation. - pose_from_world: np.ndarray = np.array( - [ - [1, 0, -world_center_m[0]], - [0, 1, -world_center_m[1]], - [0, 0, 1], - ] - ) - - raster_from_world: np.ndarray = raster_from_local @ pose_from_world - - map_img: np.ndarray = np.zeros( - shape=(raster_size_y, raster_size_x, 3), dtype=np.uint8 - ) - - lane_edges: List[np.ndarray] = list() - centerlines: List[np.ndarray] = list() - lane: RoadLane - for lane in tqdm( - self.elements[MapElementType.ROAD_LANE].values(), - desc=f"Rasterizing Map at {resolution:.2f} px/m", - leave=False, - ): - centerlines.append( - raster_utils.world_to_subpixel( - lane.center.points[:, :2], raster_from_world - ) - ) - if lane.left_edge is not None and lane.right_edge is not None: - left_pts: np.ndarray = lane.left_edge.points[:, :2] - right_pts: np.ndarray = lane.right_edge.points[:, :2] - - lane_edges += [ - raster_utils.world_to_subpixel(left_pts, raster_from_world), - raster_utils.world_to_subpixel(right_pts, raster_from_world), - ] - - lane_color = area_color - status = self.get_traffic_light_status(lane.id, scene_ts) - if status == TrafficLightStatus.GREEN: - lane_color = [0, 200, 0] - elif status == TrafficLightStatus.RED: - lane_color = [200, 0, 0] - elif status == TrafficLightStatus.UNKNOWN: - lane_color = [150, 150, 0] - - # Drawing lane areas. Need to do per loop because doing it all at once can - # create lots of wonky holes in the image. - # See https://stackoverflow.com/questions/69768620/cv2-fillpoly-failing-for-intersecting-polygons - if incl_lane_area: - lane_area: np.ndarray = np.concatenate( - [left_pts, right_pts[::-1]], axis=0 - ) - raster_utils.rasterize_world_polygon( - lane_area, - map_img, - raster_from_world, - color=lane_color, - ) - - # Drawing all lane edge lines at the same time. - if incl_lane_edges: - raster_utils.cv2_draw_polylines(lane_edges, map_img, color=edge_color) - - # Drawing centerlines last (on top of everything else). - if incl_centerlines: - raster_utils.cv2_draw_polylines(centerlines, map_img, color=center_color) - - if return_tf_mat: - return map_img.astype(float) / 255, raster_from_world - else: - return map_img.astype(float) / 255 + raise NotImplementedError() @overload def visualize_lane_graph( @@ -568,16 +708,17 @@ def visualize_lane_graph( origin_lane: RoadLane, num_hops: int, **kwargs, - ) -> Axes: - ... + ) -> Axes: ... @overload - def visualize_lane_graph(self, origin_lane: str, num_hops: int, **kwargs) -> Axes: - ... + def visualize_lane_graph( + self, origin_lane: str, num_hops: int, **kwargs + ) -> Axes: ... @overload - def visualize_lane_graph(self, origin_lane: int, num_hops: int, **kwargs) -> Axes: - ... + def visualize_lane_graph( + self, origin_lane: int, num_hops: int, **kwargs + ) -> Axes: ... def visualize_lane_graph( self, origin_lane: Union[RoadLane, str, int], num_hops: int, **kwargs @@ -646,3 +787,53 @@ def visualize_lane_graph( ax.legend(loc="best", frameon=True) return ax + + +def split_lane_segments(lane, n=None, max_len=None): + if n is None: + length = np.linalg.norm(lane.center.xy[1:] - lane.center.xy[:-1], axis=-1).sum() + n = ceil(length / max_len) + if n == 1: + return [lane] + idx = np.linspace(0, lane.center.xy.shape[0] - 1, n + 1).astype(int) + left_idx = ( + np.linspace(0, lane.left_edge.xy.shape[0] - 1, n + 1).astype(int) + if lane.left_edge is not None + else None + ) + right_idx = ( + np.linspace(0, lane.right_edge.xy.shape[0] - 1, n + 1).astype(int) + if lane.right_edge is not None + else None + ) + + split_lanes = list() + for i in range(n): + center = Polyline(points=lane.center.points[idx[i] : idx[i + 1] + 1]) + left_edge = ( + Polyline(lane.left_edge.points[left_idx[i] : left_idx[i + 1] + 1]) + if lane.left_edge is not None + else None + ) + right_edge = ( + Polyline(lane.right_edge.points[right_idx[i] : right_idx[i + 1] + 1]) + if lane.right_edge is not None + else None + ) + + new_lane = RoadLane( + center=center, + left_edge=left_edge, + right_edge=right_edge, + adj_lanes_left=lane.adj_lanes_left, + adj_lanes_right=lane.adj_lanes_right, + next_lanes=lane.next_lanes if i == n - 1 else set(), + prev_lanes=lane.prev_lanes if i == 0 else {split_lanes[i - 1].id}, + road_area_ids=lane.road_area_ids, + elem_type=lane.elem_type, + id=lane.id + f"_{i}", + ) + if i > 0: + split_lanes[i - 1].next_lanes = {new_lane.id} + split_lanes.append(new_lane) + return split_lanes diff --git a/src/trajdata/maps/vec_map_elements.py b/src/trajdata/maps/vec_map_elements.py index fd7a9f1..c5f57a8 100644 --- a/src/trajdata/maps/vec_map_elements.py +++ b/src/trajdata/maps/vec_map_elements.py @@ -1,10 +1,10 @@ from dataclasses import dataclass, field from enum import IntEnum -from typing import List, Optional, Set +from typing import List, Optional, Set, Union import numpy as np - from trajdata.utils import map_utils +from trajdata.utils.arr_utils import angle_wrap class MapElementType(IntEnum): @@ -12,6 +12,9 @@ class MapElementType(IntEnum): ROAD_AREA = 2 PED_CROSSWALK = 3 PED_WALKWAY = 4 + ROAD_EDGE = 5 + TRAFFIC_SIGN = 6 + WAIT_LINE = 7 @dataclass @@ -47,6 +50,15 @@ def xy(self) -> np.ndarray: def xyz(self) -> np.ndarray: return self.points[..., :3] + @property + def xyh(self) -> np.ndarray: + if self.has_heading: + return self.points[..., (0, 1, 3)] + else: + raise ValueError( + f"This Polyline only has {self.points.shape[-1]} coordinates, expected 4." + ) + @property def xyzh(self) -> np.ndarray: if self.has_heading: @@ -67,14 +79,17 @@ def interpolate( map_utils.interpolate(self.points, num_pts=num_pts, max_dist=max_dist) ) - def project_onto(self, xyz_or_xyzh: np.ndarray) -> np.ndarray: + def project_onto(self, xyz_or_xyzh: np.ndarray, return_index: bool = False) -> Union[np.ndarray, List]: """Project the given points onto this Polyline. Args: xyzh (np.ndarray): Points to project, of shape (M, D) + return_indices (bool): Return the index of starting point of the line segment + on which the projected points lies on. Returns: np.ndarray: The projected points, of shape (M, D) + np.ndarray: The index of previous polyline points if return_indices == True. Note: D = 4 if this Polyline has headings, otherwise D = 3 @@ -94,7 +109,7 @@ def project_onto(self, xyz_or_xyzh: np.ndarray) -> np.ndarray: dot_products: np.ndarray = (point_seg_diffs * line_seg_diffs).sum( axis=-1, keepdims=True ) - norms: np.ndarray = np.linalg.norm(line_seg_diffs, axis=-1, keepdims=True) ** 2 + norms: np.ndarray = np.square(line_seg_diffs).sum(axis=-1, keepdims=True) # Clip ensures that the projected point stays within the line segment boundaries. projs: np.ndarray = ( @@ -102,20 +117,114 @@ def project_onto(self, xyz_or_xyzh: np.ndarray) -> np.ndarray: ) # 2. Find the nearest projections to the original points. - closest_proj_idxs: int = np.linalg.norm(xyz - projs, axis=-1).argmin(axis=-1) + # We have nan values when two consecutive points are equal. This will never be + # the closest projection point, so we replace nans with a large number. + point_to_proj_dist = np.nan_to_num(np.linalg.norm(xyz - projs, axis=-1), nan=1e6) + closest_proj_idxs: int = point_to_proj_dist.argmin(axis=-1) + + proj_points = projs[range(xyz.shape[0]), closest_proj_idxs] if self.has_heading: # Adding in the heading of the corresponding p0 point (which makes # sense as p0 to p1 is a line => same heading along it). - return np.concatenate( + proj_points = np.concatenate( [ - projs[range(xyz.shape[0]), closest_proj_idxs], + proj_points, np.expand_dims(self.points[closest_proj_idxs, -1], axis=-1), ], axis=-1, ) + + if return_index: + return proj_points, closest_proj_idxs else: - return projs[range(xyz.shape[0]), closest_proj_idxs] + return proj_points + + def distance_to_point(self, xyz: np.ndarray): + assert xyz.ndim == 2 + xyz_proj = self.project_onto(xyz) + return np.linalg.norm(xyz[..., :3] - xyz_proj[..., :3], axis=-1) + + def get_length(self): + # TODO we could store cummulative distances to speed this up + dists = np.linalg.norm(self.xyz[1:, :3] - self.xyz[:-1, :3], axis=-1) + length = dists.sum() + return length + + + def get_length_from(self, start_ind: np.ndarray): + # TODO we could store cummulative distances to speed this up + assert start_ind.ndim == 1 + dists = np.linalg.norm(self.xyz[1:, :3] - self.xyz[:-1, :3], axis=-1) + length_upto = np.cumsum(np.pad(dists, (1, 0))) + length_from = length_upto[-1][None] - length_upto[start_ind] + return length_from + + + def traverse_along(self, dist: np.ndarray, start_ind: Optional[np.ndarray] = None) -> np.ndarray: + """ + Interpolated endpoint of traversing `dist` distance along polyline from a starting point. + + Returns nan if the end point is not inside the polyline. + TODO we could store cummulative distances to speed this up + + Args: + dist (np.ndarray): distances, any shape [...] + start_ind (np.ndarray): index of point along polyline to calcualte distance from. + Optional. Shape must match dist. [...] + + Returns: + endpoint_xyzh (np.ndarray): points along polyline `dist` distance from the + starting point. Nan if endpoint would require extrapolation. [..., 4] + + """ + assert self.has_heading + + # Add up distances from beginning of polyline + segment_lens = np.linalg.norm(self.xyz[1:] - self.xyz[:-1], axis=-1) # n-1 + cum_len = np.pad(np.cumsum(segment_lens, axis=0), (1, 0)) # n + + # Increase dist with the length of lane up to start_ind + if start_ind is not None: + assert start_ind.ndim == dist.ndim + dist = dist + cum_len[start_ind] + + # Find the first index where cummulative length is larger or equal than `dist` + inds = np.searchsorted(cum_len, dist, side='right') + # Invalidate inds == 0 and inds == len(cum_len), which means endpoint is outside the polyline. + invalid = np.logical_or(inds == 0, inds == len(cum_len)) + # Replace invalid indices so we can easily carry out computation below, and invalidate output eventually. + inds[invalid] = 1 + + # Remaining distance from last point + remaining_dist = dist - cum_len[inds-1] + + # Invalidate negative remaining dist (this should only happen when dist < 0) + invalid = np.logical_or(invalid, remaining_dist < 0.) + + # Interpolate between the previous and next points. + segment_vect_xyz = self.xyz[inds] - self.xyz[inds-1] + segment_len = np.linalg.norm(segment_vect_xyz, axis=-1) + assert (segment_len > 0.).all(), "Polyline segment has zero length" + + proportion = (remaining_dist / segment_len) + endpoint_xyz = segment_vect_xyz * proportion[..., np.newaxis] + self.xyz[inds] + endpoint_h = angle_wrap(angle_wrap(self.h[inds] - self.h[inds-1]) * proportion + self.h[inds-1]) + endpoint_xyzh = np.concatenate((endpoint_xyz, endpoint_h[..., np.newaxis]), axis=-1) + + # Invalidate dummy output + endpoint_xyzh[invalid] = np.nan + + return endpoint_xyzh + + def concatenate_with(self, other: "Polyline") -> "Polyline": + return self.concatenate([self, other]) + + @staticmethod + def concatenate(polylines: List["Polyline"]) -> "Polyline": + # Assumes no overlap between consecutive polylines, i.e. next lane starts after current lane ends. + points = np.concatenate([polyline.points for polyline in polylines], axis=0) + return Polyline(points) @dataclass @@ -128,10 +237,13 @@ class RoadLane(MapElement): center: Polyline left_edge: Optional[Polyline] = None right_edge: Optional[Polyline] = None + traffic_sign_ids: Optional[Set[str]] = None + wait_line_ids: Optional[Set[str]] = None adj_lanes_left: Set[str] = field(default_factory=lambda: set()) adj_lanes_right: Set[str] = field(default_factory=lambda: set()) next_lanes: Set[str] = field(default_factory=lambda: set()) prev_lanes: Set[str] = field(default_factory=lambda: set()) + road_area_ids: Set[str] = field(default_factory=lambda: set()) elem_type: MapElementType = MapElementType.ROAD_LANE def __post_init__(self) -> None: @@ -151,6 +263,32 @@ def __hash__(self) -> int: def reachable_lanes(self) -> Set[str]: return self.adj_lanes_left | self.adj_lanes_right | self.next_lanes + def combine_next(self, next_lane): + assert next_lane.id in self.next_lanes + self.next_lanes.remove(next_lane.id) + self.next_lanes = self.next_lanes.union(next_lane.next_lanes) + self.center = self.center.concatenate_with(next_lane.center) + if self.left_edge is not None and next_lane.left_edge is not None: + self.left_edge = self.left_edge.concatenate_with(next_lane.left_edge) + if self.right_edge is not None and next_lane.right_edge is not None: + self.right_edge = self.right_edge.concatenate_with(next_lane.right_edge) + self.adj_lanes_right = self.adj_lanes_right.union(next_lane.adj_lanes_right) + self.adj_lanes_left = self.adj_lanes_left.union(next_lane.adj_lanes_left) + self.road_area_ids = self.road_area_ids.union(next_lane.road_area_ids) + + def combine_prev(self,prev_lane): + assert prev_lane.id in self.prev_lanes + self.prev_lanes.remove(prev_lane.id) + self.prev_lanes = self.prev_lanes.union(prev_lane.prev_lanes) + self.center = prev_lane.center.concatenate_with(self.center) + if self.left_edge is not None and prev_lane.left_edge is not None: + self.left_edge = prev_lane.left_edge.concatenate_with(self.left_edge) + if self.right_edge is not None and prev_lane.right_edge is not None: + self.right_edge = prev_lane.right_edge.concatenate_with(self.right_edge) + self.adj_lanes_right = self.adj_lanes_right.union(prev_lane.adj_lanes_right) + self.adj_lanes_left = self.adj_lanes_left.union(prev_lane.adj_lanes_left) + self.road_area_ids = self.road_area_ids.union(prev_lane.road_area_ids) + @dataclass class RoadArea(MapElement): @@ -169,3 +307,37 @@ class PedCrosswalk(MapElement): class PedWalkway(MapElement): polygon: Polyline elem_type: MapElementType = MapElementType.PED_WALKWAY + + +@dataclass +class RoadEdge(MapElement): + polyline: Polyline + elem_type: MapElementType = MapElementType.ROAD_EDGE + + def __post_init__(self) -> None: + if not self.polyline.has_heading: + self.polyline = Polyline( + np.append( + self.polyline.xyz, + map_utils.get_polyline_headings(self.polyline.xyz), + axis=-1, + ) + ) + + def __hash__(self) -> int: + return hash(self.id) + + +@dataclass +class TrafficSign(MapElement): + position: np.ndarray + sign_type: str + elem_type: MapElementType = MapElementType.TRAFFIC_SIGN + +@dataclass +class WaitLine(MapElement): + polyline: Polyline + wait_line_type: str + is_implicit: bool + elem_type: MapElementType = MapElementType.WAIT_LINE + diff --git a/src/trajdata/parallel/data_preprocessor.py b/src/trajdata/parallel/data_preprocessor.py index 6c61353..fa44e8e 100644 --- a/src/trajdata/parallel/data_preprocessor.py +++ b/src/trajdata/parallel/data_preprocessor.py @@ -23,7 +23,7 @@ def __init__( cache_class: Type[SceneCache], rebuild_cache: bool, ) -> None: - self.env_cache_path = np.array(env_cache_path).astype(np.string_) + self.env_cache_path = np.array(env_cache_path).astype(np.bytes_) self.desired_dt = desired_dt self.cache_class = cache_class self.rebuild_cache = rebuild_cache @@ -43,9 +43,9 @@ def __init__( ) self.scene_name_idxs = np.array(scene_name_idxs, dtype=int) - self.env_names_arr = np.array(env_names).astype(np.string_) - self.scene_names_arr = np.array(scene_names).astype(np.string_) - self.data_dir_arr = np.array(list(envs_dir_dict.values())).astype(np.string_) + self.env_names_arr = np.array(env_names).astype(np.bytes_) + self.scene_names_arr = np.array(scene_names).astype(np.bytes_) + self.data_dir_arr = np.array(list(envs_dir_dict.values())).astype(np.bytes_) self.data_len: int = len(scene_info_list) diff --git a/src/trajdata/proto/vectorized_map.proto b/src/trajdata/proto/vectorized_map.proto index e1cd502..505c957 100644 --- a/src/trajdata/proto/vectorized_map.proto +++ b/src/trajdata/proto/vectorized_map.proto @@ -29,6 +29,9 @@ message MapElement { RoadArea road_area = 3; PedCrosswalk ped_crosswalk = 4; PedWalkway ped_walkway = 5; + RoadEdge road_edge = 6; + TrafficSign traffic_sign = 7; + WaitLine wait_line = 8; } } @@ -39,14 +42,14 @@ message Point { } message Polyline { - // Position deltas in millimeters. The origin is an arbitrary location. - // From https://github.com/woven-planet/l5kit/blob/master/l5kit/l5kit/data/proto/road_network.proto#L446 + // Position deltas in 10^-5 meters. The origin is an arbitrary location. + // Inspired by https://github.com/woven-planet/l5kit/blob/master/l5kit/l5kit/data/proto/road_network.proto#L446 // The delta for the first point is just its coordinates tuple, i.e. it is a "delta" from // the origin. For subsequent points, this field stores the difference between the point's // coordinates and the previous point's coordinates. This is for representation efficiency. - repeated sint32 dx_mm = 1; - repeated sint32 dy_mm = 2; - repeated sint32 dz_mm = 3; + repeated sint64 dx_mm = 1; + repeated sint64 dy_mm = 2; + repeated sint64 dz_mm = 3; repeated double h_rad = 4; } @@ -74,6 +77,16 @@ message RoadLane { // A list of neighbors to the right of this lane. Neighbor lanes // include only adjacent lanes going the same direction. repeated bytes adjacent_lanes_right = 7; + + // A list of associated road area ids. + repeated bytes road_area_ids = 8; + + // A list of associated traffic sign ids. + repeated bytes traffic_sign_ids = 9; + + // A list of associated wait line ids. + repeated bytes wait_line_ids = 10; + } message RoadArea { @@ -98,4 +111,26 @@ message PedWalkway { // The polygon is assumed to be closed (i.e. a segment exists between the last // point and the first point). Polyline polygon = 1; +} + +message RoadEdge { + // The polyline data for the edge of the road. This is used to define the + // boundary of the road network. + Polyline polyline = 1; +} + +message TrafficSign { + // Position of the traffic sign + Point position = 1; + // Type of the traffic sign + string sign_type = 2; +} + +message WaitLine { + // The polyline data for the wait line + Polyline polyline = 1; + // Type of the wait line. e.g., stop line + string wait_line_type = 2; + // Is the waitline explicit or implicit + bool is_implicit = 3; } \ No newline at end of file diff --git a/src/trajdata/proto/vectorized_map_pb2.py b/src/trajdata/proto/vectorized_map_pb2.py index acdd9cc..cb67f79 100644 --- a/src/trajdata/proto/vectorized_map_pb2.py +++ b/src/trajdata/proto/vectorized_map_pb2.py @@ -1,135 +1,46 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: vectorized_map.proto +# Protobuf Python Version: 4.25.1 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database - +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x14vectorized_map.proto\x12\x08trajdata"\xb0\x01\n\rVectorizedMap\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x08\x65lements\x18\x02 \x03(\x0b\x32\x14.trajdata.MapElement\x12\x1f\n\x06max_pt\x18\x03 \x01(\x0b\x32\x0f.trajdata.Point\x12\x1f\n\x06min_pt\x18\x04 \x01(\x0b\x32\x0f.trajdata.Point\x12\'\n\x0eshifted_origin\x18\x05 \x01(\x0b\x32\x0f.trajdata.Point"\xd8\x01\n\nMapElement\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\'\n\troad_lane\x18\x02 \x01(\x0b\x32\x12.trajdata.RoadLaneH\x00\x12\'\n\troad_area\x18\x03 \x01(\x0b\x32\x12.trajdata.RoadAreaH\x00\x12/\n\rped_crosswalk\x18\x04 \x01(\x0b\x32\x16.trajdata.PedCrosswalkH\x00\x12+\n\x0bped_walkway\x18\x05 \x01(\x0b\x32\x14.trajdata.PedWalkwayH\x00\x42\x0e\n\x0c\x65lement_data"(\n\x05Point\x12\t\n\x01x\x18\x01 \x01(\x01\x12\t\n\x01y\x18\x02 \x01(\x01\x12\t\n\x01z\x18\x03 \x01(\x01"F\n\x08Polyline\x12\r\n\x05\x64x_mm\x18\x01 \x03(\x11\x12\r\n\x05\x64y_mm\x18\x02 \x03(\x11\x12\r\n\x05\x64z_mm\x18\x03 \x03(\x11\x12\r\n\x05h_rad\x18\x04 \x03(\x01"\x98\x02\n\x08RoadLane\x12"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\x12.\n\rleft_boundary\x18\x02 \x01(\x0b\x32\x12.trajdata.PolylineH\x00\x88\x01\x01\x12/\n\x0eright_boundary\x18\x03 \x01(\x0b\x32\x12.trajdata.PolylineH\x01\x88\x01\x01\x12\x13\n\x0b\x65ntry_lanes\x18\x04 \x03(\x0c\x12\x12\n\nexit_lanes\x18\x05 \x03(\x0c\x12\x1b\n\x13\x61\x64jacent_lanes_left\x18\x06 \x03(\x0c\x12\x1c\n\x14\x61\x64jacent_lanes_right\x18\x07 \x03(\x0c\x42\x10\n\x0e_left_boundaryB\x11\n\x0f_right_boundary"d\n\x08RoadArea\x12,\n\x10\x65xterior_polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\x12*\n\x0einterior_holes\x18\x02 \x03(\x0b\x32\x12.trajdata.Polyline"3\n\x0cPedCrosswalk\x12#\n\x07polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline"1\n\nPedWalkway\x12#\n\x07polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polylineb\x06proto3' -) - - -_VECTORIZEDMAP = DESCRIPTOR.message_types_by_name["VectorizedMap"] -_MAPELEMENT = DESCRIPTOR.message_types_by_name["MapElement"] -_POINT = DESCRIPTOR.message_types_by_name["Point"] -_POLYLINE = DESCRIPTOR.message_types_by_name["Polyline"] -_ROADLANE = DESCRIPTOR.message_types_by_name["RoadLane"] -_ROADAREA = DESCRIPTOR.message_types_by_name["RoadArea"] -_PEDCROSSWALK = DESCRIPTOR.message_types_by_name["PedCrosswalk"] -_PEDWALKWAY = DESCRIPTOR.message_types_by_name["PedWalkway"] -VectorizedMap = _reflection.GeneratedProtocolMessageType( - "VectorizedMap", - (_message.Message,), - { - "DESCRIPTOR": _VECTORIZEDMAP, - "__module__": "vectorized_map_pb2" - # @@protoc_insertion_point(class_scope:trajdata.VectorizedMap) - }, -) -_sym_db.RegisterMessage(VectorizedMap) - -MapElement = _reflection.GeneratedProtocolMessageType( - "MapElement", - (_message.Message,), - { - "DESCRIPTOR": _MAPELEMENT, - "__module__": "vectorized_map_pb2" - # @@protoc_insertion_point(class_scope:trajdata.MapElement) - }, -) -_sym_db.RegisterMessage(MapElement) - -Point = _reflection.GeneratedProtocolMessageType( - "Point", - (_message.Message,), - { - "DESCRIPTOR": _POINT, - "__module__": "vectorized_map_pb2" - # @@protoc_insertion_point(class_scope:trajdata.Point) - }, -) -_sym_db.RegisterMessage(Point) - -Polyline = _reflection.GeneratedProtocolMessageType( - "Polyline", - (_message.Message,), - { - "DESCRIPTOR": _POLYLINE, - "__module__": "vectorized_map_pb2" - # @@protoc_insertion_point(class_scope:trajdata.Polyline) - }, -) -_sym_db.RegisterMessage(Polyline) - -RoadLane = _reflection.GeneratedProtocolMessageType( - "RoadLane", - (_message.Message,), - { - "DESCRIPTOR": _ROADLANE, - "__module__": "vectorized_map_pb2" - # @@protoc_insertion_point(class_scope:trajdata.RoadLane) - }, -) -_sym_db.RegisterMessage(RoadLane) - -RoadArea = _reflection.GeneratedProtocolMessageType( - "RoadArea", - (_message.Message,), - { - "DESCRIPTOR": _ROADAREA, - "__module__": "vectorized_map_pb2" - # @@protoc_insertion_point(class_scope:trajdata.RoadArea) - }, -) -_sym_db.RegisterMessage(RoadArea) -PedCrosswalk = _reflection.GeneratedProtocolMessageType( - "PedCrosswalk", - (_message.Message,), - { - "DESCRIPTOR": _PEDCROSSWALK, - "__module__": "vectorized_map_pb2" - # @@protoc_insertion_point(class_scope:trajdata.PedCrosswalk) - }, -) -_sym_db.RegisterMessage(PedCrosswalk) -PedWalkway = _reflection.GeneratedProtocolMessageType( - "PedWalkway", - (_message.Message,), - { - "DESCRIPTOR": _PEDWALKWAY, - "__module__": "vectorized_map_pb2" - # @@protoc_insertion_point(class_scope:trajdata.PedWalkway) - }, -) -_sym_db.RegisterMessage(PedWalkway) +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14vectorized_map.proto\x12\x08trajdata\"\xb0\x01\n\rVectorizedMap\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x08\x65lements\x18\x02 \x03(\x0b\x32\x14.trajdata.MapElement\x12\x1f\n\x06max_pt\x18\x03 \x01(\x0b\x32\x0f.trajdata.Point\x12\x1f\n\x06min_pt\x18\x04 \x01(\x0b\x32\x0f.trajdata.Point\x12\'\n\x0eshifted_origin\x18\x05 \x01(\x0b\x32\x0f.trajdata.Point\"\xd9\x02\n\nMapElement\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\'\n\troad_lane\x18\x02 \x01(\x0b\x32\x12.trajdata.RoadLaneH\x00\x12\'\n\troad_area\x18\x03 \x01(\x0b\x32\x12.trajdata.RoadAreaH\x00\x12/\n\rped_crosswalk\x18\x04 \x01(\x0b\x32\x16.trajdata.PedCrosswalkH\x00\x12+\n\x0bped_walkway\x18\x05 \x01(\x0b\x32\x14.trajdata.PedWalkwayH\x00\x12\'\n\troad_edge\x18\x06 \x01(\x0b\x32\x12.trajdata.RoadEdgeH\x00\x12-\n\x0ctraffic_sign\x18\x07 \x01(\x0b\x32\x15.trajdata.TrafficSignH\x00\x12\'\n\twait_line\x18\x08 \x01(\x0b\x32\x12.trajdata.WaitLineH\x00\x42\x0e\n\x0c\x65lement_data\"(\n\x05Point\x12\t\n\x01x\x18\x01 \x01(\x01\x12\t\n\x01y\x18\x02 \x01(\x01\x12\t\n\x01z\x18\x03 \x01(\x01\"F\n\x08Polyline\x12\r\n\x05\x64x_mm\x18\x01 \x03(\x12\x12\r\n\x05\x64y_mm\x18\x02 \x03(\x12\x12\r\n\x05\x64z_mm\x18\x03 \x03(\x12\x12\r\n\x05h_rad\x18\x04 \x03(\x01\"\xe0\x02\n\x08RoadLane\x12\"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\x12.\n\rleft_boundary\x18\x02 \x01(\x0b\x32\x12.trajdata.PolylineH\x00\x88\x01\x01\x12/\n\x0eright_boundary\x18\x03 \x01(\x0b\x32\x12.trajdata.PolylineH\x01\x88\x01\x01\x12\x13\n\x0b\x65ntry_lanes\x18\x04 \x03(\x0c\x12\x12\n\nexit_lanes\x18\x05 \x03(\x0c\x12\x1b\n\x13\x61\x64jacent_lanes_left\x18\x06 \x03(\x0c\x12\x1c\n\x14\x61\x64jacent_lanes_right\x18\x07 \x03(\x0c\x12\x15\n\rroad_area_ids\x18\x08 \x03(\x0c\x12\x18\n\x10traffic_sign_ids\x18\t \x03(\x0c\x12\x15\n\rwait_line_ids\x18\n \x03(\x0c\x42\x10\n\x0e_left_boundaryB\x11\n\x0f_right_boundary\"d\n\x08RoadArea\x12,\n\x10\x65xterior_polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\x12*\n\x0einterior_holes\x18\x02 \x03(\x0b\x32\x12.trajdata.Polyline\"3\n\x0cPedCrosswalk\x12#\n\x07polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\"1\n\nPedWalkway\x12#\n\x07polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\"0\n\x08RoadEdge\x12$\n\x08polyline\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\"C\n\x0bTrafficSign\x12!\n\x08position\x18\x01 \x01(\x0b\x32\x0f.trajdata.Point\x12\x11\n\tsign_type\x18\x02 \x01(\t\"]\n\x08WaitLine\x12$\n\x08polyline\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\x12\x16\n\x0ewait_line_type\x18\x02 \x01(\t\x12\x13\n\x0bis_implicit\x18\x03 \x01(\x08\x62\x06proto3') +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'vectorized_map_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _VECTORIZEDMAP._serialized_start = 35 - _VECTORIZEDMAP._serialized_end = 211 - _MAPELEMENT._serialized_start = 214 - _MAPELEMENT._serialized_end = 430 - _POINT._serialized_start = 432 - _POINT._serialized_end = 472 - _POLYLINE._serialized_start = 474 - _POLYLINE._serialized_end = 544 - _ROADLANE._serialized_start = 547 - _ROADLANE._serialized_end = 827 - _ROADAREA._serialized_start = 829 - _ROADAREA._serialized_end = 929 - _PEDCROSSWALK._serialized_start = 931 - _PEDCROSSWALK._serialized_end = 982 - _PEDWALKWAY._serialized_start = 984 - _PEDWALKWAY._serialized_end = 1033 + DESCRIPTOR._options = None + _globals['_VECTORIZEDMAP']._serialized_start=35 + _globals['_VECTORIZEDMAP']._serialized_end=211 + _globals['_MAPELEMENT']._serialized_start=214 + _globals['_MAPELEMENT']._serialized_end=559 + _globals['_POINT']._serialized_start=561 + _globals['_POINT']._serialized_end=601 + _globals['_POLYLINE']._serialized_start=603 + _globals['_POLYLINE']._serialized_end=673 + _globals['_ROADLANE']._serialized_start=676 + _globals['_ROADLANE']._serialized_end=1028 + _globals['_ROADAREA']._serialized_start=1030 + _globals['_ROADAREA']._serialized_end=1130 + _globals['_PEDCROSSWALK']._serialized_start=1132 + _globals['_PEDCROSSWALK']._serialized_end=1183 + _globals['_PEDWALKWAY']._serialized_start=1185 + _globals['_PEDWALKWAY']._serialized_end=1234 + _globals['_ROADEDGE']._serialized_start=1236 + _globals['_ROADEDGE']._serialized_end=1284 + _globals['_TRAFFICSIGN']._serialized_start=1286 + _globals['_TRAFFICSIGN']._serialized_end=1353 + _globals['_WAITLINE']._serialized_start=1355 + _globals['_WAITLINE']._serialized_end=1448 # @@protoc_insertion_point(module_scope) diff --git a/src/trajdata/proto/vectorized_map_pb2_grpc.py b/src/trajdata/proto/vectorized_map_pb2_grpc.py new file mode 100644 index 0000000..c1f9c43 --- /dev/null +++ b/src/trajdata/proto/vectorized_map_pb2_grpc.py @@ -0,0 +1,24 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + + +GRPC_GENERATED_VERSION = '1.67.0' +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in vectorized_map_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + ) diff --git a/src/trajdata/utils/arr_utils.py b/src/trajdata/utils/arr_utils.py index e76a678..fa305d3 100644 --- a/src/trajdata/utils/arr_utils.py +++ b/src/trajdata/utils/arr_utils.py @@ -93,7 +93,9 @@ def vrange(starts: np.ndarray, stops: np.ndarray) -> np.ndarray: return np.repeat(stops - lens.cumsum(), lens) + np.arange(lens.sum()) -def angle_wrap(radians: np.ndarray) -> np.ndarray: +def angle_wrap( + radians: Union[np.ndarray, torch.Tensor] +) -> Union[np.ndarray, torch.Tensor]: """This function wraps angles to lie within [-pi, pi). Args: @@ -130,25 +132,34 @@ def rotation_matrix(angle: Union[float, np.ndarray]) -> np.ndarray: return rotmat.transpose(*np.arange(2, batch_dims + 2), 0, 1) -def transform_matrices(angles: Tensor, translations: Tensor) -> Tensor: +def transform_matrices(angles: Tensor, translations: Optional[Tensor]) -> Tensor: """Creates a 3x3 transformation matrix for each angle and translation in the input. Args: - angles (Tensor): The (N,)-shaped angles tensor to rotate points by (in radians). - translations (Tensor): The (N,2)-shaped translations to shift points by. + angles (Tensor): The (...)-shaped angles tensor to rotate points by (in radians). + translations (Tensor): The (...,2)-shaped translations to shift points by. Returns: Tensor: The Nx3x3 transformation matrices. """ cos_vals = torch.cos(angles) sin_vals = torch.sin(angles) - last_rows = torch.tensor( - [[0.0, 0.0, 1.0]], dtype=angles.dtype, device=angles.device - ).expand((angles.shape[0], -1)) + last_rows = ( + torch.tensor([0.0, 0.0, 1.0], dtype=angles.dtype, device=angles.device) + .view([1] * angles.ndim + [3]) + .expand(list(angles.shape) + [-1]) + ) + + if translations is None: + trans_x = torch.zeros_like(angles) + trans_y = trans_x + else: + trans_x, trans_y = torch.unbind(translations, dim=-1) + return torch.stack( [ - torch.stack([cos_vals, -sin_vals, translations[:, 0]], dim=-1), - torch.stack([sin_vals, cos_vals, translations[:, 1]], dim=-1), + torch.stack([cos_vals, -sin_vals, trans_x], dim=-1), + torch.stack([sin_vals, cos_vals, trans_y], dim=-1), last_rows, ], dim=-2, @@ -239,10 +250,253 @@ def transform_xyh_np(xyh: np.ndarray, tf_mat: np.ndarray) -> np.ndarray: xyh (np.ndarray): shape [...,3] tf_mat (np.ndarray): shape [...,3,3] """ - transformed_xy = transform_coords_np(xyh[..., :2], tf_mat) - transformed_angles = transform_angles_np(xyh[..., 2], tf_mat) + transformed_xy = transform_coords_np(xyh[..., :-1], tf_mat) + transformed_angles = transform_angles_np(xyh[..., -1], tf_mat) return np.concatenate([transformed_xy, transformed_angles[..., None]], axis=-1) +def transform_xyh_torch(xyh: torch.Tensor, tf_mat: torch.Tensor) -> torch.Tensor: + """ + Returns transformed set of xyh points + + Args: + xyh (torch.Tensor): shape [...,3] + tf_mat (torch.Tensor): shape [...,3,3] + """ + transformed_xy = batch_nd_transform_points_pt(xyh[..., :2], tf_mat) + transformed_angles = batch_nd_transform_angles_pt(xyh[..., 2], tf_mat) + return torch.cat([transformed_xy, transformed_angles[..., None]], dim=-1) + +# -------- TODO redundant transforms, remove them + + +def batch_nd_transform_points_np(points: np.ndarray, Mat: np.ndarray) -> np.ndarray: + ndim = Mat.shape[-1] - 1 + batch = list(range(Mat.ndim - 2)) + [Mat.ndim - 1] + [Mat.ndim - 2] + Mat = np.transpose(Mat, batch) + if points.ndim == Mat.ndim - 1: + return (points[..., np.newaxis, :] @ Mat[..., :ndim, :ndim]).squeeze(-2) + Mat[ + ..., -1:, :ndim + ].squeeze(-2) + elif points.ndim == Mat.ndim: + return ( + (points[..., np.newaxis, :] @ Mat[..., np.newaxis, :ndim, :ndim]) + + Mat[..., np.newaxis, -1:, :ndim] + ).squeeze(-2) + else: + raise Exception("wrong shape") + +def batch_nd_transform_points_pt( + points: torch.Tensor, Mat: torch.Tensor +) -> torch.Tensor: + ndim = Mat.shape[-1] - 1 + Mat = torch.transpose(Mat, -1, -2) + if points.ndim == Mat.ndim - 1: + return (points[..., np.newaxis, :] @ Mat[..., :ndim, :ndim]).squeeze(-2) + Mat[ + ..., -1:, :ndim + ].squeeze(-2) + elif points.ndim == Mat.ndim: + return ( + (points[..., np.newaxis, :] @ Mat[..., np.newaxis, :ndim, :ndim]) + + Mat[..., np.newaxis, -1:, :ndim] + ).squeeze(-2) + elif points.ndim == Mat.ndim + 1: + return ( + ( + points[..., np.newaxis, :] + @ Mat[..., np.newaxis, np.newaxis, :ndim, :ndim] + ) + + Mat[..., np.newaxis, np.newaxis, -1:, :ndim] + ).squeeze(-2) + else: + raise Exception("wrong shape") + + +def batch_nd_transform_angles_np(angles: np.ndarray, Mat: np.ndarray) -> np.ndarray: + cos_vals, sin_vals = Mat[..., 0, 0], Mat[..., 1, 0] + rot_angle = np.arctan2(sin_vals, cos_vals) + angles = angles + rot_angle + angles = angle_wrap(angles) + return angles + + +def batch_nd_transform_angles_pt( + angles: torch.Tensor, Mat: torch.Tensor +) -> torch.Tensor: + cos_vals, sin_vals = Mat[..., 0, 0], Mat[..., 1, 0] + rot_angle = torch.arctan2(sin_vals, cos_vals) + if rot_angle.ndim > angles.ndim: + raise ValueError("wrong shape") + while rot_angle.ndim < angles.ndim: + rot_angle = rot_angle.unsqueeze(-1) + angles = angles + rot_angle + angles = angle_wrap(angles) + return angles + + +def batch_nd_transform_points_angles_np( + points_angles: np.ndarray, Mat: np.ndarray +) -> np.ndarray: + assert points_angles.shape[-1] == 3 + points = batch_nd_transform_points_np(points_angles[..., :2], Mat) + angles = batch_nd_transform_angles_np(points_angles[..., 2:3], Mat) + points_angles = np.concatenate([points, angles], axis=-1) + return points_angles + + +def batch_nd_transform_points_angles_pt( + points_angles: torch.Tensor, Mat: torch.Tensor +) -> torch.Tensor: + assert points_angles.shape[-1] == 3 + points = batch_nd_transform_points_pt(points_angles[..., :2], Mat) + angles = batch_nd_transform_angles_pt(points_angles[..., 2:3], Mat) + points_angles = torch.concat([points, angles], axis=-1) + return points_angles + + +def batch_nd_transform_xyvvaahh_pt(traj_xyvvaahh: torch.Tensor, tf: torch.Tensor) -> torch.Tensor: + """ + traj_xyvvaahh: [..., state_dim] where state_dim = [x, y, vx, vy, ax, ay, sinh, cosh] + This is the state representation used in AgentBatch and SceneBatch. + """ + rot_only_tf = tf.clone() + rot_only_tf[..., :2, -1] = 0. + + xy, vv, aa, hh = torch.split(traj_xyvvaahh, (2, 2, 2, 2), dim=-1) + xy = batch_nd_transform_points_pt(xy, tf) + vv = batch_nd_transform_points_pt(vv, rot_only_tf) + aa = batch_nd_transform_points_pt(aa, rot_only_tf) + # hh: sinh, cosh instead of cosh, sinh, so we use flip + hh = batch_nd_transform_points_pt(hh.flip(-1), rot_only_tf).flip(-1) + + return torch.concat((xy, vv, aa, hh), dim=-1) + + +# -------- end of redundant transforms + + +def transform_xyh_torch(xyh: torch.Tensor, tf_mat: torch.Tensor) -> torch.Tensor: + """ + Returns transformed set of xyh points + + Args: + xyh (torch.Tensor): shape [...,3] + tf_mat (torch.Tensor): shape [...,3,3] + """ + transformed_xy = batch_nd_transform_points_pt(xyh[..., :2], tf_mat) + transformed_angles = batch_nd_transform_angles_pt(xyh[..., 2], tf_mat) + return torch.cat([transformed_xy, transformed_angles[..., None]], dim=-1) + + +# -------- TODO redundant transforms, remove them + + +def batch_nd_transform_points_np(points: np.ndarray, Mat: np.ndarray) -> np.ndarray: + ndim = Mat.shape[-1] - 1 + batch = list(range(Mat.ndim - 2)) + [Mat.ndim - 1] + [Mat.ndim - 2] + Mat = np.transpose(Mat, batch) + if points.ndim == Mat.ndim - 1: + return (points[..., np.newaxis, :] @ Mat[..., :ndim, :ndim]).squeeze(-2) + Mat[ + ..., -1:, :ndim + ].squeeze(-2) + elif points.ndim == Mat.ndim: + return ( + (points[..., np.newaxis, :] @ Mat[..., np.newaxis, :ndim, :ndim]) + + Mat[..., np.newaxis, -1:, :ndim] + ).squeeze(-2) + else: + raise Exception("wrong shape") + + +def batch_nd_transform_points_pt( + points: torch.Tensor, Mat: torch.Tensor +) -> torch.Tensor: + ndim = Mat.shape[-1] - 1 + Mat = torch.transpose(Mat, -1, -2) + if points.ndim == Mat.ndim - 1: + return (points[..., np.newaxis, :] @ Mat[..., :ndim, :ndim]).squeeze(-2) + Mat[ + ..., -1:, :ndim + ].squeeze(-2) + elif points.ndim == Mat.ndim: + return ( + (points[..., np.newaxis, :] @ Mat[..., np.newaxis, :ndim, :ndim]) + + Mat[..., np.newaxis, -1:, :ndim] + ).squeeze(-2) + elif points.ndim == Mat.ndim + 1: + return ( + ( + points[..., np.newaxis, :] + @ Mat[..., np.newaxis, np.newaxis, :ndim, :ndim] + ) + + Mat[..., np.newaxis, np.newaxis, -1:, :ndim] + ).squeeze(-2) + else: + raise Exception("wrong shape") + + +def batch_nd_transform_angles_np(angles: np.ndarray, Mat: np.ndarray) -> np.ndarray: + cos_vals, sin_vals = Mat[..., 0, 0], Mat[..., 1, 0] + rot_angle = np.arctan2(sin_vals, cos_vals) + angles = angles + rot_angle + angles = angle_wrap(angles) + return angles + + +def batch_nd_transform_angles_pt( + angles: torch.Tensor, Mat: torch.Tensor +) -> torch.Tensor: + cos_vals, sin_vals = Mat[..., 0, 0], Mat[..., 1, 0] + rot_angle = torch.arctan2(sin_vals, cos_vals) + if rot_angle.ndim > angles.ndim: + raise ValueError("wrong shape") + while rot_angle.ndim < angles.ndim: + rot_angle = rot_angle.unsqueeze(-1) + angles = angles + rot_angle + angles = angle_wrap(angles) + return angles + + +def batch_nd_transform_points_angles_np( + points_angles: np.ndarray, Mat: np.ndarray +) -> np.ndarray: + assert points_angles.shape[-1] == 3 + points = batch_nd_transform_points_np(points_angles[..., :2], Mat) + angles = batch_nd_transform_angles_np(points_angles[..., 2:3], Mat) + points_angles = np.concatenate([points, angles], axis=-1) + return points_angles + + +def batch_nd_transform_points_angles_pt( + points_angles: torch.Tensor, Mat: torch.Tensor +) -> torch.Tensor: + assert points_angles.shape[-1] == 3 + points = batch_nd_transform_points_pt(points_angles[..., :2], Mat) + angles = batch_nd_transform_angles_pt(points_angles[..., 2:3], Mat) + points_angles = torch.concat([points, angles], axis=-1) + return points_angles + + +def batch_nd_transform_xyvvaahh_pt( + traj_xyvvaahh: torch.Tensor, tf: torch.Tensor +) -> torch.Tensor: + """ + traj_xyvvaahh: [..., state_dim] where state_dim = [x, y, vx, vy, ax, ay, sinh, cosh] + This is the state representation used in AgentBatch and SceneBatch. + """ + rot_only_tf = tf.clone() + rot_only_tf[..., :2, -1] = 0.0 + + xy, vv, aa, hh = torch.split(traj_xyvvaahh, (2, 2, 2, 2), dim=-1) + xy = batch_nd_transform_points_pt(xy, tf) + vv = batch_nd_transform_points_pt(vv, rot_only_tf) + aa = batch_nd_transform_points_pt(aa, rot_only_tf) + # hh: sinh, cosh instead of cosh, sinh, so we use flip + hh = batch_nd_transform_points_pt(hh.flip(-1), rot_only_tf).flip(-1) + + return torch.concat((xy, vv, aa, hh), dim=-1) + + +# -------- end of redundant transforms + def agent_aware_diff(values: np.ndarray, agent_ids: np.ndarray) -> np.ndarray: values_diff: np.ndarray = np.diff( @@ -325,3 +579,150 @@ def quaternion_to_yaw(q: np.ndarray): 2 * (q[..., 0] * q[..., 3] - q[..., 1] * q[..., 2]), 1 - 2 * (q[..., 2] ** 2 + q[..., 3] ** 2), ) + + +def batch_select( + x: torch.Tensor, + index: torch.Tensor, + batch_dims: int +) -> torch.Tensor: + # Indexing into tensor, treating the first `batch_dims` dimensions as batch. + # Kind of: output[..., k] = x[..., index[...]] + + assert index.ndim >= batch_dims + assert index.ndim <= x.ndim + assert x.shape[:batch_dims] == index.shape[:batch_dims] + + batch_shape = x.shape[:batch_dims] + x_flat = x.reshape(-1, *x.shape[batch_dims:]) + index_flat = index.reshape(-1, *index.shape[batch_dims:]) + x_flat = x_flat[torch.arange(x_flat.shape[0]), index_flat] + x = x_flat.reshape(*batch_shape, *x_flat.shape[1:]) + + return x + + +def roll_with_tensor(mat: torch.Tensor, shifts: torch.LongTensor, dim: int): + if dim < 0: + dim = mat.ndim + dim + arange1 = torch.arange(mat.shape[dim], device=shifts.device) + expanded_shape = [1] * dim + [-1] + [1] * (mat.ndim-dim-1) + arange1 = arange1.view(expanded_shape).expand(mat.shape) + if shifts.ndim == 1: + shifts = shifts.view([1] * (dim-1) + [-1]) + # TODO assert that shift dimenesions either match mat or 1 + shifts = shifts.view(list(shifts.shape) + [1] * (mat.ndim-dim)) + + arange2 = (arange1 - shifts) % mat.shape[dim] + # print(arange2) + return torch.gather(mat, dim, arange2) + +def round_2pi(x): + return (x + np.pi) % (2 * np.pi) - np.pi + +def batch_proj(x, line): + # x:[batch,3], line:[batch,N,3] + line_length = line.shape[-2] + batch_dim = x.ndim - 1 + if isinstance(x, torch.Tensor): + delta = line[..., 0:2] - torch.unsqueeze(x[..., 0:2], dim=-2).repeat( + *([1] * batch_dim), line_length, 1 + ) + dis = torch.linalg.norm(delta, axis=-1) + idx0 = torch.argmin(dis, dim=-1) + idx = idx0.view(*line.shape[:-2], 1, 1).repeat( + *([1] * (batch_dim + 1)), line.shape[-1] + ) + line_min = torch.squeeze(torch.gather(line, -2, idx), dim=-2) + dx = x[..., None, 0] - line[..., 0] + dy = x[..., None, 1] - line[..., 1] + delta_y = -dx * torch.sin(line_min[..., None, 2]) + dy * torch.cos( + line_min[..., None, 2] + ) + delta_x = dx * torch.cos(line_min[..., None, 2]) + dy * torch.sin( + line_min[..., None, 2] + ) + + delta_psi = round_2pi(x[..., 2] - line_min[..., 2]) + + return ( + delta_x, + delta_y, + torch.unsqueeze(delta_psi, dim=-1), + ) + elif isinstance(x, np.ndarray): + delta = line[..., 0:2] - np.repeat( + x[..., np.newaxis, 0:2], line_length, axis=-2 + ) + dis = np.linalg.norm(delta, axis=-1) + idx0 = np.argmin(dis, axis=-1) + idx = idx0.reshape(*line.shape[:-2], 1, 1).repeat(line.shape[-1], axis=-1) + line_min = np.squeeze(np.take_along_axis(line, idx, axis=-2), axis=-2) + dx = x[..., None, 0] - line[..., 0] + dy = x[..., None, 1] - line[..., 1] + delta_y = -dx * np.sin(line_min[..., None, 2]) + dy * np.cos( + line_min[..., None, 2] + ) + delta_x = dx * np.cos(line_min[..., None, 2]) + dy * np.sin( + line_min[..., None, 2] + ) + + delta_psi = round_2pi(x[..., 2] - line_min[..., 2]) + return ( + delta_x, + delta_y, + np.expand_dims(delta_psi, axis=-1), + ) + +def get_close_lanes(radius,ego_xyh,vec_map,num_pts): + # obtain close lanes, their distance to the ego + close_lanes = [] + while len(close_lanes)==0: + close_lanes=vec_map.get_lanes_within(ego_xyh,radius) + radius+=20 + dis = list() + lane_pts = np.stack([lane.center.interpolate(num_pts).points[:,[0,1,3]] for lane in close_lanes],0) + dx,dy,dh = batch_proj(ego_xyh[None].repeat(lane_pts.shape[0],0),lane_pts) + + idx = np.abs(dx).argmin(axis=1) + # hausdorff distance to the lane (longitudinal) + x_dis = np.take_along_axis(np.abs(dx),idx[:,None],axis=1).squeeze(1) + x_dis[(dx.min(1)<0) & (dx.max(1)>0)] = 0 + + y_dis = np.take_along_axis(np.abs(dy),idx[:,None],axis=1).squeeze(1) + + # distance metric to the lane (combining x,y) + dis = x_dis+y_dis + + return close_lanes, dis + + +def get_close_road_edges(radius, ego_xyh, vec_map, num_pts): + # obtain close lanes, their distance to the ego + close_road_edges = [] + while len(close_road_edges) == 0: + close_road_edges = vec_map.get_road_edges_within(ego_xyh, radius) + radius += 20 + dis = list() + road_edge_pts = np.stack( + [ + road_edge.polyline.interpolate(num_pts).points[:, [0, 1, 3]] + for road_edge in close_road_edges + ], + 0, + ) + dx, dy, dh = batch_proj( + ego_xyh[None].repeat(road_edge_pts.shape[0], 0), road_edge_pts + ) + + idx = np.abs(dx).argmin(axis=1) + # hausdorff distance to the lane (longitudinal) + x_dis = np.take_along_axis(np.abs(dx), idx[:, None], axis=1).squeeze(1) + x_dis[(dx.min(1) < 0) & (dx.max(1) > 0)] = 0 + + y_dis = np.take_along_axis(np.abs(dy), idx[:, None], axis=1).squeeze(1) + + # distance metric to the lane (combining x,y) + dis = x_dis + y_dis + + return close_road_edges, dis diff --git a/src/trajdata/utils/batch_utils.py b/src/trajdata/utils/batch_utils.py index cf63009..be0d79f 100644 --- a/src/trajdata/utils/batch_utils.py +++ b/src/trajdata/utils/batch_utils.py @@ -1,8 +1,14 @@ from collections import defaultdict -from typing import Any, Dict, Iterator, List, Optional, Tuple + +import torch + +from pathlib import Path +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import numpy as np from torch.utils.data import Sampler +import torch +import dill from trajdata import UnifiedDataset from trajdata.data_structures import ( @@ -10,10 +16,24 @@ AgentBatchElement, AgentDataIndex, AgentType, + SceneBatch, SceneBatchElement, SceneTimeAgent, + SceneBatch, +) +from trajdata.data_structures.collation import ( + agent_collate_fn, + batch_rotate_raster_maps_for_agents_in_scene, ) -from trajdata.data_structures.collation import agent_collate_fn +from trajdata.maps import RasterizedMapMetadata, RasterizedMapPatch, VectorMap +from trajdata.utils.map_utils import load_map_patch +from trajdata.utils.arr_utils import ( + batch_nd_transform_xyvvaahh_pt, + batch_select, + PadDirection, +) +from trajdata.caching.df_cache import DataFrameCache +from pathlib import Path class SceneTimeBatcher(Sampler): @@ -173,3 +193,83 @@ def convert_to_agent_batch( ) return agent_collate_fn(batch_elems, return_dict=False, pad_format=pad_format) + + +def get_agents_map_patch( + cache_path: Path, + map_name: str, + patch_params: Dict[str, int], + agent_world_states_xyh: Union[np.ndarray, torch.Tensor], + allow_nan: float = False, + vector_map: Optional[VectorMap] = None, +) -> List[RasterizedMapPatch]: + + if isinstance(agent_world_states_xyh, torch.Tensor): + agent_world_states_xyh = agent_world_states_xyh.cpu().numpy() + assert agent_world_states_xyh.ndim == 2 + assert agent_world_states_xyh.shape[-1] == 3 + desired_patch_size: int = patch_params["map_size_px"] + resolution: float = patch_params["px_per_m"] + offset_xy: Tuple[float, float] = patch_params.get("offset_frac_xy", (0.0, 0.0)) + return_rgb: bool = patch_params.get("return_rgb", True) + no_map_fill_val: float = patch_params.get("no_map_fill_value", 0.0) + + if ( + vector_map is not None + and getattr(vector_map, "raster_map_data", None) is not None + ): + # Use preloaded raster map data + raster_map_or_path = vector_map.raster_map_data + raster_metadata_or_path = vector_map.raster_map_metadata + else: + env_name, location_name = map_name.split( + ":" + ) # assumes map_name format nusc_mini:boston-seaports + ( + maps_path, + _, + _, + raster_map_or_path, + raster_metadata_or_path, + ) = DataFrameCache.get_map_paths( + cache_path, env_name, location_name, resolution + ) + + map_patches = list() + for i in range(agent_world_states_xyh.shape[0]): + patch_data, raster_from_world_tf, has_data = load_map_patch( + raster_map_or_path, + raster_metadata_or_path, + agent_world_states_xyh[i, 0], + agent_world_states_xyh[i, 1], + desired_patch_size, + resolution, + offset_xy, + agent_world_states_xyh[i, 2], + return_rgb, + rot_pad_factor=np.sqrt(2), + no_map_val=no_map_fill_val, + ) + map_patches.append( + RasterizedMapPatch( + data=patch_data, + rot_angle=agent_world_states_xyh[i, 2], + crop_size=desired_patch_size, + resolution=resolution, + raster_from_world_tf=raster_from_world_tf, + has_data=has_data, + ) + ) + + return map_patches + + +def load_raster_map_data( + cache_path: Path, map_name: str, patch_params: Dict[str, int] +): + raise NotImplementedError() + +def get_raster_maps_for_scene_batch( + batch: SceneBatch, cache_path: Path, raster_map_params: Dict, ego_only: bool = False +): + raise NotImplementedError() diff --git a/src/trajdata/utils/comm_utils.py b/src/trajdata/utils/comm_utils.py new file mode 100644 index 0000000..9c860aa --- /dev/null +++ b/src/trajdata/utils/comm_utils.py @@ -0,0 +1,24 @@ +import numpy as np +import socket + +from contextlib import closing +from typing import Callable, Optional + + +def find_open_port(): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +def find_open_port_in_range(start_port, end_port): + for port in range(start_port, end_port + 1): + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("localhost", port)) + s.listen(1) + return port + except OSError: + continue + return None diff --git a/src/trajdata/utils/env_utils.py b/src/trajdata/utils/env_utils.py index 49c22ec..47f13b2 100644 --- a/src/trajdata/utils/env_utils.py +++ b/src/trajdata/utils/env_utils.py @@ -9,11 +9,6 @@ def get_raw_dataset(dataset_name: str, data_dir: str) -> RawDataset: return NuscDataset(dataset_name, data_dir, parallelizable=False, has_maps=True) - if "vod" in dataset_name: - from trajdata.dataset_specific.vod import VODDataset - - return VODDataset(dataset_name, data_dir, parallelizable=True, has_maps=True) - if "lyft" in dataset_name: from trajdata.dataset_specific.lyft import LyftDataset @@ -50,11 +45,6 @@ def get_raw_dataset(dataset_name: str, data_dir: str) -> RawDataset: dataset_name, data_dir, parallelizable=True, has_maps=True ) - if "av2" in dataset_name: - from trajdata.dataset_specific.argoverse2 import Av2Dataset - - return Av2Dataset(dataset_name, data_dir, parallelizable=True, has_maps=True) - raise ValueError(f"Dataset with name '{dataset_name}' is not supported") diff --git a/src/trajdata/utils/map_utils.py b/src/trajdata/utils/map_utils.py index 9856972..de7fb86 100644 --- a/src/trajdata/utils/map_utils.py +++ b/src/trajdata/utils/map_utils.py @@ -1,15 +1,19 @@ from __future__ import annotations -import warnings from typing import TYPE_CHECKING +import warnings if TYPE_CHECKING: - from trajdata.maps import map_kdtree, vec_map, map_strtree + from trajdata.maps import ( + map_kdtree, + vec_map, + map_strtree + ) from trajdata.maps.vec_map_elements import MapElementType from pathlib import Path -from typing import Dict, Final, Optional - +from typing import Dict, Final, Optional, List, NamedTuple +import enum import dill import numpy as np from scipy.stats import circmean @@ -17,7 +21,28 @@ import trajdata.proto.vectorized_map_pb2 as map_proto from trajdata.utils import arr_utils -MM_PER_M: Final[float] = 1000 +NUM_DECIMALS: Final[int] = 5 +COMPRESSION_SCALE: Final[float] = 10**NUM_DECIMALS + + +class LaneSegRelation(enum.IntEnum): + """ + Categorical token describing the relationship between an agent and a Lane + """ + + NOTCONNECTED = 0 + NEXT = 1 + PREV = 2 + LEFT = 3 + RIGHT = 4 + + +class LaneGraph(NamedTuple): + lane_ids: List[str] + lane_centerlines: np.ndarray + lane_left_edges: np.ndarray + lane_right_edges: np.ndarray + lane_connectivity: np.ndarray def decompress_values(data: np.ndarray) -> np.ndarray: @@ -25,11 +50,11 @@ def decompress_values(data: np.ndarray) -> np.ndarray: # The delta for the first point is just its coordinates tuple, i.e. it is a "delta" from # the origin. For subsequent points, this field stores the difference between the point's # coordinates and the previous point's coordinates. This is for representation efficiency. - return np.cumsum(data, axis=0, dtype=float) / MM_PER_M + return np.cumsum(data, axis=0, dtype=float) / COMPRESSION_SCALE def compress_values(data: np.ndarray) -> np.ndarray: - return (np.diff(data, axis=0, prepend=0.0) * MM_PER_M).astype(np.int32) + return (np.diff(data, axis=0, prepend=0.0) * COMPRESSION_SCALE).astype(np.int64) def get_polyline_headings(points: np.ndarray) -> np.ndarray: @@ -109,6 +134,30 @@ def populate_lane_polylines( new_lane_proto.right_boundary.dx_mm.extend(compressed_right_pts[:, 0].tolist()) new_lane_proto.right_boundary.dy_mm.extend(compressed_right_pts[:, 1].tolist()) new_lane_proto.right_boundary.dz_mm.extend(compressed_right_pts[:, 2].tolist()) + if road_lane_py.traffic_sign_ids is not None: + new_lane_proto.traffic_sign_ids.extend([iden.encode() for iden in road_lane_py.traffic_sign_ids]) + if road_lane_py.wait_line_ids is not None: + new_lane_proto.wait_line_ids.extend([iden.encode() for iden in road_lane_py.wait_line_ids]) + + +def populate_road_edge_polylines( + new_road_edge_proto: map_proto.RoadEdge, + road_edge_py: vec_map.RoadEdge, + origin: np.ndarray, +) -> None: + """Fill a Lane object's polyline attributes. + All points should be in world coordinates. + + Args: + new_road_edge_proto (RoadEdge): _description_ + road_edge_py (np.ndarray): _description_ + origin (np.ndarray): _description_ + """ + compressed_pts: np.ndarray = compress_values(road_edge_py.polyline.xyz - origin) + new_road_edge_proto.polyline.dx_mm.extend(compressed_pts[:, 0].tolist()) + new_road_edge_proto.polyline.dy_mm.extend(compressed_pts[:, 1].tolist()) + new_road_edge_proto.polyline.dz_mm.extend(compressed_pts[:, 2].tolist()) + new_road_edge_proto.polyline.h_rad.extend(road_edge_py.polyline.h.tolist()) def populate_polygon( @@ -290,6 +339,119 @@ def load_kdtrees( return kdtrees +def obtain_lane_graph(scene, map_api, vector_map_params) -> LaneGraph: + map_name = f"{scene.env_name}:{scene.location}" + num_pts = vector_map_params.get("num_lane_pts", 30) + vec_map = map_api.get_map( + map_name, + None, + **vector_map_params, + ) + infer_lane_connectivity = vector_map_params.get("infer_lane_connectivity", False) + lane_ids = [lane.id for lane in vec_map.lanes] + lane_centerlines = np.stack( + [ + lane.center.interpolate(num_pts).points[:, [0, 1, 3]] + for lane in vec_map.lanes + ], + 0, + ) + lane_left_edges = np.stack( + [ + ( + lane.left_edge.interpolate(num_pts).points[:, :2] + if lane.left_edge is not None + else np.full( + [num_pts, 2], dtype=lane_centerlines.dtype, fill_value=np.nan + ) + ) + for lane in vec_map.lanes + ], + 0, + ) + lane_right_edges = np.stack( + [ + ( + lane.right_edge.interpolate(num_pts).points[:, :2] + if lane.right_edge is not None + else np.full( + [num_pts, 2], dtype=lane_centerlines.dtype, fill_value=np.nan + ) + ) + for lane in vec_map.lanes + ], + 0, + ) + # Lane connectivity + rough_dis_map = lane_centerlines.mean() + lane_adj = np.zeros((len(lane_ids), len(lane_ids)), dtype=np.int8) + + for i, lane in enumerate(vec_map.lanes): + try: + for adj_lane_id in lane.next_lanes: + lane_adj[i, lane_ids.index(adj_lane_id)] = LaneSegRelation.NEXT.value + + for adj_lane_id in lane.prev_lanes: + lane_adj[i, lane_ids.index(adj_lane_id)] = LaneSegRelation.PREV.value + + for adj_lane_id in lane.adj_lanes_left: + lane_adj[i, lane_ids.index(adj_lane_id)] = LaneSegRelation.LEFT.value + + for adj_lane_id in lane.adj_lanes_right: + lane_adj[i, lane_ids.index(adj_lane_id)] = LaneSegRelation.RIGHT.value + except: + pass + if infer_lane_connectivity: + LAT_THRESHOLD = 4.5 + H_THRESHOLD = 0.1 + PTS_THRESHOLD = 3 + lane_center_pt = lane_centerlines[:, :, :2].mean(1) + rough_dis_map = np.linalg.norm( + lane_center_pt[:, None] - lane_center_pt[None], axis=-1 + ) + topk = 10 + topk_idx = np.argsort(rough_dis_map, axis=1)[:, 1 : topk + 1] + infered_lane_adj = np.zeros((len(lane_ids), len(lane_ids)), dtype=np.int8) + for i, lane in enumerate(vec_map.lanes): + centerline = lane_centerlines[i] + relevant_lanes = lane_centerlines[topk_idx[i]] + dx, dy, dh = arr_utils.batch_proj( + centerline.repeat(topk, 0), np.tile(relevant_lanes, (num_pts, 1, 1)) + ) + min_idx = np.argmin(np.abs(dx), axis=1) + min_dx = np.take_along_axis(dx, min_idx[:, None], axis=1).reshape( + num_pts, topk + ) + min_dy = np.take_along_axis(dy, min_idx[:, None], axis=1).reshape( + num_pts, topk + ) + dh = dh.reshape(num_pts, topk) + dx = dx.reshape(num_pts, topk, num_pts) + + # if one of the waypoints has a negative and positive dx, then it is contained in between the start and end of the adjacent lane + longi_adj = np.logical_and(dx.min(2) < 0, dx.max(2) > 0) + + lat_left_adj = np.logical_and(min_dy > 0, min_dy < LAT_THRESHOLD) + lat_right_adj = np.logical_and(min_dy < 0, min_dy > -LAT_THRESHOLD) + heading_adj = np.abs(dh) < H_THRESHOLD + + left_adj_flag = (longi_adj * lat_left_adj * heading_adj).sum( + 0 + ) > PTS_THRESHOLD + right_adj_flag = (longi_adj * lat_right_adj * heading_adj).sum( + 0 + ) > PTS_THRESHOLD + infered_lane_adj[i, topk_idx[i][left_adj_flag]] = LaneSegRelation.LEFT.value + infered_lane_adj[i, topk_idx[i][right_adj_flag]] = ( + LaneSegRelation.RIGHT.value + ) + + lane_adj = lane_adj + (lane_adj == 0) * infered_lane_adj + return LaneGraph( + lane_ids, lane_centerlines, lane_left_edges, lane_right_edges, lane_adj + ) + + def load_rtrees( rtrees_path: Path, ) -> Optional[Dict[MapElementType, map_strtree.MapElementSTRTree]]: diff --git a/src/trajdata/utils/scene_utils.py b/src/trajdata/utils/scene_utils.py index 804c498..1c29bc5 100644 --- a/src/trajdata/utils/scene_utils.py +++ b/src/trajdata/utils/scene_utils.py @@ -60,12 +60,13 @@ def interpolate_scene_dt(scene: Scene, desired_dt: float) -> None: def subsample_scene_dt(scene: Scene, desired_dt: float) -> None: dt_ratio: float = desired_dt / scene.dt - if not dt_ratio.is_integer(): + + if not is_integer_robust(dt_ratio): raise ValueError( f"Cannot subsample scene: {desired_dt} is not integer divisible by {scene.dt} for {str(scene)}" ) - dt_factor: int = int(dt_ratio) + dt_factor: int = int(round(dt_ratio)) # E.g., the scene is currently at dt = 0.1s (10 Hz), # but we want desired_dt = 0.5s (2 Hz). @@ -86,3 +87,6 @@ def subsample_scene_dt(scene: Scene, desired_dt: float) -> None: scene.dt = desired_dt # Note we do not touch scene_info.env_metadata.dt, this will serve as our # source of the "original" data dt information. + +def is_integer_robust(x): + return abs(x-round(x))<1e-6 \ No newline at end of file diff --git a/src/trajdata/utils/vis_utils.py b/src/trajdata/utils/vis_utils.py index cf8f146..30b6dd4 100644 --- a/src/trajdata/utils/vis_utils.py +++ b/src/trajdata/utils/vis_utils.py @@ -1,14 +1,16 @@ from collections import defaultdict -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple +import bokeh import geopandas as gpd import numpy as np import pandas as pd import seaborn as sns +from bokeh.io import export_png from bokeh.models import ColumnDataSource, GlyphRenderer -from bokeh.plotting import figure +from bokeh.plotting import curdoc, figure +from PIL import Image from shapely.geometry import LineString, Polygon - from trajdata.data_structures.agent import AgentType from trajdata.data_structures.batch import AgentBatch from trajdata.data_structures.state import StateArray @@ -19,8 +21,13 @@ PedWalkway, RoadArea, RoadLane, + RoadEdge, +) +from trajdata.utils.arr_utils import ( + batch_nd_transform_points_np, + batch_nd_transform_points_pt, + transform_coords_2d_np, ) -from trajdata.utils.arr_utils import transform_coords_2d_np def apply_default_settings(fig: figure) -> None: @@ -82,23 +89,25 @@ def calculate_figure_sizes( def pretty_print_agent_type(agent_type: AgentType): - return str(agent_type)[len("AgentType.") :].capitalize() + return str(agent_type).capitalize() def agent_type_to_str(agent_type_int: int) -> str: - return pretty_print_agent_type(AgentType(agent_type_int)) + return pretty_print_agent_type(AgentType(agent_type_int).name) def get_agent_type_color(agent_type: AgentType) -> str: - palette = sns.color_palette("husl", 4).as_hex() + palette = sns.color_palette("husl", 8).as_hex() if agent_type == AgentType.VEHICLE: - return palette[0] + return palette[3] elif agent_type == AgentType.PEDESTRIAN: return "darkorange" elif agent_type == AgentType.BICYCLE: - return palette[2] + return palette[5] elif agent_type == AgentType.MOTORCYCLE: - return palette[3] + return palette[7] + elif agent_type == 'EGO': + return palette[0] else: return "#A9A9A9" @@ -213,8 +222,7 @@ def compute_agent_rects_coords( return agent_rect_coords, dir_patch_coords - -def extract_full_agent_data_df(batch: AgentBatch, batch_idx: int) -> pd.DataFrame: +def extract_full_agent_data_df(batch: AgentBatch, batch_idx: int, violations) -> pd.DataFrame: main_data_dict = defaultdict(list) # Historical information @@ -252,7 +260,9 @@ def extract_full_agent_data_df(batch: AgentBatch, batch_idx: int) -> pd.DataFram main_data_dict["length"].extend(lengths) main_data_dict["width"].extend(widths) main_data_dict["pred_agent"].extend([True] * H) - main_data_dict["color"].extend([get_agent_type_color(agent_type)] * H) + # main_data_dict["color"].extend([get_agent_type_color(agent_type)] * H) + main_data_dict["color"] = ["lightblue"] * H + main_data_dict["violation"] = [''] * H ## Neighbors num_neighbors: int = batch.num_neigh[batch_idx].item() @@ -296,6 +306,7 @@ def extract_full_agent_data_df(batch: AgentBatch, batch_idx: int) -> pd.DataFram main_data_dict["width"].extend(widths) main_data_dict["pred_agent"].extend([False] * H) main_data_dict["color"].extend([get_agent_type_color(agent_type)] * H) + main_data_dict["violation"].extend([''] * H) # Future information ## Agent @@ -332,8 +343,12 @@ def extract_full_agent_data_df(batch: AgentBatch, batch_idx: int) -> pd.DataFram main_data_dict["length"].extend(lengths) main_data_dict["width"].extend(widths) main_data_dict["pred_agent"].extend([True] * T) - main_data_dict["color"].extend([get_agent_type_color(agent_type)] * T) - + main_data_dict["color"].extend(["lightblue"] * T) + main_data_dict["violation"].extend([''] * T) + # main_data_dict["color"].extend([get_agent_type_color(agent_type)] * T) + if violations is not None: + main_data_dict['color'][-T:len(main_data_dict['violation']) - T + len(violations)] = ['lightblue' if len(violation) == 0 else get_agent_type_color('EGO') for violation in violations] + main_data_dict['violation'][-T:len(main_data_dict['violation']) - T + len(violations)] = [','.join(violation) for violation in violations] ## Neighbors num_neighbors: int = batch.num_neigh[batch_idx].item() @@ -374,6 +389,12 @@ def extract_full_agent_data_df(batch: AgentBatch, batch_idx: int) -> pd.DataFram main_data_dict["width"].extend(widths) main_data_dict["pred_agent"].extend([False] * T) main_data_dict["color"].extend([get_agent_type_color(agent_type)] * T) + main_data_dict['violation'].extend([''] * T) + + # Fix: Ensure all lists are of the same length + min_lenght = min(len(v) for v in main_data_dict.values()) + for k in main_data_dict.keys(): + main_data_dict[k] = main_data_dict[k][:min_lenght] return pd.DataFrame(main_data_dict) @@ -394,6 +415,8 @@ def convert_to_gpd(vec_map: VectorMap) -> gpd.GeoDataFrame: holes=[hole.xyz for hole in elem.interior_holes], ) ) + elif isinstance(elem, RoadEdge): + geo_data["geometry"].append(LineString(elem.polyline.xyz)) return gpd.GeoDataFrame(geo_data) @@ -408,12 +431,14 @@ def get_map_cds( ColumnDataSource, ColumnDataSource, ColumnDataSource, + ColumnDataSource, ]: road_lane_data = defaultdict(list) lane_center_data = defaultdict(list) ped_crosswalk_data = defaultdict(list) ped_walkway_data = defaultdict(list) road_area_data = defaultdict(list) + road_edge_data = defaultdict(list) map_gpd = convert_to_gpd(vec_map) affine_tf_params = ( @@ -467,12 +492,17 @@ def get_map_cds( road_area_data["ys"].append( [[xy[..., 1]] + [hole[..., 1] for hole in holes_xy]] ) + elif row['type'] == MapElementType.ROAD_EDGE: + xy = np.stack(row["geometry"].xy, axis=1) + road_edge_data["xs"].append(xy[..., 0]) + road_edge_data["ys"].append(xy[..., 1]) return ( ColumnDataSource(data=lane_center_data), ColumnDataSource(data=road_lane_data), ColumnDataSource(data=ped_crosswalk_data), ColumnDataSource(data=ped_walkway_data), ColumnDataSource(data=road_area_data), + ColumnDataSource(data=road_edge_data), ) @@ -500,6 +530,7 @@ def draw_map_elems( ped_crosswalk_cds, ped_walkway_cds, road_area_cds, + road_edge_cds ) = get_map_cds(map_from_world_tf, vec_map, bbox) road_areas = fig.multi_polygons( @@ -539,5 +570,12 @@ def draw_map_elems( line_color="gray", line_alpha=0.5, ) + + if kwargs.get('incl_road_edges', True): + road_edges = fig.multi_line( + source=road_edge_cds, + line_color="red", + line_alpha=0.5, + ) return road_areas, road_lanes, ped_crosswalks, ped_walkways, lane_centers diff --git a/src/trajdata/visualization/interactive_animation.py b/src/trajdata/visualization/interactive_animation.py index bf71b2c..b47ac74 100644 --- a/src/trajdata/visualization/interactive_animation.py +++ b/src/trajdata/visualization/interactive_animation.py @@ -1,4 +1,5 @@ import logging +import os import socket import threading import time @@ -9,12 +10,13 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Optional -import cv2 import numpy as np import pandas as pd +import panel as pn from bokeh.application import Application from bokeh.application.handlers import FunctionHandler from bokeh.document import Document, without_document_lock +from bokeh.io import export_png from bokeh.io.export import get_screenshot_as_png from bokeh.layouts import column, row from bokeh.models import ( @@ -29,7 +31,7 @@ Select, Slider, ) -from bokeh.plotting import figure +from bokeh.plotting import curdoc, figure from bokeh.server.server import Server from selenium import webdriver from tornado import gen @@ -43,61 +45,19 @@ from trajdata.utils import vis_utils -class InteractiveAnimation: - def __init__( - self, - main_func: Callable[[Document, IOLoop], None], - port: Optional[int] = None, - **kwargs, - ) -> None: - self.main_func = main_func - self.port = port - self.kwargs = kwargs - - def get_open_port(self) -> int: - with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: - s.bind(("", 0)) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - return s.getsockname()[1] - - def show(self) -> None: - io_loop = IOLoop() - - if self.port is None: - self.port = self.get_open_port() - - def kill_on_tab_close(session_context): - io_loop.stop() - - def app_init(doc: Document): - doc.on_session_destroyed(kill_on_tab_close) - self.main_func(doc=doc, io_loop=io_loop, **self.kwargs) - return doc - - server = Server( - {"/": Application(FunctionHandler(app_init))}, - io_loop=io_loop, - port=self.port, - check_unused_sessions_milliseconds=500, - unused_session_lifetime_milliseconds=500, - ) - server.start() - - # print(f"Opening Bokeh application on http://localhost:{self.port}/") - server.io_loop.add_callback(server.show, "/") - server.io_loop.start() - - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=RuntimeWarning) - server.io_loop.close() - - def animate_agent_batch_interactive( - doc: Document, io_loop: IOLoop, batch: AgentBatch, batch_idx: int, cache_path: Path + batch: AgentBatch, + batch_idx: int, + cache_path: Path, + render_immediately: bool = False, + incl_road_edges: bool = False, + embeded: bool = False, + image_files: List[str] = None, + violations = None, ) -> None: - agent_data_df = vis_utils.extract_full_agent_data_df(batch, batch_idx) + agent_data_df = vis_utils.extract_full_agent_data_df(batch, batch_idx,violations) - # Figure creation and a few initial settings. + # figure creation and a few initial settings. width: int = 1280 aspect_ratio: float = 16 / 9 data_vis_margin: float = 10.0 @@ -154,6 +114,7 @@ def animate_agent_batch_interactive( incl_road_areas=True, incl_ped_crosswalks=True, incl_ped_walkways=True, + incl_road_edges=True if incl_road_edges else False, ) ( @@ -180,13 +141,18 @@ def animate_agent_batch_interactive( # Some neighbors can have more history than the agent to be predicted # (the data-collecting agent has observed the neighbors for longer). - full_H = max( - batch.agent_hist_len[batch_idx].item(), - *batch.neigh_hist_len[batch_idx].tolist(), - ) - full_T = max( - batch.agent_fut_len[batch_idx].item(), *batch.neigh_fut_len[batch_idx].tolist() - ) + if batch.neigh_hist_len[batch_idx].shape[0] == 0: + full_H = batch.agent_hist_len[batch_idx].item() + full_T = batch.agent_fut_len[batch_idx].item() + else: + full_H = max( + batch.agent_hist_len[batch_idx].item(), + *batch.neigh_hist_len[batch_idx].tolist(), + ) + full_T = max( + batch.agent_fut_len[batch_idx].item(), + *batch.neigh_fut_len[batch_idx].tolist(), + ) def create_multi_line_data(agents_df: pd.DataFrame) -> Dict[str, List]: lines_data = defaultdict(list) @@ -245,6 +211,9 @@ def slice_multi_line_data( slice_multi_line_data(future_line_data_df, slice(full_H, None), check_idx=0) ) + + + history_lines = fig.multi_line( xs="xs", ys="ys", @@ -262,6 +231,7 @@ def slice_multi_line_data( line_width=2, source=future_lines_cds, ) + # Agent rectangles/directional arrows at the current timestep. agent_rects = fig.patches( @@ -283,19 +253,63 @@ def slice_multi_line_data( source=agent_cds, view=curr_time_view, ) - + if batch.extras is not None and "action_sample" in batch.extras: + action_sample = batch.extras["action_sample"][batch_idx,0] + action_sample = action_sample.cpu().numpy() + action_sample = np.concatenate([np.full([full_H,*action_sample.shape[1:]],fill_value=np.nan),action_sample],axis=0) + action_sample_cds = ColumnDataSource( + dict(xs=[action_sample[0,i,:,0] for i in range(action_sample.shape[1])], + ys=[action_sample[0,i,:,1] for i in range(action_sample.shape[1])], + ) + ) + action_sample_lines = fig.multi_line( + xs="xs", + ys="ys", + line_color="green", + line_dash="dashed", + line_width=1.5, + source=action_sample_cds + ) + else: + action_sample_lines = None scene_ts: int = batch.scene_ts[batch_idx].item() # Controlling the timestep shown to users. + end_time: int = min(agent_cds.data["t"].max(), len(violations)) if violations is not None else agent_cds.data["t"].max() + total_timesteps: int = end_time - agent_cds.data["t"].min() + 1 + abs_total_timesteps: int = agent_cds.data["t"].max() - agent_cds.data["t"].min() + 1 time_slider = Slider( start=agent_cds.data["t"].min(), - end=agent_cds.data["t"].max(), + end=end_time, step=1, value=0, title=f"Current Timestep (scene timestep {scene_ts})", ) dt: float = batch.dt[batch_idx].item() + + # adding image if available + if image_files is not None: + from PIL import Image + + def pil_image_to_rgba(image_file): + image = Image.open(image_file) + image = image.convert("RGBA") # Ensure the image is in RGBA mode + img_array = np.array(image) # Convert the image to a numpy array + # Flatten the RGBA array into a 1D array in the format Bokeh expects (row-major order) + img_flat = np.flipud(img_array).flatten() + return img_flat.view(np.uint32).reshape((image.height, image.width)) + + img_source = ColumnDataSource(data={'image': [pil_image_to_rgba(image_files[0])]}) + + # Set up the figure for displaying the image + p = figure(x_range=(0, 1), y_range=(0, 1), width=Image.open(image_files[0]).width, height=Image.open(image_files[0]).height) + p.image_rgba(image='image', source=img_source, x=0, y=0, dw=1, dh=1) + p.xaxis.visible = False + p.yaxis.visible = False + p.xgrid.visible = False + p.ygrid.visible = False + p.outline_line_color = None # Ensuring that information gets updated upon a cahnge in the slider value. def time_callback(attr, old, new) -> None: @@ -306,6 +320,10 @@ def time_callback(attr, old, new) -> None: future_lines_cds.data = slice_multi_line_data( future_line_data_df, slice(new + full_H, None), check_idx=0 ) + if action_sample_lines is not None: + action_sample_cds.data = dict(xs=[action_sample[new+full_H,i,:,0] for i in range(action_sample.shape[1])], + ys=[action_sample[new+full_H,i,:,1] for i in range(action_sample.shape[1])], + ) if new == 0: time_slider.title = f"Current Timestep (scene timestep {scene_ts})" @@ -313,6 +331,10 @@ def time_callback(attr, old, new) -> None: n_steps = abs(new) time_slider.title = f"{n_steps} timesteps ({n_steps * dt:.2f} s) into the {'future' if new > 0 else 'past'}" + if image_files is not None: + new_image_idx = new * len(image_files) // abs_total_timesteps + img_source.data = {'image': [pil_image_to_rgba(image_files[new_image_idx])]} + time_slider.on_change("value", time_callback) # Adding tooltips on mouse hover. @@ -322,16 +344,20 @@ def time_callback(attr, old, new) -> None: ("Class", "@type"), ("Position", "(@x, @y) m"), ("Speed", "@speed_mps m/s (@speed_kph km/h)"), + ("Violation", "@violation"), ], renderers=[agent_rects], ) ) + exit_button = Button(label="Exit", button_type="danger", width=60) def button_callback(): # Stop the server. - io_loop.stop() + import sys - exit_button = Button(label="Exit", button_type="danger", width=60) + from tornado.ioloop import IOLoop + + sys.exit() exit_button.on_click(button_callback) # Writing animation callback functions so that the play/pause button animate the @@ -351,12 +377,12 @@ def animate(): if play_button.label.startswith("►"): play_button.label = "❚❚ Pause" - play_cb_manager[0] = doc.add_periodic_callback( + play_cb_manager[0] = play_button.document.add_periodic_callback( animate_update, period_milliseconds=int(dt * 1000) ) else: play_button.label = "► Play" - doc.remove_periodic_callback(play_cb_manager[0]) + play_button.document.remove_periodic_callback(play_cb_manager[0]) play_button = Button(label="► Play", width=100) play_button.on_click(animate) @@ -364,13 +390,25 @@ def animate(): # Creating the legend elements and connecting them to their original elements # (allows us to hide them on click later!) agent_legend_elems = [ + fig.rect( + fill_color='lightblue', + line_color='black', + name='EGO' + ), + fig.rect( + fill_color=vis_utils.get_agent_type_color('EGO'), + line_color='black', + name='EGO_Violation' + ) + ] + agent_legend_elems.extend([ fig.rect( fill_color=vis_utils.get_agent_type_color(x), line_color="black", name=vis_utils.agent_type_to_str(x), ) for x in AgentType - ] + ]) map_legend_elems = [LegendItem(label="Lane Center", renderers=[lane_centers])] @@ -450,42 +488,7 @@ def after_frame_save(label: str) -> None: animate_update() def execute_save_animation(file_path: Path) -> None: - images = [] - - chrome_options = webdriver.ChromeOptions() - chrome_options.headless = True - driver = webdriver.Chrome(chrome_options=chrome_options) - - n_frames = render_range_slider.value[1] - render_range_slider.value[0] + 1 - for frame_index in trange(n_frames, desc="Rendering Video"): - # Giving the doc a chance to update the figure. - time.sleep(0.1) - - image = get_screenshot_as_png(fig, driver=driver) - shape = image.size - images.append(image) - - doc.add_next_tick_callback( - partial( - after_frame_save, - label=f"Rendering... ({100*(frame_index+1)/n_frames:.0f}%)", - ) - ) - - if file_path.suffix == ".mp4": - fourcc = cv2.VideoWriter_fourcc(*"mp4v") - elif file_path.suffix == ".avi": - fourcc = cv2.VideoWriter_fourcc("M", "J", "P", "G") - - video_obj = cv2.VideoWriter( - filename=str(file_path), fourcc=fourcc, fps=1.0 / dt, frameSize=shape - ) - for image in images: - cv2_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) - video_obj.write(cv2_image) - video_obj.release() - - doc.add_next_tick_callback(reset_buttons) + raise NotImplementedError() @gen.coroutine @without_document_lock @@ -516,19 +519,25 @@ def save_animation(filename: str) -> None: args=(Path(filename + filetype_select.value),), ).start() - video_button.on_click( - partial( - save_animation, - filename=( - "_".join([env_name, map_name, scene_id, f"t{scene_ts}", agent_name]) - ), - ) + video_button_fn = partial( + save_animation, + filename=("_".join([env_name, map_name, scene_id, f"t{scene_ts}", agent_name])), ) + video_button.on_click(video_button_fn) - doc.add_root( - column( - fig, - row(play_button, time_slider, exit_button), - row(video_button, render_range_slider, filetype_select), - ) + layout = column( + fig if image_files is None else row(fig, p), + # row(play_button, time_slider, exit_button), + row(play_button, time_slider), + row(video_button, render_range_slider, filetype_select), ) + bokeh_pane = pn.pane.Bokeh(layout) + + if embeded: + return bokeh_pane + + server = pn.serve(bokeh_pane) + if render_immediately: + video_button_fn() + + return server diff --git a/src/trajdata/visualization/interactive_figure.py b/src/trajdata/visualization/interactive_figure.py index b61ed59..bf68bac 100644 --- a/src/trajdata/visualization/interactive_figure.py +++ b/src/trajdata/visualization/interactive_figure.py @@ -17,8 +17,8 @@ class InteractiveFigure: def __init__(self, **kwargs) -> None: self.aspect_ratio: float = kwargs.get("aspect_ratio", 16 / 9) - self.width: int = kwargs.get("width", 1280) - self.height: int = kwargs.get("height", int(self.width / self.aspect_ratio)) + self.width: int = kwargs.pop("width", 1280) + self.height: int = kwargs.pop("height", int(self.width / self.aspect_ratio)) # We'll be tracking the maxes and mins of data with these. self.x_min = np.inf diff --git a/src/trajdata/visualization/interactive_vis.py b/src/trajdata/visualization/interactive_vis.py index 85a263a..bc17d81 100644 --- a/src/trajdata/visualization/interactive_vis.py +++ b/src/trajdata/visualization/interactive_vis.py @@ -12,7 +12,7 @@ from trajdata.visualization.interactive_figure import InteractiveFigure -def plot_agent_batch_interactive(batch: AgentBatch, batch_idx: int, cache_path: Path): +def plot_agent_batch_interactive(batch: AgentBatch, batch_idx: int, cache_path: Path, **kwargs) -> None: fig = InteractiveFigure( tooltips=[ ("Class", "@type"), @@ -76,6 +76,7 @@ def plot_agent_batch_interactive(batch: AgentBatch, batch_idx: int, cache_path: incl_road_areas=True, incl_ped_crosswalks=True, incl_ped_walkways=True, + incl_road_edges=True if kwargs.get("incl_road_edges", False) else False ), # x_min, x_max, y_min, y_max bbox=( @@ -84,10 +85,12 @@ def plot_agent_batch_interactive(batch: AgentBatch, batch_idx: int, cache_path: y - map_vis_radius, y + map_vis_radius, ), + kwargs=kwargs, ) fig.add_lines(agent_histories) - fig.add_lines(agent_futures) + if agent_fut_np.shape[0] > 0: + fig.add_lines(agent_futures) agent_extent: np.ndarray = batch.agent_hist_extent[batch_idx, -1] if agent_extent.isnan().any(): @@ -116,7 +119,8 @@ def plot_agent_batch_interactive(batch: AgentBatch, batch_idx: int, cache_path: "y": [y], "xs": [agent_rect_coords[:, 0] + x], "ys": [agent_rect_coords[:, 1] + y], - "fill_color": [vis_utils.get_agent_type_color(agent_type)], + # "fill_color": [vis_utils.get_agent_type_color(agent_type)], + "fill_color": [vis_utils.get_agent_type_color('EGO')], # ego color "line_color": ["black"], "fill_alpha": [0.7], "type": [str(AgentType(agent_type))[len("AgentType.") :]], @@ -142,7 +146,8 @@ def plot_agent_batch_interactive(batch: AgentBatch, batch_idx: int, cache_path: dir_patches_data = { "xs": [dir_patch_coords[:, 0] + x], "ys": [dir_patch_coords[:, 1] + y], - "fill_color": [vis_utils.get_agent_type_color(agent_type)], + "fill_color": [vis_utils.get_agent_type_color('EGO')], # ego color + # "fill_color": [vis_utils.get_agent_type_color(agent_type)], "line_color": ["black"], "alpha": [0.7], } @@ -153,7 +158,8 @@ def plot_agent_batch_interactive(batch: AgentBatch, batch_idx: int, cache_path: agent_extent: np.ndarray = batch.neigh_hist_extents[batch_idx, n_neigh, -1] if agent_extent.isnan().any(): - raise ValueError("Agent extents cannot be NaN!") + continue + # raise ValueError("Agent extents cannot be NaN!") length = agent_extent[0].item() width = agent_extent[1].item() diff --git a/tests/test_raster_map.py b/tests/test_raster_map.py new file mode 100644 index 0000000..9bfc197 --- /dev/null +++ b/tests/test_raster_map.py @@ -0,0 +1,100 @@ +import unittest +from pathlib import Path +from typing import Dict, List + +from trajdata import MapAPI, VectorMap + +import unittest +from collections import defaultdict + +import torch + +from trajdata import AgentType, UnifiedDataset, SceneBatch +from trajdata.dataset import DataLoader +from trajdata.utils.batch_utils import get_raster_maps_for_scene_batch + + +class TestRasterMap(unittest.TestCase): + def __init__(self, methodName: str = "batchConversion") -> None: + super().__init__(methodName) + + data_source = "nusc_mini" + history_sec = 2.0 + prediction_sec = 6.0 + + attention_radius = defaultdict( + lambda: 20.0 + ) # Default range is 20m unless otherwise specified. + attention_radius[(AgentType.PEDESTRIAN, AgentType.PEDESTRIAN)] = 10.0 + attention_radius[(AgentType.PEDESTRIAN, AgentType.VEHICLE)] = 20.0 + attention_radius[(AgentType.VEHICLE, AgentType.PEDESTRIAN)] = 20.0 + attention_radius[(AgentType.VEHICLE, AgentType.VEHICLE)] = 30.0 + + self._map_params = {"px_per_m": 2, "map_size_px": 100, "offset_frac_xy": (-0.75, 0.0)} + + self._scene_dataset = UnifiedDataset( + centric="scene", + desired_data=[data_source], + history_sec=(history_sec, history_sec), + future_sec=(prediction_sec, prediction_sec), + agent_interaction_distances=attention_radius, + incl_robot_future=False, + incl_raster_map=True, + raster_map_params=self._map_params, + only_predict=[AgentType.VEHICLE, AgentType.PEDESTRIAN], + no_types=[AgentType.UNKNOWN], + num_workers=0, + standardize_data=True, + data_dirs={ + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + self._scene_dataloader = DataLoader( + self._scene_dataset, + batch_size=4, + shuffle=False, + collate_fn=self._scene_dataset.get_collate_fn(), + num_workers=0, + ) + + def _assert_allclose_with_nans(self, tensor1, tensor2, atol=1e-8): + """ + asserts that the two tensors have nans in the same locations, and the non-nan + elements all are close. + """ + # Check nans are in the same place + self.assertFalse( + torch.any( # True if there's any mismatch + torch.logical_xor( # True where either tensor1 or tensor 2 has nans, but not both (mismatch) + torch.isnan(tensor1), # True where tensor1 has nans + torch.isnan(tensor2), # True where tensor2 has nans + ) + ), + msg="Nans occur in different places.", + ) + valid_mask = torch.logical_not(torch.isnan(tensor1)) + self.assertTrue( + torch.allclose(tensor1[valid_mask], tensor2[valid_mask], atol=atol), + msg="Non-nan values don't match.", + ) + + def test_map_transform_scenebatch(self): + scene_batch: SceneBatch + for i, scene_batch in enumerate(self._scene_dataloader): + + # Make the tf double for more accurate transform. + scene_batch.centered_world_from_agent_tf = scene_batch.centered_world_from_agent_tf.double() + + maps, maps_resolution, raster_from_world_tf = get_raster_maps_for_scene_batch( + scene_batch, self._scene_dataset.cache_path, "nusc_mini", self._map_params) + + self._assert_allclose_with_nans(scene_batch.rasters_from_world_tf, raster_from_world_tf, atol=1e-2) + self._assert_allclose_with_nans(scene_batch.maps_resolution, maps_resolution) + self._assert_allclose_with_nans(scene_batch.maps, maps, atol=1e-4) + + if i > 50: + break + +if __name__ == "__main__": + unittest.main(catchbreak=False) diff --git a/tests/test_state.py b/tests/test_state.py index f0d0c12..a954145 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -4,6 +4,10 @@ import torch from trajdata.data_structures.state import NP_STATE_TYPES, TORCH_STATE_TYPES +from trajdata import UnifiedDataset, AgentBatch, AgentType +from trajdata.data_structures import AgentBatchElement, SceneBatchElement +from collections import defaultdict +from torch.utils.data import DataLoader AgentStateArray = NP_STATE_TYPES["x,y,z,xd,yd,xdd,ydd,h"] AgentObsArray = NP_STATE_TYPES["x,y,z,xd,yd,xdd,ydd,s,c"] @@ -99,6 +103,15 @@ def test_tensor_ops(self): self.assertFalse(isinstance(c, AgentStateTensor)) self.assertTrue(isinstance(c, torch.Tensor)) + def test_reshape_keeps_class_if_possible(self): + a = AgentStateTensor(torch.rand(2, 3, 8)) + self.assertTrue(isinstance(a.reshape(6, 8), AgentStateTensor)) + self.assertTrue(isinstance(a.reshape((6, 8)), AgentStateTensor)) + self.assertTrue(isinstance(torch.reshape(a, (6, 8)), AgentStateTensor)) + self.assertFalse(isinstance(a.reshape(12, 4), AgentStateTensor)) + self.assertFalse(isinstance(a.reshape((12, 4)), AgentStateTensor)) + self.assertFalse(isinstance(torch.reshape(a, (12, 4)), AgentStateTensor)) + class TestStateArray(unittest.TestCase): def test_construction(self): @@ -177,5 +190,203 @@ def test_tensor_ops(self): self.assertTrue(isinstance(c, float)) +class TestDataset(unittest.TestCase): + def test_dataloading(self): + dataset = UnifiedDataset( + desired_data=["nusc_mini-mini_val"], + centric="agent", + desired_dt=0.1, + history_sec=(3.2, 3.2), + future_sec=(4.8, 4.8), + only_predict=[AgentType.VEHICLE], + agent_interaction_distances=defaultdict(lambda: 30.0), + incl_robot_future=True, + incl_raster_map=True, + standardize_data=False, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, + num_workers=4, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + dataloader = DataLoader( + dataset, + batch_size=4, + shuffle=True, + collate_fn=dataset.get_collate_fn(), + num_workers=0, + ) + + i = 0 + batch: AgentBatch + for batch in dataloader: + i += 1 + + batch.to("cuda") + + self.assertIsInstance(batch.curr_agent_state, dataset.torch_state_type) + self.assertIsInstance(batch.agent_hist, dataset.torch_obs_type) + self.assertIsInstance(batch.agent_fut, dataset.torch_obs_type) + self.assertIsInstance(batch.robot_fut, dataset.torch_obs_type) + + if i == 5: + break + + def test_dict_dataloading(self): + dataset = UnifiedDataset( + desired_data=["nusc_mini-mini_val"], + centric="agent", + desired_dt=0.1, + history_sec=(3.2, 3.2), + future_sec=(4.8, 4.8), + only_predict=[AgentType.VEHICLE], + agent_interaction_distances=defaultdict(lambda: 30.0), + incl_robot_future=True, + incl_raster_map=True, + standardize_data=False, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, + num_workers=4, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + dataloader = DataLoader( + dataset, + batch_size=4, + shuffle=True, + collate_fn=dataset.get_collate_fn(return_dict=True), + num_workers=0, + ) + + i = 0 + for batch in dataloader: + i += 1 + + self.assertIsInstance(batch["curr_agent_state"], dataset.torch_state_type) + self.assertIsInstance(batch["agent_hist"], dataset.torch_obs_type) + self.assertIsInstance(batch["agent_fut"], dataset.torch_obs_type) + self.assertIsInstance(batch["robot_fut"], dataset.torch_obs_type) + + if i == 5: + break + + dataset = UnifiedDataset( + desired_data=["nusc_mini-mini_val"], + centric="scene", + desired_dt=0.1, + history_sec=(3.2, 3.2), + future_sec=(4.8, 4.8), + only_predict=[AgentType.VEHICLE], + agent_interaction_distances=defaultdict(lambda: 30.0), + incl_robot_future=True, + incl_raster_map=True, + standardize_data=False, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, + num_workers=4, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + dataloader = DataLoader( + dataset, + batch_size=4, + shuffle=True, + collate_fn=dataset.get_collate_fn(return_dict=True), + num_workers=0, + ) + + i = 0 + for batch in dataloader: + i += 1 + + self.assertIsInstance( + batch["centered_agent_state"], dataset.torch_state_type + ) + self.assertIsInstance(batch["agent_hist"], dataset.torch_obs_type) + self.assertIsInstance(batch["agent_fut"], dataset.torch_obs_type) + self.assertIsInstance(batch["robot_fut"], dataset.torch_obs_type) + + if i == 5: + break + + def test_default_datatypes_agent(self): + dataset = UnifiedDataset( + desired_data=["nusc_mini-mini_val"], + centric="agent", + desired_dt=0.1, + history_sec=(3.2, 3.2), + future_sec=(4.8, 4.8), + only_predict=[AgentType.VEHICLE], + agent_interaction_distances=defaultdict(lambda: 30.0), + incl_robot_future=True, + incl_raster_map=True, + standardize_data=False, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, + num_workers=4, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + elem: AgentBatchElement = dataset[0] + self.assertIsInstance(elem.curr_agent_state_np, dataset.np_state_type) + self.assertIsInstance(elem.agent_history_np, dataset.np_obs_type) + self.assertIsInstance(elem.agent_future_np, dataset.np_obs_type) + self.assertIsInstance(elem.robot_future_np, dataset.np_obs_type) + + def test_default_datatypes_scene(self): + dataset = UnifiedDataset( + desired_data=["nusc_mini-mini_val"], + centric="scene", + desired_dt=0.1, + history_sec=(3.2, 3.2), + future_sec=(4.8, 4.8), + only_predict=[AgentType.VEHICLE], + agent_interaction_distances=defaultdict(lambda: 30.0), + incl_robot_future=True, + incl_raster_map=True, + standardize_data=False, + raster_map_params={ + "px_per_m": 2, + "map_size_px": 224, + "offset_frac_xy": (-0.5, 0.0), + }, + num_workers=4, + verbose=True, + data_dirs={ # Remember to change this to match your filesystem! + "nusc_mini": "~/datasets/nuScenes", + }, + ) + + elem: SceneBatchElement = dataset[0] + self.assertIsInstance(elem.centered_agent_state_np, dataset.np_state_type) + self.assertIsInstance(elem.agent_histories[0], dataset.np_obs_type) + self.assertIsInstance(elem.agent_futures[0], dataset.np_obs_type) + self.assertIsInstance(elem.robot_future_np, dataset.np_obs_type) + + if __name__ == "__main__": unittest.main() From c4c3c6d6907d1d18d85a3788ce2adcebe12ec55a Mon Sep 17 00:00:00 2001 From: Boris Ivanovic Date: Sat, 1 Nov 2025 22:53:55 -0700 Subject: [PATCH 2/4] Fix lane connections and directions and bump version. --- pyproject.toml | 2 +- .../dataset_specific/xodr/connectivity.py | 365 +++++++++--------- .../dataset_specific/xodr/lane_processing.py | 53 ++- src/trajdata/dataset_specific/xodr/parser.py | 18 +- 4 files changed, 235 insertions(+), 203 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fb52a43..5f0ecb4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", ] name = "trajdata-alpasim" -version = "1.4.2" +version = "1.4.3" authors = [{ name = "Boris Ivanovic", email = "bivanovic@nvidia.com" }] description = "A unified interface to many trajectory forecasting datasets." readme = "README.md" diff --git a/src/trajdata/dataset_specific/xodr/connectivity.py b/src/trajdata/dataset_specific/xodr/connectivity.py index c23d8a8..503a06b 100644 --- a/src/trajdata/dataset_specific/xodr/connectivity.py +++ b/src/trajdata/dataset_specific/xodr/connectivity.py @@ -9,37 +9,46 @@ import xml.etree.ElementTree as ET from typing import Dict, List -import numpy as np - from .datatypes import LaneGeom -# Constants -# Maximum distance (meters) to connect junction lanes -JUNCTION_LANE_CONNECTION_THRESHOLD = 10.0 - def build_lane_adjacency(road_to_lanes: Dict[str, List[LaneGeom]]) -> None: - """Build adjacency relationships between lanes within each road. + """Build physical adjacency relationships between lanes within each road. + + Sets the adj_left and adj_right attributes for each lane based on + lane ID ordering within the same road. Args: - road_to_lanes: Dictionary mapping road IDs to their lane geometries. - Modified in place to add adjacency relationships. + road_to_lanes: Dictionary mapping road IDs to their lanes. + Modified in place to add adjacency info. """ for lanes in road_to_lanes.values(): - left_lanes = sorted( - [lane for lane in lanes if lane.is_left], key=lambda lane: lane.lane_id_xml - ) - right_lanes = sorted( - [lane for lane in lanes if not lane.is_left], - key=lambda lane: lane.lane_id_xml, - reverse=True, - ) - for side in (left_lanes, right_lanes): - for i, curr in enumerate(side): - if i > 0: - curr.adj_right.add(side[i - 1].unique_id) - if i < len(side) - 1: - curr.adj_left.add(side[i + 1].unique_id) + # Sort lanes by ID for adjacency determination + sorted_lanes = sorted(lanes, key=lambda lane: lane.lane_id_xml) + + # Create a map from lane ID to lane object for quick lookup + id_to_lane = {lane.lane_id_xml: lane for lane in sorted_lanes} + + for lane in sorted_lanes: + lane_id = lane.lane_id_xml + + # Check for left adjacent lane + # Left = geometrically to the left = higher ID (for both positive and negative) + # But never cross zero (lane 0 is reference line, not a real lane) + left_id = lane_id + 1 + if left_id != 0 and left_id in id_to_lane: + adj_lane = id_to_lane[left_id] + if adj_lane.is_driving: # Only connect to driveable lanes + lane.adj_left.add(adj_lane.unique_id) + + # Check for right adjacent lane + # Right = geometrically to the right = lower ID (for both positive and negative) + # But never cross zero (lane 0 is reference line, not a real lane) + right_id = lane_id - 1 + if right_id != 0 and right_id in id_to_lane: + adj_lane = id_to_lane[right_id] + if adj_lane.is_driving: # Only connect to driveable lanes + lane.adj_right.add(adj_lane.unique_id) # Populate legal lane change permissions from XODR neighbor references # IMPORTANT: Only allow lane changes that are BOTH legal AND physically possible @@ -62,37 +71,42 @@ def build_lane_adjacency(road_to_lanes: Dict[str, List[LaneGeom]]) -> None: def extract_road_connections(root: ET.Element) -> Dict[str, Dict]: - """Extract road-to-road connectivity from elements. + """Extract road-to-road connectivity from XODR link elements. Args: root: Root XML element of the XODR document Returns: - Dictionary mapping road IDs to their connectivity info + Dictionary mapping road IDs to their connections: + {road_id: {"successor": {"road_id": str, "contactPoint": str}, + "predecessor": {"road_id": str, "contactPoint": str}}} """ - road_connections: Dict[str, Dict] = {} + road_connections = {} for road in root.findall("road"): road_id = road.attrib["id"] link_elem = road.find("link") + if link_elem is not None: connections = {} - predecessor = link_elem.find("predecessor") + + # Extract successor successor = link_elem.find("successor") + if successor is not None and successor.attrib.get("elementType") == "road": + connections["successor"] = { + "road_id": successor.attrib["elementId"], + "contactPoint": successor.attrib.get("contactPoint", "start"), + } + # Extract predecessor + predecessor = link_elem.find("predecessor") if ( predecessor is not None and predecessor.attrib.get("elementType") == "road" ): connections["predecessor"] = { "road_id": predecessor.attrib["elementId"], - "contact": predecessor.attrib.get("contactPoint", "start"), - } - - if successor is not None and successor.attrib.get("elementType") == "road": - connections["successor"] = { - "road_id": successor.attrib["elementId"], - "contact": successor.attrib.get("contactPoint", "start"), + "contactPoint": predecessor.attrib.get("contactPoint", "end"), } if connections: @@ -101,176 +115,171 @@ def extract_road_connections(root: ET.Element) -> Dict[str, Dict]: return road_connections -def connect_lanes_between_roads( - root: ET.Element, - road_connections: Dict[str, Dict], - road_to_lanes: Dict[str, List[LaneGeom]], +def _add_lane_link_connection( + link_elem: ET.Element, + link_type: str, + connected_road_id: str, + current_lane_id: str, + current_lane_geom: LaneGeom, + lane_geoms: Dict[str, LaneGeom], + to_next: bool, ) -> None: - """Connect lanes between roads using connectivity information. + """Parse and add a single lane link connection (successor or predecessor). - Handles both regular road connections (exact lane ID match) and - junction connections (spatial proximity). + Handles directional logic for positive lane IDs (left lanes) which flow + opposite to the road's reference direction. Args: - root: Root XML element of the XODR document - road_connections: Road connectivity information - road_to_lanes: Dictionary mapping road IDs to their lanes - Modified in place to add lane connections. - """ - for road_id, connections in road_connections.items(): - current_road_lanes = road_to_lanes.get(road_id, []) - - for direction, conn_info in connections.items(): - connected_road_id = conn_info["road_id"] - connected_road_lanes = road_to_lanes.get(connected_road_id, []) - - # Determine if this is a junction connection - curr_road = next( - (r for r in root.findall("road") if r.attrib["id"] == road_id), None - ) - conn_road = next( - ( - r - for r in root.findall("road") - if r.attrib["id"] == connected_road_id - ), - None, - ) - - curr_is_junction = ( - curr_road is not None and curr_road.attrib.get("junction", "-1") != "-1" - ) - conn_is_junction = ( - conn_road is not None and conn_road.attrib.get("junction", "-1") != "-1" - ) - - if curr_is_junction or conn_is_junction: - # Junction connection: use spatial proximity - _connect_junction_lanes( - current_road_lanes, - connected_road_lanes, - direction, - ) - else: - # Regular road connection: use exact lane ID match - _connect_regular_lanes( - current_road_lanes, - connected_road_lanes, - direction, - ) - - -def _connect_junction_lanes( - current_road_lanes: List[LaneGeom], - connected_road_lanes: List[LaneGeom], - direction: str, -) -> None: - """Connect lanes in junction areas using spatial proximity. + link_elem: The XML element containing successor/predecessor + link_type: "successor" or "predecessor" + connected_road_id: ID of the road containing the connected lane + current_lane_id: Unique ID (roadId_laneId) of current lane + current_lane_geom: LaneGeom object of current lane to update + lane_geoms: Dictionary of all lane geometries + to_next: True if connection should be added to next_lanes, + False for prev_lanes - Args: - current_road_lanes: Lanes from current road - connected_road_lanes: Lanes from connected road - direction: "successor" or "predecessor" """ - for curr_lane in current_road_lanes: - if not curr_lane.is_driving: - continue # Only connect driveable lanes - - # Find closest lane in connected road (by end/start point proximity) - curr_endpoint = ( - curr_lane.center[-1] if direction == "successor" else curr_lane.center[0] - ) - - best_match = None - min_distance = float("inf") - - for conn_lane in connected_road_lanes: - if not conn_lane.is_driving: - continue - - # Check distance to appropriate endpoint - conn_endpoint = ( - conn_lane.center[0] - if direction == "successor" - else conn_lane.center[-1] - ) - distance = np.linalg.norm(curr_endpoint[:2] - conn_endpoint[:2]) - - if ( - distance < min_distance - and distance < JUNCTION_LANE_CONNECTION_THRESHOLD - ): - min_distance = distance - best_match = conn_lane - - if best_match is not None: - if direction == "successor": - curr_lane.next_lanes.add(best_match.unique_id) - best_match.prev_lanes.add(curr_lane.unique_id) - elif direction == "predecessor": - curr_lane.prev_lanes.add(best_match.unique_id) - best_match.next_lanes.add(curr_lane.unique_id) - - -def _connect_regular_lanes( - current_road_lanes: List[LaneGeom], - connected_road_lanes: List[LaneGeom], - direction: str, + elem = link_elem.find(link_type) + if elem is not None: + connected_lane_id = elem.attrib.get("id") + if connected_lane_id and connected_road_id is not None: + connected_unique_id = f"{connected_road_id}_{connected_lane_id}" + if connected_unique_id in lane_geoms: + if to_next: + current_lane_geom.next_lanes.add(connected_unique_id) + lane_geoms[connected_unique_id].prev_lanes.add(current_lane_id) + else: + current_lane_geom.prev_lanes.add(connected_unique_id) + lane_geoms[connected_unique_id].next_lanes.add(current_lane_id) + + +def parse_lane_successors_predecessors( + root: ET.Element, lane_geoms: Dict[str, LaneGeom], road_connections: Dict[str, Dict] ) -> None: - """Connect lanes between regular roads using exact lane ID match. + """Parse lane-level successor/predecessor connections from XODR. + + This handles the lane.link.successor and lane.link.predecessor elements + which specify direct lane-to-lane connections across road boundaries. Args: - current_road_lanes: Lanes from current road - connected_road_lanes: Lanes from connected road - direction: "successor" or "predecessor" + root: Root XML element + lane_geoms: Dictionary of lane geometries to update + road_connections: Road-to-road connectivity info """ - for curr_lane in current_road_lanes: - for conn_lane in connected_road_lanes: - if curr_lane.lane_id_xml == conn_lane.lane_id_xml: # Same lane number - if direction == "successor": - curr_lane.next_lanes.add(conn_lane.unique_id) - conn_lane.prev_lanes.add(curr_lane.unique_id) - elif direction == "predecessor": - curr_lane.prev_lanes.add(conn_lane.unique_id) - conn_lane.next_lanes.add(curr_lane.unique_id) + for road in root.findall("road"): + road_id = road.attrib["id"] + lanes_elem = road.find("lanes") + if lanes_elem is None: + continue + + # Get the successor/predecessor road IDs for this road + successor_road_id = None + predecessor_road_id = None + + if road_id in road_connections: + if "successor" in road_connections[road_id]: + successor_road_id = road_connections[road_id]["successor"]["road_id"] + if "predecessor" in road_connections[road_id]: + predecessor_road_id = road_connections[road_id]["predecessor"]["road_id"] + + for lane_section in lanes_elem.findall("laneSection"): + for side in ["left", "right"]: + side_elem = lane_section.find(side) + if side_elem is None: + continue + + for lane in side_elem.findall("lane"): + lane_id = int(lane.attrib["id"]) + unique_id = f"{road_id}_{lane_id}" + + if unique_id not in lane_geoms: + continue + + lg = lane_geoms[unique_id] + + # Parse lane links + link_elem = lane.find("link") + if link_elem is None: + continue + + # Lane direction depends on lane ID sign: + # - Negative lanes (right): follow road direction + # - Positive lanes (left): opposite to road direction + # For positive lanes, we need to swap predecessor/successor meaning + should_invert = lane_id > 0 + + # Parse successor + # For positive lanes (should_invert=True): successor connects to prev + # For negative lanes (should_invert=False): successor connects to next + _add_lane_link_connection( + link_elem, + "successor", + successor_road_id, + unique_id, + lg, + lane_geoms, + to_next=not should_invert, + ) + + # Parse predecessor + # For positive lanes (should_invert=True): predecessor connects to next + # For negative lanes (should_invert=False): predecessor connects to prev + _add_lane_link_connection( + link_elem, + "predecessor", + predecessor_road_id, + unique_id, + lg, + lane_geoms, + to_next=should_invert, + ) def process_junction_connections( root: ET.Element, lane_geoms: Dict[str, LaneGeom], ) -> None: - """Process explicit junction laneLink connectivity. + """Process explicit junction connections from XODR. + + Parses elements to establish + lane-to-lane connectivity within junctions. Args: - root: Root XML element of the XODR document - lane_geoms: Dictionary of all lane geometries - Modified in place to add junction connections. + root: Root XML element + lane_geoms: Dictionary of lane geometries to update """ for junction in root.findall("junction"): - for conn in junction.findall("connection"): - conn_road_id = conn.attrib["connectingRoad"] - incoming_road_id = conn.attrib["incomingRoad"] - contact_pt = conn.attrib.get( - "contactPoint", "end" - ) # 'start' or 'end' relative to incoming road + for connection in junction.findall("connection"): + incoming_road = connection.attrib.get("incomingRoad") + connecting_road = connection.attrib.get("connectingRoad") + contact_point = connection.attrib.get("contactPoint", "start") + + if not incoming_road or not connecting_road: + continue + + # Process lane links + for lane_link in connection.findall("laneLink"): + from_lane_id = lane_link.attrib.get("from") + to_lane_id = lane_link.attrib.get("to") - for ll in conn.findall("laneLink"): - from_lane = int(ll.attrib["from"]) # lane id in connectingRoad - to_lane = int(ll.attrib["to"]) # lane id in incomingRoad + if not from_lane_id or not to_lane_id: + continue - unique_from = f"{conn_road_id}_{from_lane}" - unique_to = f"{incoming_road_id}_{to_lane}" + # Construct unique lane IDs + from_unique = f"{incoming_road}_{from_lane_id}" + to_unique = f"{connecting_road}_{to_lane_id}" - if unique_from not in lane_geoms or unique_to not in lane_geoms: - continue # malformed reference, skip + if from_unique not in lane_geoms or to_unique not in lane_geoms: + continue # Establish prev/next according to OpenDRIVE spec: # incomingRoad -> connectingRoad -> (other outgoing road) - if contact_pt == "start": + if contact_point == "start": # connectingRoad starts at incoming road, so incoming -> connecting - lane_geoms[unique_from].prev_lanes.add(unique_to) - lane_geoms[unique_to].next_lanes.add(unique_from) + lane_geoms[from_unique].next_lanes.add(to_unique) + lane_geoms[to_unique].prev_lanes.add(from_unique) else: # 'end' # connectingRoad ends at incoming road, so connecting -> incoming - lane_geoms[unique_from].next_lanes.add(unique_to) - lane_geoms[unique_to].prev_lanes.add(unique_from) + lane_geoms[to_unique].next_lanes.add(from_unique) + lane_geoms[from_unique].prev_lanes.add(to_unique) diff --git a/src/trajdata/dataset_specific/xodr/lane_processing.py b/src/trajdata/dataset_specific/xodr/lane_processing.py index 0601aee..d8fb935 100644 --- a/src/trajdata/dataset_specific/xodr/lane_processing.py +++ b/src/trajdata/dataset_specific/xodr/lane_processing.py @@ -6,7 +6,6 @@ from __future__ import annotations -import math import xml.etree.ElementTree as ET from typing import Dict, List, Tuple @@ -243,15 +242,31 @@ def _process_lane_geometry( max_xyz: Maximum extent (updated in place) """ # Process left side (positive IDs) and right side (negative IDs) separately - for side_ids in [ - sorted([lid for lid, _ in lane_offsets if lid > 0]), - sorted([lid for lid, _ in lane_offsets if lid < 0], reverse=True), - ]: - # Start from road centerline - current_edge_x, current_edge_y = center_x.copy(), center_y.copy() + # Both are processed from inside out (centerline -> outer edge) + left_lane_ids = sorted([lid for lid, _ in lane_offsets if lid > 0]) # [1, 2, 3...] + right_lane_ids = sorted([lid for lid, _ in lane_offsets if lid < 0], reverse=True) # [-1, -2, -3...] + + for is_left_side, side_ids in [(True, left_lane_ids), (False, right_lane_ids)]: + if is_left_side: + # For left lanes, start from reversed centerline since they flow opposite + current_edge_x = center_x[::-1] + current_edge_y = center_y[::-1] + current_center_z = center_z[::-1] + current_road_headings = road_headings[::-1] + else: + # For right lanes, use normal centerline + current_edge_x = center_x.copy() + current_edge_y = center_y.copy() + current_center_z = center_z.copy() + current_road_headings = road_headings.copy() for lid in side_ids: - widths = lane_width_samples[lid] + if is_left_side: + # For left lanes, reverse the width array to match reversed geometry + widths = lane_width_samples[lid][::-1] + else: + widths = lane_width_samples[lid] + lane_type = lane_types.get(lid, "none") lane_direction = lane_directions.get(lid, "standard") is_driving = lane_type in DRIVEABLE_LANE_TYPES @@ -265,8 +280,8 @@ def _process_lane_geometry( sign, current_edge_x, current_edge_y, - center_z, - road_headings, + current_center_z, + current_road_headings, lane_type, lane_direction, is_driving, @@ -297,7 +312,7 @@ def _process_lane_geometry( # Update edge for next iteration (move to outer edge) outer_edge_x, outer_edge_y = compute_polyline_from_width( - current_edge_x, current_edge_y, widths, sign, road_headings + current_edge_x, current_edge_y, widths, sign, current_road_headings ) current_edge_x, current_edge_y = outer_edge_x, outer_edge_y @@ -324,6 +339,7 @@ def _create_single_lane_geometry( sign: 1 for left lanes, -1 for right lanes current_edge_x: X coordinates of the inner edge current_edge_y: Y coordinates of the inner edge + center_z: Z coordinates along the lane road_headings: Heading angles along the road lane_type: Lane type from XODR lane_direction: Lane direction from XODR ("standard", "reversed", or "both") @@ -345,18 +361,15 @@ def _create_single_lane_geometry( # Build 3D coordinates using actual elevation data xyz_center = np.stack([mid_x, mid_y, center_z], axis=1) - - if lane_id > 0: # Left lane - xyz_left = np.stack([outer_edge_x, outer_edge_y, center_z], axis=1) - xyz_right = np.stack([current_edge_x, current_edge_y, center_z], axis=1) - else: # Right lane - xyz_left = np.stack([current_edge_x, current_edge_y, center_z], axis=1) - xyz_right = np.stack([outer_edge_x, outer_edge_y, center_z], axis=1) + + # Assign edges + # For right lanes: current edge is inner (left), outer edge is outer (right) + # For left lanes: with pre-reversed geometry, we use the same assignment + xyz_left = np.stack([current_edge_x, current_edge_y, center_z], axis=1) + xyz_right = np.stack([outer_edge_x, outer_edge_y, center_z], axis=1) # Compute lane headings from actual lane centerline lane_headings = recompute_headings(mid_x, mid_y) - if lane_id > 0: # Left lane - add π - lane_headings = (lane_headings + math.pi) % (2 * math.pi) - math.pi return LaneGeom( lane_id_xml=lane_id, diff --git a/src/trajdata/dataset_specific/xodr/parser.py b/src/trajdata/dataset_specific/xodr/parser.py index 146f8d7..82ca902 100644 --- a/src/trajdata/dataset_specific/xodr/parser.py +++ b/src/trajdata/dataset_specific/xodr/parser.py @@ -17,8 +17,8 @@ from .connectivity import ( build_lane_adjacency, - connect_lanes_between_roads, extract_road_connections, + parse_lane_successors_predecessors, process_junction_connections, ) from .datatypes import LaneGeom, ParsedXodr @@ -84,7 +84,9 @@ def parse_xodr(xodr_str: str, resolution: float = 0.5) -> ParsedXodr: # Extract and apply road-to-road connectivity road_connections = extract_road_connections(root) - connect_lanes_between_roads(root, road_connections, road_to_lanes) + + # Parse lane-level successor/predecessor connections + parse_lane_successors_predecessors(root, lane_geoms, road_connections) # Process explicit junction connections process_junction_connections(root, lane_geoms) @@ -266,11 +268,19 @@ def _parse_neighbor_lanes(root: ET.Element, lane_geoms: Dict[str, LaneGeom]) -> if left_elem is not None: neighbor_id = left_elem.attrib.get("id") if neighbor_id: - lg.left_neighbor_forward.add(f"{road_id}_{neighbor_id}") + direction = left_elem.attrib.get("direction", "forward") + if direction == "forward": + lg.left_neighbor_forward.add(f"{road_id}_{neighbor_id}") + elif direction == "backward": + lg.left_neighbor_backward.add(f"{road_id}_{neighbor_id}") # Right neighbor right_elem = link_elem.find("right") if right_elem is not None: neighbor_id = right_elem.attrib.get("id") if neighbor_id: - lg.right_neighbor_forward.add(f"{road_id}_{neighbor_id}") + direction = right_elem.attrib.get("direction", "forward") + if direction == "forward": + lg.right_neighbor_forward.add(f"{road_id}_{neighbor_id}") + elif direction == "backward": + lg.right_neighbor_backward.add(f"{road_id}_{neighbor_id}") From d428ba23094c32c010439b191fa38353a25c79c2 Mon Sep 17 00:00:00 2001 From: maximilianigl Date: Fri, 13 Mar 2026 22:04:44 +0100 Subject: [PATCH 3/4] Add clipgt dataset support (#61) * Add MADS dataset support Add the MADS dataset module, which provides: - MADSDataset: full dataset integration for loading MADS clipgt data including ego and obstacle agent parsing with cubic interpolation, overlap filtering, and map caching - populate_vector_map: constructs VectorMap instances from MADS parquet map data (lanes, road edges, traffic signs, wait lines) - MadsSceneRecord and env_utils registration for dataset discovery * Fix test * Update src/trajdata/dataset_specific/mads/mads_utils.py Co-authored-by: Boris Ivanovic <8534290+BorisIvanovic@users.noreply.github.com> * Update src/trajdata/dataset_specific/mads/mads_utils.py Co-authored-by: Boris Ivanovic <8534290+BorisIvanovic@users.noreply.github.com> * Update src/trajdata/utils/env_utils.py Co-authored-by: Boris Ivanovic <8534290+BorisIvanovic@users.noreply.github.com> * Update src/trajdata/dataset_specific/mads/mads_utils.py Co-authored-by: Boris Ivanovic <8534290+BorisIvanovic@users.noreply.github.com> * Update src/trajdata/dataset_specific/mads/mads_utils.py Co-authored-by: Boris Ivanovic <8534290+BorisIvanovic@users.noreply.github.com> * Update src/trajdata/dataset_specific/mads/mads_dataset.py Co-authored-by: Boris Ivanovic <8534290+BorisIvanovic@users.noreply.github.com> * Update src/trajdata/dataset_specific/mads/mads_dataset.py Co-authored-by: Boris Ivanovic <8534290+BorisIvanovic@users.noreply.github.com> * Update src/trajdata/dataset_specific/mads/mads_dataset.py Co-authored-by: Boris Ivanovic <8534290+BorisIvanovic@users.noreply.github.com> * Update src/trajdata/dataset_specific/mads/mads_dataset.py Co-authored-by: Boris Ivanovic <8534290+BorisIvanovic@users.noreply.github.com> * Update mads_dataset.py * fixes * Update env_utils.py * Update mads_utils.py --------- Co-authored-by: Boris Ivanovic <8534290+BorisIvanovic@users.noreply.github.com> --- .../dataset_specific/mads/__init__.py | 1 + .../dataset_specific/mads/mads_dataset.py | 575 ++++++++++++++++++ .../dataset_specific/mads/mads_utils.py | 398 ++++++++++++ .../dataset_specific/scene_records.py | 8 + src/trajdata/utils/env_utils.py | 4 + tests/test_state.py | 5 + 6 files changed, 991 insertions(+) create mode 100644 src/trajdata/dataset_specific/mads/__init__.py create mode 100644 src/trajdata/dataset_specific/mads/mads_dataset.py create mode 100644 src/trajdata/dataset_specific/mads/mads_utils.py diff --git a/src/trajdata/dataset_specific/mads/__init__.py b/src/trajdata/dataset_specific/mads/__init__.py new file mode 100644 index 0000000..d1eed6a --- /dev/null +++ b/src/trajdata/dataset_specific/mads/__init__.py @@ -0,0 +1 @@ +from .mads_dataset import MADSDataset diff --git a/src/trajdata/dataset_specific/mads/mads_dataset.py b/src/trajdata/dataset_specific/mads/mads_dataset.py new file mode 100644 index 0000000..fed0f95 --- /dev/null +++ b/src/trajdata/dataset_specific/mads/mads_dataset.py @@ -0,0 +1,575 @@ +import glob +import os +import random +from collections import defaultdict +from functools import partial +from pathlib import Path +from typing import Any, Dict, Final, List, Optional, Set, Tuple, Type, Union + +import numpy as np +import pandas as pd +from scipy.interpolate import CubicSpline +from scipy.spatial.transform import Rotation as R +from scipy.spatial.transform import Slerp +from tqdm import tqdm +from trajdata.caching import EnvCache, SceneCache +from trajdata.data_structures.agent import AgentMetadata, AgentType, FixedExtent +from trajdata.data_structures.environment import EnvMetadata +from trajdata.data_structures.scene_metadata import Scene, SceneMetadata +from trajdata.data_structures.scene_tag import SceneTag +from trajdata.dataset_specific.mads import mads_utils +from trajdata.dataset_specific.raw_dataset import RawDataset +from trajdata.dataset_specific.scene_records import MadsSceneRecord +from trajdata.maps import VectorMap +from trajdata.utils import arr_utils + +MADS_DT: Final[float] = 0.1 +EGO_LENGTH = 5.2993629 +EGO_WIDTH = 2.11311007 +EGO_HEIGHT = 1.34435794 + +# Minimum frames for an agent to be considered +MIN_FRAMES = 10 + +USE_CUBIC_INTERPOLATION = True + +agents_to_remove: List[str] = list() + + +def mads_type_to_unified_type(mads_type: str) -> AgentType: + if mads_type.startswith("person"): + return AgentType.PEDESTRIAN + elif mads_type == "automobile": + return AgentType.VEHICLE + elif mads_type == "other_vehicle": + return AgentType.VEHICLE + elif mads_type.startswith("cycle"): + return AgentType.BICYCLE + elif mads_type.startswith("motocycle"): + return AgentType.MOTORCYCLE + else: + return AgentType.UNKNOWN + + +class MADSDataset(RawDataset): + def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: + # Create scene splits + dataset_parts: List[Tuple[str, ...]] = [("mini_train", "mini_val")] + self.data_dir = data_dir + + # List to hold file names without extensions + file_names = [] + expanded_data_dir = str(Path(data_dir).expanduser()) + clip_dir = dict() + clip_duration = dict() + for subdir in os.listdir(expanded_data_dir): + # Construct the full path using the expanded data_dir + + # Sometimes the folder structure is .../clipgt_id/clipgt_id/files + clip_file = glob.glob( + os.path.join(expanded_data_dir, subdir) + "/**/clip.parquet", + recursive=True, + ) + + if len(clip_file) == 0: + continue + clip_file = clip_file[0] + df_meta = pd.read_parquet(clip_file) + df_meta = mads_utils.df_expand_json(df_meta) + session_id = df_meta["key.session_id"][0] + clip_id = df_meta["key.clip_id"][0] + clip_dir[clip_id] = str(Path(clip_file).parent.absolute()) + t0 = df_meta["key.time_range.start_micros"][0] + tf = df_meta["key.time_range.end_micros"][0] + clip_duration[clip_id] = (tf - t0) / 1e6 + self.clip_duration = clip_duration + + # Match scene with corresponding map. Each scene has a matching map. + + clip_items = list(clip_dir.items()) + random.Random(42).shuffle(clip_items) + self.clip_dir = dict(clip_items) + clip_ids = list(clip_dir.keys()) + + split_index = int(0.8 * len(clip_dir)) # 80% for training, 20% for validation + train_clips = [clip_id for clip_id, _ in clip_items[:split_index]] + test_clips = [clip_id for clip_id, _ in clip_items[split_index:]] + + scene_split_map = {} + for clip_id in train_clips: + scene_split_map[clip_id] = "mini_train" + for clip_id in test_clips: + scene_split_map[clip_id] = "mini_val" + # scene_0: mini_train + # scene_1: mini_val + # scene_2: mini_train + # scene_3: mini_train + # ... and so on + env_metadata = EnvMetadata( + name=env_name, + data_dir=data_dir, + dt=MADS_DT, + parts=dataset_parts, + scene_split_map=scene_split_map, + # The location names should match the map names used in + # the unified data cache. + map_locations=tuple(clip_ids), + ) + + return env_metadata + + def load_dataset_obj(self, verbose: bool = False) -> None: + pass + + def _get_matching_scenes_from_obj( + self, + scene_tag: SceneTag, + scene_desc_contains: Optional[List[str]], + env_cache: EnvCache, + ) -> List[SceneMetadata]: + all_scenes_list: List[MadsSceneRecord] = list() + + scenes_list: List[SceneMetadata] = list() + for idx, (clip_id, dir) in enumerate(self.clip_dir.items()): + scene_location = "clipGT" + if clip_id not in self.metadata.scene_split_map: + raise ValueError(f"Scene {clip_id} not in scene_split_map") + + scene_split: str = self.metadata.scene_split_map[clip_id] + scene_length: int = int(self.clip_duration[clip_id] / MADS_DT) + + if scene_length > 1: + all_scenes_list.append( + MadsSceneRecord( + clip_id, scene_location, scene_length, scene_split, idx + ) + ) + + if (scene_split in scene_tag) and scene_desc_contains is None: + scene_metadata = SceneMetadata( + env_name=self.metadata.name, + name=clip_id, + dt=self.metadata.dt, + raw_data_idx=idx, + ) + scenes_list.append(scene_metadata) + + self.cache_all_scenes_list(env_cache, all_scenes_list) + return scenes_list + + def _get_matching_scenes_from_cache( + self, + scene_tag: SceneTag, + scene_desc_contains: Optional[List[str]], + env_cache: EnvCache, + ) -> List[Scene]: + all_scenes_list: List[MadsSceneRecord] = env_cache.load_env_scenes_list( + self.name + ) + + scenes_list: List[SceneMetadata] = list() + for scene_record in all_scenes_list: + scene_name, scene_location, scene_length, scene_split, data_idx = ( + scene_record + ) + + if scene_split in scene_tag and scene_desc_contains is None: + scene_metadata = Scene( + self.metadata, + scene_name, + scene_location, + scene_split, + scene_length, + data_idx, + None, # This isn't used if everything is already cached. + ) + scenes_list.append(scene_metadata) + + return scenes_list + + def get_scene(self, scene_info: SceneMetadata) -> Scene: + # Type hinting for scene_info is not working properly in python 3.10 + # _, scene_name, _, data_idx = scene_info + scene_name = scene_info.name + data_idx = scene_info.raw_data_idx + + scene_location = scene_info.name + scene_split: str = self.metadata.scene_split_map[scene_name] + scene_length: int = int(self.clip_duration[scene_name] / MADS_DT) + 1 + + return Scene( + self.metadata, + scene_name, + scene_location, + scene_split, + scene_length, + data_idx, + None, + ) + + def _get_df_from_path(self, scene_path, scene_name): + dynamic_df = pd.read_parquet(os.path.join(scene_path, "obstacle.parquet")) + dynamic_df = mads_utils.df_expand_json(dynamic_df) + + if len(dynamic_df["key.clip_id"].unique()) != 1: + raise ValueError( + f"Expected only one clip_id, but got {dynamic_df['key.clip_id'].unique()}" + ) + + # Filter for clip_id + dynamic_df = dynamic_df[(dynamic_df["key.clip_id"] == scene_name)] + + manual_label_df = dynamic_df[ + dynamic_df["key.label_class_id"] == "seq3d:obstacles:v1" + ].copy() + manual_label_df["Obstacle.trackline_id"] = manual_label_df[ + "Obstacle.trackline_id" + ].map(lambda x: x.split(":")[-1]) + + manual_label_df["agent_id"] = "agent_gt_" + manual_label_df[ + "Obstacle.trackline_id" + ].astype(float).astype(int).astype(str) + + auto_label_df = dynamic_df[ + dynamic_df["key.label_class_id"] == "scene:obstacles:autolabels:v1" + ] + available_gt_labels = manual_label_df["Obstacle.trackline_id"].unique() + auto_label_df = auto_label_df[ + ~auto_label_df["Obstacle.trackline_id"].isin(available_gt_labels) + ] + auto_label_df["agent_id"] = "agent_auto_" + auto_label_df[ + "Obstacle.trackline_id" + ].astype(float).astype(int).astype(str) + + nr_auto_labels = len(auto_label_df["Obstacle.trackline_id"].unique()) + nr_gt_labels = len(available_gt_labels) + nr_obstacles = nr_auto_labels + nr_gt_labels + fraction = nr_gt_labels / nr_obstacles if nr_obstacles > 0 else 0 + print(f"{fraction * 100:.2f}% of {nr_obstacles} labels are manual.") + + manual_label_df["source"] = "manual" + auto_label_df["source"] = "autolabel" + dynamic_df = pd.concat([manual_label_df, auto_label_df]) + + dynamic_df.drop(columns=["Obstacle.trackline_id"], inplace=True) + + ego_df = pd.read_parquet(os.path.join(scene_path, "egomotion_estimate.parquet")) + ego_df = mads_utils.df_expand_json(ego_df) + + assert ego_df["EgomotionEstimate.name"].unique().size == 1 + ego_df = ego_df.assign( + length=EGO_LENGTH, + width=EGO_WIDTH, + height=EGO_HEIGHT, + type="automobile", + agent_id="ego", + source="manual", + ) + ego_df["key.label_class_id"] = ego_df["EgomotionEstimate.name"].iat[0] + + assert ego_df["EgomotionEstimate.name"].unique().size == 1 + ego_df = ego_df.drop(columns=["EgomotionEstimate.name"]) + + ego_df.rename( + columns={ + "EgomotionEstimate.location.x": "x", + "EgomotionEstimate.location.y": "y", + "EgomotionEstimate.location.z": "z", + "EgomotionEstimate.orientation.x": "qx", + "EgomotionEstimate.orientation.y": "qy", + "EgomotionEstimate.orientation.z": "qz", + "EgomotionEstimate.orientation.w": "qw", + }, + inplace=True, + ) + dynamic_df.rename( + columns={ + "Obstacle.center.x": "x", + "Obstacle.center.y": "y", + "Obstacle.center.z": "z", + "Obstacle.orientation.x": "qx", + "Obstacle.orientation.y": "qy", + "Obstacle.orientation.z": "qz", + "Obstacle.orientation.w": "qw", + "Obstacle.size.x": "length", + "Obstacle.size.y": "width", + "Obstacle.size.z": "height", + "Obstacle.category": "type", + }, + inplace=True, + ) + + t0 = ego_df["key.timestamp_micros"].iat[0] + tf = ego_df["key.timestamp_micros"].iat[-1] + + if not (dynamic_df["key.timestamp_micros"] < tf).all(): + raise ValueError("Some dynamic data is after the last ego data.") + + dynamic_df = pd.concat([ego_df, dynamic_df]) + + # Only select relevant dynamic data + dynamic_df = dynamic_df[(dynamic_df["key.timestamp_micros"] <= tf)] + + dynamic_df["rel_time_seconds"] = (dynamic_df["key.timestamp_micros"] - t0) / 1e6 + + interpolated_dfs = [] + for group_name, group_df in dynamic_df.groupby( + ["key.clip_id", "key.label_class_id", "agent_id"] + ): + group_df = group_df.sort_values(by=["rel_time_seconds"]) + + duplicated = group_df.duplicated(subset=["rel_time_seconds"]) + + if duplicated.sum() > 0: + print(f"Duplicated timestamps found for agent: {group_name}") + group_df = group_df[~duplicated] + + min_time = group_df["rel_time_seconds"].min() + min_step = int(np.ceil(min_time / MADS_DT)) + max_time = group_df["rel_time_seconds"].max() + max_step = int(np.floor(max_time / MADS_DT)) + target_steps = np.arange(min_step, max_step + 1) + target_times = target_steps * MADS_DT + + if max_step - min_step + 1 < MIN_FRAMES: + continue + + def _interp(col_name): + x = group_df["rel_time_seconds"] + y = group_df[col_name] + if USE_CUBIC_INTERPOLATION: + return CubicSpline(x, y)(target_times) + return np.interp(target_times, x, y) + + if not group_df["type"].unique().size == 1: + print( + "Multiple types encountered for agent: " + f"{group_name}, {group_df['type'].unique()}" + ) + + # [N, 4] + quats_tensor = np.stack( + [ + group_df["qx"], + group_df["qy"], + group_df["qz"], + group_df["qw"], + ], + axis=1, + ) + # Takes in scalar-last quaternion (x, y, z, w) + r = R.from_quat(quats_tensor) + slerp = Slerp(group_df["rel_time_seconds"], r) + interp_r = slerp(target_times) + headings = interp_r.as_euler("zyx", degrees=False)[:, 0] + + df = pd.DataFrame( + { + "key.clip_id": group_name[0], + "key.label_class_id": group_name[1], + "agent_id": group_name[2], + "scene_ts": target_steps, + "rel_time_seconds": target_times, + "x": _interp("x"), + "y": _interp("y"), + "z": _interp("z"), + "heading": headings, + # We interpolate this as this might change! + # In particular, I found this to change for manual labels. + "length": _interp("length"), + "width": _interp("width"), + "height": _interp("height"), + "type": group_df["type"].iat[0], + "source": group_df["source"].iat[0], + } + ) + + df["vx"] = df["x"].diff() / MADS_DT + df["vy"] = df["y"].diff() / MADS_DT + + # Calculate ego accelerations 'ax' and 'ay' + df["ax"] = df["vx"].diff() / MADS_DT + df["ay"] = df["vy"].diff() / MADS_DT + + # Replace infinity with nan for later nan handling + df["ax"] = df["ax"].replace([np.inf, -np.inf], np.nan) + df["ay"] = df["ay"].replace([np.inf, -np.inf], np.nan) + + # The first row of ax and ay is NaN, fill in values where NaN exists + df["vx"] = df["vx"].bfill().ffill() + df["vy"] = df["vy"].bfill().ffill() + df["ax"] = df["ax"].bfill().ffill() + df["ay"] = df["ay"].bfill().ffill() + + interpolated_dfs.append(df) + interpolated_df = pd.concat(interpolated_dfs).reset_index(drop=True) + assert interpolated_df.duplicated(subset=["scene_ts", "agent_id"]).sum() == 0 + + T = (tf - t0) / (1e6 * MADS_DT) + interpolated_df = interpolated_df[ + (interpolated_df["scene_ts"] >= 0) & (interpolated_df["scene_ts"] <= T) + ] + + # Sort by distance to ego + ego_start = interpolated_df.query("agent_id == 'ego' and scene_ts == 0") + ego_x = ego_start["x"].iat[0] + ego_y = ego_start["y"].iat[0] + + unique_distances = set() + + def get_group_distance_to_ego(group_df): + min_ts = group_df["scene_ts"].min() + agent_start = group_df.query(f"scene_ts == {min_ts}") + agent_x = agent_start["x"].iat[0] + agent_y = agent_start["y"].iat[0] + distance_to_ego = np.sqrt((ego_x - agent_x) ** 2 + (ego_y - agent_y) ** 2) + assert distance_to_ego not in unique_distances + unique_distances.add(distance_to_ego) + group_df["distance_to_ego"] = distance_to_ego + return group_df + + sorted_df = interpolated_df.groupby("agent_id").apply( + get_group_distance_to_ego, include_groups=False + ) + sorted_df = ( + sorted_df.sort_values(by=["distance_to_ego", "agent_id", "scene_ts"]) + .reset_index() + .drop(columns=["level_1"]) + ) + + # Filter out agents that are too close to each other + # Strategy: For simplicity and speed we only compare the xy locations of agents + # when they are first seen. This might miss cases when the agent moves + # and the 'ghost' object appears later. + # We start by adding all agents with gt labels. Then, we iterate over the rest + # of the agents and either: + # - accept them and add their first seen location to `first_states` + # - reject them and add them to `agents_to_remove` + def get_row_first_seen(df): + return ( + df.groupby("agent_id", sort=False) + .apply(lambda gdf: gdf.iloc[0], include_groups=False) + .reset_index() + ) + + first_states = get_row_first_seen(sorted_df.query("source == 'manual'")) + + agents_to_remove = set() + # Add agents one by one if they don't overlap + for agent_id, group_df in sorted_df.query("source != 'manual'").groupby( + "agent_id", sort=False + ): + current_agent_first_state = group_df.iloc[0] + first_states["distance_to_current_agent"] = np.sqrt( + (first_states["x"] - current_agent_first_state.x) ** 2 + + (first_states["y"] - current_agent_first_state.y) ** 2 + ) + closest_state = first_states.sort_values( + by=["distance_to_current_agent"] + ).iloc[0] + # Only conservative filtering based on dimensions: + if ( + min(closest_state.width, closest_state.height) + < closest_state.distance_to_current_agent + ): + agents_to_remove.add(agent_id) + else: + first_states = pd.concat([first_states, get_row_first_seen(group_df)]) + + sorted_df = sorted_df[~sorted_df["agent_id"].isin(agents_to_remove)] + return sorted_df + + def get_agent_info( + self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] + ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: + sorted_df = self._get_df_from_path(self.clip_dir[scene.name], scene.name) + + agent_list: List[AgentMetadata] = [] + agent_presence: List[List[AgentMetadata]] = [ + [] for _ in range(scene.length_timesteps) + ] + agents_to_remove = [] + for agent_id, frames in sorted_df.groupby("agent_id", sort=False)[ + ["scene_ts", "type", "length", "width", "height"] + ]: + all_frame_ids = frames["scene_ts"] + + start_frame: int = all_frame_ids.iat[0] + last_frame: int = all_frame_ids.iat[-1] + + agent_length = ( + frames["length"].iloc[0][0] + if isinstance(frames["length"].iloc[0], list) + else frames["length"].iloc[0] + ) + agent_width = ( + frames["width"].iloc[0][0] + if isinstance(frames["width"].iloc[0], list) + else frames["width"].iloc[0] + ) + agent_height = ( + frames["height"].iloc[0][0] + if isinstance(frames["height"].iloc[0], list) + else frames["height"].iloc[0] + ) + + agent_metadata = AgentMetadata( + name=agent_id, + agent_type=mads_type_to_unified_type(frames["type"].iloc[0]), + first_timestep=start_frame, + last_timestep=last_frame, + extent=FixedExtent( + length=agent_length, width=agent_width, height=agent_height + ), + ) + + agent_list.append(agent_metadata) + for frame in range( + agent_metadata.first_timestep, agent_metadata.last_timestep + ): + agent_presence[frame].append(agent_metadata) + + sorted_df = sorted_df[~sorted_df["agent_id"].isin(agents_to_remove)] + sorted_df.set_index(["agent_id", "scene_ts"], inplace=True) + + cache_class.save_agent_data( + sorted_df, + cache_path, + scene, + ) + + return agent_list, agent_presence + + def cache_map( + self, + map_name: str, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], + ) -> None: + """ + Stores rasterized maps to disk for later retrieval. + """ + resolution: float = map_params["px_per_m"] + print(f"Caching {map_name} Map at {resolution:.2f} px/m...", flush=True) + + vector_map = VectorMap(map_id=f"{self.name}:{map_name}") + mads_utils.populate_vector_map(vector_map, self.clip_dir[map_name]) + + map_cache_class.finalize_and_cache_map(cache_path, vector_map, map_params) + + def cache_maps( + self, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], + ) -> None: + """ + Stores rasterized maps to disk for later retrieval. + """ + for map_name in tqdm( + self.clip_dir.keys(), + desc=f"Caching {self.name} Maps at {map_params['px_per_m']:.2f} px/m", + position=0, + ): + self.cache_map(map_name, cache_path, map_cache_class, map_params) diff --git a/src/trajdata/dataset_specific/mads/mads_utils.py b/src/trajdata/dataset_specific/mads/mads_utils.py new file mode 100644 index 0000000..f05fbfc --- /dev/null +++ b/src/trajdata/dataset_specific/mads/mads_utils.py @@ -0,0 +1,398 @@ +import json +import os +from pathlib import Path +from typing import Any, Dict, Final, List, Optional, Tuple, Type + +import numpy as np +import pandas as pd +from scipy.interpolate import interp1d +from tqdm import tqdm +from trajdata.maps.vec_map import VectorMap +from trajdata.maps.vec_map_elements import ( + PedCrosswalk, + Polyline, + RoadEdge, + RoadLane, + TrafficSign, + WaitLine, +) + +MAX_POLYLINE_POINT_DIST = 2.0 + + +def df_expand_json(df: pd.DataFrame) -> pd.DataFrame: + """ + Expand json columns in a pandas DataFrame. + """ + columns = df.keys() + for key in columns: + cap_key = ( + "".join([x.capitalize() for x in key.split("_")]) + if key not in ["key", "version"] + else key + ) + df = df.join(pd.json_normalize(df[key]).add_prefix(f"{cap_key}.")) + + return df + + +def find_lane_polylines_parquet( + df_lane: pd.DataFrame, + df_lane_relation: pd.DataFrame, + clip_id: str, + df_wait_line: pd.DataFrame, +) -> Dict[str, Dict[str, Any]]: + lanes_dict = {} + for i in range(df_lane.shape[0]): + lane_id = df_lane["key.map_id"][i] + df_this_lane_relation = df_lane_relation[ + (df_lane_relation["key.clip_id"] == clip_id) + & (df_lane_relation["Association.subjects"] == lane_id) + ] + prev_lane_df = df_this_lane_relation[ + (df_this_lane_relation["key.kind"] == "PREVIOUS_LANE") + ] + prev_lane_id = ( + list(prev_lane_df["Association.objects"].values[0]) + if prev_lane_df.shape[0] > 0 + else {} + ) + + next_lane_df = df_this_lane_relation[ + (df_this_lane_relation["key.kind"] == "NEXT_LANE") + ] + next_lane_id = ( + list(next_lane_df["Association.objects"].values[0]) + if next_lane_df.shape[0] > 0 + else {} + ) + + right_lane_df = df_this_lane_relation[ + (df_this_lane_relation["key.kind"] == "RIGHT_LANE") + ] + right_lane_id = ( + list(right_lane_df["Association.objects"].values[0]) + if right_lane_df.shape[0] > 0 + else {} + ) + + left_lane_df = df_this_lane_relation[ + (df_this_lane_relation["key.kind"] == "LEFT_LANE") + ] + left_lane_id = ( + list(left_lane_df["Association.objects"].values[0]) + if left_lane_df.shape[0] > 0 + else {} + ) + + traffic_sign_df = df_this_lane_relation[ + (df_this_lane_relation["key.kind"] == "SIGN_TO_LANE") + ] + traffic_sign_id = ( + list(traffic_sign_df["Association.objects"].values[0]) + if traffic_sign_df.shape[0] > 0 + else {} + ) + + wait_line_df = df_wait_line[ + (df_wait_line["key.clip_id"] == clip_id) + & (df_wait_line["lane.map_id"] == lane_id) + ] + wait_line_id = ( + list(wait_line_df["key.map_id"].values) if wait_line_df.shape[0] > 0 else {} + ) + + if lane_id not in lanes_dict: + lanes_dict[lane_id] = {} + if "Lane.left_rail.x" in df_lane.keys(): + lanes_dict[lane_id]["left_rail"] = np.stack( + [ + df_lane["Lane.left_rail.x"][i], + df_lane["Lane.left_rail.y"][i], + df_lane["Lane.left_rail.z"][i], + ], + axis=1, + ) + lanes_dict[lane_id]["right_rail"] = np.stack( + [ + df_lane["Lane.right_rail.x"][i], + df_lane["Lane.right_rail.y"][i], + df_lane["Lane.right_rail.z"][i], + ], + axis=1, + ) + else: + lanes_dict[lane_id]["left_rail"] = np.stack( + [ + [pt["x"] for pt in df_lane["Lane.left_rail"][i]], + [pt["y"] for pt in df_lane["Lane.left_rail"][i]], + [pt["z"] for pt in df_lane["Lane.left_rail"][i]], + ], + axis=1, + ) + lanes_dict[lane_id]["right_rail"] = np.stack( + [ + [pt["x"] for pt in df_lane["Lane.right_rail"][i]], + [pt["y"] for pt in df_lane["Lane.right_rail"][i]], + [pt["z"] for pt in df_lane["Lane.right_rail"][i]], + ], + axis=1, + ) + + lanes_dict[lane_id]["next_lane"] = next_lane_id + lanes_dict[lane_id]["prev_lane"] = prev_lane_id + lanes_dict[lane_id]["left_lane"] = left_lane_id + lanes_dict[lane_id]["right_lane"] = right_lane_id + lanes_dict[lane_id]["traffic_sign"] = traffic_sign_id + lanes_dict[lane_id]["wait_line"] = wait_line_id + + return lanes_dict + + +def find_road_edges_parquet(df_road_edge: pd.DataFrame, clip_id: str) -> Dict[str, Any]: + road_edges_dict = {} + for i in range(df_road_edge.shape[0]): + road_edge_id = df_road_edge["key.map_id"][i] + if road_edge_id not in road_edges_dict: + if "RoadBoundary.location.x" in df_road_edge.keys(): + road_edges_dict[road_edge_id] = np.stack( + [ + df_road_edge["RoadBoundary.location.x"][i], + df_road_edge["RoadBoundary.location.y"][i], + df_road_edge["RoadBoundary.location.z"][i], + ], + axis=1, + ) + else: + road_edges_dict[road_edge_id] = np.stack( + [ + [pt["x"] for pt in df_road_edge["RoadBoundary.location"][i]], + [pt["y"] for pt in df_road_edge["RoadBoundary.location"][i]], + [pt["z"] for pt in df_road_edge["RoadBoundary.location"][i]], + ], + axis=1, + ) + return road_edges_dict + + +def find_traffic_signs_parquet( + df_traffic_sign: pd.DataFrame, clip_id: str +) -> Dict[str, Any]: + """ + return: + traffic_signs_dict: { + traffic_sign_id: { + position: np.array([x, y, z]), position of the traffic sign + type: str, traffic sign type, e.g., stop sign + } + } + """ + # TODO: including all traffic signs here which might be not linked by the lane + traffic_signs_dict = {} + for i in range(df_traffic_sign.shape[0]): + traffic_sign_id = df_traffic_sign["key.map_id"][i] + if traffic_sign_id not in traffic_signs_dict: + traffic_signs_dict[traffic_sign_id] = {} + traffic_signs_dict[traffic_sign_id]["position"] = np.array( + [ + df_traffic_sign["TrafficSign.center.x"][i], + df_traffic_sign["TrafficSign.center.y"][i], + df_traffic_sign["TrafficSign.center.z"][i], + ] + ) + traffic_signs_dict[traffic_sign_id]["type"] = df_traffic_sign[ + "TrafficSign.category" + ][i] + return traffic_signs_dict + + +def find_wait_lines_parquet(df_wait_line: pd.DataFrame, clip_id: str) -> Dict[str, Any]: + """ + return: + wait_lines_dict: { + wait_line_id: { + location: np.array([[x1, y1, z1], [x2, y2, z2], ...]), location of the wait line + category: str, wait line category, e.g., Yield/Stop + implicit: bool, unclear what this means, need to check with clipgt team + } + } + """ + wait_lines_dict = {} + for i in range(df_wait_line.shape[0]): + wait_line_id = df_wait_line["key.map_id"][i] + if wait_line_id not in wait_lines_dict: + wait_lines_dict[wait_line_id] = {} + if "WaitLine.location.x" in df_wait_line.keys(): + wait_lines_dict[wait_line_id]["location"] = np.stack( + [ + df_wait_line["WaitLine.location.x"][i], + df_wait_line["WaitLine.location.y"][i], + df_wait_line["WaitLine.location.z"][i], + ], + axis=1, + ) + else: + wait_lines_dict[wait_line_id]["location"] = np.stack( + [ + [pt["x"] for pt in df_wait_line["WaitLine.location"][i]], + [pt["y"] for pt in df_wait_line["WaitLine.location"][i]], + [pt["z"] for pt in df_wait_line["WaitLine.location"][i]], + ], + axis=1, + ) + wait_lines_dict[wait_line_id]["category"] = df_wait_line["WaitLine.category"][ + i + ] # Yield/Stop + wait_lines_dict[wait_line_id]["implicit"] = df_wait_line[ + "WaitLine.is_implicit" + ][i] # TODO: need to clarify what this specifically refers to + return wait_lines_dict + + +def interpolate_points(list1, list2): + # Determine which list is shorter + if len(list1) > len(list2): + longer, shorter = list1, list2 + else: + longer, shorter = list2, list1 + + # Extract x, y, z coordinates + x, y, z = zip(*shorter) + x_long, y_long, z_long = zip(*longer) + + # Create an array of indices for interpolation + interp_indices = np.linspace(0, len(shorter) - 1, num=len(longer)) + + # Interpolate x, y, z + x_interp = interp1d(range(len(x)), x, kind="linear")(interp_indices) + y_interp = interp1d(range(len(y)), y, kind="linear")(interp_indices) + z_interp = interp1d(range(len(z)), z, kind="linear")(interp_indices) + + # Combine interpolated coordinates + interpolated_points = list(zip(x_interp, y_interp, z_interp)) + + # Return the original longer list and the new interpolated list + return longer, interpolated_points + + +def populate_vector_map(vector_map: VectorMap, map_root) -> None: + # populate vector map from mads parquet files + # map label schema: https://docs.nvda.ai/ndas/avdnn/latest/reference/maglev/data/clip/reference/schemas/labels.html?#map-derived-labels + + maximum_bound: np.ndarray = np.full((3,), np.nan) + minimum_bound: np.ndarray = np.full((3,), np.nan) + df_lane = df_expand_json(pd.read_parquet(os.path.join(map_root, "lane.parquet"))) + df_lane_relation = df_expand_json( + pd.read_parquet(os.path.join(map_root, "association.parquet")) + ) + df_road_edge = df_expand_json( + pd.read_parquet(os.path.join(map_root, "road_boundary.parquet")) + ) + df_meta = df_expand_json(pd.read_parquet(os.path.join(map_root, "clip.parquet"))) + df_traffic_sign = df_expand_json( + pd.read_parquet(os.path.join(map_root, "traffic_sign.parquet")) + ) + # TODO: invalid traffic light data from mads + # df_traffic_light = pd.read_parquet(os.path.join(map_root, "traffic_light.parquet")) + df_wait_line = df_expand_json( + pd.read_parquet(os.path.join(map_root, "wait_line.parquet")) + ) + # wait_line id is formatted as {wait_line_id}-{lane_id} + df_wait_line["lane.map_id"] = df_wait_line["key.map_id"].map( + lambda x: x.split("-")[1] + ) + clip_id = df_meta["key.clip_id"][0] + all_lanes_dict = find_lane_polylines_parquet( + df_lane, df_lane_relation, clip_id, df_wait_line + ) + all_road_edges_dict = find_road_edges_parquet(df_road_edge, clip_id) + all_traffic_signs_dict = find_traffic_signs_parquet(df_traffic_sign, clip_id) + all_wait_lines_dict = find_wait_lines_parquet(df_wait_line, clip_id) + del df_lane, df_lane_relation, df_meta, df_road_edge + + if not all_lanes_dict: + print("No valid data available in map file") + return + + for lane_id, lane_info_dict in tqdm( + all_lanes_dict.items(), desc="Creating Vectorized Map" + ): + left_polyline = np.array(lane_info_dict["left_rail"]) + right_polyline = np.array(lane_info_dict["right_rail"]) + + midlane_pts: np.ndarray = (left_polyline + right_polyline) / 2 + + # Computing the maximum and minimum map coordinates. + maximum_bound = np.fmax(maximum_bound, left_polyline.max(axis=0)) + minimum_bound = np.fmin(minimum_bound, left_polyline.min(axis=0)) + + maximum_bound = np.fmax(maximum_bound, right_polyline.max(axis=0)) + minimum_bound = np.fmin(minimum_bound, right_polyline.min(axis=0)) + + maximum_bound = np.fmax(maximum_bound, midlane_pts.max(axis=0)) + minimum_bound = np.fmin(minimum_bound, midlane_pts.min(axis=0)) + + new_lane = RoadLane( + id=lane_id, + center=Polyline(midlane_pts).interpolate(max_dist=MAX_POLYLINE_POINT_DIST), + left_edge=Polyline(left_polyline).interpolate( + max_dist=MAX_POLYLINE_POINT_DIST + ), + right_edge=Polyline(right_polyline).interpolate( + max_dist=MAX_POLYLINE_POINT_DIST + ), + next_lanes=lane_info_dict["next_lane"], + adj_lanes_left=lane_info_dict["left_lane"], + adj_lanes_right=lane_info_dict["right_lane"], + prev_lanes=lane_info_dict["prev_lane"], + traffic_sign_ids=lane_info_dict["traffic_sign"], + wait_line_ids=lane_info_dict["wait_line"], + ) + vector_map.add_map_element(new_lane) + + for road_edge_id, road_edge_pts in tqdm( + all_road_edges_dict.items(), desc="Creating Road Edges" + ): + maximum_bound = np.fmax(maximum_bound, road_edge_pts.max(axis=0)) + minimum_bound = np.fmin(minimum_bound, road_edge_pts.min(axis=0)) + new_road_edge = RoadEdge( + id=road_edge_id, + polyline=Polyline(road_edge_pts).interpolate( + max_dist=MAX_POLYLINE_POINT_DIST + ), + ) + + vector_map.add_map_element(new_road_edge) + + for traffic_sign_id, traffic_sign_info_dict in tqdm( + all_traffic_signs_dict.items(), desc="Creating Traffic Signs" + ): + # TODO better error handling for invalid traffic sign data + if traffic_sign_id is None: + continue + new_traffic_sign = TrafficSign( + id=traffic_sign_id, + position=traffic_sign_info_dict["position"], + sign_type=traffic_sign_info_dict["type"], + ) + vector_map.add_map_element(new_traffic_sign) + + for wait_line_id, wait_line_info_dict in tqdm( + all_wait_lines_dict.items(), desc="Creating Wait Lines" + ): + # TODO better error handling for invalid wait line data + if wait_line_id is None: + continue + new_wait_line = WaitLine( + id=wait_line_id, + polyline=Polyline( + wait_line_info_dict["location"] + ), # yulongc: do we need interpolate here? + wait_line_type=wait_line_info_dict["category"], + is_implicit=wait_line_info_dict["implicit"], + ) + vector_map.add_map_element(new_wait_line) + + # Setting the map bounds. + # vector_map.extent is [min_x, min_y, min_z, max_x, max_y, max_z] + vector_map.extent = np.concatenate((minimum_bound, maximum_bound)) diff --git a/src/trajdata/dataset_specific/scene_records.py b/src/trajdata/dataset_specific/scene_records.py index 68bd5b9..0108f41 100644 --- a/src/trajdata/dataset_specific/scene_records.py +++ b/src/trajdata/dataset_specific/scene_records.py @@ -48,3 +48,11 @@ class NuPlanSceneRecord(NamedTuple): split: str # desc: str data_idx: int + + +class MadsSceneRecord(NamedTuple): + name: str + location: str + length: str + split: str + data_idx: int diff --git a/src/trajdata/utils/env_utils.py b/src/trajdata/utils/env_utils.py index 47f13b2..62bd831 100644 --- a/src/trajdata/utils/env_utils.py +++ b/src/trajdata/utils/env_utils.py @@ -44,6 +44,10 @@ def get_raw_dataset(dataset_name: str, data_dir: str) -> RawDataset: return InteractionDataset( dataset_name, data_dir, parallelizable=True, has_maps=True ) + if "mads" in dataset_name.lower(): + from trajdata.dataset_specific.mads import MADSDataset + + return MADSDataset(dataset_name, data_dir, parallelizable=True, has_maps=True) raise ValueError(f"Dataset with name '{dataset_name}' is not supported") diff --git a/tests/test_state.py b/tests/test_state.py index a954145..7ed3210 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1,3 +1,5 @@ +import importlib +import importlib.util import unittest import numpy as np @@ -9,6 +11,8 @@ from collections import defaultdict from torch.utils.data import DataLoader +_has_nuscenes = importlib.util.find_spec("nuscenes") is not None + AgentStateArray = NP_STATE_TYPES["x,y,z,xd,yd,xdd,ydd,h"] AgentObsArray = NP_STATE_TYPES["x,y,z,xd,yd,xdd,ydd,s,c"] AgentStateTensor = TORCH_STATE_TYPES["x,y,z,xd,yd,xdd,ydd,h"] @@ -190,6 +194,7 @@ def test_tensor_ops(self): self.assertTrue(isinstance(c, float)) +@unittest.skipUnless(_has_nuscenes, "nuscenes package not installed") class TestDataset(unittest.TestCase): def test_dataloading(self): dataset = UnifiedDataset( From 90f6f66ad186f9336c40f2a427388b6cda49ba54 Mon Sep 17 00:00:00 2001 From: Maximilian Igl Date: Thu, 9 Apr 2026 11:57:52 +0200 Subject: [PATCH 4/4] Fix left road boundary placed at centerline instead of outer edge _extract_road_edges read lg.left_edge from the outermost left lane to get the left road boundary. In v1.4.3 the edge convention changed (left_edge = inner for all lanes), but _extract_road_edges was not updated, so the left boundary ended up at the centerline. Replace the indirect lookup with direct capture: after the edge-chaining loop finishes each side, current_edge is already the outermost road edge. Record it there and delete _extract_road_edges entirely. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../dataset_specific/xodr/lane_processing.py | 53 +++++-------------- 1 file changed, 14 insertions(+), 39 deletions(-) diff --git a/src/trajdata/dataset_specific/xodr/lane_processing.py b/src/trajdata/dataset_specific/xodr/lane_processing.py index d8fb935..f2821e2 100644 --- a/src/trajdata/dataset_specific/xodr/lane_processing.py +++ b/src/trajdata/dataset_specific/xodr/lane_processing.py @@ -86,7 +86,7 @@ def process_road( # Precompute widths per lane lane_width_samples = _compute_lane_widths(lane_offsets, s_grid) - # Process all lanes to create geometry + # Process all lanes to create geometry and capture outer road edges _process_lane_geometry( road_id, lane_offsets, @@ -98,14 +98,12 @@ def process_road( center_z, road_headings, lane_geoms, + road_edges, min_xyz, max_xyz, sidewalks, ) - # Extract and register road edges - _extract_road_edges(road_id, lane_offsets, lane_geoms, road_edges) - # Create artificial edges for non-junction roads if not is_junction_road: _create_artificial_edges( @@ -222,6 +220,7 @@ def _process_lane_geometry( center_z: np.ndarray, road_headings: np.ndarray, lane_geoms: Dict[str, LaneGeom], + road_edges: Dict[str, np.ndarray], min_xyz: np.ndarray, max_xyz: np.ndarray, sidewalks: Dict[str, np.ndarray] = None, @@ -238,6 +237,7 @@ def _process_lane_geometry( center_y: Y coordinates of road centerline road_headings: Heading angles at each point lane_geoms: Dictionary with lane geometries (updated in place) + road_edges: Dictionary with road edge polylines (updated in place) min_xyz: Minimum extent (updated in place) max_xyz: Maximum extent (updated in place) """ @@ -316,6 +316,16 @@ def _process_lane_geometry( ) current_edge_x, current_edge_y = outer_edge_x, outer_edge_y + # After all lanes on this side, current_edge is the outermost road edge + if side_ids: + if is_left_side: + # Reverse back to road reference direction + edge = np.stack([current_edge_x[::-1], current_edge_y[::-1], center_z], axis=1) + road_edges[f"{road_id}_L"] = edge + else: + edge = np.stack([current_edge_x, current_edge_y, center_z], axis=1) + road_edges[f"{road_id}_R"] = edge + def _create_single_lane_geometry( road_id: str, @@ -386,41 +396,6 @@ def _create_single_lane_geometry( ) -def _extract_road_edges( - road_id: str, - lane_offsets: List[Tuple[int, List]], - lane_geoms: Dict[str, LaneGeom], - road_edges: Dict[str, np.ndarray], -) -> None: - """Extract outer road edges from the outermost lanes. - - Args: - road_id: ID of the road - lane_offsets: List of (lane_id, width_sections) tuples - lane_geoms: Dictionary of lane geometries - road_edges: Dictionary with road edges (updated in place) - """ - # Extract edges from outermost lanes - left_lane_ids = [lid for lid, _ in lane_offsets if lid > 0] - right_lane_ids = [lid for lid, _ in lane_offsets if lid < 0] - - # Try to get edges from outermost lanes first - if left_lane_ids: - outermost_left = max(left_lane_ids) - outermost_id = f"{road_id}_{outermost_left}" - if outermost_id in lane_geoms: - lg = lane_geoms[outermost_id] - if lg.left_edge is not None: - road_edges[f"{road_id}_L"] = lg.left_edge - - if right_lane_ids: - outermost_right = min(right_lane_ids) - outermost_id = f"{road_id}_{outermost_right}" - if outermost_id in lane_geoms: - lg = lane_geoms[outermost_id] - if lg.right_edge is not None: - road_edges[f"{road_id}_R"] = lg.right_edge - def _create_artificial_edges( road_id: str,