From b2fbf3a2a52faad4f4393d60698d80497201e874 Mon Sep 17 00:00:00 2001 From: Stepan Konev Date: Sun, 8 Oct 2023 23:15:29 +0200 Subject: [PATCH 1/4] Yandex Shifts Motion Dataset add --- DATASETS.md | 26 ++ README.md | 8 +- pyproject.toml | 1 + src/trajdata/augmentation/noise_histories.py | 7 +- src/trajdata/data_structures/batch_element.py | 2 + .../dataset_specific/scene_records.py | 12 + .../yandex_shifts/__init__.py | 1 + .../yandex_shifts/yandex_shifts_dataset.py | 205 ++++++++++++ .../yandex_shifts/yandex_shifts_utils.py | 294 ++++++++++++++++++ src/trajdata/utils/env_utils.py | 11 + src/trajdata/utils/raster_utils.py | 32 +- 11 files changed, 580 insertions(+), 19 deletions(-) create mode 100644 src/trajdata/dataset_specific/yandex_shifts/__init__.py create mode 100644 src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_dataset.py create mode 100644 src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_utils.py diff --git a/DATASETS.md b/DATASETS.md index b2f9255..07de6a0 100644 --- a/DATASETS.md +++ b/DATASETS.md @@ -76,6 +76,32 @@ It should look like this after downloading: **Note**: Not all the dataset parts need to be downloaded, only the necessary directories in [the Google Cloud Bucket](https://console.cloud.google.com/storage/browser/waymo_open_dataset_motion_v_1_1_0/uncompressed/scenario) need to be downloaded (e.g., `validation` for the validation dataset). +## Yandex Shifts Motion Prediction Dataset +Nothing special needs to be done for the Yandex Shifts Motion Prediction Dataset, simply download as per [the instructions on the dataset website](https://github.com/Shifts-Project/shifts#motion-prediction-1). + +It should look like this after downloading: +``` +/path/to/ysdc/ + ├── train/ + | ├── 000 + | | ├── 000000.pb + | | └── ... + | └── ... + ├── development/ + | ├── 000 + | | ├── 000000.pb + | | └── ... + | └── ... + └── eval/ + ├── 000 + | ├── 000000.pb + | └── ... + └── ... +``` + +**Note**: Yuo may also download a complete unpartitioned dataset. The dataset also contains prerendered examples, +which are not required for `trajdata` functioning. + ## Lyft Level 5 Nothing special needs to be done for the Lyft Level 5 dataset, simply download it as per [the instructions on the dataset website](https://woven-planet.github.io/l5kit/dataset.html). diff --git a/README.md b/README.md index a715ebc..5becea3 100644 --- a/README.md +++ b/README.md @@ -24,11 +24,14 @@ pip install "trajdata[lyft]" # For Waymo pip install "trajdata[waymo]" +# For Yandex Shifts Motion Dataset +pip install "trajdata[ysdc]" + # For INTERACTION pip install "trajdata[interaction]" # All -pip install "trajdata[nusc,lyft,waymo,interaction]" +pip install "trajdata[nusc,lyft,waymo,interaction,ysdc]" ``` Then, download the raw datasets (nuScenes, Lyft Level 5, ETH/UCY, etc.) in case you do not already have them. For more information about how to structure dataset folders/files, please see [`DATASETS.md`](./DATASETS.md). @@ -99,6 +102,9 @@ Currently, the dataloader supports interfacing with the following datasets: | Waymo Open Motion Training | `waymo_train` | `train` | N/A | Waymo Open Motion Dataset `training` split | 0.1s (10Hz) | :white_check_mark: | | Waymo Open Motion Validation | `waymo_val` | `val` | N/A | Waymo Open Motion Dataset `validation` split | 0.1s (10Hz) | :white_check_mark: | | Waymo Open Motion Testing | `waymo_test` | `test` | N/A | Waymo Open Motion Dataset `testing` split | 0.1s (10Hz) | :white_check_mark: | +| Yandex Shifts Motion Dataset Training | `ysdc_train` | `train` | N/A | Yandex Shifts Motion Dataset `training` split | 0.2s (5Hz) | :white_check_mark: | +| Yandex Shifts Motion Dataset Development | `ysdc_development` | `development` | N/A | Yandex Shifts Motion Dataset `development` split | 0.2s (5Hz) | :white_check_mark: | +| Yandex Shifts Motion Dataset Evaluation | `ysdc_eval` | `eval` | N/A | Yandex Shifts Motion Dataset `eval` split | 0.2 (5Hz) | :white_check_mark: | | Lyft Level 5 Train | `lyft_train` | `train` | `palo_alto` | Lyft Level 5 training data - part 1/2 (8.4 GB) | 0.1s (10Hz) | :white_check_mark: | | Lyft Level 5 Train Full | `lyft_train_full` | `train` | `palo_alto` | Lyft Level 5 training data - part 2/2 (70 GB) | 0.1s (10Hz) | :white_check_mark: | | Lyft Level 5 Validation | `lyft_val` | `val` | `palo_alto` | Lyft Level 5 validation data (8.2 GB) | 0.1s (10Hz) | :white_check_mark: | diff --git a/pyproject.toml b/pyproject.toml index 4cf3822..e40550a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ interaction = ["lanelet2==1.2.1"] lyft = ["l5kit==1.5.0"] nusc = ["nuscenes-devkit==1.1.9"] waymo = ["tensorflow==2.11.0", "waymo-open-dataset-tf-2-11-0", "intervaltree"] +ysdc = ["ysdc-dataset-api @ git+https://github.com/yandex-research/shifts.git#subdirectory=sdc"] [project.urls] "Homepage" = "https://github.com/nvr-avg/trajdata" diff --git a/src/trajdata/augmentation/noise_histories.py b/src/trajdata/augmentation/noise_histories.py index 1aca9c6..b0fa4af 100644 --- a/src/trajdata/augmentation/noise_histories.py +++ b/src/trajdata/augmentation/noise_histories.py @@ -23,8 +23,11 @@ def apply_agent(self, agent_batch: AgentBatch) -> None: ) if agent_batch.history_pad_dir == PadDirection.BEFORE: - agent_hist_noise[..., -1, :] = 0 - neigh_hist_noise[..., -1, :] = 0 + try: + agent_hist_noise[..., -1, :] = 0 + neigh_hist_noise[..., -1, :] = 0 + except IndexError: + pass else: len_mask = ~mask_up_to( agent_batch.agent_hist_len, diff --git a/src/trajdata/data_structures/batch_element.py b/src/trajdata/data_structures/batch_element.py index f18e34d..6a26e82 100644 --- a/src/trajdata/data_structures/batch_element.py +++ b/src/trajdata/data_structures/batch_element.py @@ -39,6 +39,7 @@ def __init__( self.cache: SceneCache = cache self.data_index: int = data_index self.dt: float = scene_time_agent.scene.dt + self.track_info = scene_time_agent.scene.data_access_info self.scene_ts: int = scene_time_agent.ts self.history_sec = history_sec self.future_sec = future_sec @@ -341,6 +342,7 @@ def __init__( self.data_index = data_index self.dt: float = scene_time.scene.dt self.scene_ts: int = scene_time.ts + self.track_info = scene_time.scene.data_access_info if max_agent_num is not None: scene_time.agents = scene_time.agents[:max_agent_num] diff --git a/src/trajdata/dataset_specific/scene_records.py b/src/trajdata/dataset_specific/scene_records.py index 68bd5b9..5790713 100644 --- a/src/trajdata/dataset_specific/scene_records.py +++ b/src/trajdata/dataset_specific/scene_records.py @@ -48,3 +48,15 @@ class NuPlanSceneRecord(NamedTuple): split: str # desc: str data_idx: int + + +class YandexShiftsSceneRecord(NamedTuple): + name: str + length: str + data_idx: int + day_time: str + season: str + track: str + sun_phase: str + precipitation: str + diff --git a/src/trajdata/dataset_specific/yandex_shifts/__init__.py b/src/trajdata/dataset_specific/yandex_shifts/__init__.py new file mode 100644 index 0000000..dc1fde6 --- /dev/null +++ b/src/trajdata/dataset_specific/yandex_shifts/__init__.py @@ -0,0 +1 @@ +from .yandex_shifts_dataset import YandexShiftsDataset diff --git a/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_dataset.py b/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_dataset.py new file mode 100644 index 0000000..ce5320d --- /dev/null +++ b/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_dataset.py @@ -0,0 +1,205 @@ +from collections import defaultdict +from functools import partial +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Type + +import pandas as pd +import tqdm + +from ysdc_dataset_api.utils import get_file_paths, scenes_generator +from ysdc_dataset_api.proto import Scene as YSDCScene +from trajdata.caching import EnvCache, SceneCache +from trajdata.data_structures import ( + EnvMetadata, Scene, SceneMetadata, SceneTag) +from trajdata.data_structures.agent import AgentMetadata +from trajdata.dataset_specific.raw_dataset import RawDataset +from trajdata.dataset_specific.scene_records import YandexShiftsSceneRecord +from trajdata.dataset_specific.yandex_shifts import yandex_shifts_utils +from trajdata.maps import VectorMap +from trajdata.utils.parallel_utils import parallel_apply +from trajdata.dataset_specific.yandex_shifts.yandex_shifts_utils import ( + read_scene_from_original_proto, get_scene_path, extract_vectorized, + extract_traffic_light_status, extract_agent_data_from_ysdc_scene) + + +def const_lambda(const_val: Any) -> Any: + return const_val + + +class YandexShiftsDataset(RawDataset): + def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: + if env_name == "ysdc_train": + dataset_parts = [("train",)] + scene_split_map = defaultdict( + partial(const_lambda, const_val="train")) + + elif env_name == "ysdc_development": + dataset_parts = [("development",)] + scene_split_map = defaultdict(partial( + const_lambda, const_val="development")) + + elif env_name == "ysdc_eval": + dataset_parts = [("eval",)] + scene_split_map = defaultdict(partial( + const_lambda, const_val="eval")) + + elif env_name == "ysdc_full": + dataset_parts = [("full",)] + scene_split_map = defaultdict(partial( + const_lambda, const_val="full")) + + return EnvMetadata( + name=env_name, + data_dir=data_dir, + dt=yandex_shifts_utils.YSDC_DT, + parts=dataset_parts, + scene_split_map=scene_split_map, + ) + + def load_dataset_obj(self, verbose: bool = False) -> None: + if verbose: + print(f"Loading {self.name} dataset...", flush=True) + self.dataset_obj = scenes_generator( + get_file_paths(self.metadata.data_dir)) + + def _get_matching_scenes_from_obj( + self, + scene_tag: SceneTag, + scene_desc_contains: Optional[List[str]], + env_cache: EnvCache, + ) -> List[SceneMetadata]: + all_scenes_list: List[YandexShiftsSceneRecord] = list() + scenes_list: List[SceneMetadata] = list() + for idx, scene in tqdm.tqdm( + enumerate(self.dataset_obj), + desc="Processing scenes from proto files"): + scene_name: str = scene.id + scene_split: str = self.metadata.scene_split_map[scene_name] + scene_length: int = yandex_shifts_utils.YSDC_LENGTH + # Saving all scene records for later caching. + all_scenes_list.append( + YandexShiftsSceneRecord( + scene_name, + str(scene_length), + idx, + scene.scene_tags.day_time, + scene.scene_tags.season, + scene.scene_tags.track, + scene.scene_tags.sun_phase, + scene.scene_tags.precipitation)) + if scene_split in scene_tag and scene_desc_contains is None: + scene_metadata = SceneMetadata( + env_name=self.metadata.name, + name=scene_name, + dt=self.metadata.dt, + raw_data_idx=idx, + ) + scenes_list.append(scene_metadata) + self.cache_all_scenes_list(env_cache, all_scenes_list) + return scenes_list + + def _get_matching_scenes_from_cache( + self, + scene_tag: SceneTag, + scene_desc_contains: Optional[List[str]], + env_cache: EnvCache, + ) -> List[Scene]: + all_scenes_list: List[YandexShiftsSceneRecord] = \ + env_cache.load_env_scenes_list(self.name) + scenes_list: List[SceneMetadata] = list() + for scene_record in all_scenes_list: + scene_split: str = self.metadata.scene_split_map[scene_record.name] + if scene_split in scene_tag and scene_desc_contains is None: + scene_metadata = Scene( + self.metadata, + scene_record.name, + scene_record.data_idx, + scene_record.scene_split, + scene_record.length, + scene_record.data_idx, + None, # This isn't used if everything is already cached. + ) + scenes_list.append(scene_metadata) + return scenes_list + + def get_scene(self, scene_info: SceneMetadata) -> Scene: + _, _, _, data_idx = scene_info + scene_data_from_proto: YSDCScene = read_scene_from_original_proto( + get_scene_path(self.metadata.data_dir, data_idx)) + num_history_timestamps = len(scene_data_from_proto.past_vehicle_tracks) + num_future_timestamps = \ + len(scene_data_from_proto.future_vehicle_tracks) + scene_name: str = scene_data_from_proto.id + scene_split: str = self.metadata.scene_split_map[scene_name] + scene_length: int = num_history_timestamps + num_future_timestamps + return Scene( + self.metadata, + scene_data_from_proto.id, + data_idx, + scene_split, + scene_length, + data_idx, + { + "day_time": scene_data_from_proto.scene_tags.day_time, + "season": scene_data_from_proto.scene_tags.season, + "track_location": scene_data_from_proto.scene_tags.track, + "sun_phase": scene_data_from_proto.scene_tags.sun_phase, + "precipitation": + scene_data_from_proto.scene_tags.precipitation}) + + def get_agent_info( + self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] + ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: + scene_data_from_proto = read_scene_from_original_proto( + get_scene_path(self.metadata.data_dir, scene.raw_data_idx)) + scene_agents_data_df, agent_list, agent_presence = \ + extract_agent_data_from_ysdc_scene(scene_data_from_proto, scene) + cache_class.save_agent_data(scene_agents_data_df, cache_path, scene) + tls_dict = extract_traffic_light_status(scene_data_from_proto) + tls_df = pd.DataFrame( + tls_dict.values(), + index=pd.MultiIndex.from_tuples( + tls_dict.keys(), names=["lane_id", "scene_ts"] + ), + columns=["status"], + ) + cache_class.save_traffic_light_data(tls_df, cache_path, scene) + return agent_list, agent_presence + + def cache_map( + self, + data_idx: int, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], + ): + scene_data_from_proto = read_scene_from_original_proto( + get_scene_path(self.metadata.data_dir, data_idx)) + vector_map: VectorMap = extract_vectorized( + scene_data_from_proto.path_graph, + map_name=f"{self.name}:{data_idx}") + map_cache_class.finalize_and_cache_map( + cache_path, vector_map, map_params) + + def cache_maps( + self, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], + ) -> None: + num_workers: int = map_params.get("num_workers", 0) + if num_workers > 1: + parallel_apply( + partial( + self.cache_map, + cache_path=cache_path, + map_cache_class=map_cache_class, + map_params=map_params, + ), + range(len(get_file_paths(self.metadata.data_dir))), + num_workers=num_workers, + ) + + else: + for i in tqdm.trange(len(get_file_paths(self.metadata.data_dir))): + self.cache_map(i, cache_path, map_cache_class, map_params) diff --git a/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_utils.py b/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_utils.py new file mode 100644 index 0000000..cb5f0d2 --- /dev/null +++ b/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_utils.py @@ -0,0 +1,294 @@ +import os +from collections import defaultdict +from typing import Dict, List, Tuple, Union, Any +import numpy as np +import pandas as pd +from ysdc_dataset_api.proto import Scene as YSDCScene +from ysdc_dataset_api.proto.map_pb2 import PathGraph as YSDCPathGraph +from ysdc_dataset_api.proto.dataset_pb2 import VehicleTrack as YSDCVehicleTrack +from trajdata.maps import TrafficLightStatus, VectorMap +from trajdata.maps.vec_map_elements import ( + PedCrosswalk, Polyline, RoadLane, RoadArea) +from trajdata.data_structures.batch_element import ( + AgentBatchElement, SceneBatchElement) +from trajdata.data_structures.agent import ( + AgentMetadata, AgentType, VariableExtent) +from trajdata.data_structures import Scene as TRAJScene + + +YSDC_DT = 0.2 +YSDC_LENGTH = 50 + + +def fetch_season_info( + batch_element: Union[AgentBatchElement, SceneBatchElement]) -> int: + return batch_element.track_info['season'] + + +def fetch_day_time_info( + batch_element: Union[AgentBatchElement, SceneBatchElement]) -> int: + return batch_element.track_info['day_time'] + + +def fetch_track_location_info( + batch_element: Union[AgentBatchElement, SceneBatchElement]) -> int: + return batch_element.track_info['track_location'] + + +def fetch_sun_phase_info( + batch_element: Union[AgentBatchElement, SceneBatchElement]) -> int: + return batch_element.track_info['sun_phase'] + + +def fetch_precipitation_info( + batch_element: Union[AgentBatchElement, SceneBatchElement]) -> int: + return batch_element.track_info['precipitation'] + + +def read_scene_from_original_proto(path: str) -> YSDCScene: + with open(path, "rb") as f: + scene = YSDCScene() + scene.ParseFromString(f.read()) + return scene + + +def get_scene_path(data_dir: str, scene_idx: int) -> str: + return os.path.join( + data_dir, + str(scene_idx // 1000).zfill(3), + str(scene_idx).zfill(6)) + ".pb" + + +def fix_headings(agents_data_df: pd.DataFrame) -> pd.DataFrame: + headings = agents_data_df['heading'].values + previous_headings = np.roll(headings, 1) + previous_headings[0] = headings[0] + normalized_angle_diff = \ + np.abs(headings - previous_headings) % (2 * np.pi) + fixed_headings = np.where( + np.minimum(normalized_angle_diff, 2 * np.pi - normalized_angle_diff) \ + > np.pi / 2, + headings + np.pi, headings) + agents_data_df['heading'] = fixed_headings + return agents_data_df + + +def fill_missing_timestamps( + agents_data_df: pd.DataFrame, + agent_id_to_time_range: dict) -> pd.DataFrame: + filled_agents_data_df = [] + for agent_id, agent_df in agents_data_df.groupby('agent_id'): + state_idx = 0 + for ts in range( + agent_id_to_time_range[agent_id][0], + agent_id_to_time_range[agent_id][1] + 1): + if state_idx < agent_df.shape[0] and \ + agent_df.iloc[state_idx]["scene_ts"] == ts: + d = agent_df.iloc[state_idx].to_dict() + d["agent_id"] = agent_id + filled_agents_data_df.append(d) + state_idx += 1 + else: + filled_agents_data_df.append({ + "agent_id": agent_id, + "scene_ts": ts, + "x": None, + "y": None, + "z": None, + "vx": None, + "vy": None, + "ax": None, + "ay": None, + "heading": None, + "length": None, + "width": None, + "height": None,}) + return pd.DataFrame(filled_agents_data_df).sort_values( + by=["agent_id", "scene_ts"]) + + +def map_ysdc_to_trajdata_traffic_light_status( + ysdc_tl_status: int) -> TrafficLightStatus: + mapping = { + -1: TrafficLightStatus.NO_DATA, + 0: TrafficLightStatus.UNKNOWN, + 1: TrafficLightStatus.GREEN, + 2: TrafficLightStatus.GREEN, + 3: TrafficLightStatus.RED, + 4: TrafficLightStatus.RED, + 5: TrafficLightStatus.RED, + 6: TrafficLightStatus.UNKNOWN, + 7: TrafficLightStatus.UNKNOWN, + 8: TrafficLightStatus.UNKNOWN, + 9: TrafficLightStatus.UNKNOWN, + 10: TrafficLightStatus.UNKNOWN, + 11: TrafficLightStatus.RED} + return mapping[ysdc_tl_status] + + +def extract_traffic_light_status(ysdc_scene: YSDCScene) -> \ + Dict[Tuple[str, int], TrafficLightStatus]: + traffic_light_data = {} + n_states = len(ysdc_scene.past_vehicle_tracks) + \ + len(ysdc_scene.future_vehicle_tracks) + traffic_light_section_id_to_state = {} + for traffic_light in ysdc_scene.traffic_lights: + for traffic_light_section in traffic_light.sections: + traffic_light_section_id_to_state[traffic_light_section.id] = \ + traffic_light_section.state + for lane_idx, lane in enumerate(ysdc_scene.path_graph.lanes): + # YSDC dataset supports also left_section_id and right_section_id + conventional_lane_id = f"lane_{lane_idx}" + lane_main_section_id = lane.traffic_light_section_ids.main_section_id + if lane_main_section_id not in traffic_light_section_id_to_state: + traffic_light_section_id_to_state[lane_main_section_id] = -1 + ysdc_traffic_light_state = \ + traffic_light_section_id_to_state[lane_main_section_id] + conventional_traffic_light_state = \ + map_ysdc_to_trajdata_traffic_light_status(ysdc_traffic_light_state) + for ts in range(n_states): + traffic_light_data[(conventional_lane_id, ts)] = \ + conventional_traffic_light_state + return traffic_light_data + + +def extract_vectorized( + map_features: YSDCPathGraph, map_name: str) -> VectorMap: + vec_map = VectorMap(map_id=map_name) + max_pt = np.array([np.nan, np.nan]) + min_pt = np.array([np.nan, np.nan]) + + for lane_idx, lane in enumerate(map_features.lanes): + lane_centers = np.array([(node.x, node.y) for node in lane.centers]) + max_pt = np.nanmax(np.vstack([max_pt, lane_centers]), axis=0) + min_pt = np.nanmin(np.vstack([min_pt, lane_centers]), axis=0) + vec_map.add_map_element( + RoadLane( + # YSDC only has center lane + id=f"lane_{lane_idx}", + center=Polyline(lane_centers))) + + for crosswalk_idx, crosswalk in enumerate(map_features.crosswalks): + crosswalk_points = np.array([ + (node.x, node.y) for node in crosswalk.geometry.points]) + max_pt = np.nanmax(np.vstack([max_pt, crosswalk_points]), axis=0) + min_pt = np.nanmin(np.vstack([min_pt, crosswalk_points]), axis=0) + vec_map.add_map_element( + PedCrosswalk( + id=f"crosswalk_{crosswalk_idx}", + polygon=Polyline(crosswalk_points))) + + for road_polygon_idx, road_polygon in enumerate( + map_features.road_polygons): + road_polygon_points = np.array([ + (node.x, node.y) for node in road_polygon.geometry.points]) + max_pt = np.nanmax(np.vstack([max_pt, road_polygon_points]), axis=0) + min_pt = np.nanmin(np.vstack([min_pt, road_polygon_points]), axis=0) + vec_map.add_map_element( + RoadArea( + id=f"road_polygon_{road_polygon_idx}", + exterior_polygon=Polyline(road_polygon_points))) + + vec_map.extent = np.array([*min_pt, 0, *max_pt, 0]) + return vec_map + + +def prepare_agent_info_dict_from_track( + track: YSDCVehicleTrack, scene_ts: int, + entity: AgentType, is_ego: bool = False) -> Dict[str, Any]: + assert entity in [AgentType.VEHICLE, AgentType.PEDESTRIAN] + return { + "agent_id": "ego" if is_ego else str(track.track_id), + "scene_ts": scene_ts, + "x": track.position.x, + "y": track.position.y, + "z": track.position.z, + "vx": track.linear_velocity.x, + "vy": track.linear_velocity.y, + "ax": 0 if entity==AgentType.PEDESTRIAN else \ + track.linear_acceleration.x, + "ay": 0 if entity==AgentType.PEDESTRIAN else \ + track.linear_acceleration.y, + "heading": + np.arctan2(track.linear_velocity.y, track.linear_velocity.x) \ + if entity==AgentType.PEDESTRIAN else track.yaw, + "length": track.dimensions.x, + "width": track.dimensions.y, + "height": track.dimensions.z} + + +def update_time_range( + agent_id: str, timestamp: int, + agent_id_to_time_range: Dict[str, Tuple[float, float]]) -> None: + agent_id_to_time_range[agent_id] = ( + min(agent_id_to_time_range[agent_id][0], timestamp), + max(agent_id_to_time_range[agent_id][1], timestamp)) + + +def extract_agent_data_from_ysdc_scene( + ysdc_scene: YSDCScene, trajdata_scene: TRAJScene) -> \ + Tuple[ + pd.DataFrame, List[AgentMetadata], List[List[AgentMetadata]]]: + agent_list: List[AgentMetadata] = [] + agent_presence: List[List[AgentMetadata]] = [ + [] for _ in range(trajdata_scene.length_timesteps)] + scene_agents_data = defaultdict(list) + agent_id_to_time_range = defaultdict(lambda: (np.inf, -np.inf)) + agent_id_to_type = {"ego": AgentType.VEHICLE} + agents_types_data = [ + (AgentType.VEHICLE, + list(ysdc_scene.past_vehicle_tracks) + \ + list(ysdc_scene.future_vehicle_tracks)), + (AgentType.PEDESTRIAN, + list(ysdc_scene.past_pedestrian_tracks) + \ + list(ysdc_scene.future_pedestrian_tracks))] + ego_agent_data = list(ysdc_scene.past_ego_track) + \ + list(ysdc_scene.future_ego_track) + for agent_type, scene_moment_states in agents_types_data: + for timestamp, scene_moment_state in enumerate( + scene_moment_states): + for agent_moment_state in scene_moment_state.tracks: + agent_info_dict = prepare_agent_info_dict_from_track( + agent_moment_state, timestamp, agent_type) + scene_agents_data[agent_info_dict["agent_id"]] \ + .append(agent_info_dict) + update_time_range( + agent_info_dict["agent_id"], timestamp, + agent_id_to_time_range) + agent_id_to_type[str(agent_moment_state.track_id)] = \ + agent_type + for timestamp, ego_agent_moment_state in enumerate(ego_agent_data): + agent_info_dict = prepare_agent_info_dict_from_track( + ego_agent_moment_state, timestamp, AgentType.VEHICLE, True) + scene_agents_data[agent_info_dict["agent_id"]] \ + .append(agent_info_dict) + update_time_range( + agent_info_dict["agent_id"], timestamp, agent_id_to_time_range) + scene_agents_data_df = pd.DataFrame([ + item for sublist in scene_agents_data.values() for \ + item in sublist]) \ + .sort_values(by=["agent_id", "scene_ts"]) + scene_agents_data_df = fix_headings(scene_agents_data_df) + scene_agents_data_df = fill_missing_timestamps( + scene_agents_data_df, agent_id_to_time_range) + scene_agents_data_df = scene_agents_data_df.groupby( + "agent_id", group_keys=True) \ + .apply(lambda group: group.interpolate(limit_area="inside")) \ + .reset_index(drop=True) + for agent_id in agent_id_to_type.keys(): + agent_list.append(AgentMetadata( + name=agent_id, + agent_type=agent_id_to_type[agent_id], + first_timestep=agent_id_to_time_range[agent_id][0], + last_timestep=agent_id_to_time_range[agent_id][1], + extent=VariableExtent())) + for ts in range( + agent_id_to_time_range[agent_id][0], + agent_id_to_time_range[agent_id][1] + 1): + agent_presence[ts].append(AgentMetadata( + name=agent_id, + agent_type=agent_id_to_type[agent_id], + first_timestep=agent_id_to_time_range[agent_id][0], + last_timestep=agent_id_to_time_range[agent_id][1], + extent=VariableExtent())) + return scene_agents_data_df, agent_list, agent_presence diff --git a/src/trajdata/utils/env_utils.py b/src/trajdata/utils/env_utils.py index 4726537..3974b87 100644 --- a/src/trajdata/utils/env_utils.py +++ b/src/trajdata/utils/env_utils.py @@ -41,6 +41,13 @@ # with the "trajdata[waymo]" option. pass +try: + from trajdata.dataset_specific.yandex_shifts import YandexShiftsDataset +except ModuleNotFoundError: + # This can happen if the user did not install trajdata + # with the "trajdata[ysdc]" option. + pass + def get_raw_dataset(dataset_name: str, data_dir: str) -> RawDataset: if "nusc" in dataset_name: @@ -65,6 +72,10 @@ def get_raw_dataset(dataset_name: str, data_dir: str) -> RawDataset: if "waymo" in dataset_name: return WaymoDataset(dataset_name, data_dir, parallelizable=True, has_maps=True) + if "ysdc" in dataset_name: + return YandexShiftsDataset( + dataset_name, data_dir, parallelizable=True, has_maps=True) + if "interaction" in dataset_name: return InteractionDataset( dataset_name, data_dir, parallelizable=True, has_maps=True diff --git a/src/trajdata/utils/raster_utils.py b/src/trajdata/utils/raster_utils.py index 120981f..cf47669 100644 --- a/src/trajdata/utils/raster_utils.py +++ b/src/trajdata/utils/raster_utils.py @@ -199,22 +199,22 @@ def rasterize_map( line_color=(0, 255, 0), ) - # # This code helps visualize centerlines to check if the inferred headings are correct. - # center_pts = cv2_subpixel( - # transform_points( - # proto_to_np(map_elem.road_lane.center, incl_heading=False), - # raster_from_world, - # ) - # )[..., :2] - - # # Drawing lane centerlines. - # cv2.polylines( - # img=lane_line_img, - # pts=center_pts[None, :, :], - # isClosed=False, - # color=(255, 0, 0), - # **CV2_SUB_VALUES, - # ) + # This code helps visualize centerlines to check if the inferred headings are correct. + center_pts = cv2_subpixel( + map_utils.transform_points( + map_elem.center.xyz, + raster_from_world, + ) + )[..., :2] + + # Drawing lane centerlines. + cv2.polylines( + img=lane_line_img, + pts=center_pts[None, :, :], + isClosed=False, + color=(255, 0, 0), + **CV2_SUB_VALUES, + ) # headings = np.asarray(map_elem.road_lane.center.h_rad) # delta = cv2_subpixel(30*np.array([np.cos(headings[0]), np.sin(headings[0])])) From a5e15a51d2ff06e49f2dd90c40ceb48886d8549e Mon Sep 17 00:00:00 2001 From: Stepan Konev Date: Tue, 10 Oct 2023 00:44:45 +0200 Subject: [PATCH 2/4] ysdc code formatting --- src/trajdata/data_structures/batch_element.py | 1 - .../dataset_specific/scene_records.py | 1 - .../yandex_shifts/yandex_shifts_dataset.py | 72 +++-- .../yandex_shifts/yandex_shifts_utils.py | 300 ++++++++++-------- src/trajdata/utils/env_utils.py | 3 +- 5 files changed, 208 insertions(+), 169 deletions(-) diff --git a/src/trajdata/data_structures/batch_element.py b/src/trajdata/data_structures/batch_element.py index 6a26e82..80042e6 100644 --- a/src/trajdata/data_structures/batch_element.py +++ b/src/trajdata/data_structures/batch_element.py @@ -508,7 +508,6 @@ def get_agents_future( future_sec: Tuple[Optional[float], Optional[float]], nearby_agents: List[AgentMetadata], ) -> Tuple[List[StateArray], List[np.ndarray], np.ndarray]: - ( agent_futures, agent_future_extents, diff --git a/src/trajdata/dataset_specific/scene_records.py b/src/trajdata/dataset_specific/scene_records.py index 5790713..36a83e4 100644 --- a/src/trajdata/dataset_specific/scene_records.py +++ b/src/trajdata/dataset_specific/scene_records.py @@ -59,4 +59,3 @@ class YandexShiftsSceneRecord(NamedTuple): track: str sun_phase: str precipitation: str - diff --git a/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_dataset.py b/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_dataset.py index ce5320d..3dc4ade 100644 --- a/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_dataset.py +++ b/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_dataset.py @@ -9,8 +9,7 @@ from ysdc_dataset_api.utils import get_file_paths, scenes_generator from ysdc_dataset_api.proto import Scene as YSDCScene from trajdata.caching import EnvCache, SceneCache -from trajdata.data_structures import ( - EnvMetadata, Scene, SceneMetadata, SceneTag) +from trajdata.data_structures import EnvMetadata, Scene, SceneMetadata, SceneTag from trajdata.data_structures.agent import AgentMetadata from trajdata.dataset_specific.raw_dataset import RawDataset from trajdata.dataset_specific.scene_records import YandexShiftsSceneRecord @@ -18,8 +17,12 @@ from trajdata.maps import VectorMap from trajdata.utils.parallel_utils import parallel_apply from trajdata.dataset_specific.yandex_shifts.yandex_shifts_utils import ( - read_scene_from_original_proto, get_scene_path, extract_vectorized, - extract_traffic_light_status, extract_agent_data_from_ysdc_scene) + read_scene_from_original_proto, + get_scene_path, + extract_vectorized, + extract_traffic_light_status, + extract_agent_data_from_ysdc_scene, +) def const_lambda(const_val: Any) -> Any: @@ -30,23 +33,21 @@ class YandexShiftsDataset(RawDataset): def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: if env_name == "ysdc_train": dataset_parts = [("train",)] - scene_split_map = defaultdict( - partial(const_lambda, const_val="train")) + scene_split_map = defaultdict(partial(const_lambda, const_val="train")) elif env_name == "ysdc_development": dataset_parts = [("development",)] - scene_split_map = defaultdict(partial( - const_lambda, const_val="development")) + scene_split_map = defaultdict( + partial(const_lambda, const_val="development") + ) elif env_name == "ysdc_eval": dataset_parts = [("eval",)] - scene_split_map = defaultdict(partial( - const_lambda, const_val="eval")) + scene_split_map = defaultdict(partial(const_lambda, const_val="eval")) elif env_name == "ysdc_full": dataset_parts = [("full",)] - scene_split_map = defaultdict(partial( - const_lambda, const_val="full")) + scene_split_map = defaultdict(partial(const_lambda, const_val="full")) return EnvMetadata( name=env_name, @@ -59,8 +60,7 @@ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: def load_dataset_obj(self, verbose: bool = False) -> None: if verbose: print(f"Loading {self.name} dataset...", flush=True) - self.dataset_obj = scenes_generator( - get_file_paths(self.metadata.data_dir)) + self.dataset_obj = scenes_generator(get_file_paths(self.metadata.data_dir)) def _get_matching_scenes_from_obj( self, @@ -71,8 +71,8 @@ def _get_matching_scenes_from_obj( all_scenes_list: List[YandexShiftsSceneRecord] = list() scenes_list: List[SceneMetadata] = list() for idx, scene in tqdm.tqdm( - enumerate(self.dataset_obj), - desc="Processing scenes from proto files"): + enumerate(self.dataset_obj), desc="Processing scenes from proto files" + ): scene_name: str = scene.id scene_split: str = self.metadata.scene_split_map[scene_name] scene_length: int = yandex_shifts_utils.YSDC_LENGTH @@ -86,7 +86,9 @@ def _get_matching_scenes_from_obj( scene.scene_tags.season, scene.scene_tags.track, scene.scene_tags.sun_phase, - scene.scene_tags.precipitation)) + scene.scene_tags.precipitation, + ) + ) if scene_split in scene_tag and scene_desc_contains is None: scene_metadata = SceneMetadata( env_name=self.metadata.name, @@ -104,8 +106,9 @@ def _get_matching_scenes_from_cache( scene_desc_contains: Optional[List[str]], env_cache: EnvCache, ) -> List[Scene]: - all_scenes_list: List[YandexShiftsSceneRecord] = \ - env_cache.load_env_scenes_list(self.name) + all_scenes_list: List[YandexShiftsSceneRecord] = env_cache.load_env_scenes_list( + self.name + ) scenes_list: List[SceneMetadata] = list() for scene_record in all_scenes_list: scene_split: str = self.metadata.scene_split_map[scene_record.name] @@ -125,10 +128,10 @@ def _get_matching_scenes_from_cache( def get_scene(self, scene_info: SceneMetadata) -> Scene: _, _, _, data_idx = scene_info scene_data_from_proto: YSDCScene = read_scene_from_original_proto( - get_scene_path(self.metadata.data_dir, data_idx)) + get_scene_path(self.metadata.data_dir, data_idx) + ) num_history_timestamps = len(scene_data_from_proto.past_vehicle_tracks) - num_future_timestamps = \ - len(scene_data_from_proto.future_vehicle_tracks) + num_future_timestamps = len(scene_data_from_proto.future_vehicle_tracks) scene_name: str = scene_data_from_proto.id scene_split: str = self.metadata.scene_split_map[scene_name] scene_length: int = num_history_timestamps + num_future_timestamps @@ -144,16 +147,21 @@ def get_scene(self, scene_info: SceneMetadata) -> Scene: "season": scene_data_from_proto.scene_tags.season, "track_location": scene_data_from_proto.scene_tags.track, "sun_phase": scene_data_from_proto.scene_tags.sun_phase, - "precipitation": - scene_data_from_proto.scene_tags.precipitation}) + "precipitation": scene_data_from_proto.scene_tags.precipitation, + }, + ) def get_agent_info( self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: scene_data_from_proto = read_scene_from_original_proto( - get_scene_path(self.metadata.data_dir, scene.raw_data_idx)) - scene_agents_data_df, agent_list, agent_presence = \ - extract_agent_data_from_ysdc_scene(scene_data_from_proto, scene) + get_scene_path(self.metadata.data_dir, scene.raw_data_idx) + ) + ( + scene_agents_data_df, + agent_list, + agent_presence, + ) = extract_agent_data_from_ysdc_scene(scene_data_from_proto, scene) cache_class.save_agent_data(scene_agents_data_df, cache_path, scene) tls_dict = extract_traffic_light_status(scene_data_from_proto) tls_df = pd.DataFrame( @@ -174,12 +182,12 @@ def cache_map( map_params: Dict[str, Any], ): scene_data_from_proto = read_scene_from_original_proto( - get_scene_path(self.metadata.data_dir, data_idx)) + get_scene_path(self.metadata.data_dir, data_idx) + ) vector_map: VectorMap = extract_vectorized( - scene_data_from_proto.path_graph, - map_name=f"{self.name}:{data_idx}") - map_cache_class.finalize_and_cache_map( - cache_path, vector_map, map_params) + scene_data_from_proto.path_graph, map_name=f"{self.name}:{data_idx}" + ) + map_cache_class.finalize_and_cache_map(cache_path, vector_map, map_params) def cache_maps( self, diff --git a/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_utils.py b/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_utils.py index cb5f0d2..1581fa9 100644 --- a/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_utils.py +++ b/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_utils.py @@ -7,12 +7,9 @@ from ysdc_dataset_api.proto.map_pb2 import PathGraph as YSDCPathGraph from ysdc_dataset_api.proto.dataset_pb2 import VehicleTrack as YSDCVehicleTrack from trajdata.maps import TrafficLightStatus, VectorMap -from trajdata.maps.vec_map_elements import ( - PedCrosswalk, Polyline, RoadLane, RoadArea) -from trajdata.data_structures.batch_element import ( - AgentBatchElement, SceneBatchElement) -from trajdata.data_structures.agent import ( - AgentMetadata, AgentType, VariableExtent) +from trajdata.maps.vec_map_elements import PedCrosswalk, Polyline, RoadLane, RoadArea +from trajdata.data_structures.batch_element import AgentBatchElement, SceneBatchElement +from trajdata.data_structures.agent import AgentMetadata, AgentType, VariableExtent from trajdata.data_structures import Scene as TRAJScene @@ -21,28 +18,33 @@ def fetch_season_info( - batch_element: Union[AgentBatchElement, SceneBatchElement]) -> int: - return batch_element.track_info['season'] + batch_element: Union[AgentBatchElement, SceneBatchElement] +) -> int: + return batch_element.track_info["season"] def fetch_day_time_info( - batch_element: Union[AgentBatchElement, SceneBatchElement]) -> int: - return batch_element.track_info['day_time'] + batch_element: Union[AgentBatchElement, SceneBatchElement] +) -> int: + return batch_element.track_info["day_time"] def fetch_track_location_info( - batch_element: Union[AgentBatchElement, SceneBatchElement]) -> int: - return batch_element.track_info['track_location'] + batch_element: Union[AgentBatchElement, SceneBatchElement] +) -> int: + return batch_element.track_info["track_location"] def fetch_sun_phase_info( - batch_element: Union[AgentBatchElement, SceneBatchElement]) -> int: - return batch_element.track_info['sun_phase'] + batch_element: Union[AgentBatchElement, SceneBatchElement] +) -> int: + return batch_element.track_info["sun_phase"] def fetch_precipitation_info( - batch_element: Union[AgentBatchElement, SceneBatchElement]) -> int: - return batch_element.track_info['precipitation'] + batch_element: Union[AgentBatchElement, SceneBatchElement] +) -> int: + return batch_element.track_info["precipitation"] def read_scene_from_original_proto(path: str) -> YSDCScene: @@ -53,62 +55,68 @@ def read_scene_from_original_proto(path: str) -> YSDCScene: def get_scene_path(data_dir: str, scene_idx: int) -> str: - return os.path.join( - data_dir, - str(scene_idx // 1000).zfill(3), - str(scene_idx).zfill(6)) + ".pb" + return ( + os.path.join(data_dir, str(scene_idx // 1000).zfill(3), str(scene_idx).zfill(6)) + + ".pb" + ) def fix_headings(agents_data_df: pd.DataFrame) -> pd.DataFrame: - headings = agents_data_df['heading'].values + headings = agents_data_df["heading"].values previous_headings = np.roll(headings, 1) previous_headings[0] = headings[0] - normalized_angle_diff = \ - np.abs(headings - previous_headings) % (2 * np.pi) + normalized_angle_diff = np.abs(headings - previous_headings) % (2 * np.pi) fixed_headings = np.where( - np.minimum(normalized_angle_diff, 2 * np.pi - normalized_angle_diff) \ - > np.pi / 2, - headings + np.pi, headings) - agents_data_df['heading'] = fixed_headings + np.minimum(normalized_angle_diff, 2 * np.pi - normalized_angle_diff) + > np.pi / 2, + headings + np.pi, + headings, + ) + agents_data_df["heading"] = fixed_headings return agents_data_df def fill_missing_timestamps( - agents_data_df: pd.DataFrame, - agent_id_to_time_range: dict) -> pd.DataFrame: + agents_data_df: pd.DataFrame, agent_id_to_time_range: dict +) -> pd.DataFrame: filled_agents_data_df = [] - for agent_id, agent_df in agents_data_df.groupby('agent_id'): + for agent_id, agent_df in agents_data_df.groupby("agent_id"): state_idx = 0 for ts in range( - agent_id_to_time_range[agent_id][0], - agent_id_to_time_range[agent_id][1] + 1): - if state_idx < agent_df.shape[0] and \ - agent_df.iloc[state_idx]["scene_ts"] == ts: + agent_id_to_time_range[agent_id][0], agent_id_to_time_range[agent_id][1] + 1 + ): + if ( + state_idx < agent_df.shape[0] + and agent_df.iloc[state_idx]["scene_ts"] == ts + ): d = agent_df.iloc[state_idx].to_dict() d["agent_id"] = agent_id filled_agents_data_df.append(d) state_idx += 1 else: - filled_agents_data_df.append({ - "agent_id": agent_id, - "scene_ts": ts, - "x": None, - "y": None, - "z": None, - "vx": None, - "vy": None, - "ax": None, - "ay": None, - "heading": None, - "length": None, - "width": None, - "height": None,}) - return pd.DataFrame(filled_agents_data_df).sort_values( - by=["agent_id", "scene_ts"]) + filled_agents_data_df.append( + { + "agent_id": agent_id, + "scene_ts": ts, + "x": None, + "y": None, + "z": None, + "vx": None, + "vy": None, + "ax": None, + "ay": None, + "heading": None, + "length": None, + "width": None, + "height": None, + } + ) + return pd.DataFrame(filled_agents_data_df).sort_values(by=["agent_id", "scene_ts"]) def map_ysdc_to_trajdata_traffic_light_status( - ysdc_tl_status: int) -> TrafficLightStatus: + ysdc_tl_status: int, +) -> TrafficLightStatus: mapping = { -1: TrafficLightStatus.NO_DATA, 0: TrafficLightStatus.UNKNOWN, @@ -122,38 +130,44 @@ def map_ysdc_to_trajdata_traffic_light_status( 8: TrafficLightStatus.UNKNOWN, 9: TrafficLightStatus.UNKNOWN, 10: TrafficLightStatus.UNKNOWN, - 11: TrafficLightStatus.RED} + 11: TrafficLightStatus.RED, + } return mapping[ysdc_tl_status] -def extract_traffic_light_status(ysdc_scene: YSDCScene) -> \ - Dict[Tuple[str, int], TrafficLightStatus]: +def extract_traffic_light_status( + ysdc_scene: YSDCScene, +) -> Dict[Tuple[str, int], TrafficLightStatus]: traffic_light_data = {} - n_states = len(ysdc_scene.past_vehicle_tracks) + \ - len(ysdc_scene.future_vehicle_tracks) + n_states = len(ysdc_scene.past_vehicle_tracks) + len( + ysdc_scene.future_vehicle_tracks + ) traffic_light_section_id_to_state = {} for traffic_light in ysdc_scene.traffic_lights: for traffic_light_section in traffic_light.sections: - traffic_light_section_id_to_state[traffic_light_section.id] = \ - traffic_light_section.state + traffic_light_section_id_to_state[ + traffic_light_section.id + ] = traffic_light_section.state for lane_idx, lane in enumerate(ysdc_scene.path_graph.lanes): # YSDC dataset supports also left_section_id and right_section_id conventional_lane_id = f"lane_{lane_idx}" lane_main_section_id = lane.traffic_light_section_ids.main_section_id if lane_main_section_id not in traffic_light_section_id_to_state: traffic_light_section_id_to_state[lane_main_section_id] = -1 - ysdc_traffic_light_state = \ - traffic_light_section_id_to_state[lane_main_section_id] - conventional_traffic_light_state = \ - map_ysdc_to_trajdata_traffic_light_status(ysdc_traffic_light_state) + ysdc_traffic_light_state = traffic_light_section_id_to_state[ + lane_main_section_id + ] + conventional_traffic_light_state = map_ysdc_to_trajdata_traffic_light_status( + ysdc_traffic_light_state + ) for ts in range(n_states): - traffic_light_data[(conventional_lane_id, ts)] = \ - conventional_traffic_light_state + traffic_light_data[ + (conventional_lane_id, ts) + ] = conventional_traffic_light_state return traffic_light_data -def extract_vectorized( - map_features: YSDCPathGraph, map_name: str) -> VectorMap: +def extract_vectorized(map_features: YSDCPathGraph, map_name: str) -> VectorMap: vec_map = VectorMap(map_id=map_name) max_pt = np.array([np.nan, np.nan]) min_pt = np.array([np.nan, np.nan]) @@ -166,36 +180,42 @@ def extract_vectorized( RoadLane( # YSDC only has center lane id=f"lane_{lane_idx}", - center=Polyline(lane_centers))) + center=Polyline(lane_centers), + ) + ) for crosswalk_idx, crosswalk in enumerate(map_features.crosswalks): - crosswalk_points = np.array([ - (node.x, node.y) for node in crosswalk.geometry.points]) + crosswalk_points = np.array( + [(node.x, node.y) for node in crosswalk.geometry.points] + ) max_pt = np.nanmax(np.vstack([max_pt, crosswalk_points]), axis=0) min_pt = np.nanmin(np.vstack([min_pt, crosswalk_points]), axis=0) vec_map.add_map_element( PedCrosswalk( - id=f"crosswalk_{crosswalk_idx}", - polygon=Polyline(crosswalk_points))) - - for road_polygon_idx, road_polygon in enumerate( - map_features.road_polygons): - road_polygon_points = np.array([ - (node.x, node.y) for node in road_polygon.geometry.points]) + id=f"crosswalk_{crosswalk_idx}", polygon=Polyline(crosswalk_points) + ) + ) + + for road_polygon_idx, road_polygon in enumerate(map_features.road_polygons): + road_polygon_points = np.array( + [(node.x, node.y) for node in road_polygon.geometry.points] + ) max_pt = np.nanmax(np.vstack([max_pt, road_polygon_points]), axis=0) min_pt = np.nanmin(np.vstack([min_pt, road_polygon_points]), axis=0) vec_map.add_map_element( RoadArea( id=f"road_polygon_{road_polygon_idx}", - exterior_polygon=Polyline(road_polygon_points))) + exterior_polygon=Polyline(road_polygon_points), + ) + ) vec_map.extent = np.array([*min_pt, 0, *max_pt, 0]) return vec_map def prepare_agent_info_dict_from_track( - track: YSDCVehicleTrack, scene_ts: int, - entity: AgentType, is_ego: bool = False) -> Dict[str, Any]: + track: YSDCVehicleTrack, scene_ts: int, entity: AgentType, is_ego: bool = False +) -> Dict[str, Any]: assert entity in [AgentType.VEHICLE, AgentType.PEDESTRIAN] return { "agent_id": "ego" if is_ego else str(track.track_id), @@ -205,90 +225,102 @@ def prepare_agent_info_dict_from_track( "z": track.position.z, "vx": track.linear_velocity.x, "vy": track.linear_velocity.y, - "ax": 0 if entity==AgentType.PEDESTRIAN else \ - track.linear_acceleration.x, - "ay": 0 if entity==AgentType.PEDESTRIAN else \ - track.linear_acceleration.y, - "heading": - np.arctan2(track.linear_velocity.y, track.linear_velocity.x) \ - if entity==AgentType.PEDESTRIAN else track.yaw, + "ax": 0 if entity == AgentType.PEDESTRIAN else track.linear_acceleration.x, + "ay": 0 if entity == AgentType.PEDESTRIAN else track.linear_acceleration.y, + "heading": np.arctan2(track.linear_velocity.y, track.linear_velocity.x) + if entity == AgentType.PEDESTRIAN + else track.yaw, "length": track.dimensions.x, "width": track.dimensions.y, - "height": track.dimensions.z} + "height": track.dimensions.z, + } def update_time_range( - agent_id: str, timestamp: int, - agent_id_to_time_range: Dict[str, Tuple[float, float]]) -> None: + agent_id: str, + timestamp: int, + agent_id_to_time_range: Dict[str, Tuple[float, float]], +) -> None: agent_id_to_time_range[agent_id] = ( min(agent_id_to_time_range[agent_id][0], timestamp), - max(agent_id_to_time_range[agent_id][1], timestamp)) + max(agent_id_to_time_range[agent_id][1], timestamp), + ) def extract_agent_data_from_ysdc_scene( - ysdc_scene: YSDCScene, trajdata_scene: TRAJScene) -> \ - Tuple[ - pd.DataFrame, List[AgentMetadata], List[List[AgentMetadata]]]: + ysdc_scene: YSDCScene, trajdata_scene: TRAJScene +) -> Tuple[pd.DataFrame, List[AgentMetadata], List[List[AgentMetadata]]]: agent_list: List[AgentMetadata] = [] agent_presence: List[List[AgentMetadata]] = [ - [] for _ in range(trajdata_scene.length_timesteps)] + [] for _ in range(trajdata_scene.length_timesteps) + ] scene_agents_data = defaultdict(list) agent_id_to_time_range = defaultdict(lambda: (np.inf, -np.inf)) agent_id_to_type = {"ego": AgentType.VEHICLE} agents_types_data = [ - (AgentType.VEHICLE, - list(ysdc_scene.past_vehicle_tracks) + \ - list(ysdc_scene.future_vehicle_tracks)), - (AgentType.PEDESTRIAN, - list(ysdc_scene.past_pedestrian_tracks) + \ - list(ysdc_scene.future_pedestrian_tracks))] - ego_agent_data = list(ysdc_scene.past_ego_track) + \ - list(ysdc_scene.future_ego_track) + ( + AgentType.VEHICLE, + list(ysdc_scene.past_vehicle_tracks) + + list(ysdc_scene.future_vehicle_tracks), + ), + ( + AgentType.PEDESTRIAN, + list(ysdc_scene.past_pedestrian_tracks) + + list(ysdc_scene.future_pedestrian_tracks), + ), + ] + ego_agent_data = list(ysdc_scene.past_ego_track) + list(ysdc_scene.future_ego_track) for agent_type, scene_moment_states in agents_types_data: - for timestamp, scene_moment_state in enumerate( - scene_moment_states): + for timestamp, scene_moment_state in enumerate(scene_moment_states): for agent_moment_state in scene_moment_state.tracks: agent_info_dict = prepare_agent_info_dict_from_track( - agent_moment_state, timestamp, agent_type) - scene_agents_data[agent_info_dict["agent_id"]] \ - .append(agent_info_dict) + agent_moment_state, timestamp, agent_type + ) + scene_agents_data[agent_info_dict["agent_id"]].append(agent_info_dict) update_time_range( - agent_info_dict["agent_id"], timestamp, - agent_id_to_time_range) - agent_id_to_type[str(agent_moment_state.track_id)] = \ - agent_type + agent_info_dict["agent_id"], timestamp, agent_id_to_time_range + ) + agent_id_to_type[str(agent_moment_state.track_id)] = agent_type for timestamp, ego_agent_moment_state in enumerate(ego_agent_data): agent_info_dict = prepare_agent_info_dict_from_track( - ego_agent_moment_state, timestamp, AgentType.VEHICLE, True) - scene_agents_data[agent_info_dict["agent_id"]] \ - .append(agent_info_dict) + ego_agent_moment_state, timestamp, AgentType.VEHICLE, True + ) + scene_agents_data[agent_info_dict["agent_id"]].append(agent_info_dict) update_time_range( - agent_info_dict["agent_id"], timestamp, agent_id_to_time_range) - scene_agents_data_df = pd.DataFrame([ - item for sublist in scene_agents_data.values() for \ - item in sublist]) \ - .sort_values(by=["agent_id", "scene_ts"]) + agent_info_dict["agent_id"], timestamp, agent_id_to_time_range + ) + scene_agents_data_df = pd.DataFrame( + [item for sublist in scene_agents_data.values() for item in sublist] + ).sort_values(by=["agent_id", "scene_ts"]) scene_agents_data_df = fix_headings(scene_agents_data_df) scene_agents_data_df = fill_missing_timestamps( - scene_agents_data_df, agent_id_to_time_range) - scene_agents_data_df = scene_agents_data_df.groupby( - "agent_id", group_keys=True) \ - .apply(lambda group: group.interpolate(limit_area="inside")) \ + scene_agents_data_df, agent_id_to_time_range + ) + scene_agents_data_df = ( + scene_agents_data_df.groupby("agent_id", group_keys=True) + .apply(lambda group: group.interpolate(limit_area="inside")) .reset_index(drop=True) + ) for agent_id in agent_id_to_type.keys(): - agent_list.append(AgentMetadata( - name=agent_id, - agent_type=agent_id_to_type[agent_id], - first_timestep=agent_id_to_time_range[agent_id][0], - last_timestep=agent_id_to_time_range[agent_id][1], - extent=VariableExtent())) - for ts in range( - agent_id_to_time_range[agent_id][0], - agent_id_to_time_range[agent_id][1] + 1): - agent_presence[ts].append(AgentMetadata( + agent_list.append( + AgentMetadata( name=agent_id, agent_type=agent_id_to_type[agent_id], first_timestep=agent_id_to_time_range[agent_id][0], last_timestep=agent_id_to_time_range[agent_id][1], - extent=VariableExtent())) + extent=VariableExtent(), + ) + ) + for ts in range( + agent_id_to_time_range[agent_id][0], agent_id_to_time_range[agent_id][1] + 1 + ): + agent_presence[ts].append( + AgentMetadata( + name=agent_id, + agent_type=agent_id_to_type[agent_id], + first_timestep=agent_id_to_time_range[agent_id][0], + last_timestep=agent_id_to_time_range[agent_id][1], + extent=VariableExtent(), + ) + ) return scene_agents_data_df, agent_list, agent_presence diff --git a/src/trajdata/utils/env_utils.py b/src/trajdata/utils/env_utils.py index 3974b87..bee3ef6 100644 --- a/src/trajdata/utils/env_utils.py +++ b/src/trajdata/utils/env_utils.py @@ -74,7 +74,8 @@ def get_raw_dataset(dataset_name: str, data_dir: str) -> RawDataset: if "ysdc" in dataset_name: return YandexShiftsDataset( - dataset_name, data_dir, parallelizable=True, has_maps=True) + dataset_name, data_dir, parallelizable=True, has_maps=True + ) if "interaction" in dataset_name: return InteractionDataset( From 06f73afa80fc9fc4fc5130e9ce7da7a77b5867ba Mon Sep 17 00:00:00 2001 From: Stepan Konev Date: Thu, 19 Oct 2023 23:58:54 +0200 Subject: [PATCH 3/4] minor fix for existing cache --- .../dataset_specific/yandex_shifts/yandex_shifts_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_dataset.py b/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_dataset.py index 3dc4ade..041cf43 100644 --- a/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_dataset.py +++ b/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_dataset.py @@ -117,7 +117,7 @@ def _get_matching_scenes_from_cache( self.metadata, scene_record.name, scene_record.data_idx, - scene_record.scene_split, + scene_split, scene_record.length, scene_record.data_idx, None, # This isn't used if everything is already cached. From c9a154fa676818de2320e20d386b0e50b5300e42 Mon Sep 17 00:00:00 2001 From: Stepan Konev Date: Sat, 21 Oct 2023 03:11:52 +0200 Subject: [PATCH 4/4] add resetting index for cache table --- .../dataset_specific/yandex_shifts/yandex_shifts_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_utils.py b/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_utils.py index 1581fa9..ef38430 100644 --- a/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_utils.py +++ b/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_utils.py @@ -300,6 +300,7 @@ def extract_agent_data_from_ysdc_scene( scene_agents_data_df.groupby("agent_id", group_keys=True) .apply(lambda group: group.interpolate(limit_area="inside")) .reset_index(drop=True) + .set_index(["agent_id", "scene_ts"]) ) for agent_id in agent_id_to_type.keys(): agent_list.append(