From c19fe8c73f5d1c5e7c5a31de26075527ff319aae Mon Sep 17 00:00:00 2001 From: Xinshuo Weng Date: Fri, 20 Mar 2026 11:35:33 -0400 Subject: [PATCH 1/3] Add source change needed by running CoC auto-labeling pipeline --- .../dataset_specific/mads/__init__.py | 20 + .../dataset_specific/mads/constant.py | 118 +++ .../dataset_specific/mads/mads_dataset.py | 820 ++++++++++++++++++ .../dataset_specific/mads/mads_utils.py | 550 ++++++++++++ .../dataset_specific/mads/tar_extractor.py | 77 ++ src/trajdata/dataset_specific/pai/__init__.py | 20 + .../dataset_specific/pai/pai_dataset.py | 399 +++++++++ .../dataset_specific/scene_records.py | 15 + src/trajdata/utils/env_utils.py | 9 + 9 files changed, 2028 insertions(+) create mode 100644 src/trajdata/dataset_specific/mads/__init__.py create mode 100644 src/trajdata/dataset_specific/mads/constant.py create mode 100644 src/trajdata/dataset_specific/mads/mads_dataset.py create mode 100644 src/trajdata/dataset_specific/mads/mads_utils.py create mode 100644 src/trajdata/dataset_specific/mads/tar_extractor.py create mode 100644 src/trajdata/dataset_specific/pai/__init__.py create mode 100644 src/trajdata/dataset_specific/pai/pai_dataset.py diff --git a/src/trajdata/dataset_specific/mads/__init__.py b/src/trajdata/dataset_specific/mads/__init__.py new file mode 100644 index 0000000..f7ed469 --- /dev/null +++ b/src/trajdata/dataset_specific/mads/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MADS dataset package exports.""" + +from .mads_dataset import MADSDataset + +__all__ = ["MADSDataset"] diff --git a/src/trajdata/dataset_specific/mads/constant.py b/src/trajdata/dataset_specific/mads/constant.py new file mode 100644 index 0000000..34ebd25 --- /dev/null +++ b/src/trajdata/dataset_specific/mads/constant.py @@ -0,0 +1,118 @@ +# SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Constants and enums for the MADS dataset loader.""" + +import os +from enum import Enum +from typing import Final, FrozenSet, Optional + +# Core dataset timing/config constants. +MADS_DT: Final[float] = 0.1 +EGO_LENGTH: Final[float] = 5.2993629 +EGO_WIDTH: Final[float] = 2.11311007 +EGO_HEIGHT: Final[float] = 1.34435794 + +# by default, we support internal production v2 dataset, with clipgt-2.0.0 and above. +DATA_SRC: Final[str] = "v2" + +# Allowed MADS data source tags. +SUPPORTED_DATA_SRCS: Final[FrozenSet[str]] = frozenset({"v2", "pai"}) + +# Environment variable name for runtime override. +DATA_SRC_ENV_VAR: Final[str] = "MADS_DATA_SRC" + + +def resolve_data_src(cli_override: Optional[str] = None) -> str: + """Resolve dataset source with priority: CLI override > env var > default.""" + candidate = (cli_override or os.getenv(DATA_SRC_ENV_VAR) or DATA_SRC).strip() + if candidate not in SUPPORTED_DATA_SRCS: + supported = ", ".join(sorted(SUPPORTED_DATA_SRCS)) + raise ValueError( + f"Unsupported data source '{candidate}'. " + f"Supported values: {supported}." + ) + return candidate + +# Minimum frames for an agent to be considered. +MIN_FRAMES: Final[int] = 10 + +USE_CUBIC_INTERPOLATION: Final[bool] = False + + +class ObstacleClassV1(Enum): + """Obstacle classes for MADS based on NDAS `obstacle_types.proto`.""" + + # buf:lint:ignore ENUM_ZERO_VALUE_SUFFIX + OBSTACLE_CLASS_INVALID = 0 + # Vehicles + OBSTACLE_CLASS_VEHICLE_UNKNOWN = 1280 + OBSTACLE_CLASS_VEHICLE_CAR = 1281 + OBSTACLE_CLASS_VEHICLE_TRUCK = 1282 + OBSTACLE_CLASS_VEHICLE_BUS = 1283 + OBSTACLE_CLASS_VEHICLE_EMERGENCY = 1284 + OBSTACLE_CLASS_VEHICLE_CONSTRUCTION = 1285 + OBSTACLE_CLASS_VEHICLE_POLICE = 1286 + OBSTACLE_CLASS_VEHICLE_SCHOOL_BUS = 1287 + # Two wheeled vehicles + OBSTACLE_CLASS_BIKE_UNKNOWN = 2304 + OBSTACLE_CLASS_BIKE = 2305 + OBSTACLE_CLASS_BIKE_MOTOR = 2306 + # Two wheeled vehicles with rider + OBSTACLE_CLASS_BIKE_UNKNOWN_WITH_RIDER = 2307 + OBSTACLE_CLASS_BIKE_WITH_RIDER = 2308 + OBSTACLE_CLASS_BIKE_MOTOR_WITH_RIDER = 2309 + OBSTACLE_CLASS_TRICYCLE = 2310 + # Pedestrians + OBSTACLE_CLASS_PEDESTRIAN_UNKNOWN = 4352 + OBSTACLE_CLASS_PEDESTRIAN_ADULT = 4353 + OBSTACLE_CLASS_PEDESTRIAN_CHILD = 4354 + OBSTACLE_CLASS_PEDESTRIAN_CONSTRUCTION_WORKER = 4355 + OBSTACLE_CLASS_PEDESTRIAN_OFFICIAL = 4356 + # Animals + OBSTACLE_CLASS_ANIMAL_UNKNOWN = 8448 + OBSTACLE_CLASS_ANIMAL_SMALL = 8449 + OBSTACLE_CLASS_ANIMAL_MEDIUM = 8450 + OBSTACLE_CLASS_ANIMAL_LARGE = 8451 + # Objects + OBSTACLE_CLASS_OBJECT_UNKNOWN = 16384 + OBSTACLE_CLASS_OBJECT_HAZARD = 16385 + OBSTACLE_CLASS_OBJECT_OFFICIAL = 16386 + OBSTACLE_CLASS_OBJECT_CONE = 16387 + OBSTACLE_CLASS_OBJECT_CURB = 16388 + OBSTACLE_CLASS_OBJECT_BARRIER = 16389 + OBSTACLE_CLASS_OBJECT_PHYSICAL_LANE_LINE = 16390 + OBSTACLE_CLASS_OBJECT_HARD_OVERHANG = 16391 + OBSTACLE_CLASS_OBJECT_SOFT_OVERHANG = 16392 + OBSTACLE_CLASS_OBJECT_SPEED_BUMP = 16393 + OBSTACLE_CLASS_OBJECT_POT_HOLE = 16394 + OBSTACLE_CLASS_OBJECT_NEGATIVE = 16395 + OBSTACLE_CLASS_OBJECT_FORBIDDEN = 16396 + # Other + OBSTACLE_CLASS_UNDEFINED_STATIC = 32768 + OBSTACLE_CLASS_UNDEFINED_ANIMATE = 33024 + # Virtual obstacles + OBSTACLE_CLASS_FORBIDDEN_LANE_LINE = 513 + OBSTACLE_CLASS_CROSSABLE_LANE_LINE = 514 + OBSTACLE_CLASS_GORE_AREA = 515 + OBSTACLE_CLASS_OCCLUSION = 516 + + @classmethod + def get_enum_name(cls, value: int) -> Optional[str]: + """Return enum member name for a numeric value, or `None` if not found.""" + for name, member in cls.__members__.items(): + if member.value == value: + return name + return None diff --git a/src/trajdata/dataset_specific/mads/mads_dataset.py b/src/trajdata/dataset_specific/mads/mads_dataset.py new file mode 100644 index 0000000..6f8046b --- /dev/null +++ b/src/trajdata/dataset_specific/mads/mads_dataset.py @@ -0,0 +1,820 @@ +# SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import glob +import os +import random +import traceback +from functools import partial +from pathlib import Path +from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast + +import numpy as np +import pandas as pd +from scipy.interpolate import CubicSpline +from scipy.spatial.transform import Rotation as R +from scipy.spatial.transform import Slerp +from tqdm import tqdm + +from trajdata.caching import EnvCache, SceneCache +from trajdata.data_structures.agent import AgentMetadata, FixedExtent +from trajdata.data_structures.environment import EnvMetadata +from trajdata.data_structures.scene_metadata import Scene, SceneMetadata +from trajdata.data_structures.scene_tag import SceneTag +from trajdata.dataset_specific.mads import mads_utils +from trajdata.dataset_specific.mads.constant import ( + EGO_HEIGHT, + EGO_LENGTH, + EGO_WIDTH, + MADS_DT, + MIN_FRAMES, + SUPPORTED_DATA_SRCS, + USE_CUBIC_INTERPOLATION, + ObstacleClassV1, + resolve_data_src, +) +from trajdata.dataset_specific.mads.tar_extractor import TarExtractor +from trajdata.dataset_specific.raw_dataset import RawDataset +from trajdata.dataset_specific.scene_records import MadsSceneRecord +from trajdata.maps import VectorMap +from trajdata.utils.parallel_utils import parallel_apply + + +class MADSDataset(RawDataset): + def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: + """Scan clips and build dataset-level metadata. + + Args: + env_name: Environment name for trajdata metadata. + data_dir: Root data directory containing clip folders or tar files. + + Returns: + Constructed `EnvMetadata` for this dataset instance. + """ + # Create scene splits + dataset_parts: List[Tuple[str, ...]] = [("all", "train", "val")] + self.data_dir = data_dir + + # List to hold file names without extensions + expanded_data_dir = str(Path(data_dir).expanduser()) + clip_dir = dict() + clip_start_micros = dict() + clip_duration = dict() + tar_extractor = TarExtractor() + for subdir in os.listdir(expanded_data_dir): + # Construct the full path using the expanded data_dir + + # support for the clipgt2.3.0 that is extracted by the prep_parquet_given_clipids.py + clip_files = glob.glob( + os.path.join(expanded_data_dir, subdir) + "/**/clip.parquet", + recursive=True, + ) + clip_file = clip_files[0] if clip_files else None + + # if not extracted in advance, we can parse the raw tar + if clip_file is None: + # Check for tarred version -- force clipgt.2.0.0, TODO: make it a config option + tar_files = glob.glob( + os.path.join(expanded_data_dir, subdir, "**/clipgt.2.0.0*.tar"), + recursive=True, + ) + if tar_files: + extracted_dir = tar_extractor.get_clip_dir(Path(tar_files[0])) + extracted_clip_file = os.path.join(extracted_dir, "clip.parquet") + if os.path.exists(extracted_clip_file): + clip_file = extracted_clip_file + + if clip_file is None: + # Could not find any clip file + continue + + # Skip clips missing required ego trajectory parquet to avoid worker crashes. + clip_parent_dir = str(Path(clip_file).parent.absolute()) + ego_file = os.path.join(clip_parent_dir, "egomotion_estimate.parquet") + if not os.path.exists(ego_file): + continue + + df_meta = pd.read_parquet(clip_file) + df_meta = mads_utils.df_expand_json(df_meta) + + clip_id = df_meta["key.clip_id"][0] + clip_dir[clip_id] = clip_parent_dir + t0 = df_meta["key.time_range.start_micros"][0] + tf = df_meta["key.time_range.end_micros"][0] + clip_start_micros[clip_id] = t0 + clip_duration[clip_id] = (tf - t0) / 1e6 + + # Cleanup to free the shmem. + tar_extractor.cleanup() + self.clip_start_micros = clip_start_micros + + # get all clip ids + self.clip_duration = clip_duration + clip_items = list(clip_dir.items()) + random.shuffle(clip_items) + self.clip_dir = dict(clip_items) + clip_ids = list(clip_dir.keys()) + all_clips = [clip_id for clip_id, _ in clip_items[:]] + + scene_split_map: Dict[str, str] = {} + for clip_id in all_clips: + scene_split_map[clip_id] = "all" + + env_metadata = EnvMetadata( + name=env_name, + data_dir=data_dir, + dt=MADS_DT, + parts=cast(List[Tuple[str]], dataset_parts), + scene_split_map=scene_split_map, + # The location names should match the map names used in + # the unified data cache. + map_locations=tuple(clip_ids), + ) + + # Define the output file path (env override supported). + self.clips_w_wm = os.getenv( + "MADS_CLIPS_W_WM_PATH", "./tmp/clips_w_wm.txt" + ) + clips_out_dir = os.path.dirname(self.clips_w_wm) or "." + os.makedirs(clips_out_dir, exist_ok=True) + + return env_metadata + + def load_dataset_obj(self, verbose: bool = False) -> None: + """Mark dataset as loaded so scene discovery uses the object path. + + MADS reads per-scene parquet files on demand and does not need a heavy + in-memory dataset object. However, `RawDataset.get_matching_scenes()` + uses `self.dataset_obj is None` as the signal to load from cache. + Setting a lightweight sentinel prevents first runs from trying to read + `scenes_list.dill` before it is created. + """ + self.dataset_obj = True + + def _get_matching_scenes_from_obj( + self, + scene_tag: SceneTag, + scene_desc_contains: Optional[List[str]], + env_cache: EnvCache, + ) -> List[SceneMetadata]: + """Resolve scenes from in-memory clip listing. + + Args: + scene_tag: Scene tag filter. + scene_desc_contains: Optional description filter. + env_cache: Environment cache for scene-list persistence. + + Returns: + Matching scene metadata entries. + """ + all_scenes_list: List[MadsSceneRecord] = list() + + scenes_list: List[SceneMetadata] = list() + for idx, (clip_id, _clip_dir) in enumerate(self.clip_dir.items()): + scene_location = "clipGT" + if clip_id not in self.metadata.scene_split_map: + print("Alert!!!!! scene {} not in scene_split_map".format(clip_id)) + continue + + scene_split: str = self.metadata.scene_split_map[clip_id] + scene_length: int = int(self.clip_duration[clip_id] / MADS_DT) + + if scene_length > 1: + all_scenes_list.append( + MadsSceneRecord( + clip_id, scene_location, str(scene_length), scene_split, idx + ) + ) + + if (scene_split in scene_tag) and scene_desc_contains is None: + scene_metadata = SceneMetadata( + env_name=self.metadata.name, + name=clip_id, + dt=self.metadata.dt, + raw_data_idx=idx, + ) + scenes_list.append(scene_metadata) + + self.cache_all_scenes_list(env_cache, cast(List[NamedTuple], all_scenes_list)) + return scenes_list + + def _get_matching_scenes_from_cache( + self, + scene_tag: SceneTag, + scene_desc_contains: Optional[List[str]], + env_cache: EnvCache, + ) -> List[Scene]: + """Resolve scenes from cached scene records. + + Args: + scene_tag: Scene tag filter. + scene_desc_contains: Optional description filter. + env_cache: Environment cache containing scene records. + + Returns: + Matching `Scene` entries. + """ + + all_scenes_list = cast( + List[MadsSceneRecord], env_cache.load_env_scenes_list(self.name) + ) + + scenes_list: List[Scene] = [] + for scene_record in all_scenes_list: + ( + scene_name, + scene_location, + scene_length, + scene_split, + data_idx, + ) = scene_record + + if scene_split in scene_tag and scene_desc_contains is None: + scene_metadata = Scene( + self.metadata, + scene_name, + scene_location, + scene_split, + int(scene_length), + data_idx, + None, # This isn't used if everything is already cached. + ) + scenes_list.append(scene_metadata) + + return scenes_list + + def get_scene(self, scene_info: SceneMetadata) -> Scene: + """Create a trajdata `Scene` from metadata.""" + # Type hinting for scene_info is not working properly in python 3.10 + + scene_name = scene_info.name + data_idx = scene_info.raw_data_idx + + scene_location = scene_info.name + scene_split: str = self.metadata.scene_split_map[scene_name] + scene_length: int = int(self.clip_duration[scene_name] / MADS_DT) + 1 + + return Scene( + self.metadata, + scene_name, + scene_location, + scene_split, + scene_length, + data_idx, + None, + ) + + @staticmethod + def get_df_from_path( + scene_path: str, + scene_name: str, + verbose: bool = False, + data_src: Optional[str] = None, + ) -> pd.DataFrame: + """Load and normalize ego/dynamic data into a time-aligned dataframe. + + Args: + scene_path: Path to one clip directory. + scene_name: Clip ID. + verbose: Whether to print debug information. + data_src: Optional dataset source override (e.g., v2 or pai). + + Returns: + Normalized dataframe with one row per `(agent_id, scene_ts)`. + """ + + resolved_data_src = resolve_data_src(data_src) + + if resolved_data_src == "v2": + obstacle_path = os.path.join(scene_path, "object_fused.parquet") + if not os.path.exists(obstacle_path): + if verbose: + print( + f"Skipping scene {scene_name} as it does not have obstacle data.", + flush=True, + ) + dynamic_df = pd.DataFrame() + else: + dynamic_df = pd.read_parquet(obstacle_path) + try: + dynamic_df = dynamic_df[(dynamic_df["key.clip_id"] == scene_name)] + dynamic_df["object_fused.obstacle_class"] = dynamic_df[ + "object_fused.obstacle_class" + ].apply(ObstacleClassV1.get_enum_name) + dynamic_df = mads_utils.add_quaternion_from_direction( + dynamic_df, + "object_fused.obstacle_direction.x", + "object_fused.obstacle_direction.y", + "object_fused.obstacle_direction.z", + ) + dynamic_df["agent_id"] = dynamic_df[ + "object_fused.obstacle_id" + ].astype(str) + dynamic_df["length"] = ( + dynamic_df["object_fused.cuboid_3D_halfAxisXYZ.x"].values * 2 + ) + dynamic_df["width"] = ( + dynamic_df["object_fused.cuboid_3D_halfAxisXYZ.y"].values * 2 + ) + dynamic_df["height"] = ( + dynamic_df["object_fused.cuboid_3D_halfAxisXYZ.z"].values * 2 + ) + except Exception: + if verbose: + print( + f"Skipping scene {scene_name} as it does not have obstacle data.", + flush=True, + ) + dynamic_df = pd.DataFrame() + + # read ego motion data + ego_df = pd.read_parquet( + os.path.join(scene_path, f"egomotion_estimate.parquet") + ) + ego_df = mads_utils.df_expand_json(ego_df) + egomotion_key = "EgomotionEstimate" + + elif resolved_data_src == "pai": + raise NotImplementedError( + "DATA_SRC='pai' placeholder is defined but parsing is not implemented yet." + ) + else: + raise ValueError( + f"{resolved_data_src} not supported in trajdata mads dataset" + ) + + assert ego_df[f"{egomotion_key}.name"].unique().size == 1 + ego_df = ego_df.assign( + length=EGO_LENGTH, + width=EGO_WIDTH, + height=EGO_HEIGHT, + type="automobile", + agent_id="ego", + source="manual", + ) + ego_df["key.label_class_id"] = ego_df[f"{egomotion_key}.name"].iat[0] + + assert ego_df[f"{egomotion_key}.name"].unique().size == 1 + ego_df = ego_df.drop(columns=[f"{egomotion_key}.name"]) + + # re-naming the fields + ego_df.rename( + columns={ + f"{egomotion_key}.location.x": "x", + f"{egomotion_key}.location.y": "y", + f"{egomotion_key}.location.z": "z", + f"{egomotion_key}.orientation.x": "qx", + f"{egomotion_key}.orientation.y": "qy", + f"{egomotion_key}.orientation.z": "qz", + f"{egomotion_key}.orientation.w": "qw", + }, + inplace=True, + ) + dynamic_df.rename( + columns={ + "Obstacle.center.x": "x", + "Obstacle.center.y": "y", + "Obstacle.center.z": "z", + "object_fused.cuboid_3D_center.x": "x", + "object_fused.cuboid_3D_center.y": "y", + "object_fused.cuboid_3D_center.z": "z", + "Obstacle.orientation.x": "qx", + "Obstacle.orientation.y": "qy", + "Obstacle.orientation.z": "qz", + "Obstacle.orientation.w": "qw", + "Obstacle.size.x": "length", + "Obstacle.size.y": "width", + "Obstacle.size.z": "height", + "Obstacle.size.dimX": "length", + "Obstacle.size.dimY": "width", + "Obstacle.size.dimZ": "height", + "Obstacle.category": "type", + "object_fused.obstacle_class": "type", + }, + inplace=True, + ) + + # timestamp + t0 = ego_df["key.timestamp_micros"].iat[0] + tf = ego_df["key.timestamp_micros"].iat[-1] + if verbose: + print("dynamic_df.empty", dynamic_df.empty) + + # Only select relevant dynamic data + dynamic_df = pd.concat([ego_df, dynamic_df]) + dynamic_df = dynamic_df[(dynamic_df["key.timestamp_micros"] <= tf)] + dynamic_df["rel_time_seconds"] = (dynamic_df["key.timestamp_micros"] - t0) / 1e6 + + interpolated_dfs: List[pd.DataFrame] = [] + for group_name, group_df_raw in dynamic_df.groupby( + ["key.clip_id", "key.label_class_id", "agent_id"] + ): + group_df: pd.DataFrame = cast(pd.DataFrame, group_df_raw) + group_df = group_df.sort_values(by=["rel_time_seconds"]) + duplicated = group_df.duplicated(subset=["rel_time_seconds"]) + duplicated_mask = np.asarray(duplicated, dtype=bool) + + if duplicated_mask.sum() > 0: + if verbose: + print(f"Duplicated timestamps found for agent: {group_name}") + group_df = group_df.loc[~duplicated_mask] + + min_time = group_df["rel_time_seconds"].min() + min_step = int(np.ceil(min_time / MADS_DT)) + max_time = group_df["rel_time_seconds"].max() + max_step = int(np.floor(max_time / MADS_DT)) + target_steps = np.arange(min_step, max_step + 1) + target_times = target_steps * MADS_DT + + if max_step - min_step + 1 < MIN_FRAMES: + continue + + def _interp(col_name): + x = group_df["rel_time_seconds"] + y = group_df[col_name] + if USE_CUBIC_INTERPOLATION: + return CubicSpline(x, y)(target_times) + return np.interp(target_times, x, y) + + # [N, 4] + quats_tensor = np.stack( + [group_df["qx"], group_df["qy"], group_df["qz"], group_df["qw"],], + axis=1, + ) + # Takes in scalar-last quaternion (x, y, z, w) + r = R.from_quat(quats_tensor) + slerp = Slerp(group_df["rel_time_seconds"], r) + interp_r = slerp(target_times) + headings = interp_r.as_euler("zyx", degrees=False)[:, 0] + # Scalar-last + # interp_quats = interp_r.as_quat() + + df = pd.DataFrame( + { + "key.clip_id": group_name[0], + "key.label_class_id": group_name[1], + "agent_id": group_name[2], + "scene_ts": target_steps, + "rel_time_seconds": target_times, + "x": _interp("x"), + "y": _interp("y"), + "z": _interp("z"), + # "qx": interp_quats[:, 0], + # "qy": interp_quats[:, 1], + # "qz": interp_quats[:, 2], + # "qw": interp_quats[:, 3], + "heading": headings, + # We interpolate this as this might change! + # In particular, I found this to change for manual labels. + "length": _interp("length"), + "width": _interp("width"), + "height": _interp("height"), + # "length": group_df["length"].iat[0], + # "width": group_df["width"].iat[0], + # "height": group_df["height"].iat[0], + "type": group_df["type"].iat[0], + "source": group_df["source"].iat[0], + } + ) + + df["vx"] = df["x"].diff() / MADS_DT + df["vy"] = df["y"].diff() / MADS_DT + + # Calculate ego accelerations 'ax' and 'ay' + df["ax"] = df["vx"].diff() / MADS_DT + df["ay"] = df["vy"].diff() / MADS_DT + + # Replace infinity with nan for later nan handling + df["ax"] = df["ax"].replace([np.inf, -np.inf], np.nan) + df["ay"] = df["ay"].replace([np.inf, -np.inf], np.nan) + + # The first row of ax and ay is NaN, fill in values where NaN exists + df["vx"] = df["vx"].bfill().ffill() + df["vy"] = df["vy"].bfill().ffill() + df["ax"] = df["ax"].bfill().ffill() + df["ay"] = df["ay"].bfill().ffill() + + interpolated_dfs.append(df) + interpolated_df: pd.DataFrame = pd.concat(interpolated_dfs).reset_index(drop=True) + assert int(interpolated_df.duplicated(subset=["scene_ts", "agent_id"]).sum()) == 0 + + T = (tf - t0) / (1e6 * MADS_DT) + scene_ts_series: pd.Series = cast(pd.Series, interpolated_df.loc[:, "scene_ts"]) + valid_scene_ts_mask: pd.Series = cast( + pd.Series, (scene_ts_series >= 0) & (scene_ts_series <= T) + ) + valid_scene_ts_mask_np = np.asarray(valid_scene_ts_mask, dtype=bool) + interpolated_df = interpolated_df.loc[valid_scene_ts_mask_np] + + # Sort by distance to ego + ego_start = interpolated_df.query("agent_id == 'ego' and scene_ts == 0") + ego_x = ego_start["x"].iat[0] + ego_y = ego_start["y"].iat[0] + + unique_distances = set() + + def get_group_distance_to_ego(group_df: pd.DataFrame) -> pd.DataFrame: + min_ts = group_df["scene_ts"].min() + agent_start = group_df.query(f"scene_ts == {min_ts}") + agent_x = agent_start["x"].iat[0] + agent_y = agent_start["y"].iat[0] + distance_to_ego = np.sqrt((ego_x - agent_x) ** 2 + (ego_y - agent_y) ** 2) + assert distance_to_ego not in unique_distances + unique_distances.add(distance_to_ego) + group_df["distance_to_ego"] = distance_to_ego + return group_df + + sorted_df: pd.DataFrame = cast( + pd.DataFrame, + cast(Any, interpolated_df.groupby("agent_id")).apply( + get_group_distance_to_ego, include_groups=False + ), + ) + sorted_df = ( + sorted_df.sort_values(by=["distance_to_ego", "agent_id", "scene_ts"]) + .reset_index() + .drop(columns=["level_1"]) + ) + + # Filter out agents that are too close to each other + # Strategy: For simplicity and speed we only compare the xy locations of agents + # when they are first seen. This might ofc missing cases when the agent moves + # and the 'ghost' object appears later. + # We start by adding all agents with gt labels. Then, we iterate over the rest + # of the agents and either: + # - accept them and add their first seen location to `first_states` + # - reject them and add them to `agents_to_remove` + def get_row_first_seen(df: pd.DataFrame) -> pd.DataFrame: + return ( + cast(Any, df.groupby("agent_id", sort=False)).apply( + lambda gdf: gdf.iloc[0], include_groups=False + ) + .reset_index() + ) + + first_states = get_row_first_seen(sorted_df.query("source == 'manual'")) + + agents_to_remove: set[str] = set() + # Add agents one by one if they don't overlap + for agent_id, group_df in sorted_df.query("source != 'manual'").groupby( + "agent_id", sort=False + ): + current_agent_first_state = group_df.iloc[0] + first_states["distance_to_current_agent"] = np.sqrt( + (first_states["x"] - current_agent_first_state.x) ** 2 + + (first_states["y"] - current_agent_first_state.y) ** 2 + ) + closest_state = first_states.sort_values( + by=["distance_to_current_agent"] + ).iloc[0] + # Only conservative filtering based on dimensions: + if ( + min(closest_state.width, closest_state.height) + > closest_state.distance_to_current_agent + ): + agents_to_remove.add(agent_id) + else: + first_states = pd.concat([first_states, get_row_first_seen(group_df)]) + + agent_id_series = cast(pd.Series, sorted_df["agent_id"]) + keep_mask = ~cast(Any, agent_id_series).isin(list(agents_to_remove)) + sorted_df = cast(pd.DataFrame, sorted_df[keep_mask]) + + return sorted_df + + def get_agent_info( + self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] + ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: + sorted_df = self.get_df_from_path(self.clip_dir[scene.name], scene.name) + + contain_obstacles: bool = False + agent_list: List[AgentMetadata] = [] + agent_presence: List[List[AgentMetadata]] = [ + [] for _ in range(scene.length_timesteps) + ] + agents_to_remove = [] + for agent_id, frames in sorted_df.groupby("agent_id", sort=False)[ + ["scene_ts", "type", "length", "width", "height"] + ]: + all_frame_ids = frames["scene_ts"] + + start_frame: int = all_frame_ids.iat[0] + last_frame: int = all_frame_ids.iat[-1] + + # XW: added for v0tar data, sometimes the clip duration is short + last_frame = min(last_frame, scene.length_timesteps) + + if agent_id != "ego": + contain_obstacles = True + + agent_length = ( + frames["length"].iloc[0][0] + if isinstance(frames["length"].iloc[0], list) + else frames["length"].iloc[0] + ) + agent_width = ( + frames["width"].iloc[0][0] + if isinstance(frames["width"].iloc[0], list) + else frames["width"].iloc[0] + ) + agent_height = ( + frames["height"].iloc[0][0] + if isinstance(frames["height"].iloc[0], list) + else frames["height"].iloc[0] + ) + + agent_metadata = AgentMetadata( + name=agent_id, + agent_type=mads_utils.mads_type_to_unified_type(frames["type"].iloc[0]), + first_timestep=start_frame, + last_timestep=last_frame, + extent=FixedExtent( + length=agent_length, width=agent_width, height=agent_height + ), + ) + + agent_list.append(agent_metadata) + for frame in range( + agent_metadata.first_timestep, agent_metadata.last_timestep + ): + agent_presence[frame].append(agent_metadata) + + agent_id_series = cast(pd.Series, sorted_df["agent_id"]) + keep_mask = ~cast(Any, agent_id_series).isin(list(agents_to_remove)) + sorted_df = cast(pd.DataFrame, sorted_df[keep_mask]) + sorted_df.set_index(["agent_id", "scene_ts"], inplace=True) + + cache_class.save_agent_data( + sorted_df, cache_path, scene, + ) + + # also save clip ids that contain WM data + if contain_obstacles: + with open(self.clips_w_wm, "a", encoding="utf-8") as f: + f.write(f"{scene.name}\n") + + return agent_list, agent_presence + + def cache_map( + self, + map_name: str, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], + verbose: bool = False, + ) -> None: + """Cache one map into trajdata map cache if not already present.""" + + save_file = os.path.join( + cache_path, self.metadata.name, "maps", f"{map_name}_4.00px_m.dill" + ) + if os.path.exists(save_file): + if verbose: + print(f"Skipping {map_name} Map", flush=True) + elif resolve_data_src() == "v2": + vector_map = VectorMap(map_id=f"{self.name}:{map_name}") + try: + mads_utils.populate_vector_map(vector_map, self.clip_dir[map_name]) + map_cache_class.finalize_and_cache_map( + cache_path, vector_map, map_params + ) + except FileNotFoundError: + # Some clips do not include lane parquet files; skip quietly. + if verbose: + print( + f"[MapCache] Missing lane parquet for {map_name}, skipping.", + flush=True, + ) + return + except Exception: + print( + f"[MapCache] Failed to cache map {map_name}, skipping.", flush=True + ) + traceback.print_exc() + return + elif resolve_data_src() == "pai": + raise NotImplementedError( + "DATA_SRC='pai' placeholder is defined but map caching is not implemented yet." + ) + else: + raise ValueError("not supported") + + def cache_maps( + self, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], + resume: bool = True, + ) -> None: + """Cache maps for all clips, optionally skipping already cached maps.""" + + # select the ones that are not finished + if resume: + clip_dir_need_map = [] + for clip_id in self.clip_dir.keys(): + map_file = os.path.join( + cache_path, self.metadata.name, "maps", f"{clip_id}_4.00px_m.dill" + ) + if not os.path.exists(map_file): + clip_dir_need_map.append(clip_id) + clip_list = clip_dir_need_map + else: + clip_list = self.clip_dir.keys() + + num_workers: int = map_params.get("num_workers", 0) + if num_workers > 1: + parallel_apply( + partial( + self.cache_map, + cache_path=cache_path, + map_cache_class=map_cache_class, + map_params=map_params, + ), + clip_list, + num_workers=num_workers, + ) + + else: + for map_name in tqdm( + clip_list, + desc=f"Caching {self.name} Maps at {map_params['px_per_m']:.2f} px/m", + position=0, + ): + self.cache_map(map_name, cache_path, map_cache_class, map_params) + + print(f"Caching Map finished", flush=True) + + + +def _debug_dump_scene_df(data_src: Optional[str] = None) -> None: + """Debug helper to inspect one scene dataframe when run as a script.""" + scene_path = ( + "/lustre/fsw/portfolios/nvr/users/xweng/agentdriver_alpamayo/data/new_data" + ) + scene_name = "762e063d-6eb9-43ae-959c-e53af10b53f9" + scene_path = os.path.join(scene_path, scene_name) + ego_df: pd.DataFrame = MADSDataset.get_df_from_path( + scene_path, scene_name, verbose=True, data_src=data_src + ) + + # Display basic information + print("\n--- DataFrame Shape (rows, columns) ---") + print(ego_df.shape) + + print("\n--- Column Names ---") + print(ego_df.columns.tolist()) + + print("\n--- First 5 Rows ---") + print(ego_df.head()) + + print("\n--- DataFrame Info ---") + print(ego_df.info()) + + # Identify object columns + include_dtypes = cast(Any, ["object", "int64", "float64"]) + ego_df_any = cast(Any, ego_df) + selected_df: pd.DataFrame = cast(pd.DataFrame, ego_df_any.select_dtypes(include=include_dtypes)) + object_columns: List[str] = cast(List[str], [str(c) for c in list(selected_df.columns)]) + print("\nColumns with object dtype:", object_columns) + + # Analyze each object column + for col in object_columns: + print(f"\nAnalyzing column: {col}") + + # Get unique types in the column + unique_types = ego_df[col].map(type).unique() + print("Unique data types:", unique_types) + + # Display a few sample values + sample_values = ego_df[col].dropna().sample(min(5, len(ego_df)), random_state=42) + print("Sample values:", sample_values.tolist()) + + # Show descriptive statistics + print("\n--- Descriptive Statistics ---") + print(ego_df.describe(include="all")) + + +if __name__ == "__main__": # pyright: ignore[reportUnreachableCode] + parser = argparse.ArgumentParser(description="Inspect MADS dataframe parsing") + parser.add_argument( + "--data-src", + type=str, + default=None, + choices=sorted(SUPPORTED_DATA_SRCS), + help="Override data source (default: env MADS_DATA_SRC or constant DATA_SRC).", + ) + args = parser.parse_args() + _debug_dump_scene_df(data_src=args.data_src) diff --git a/src/trajdata/dataset_specific/mads/mads_utils.py b/src/trajdata/dataset_specific/mads/mads_utils.py new file mode 100644 index 0000000..bcae2bb --- /dev/null +++ b/src/trajdata/dataset_specific/mads/mads_utils.py @@ -0,0 +1,550 @@ +# SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing +import os +from concurrent import futures +from functools import partial +from typing import Any, Dict, List, Sequence, Tuple + +import numpy as np +import pandas as pd + +from trajdata.maps.vec_map import VectorMap +from trajdata.maps.vec_map_elements import Polyline, RoadLane +from trajdata.data_structures.agent import AgentType + +MAX_POLYLINE_POINT_DIST = 2.0 + + +def df_expand_json(df: pd.DataFrame) -> pd.DataFrame: + """Expand nested/json-like columns into flattened dotted columns. + + Args: + df: Input DataFrame with nested object columns. + + Returns: + Expanded DataFrame where each nested column is normalized and prefixed. + """ + # Use explicit string column names for static type checkers. + columns = [str(col) for col in df.columns] + for key in columns: + cap_key = ( + "".join([x.capitalize() for x in key.split("_")]) + if key not in ["key", "version"] + else key + ) + df = df.join(pd.json_normalize(df.loc[:, key]).add_prefix(f"{cap_key}.")) + + return df + + +# quaternion from direction +import math +def quaternion_from_direction(v): + """ + Build a quaternion (w, x, y, z) that rotates (1,0,0) onto the direction of v. + """ + + x, y, _ = v + # compute yaw angle + yaw = math.atan2(y, x) + half = yaw * 0.5 + + w = math.cos(half) + # rotation axis is Z, so only z component non‑zero + return (0.0, 0.0, math.sin(half), w) + +def add_quaternion_from_direction(df, x_col='x', y_col='y', z_col='z'): + """ + Given a DataFrame with 3D vectors in columns x_col, y_col, z_col, + compute for each row the quaternion (w,x,y,z) that rotates +Z → (x,y,z), + and returns a new DataFrame with added columns: qw, qx, qy, qz. + """ + # apply row‑wise + qs = df.apply( + lambda row: quaternion_from_direction((row[x_col], row[y_col], row[z_col])), + axis=1, + result_type='expand' + ) + qs.columns = ['qx', 'qy', 'qz', 'qw'] + return df.join(qs) + +def mads_type_to_unified_type(mads_type: str) -> AgentType: + if mads_type.startswith("person"): + return AgentType.PEDESTRIAN + elif mads_type == "automobile": + return AgentType.VEHICLE + elif mads_type == "other_vehicle": + return AgentType.VEHICLE + elif mads_type.startswith("cycle"): + return AgentType.BICYCLE + elif mads_type.startswith("motorcycle"): + return AgentType.MOTORCYCLE + # v1 type + elif "VEHICLE" in mads_type: + return AgentType.VEHICLE + elif "PEDESTRIAN" in mads_type: + return AgentType.PEDESTRIAN + elif "BIKE_MOTOR" in mads_type: + return AgentType.MOTORCYCLE + elif "BIKE" in mads_type: + return AgentType.BICYCLE + else: + return AgentType.UNKNOWN + + +def _preprocess_lanes( + positions: np.ndarray, lane_ids: np.ndarray, +) -> Dict[str, np.ndarray]: + """Build lane-id groups by relative position label. + + Args: + positions: Lane position labels from `dw_lane.position`. + lane_ids: Lane IDs aligned with `positions`. + + Returns: + Dictionary with `ego` and optional `left`/`right` lane-id arrays. + """ + assert len(positions) == len( + lane_ids + ), "positions and lane_ids must have the same length" + + lane_info: Dict[str, np.ndarray] = {} + + # From av/avprotos/lane.proto `LanePosition`. + left_indices = np.where(positions == 2)[0] + right_indices = np.where(positions == 3)[0] + ego_indices = np.where(positions == 1)[0] + + lane_info["ego"] = lane_ids[ego_indices] + + if len(left_indices) > 0: + lane_info["left"] = lane_ids[left_indices] + if len(right_indices) > 0: + lane_info["right"] = lane_ids[right_indices] + + return lane_info + + +def _process_single_lane( + lane_data: Dict[str, Any], + lane_chunks_details: pd.DataFrame, + lane_info: Dict[str, np.ndarray], + lane_ids_in_order: Sequence[int], + rotation: np.ndarray, + translation: np.ndarray, +) -> Tuple[int, Dict[str, Any]]: + """Process one lane row into the intermediate lane map structure. + + Args: + lane_data: Per-lane fields extracted from `dw_lane`. + lane_chunks_details: Lane-chunk details table. + lane_info: Lane groups with keys `ego`, and optionally `left`/`right`. + lane_ids_in_order: Lane IDs aligned to continuation indices. + rotation: 3x3 rotation matrix. + translation: 3D translation vector. + + Returns: + Tuple `(original_index, processed_lane_properties)`. + """ + idx = lane_data["idx"] + + live_path_map_lane: Dict[str, Any] = { + "beginType": int(lane_data["begin_type"]), + "conf": 1.25, + "endType": int(lane_data["end_type"]), + "id": int(lane_data["current_id"]), + "isBiDir": bool(lane_data["bidirectional"]), + "isTurn": bool(lane_data["turns_allowed"]), + "laneClass": int(lane_data["lane_class"]), + "laneClassConf": 1.0, + "ts": int(lane_data["timestamp"]), + } + + lane_id = int(lane_data["current_id"]) + if lane_id in lane_info.get("ego", np.array([], dtype=np.int64)): + right_lane_id = lane_info.get("right") + left_lane_id = lane_info.get("left") + if right_lane_id is not None: + live_path_map_lane["laneChangeRightIds"] = right_lane_id.tolist() + if left_lane_id is not None: + live_path_map_lane["laneChangeLeftIds"] = left_lane_id.tolist() + + continuation_array = lane_data["continuation_array"] + if len(continuation_array) > 0: + live_path_map_lane["laneSuccessorIds"] = [ + lane_ids_in_order[idx_tmp] for idx_tmp in continuation_array + ] + + chunk_indices = lane_data["chunk_indices"] + live_path_map_lane["laneGeometry"] = _extract_lane_geometry( + lane_chunks_details, chunk_indices, rotation, translation, + ) + + # Preserve original order via index for caller-side sorting. + return (idx, live_path_map_lane) + + +def _extract_lane_geometry( + lane_chunks_details: pd.DataFrame, + chunk_indices: Sequence[int], + rotation: np.ndarray, + translation: np.ndarray, +) -> List[Dict[str, Any]]: + """Extract and transform lane geometry samples for one lane. + + Args: + lane_chunks_details: Lane chunk table for one timestamp. + chunk_indices: Indices of chunk rows to use. + rotation: 3x3 rotation matrix. + translation: 3D translation vector. + + Returns: + List of geometry dictionaries per sampled point. + """ + filtered_lane_chunks = lane_chunks_details.iloc[list(chunk_indices)] + # Check if df_expand_json has been called (clipgt-2.0.0). + chunk_key = "lane_chunk" + if "LaneChunk.center" in filtered_lane_chunks: + chunk_key = "LaneChunk" + + # Keep a best-effort reference length for robust fallback arrays. + # This preserves prior behavior intent (matching other per-point arrays) + # while avoiding undefined-variable fallback paths. + reference_len = 0 + + def flatten_column(name: str) -> np.ndarray: + """Flatten a ListArray column and convert to float32.""" + nonlocal reference_len + try: + # If expanded columns are present, use them directly. + if name in filtered_lane_chunks: + col = filtered_lane_chunks[name].to_list() + else: + # clipgt-2.0.0 fallback: each cell can be a list of json objects. + col = [] + splitname = name.rsplit(".", 1) + for row in filtered_lane_chunks[splitname[0]]: + col.append(pd.json_normalize(row)[splitname[1]]) + flat = np.concatenate(col).astype(np.float32) + reference_len = max(reference_len, int(flat.shape[0])) + return flat + except (KeyError, ValueError): + # Keep historical fallback behavior intent with a safe reference length. + return np.zeros(reference_len, dtype=np.float32) + + flat_cx = flatten_column(f"{chunk_key}.center.x") + flat_cy = flatten_column(f"{chunk_key}.center.y") + flat_cz = flatten_column(f"{chunk_key}.center.z") + + flat_lx = flatten_column(f"{chunk_key}.left.x") + flat_ly = flatten_column(f"{chunk_key}.left.y") + flat_lz = flatten_column(f"{chunk_key}.left.z") + + flat_rx = flatten_column(f"{chunk_key}.right.x") + flat_ry = flatten_column(f"{chunk_key}.right.y") + flat_rz = flatten_column(f"{chunk_key}.right.z") + + centers = np.stack([flat_cx, flat_cy, flat_cz], axis=1) + lefts = np.stack([flat_lx, flat_ly, flat_lz], axis=1) + rights = np.stack([flat_rx, flat_ry, flat_rz], axis=1) + + centers = centers @ rotation.T + translation + lefts = lefts @ rotation.T + translation + rights = rights @ rotation.T + translation + + def flatten_int(name: str) -> np.ndarray: + """Flatten a ListArray column and convert to int32.""" + try: + col = filtered_lane_chunks[name].to_list() + flat = np.concatenate(col).astype(np.int32) + return flat + except (KeyError, ValueError): + return np.zeros(len(centers), dtype=np.int32) + + left_color = flatten_int(f"{chunk_key}.leftColor") + left_style = flatten_int(f"{chunk_key}.leftStyle") + left_type = flatten_int(f"{chunk_key}.leftType") + right_color = flatten_int(f"{chunk_key}.rightColor") + right_style = flatten_int(f"{chunk_key}.rightStyle") + right_type = flatten_int(f"{chunk_key}.rightType") + + lane_geometry = [ + { + "centerNormal": [0.0, 0.0, 0.0], + "centerXYZ": centers[i].tolist(), + "comSpeed": 0.0, + "leftColor": int(left_color[i]), + "leftStyle": int(left_style[i]), + "leftType": int(left_type[i]), + "leftXYZ": lefts[i].tolist(), + "maxSpeed": 0.0, + "rightColor": int(right_color[i]), + "rightStyle": int(right_style[i]), + "rightType": int(right_type[i]), + "rightXYZ": rights[i].tolist(), + } + for i in range(len(centers)) + ] + return lane_geometry + + +def _process_lanes_parallel( + lane_table: pd.DataFrame, + lane_patch_table: pd.DataFrame, + rotation: np.ndarray, + translation: np.ndarray, +) -> List[Any]: + """Process all lanes for one timestamp, optionally in parallel. + + Args: + lane_table: `dw_lane` rows for one timestamp. + lane_patch_table: `lane_chunk` rows for one timestamp. + rotation: 3x3 rotation matrix. + translation: 3D translation vector. + + Returns: + List of per-lane dictionaries sorted by original lane index. + """ + num_lanes = len(lane_table) + + # Check if df_expand_json has been called (clipgt-2.0.0). + lane_key = "dw_lane" + if "dw_lane.laneClass" not in lane_table: + lane_key = "DwLane" + + timestamp_col = lane_table["key.timestamp_micros"].to_numpy() + lane_class_col = lane_table[f"{lane_key}.laneClass"].to_numpy() + begin_type_col = lane_table[f"{lane_key}.beginType"].to_numpy() + end_type_col = lane_table[f"{lane_key}.endType"].to_numpy() + current_id_col = lane_table[f"{lane_key}.currentId"].to_numpy() + position_col = lane_table[f"{lane_key}.position"].to_numpy() + bidirectional_col = lane_table[f"{lane_key}.bidirectional"].to_numpy() + turns_allowed_col = lane_table[f"{lane_key}.turnsAllowed"].to_numpy() + continuation_array = lane_table[ + f"{lane_key}.lgwm_lane_continuation_array" + ].to_numpy() + chunk_indices = lane_table[f"{lane_key}.chunkIndices"].to_numpy() + + lane_info = _preprocess_lanes(position_col, current_id_col) + + lane_data: List[Dict[str, Any]] = [ + { + "idx": i, + "timestamp": timestamp_col[i], + "lane_class": lane_class_col[i], + "begin_type": begin_type_col[i], + "end_type": end_type_col[i], + "current_id": current_id_col[i], + "bidirectional": bidirectional_col[i], + "turns_allowed": turns_allowed_col[i], + "continuation_array": continuation_array[i], + "chunk_indices": chunk_indices[i], + } + for i in range(num_lanes) + ] + + process_func = partial( + _process_single_lane, + lane_chunks_details=lane_patch_table, + lane_info=lane_info, + lane_ids_in_order=current_id_col, + rotation=rotation, + translation=translation, + ) + + num_processes = min(multiprocessing.cpu_count(), num_lanes) + + # Process in parallel only when there are enough lanes to amortize overhead. + if num_lanes > 4: + with futures.ThreadPoolExecutor(max_workers=num_processes) as executor: + results = list(executor.map(process_func, lane_data)) + else: + # For small datasets, process sequentially to avoid thread overhead. + results = [process_func(lane_item) for lane_item in lane_data] + + sorted_results: List[Any] = [None] * num_lanes + for idx, result in results: + sorted_results[int(idx)] = result + + return sorted_results + + +def populate_vector_map(vector_map: VectorMap, map_root: str) -> None: + """Populate `vector_map` from MADS lane parquet files. + + Args: + vector_map: Target trajdata vector map to populate in-place. + map_root: Directory containing `dw_lane.parquet` and `lane_chunk.parquet`. + + Returns: + None. + """ + maximum_bound: np.ndarray = np.full((3,), np.nan) + minimum_bound: np.ndarray = np.full((3,), np.nan) + + rotation = np.array( + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], dtype=np.float32 + ) + translation = np.array([0.0, 0.0, 0.0], dtype=np.float32) + + dw_lane_path = os.path.join(map_root, "dw_lane.parquet") + lane_chunk_path = os.path.join(map_root, "lane_chunk.parquet") + if not (os.path.exists(dw_lane_path) and os.path.exists(lane_chunk_path)): + raise FileNotFoundError( + f"Missing lane parquet(s) under {map_root}: " + f"dw_lane.parquet={os.path.exists(dw_lane_path)}, " + f"lane_chunk.parquet={os.path.exists(lane_chunk_path)}" + ) + + df_dw_lane = pd.read_parquet(dw_lane_path) + df_lane_chunk = pd.read_parquet(lane_chunk_path) + + # Check if df_expand_json needs to be called (clipgt-2.0.0). + if "key.timestamp_micros" not in df_dw_lane: + df_dw_lane = df_expand_json(df_dw_lane) + + if "key.timestamp_micros" not in df_dw_lane: + ts_list: np.ndarray = np.array([], dtype=np.int64) + else: + # Keep timestamp handling explicit to avoid Series/ndarray union issues in type stubs. + ts_numeric = pd.Series( + pd.to_numeric(df_dw_lane["key.timestamp_micros"], errors="coerce") + ) + ts_valid = ts_numeric.loc[ts_numeric.notna()].to_numpy() + ts_list = np.unique(np.asarray(ts_valid, dtype=np.int64)) + + if "key.timestamp_micros" not in df_lane_chunk: + df_lane_chunk = df_expand_json(df_lane_chunk) + + sample_rate = 20 # 20 * 0.1s = 2s + live_path_map_lanes: List[Tuple[int, List[Any]]] = [] + + for idx, ts in enumerate(ts_list[::sample_rate]): + dw_mask = df_dw_lane["key.timestamp_micros"] == ts + chunk_mask = df_lane_chunk["key.timestamp_micros"] == ts + df_dw_lane_ts: pd.DataFrame = df_dw_lane.loc[dw_mask] + df_lane_chunk_ts: pd.DataFrame = df_lane_chunk.loc[chunk_mask] + + live_path_map_lanes.append( + ( + idx * sample_rate, + _process_lanes_parallel( + lane_table=df_dw_lane_ts, + lane_patch_table=df_lane_chunk_ts, + rotation=rotation, + translation=translation, + ), + ) + ) # (ts, map results) + + all_lanes_dict: Dict[str, Dict[str, Any]] = {} + for ts, lanes_ts in live_path_map_lanes: # per ts + for lane in lanes_ts: + lane_id = f"{lane['id']}_{ts}" + lane_geometry = lane["laneGeometry"] + + left_rail: List[Any] = [] + right_rail: List[Any] = [] + midlane_pts: List[Any] = [] + next_lane: List[str] = [] + prev_lane: List[str] = [] + left_lane: List[str] = [] + right_lane: List[str] = [] + traffic_sign: List[Any] = [] + wait_line: List[Any] = [] + + for segment in lane_geometry: + left_rail.append(segment["leftXYZ"]) + right_rail.append(segment["rightXYZ"]) + midlane_pts.append(segment["centerXYZ"]) + + left_rail_arr = np.array(left_rail).reshape(-1, 3) + right_rail_arr = np.array(right_rail).reshape(-1, 3) + midlane_pts_arr = np.array(midlane_pts).reshape(-1, 3) + + if "laneSuccessorIds" in lane: + for lane_successor_id in lane["laneSuccessorIds"]: + next_lane.append(f"{lane_successor_id}_{ts}") + + if "lanePredecessorIds" in lane: + for lane_predecessor_id in lane["lanePredecessorIds"]: + prev_lane.append(f"{lane_predecessor_id}_{ts}") + + if "laneChangeLeftIds" in lane: + for lane_change_left_id in lane["laneChangeLeftIds"]: + left_lane.append(f"{lane_change_left_id}_{ts}") + + if "laneChangeRightIds" in lane: + for lane_change_right_id in lane["laneChangeRightIds"]: + right_lane.append(f"{lane_change_right_id}_{ts}") + + all_lanes_dict[lane_id] = { + "left_rail": left_rail_arr, + "right_rail": right_rail_arr, + "midlane_pts": midlane_pts_arr, + "next_lane": next_lane, + "prev_lane": prev_lane, + "left_lane": left_lane, + "right_lane": right_lane, + "traffic_sign": traffic_sign, + "wait_line": wait_line, + } + + if not all_lanes_dict: + print("No valid data available in map file") + return + + # Creating Vectorized Map. + for lane_id, lane_info_dict in all_lanes_dict.items(): + left_polyline = np.asarray(lane_info_dict["left_rail"], dtype=np.float32) + right_polyline = np.asarray(lane_info_dict["right_rail"], dtype=np.float32) + midlane_pts = np.asarray(lane_info_dict["midlane_pts"], dtype=np.float32) + + # Compute map bounds. + left_max = np.max(left_polyline, axis=0) + left_min = np.min(left_polyline, axis=0) + maximum_bound = np.fmax(maximum_bound, left_max) + minimum_bound = np.fmin(minimum_bound, left_min) + + right_max = np.max(right_polyline, axis=0) + right_min = np.min(right_polyline, axis=0) + maximum_bound = np.fmax(maximum_bound, right_max) + minimum_bound = np.fmin(minimum_bound, right_min) + + mid_max = np.max(midlane_pts, axis=0) + mid_min = np.min(midlane_pts, axis=0) + maximum_bound = np.fmax(maximum_bound, mid_max) + minimum_bound = np.fmin(minimum_bound, mid_min) + + new_lane = RoadLane( + id=lane_id, + center=Polyline(midlane_pts).interpolate(max_dist=MAX_POLYLINE_POINT_DIST), + left_edge=Polyline(left_polyline).interpolate( + max_dist=MAX_POLYLINE_POINT_DIST + ), + right_edge=Polyline(right_polyline).interpolate( + max_dist=MAX_POLYLINE_POINT_DIST + ), + next_lanes=lane_info_dict["next_lane"], + adj_lanes_left=lane_info_dict["left_lane"], + adj_lanes_right=lane_info_dict["right_lane"], + prev_lanes=lane_info_dict["prev_lane"], + ) + vector_map.add_map_element(new_lane) + + # vector_map.extent is [min_x, min_y, min_z, max_x, max_y, max_z] + vector_map.extent = np.concatenate((minimum_bound, maximum_bound)) diff --git a/src/trajdata/dataset_specific/mads/tar_extractor.py b/src/trajdata/dataset_specific/mads/tar_extractor.py new file mode 100644 index 0000000..9c2204c --- /dev/null +++ b/src/trajdata/dataset_specific/mads/tar_extractor.py @@ -0,0 +1,77 @@ +# Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. +"""Tar extraction utilities for component data files. + +Provides TarExtractor class for extracting tar archives to RAM filesystem +for fast access during data processing operations. +""" + +import pathlib +import shutil +import tarfile +import uuid +from typing import Optional + + +class TarExtractor: + """Extracts tar archives to /dev/shm for fast access.""" + + def __init__(self): + """Initialize extractor with empty state, ready to extract a tar file.""" + self._filename: Optional[pathlib.Path] = None + self._extracted_tar_file: Optional[pathlib.Path] = None + self._temp_dir: Optional[pathlib.Path] = None + + def get_clip_dir(self, tar_filepath: pathlib.Path) -> pathlib.Path: + """Get directory with extracted components. + + Args: + base_path: Directory containing the tar file. + + Returns: + Path to extracted components directory. + """ + tar_filepath = pathlib.Path(tar_filepath) + + # Return cached directory if already extracted + if self._extracted_tar_file == tar_filepath and self._temp_dir: + return self._temp_dir + + self._extract(tar_filepath) + if self._temp_dir is None: + raise RuntimeError(f"Failed to extract tar file: {tar_filepath}") + return self._temp_dir + + def _extract(self, tar_file: pathlib.Path): + """Extract tar to temporary directory in /dev/shm. + + Args: + tar_file: Path to tar file to extract. + """ + self.cleanup() + + # Create temp directory + clip_id = tar_file.parent.name + untar_id = uuid.uuid4().hex[:10] + self._temp_dir = ( + pathlib.Path("/dev/shm") / f"untar_{untar_id}_of_clip_{clip_id}" + ) + self._temp_dir.mkdir(exist_ok=True) + + try: + with tarfile.open(tar_file, "r") as tar: + tar.extractall(self._temp_dir) + self._extracted_tar_file = tar_file + except Exception: + self.cleanup() + raise + + def cleanup(self): + """Remove temporary directory and reset state.""" + if self._temp_dir: + shutil.rmtree(self._temp_dir, ignore_errors=True) + self._temp_dir = None + self._extracted_tar_file = None + + def __del__(self): + """Clean up temporary files when object is destroyed.""" + self.cleanup() diff --git a/src/trajdata/dataset_specific/pai/__init__.py b/src/trajdata/dataset_specific/pai/__init__.py new file mode 100644 index 0000000..57ed92e --- /dev/null +++ b/src/trajdata/dataset_specific/pai/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""PAI dataset package exports.""" + +from .pai_dataset import PAIDataset + +__all__ = ["PAIDataset"] diff --git a/src/trajdata/dataset_specific/pai/pai_dataset.py b/src/trajdata/dataset_specific/pai/pai_dataset.py new file mode 100644 index 0000000..25670ec --- /dev/null +++ b/src/trajdata/dataset_specific/pai/pai_dataset.py @@ -0,0 +1,399 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import math +import os +import random +import zipfile +from pathlib import Path +from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast + +import numpy as np +import pandas as pd +from scipy.spatial.transform import Rotation as R +from scipy.spatial.transform import Slerp + +from trajdata.caching import EnvCache, SceneCache +from trajdata.data_structures.agent import AgentMetadata, FixedExtent +from trajdata.data_structures.environment import EnvMetadata +from trajdata.data_structures.scene_metadata import Scene, SceneMetadata +from trajdata.data_structures.scene_tag import SceneTag +from trajdata.dataset_specific.mads import mads_utils +from trajdata.dataset_specific.mads.constant import ( + EGO_HEIGHT, + EGO_LENGTH, + EGO_WIDTH, + MADS_DT, + MIN_FRAMES, +) +from trajdata.dataset_specific.raw_dataset import RawDataset +from trajdata.dataset_specific.scene_records import PAISceneRecord + +_ZIP_MEMBER_SEP = "::" + +class PAIDataset(RawDataset): + def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: + dataset_parts: List[Tuple[str, ...]] = [("all", "train", "val")] + + self.data_dir = str(Path(data_dir).expanduser()) + egomotion_dir = Path(self.data_dir) / "labels" / "egomotion" + if not egomotion_dir.exists(): + raise FileNotFoundError( + f"Expected PAI egomotion directory at {egomotion_dir}." + ) + + clip_dir: Dict[str, str] = {} + clip_duration: Dict[str, float] = {} + clip_start_micros: Dict[str, int] = {} + + for zip_path in sorted(egomotion_dir.glob("*.zip")): + with zipfile.ZipFile(zip_path, "r") as zf: + for member in zf.namelist(): + if not member.endswith(".parquet"): + continue + + file_name = Path(member).name + if not file_name: + continue + + clip_id = file_name.split(".")[0] + if not clip_id or clip_id in clip_dir: + continue + + clip_dir[clip_id] = f"{zip_path}{_ZIP_MEMBER_SEP}{member}" + clip_start_micros[clip_id] = 0 + clip_duration[clip_id] = 20.0 + + if not clip_dir: + raise FileNotFoundError( + f"No egomotion parquet files were found under {egomotion_dir}." + ) + + self.clip_start_micros = clip_start_micros + self.clip_duration = clip_duration + + clip_items = list(clip_dir.items()) + random.shuffle(clip_items) + self.clip_dir = dict(clip_items) + + all_clips = [clip_id for clip_id, _ in clip_items] + scene_split_map: Dict[str, str] = {clip_id: "all" for clip_id in all_clips} + + return EnvMetadata( + name=env_name, + data_dir=self.data_dir, + dt=MADS_DT, + parts=cast(List[Tuple[str]], dataset_parts), + scene_split_map=scene_split_map, + map_locations=tuple(all_clips), + ) + + def load_dataset_obj(self, verbose: bool = False) -> None: + pass + + def _get_matching_scenes_from_obj( + self, + scene_tag: SceneTag, + scene_desc_contains: Optional[List[str]], + env_cache: EnvCache, + ) -> List[SceneMetadata]: + all_scenes_list: List[PAISceneRecord] = [] + scenes_list: List[SceneMetadata] = [] + + for idx, (clip_id, _clip_path) in enumerate(self.clip_dir.items()): + scene_split = self.metadata.scene_split_map[clip_id] + scene_length: int = int(self.clip_duration[clip_id] / MADS_DT) + + if scene_length > 1: + all_scenes_list.append( + PAISceneRecord(clip_id, "pai", str(scene_length), scene_split, idx) + ) + + if (scene_split in scene_tag) and scene_desc_contains is None: + scenes_list.append( + SceneMetadata( + env_name=self.metadata.name, + name=clip_id, + dt=self.metadata.dt, + raw_data_idx=idx, + ) + ) + + self.cache_all_scenes_list(env_cache, cast(List[NamedTuple], all_scenes_list)) + return scenes_list + + def _get_matching_scenes_from_cache( + self, + scene_tag: SceneTag, + scene_desc_contains: Optional[List[str]], + env_cache: EnvCache, + ) -> List[Scene]: + all_scenes_list = cast( + List[PAISceneRecord], env_cache.load_env_scenes_list(self.name) + ) + + scenes_list: List[Scene] = [] + for scene_record in all_scenes_list: + scene_name, scene_location, scene_length, scene_split, data_idx = scene_record + + if scene_split in scene_tag and scene_desc_contains is None: + scenes_list.append( + Scene( + self.metadata, + scene_name, + scene_location, + scene_split, + int(scene_length), + data_idx, + None, + ) + ) + + return scenes_list + + def get_scene(self, scene_info: SceneMetadata) -> Scene: + scene_name = scene_info.name + data_idx = scene_info.raw_data_idx + scene_split: str = self.metadata.scene_split_map[scene_name] + scene_length: int = int(self.clip_duration[scene_name] / MADS_DT) + 1 + + return Scene( + self.metadata, + scene_name, + "pai", + scene_split, + scene_length, + data_idx, + None, + ) + + @staticmethod + def _load_ego_df(scene_path: str) -> pd.DataFrame: + if _ZIP_MEMBER_SEP in scene_path: + zip_path, member = scene_path.split(_ZIP_MEMBER_SEP, maxsplit=1) + with zipfile.ZipFile(zip_path, "r") as zf: + with zf.open(member, "r") as f: + return pd.read_parquet(f) + + if os.path.isdir(scene_path): + estimated = os.path.join(scene_path, "egomotion_estimate.parquet") + if os.path.exists(estimated): + return pd.read_parquet(estimated) + + egomotion = os.path.join(scene_path, "egomotion.parquet") + if os.path.exists(egomotion): + return pd.read_parquet(egomotion) + + return pd.read_parquet(scene_path) + + @staticmethod + def get_df_from_path( + scene_path: str, + scene_name: str, + verbose: bool = False, + ) -> pd.DataFrame: + ego_df = PAIDataset._load_ego_df(scene_path) + ego_df = mads_utils.df_expand_json(ego_df) + + timestamp_col = "timestamp" + if timestamp_col not in ego_df.columns: + timestamp_col = "key.timestamp_micros" + + if timestamp_col not in ego_df.columns: + raise KeyError( + f"Expected timestamp column in egomotion parquet for scene {scene_name}." + ) + + if "key.clip_id" not in ego_df.columns: + ego_df["key.clip_id"] = scene_name + + col_map = { + "x": "x", + "y": "y", + "z": "z", + "qx": "qx", + "qy": "qy", + "qz": "qz", + "qw": "qw", + } + prefixed_map = { + "EgomotionEstimate.location.x": "x", + "EgomotionEstimate.location.y": "y", + "EgomotionEstimate.location.z": "z", + "EgomotionEstimate.orientation.x": "qx", + "EgomotionEstimate.orientation.y": "qy", + "EgomotionEstimate.orientation.z": "qz", + "EgomotionEstimate.orientation.w": "qw", + } + + if all(key in ego_df.columns for key in col_map.keys()): + normalized = ego_df.rename(columns=col_map).copy() + elif all(key in ego_df.columns for key in prefixed_map.keys()): + normalized = ego_df.rename(columns=prefixed_map).copy() + else: + missing = [ + key + for key in [ + "x", + "y", + "z", + "qx", + "qy", + "qz", + "qw", + "EgomotionEstimate.location.x", + "EgomotionEstimate.location.y", + "EgomotionEstimate.location.z", + "EgomotionEstimate.orientation.x", + "EgomotionEstimate.orientation.y", + "EgomotionEstimate.orientation.z", + "EgomotionEstimate.orientation.w", + ] + if key not in ego_df.columns + ] + raise KeyError( + f"Could not find expected ego pose columns for scene {scene_name}. Missing: {missing[:6]}" + ) + + normalized = normalized.sort_values(by=[timestamp_col]).drop_duplicates( + subset=[timestamp_col], keep="first" + ) + + t0 = int(normalized[timestamp_col].iat[0]) + tf = int(normalized[timestamp_col].iat[-1]) + normalized["key.timestamp_micros"] = normalized[timestamp_col] + normalized["rel_time_seconds"] = (normalized["key.timestamp_micros"] - t0) / 1e6 + + min_time = float(normalized["rel_time_seconds"].min()) + max_time = float(normalized["rel_time_seconds"].max()) + min_step = int(math.ceil(min_time / MADS_DT)) + max_step = int(math.floor(max_time / MADS_DT)) + + target_steps = np.arange(min_step, max_step + 1) + target_times = target_steps * MADS_DT + + if target_steps.size < MIN_FRAMES: + raise ValueError( + f"Scene {scene_name} is too short after interpolation ({target_steps.size} < {MIN_FRAMES})." + ) + + def _interp(col_name: str) -> np.ndarray: + return np.interp( + target_times, + normalized["rel_time_seconds"].to_numpy(), + normalized[col_name].to_numpy(), + ) + + quats_tensor = np.stack( + [ + normalized["qx"].to_numpy(), + normalized["qy"].to_numpy(), + normalized["qz"].to_numpy(), + normalized["qw"].to_numpy(), + ], + axis=1, + ) + interp_r = Slerp(normalized["rel_time_seconds"], R.from_quat(quats_tensor))( + target_times + ) + headings = interp_r.as_euler("zyx", degrees=False)[:, 0] + + df = pd.DataFrame( + { + "key.clip_id": scene_name, + "key.label_class_id": "ego", + "agent_id": "ego", + "scene_ts": target_steps, + "rel_time_seconds": target_times, + "x": _interp("x"), + "y": _interp("y"), + "z": _interp("z"), + "heading": headings, + "length": EGO_LENGTH, + "width": EGO_WIDTH, + "height": EGO_HEIGHT, + "type": "automobile", + "source": "manual", + } + ) + + if "vx" in normalized.columns and "vy" in normalized.columns: + df["vx"] = _interp("vx") + df["vy"] = _interp("vy") + else: + df["vx"] = df["x"].diff() / MADS_DT + df["vy"] = df["y"].diff() / MADS_DT + + if "ax" in normalized.columns and "ay" in normalized.columns: + df["ax"] = _interp("ax") + df["ay"] = _interp("ay") + else: + df["ax"] = df["vx"].diff() / MADS_DT + df["ay"] = df["vy"].diff() / MADS_DT + + df["vx"] = df["vx"].replace([np.inf, -np.inf], np.nan).bfill().ffill() + df["vy"] = df["vy"].replace([np.inf, -np.inf], np.nan).bfill().ffill() + df["ax"] = df["ax"].replace([np.inf, -np.inf], np.nan).bfill().ffill() + df["ay"] = df["ay"].replace([np.inf, -np.inf], np.nan).bfill().ffill() + + t_horizon = (tf - t0) / (1e6 * MADS_DT) + return cast( + pd.DataFrame, + df[(df["scene_ts"] >= 0) & (df["scene_ts"] <= t_horizon)].reset_index( + drop=True + ), + ) + + def get_agent_info( + self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] + ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: + ego_df = self.get_df_from_path(self.clip_dir[scene.name], scene.name) + ego_df.set_index(["agent_id", "scene_ts"], inplace=True) + + ego_metadata = AgentMetadata( + name="ego", + agent_type=mads_utils.mads_type_to_unified_type("automobile"), + first_timestep=0, + last_timestep=min(scene.length_timesteps, int(ego_df.index.get_level_values(1).max())), + extent=FixedExtent(length=EGO_LENGTH, width=EGO_WIDTH, height=EGO_HEIGHT), + ) + + agent_presence: List[List[AgentMetadata]] = [ + [] for _ in range(scene.length_timesteps) + ] + for frame in range(ego_metadata.first_timestep, ego_metadata.last_timestep): + agent_presence[frame].append(ego_metadata) + + cache_class.save_agent_data(ego_df, cache_path, scene) + return [ego_metadata], agent_presence + + def cache_map( + self, + map_name: str, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], + verbose: bool = False, + ) -> None: + return + + def cache_maps( + self, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], + resume: bool = True, + ) -> None: + return diff --git a/src/trajdata/dataset_specific/scene_records.py b/src/trajdata/dataset_specific/scene_records.py index b785415..e47495b 100644 --- a/src/trajdata/dataset_specific/scene_records.py +++ b/src/trajdata/dataset_specific/scene_records.py @@ -62,3 +62,18 @@ class NuPlanSceneRecord(NamedTuple): split: str # desc: str data_idx: int + +class MadsSceneRecord(NamedTuple): + name: str + location: str + length: str + split: str + data_idx: int + + +class PAISceneRecord(NamedTuple): + name: str + location: str + length: str + split: str + data_idx: int \ No newline at end of file diff --git a/src/trajdata/utils/env_utils.py b/src/trajdata/utils/env_utils.py index 49c22ec..be12cbe 100644 --- a/src/trajdata/utils/env_utils.py +++ b/src/trajdata/utils/env_utils.py @@ -55,6 +55,15 @@ def get_raw_dataset(dataset_name: str, data_dir: str) -> RawDataset: return Av2Dataset(dataset_name, data_dir, parallelizable=True, has_maps=True) + if "mads" in dataset_name.lower(): + from trajdata.dataset_specific.mads import MADSDataset + + return MADSDataset(dataset_name, data_dir, parallelizable=True, has_maps=False) + + if "pai" in dataset_name.lower(): + from trajdata.dataset_specific.pai import PAIDataset + return PAIDataset(dataset_name, data_dir, parallelizable=True, has_maps=False) + raise ValueError(f"Dataset with name '{dataset_name}' is not supported") From 8aff2c416e3667c931851ca3d7f4ae4619facd33 Mon Sep 17 00:00:00 2001 From: Xinshuo Weng Date: Fri, 20 Mar 2026 18:59:44 -0400 Subject: [PATCH 2/3] address comments --- .../dataset_specific/mads/__init__.py | 1 + .../dataset_specific/mads/constant.py | 5 ++-- .../dataset_specific/mads/mads_dataset.py | 13 +++++----- .../dataset_specific/mads/mads_utils.py | 1 + .../dataset_specific/mads/tar_extractor.py | 17 +++++++++++- .../dataset_specific/pai/pai_dataset.py | 26 +++++-------------- src/trajdata/utils/env_utils.py | 3 ++- 7 files changed, 35 insertions(+), 31 deletions(-) diff --git a/src/trajdata/dataset_specific/mads/__init__.py b/src/trajdata/dataset_specific/mads/__init__.py index f7ed469..d1bd734 100644 --- a/src/trajdata/dataset_specific/mads/__init__.py +++ b/src/trajdata/dataset_specific/mads/__init__.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# """MADS dataset package exports.""" diff --git a/src/trajdata/dataset_specific/mads/constant.py b/src/trajdata/dataset_specific/mads/constant.py index 34ebd25..3d57f72 100644 --- a/src/trajdata/dataset_specific/mads/constant.py +++ b/src/trajdata/dataset_specific/mads/constant.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# """Constants and enums for the MADS dataset loader.""" @@ -25,7 +26,7 @@ EGO_WIDTH: Final[float] = 2.11311007 EGO_HEIGHT: Final[float] = 1.34435794 -# by default, we support internal production v2 dataset, with clipgt-2.0.0 and above. +# by default, we support v2 data, version clipgt-2.0.0, clipgt-2.3.0 and above. DATA_SRC: Final[str] = "v2" # Allowed MADS data source tags. @@ -53,7 +54,7 @@ def resolve_data_src(cli_override: Optional[str] = None) -> str: class ObstacleClassV1(Enum): - """Obstacle classes for MADS based on NDAS `obstacle_types.proto`.""" + """Obstacle classes for MADS.""" # buf:lint:ignore ENUM_ZERO_VALUE_SUFFIX OBSTACLE_CLASS_INVALID = 0 diff --git a/src/trajdata/dataset_specific/mads/mads_dataset.py b/src/trajdata/dataset_specific/mads/mads_dataset.py index 6f8046b..33223b7 100644 --- a/src/trajdata/dataset_specific/mads/mads_dataset.py +++ b/src/trajdata/dataset_specific/mads/mads_dataset.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# import argparse import glob @@ -124,7 +125,7 @@ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: # get all clip ids self.clip_duration = clip_duration clip_items = list(clip_dir.items()) - random.shuffle(clip_items) + random.Random(42).shuffle(clip_items) self.clip_dir = dict(clip_items) clip_ids = list(clip_dir.keys()) all_clips = [clip_id for clip_id, _ in clip_items[:]] @@ -278,7 +279,7 @@ def get_scene(self, scene_info: SceneMetadata) -> Scene: ) @staticmethod - def get_df_from_path( + def get_ego_df_from_path( scene_path: str, scene_name: str, verbose: bool = False, @@ -598,7 +599,7 @@ def get_row_first_seen(df: pd.DataFrame) -> pd.DataFrame: def get_agent_info( self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: - sorted_df = self.get_df_from_path(self.clip_dir[scene.name], scene.name) + sorted_df = self.get_ego_df_from_path(self.clip_dir[scene.name], scene.name) contain_obstacles: bool = False agent_list: List[AgentMetadata] = [] @@ -761,12 +762,10 @@ def cache_maps( def _debug_dump_scene_df(data_src: Optional[str] = None) -> None: """Debug helper to inspect one scene dataframe when run as a script.""" - scene_path = ( - "/lustre/fsw/portfolios/nvr/users/xweng/agentdriver_alpamayo/data/new_data" - ) + scene_path = 'path/to/source/data' scene_name = "762e063d-6eb9-43ae-959c-e53af10b53f9" scene_path = os.path.join(scene_path, scene_name) - ego_df: pd.DataFrame = MADSDataset.get_df_from_path( + ego_df: pd.DataFrame = MADSDataset.get_ego_df_from_path( scene_path, scene_name, verbose=True, data_src=data_src ) diff --git a/src/trajdata/dataset_specific/mads/mads_utils.py b/src/trajdata/dataset_specific/mads/mads_utils.py index bcae2bb..6e35977 100644 --- a/src/trajdata/dataset_specific/mads/mads_utils.py +++ b/src/trajdata/dataset_specific/mads/mads_utils.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# import multiprocessing import os diff --git a/src/trajdata/dataset_specific/mads/tar_extractor.py b/src/trajdata/dataset_specific/mads/tar_extractor.py index 9c2204c..407ca00 100644 --- a/src/trajdata/dataset_specific/mads/tar_extractor.py +++ b/src/trajdata/dataset_specific/mads/tar_extractor.py @@ -1,4 +1,19 @@ -# Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + """Tar extraction utilities for component data files. Provides TarExtractor class for extracting tar archives to RAM filesystem diff --git a/src/trajdata/dataset_specific/pai/pai_dataset.py b/src/trajdata/dataset_specific/pai/pai_dataset.py index 25670ec..8385a40 100644 --- a/src/trajdata/dataset_specific/pai/pai_dataset.py +++ b/src/trajdata/dataset_specific/pai/pai_dataset.py @@ -51,7 +51,7 @@ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: egomotion_dir = Path(self.data_dir) / "labels" / "egomotion" if not egomotion_dir.exists(): raise FileNotFoundError( - f"Expected PAI egomotion directory at {egomotion_dir}." + f"Expected PhysicalAI-AV egomotion directory at {egomotion_dir}." ) clip_dir: Dict[str, str] = {} @@ -85,7 +85,7 @@ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: self.clip_duration = clip_duration clip_items = list(clip_dir.items()) - random.shuffle(clip_items) + random.Random(42).shuffle(clip_items) self.clip_dir = dict(clip_items) all_clips = [clip_id for clip_id, _ in clip_items] @@ -199,7 +199,7 @@ def _load_ego_df(scene_path: str) -> pd.DataFrame: return pd.read_parquet(scene_path) @staticmethod - def get_df_from_path( + def get_ego_df_from_path( scene_path: str, scene_name: str, verbose: bool = False, @@ -243,24 +243,10 @@ def get_df_from_path( elif all(key in ego_df.columns for key in prefixed_map.keys()): normalized = ego_df.rename(columns=prefixed_map).copy() else: + expected_cols = [*col_map.keys(), *prefixed_map.keys()] missing = [ key - for key in [ - "x", - "y", - "z", - "qx", - "qy", - "qz", - "qw", - "EgomotionEstimate.location.x", - "EgomotionEstimate.location.y", - "EgomotionEstimate.location.z", - "EgomotionEstimate.orientation.x", - "EgomotionEstimate.orientation.y", - "EgomotionEstimate.orientation.z", - "EgomotionEstimate.orientation.w", - ] + for key in expected_cols if key not in ego_df.columns ] raise KeyError( @@ -359,7 +345,7 @@ def _interp(col_name: str) -> np.ndarray: def get_agent_info( self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: - ego_df = self.get_df_from_path(self.clip_dir[scene.name], scene.name) + ego_df = self.get_ego_df_from_path(self.clip_dir[scene.name], scene.name) ego_df.set_index(["agent_id", "scene_ts"], inplace=True) ego_metadata = AgentMetadata( diff --git a/src/trajdata/utils/env_utils.py b/src/trajdata/utils/env_utils.py index be12cbe..4663f58 100644 --- a/src/trajdata/utils/env_utils.py +++ b/src/trajdata/utils/env_utils.py @@ -58,10 +58,11 @@ def get_raw_dataset(dataset_name: str, data_dir: str) -> RawDataset: if "mads" in dataset_name.lower(): from trajdata.dataset_specific.mads import MADSDataset - return MADSDataset(dataset_name, data_dir, parallelizable=True, has_maps=False) + return MADSDataset(dataset_name, data_dir, parallelizable=True, has_maps=True) if "pai" in dataset_name.lower(): from trajdata.dataset_specific.pai import PAIDataset + return PAIDataset(dataset_name, data_dir, parallelizable=True, has_maps=False) raise ValueError(f"Dataset with name '{dataset_name}' is not supported") From de5f1e7190e4af5428a5e91c4e5644b24613671b Mon Sep 17 00:00:00 2001 From: Xinshuo Weng Date: Mon, 23 Mar 2026 23:05:10 -0400 Subject: [PATCH 3/3] address comments and fix minor bugs --- .../dataset_specific/mads/constant.py | 1 - .../dataset_specific/mads/mads_dataset.py | 74 +++--- .../dataset_specific/mads/mads_utils.py | 224 +++++++++++++++++- .../dataset_specific/pai/pai_dataset.py | 6 +- 4 files changed, 254 insertions(+), 51 deletions(-) diff --git a/src/trajdata/dataset_specific/mads/constant.py b/src/trajdata/dataset_specific/mads/constant.py index 3d57f72..9c73cc3 100644 --- a/src/trajdata/dataset_specific/mads/constant.py +++ b/src/trajdata/dataset_specific/mads/constant.py @@ -50,7 +50,6 @@ def resolve_data_src(cli_override: Optional[str] = None) -> str: # Minimum frames for an agent to be considered. MIN_FRAMES: Final[int] = 10 -USE_CUBIC_INTERPOLATION: Final[bool] = False class ObstacleClassV1(Enum): diff --git a/src/trajdata/dataset_specific/mads/mads_dataset.py b/src/trajdata/dataset_specific/mads/mads_dataset.py index 33223b7..b51049b 100644 --- a/src/trajdata/dataset_specific/mads/mads_dataset.py +++ b/src/trajdata/dataset_specific/mads/mads_dataset.py @@ -43,7 +43,6 @@ MADS_DT, MIN_FRAMES, SUPPORTED_DATA_SRCS, - USE_CUBIC_INTERPOLATION, ObstacleClassV1, resolve_data_src, ) @@ -259,7 +258,6 @@ def _get_matching_scenes_from_cache( def get_scene(self, scene_info: SceneMetadata) -> Scene: """Create a trajdata `Scene` from metadata.""" - # Type hinting for scene_info is not working properly in python 3.10 scene_name = scene_info.name data_idx = scene_info.raw_data_idx @@ -279,11 +277,12 @@ def get_scene(self, scene_info: SceneMetadata) -> Scene: ) @staticmethod - def get_ego_df_from_path( + def get_ego_agent_df_from_path( scene_path: str, scene_name: str, verbose: bool = False, data_src: Optional[str] = None, + use_cubic_interpolation: bool = False, ) -> pd.DataFrame: """Load and normalize ego/dynamic data into a time-aligned dataframe. @@ -292,6 +291,7 @@ def get_ego_df_from_path( scene_name: Clip ID. verbose: Whether to print debug information. data_src: Optional dataset source override (e.g., v2 or pai). + use_cubic_interpolation: Whether to use cubic spline for xyz/size interpolation. Returns: Normalized dataframe with one row per `(agent_id, scene_ts)`. @@ -367,8 +367,6 @@ def get_ego_df_from_path( source="manual", ) ego_df["key.label_class_id"] = ego_df[f"{egomotion_key}.name"].iat[0] - - assert ego_df[f"{egomotion_key}.name"].unique().size == 1 ego_df = ego_df.drop(columns=[f"{egomotion_key}.name"]) # re-naming the fields @@ -411,8 +409,6 @@ def get_ego_df_from_path( # timestamp t0 = ego_df["key.timestamp_micros"].iat[0] tf = ego_df["key.timestamp_micros"].iat[-1] - if verbose: - print("dynamic_df.empty", dynamic_df.empty) # Only select relevant dynamic data dynamic_df = pd.concat([ego_df, dynamic_df]) @@ -446,7 +442,7 @@ def get_ego_df_from_path( def _interp(col_name): x = group_df["rel_time_seconds"] y = group_df[col_name] - if USE_CUBIC_INTERPOLATION: + if use_cubic_interpolation: return CubicSpline(x, y)(target_times) return np.interp(target_times, x, y) @@ -460,8 +456,6 @@ def _interp(col_name): slerp = Slerp(group_df["rel_time_seconds"], r) interp_r = slerp(target_times) headings = interp_r.as_euler("zyx", degrees=False)[:, 0] - # Scalar-last - # interp_quats = interp_r.as_quat() df = pd.DataFrame( { @@ -473,19 +467,12 @@ def _interp(col_name): "x": _interp("x"), "y": _interp("y"), "z": _interp("z"), - # "qx": interp_quats[:, 0], - # "qy": interp_quats[:, 1], - # "qz": interp_quats[:, 2], - # "qw": interp_quats[:, 3], "heading": headings, # We interpolate this as this might change! # In particular, I found this to change for manual labels. "length": _interp("length"), "width": _interp("width"), "height": _interp("height"), - # "length": group_df["length"].iat[0], - # "width": group_df["width"].iat[0], - # "height": group_df["height"].iat[0], "type": group_df["type"].iat[0], "source": group_df["source"].iat[0], } @@ -494,19 +481,18 @@ def _interp(col_name): df["vx"] = df["x"].diff() / MADS_DT df["vy"] = df["y"].diff() / MADS_DT + # Clean velocity first, then derive acceleration from cleaned velocity + # to keep kinematic consistency (ax ~= dvx/dt, ay ~= dvy/dt). + df["vx"] = df["vx"].replace([np.inf, -np.inf], np.nan).bfill().ffill() + df["vy"] = df["vy"].replace([np.inf, -np.inf], np.nan).bfill().ffill() + # Calculate ego accelerations 'ax' and 'ay' df["ax"] = df["vx"].diff() / MADS_DT df["ay"] = df["vy"].diff() / MADS_DT - # Replace infinity with nan for later nan handling - df["ax"] = df["ax"].replace([np.inf, -np.inf], np.nan) - df["ay"] = df["ay"].replace([np.inf, -np.inf], np.nan) - - # The first row of ax and ay is NaN, fill in values where NaN exists - df["vx"] = df["vx"].bfill().ffill() - df["vy"] = df["vy"].bfill().ffill() - df["ax"] = df["ax"].bfill().ffill() - df["ay"] = df["ay"].bfill().ffill() + # Keep finite accelerations; boundary diff NaNs default to 0. + df["ax"] = df["ax"].replace([np.inf, -np.inf], np.nan).fillna(0.0) + df["ay"] = df["ay"].replace([np.inf, -np.inf], np.nan).fillna(0.0) interpolated_dfs.append(df) interpolated_df: pd.DataFrame = pd.concat(interpolated_dfs).reset_index(drop=True) @@ -517,8 +503,7 @@ def _interp(col_name): valid_scene_ts_mask: pd.Series = cast( pd.Series, (scene_ts_series >= 0) & (scene_ts_series <= T) ) - valid_scene_ts_mask_np = np.asarray(valid_scene_ts_mask, dtype=bool) - interpolated_df = interpolated_df.loc[valid_scene_ts_mask_np] + interpolated_df = cast(pd.DataFrame, interpolated_df.loc[valid_scene_ts_mask]) # Sort by distance to ego ego_start = interpolated_df.query("agent_id == 'ego' and scene_ts == 0") @@ -552,7 +537,7 @@ def get_group_distance_to_ego(group_df: pd.DataFrame) -> pd.DataFrame: # Filter out agents that are too close to each other # Strategy: For simplicity and speed we only compare the xy locations of agents - # when they are first seen. This might ofc missing cases when the agent moves + # when they are first seen. This might miss cases where the agent moves # and the 'ghost' object appears later. # We start by adding all agents with gt labels. Then, we iterate over the rest # of the agents and either: @@ -599,7 +584,7 @@ def get_row_first_seen(df: pd.DataFrame) -> pd.DataFrame: def get_agent_info( self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: - sorted_df = self.get_ego_df_from_path(self.clip_dir[scene.name], scene.name) + sorted_df = self.get_ego_agent_df_from_path(self.clip_dir[scene.name], scene.name) contain_obstacles: bool = False agent_list: List[AgentMetadata] = [] @@ -679,8 +664,9 @@ def cache_map( ) -> None: """Cache one map into trajdata map cache if not already present.""" + px_per_m: float = float(map_params.get("px_per_m", 4.0)) save_file = os.path.join( - cache_path, self.metadata.name, "maps", f"{map_name}_4.00px_m.dill" + cache_path, self.metadata.name, "maps", f"{map_name}_{px_per_m:.2f}px_m.dill" ) if os.path.exists(save_file): if verbose: @@ -721,13 +707,15 @@ def cache_maps( resume: bool = True, ) -> None: """Cache maps for all clips, optionally skipping already cached maps.""" + px_per_m: float = float(map_params.get("px_per_m", 4.0)) + # select the ones that are not finished if resume: clip_dir_need_map = [] for clip_id in self.clip_dir.keys(): map_file = os.path.join( - cache_path, self.metadata.name, "maps", f"{clip_id}_4.00px_m.dill" + cache_path, self.metadata.name, "maps", f"{clip_id}_{px_per_m:.2f}px_m.dill" ) if not os.path.exists(map_file): clip_dir_need_map.append(clip_id) @@ -751,7 +739,7 @@ def cache_maps( else: for map_name in tqdm( clip_list, - desc=f"Caching {self.name} Maps at {map_params['px_per_m']:.2f} px/m", + desc=f"Caching {self.name} Maps at {px_per_m:.2f} px/m", position=0, ): self.cache_map(map_name, cache_path, map_cache_class, map_params) @@ -765,27 +753,27 @@ def _debug_dump_scene_df(data_src: Optional[str] = None) -> None: scene_path = 'path/to/source/data' scene_name = "762e063d-6eb9-43ae-959c-e53af10b53f9" scene_path = os.path.join(scene_path, scene_name) - ego_df: pd.DataFrame = MADSDataset.get_ego_df_from_path( + ego_agent_df: pd.DataFrame = MADSDataset.get_ego_agent_df_from_path( scene_path, scene_name, verbose=True, data_src=data_src ) # Display basic information print("\n--- DataFrame Shape (rows, columns) ---") - print(ego_df.shape) + print(ego_agent_df.shape) print("\n--- Column Names ---") - print(ego_df.columns.tolist()) + print(ego_agent_df.columns.tolist()) print("\n--- First 5 Rows ---") - print(ego_df.head()) + print(ego_agent_df.head()) print("\n--- DataFrame Info ---") - print(ego_df.info()) + print(ego_agent_df.info()) # Identify object columns include_dtypes = cast(Any, ["object", "int64", "float64"]) - ego_df_any = cast(Any, ego_df) - selected_df: pd.DataFrame = cast(pd.DataFrame, ego_df_any.select_dtypes(include=include_dtypes)) + ego_agent_df_any = cast(Any, ego_agent_df) + selected_df: pd.DataFrame = cast(pd.DataFrame, ego_agent_df_any.select_dtypes(include=include_dtypes)) object_columns: List[str] = cast(List[str], [str(c) for c in list(selected_df.columns)]) print("\nColumns with object dtype:", object_columns) @@ -794,16 +782,16 @@ def _debug_dump_scene_df(data_src: Optional[str] = None) -> None: print(f"\nAnalyzing column: {col}") # Get unique types in the column - unique_types = ego_df[col].map(type).unique() + unique_types = ego_agent_df[col].map(type).unique() print("Unique data types:", unique_types) # Display a few sample values - sample_values = ego_df[col].dropna().sample(min(5, len(ego_df)), random_state=42) + sample_values = ego_agent_df[col].dropna().sample(min(5, len(ego_agent_df)), random_state=42) print("Sample values:", sample_values.tolist()) # Show descriptive statistics print("\n--- Descriptive Statistics ---") - print(ego_df.describe(include="all")) + print(ego_agent_df.describe(include="all")) if __name__ == "__main__": # pyright: ignore[reportUnreachableCode] diff --git a/src/trajdata/dataset_specific/mads/mads_utils.py b/src/trajdata/dataset_specific/mads/mads_utils.py index 6e35977..225651b 100644 --- a/src/trajdata/dataset_specific/mads/mads_utils.py +++ b/src/trajdata/dataset_specific/mads/mads_utils.py @@ -15,10 +15,12 @@ # import multiprocessing +import json import os +import warnings from concurrent import futures from functools import partial -from typing import Any, Dict, List, Sequence, Tuple +from typing import Any, Dict, List, Optional, Sequence, Tuple import numpy as np import pandas as pd @@ -384,6 +386,203 @@ def _process_lanes_parallel( return sorted_results +def _validate_map_timestamp_sampling(ts_list: np.ndarray, map_root: str) -> None: + """Validate timestamp cadence and optionally emit rich debug artifacts.""" + expected_raw_delta_micros = 100_000 + sample_rate = 20 + expected_sampled_delta_micros = expected_raw_delta_micros * sample_rate + + raw_tolerance_micros = 500 + sampled_tolerance_micros = 10_000 + hard_gap_factor = 1.5 + + raw_deltas = np.diff(ts_list) + sampled_ts = ts_list[::sample_rate] + sampled_deltas = np.diff(sampled_ts) + + raw_abs_err = np.abs(raw_deltas - expected_raw_delta_micros) + sampled_abs_err = np.abs(sampled_deltas - expected_sampled_delta_micros) + + raw_within_tolerance = bool(np.all(raw_abs_err <= raw_tolerance_micros)) + sampled_within_tolerance = bool( + np.all(sampled_abs_err <= sampled_tolerance_micros) + ) + + raw_hard_gap_threshold = int(expected_raw_delta_micros * hard_gap_factor) + sampled_hard_gap_threshold = int(expected_sampled_delta_micros * hard_gap_factor) + raw_hard_gap_mask = raw_deltas > raw_hard_gap_threshold + sampled_hard_gap_mask = sampled_deltas > sampled_hard_gap_threshold + has_meaningful_gaps = bool(np.any(raw_hard_gap_mask) or np.any(sampled_hard_gap_mask)) + + debug_dir_override = os.getenv("MADS_TS_DEBUG_DIR") + debug_dir_override_str = debug_dir_override.strip() if debug_dir_override else "" + + # Silent by default: only emit timestamp-gap warnings/debug artifacts when + # explicit debug mode is enabled by providing MADS_TS_DEBUG_DIR. + if not debug_dir_override_str: + return + + force_debug = True + base_debug_dir = debug_dir_override_str + clip_tag = os.path.basename(os.path.normpath(map_root)) or "unknown_clip" + clip_tag = clip_tag.replace(" ", "_") + debug_dir = os.path.join(base_debug_dir, clip_tag) + + should_write_debug = has_meaningful_gaps + if not should_write_debug: + return + + try: + os.makedirs(debug_dir, exist_ok=True) + + debug_csv_path = os.path.join(debug_dir, "mads_ts_debug.csv") + summary_json_path = os.path.join(debug_dir, "mads_ts_debug_summary.json") + + is_sampled = np.zeros(ts_list.size, dtype=bool) + is_sampled[::sample_rate] = True + + delta_to_next: List[Optional[int]] = [int(delta) for delta in raw_deltas] + delta_to_next.append(None) + + nominal_step_to_next: List[Optional[bool]] = [ + bool(abs(int(delta) - expected_raw_delta_micros) <= raw_tolerance_micros) + for delta in raw_deltas + ] + nominal_step_to_next.append(None) + + debug_df = pd.DataFrame( + { + "idx": np.arange(ts_list.size, dtype=np.int64), + "timestamp_micros": ts_list.astype(np.int64), + "delta_to_next_micros": delta_to_next, + "is_sampled": is_sampled, + "is_nominal_step_to_next": nominal_step_to_next, + } + ) + debug_df.to_csv(debug_csv_path, index=False) + + summary: Dict[str, Any] = { + "num_timestamps": int(ts_list.size), + "expected_raw_delta_micros": int(expected_raw_delta_micros), + "expected_sampled_delta_micros": int(expected_sampled_delta_micros), + "sample_rate_steps": int(sample_rate), + "raw_tolerance_micros": int(raw_tolerance_micros), + "sampled_tolerance_micros": int(sampled_tolerance_micros), + "raw_within_tolerance": raw_within_tolerance, + "sampled_within_tolerance": sampled_within_tolerance, + "force_debug": force_debug, + "num_raw_out_of_tolerance": int(np.sum(raw_abs_err > raw_tolerance_micros)), + "num_sampled_out_of_tolerance": int( + np.sum(sampled_abs_err > sampled_tolerance_micros) + ), + "raw_hard_gap_threshold_micros": int(raw_hard_gap_threshold), + "sampled_hard_gap_threshold_micros": int(sampled_hard_gap_threshold), + "num_raw_hard_gaps": int(np.sum(raw_hard_gap_mask)), + "num_sampled_hard_gaps": int(np.sum(sampled_hard_gap_mask)), + "has_meaningful_gaps": has_meaningful_gaps, + "raw_delta_min_micros": ( + int(raw_deltas.min()) if raw_deltas.size > 0 else None + ), + "raw_delta_max_micros": ( + int(raw_deltas.max()) if raw_deltas.size > 0 else None + ), + "sampled_delta_min_micros": ( + int(sampled_deltas.min()) if sampled_deltas.size > 0 else None + ), + "sampled_delta_max_micros": ( + int(sampled_deltas.max()) if sampled_deltas.size > 0 else None + ), + } + + with open(summary_json_path, "w", encoding="utf-8") as f: + json.dump(summary, f, indent=2, sort_keys=True) + + if has_meaningful_gaps: + warnings.warn( + "Detected meaningful map timestamp gaps; sampling remains unchanged " + f"(every {sample_rate}th entry). Debug files saved under: " + f"{debug_dir}", + stacklevel=2, + ) + except OSError as exc: + warnings.warn( + "Timestamp validation found issues, but failed to write debug files " + f"to {debug_dir}: {exc}", + stacklevel=2, + ) + +def _select_strict_2s_timestamps( + ts_list: np.ndarray, +) -> Tuple[np.ndarray, np.ndarray, Dict[str, Any]]: + """Select timestamps on a strict 2s time grid with nearest-neighbor snapping.""" + grid_step_micros = 2_000_000 + base_dt_micros = 100_000 + sample_rate = 20 + max_snap_micros = int(os.getenv("MADS_MAP_STRICT_MAX_SNAP_US", "150000")) + + # 1) Fast exit for empty clips. + if ts_list.size == 0: + empty = np.array([], dtype=np.int64) + stats: Dict[str, Any] = { + "grid_step_micros": grid_step_micros, + "max_snap_micros": max_snap_micros, + "target_count": 0, + "selected_count": 0, + "coverage": 0.0, + "max_abs_snap_error_micros": None, + "sample_rate_steps": sample_rate, + } + return empty, empty, stats + + # 2) Build ideal 2s targets from the first to last timestamp. + start_ts = int(ts_list[0]) + end_ts = int(ts_list[-1]) + targets = np.arange(start_ts, end_ts + 1, grid_step_micros, dtype=np.int64) + + right_idx = np.searchsorted(ts_list, targets, side="left") + # 3) For each target, get nearest left/right candidates in ts_list. + left_idx = np.clip(right_idx - 1, 0, ts_list.size - 1) + right_idx_clipped = np.clip(right_idx, 0, ts_list.size - 1) + + left_ts = ts_list[left_idx] + right_ts = ts_list[right_idx_clipped] + # 4) Snap each target to whichever candidate is closer. + choose_right = (right_idx < ts_list.size) & ( + np.abs(right_ts - targets) < np.abs(left_ts - targets) + ) + chosen_idx = np.where(choose_right, right_idx_clipped, left_idx) + + # 5) Keep only snapped timestamps within max_snap_micros tolerance. + chosen_ts = ts_list[chosen_idx] + abs_snap_err = np.abs(chosen_ts - targets) + keep_mask = abs_snap_err <= max_snap_micros + + kept_targets = targets[keep_mask] + kept_ts = chosen_ts[keep_mask] + + # 6) Deduplicate in case nearby targets snap to the same source timestamp. + if kept_ts.size > 0: + unique_mask = np.concatenate(([True], kept_ts[1:] != kept_ts[:-1])) + kept_targets = kept_targets[unique_mask] + kept_ts = kept_ts[unique_mask] + + # 7) Convert kept 2s targets into trajdata frame units (0.1s per frame). + # Also compute summary stats for monitoring coverage/snap error. + selected_frame_idx = ((kept_targets - start_ts) // base_dt_micros).astype(np.int64) + max_abs_err = int(abs_snap_err[keep_mask].max()) if np.any(keep_mask) else None + coverage = float(np.sum(keep_mask) / targets.size) if targets.size > 0 else 0.0 + + stats = { + "grid_step_micros": grid_step_micros, + "max_snap_micros": max_snap_micros, + "target_count": int(targets.size), + "selected_count": int(kept_ts.size), + "coverage": coverage, + "max_abs_snap_error_micros": max_abs_err, + "sample_rate_steps": sample_rate, + } + return kept_ts.astype(np.int64), selected_frame_idx, stats + def populate_vector_map(vector_map: VectorMap, map_root: str) -> None: """Populate `vector_map` from MADS lane parquet files. @@ -431,10 +630,27 @@ def populate_vector_map(vector_map: VectorMap, map_root: str) -> None: if "key.timestamp_micros" not in df_lane_chunk: df_lane_chunk = df_expand_json(df_lane_chunk) - sample_rate = 20 # 20 * 0.1s = 2s + _validate_map_timestamp_sampling(ts_list=ts_list, map_root=map_root) + + selected_ts, selected_frame_idx, strict_stats = _select_strict_2s_timestamps( + ts_list=ts_list + ) + if strict_stats["target_count"] > 0 and strict_stats["selected_count"] == 0: + raise ValueError( + "Strict 2s timestamp sampling found no valid timestamps. " + f"map_root={map_root}, stats={strict_stats}" + ) + + if strict_stats["target_count"] > 0 and strict_stats["coverage"] < 0.9: + warnings.warn( + "Low strict 2s timestamp coverage; map output may be sparse. " + f"map_root={map_root}, stats={strict_stats}", + stacklevel=2, + ) + live_path_map_lanes: List[Tuple[int, List[Any]]] = [] - for idx, ts in enumerate(ts_list[::sample_rate]): + for frame_idx, ts in zip(selected_frame_idx, selected_ts): dw_mask = df_dw_lane["key.timestamp_micros"] == ts chunk_mask = df_lane_chunk["key.timestamp_micros"] == ts df_dw_lane_ts: pd.DataFrame = df_dw_lane.loc[dw_mask] @@ -442,7 +658,7 @@ def populate_vector_map(vector_map: VectorMap, map_root: str) -> None: live_path_map_lanes.append( ( - idx * sample_rate, + int(frame_idx), _process_lanes_parallel( lane_table=df_dw_lane_ts, lane_patch_table=df_lane_chunk_ts, diff --git a/src/trajdata/dataset_specific/pai/pai_dataset.py b/src/trajdata/dataset_specific/pai/pai_dataset.py index 8385a40..7bef67d 100644 --- a/src/trajdata/dataset_specific/pai/pai_dataset.py +++ b/src/trajdata/dataset_specific/pai/pai_dataset.py @@ -97,11 +97,11 @@ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: dt=MADS_DT, parts=cast(List[Tuple[str]], dataset_parts), scene_split_map=scene_split_map, - map_locations=tuple(all_clips), + map_locations=cast(Tuple[str], tuple(all_clips)), ) def load_dataset_obj(self, verbose: bool = False) -> None: - pass + self.dataset_obj = True def _get_matching_scenes_from_obj( self, @@ -250,7 +250,7 @@ def get_ego_df_from_path( if key not in ego_df.columns ] raise KeyError( - f"Could not find expected ego pose columns for scene {scene_name}. Missing: {missing[:6]}" + f"Could not find expected ego pose columns for scene {scene_name}. Missing: {missing}" ) normalized = normalized.sort_values(by=[timestamp_col]).drop_duplicates(