diff --git a/notebooks/01_observations.ipynb b/notebooks/01_observations.ipynb index 0928501adb..a3e2fbe309 100644 --- a/notebooks/01_observations.ipynb +++ b/notebooks/01_observations.ipynb @@ -16,9 +16,16 @@ "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", - "from pufferlib.ocean.drive.drive import Drive\n", "from pufferlib.ocean.drive import binding\n", "import pufferlib.viz\n", + "from notebooks.notebook_utils import (\n", + " COEF_NAMES,\n", + " EGO_LABELS,\n", + " make_drive_env,\n", + " notebook_dims,\n", + " random_actions,\n", + " zero_actions,\n", + ")\n", "\n", "# --- Environment configuration ---\n", "NUM_AGENTS = 64\n", @@ -34,44 +41,15 @@ "COLLISION_BEHAVIOR = 1\n", "OFFROAD_BEHAVIOR = 1\n", "SEED = 42\n", - "MAP_DIR = \"../pufferlib/resources/drive/binaries/carla\"\n", "\n", - "# --- Observation dimensions (configurable) ---\n", + "# --- Observation dimensions ---\n", "MAX_PARTNERS = 16\n", "MAX_LANES = 32\n", "MAX_BOUNDS = 32\n", "MAX_TRAFFIC = 4\n", "\n", - "# --- Derived from binding (compile-time) ---\n", - "EGO_DIM = binding.EGO_FEATURES_JERK\n", - "NUM_COEFS = binding.NUM_REWARD_COEFS\n", - "PARTNER_F = binding.PARTNER_FEATURES\n", - "ROAD_F = binding.ROAD_FEATURES\n", - "TRAFFIC_CONTROL_F = binding.TRAFFIC_CONTROL_FEATURES\n", - "NUM_TRAFFIC_CONTROL_TYPES = binding.NUM_TRAFFIC_CONTROL_TYPES\n", - "COEF_NAMES = [\n", - " \"goal_radius\",\n", - " \"collision\",\n", - " \"offroad\",\n", - " \"comfort\",\n", - " \"lane_align\",\n", - " \"lane_center\",\n", - " \"velocity\",\n", - " \"traffic_light\",\n", - " \"center_bias\",\n", - " \"vel_align\",\n", - " \"overspeed\",\n", - " \"timestep\",\n", - " \"reverse\",\n", - " \"throttle\",\n", - " \"steer\",\n", - " \"acc\",\n", - "]\n", - "\n", - "# --- Create environment ---\n", - "env = Drive(\n", + "env, obs, info = make_drive_env(\n", " num_agents=NUM_AGENTS,\n", - " num_maps=1,\n", " min_agents_per_env=NUM_AGENTS,\n", " max_agents_per_env=NUM_AGENTS,\n", " simulation_mode=SIMULATION_MODE,\n", @@ -83,7 +61,6 @@ " reward_conditioning=REWARD_CONDITIONING,\n", " reward_randomization=REWARD_RANDOMIZATION,\n", " target_type=TARGET_TYPE,\n", - " map_dir=MAP_DIR,\n", " collision_behavior=COLLISION_BEHAVIOR,\n", " offroad_behavior=OFFROAD_BEHAVIOR,\n", " obs_slots_lane_n=MAX_LANES,\n", @@ -92,14 +69,9 @@ " obs_slots_traffic_controls_n=MAX_TRAFFIC,\n", " seed=SEED,\n", ")\n", - "obs, info = env.reset(seed=SEED)\n", + "globals().update(notebook_dims(env))\n", "\n", - "# --- Derived from env ---\n", - "MAX_TARGET = env.num_target_waypoints\n", - "TARGET_F = binding.STATIC_TARGET_FEATURES if TARGET_TYPE == \"static\" else binding.DYNAMIC_TARGET_FEATURES\n", - "TARGET_DIM = MAX_TARGET * TARGET_F\n", - "\n", - "print(f\"obs shape: {obs.shape}, dtype: {obs.dtype}\")\n", + "print(f\"env ready: {N} agents, obs={obs.shape}, act_shape={ACT_SHAPE}\")\n", "print(f\"EGO_DIM={EGO_DIM}, NUM_COEFS={NUM_COEFS}, MAX_PARTNERS={MAX_PARTNERS}, PARTNER_F={PARTNER_F}\")\n", "print(f\"MAX_LANES={MAX_LANES}, MAX_BOUNDS={MAX_BOUNDS}, ROAD_F={ROAD_F}\")\n", "print(f\"MAX_TRAFFIC={MAX_TRAFFIC}, TRAFFIC_F={TRAFFIC_CONTROL_F}\")" @@ -119,7 +91,7 @@ "outputs": [], "source": [ "# Take first step so obs are populated\n", - "actions = np.zeros([env.num_agents, 1], dtype=np.int64)\n", + "actions = zero_actions(env)\n", "\n", "obs, rew, term, trunc, info = env.step(actions)\n", "\n", @@ -157,7 +129,6 @@ "source": [ "ego, target, partners, lanes, boundaries, traffic = pufferlib.viz.unpack_obs(\n", " obs[:1],\n", - " dynamics_model=DYNAMICS_MODEL,\n", " target_type=TARGET_TYPE,\n", " reward_conditioning=REWARD_CONDITIONING,\n", " num_target_waypoints=env.num_target_waypoints,\n", @@ -398,7 +369,6 @@ "source": [ "img = pufferlib.viz.plot_observation(\n", " obs[:1],\n", - " dynamics_model=DYNAMICS_MODEL,\n", " target_type=TARGET_TYPE,\n", " reward_conditioning=True,\n", " num_target_waypoints=env.num_target_waypoints,\n", @@ -472,7 +442,7 @@ "ego_history = np.zeros((N_STEPS, EGO_DIM))\n", "\n", "for t in range(N_STEPS):\n", - " actions = np.zeros([env.num_agents, 1], dtype=np.int64)\n", + " actions = zero_actions(env)\n", " obs_t, _, _, _, _ = env.step(actions)\n", " ego_history[t] = obs_t[0, :EGO_DIM]\n", "\n", diff --git a/notebooks/02_rewards.ipynb b/notebooks/02_rewards.ipynb index 32894dda91..41a8bc6b06 100644 --- a/notebooks/02_rewards.ipynb +++ b/notebooks/02_rewards.ipynb @@ -14,12 +14,18 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", - "from pufferlib.ocean.drive.drive import Drive\n", "from pufferlib.ocean.drive import binding\n", "import pufferlib.viz\n", + "from notebooks.notebook_utils import (\n", + " COEF_NAMES,\n", + " EGO_LABELS,\n", + " make_drive_env,\n", + " notebook_dims,\n", + " random_actions,\n", + " zero_actions,\n", + ")\n", "\n", "# --- Environment configuration ---\n", "NUM_AGENTS = 64\n", @@ -35,45 +41,15 @@ "COLLISION_BEHAVIOR = 1\n", "OFFROAD_BEHAVIOR = 1\n", "SEED = 42\n", - "MAP_DIR = \"../pufferlib/resources/drive/binaries/carla\"\n", "\n", - "# --- Observation dimensions (configurable) ---\n", + "# --- Observation dimensions ---\n", "MAX_PARTNERS = 16\n", "MAX_LANES = 32\n", "MAX_BOUNDS = 32\n", "MAX_TRAFFIC = 10\n", - "MAX_STOP_SIGNS = 0\n", - "\n", - "# --- Derived from binding (compile-time) ---\n", - "EGO_DIM = binding.EGO_FEATURES_JERK\n", - "NUM_COEFS = binding.NUM_REWARD_COEFS\n", - "PARTNER_F = binding.PARTNER_FEATURES\n", - "ROAD_F = binding.ROAD_FEATURES\n", - "TRAFFIC_CONTROL_F = binding.TRAFFIC_CONTROL_FEATURES\n", - "NUM_TRAFFIC_CONTROL_TYPES = binding.NUM_TRAFFIC_CONTROL_TYPES\n", - "COEF_NAMES = [\n", - " \"goal_radius\",\n", - " \"collision\",\n", - " \"offroad\",\n", - " \"comfort\",\n", - " \"lane_align\",\n", - " \"lane_center\",\n", - " \"velocity\",\n", - " \"traffic_light\",\n", - " \"center_bias\",\n", - " \"vel_align\",\n", - " \"overspeed\",\n", - " \"timestep\",\n", - " \"reverse\",\n", - " \"throttle\",\n", - " \"steer\",\n", - " \"acc\",\n", - "]\n", "\n", - "# --- Create environment ---\n", - "env = Drive(\n", + "env, obs, info = make_drive_env(\n", " num_agents=NUM_AGENTS,\n", - " num_maps=1,\n", " min_agents_per_env=NUM_AGENTS,\n", " max_agents_per_env=NUM_AGENTS,\n", " simulation_mode=SIMULATION_MODE,\n", @@ -85,25 +61,20 @@ " reward_conditioning=REWARD_CONDITIONING,\n", " reward_randomization=REWARD_RANDOMIZATION,\n", " target_type=TARGET_TYPE,\n", - " map_dir=MAP_DIR,\n", " collision_behavior=COLLISION_BEHAVIOR,\n", " offroad_behavior=OFFROAD_BEHAVIOR,\n", " obs_slots_lane_n=MAX_LANES,\n", " obs_slots_boundary_n=MAX_BOUNDS,\n", " obs_slots_partners_n=MAX_PARTNERS,\n", + " obs_slots_traffic_controls_n=MAX_TRAFFIC,\n", " seed=SEED,\n", ")\n", - "obs, info = env.reset(seed=SEED)\n", - "\n", - "# --- Derived from env ---\n", - "MAX_TARGET = env.num_target_waypoints\n", - "TARGET_F = binding.STATIC_TARGET_FEATURES if TARGET_TYPE == \"static\" else binding.DYNAMIC_TARGET_FEATURES\n", - "TARGET_DIM = MAX_TARGET * TARGET_F\n", - "N_ACTIONS = 12\n", - "N = env.num_agents\n", - "ACT_SHAPE = (N, len(env.single_action_space.nvec))\n", + "globals().update(notebook_dims(env))\n", "\n", - "print(f\"env ready: {N} agents, obs={obs.shape}, act_shape={ACT_SHAPE}\")" + "print(f\"env ready: {N} agents, obs={obs.shape}, act_shape={ACT_SHAPE}\")\n", + "print(f\"EGO_DIM={EGO_DIM}, NUM_COEFS={NUM_COEFS}, MAX_PARTNERS={MAX_PARTNERS}, PARTNER_F={PARTNER_F}\")\n", + "print(f\"MAX_LANES={MAX_LANES}, MAX_BOUNDS={MAX_BOUNDS}, ROAD_F={ROAD_F}\")\n", + "print(f\"MAX_TRAFFIC={MAX_TRAFFIC}, TRAFFIC_F={TRAFFIC_CONTROL_F}\")" ] }, { @@ -119,7 +90,7 @@ "metadata": {}, "outputs": [], "source": [ - "actions = np.zeros(ACT_SHAPE, dtype=np.int64)\n", + "actions = zero_actions(env)\n", "obs, rew, term, trunc, info = env.step(actions)\n", "\n", "print(f\"reward shape: {rew.shape}\")\n", @@ -155,7 +126,7 @@ "terms_history = np.zeros((N_STEPS, N))\n", "\n", "for t in range(N_STEPS):\n", - " actions = np.random.randint(0, N_ACTIONS, size=ACT_SHAPE)\n", + " actions = random_actions(env)\n", " obs, rew, term, trunc, info = env.step(actions)\n", " rewards_history[t] = rew\n", " terms_history[t] = term\n", @@ -228,7 +199,7 @@ "term_rewards, trunc_rewards = [], []\n", "\n", "for t in range(N_STEPS):\n", - " actions = np.random.randint(0, N_ACTIONS, size=ACT_SHAPE)\n", + " actions = random_actions(env)\n", " obs, rew, term, trunc, info = env.step(actions)\n", " for i in range(N):\n", " if term[i]:\n", @@ -279,11 +250,12 @@ "\n", "for t in range(N_STEPS):\n", " prev_obs = obs.copy()\n", - " actions = np.random.randint(0, N_ACTIONS, size=ACT_SHAPE)\n", + " actions = random_actions(env)\n", " obs, rew, term, trunc, info = env.step(actions)\n", " for i in range(N):\n", " if rew[i] >= 0.5:\n", - " goal_dist = np.sqrt(prev_obs[i, 0] ** 2 + prev_obs[i, 1] ** 2)\n", + " target_start = EGO_DIM + NUM_COEFS\n", + " goal_dist = np.sqrt(prev_obs[i, target_start] ** 2 + prev_obs[i, target_start + 1] ** 2)\n", " goal_events.append((t, i, rew[i], goal_dist))\n", "\n", "print(f\"Goal-like events (reward >= 0.5): {len(goal_events)}\")\n", diff --git a/notebooks/03_metrics.ipynb b/notebooks/03_metrics.ipynb index 16bef2c118..30b893c781 100644 --- a/notebooks/03_metrics.ipynb +++ b/notebooks/03_metrics.ipynb @@ -16,9 +16,16 @@ "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", - "from pufferlib.ocean.drive.drive import Drive\n", "from pufferlib.ocean.drive import binding\n", "import pufferlib.viz\n", + "from notebooks.notebook_utils import (\n", + " COEF_NAMES,\n", + " EGO_LABELS,\n", + " make_drive_env,\n", + " notebook_dims,\n", + " random_actions,\n", + " zero_actions,\n", + ")\n", "\n", "# --- Environment configuration ---\n", "NUM_AGENTS = 64\n", @@ -34,45 +41,15 @@ "COLLISION_BEHAVIOR = 1\n", "OFFROAD_BEHAVIOR = 1\n", "SEED = 42\n", - "MAP_DIR = \"../pufferlib/resources/drive/binaries/carla\"\n", "\n", - "# --- Observation dimensions (configurable) ---\n", + "# --- Observation dimensions ---\n", "MAX_PARTNERS = 16\n", "MAX_LANES = 32\n", "MAX_BOUNDS = 32\n", "MAX_TRAFFIC = 10\n", - "MAX_STOP_SIGNS = 0\n", - "\n", - "# --- Derived from binding (compile-time) ---\n", - "EGO_DIM = binding.EGO_FEATURES_JERK\n", - "NUM_COEFS = binding.NUM_REWARD_COEFS\n", - "PARTNER_F = binding.PARTNER_FEATURES\n", - "ROAD_F = binding.ROAD_FEATURES\n", - "TRAFFIC_CONTROL_F = binding.TRAFFIC_CONTROL_FEATURES\n", - "NUM_TRAFFIC_CONTROL_TYPES = binding.NUM_TRAFFIC_CONTROL_TYPES\n", - "COEF_NAMES = [\n", - " \"goal_radius\",\n", - " \"collision\",\n", - " \"offroad\",\n", - " \"comfort\",\n", - " \"lane_align\",\n", - " \"lane_center\",\n", - " \"velocity\",\n", - " \"traffic_light\",\n", - " \"center_bias\",\n", - " \"vel_align\",\n", - " \"overspeed\",\n", - " \"timestep\",\n", - " \"reverse\",\n", - " \"throttle\",\n", - " \"steer\",\n", - " \"acc\",\n", - "]\n", "\n", - "# --- Create environment ---\n", - "env = Drive(\n", + "env, obs, info = make_drive_env(\n", " num_agents=NUM_AGENTS,\n", - " num_maps=1,\n", " min_agents_per_env=NUM_AGENTS,\n", " max_agents_per_env=NUM_AGENTS,\n", " simulation_mode=SIMULATION_MODE,\n", @@ -84,25 +61,20 @@ " reward_conditioning=REWARD_CONDITIONING,\n", " reward_randomization=REWARD_RANDOMIZATION,\n", " target_type=TARGET_TYPE,\n", - " map_dir=MAP_DIR,\n", " collision_behavior=COLLISION_BEHAVIOR,\n", " offroad_behavior=OFFROAD_BEHAVIOR,\n", " obs_slots_lane_n=MAX_LANES,\n", " obs_slots_boundary_n=MAX_BOUNDS,\n", " obs_slots_partners_n=MAX_PARTNERS,\n", + " obs_slots_traffic_controls_n=MAX_TRAFFIC,\n", " seed=SEED,\n", ")\n", - "obs, info = env.reset(seed=SEED)\n", - "\n", - "# --- Derived from env ---\n", - "MAX_TARGET = env.num_target_waypoints\n", - "TARGET_F = binding.STATIC_TARGET_FEATURES if TARGET_TYPE == \"static\" else binding.DYNAMIC_TARGET_FEATURES\n", - "TARGET_DIM = MAX_TARGET * TARGET_F\n", - "N_ACTIONS = 12\n", - "N = env.num_agents\n", - "ACT_SHAPE = (N, len(env.single_action_space.nvec))\n", + "globals().update(notebook_dims(env))\n", "\n", - "print(f\"env ready: {N} agents, act_shape={ACT_SHAPE}\")" + "print(f\"env ready: {N} agents, obs={obs.shape}, act_shape={ACT_SHAPE}\")\n", + "print(f\"EGO_DIM={EGO_DIM}, NUM_COEFS={NUM_COEFS}, MAX_PARTNERS={MAX_PARTNERS}, PARTNER_F={PARTNER_F}\")\n", + "print(f\"MAX_LANES={MAX_LANES}, MAX_BOUNDS={MAX_BOUNDS}, ROAD_F={ROAD_F}\")\n", + "print(f\"MAX_TRAFFIC={MAX_TRAFFIC}, TRAFFIC_F={TRAFFIC_CONTROL_F}\")" ] }, { @@ -119,7 +91,7 @@ "outputs": [], "source": [ "for _ in range(10):\n", - " actions = np.random.randint(0, N_ACTIONS, size=ACT_SHAPE)\n", + " actions = random_actions(env)\n", " obs, rew, term, trunc, info = env.step(actions)\n", "\n", "log = binding.vec_log(env.c_envs, N)\n", @@ -152,7 +124,7 @@ "all_truncs = np.zeros((N_STEPS, N))\n", "\n", "for t in range(N_STEPS):\n", - " actions = np.random.randint(0, N_ACTIONS, size=ACT_SHAPE)\n", + " actions = random_actions(env)\n", " obs, rew, term, trunc, info = env.step(actions)\n", " all_rewards[t] = rew\n", " all_terms[t] = term\n", @@ -242,7 +214,7 @@ "xy_history = np.zeros((TRACK_STEPS, TRACK_AGENTS, 2))\n", "\n", "for t in range(TRACK_STEPS):\n", - " actions = np.random.randint(0, N_ACTIONS, size=ACT_SHAPE)\n", + " actions = random_actions(env)\n", " env.step(actions)\n", " states = env.get_global_agent_state()\n", " for i in range(TRACK_AGENTS):\n", diff --git a/notebooks/04_training.ipynb b/notebooks/04_training.ipynb index d48c5ccadd..da78c7c702 100644 --- a/notebooks/04_training.ipynb +++ b/notebooks/04_training.ipynb @@ -14,14 +14,12 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn.functional as F\n", - "from pufferlib.ocean.drive.drive import Drive\n", "from pufferlib.ocean.drive import binding\n", - "from pufferlib.ocean.torch import Drive as DrivePolicy\n", + "from notebooks.notebook_utils import make_drive_env, make_drive_policy, notebook_dims, random_actions, zero_actions\n", "\n", "# --- Environment configuration ---\n", "NUM_AGENTS = 64\n", @@ -37,27 +35,13 @@ "COLLISION_BEHAVIOR = 1\n", "OFFROAD_BEHAVIOR = 1\n", "SEED = 42\n", - "MAP_DIR = \"../pufferlib/resources/drive/binaries/carla\"\n", - "\n", - "# --- Observation dimensions (configurable) ---\n", "MAX_PARTNERS = 16\n", "MAX_LANES = 32\n", "MAX_BOUNDS = 32\n", "MAX_TRAFFIC = 10\n", - "MAX_STOP_SIGNS = 0\n", - "\n", - "# --- Derived from binding (compile-time) ---\n", - "EGO_DIM = binding.EGO_FEATURES_JERK\n", - "NUM_COEFS = binding.NUM_REWARD_COEFS\n", - "PARTNER_F = binding.PARTNER_FEATURES\n", - "ROAD_F = binding.ROAD_FEATURES\n", - "TRAFFIC_CONTROL_F = binding.TRAFFIC_CONTROL_FEATURES\n", - "NUM_TRAFFIC_CONTROL_TYPES = binding.NUM_TRAFFIC_CONTROL_TYPES\n", - "\n", - "# --- Create environment ---\n", - "env = Drive(\n", + "\n", + "env, obs, info = make_drive_env(\n", " num_agents=NUM_AGENTS,\n", - " num_maps=1,\n", " min_agents_per_env=NUM_AGENTS,\n", " max_agents_per_env=NUM_AGENTS,\n", " simulation_mode=SIMULATION_MODE,\n", @@ -69,39 +53,18 @@ " reward_conditioning=REWARD_CONDITIONING,\n", " reward_randomization=REWARD_RANDOMIZATION,\n", " target_type=TARGET_TYPE,\n", - " map_dir=MAP_DIR,\n", " collision_behavior=COLLISION_BEHAVIOR,\n", " offroad_behavior=OFFROAD_BEHAVIOR,\n", " obs_slots_lane_n=MAX_LANES,\n", " obs_slots_boundary_n=MAX_BOUNDS,\n", " obs_slots_partners_n=MAX_PARTNERS,\n", + " obs_slots_traffic_controls_n=MAX_TRAFFIC,\n", " seed=SEED,\n", ")\n", - "obs, info = env.reset(seed=SEED)\n", - "\n", - "# --- Derived from env ---\n", - "MAX_TARGET = env.num_target_waypoints\n", - "TARGET_F = binding.STATIC_TARGET_FEATURES if TARGET_TYPE == \"static\" else binding.DYNAMIC_TARGET_FEATURES\n", - "TARGET_DIM = MAX_TARGET * TARGET_F\n", - "N_ACTIONS = 12\n", - "N = env.num_agents\n", - "ACT_SHAPE = (N, len(env.single_action_space.nvec))\n", + "globals().update(notebook_dims(env))\n", "\n", - "# --- Policy ---\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "policy = DrivePolicy(\n", - " env,\n", - " input_size=64,\n", - " backbone_hidden_size=128,\n", - " backbone_num_layers=1,\n", - " actor_hidden_size=128,\n", - " actor_num_layers=0,\n", - " critic_hidden_size=128,\n", - " critic_num_layers=0,\n", - " encoder_gigaflow=True,\n", - " dropout=0.0,\n", - " split_network=False,\n", - ").to(device)\n", + "policy = make_drive_policy(env, device)\n", "print(f\"Policy on {device}, params: {sum(p.numel() for p in policy.parameters()):,}\")\n", "print(f\"Action dim: {policy.atn_dim}, act_shape: {ACT_SHAPE}\")" ] @@ -138,7 +101,7 @@ "metadata": {}, "outputs": [], "source": [ - "actions = np.zeros(ACT_SHAPE, dtype=np.int64)\n", + "actions = zero_actions(env)\n", "obs, rew, term, trunc, info = env.step(actions)\n", "\n", "obs_tensor = torch.FloatTensor(obs).to(device)\n", diff --git a/notebooks/05_inference.ipynb b/notebooks/05_inference.ipynb index 434c63ea33..c92d5b66fe 100644 --- a/notebooks/05_inference.ipynb +++ b/notebooks/05_inference.ipynb @@ -16,93 +16,37 @@ "metadata": {}, "outputs": [], "source": [ - "import os, ast, glob, yaml, configparser\n", - "from collections import defaultdict\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn.functional as F\n", + "from pufferlib.ocean.drive.drive import Drive\n", + "from pufferlib.ocean.drive import binding\n", + "from pufferlib.ocean.torch import Drive as DrivePolicy, Recurrent\n", + "import pufferlib.pytorch\n", + "from notebooks.notebook_utils import (\n", + " COEF_NAMES,\n", + " EGO_LABELS,\n", + " MAP_DIR,\n", + " load_notebook_config,\n", + " make_rnn_state,\n", + " notebook_dims,\n", + " zero_actions,\n", + ")\n", "\n", - "while not os.path.exists(\"pufferlib\") and os.getcwd() != \"/\":\n", - " os.chdir(\"..\")\n", - "print(\"CWD:\", os.getcwd())\n", - "\n", - "# --- Config ---\n", - "CHECKPOINT_PATH = \"runs2/big-bs//models/model_puffer_drive_004262.pt\"\n", + "CHECKPOINT_PATH = \"/home/o-vcharrau/Workspace/PufferDrive-Valeo/runs/tomate/models/model_puffer_drive_013100.pt\"\n", "ENV_NAME = \"puffer_drive\"\n", "\n", - "\n", - "def load_notebook_config(checkpoint_path=None, env_name=\"puffer_drive\"):\n", - " \"\"\"Load config from INI defaults, optionally overlaying checkpoint's config.yaml.\"\"\"\n", - " default_ini = \"pufferlib/config/default.ini\"\n", - " env_ini = None\n", - " for path in glob.glob(\"pufferlib/config/**/*.ini\", recursive=True):\n", - " p = configparser.ConfigParser()\n", - " p.read([default_ini, path])\n", - " if p.has_option(\"base\", \"env_name\") and env_name in p[\"base\"][\"env_name\"]:\n", - " env_ini = path\n", - " break\n", - " assert env_ini, f\"No config for {env_name}\"\n", - "\n", - " def parse_val(v):\n", - " try:\n", - " return ast.literal_eval(v)\n", - " except:\n", - " return v\n", - "\n", - " args = defaultdict(dict)\n", - " for section in p.sections():\n", - " for key in p[section]:\n", - " val = parse_val(p[section][key])\n", - " if section == \"base\":\n", - " args[key] = val\n", - " else:\n", - " args[section][key] = val\n", - "\n", - " # Overlay checkpoint config.yaml if exists\n", - " if checkpoint_path:\n", - " exp_dir = os.path.dirname(os.path.dirname(checkpoint_path))\n", - " cfg_yaml = os.path.join(exp_dir, \"config.yaml\")\n", - " if os.path.exists(cfg_yaml):\n", - " print(f\"Loading config.yaml from {cfg_yaml}\")\n", - " with open(cfg_yaml) as f:\n", - " ycfg = yaml.safe_load(f)\n", - " for section in [\"env\", \"train\", \"policy\", \"rnn\"]:\n", - " if section in ycfg and isinstance(ycfg[section], dict):\n", - " for k, v in ycfg[section].items():\n", - " args[section][k] = v\n", - "\n", - " args[\"train\"][\"use_rnn\"] = args.get(\"rnn_name\") is not None\n", - " return dict(args)\n", - "\n", - "\n", "config = load_notebook_config(CHECKPOINT_PATH, ENV_NAME)\n", - "\n", - "# --- Env ---\n", - "from pufferlib.ocean.drive.drive import Drive\n", - "from pufferlib.ocean.drive import binding\n", - "\n", - "# Override for notebook: fewer agents, single env\n", "config[\"env\"][\"num_agents\"] = 64\n", "config[\"env\"][\"num_maps\"] = 8\n", "config[\"env\"][\"eval_mode\"] = 1\n", - "config[\"env\"][\"map_dir\"] = \"pufferlib/resources/drive/binaries/carla\"\n", + "config[\"env\"][\"map_dir\"] = MAP_DIR\n", "\n", "env = Drive(**config[\"env\"])\n", "obs, info = env.reset(seed=42)\n", "N = env.num_agents\n", - "\n", - "# --- Derived from binding ---\n", - "EGO_DIM = binding.EGO_FEATURES_JERK\n", - "NUM_COEFS = binding.NUM_REWARD_COEFS\n", - "PARTNER_F = binding.PARTNER_FEATURES\n", - "ROAD_F = binding.ROAD_FEATURES\n", - "TRAFFIC_CONTROL_F = binding.TRAFFIC_CONTROL_FEATURES\n", - "NUM_TRAFFIC_CONTROL_TYPES = binding.NUM_TRAFFIC_CONTROL_TYPES\n", - "\n", - "# --- Policy ---\n", - "from pufferlib.ocean.torch import Drive as DrivePolicy, Recurrent\n", - "import pufferlib.pytorch\n", + "globals().update(notebook_dims(env))\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "policy = DrivePolicy(env, **config[\"policy\"]).to(device)\n", @@ -110,18 +54,16 @@ "if use_rnn:\n", " policy = Recurrent(env, policy, **config[\"rnn\"]).to(device)\n", "\n", - "# Load checkpoint weights if provided\n", "if CHECKPOINT_PATH:\n", " sd = torch.load(CHECKPOINT_PATH, map_location=device)\n", " sd = {k.replace(\"module.\", \"\"): v for k, v in sd.items()}\n", " policy.load_state_dict(sd)\n", " print(f\"Loaded checkpoint: {CHECKPOINT_PATH}\")\n", "\n", - "\n", - "# Action shape\n", "inner_policy = policy.policy if use_rnn else policy\n", "is_continuous = inner_policy.is_continuous\n", "ACT_SHAPE = (N, len(env.single_action_space.nvec)) if not is_continuous else (N, env.single_action_space.shape[0])\n", + "state = make_rnn_state(policy, N, device) if use_rnn else None\n", "\n", "print(f\"Policy on {device}, params: {sum(p.numel() for p in policy.parameters()):,}\")\n", "print(f\"Obs shape: {obs.shape}, Action space: {env.single_action_space}\")\n", @@ -144,7 +86,7 @@ "outputs": [], "source": [ "# Take one step to get fresh obs\n", - "actions = np.zeros(ACT_SHAPE, dtype=np.int64 if not is_continuous else np.float32)\n", + "actions = zero_actions(env)\n", "obs, rew, term, trunc, info = env.step(actions)\n", "\n", "obs_tensor = torch.FloatTensor(obs).to(device)\n", @@ -217,12 +159,7 @@ "def run_rollout(env, policy, deterministic=False, horizon=HORIZON):\n", " obs, _ = env.reset(seed=42)\n", " N = env.num_agents\n", - " st = {}\n", - " if use_rnn:\n", - " st = {\n", - " \"lstm_h\": torch.zeros(N, hidden_size, device=device),\n", - " \"lstm_c\": torch.zeros(N, hidden_size, device=device),\n", - " }\n", + " st = make_rnn_state(policy, N, device) if use_rnn else None\n", "\n", " buffers = {\n", " \"obs\": np.zeros((horizon, N, obs_dim), dtype=np.float32),\n", @@ -309,7 +246,6 @@ "print(dyn_model, tgt_type, rew_cond, n_tgt_wp)\n", "img = plot_observation(\n", " sample_obs,\n", - " dynamics_model=dyn_model,\n", " target_type=tgt_type,\n", " reward_conditioning=rew_cond,\n", " num_target_waypoints=n_tgt_wp,\n", @@ -343,7 +279,6 @@ "for t in range(HORIZON):\n", " ego, *_ = unpack_obs(\n", " buf_stoch[\"obs\"][t : t + 1, TRACKED_AGENT : TRACKED_AGENT + 1][0],\n", - " dynamics_model=dyn_model,\n", " target_type=tgt_type,\n", " reward_conditioning=rew_cond,\n", " num_target_waypoints=n_tgt_wp,\n", @@ -383,7 +318,7 @@ "\n", "Obs layout (all ego-centric, normalized):\n", "- **Ego**: speed, width, length, [jerk: steering, a_long, a_lat], lane_center_dist, lane_angle, speed_limit\n", - "- **Conditioning** (if enabled): 16 reward coefs (goal_radius, collision, offroad, comfort, lane_align, lane_center, velocity, traffic_light, center_bias, vel_align, overspeed, timestep, reverse, throttle, steer, acc) + target waypoints\n", + "- **Conditioning** (if enabled): 17 reward coefs (goal_radius, goal_speed, collision, offroad, comfort, lane_align, vel_align, lane_center, center_bias, velocity, reverse, stop_line, timestep, overspeed, throttle, steer, acc) + target waypoints\n", "- **Target**: static=rel_x,rel_y,rel_z per waypoint; dynamic=rel_x,rel_y,rel_z,heading_cos,heading_sin per waypoint\n", "- **Partners** (MAX_PARTNERS x 8): rel_x, rel_y, rel_z, length, width, heading_cos, heading_sin, speed\n", "- **Lanes** (MAX_LANES x 7): rel_x, rel_y, rel_z, seg_length, seg_width, dir_cos, dir_sin\n", @@ -404,7 +339,6 @@ "sample_obs = buf_stoch[\"obs\"][sample_t : sample_t + 1, TRACKED_AGENT : TRACKED_AGENT + 1][0]\n", "ego, target, partners, lanes, boundaries, traffic_controls = unpack_obs(\n", " sample_obs,\n", - " dynamics_model=dyn_model,\n", " target_type=tgt_type,\n", " reward_conditioning=rew_cond,\n", " num_target_waypoints=n_tgt_wp,\n", @@ -415,7 +349,7 @@ ")\n", "\n", "# Also unpack conditioning manually (unpack_obs doesn't return it separately)\n", - "ego_dim = binding.EGO_FEATURES_JERK if dyn_model == \"jerk\" else binding.EGO_FEATURES_CLASSIC\n", + "ego_dim = binding.EGO_FEATURES\n", "cond_dim = binding.NUM_REWARD_COEFS if rew_cond else 0\n", "cond_obs = sample_obs[0, ego_dim : ego_dim + cond_dim] if cond_dim > 0 else None\n", "\n", @@ -448,10 +382,7 @@ "layer_stats(\"TrafficControls\", traffic_controls)\n", "\n", "# --- Ego features detail ---\n", - "if dyn_model == \"jerk\":\n", - " ego_labels = [\"speed\", \"width\", \"length\", \"steering\", \"a_long\", \"a_lat\", \"lane_center\", \"lane_align\", \"speed_limit\"]\n", - "else:\n", - " ego_labels = [\"speed\", \"width\", \"length\", \"lane_center\", \"lane_align\", \"speed_limit\"]\n", + "ego_labels = EGO_LABELS\n", "\n", "print(f\"\\n--- Ego features ---\")\n", "for i, (label, val) in enumerate(zip(ego_labels, ego)):\n", @@ -459,24 +390,7 @@ "\n", "# --- Conditioning detail ---\n", "if cond_obs is not None:\n", - " cond_labels = [\n", - " \"goal_radius\",\n", - " \"collision\",\n", - " \"offroad\",\n", - " \"comfort\",\n", - " \"lane_align\",\n", - " \"lane_center\",\n", - " \"velocity\",\n", - " \"traffic_light\",\n", - " \"center_bias\",\n", - " \"vel_align\",\n", - " \"overspeed\",\n", - " \"timestep\",\n", - " \"reverse\",\n", - " \"throttle\",\n", - " \"steer\",\n", - " \"acc\",\n", - " ]\n", + " cond_labels = COEF_NAMES\n", " print(f\"\\n--- Conditioning (reward coefs, normalized) ---\")\n", " for i, (label, val) in enumerate(zip(cond_labels, cond_obs)):\n", " print(f\" [{i:>2d}] {label:>16s} = {val:.4f}\")\n", @@ -531,7 +445,7 @@ "# --- Layer-level stats across ALL agents at sample_t ---\n", "all_obs = buf_stoch[\"obs\"][sample_t] # (N, obs_dim)\n", "\n", - "ego_dim = binding.EGO_FEATURES_JERK if dyn_model == \"jerk\" else binding.EGO_FEATURES_CLASSIC\n", + "ego_dim = binding.EGO_FEATURES\n", "cond_dim = binding.NUM_REWARD_COEFS if rew_cond else 0\n", "tgt_feat = binding.STATIC_TARGET_FEATURES if tgt_type == \"static\" else binding.DYNAMIC_TARGET_FEATURES\n", "tgt_dim = n_tgt_wp * tgt_feat\n", @@ -618,7 +532,6 @@ " ob = bufs[\"obs\"][t : t + 1, agent_idx : agent_idx + 1][0]\n", " ego, tgt, part, lane, bnd, tfc = unpack_obs(\n", " ob,\n", - " dynamics_model=dyn_model,\n", " target_type=tgt_type,\n", " reward_conditioning=rew_cond,\n", " num_target_waypoints=n_tgt_wp,\n", @@ -635,7 +548,7 @@ " n_traffic.append(np.sum(np.any(tfc != 0, axis=1)))\n", "\n", " if rew_cond:\n", - " ed = binding.EGO_FEATURES_JERK if dyn_model == \"jerk\" else binding.EGO_FEATURES_CLASSIC\n", + " ed = binding.EGO_FEATURES\n", " conds.append(ob[0, ed : ed + binding.NUM_REWARD_COEFS])\n", "\n", " return {\n", @@ -678,24 +591,7 @@ "\n", "# Conditioning heatmap over time\n", "if ts[\"cond\"] is not None:\n", - " cond_labels = [\n", - " \"goal_rad\",\n", - " \"coll\",\n", - " \"offrd\",\n", - " \"comf\",\n", - " \"l_align\",\n", - " \"l_ctr\",\n", - " \"vel\",\n", - " \"traf\",\n", - " \"c_bias\",\n", - " \"v_align\",\n", - " \"ovspd\",\n", - " \"tstep\",\n", - " \"rev\",\n", - " \"throt\",\n", - " \"steer\",\n", - " \"acc\",\n", - " ]\n", + " cond_labels = COEF_NAMES\n", " im = axes[1, 0].imshow(ts[\"cond\"].T, aspect=\"auto\", cmap=\"coolwarm\", interpolation=\"nearest\")\n", " axes[1, 0].set_yticks(range(len(cond_labels)))\n", " axes[1, 0].set_yticklabels(cond_labels, fontsize=8)\n", @@ -712,7 +608,6 @@ " ob = buf_stoch[\"obs\"][t : t + 1, TRACKED_AGENT : TRACKED_AGENT + 1][0]\n", " _, _, part, _, _, _ = unpack_obs(\n", " ob,\n", - " dynamics_model=dyn_model,\n", " target_type=tgt_type,\n", " reward_conditioning=rew_cond,\n", " num_target_waypoints=n_tgt_wp,\n", @@ -898,13 +793,10 @@ "outputs": [], "source": [ "# Ego feature distributions across all agents, pooled over full rollout\n", - "ego_dim = binding.EGO_FEATURES_JERK if dyn_model == \"jerk\" else binding.EGO_FEATURES_CLASSIC\n", + "ego_dim = binding.EGO_FEATURES\n", "all_ego = buf_stoch[\"obs\"][:, :, :ego_dim].reshape(-1, ego_dim) # (H*N, ego_dim)\n", "\n", - "if dyn_model == \"jerk\":\n", - " ego_labels = [\"speed\", \"width\", \"length\", \"steering\", \"a_long\", \"a_lat\", \"lane_center\", \"lane_align\", \"speed_limit\"]\n", - "else:\n", - " ego_labels = [\"speed\", \"width\", \"length\", \"lane_center\", \"lane_align\", \"speed_limit\"]\n", + "ego_labels = EGO_LABELS\n", "\n", "fig, axes = plt.subplots(2, len(ego_labels), figsize=(3.5 * len(ego_labels), 7))\n", "\n", @@ -943,24 +835,7 @@ " cond_end = cond_start + binding.NUM_REWARD_COEFS\n", " all_cond = buf_stoch[\"obs\"][:, :, cond_start:cond_end].reshape(-1, binding.NUM_REWARD_COEFS)\n", "\n", - " cond_labels = [\n", - " \"goal_rad\",\n", - " \"coll\",\n", - " \"offrd\",\n", - " \"comf\",\n", - " \"l_align\",\n", - " \"l_ctr\",\n", - " \"vel\",\n", - " \"traf\",\n", - " \"c_bias\",\n", - " \"v_align\",\n", - " \"ovspd\",\n", - " \"tstep\",\n", - " \"rev\",\n", - " \"throt\",\n", - " \"steer\",\n", - " \"acc\",\n", - " ]\n", + " cond_labels = COEF_NAMES\n", "\n", " fig, ax = plt.subplots(figsize=(14, 5))\n", " parts = ax.violinplot(\n", @@ -999,7 +874,7 @@ "pf = PARTNER_F\n", "\n", "# Compute slices\n", - "_ego_d = binding.EGO_FEATURES_JERK if dyn_model == \"jerk\" else binding.EGO_FEATURES_CLASSIC\n", + "_ego_d = binding.EGO_FEATURES\n", "_cond_d = binding.NUM_REWARD_COEFS if rew_cond else 0\n", "_tgt_f = binding.STATIC_TARGET_FEATURES if tgt_type == \"static\" else binding.DYNAMIC_TARGET_FEATURES\n", "_tgt_d = n_tgt_wp * _tgt_f\n", diff --git a/notebooks/06_architecture.ipynb b/notebooks/06_architecture.ipynb index bdd5810763..72c53c004a 100644 --- a/notebooks/06_architecture.ipynb +++ b/notebooks/06_architecture.ipynb @@ -19,9 +19,9 @@ "import torch\n", "import torch.nn.functional as F\n", "from torchinfo import summary\n", - "from pufferlib.ocean.drive.drive import Drive\n", "from pufferlib.ocean.drive import binding\n", - "from pufferlib.ocean.torch import Drive as DrivePolicy, DriveBackbone, Recurrent\n", + "from pufferlib.ocean.torch import Drive as DrivePolicy\n", + "from notebooks.notebook_utils import make_drive_env, notebook_dims, zero_actions\n", "\n", "# --- Environment configuration ---\n", "NUM_AGENTS = 64\n", @@ -37,14 +37,10 @@ "COLLISION_BEHAVIOR = 1\n", "OFFROAD_BEHAVIOR = 1\n", "SEED = 42\n", - "MAP_DIR = \"../pufferlib/resources/drive/binaries/carla\"\n", - "\n", - "# --- Observation dimensions ---\n", "MAX_PARTNERS = 20\n", "MAX_LANES = 100\n", "MAX_BOUNDS = 50\n", "MAX_TRAFFIC = 4\n", - "MAX_STOP_SIGNS = 0\n", "\n", "# --- Policy architecture ---\n", "INPUT_SIZE = 64\n", @@ -58,18 +54,8 @@ "ENCODER_GIGAFLOW = True\n", "DROPOUT = 0.0\n", "\n", - "# --- Derived from binding ---\n", - "EGO_DIM = binding.EGO_FEATURES_JERK\n", - "NUM_COEFS = binding.NUM_REWARD_COEFS\n", - "PARTNER_F = binding.PARTNER_FEATURES\n", - "ROAD_F = binding.ROAD_FEATURES\n", - "TRAFFIC_CONTROL_F = binding.TRAFFIC_CONTROL_FEATURES\n", - "NUM_TRAFFIC_CONTROL_TYPES = binding.NUM_TRAFFIC_CONTROL_TYPES\n", - "\n", - "# --- Create environment ---\n", - "env = Drive(\n", + "env, obs, info = make_drive_env(\n", " num_agents=NUM_AGENTS,\n", - " num_maps=1,\n", " min_agents_per_env=NUM_AGENTS,\n", " max_agents_per_env=NUM_AGENTS,\n", " simulation_mode=SIMULATION_MODE,\n", @@ -81,19 +67,15 @@ " reward_conditioning=REWARD_CONDITIONING,\n", " reward_randomization=REWARD_RANDOMIZATION,\n", " target_type=TARGET_TYPE,\n", - " map_dir=MAP_DIR,\n", " collision_behavior=COLLISION_BEHAVIOR,\n", " offroad_behavior=OFFROAD_BEHAVIOR,\n", " obs_slots_lane_n=MAX_LANES,\n", " obs_slots_boundary_n=MAX_BOUNDS,\n", " obs_slots_partners_n=MAX_PARTNERS,\n", + " obs_slots_traffic_controls_n=MAX_TRAFFIC,\n", " seed=SEED,\n", ")\n", - "obs, info = env.reset(seed=SEED)\n", - "\n", - "MAX_TARGET = env.num_target_waypoints\n", - "TARGET_F = binding.STATIC_TARGET_FEATURES if TARGET_TYPE == \"static\" else binding.DYNAMIC_TARGET_FEATURES\n", - "TARGET_DIM = MAX_TARGET * TARGET_F\n", + "globals().update(notebook_dims(env))\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "policy = DrivePolicy(\n", @@ -716,7 +698,7 @@ "outputs": [], "source": [ "# Run a few steps to get diverse observations\n", - "actions = np.zeros((NUM_AGENTS, len(env.single_action_space.nvec)), dtype=np.int64)\n", + "actions = zero_actions(env)\n", "all_obs = [obs]\n", "for _ in range(20):\n", " o, _, _, _, _ = env.step(actions)\n", @@ -764,68 +746,6 @@ "plt.show()" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## LSTM Wrapper Architecture" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "base_policy = DrivePolicy(\n", - " env,\n", - " input_size=INPUT_SIZE,\n", - " backbone_hidden_size=BACKBONE_HIDDEN_SIZE,\n", - " backbone_num_layers=BACKBONE_NUM_LAYERS,\n", - " actor_hidden_size=ACTOR_HIDDEN_SIZE,\n", - " actor_num_layers=ACTOR_NUM_LAYERS,\n", - " critic_hidden_size=CRITIC_HIDDEN_SIZE,\n", - " critic_num_layers=CRITIC_NUM_LAYERS,\n", - " encoder_gigaflow=ENCODER_GIGAFLOW,\n", - " dropout=DROPOUT,\n", - " split_network=SPLIT_NETWORK,\n", - ").to(device)\n", - "# LSTM input_size must match backbone_hidden_size (backbone output dim)\n", - "lstm_policy = Recurrent(env, base_policy, input_size=BACKBONE_HIDDEN_SIZE, hidden_size=BACKBONE_HIDDEN_SIZE).to(device)\n", - "\n", - "base_params = sum(p.numel() for p in base_policy.parameters())\n", - "lstm_params = sum(p.numel() for p in lstm_policy.parameters())\n", - "lstm_only = lstm_params - base_params\n", - "\n", - "print(f\"Base policy params: {base_params:>10,d}\")\n", - "print(f\"LSTM wrapper params: {lstm_params:>10,d}\")\n", - "print(f\"LSTM overhead: {lstm_only:>10,d} (+{lstm_only / base_params:.1%})\")\n", - "print()\n", - "\n", - "# torchinfo can't handle LSTMWrapper (requires state dict), so manual breakdown\n", - "print(f\"{'Component':>25s} | {'Params':>10s}\")\n", - "print(\"-\" * 40)\n", - "for name, module in lstm_policy.named_children():\n", - " n = sum(p.numel() for p in module.parameters())\n", - " print(f\"{name:>25s} | {n:>10,d}\")\n", - "print()\n", - "\n", - "# Verify forward pass works with proper state\n", - "# forward_eval uses LSTMCell which expects 2D state (batch, hidden)\n", - "with torch.no_grad():\n", - " state = {\n", - " \"lstm_h\": torch.zeros(NUM_AGENTS, BACKBONE_HIDDEN_SIZE, device=device),\n", - " \"lstm_c\": torch.zeros(NUM_AGENTS, BACKBONE_HIDDEN_SIZE, device=device),\n", - " }\n", - " actions, value = lstm_policy.forward_eval(obs_tensor, state)\n", - " if isinstance(actions, (list, tuple)):\n", - " print(f\"Forward eval OK: actions={[a.shape for a in actions]}, value={value.shape}\")\n", - " else:\n", - " print(f\"Forward eval OK: actions={actions}, value={value.shape}\")\n", - "\n", - "del base_policy, lstm_policy" - ] - }, { "cell_type": "markdown", "metadata": {}, diff --git a/notebooks/__init__.py b/notebooks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/notebooks/notebook_utils.py b/notebooks/notebook_utils.py new file mode 100644 index 0000000000..444745a2d4 --- /dev/null +++ b/notebooks/notebook_utils.py @@ -0,0 +1,165 @@ +import os +import sys +from pathlib import Path + +import numpy as np +import yaml + +import torch + +from pufferlib.ocean.drive.drive import Drive +from pufferlib.ocean.drive import binding +from pufferlib.ocean.torch import Drive as DrivePolicy +from pufferlib.pufferl import load_config + + +ROOT = Path(__file__).resolve().parents[1] +MAP_DIR = str(ROOT / "pufferlib/resources/drive/binaries/carla") + +COEF_NAMES = [ + "goal_radius", + "goal_speed", + "collision", + "offroad", + "comfort", + "lane_align", + "vel_align", + "lane_center", + "center_bias", + "velocity", + "reverse", + "stop_line", + "timestep", + "overspeed", + "throttle", + "steer", + "acc", +] + +EGO_LABELS = [ + "speed", + "width", + "length", + "steering", + "a_long", + "a_lat", + "lane_center", + "lane_align", + "speed_limit", + "stopped", +] + +DEFAULT_ENV_KWARGS = { + "num_agents": 64, + "num_maps": 1, + "min_agents_per_env": 64, + "max_agents_per_env": 64, + "simulation_mode": "gigaflow", + "dynamics_model": "jerk", + "action_type": "discrete", + "dt": 0.1, + "scenario_length": 512, + "resample_frequency": 0, + "reward_conditioning": True, + "reward_randomization": False, + "target_type": "static", + "map_dir": MAP_DIR, + "collision_behavior": 1, + "offroad_behavior": 1, + "obs_slots_lane_n": 32, + "obs_slots_boundary_n": 32, + "obs_slots_partners_n": 16, + "obs_slots_traffic_controls_n": 10, + "seed": 42, +} + +DEFAULT_POLICY_KWARGS = { + "input_size": 64, + "backbone_hidden_size": 128, + "backbone_num_layers": 1, + "actor_hidden_size": 128, + "actor_num_layers": 0, + "critic_hidden_size": 128, + "critic_num_layers": 0, + "encoder_gigaflow": True, + "dropout": 0.0, + "split_network": False, +} + + +def drive_kwargs(**overrides): + return {**DEFAULT_ENV_KWARGS, **overrides} + + +def make_drive_env(**overrides): + kwargs = drive_kwargs(**overrides) + env = Drive(**kwargs) + obs, info = env.reset(seed=kwargs["seed"]) + return env, obs, info + + +def notebook_dims(env): + return { + "EGO_DIM": env.ego_features, + "NUM_COEFS": binding.NUM_REWARD_COEFS, + "PARTNER_F": env.partner_features, + "ROAD_F": env.road_features, + "TRAFFIC_CONTROL_F": env.traffic_control_features, + "NUM_TRAFFIC_CONTROL_TYPES": binding.NUM_TRAFFIC_CONTROL_TYPES, + "MAX_PARTNERS": env.obs_slots_partners_n, + "MAX_LANES": env.obs_slots_lane_kept, + "MAX_BOUNDS": env.obs_slots_boundary_kept, + "MAX_TRAFFIC": env.obs_slots_traffic_controls_n, + "MAX_TARGET": env.num_target_waypoints, + "TARGET_F": env.target_features, + "TARGET_DIM": env.target_dim, + "N_ACTIONS": int(env.single_action_space.nvec[0]) if hasattr(env.single_action_space, "nvec") else 1, + "N": env.num_agents, + "ACT_SHAPE": action_shape(env), + } + + +def action_shape(env): + if hasattr(env.single_action_space, "nvec"): + return (env.num_agents, len(env.single_action_space.nvec)) + return (env.num_agents, env.single_action_space.shape[0]) + + +def zero_actions(env): + dtype = np.int64 if hasattr(env.single_action_space, "nvec") else np.float32 + return np.zeros(action_shape(env), dtype=dtype) + + +def random_actions(env): + if hasattr(env.single_action_space, "nvec"): + return np.stack([np.random.randint(0, n, size=env.num_agents) for n in env.single_action_space.nvec], axis=1) + return np.random.uniform(-1.0, 1.0, size=action_shape(env)).astype(np.float32) + + +def make_drive_policy(env, device, **overrides): + return DrivePolicy(env, **{**DEFAULT_POLICY_KWARGS, **overrides}).to(device) + + +def load_notebook_config(checkpoint_path=None, env_name="puffer_drive"): + argv = sys.argv + sys.argv = [argv[0]] + config = load_config(env_name) + sys.argv = argv + + if checkpoint_path: + cfg_yaml = os.path.join(os.path.dirname(os.path.dirname(checkpoint_path)), "config.yaml") + with open(cfg_yaml) as f: + ycfg = yaml.safe_load(f) + for section in ["env", "train", "policy", "rnn"]: + if section in ycfg and isinstance(ycfg[section], dict): + config[section].update(ycfg[section]) + + config["train"]["use_rnn"] = config.get("rnn_name") is not None + return config + + +def make_rnn_state(policy, n, device): + return { + "lstm_h": torch.zeros(n, policy.hidden_size, device=device), + "lstm_c": torch.zeros(n, policy.hidden_size, device=device), + } diff --git a/pufferlib/config/ocean/drive.ini b/pufferlib/config/ocean/drive.ini index 75d08f57b8..4d621e2bea 100644 --- a/pufferlib/config/ocean/drive.ini +++ b/pufferlib/config/ocean/drive.ini @@ -242,6 +242,7 @@ env.reward_randomization = False env.termination_mode = 0 env.num_agents = 512 env.target_type = "static" +env.num_target_waypoints = 4 env.goal_speed = 3.0 env.reward_collision = 3.0 env.reward_offroad = 3.0 @@ -259,7 +260,7 @@ env.obs_dropout_lane = 0.0 env.obs_dropout_boundary = 0.0 env.obs_slots_lane_n = 80 env.obs_slots_boundary_n = 80 -eval.num_scenarios = 250 +eval.num_scenarios = 0 eval.export_episode_csv = true eval.verify_coverage = true @@ -282,16 +283,17 @@ eval.render_max_steps = 200 inherits = "validation_defaults" type = "multi_scenario" render = true +render_backend = "obs_html" render_views = ["sim_state", "bev"] env.simulation_mode = "gigaflow" env.map_dir = "pufferlib/resources/drive/binaries/carla" env.num_maps = 8 env.min_agents_per_env = 40 env.max_agents_per_env = 40 -env.scenario_length = 500 -env.resample_frequency = 500 +env.scenario_length = 1000 +env.resample_frequency = 1000 eval.render_num_scenarios = 8 -eval.render_max_steps = 300 +eval.render_max_steps = 1000 ; --------------------------------------------------------------------------- ; Driving-behaviour evaluation: nuPlan scenes labeled by scene type. Each diff --git a/pufferlib/ocean/benchmark/evaluators/base.py b/pufferlib/ocean/benchmark/evaluators/base.py index 3c14eca580..b6582b7fb0 100644 --- a/pufferlib/ocean/benchmark/evaluators/base.py +++ b/pufferlib/ocean/benchmark/evaluators/base.py @@ -3,6 +3,7 @@ import time from dataclasses import dataclass, field from typing import ClassVar +from tqdm import tqdm @dataclass @@ -466,6 +467,7 @@ def _render_pass_html(self, vecenv, policy, args) -> list: device = args["train"]["device"] html_paths = [] scenarios_done = 0 + progress = tqdm(total=num_scenarios, desc=f"{self.name} triage_html", unit="html") vec = pufferlib.vector.make( make_env, @@ -516,6 +518,7 @@ def _render_pass_html(self, vecenv, policy, args) -> list: tmp_path.unlink(missing_ok=True) html_paths.append(html_path) scenarios_done += 1 + progress.update(1) if scenarios_done >= num_scenarios: break @@ -523,6 +526,7 @@ def _render_pass_html(self, vecenv, policy, args) -> list: break finally: vec.close() + progress.close() return html_paths @@ -563,9 +567,12 @@ def _render_pass_obs(self, vecenv, policy, args) -> list: render_env_kwargs.pop("render_mode", None) # obs viz reads state, no EGL device = args["train"]["device"] - use_traj = "trajectory" in str(args["env"].get("action_type", "")) html_paths = [] scenarios_done = 0 + progress = tqdm(total=num_scenarios * (max_steps + 1), desc=f"{self.name} obs_html", unit="step") + pool_method = getattr(policy, "pool_slot_counts", None) + if pool_method is None and getattr(policy, "policy", None) is not None: + pool_method = getattr(policy.policy, "pool_slot_counts", None) vec = pufferlib.vector.make( make_env, env_args=[], env_kwargs=render_env_kwargs, backend="PufferEnv", num_envs=1 @@ -580,56 +587,151 @@ def _render_pass_obs(self, vecenv, policy, args) -> list: if state: state["lstm_h"].zero_() state["lstm_c"].zero_() - agent_hist = [[] for _ in range(n_in_batch)] - traffic_hist = [[] for _ in range(n_in_batch)] - traj_hist = [[] for _ in range(n_in_batch)] - obs_hist = [[] for _ in range(n_in_batch)] + agent_caps = [int(sc["num_total_agents"]) for sc in scenarios] + traffic_caps = [int(sc["num_traffic_elements"]) for sc in scenarios] + active_counts = [int(sc["active_agent_count"]) for sc in scenarios] + max_agent_cap = max(agent_caps) + max_traffic_cap = max(max(traffic_caps), 1) if traffic_caps else 1 + obs_dim = int(ob.shape[-1]) + agent_f32 = np.zeros((n_in_batch, max_agent_cap, 12), dtype=np.float32) + agent_i32 = np.zeros((n_in_batch, max_agent_cap, 8), dtype=np.int32) + metrics_f32 = np.zeros((n_in_batch, max_agent_cap, 18), dtype=np.float32) + puffer_f32 = np.zeros((n_in_batch, max_agent_cap, 15), dtype=np.float32) + traffic_i16 = np.zeros((n_in_batch, max_traffic_cap, 3), dtype=np.int16) + agent_f32_hist = [np.zeros((max_steps, agent_caps[e], 12), dtype=np.float32) for e in range(n_in_batch)] + agent_i32_hist = [np.zeros((max_steps, agent_caps[e], 8), dtype=np.int32) for e in range(n_in_batch)] + metrics_hist = [np.zeros((max_steps, agent_caps[e], 18), dtype=np.float32) for e in range(n_in_batch)] + puffer_hist = [np.zeros((max_steps, agent_caps[e], 15), dtype=np.float32) for e in range(n_in_batch)] + traffic_hist = [ + np.zeros((max_steps, max(traffic_caps[e], 1), 3), dtype=np.int16) for e in range(n_in_batch) + ] + obs_hist = [ + np.zeros((max_steps, active_counts[e], obs_dim), dtype=np.float32) for e in range(n_in_batch) + ] + raw_action_hist = [[] for _ in range(n_in_batch)] + clipped_action_hist = [[] for _ in range(n_in_batch)] + value_hist = [[] for _ in range(n_in_batch)] + entropy_hist = [[] for _ in range(n_in_batch)] + policy_prob_hist = [[] for _ in range(n_in_batch)] + policy_mean_hist = [[] for _ in range(n_in_batch)] + policy_std_hist = [[] for _ in range(n_in_batch)] + policy_log_prob_hist = [[] for _ in range(n_in_batch)] + pool_hist = None for t in range(max_steps): - cur = vec.get_state() - start_obs_index = 0 - for e in range(n_in_batch): - sc = cur[e] - agent_hist[e].append(viz.fill_agents_state(sc, use_trajectory=use_traj)) - traffic_hist[e].append(viz.fill_traffics_state(sc, t)) - if use_traj: - traj_hist[e].append(viz.fill_trajectories(sc, t)) - if e > 0: - start_obs_index += cur[e - 1]["active_agent_count"] - step_obs = {} - for a in range(sc["active_agent_count"]): - aid = int(sc["active_agent_indices"][a]) - step_obs[aid] = viz.extract_obs_frame( - ob, sc, args, timestep=t, obs_index=start_obs_index + a, agent_idx=a, head_north=True - ) - obs_hist[e].append(step_obs) with torch.no_grad(): ob_t = torch.as_tensor(ob).to(device) - logits, _ = policy.forward_eval(ob_t, state) - action, _, _ = pufferlib.pytorch.sample_logits(logits, deterministic=True) - action = action.cpu().numpy().reshape(vec.action_space.shape) + logits, value = policy.forward_eval(ob_t, state) + pool_outputs = pool_method(ob_t, state) if pool_method is not None else {} + action, logprob, entropy = pufferlib.pytorch.sample_logits(logits, deterministic=True) + raw_action = action.cpu().numpy().reshape(vec.action_space.shape) + pool_outputs = {k: v.cpu().numpy().astype(np.int16, copy=False) for k, v in pool_outputs.items()} + if pool_hist is None and pool_outputs: + pool_hist = { + k: [ + np.zeros((max_steps, active_counts[e], values.shape[1]), dtype=np.int16) + for e in range(n_in_batch) + ] + for k, values in pool_outputs.items() + } + clipped_action = raw_action if isinstance(logits, torch.distributions.Normal): - action = np.clip(action, vec.action_space.low, vec.action_space.high) - ob, _, _, _, _ = vec.step(action) + clipped_action = np.clip(raw_action, vec.action_space.low, vec.action_space.high) + policy_outputs = { + "mean": logits.loc.cpu().numpy().reshape(vec.action_space.shape), + "std": logits.scale.cpu().numpy().reshape(vec.action_space.shape), + "log_prob": logprob.cpu().numpy().reshape(-1), + } + elif isinstance(logits, torch.Tensor): + policy_outputs = torch.softmax(logits, dim=-1).cpu().numpy() + else: + policy_outputs = torch.softmax(logits[0], dim=-1).cpu().numpy() + value_np = value.cpu().numpy().reshape(-1) + entropy_np = entropy.cpu().numpy().reshape(-1) + + vec.get_obs_html_frame(agent_f32, agent_i32, metrics_f32, puffer_f32, traffic_i16) + start_obs_index = 0 + for e in range(n_in_batch): + active_count = active_counts[e] + end_obs_index = start_obs_index + active_count + agent_cap = agent_caps[e] + traffic_cap = max(traffic_caps[e], 1) + agent_f32_hist[e][t] = agent_f32[e, :agent_cap] + agent_i32_hist[e][t] = agent_i32[e, :agent_cap] + metrics_hist[e][t] = metrics_f32[e, :agent_cap] + puffer_hist[e][t] = puffer_f32[e, :agent_cap] + traffic_hist[e][t] = traffic_i16[e, :traffic_cap] + obs_hist[e][t] = ob[start_obs_index:end_obs_index] + raw_action_hist[e].append( + np.asarray(raw_action[start_obs_index:end_obs_index], dtype=np.float32).copy() + ) + clipped_action_hist[e].append( + np.asarray(clipped_action[start_obs_index:end_obs_index], dtype=np.float32).copy() + ) + value_hist[e].append(value_np[start_obs_index:end_obs_index].copy()) + entropy_hist[e].append(entropy_np[start_obs_index:end_obs_index].copy()) + if pool_hist and pool_outputs: + for k, values in pool_outputs.items(): + pool_hist[k][e][t] = values[start_obs_index:end_obs_index] + if isinstance(policy_outputs, dict): + policy_mean_hist[e].append( + np.asarray( + policy_outputs["mean"][start_obs_index:end_obs_index], dtype=np.float32 + ).copy() + ) + policy_std_hist[e].append( + np.asarray( + policy_outputs["std"][start_obs_index:end_obs_index], dtype=np.float32 + ).copy() + ) + policy_log_prob_hist[e].append( + np.asarray( + policy_outputs["log_prob"][start_obs_index:end_obs_index], dtype=np.float32 + ).copy() + ) + else: + policy_prob_hist[e].append( + np.asarray(policy_outputs[start_obs_index:end_obs_index], dtype=np.float32).copy() + ) + start_obs_index = end_obs_index + ob, _, _, _, _ = vec.step(clipped_action) + progress.update(to_render) for e in range(to_render): map_name = os.path.basename(str(scenarios[e].get("map_name") or "map")).split(".")[0] # Numeric index last so build_gallery_index's `*_.html` # pattern matches. path = out_dir / f"{map_name}{step_suffix}_{scenarios_done:03d}.html" - viz.generate_interactive_replay( - scenarios[e], - agent_hist[e], - traffic_hist[e], - traj_hist[e], - obs_hist[e], - str(path), - head_north=True, - ) + compact_replay = { + "schema": "obs_html_compact_v1", + "env": dict(args["env"]), + "agent_f32": agent_f32_hist[e], + "agent_i32": agent_i32_hist[e], + "metrics_f32": metrics_hist[e], + "puffer_f32": puffer_hist[e], + "traffic_i16": traffic_hist[e], + "obs": obs_hist[e], + "raw_action": np.stack(raw_action_hist[e], axis=0), + "clipped_action": np.stack(clipped_action_hist[e], axis=0), + "value": np.stack(value_hist[e], axis=0), + "entropy": np.stack(entropy_hist[e], axis=0), + "policy_probs": np.stack(policy_prob_hist[e], axis=0) if policy_prob_hist[e] else None, + "policy_mean": np.stack(policy_mean_hist[e], axis=0) if policy_mean_hist[e] else None, + "policy_std": np.stack(policy_std_hist[e], axis=0) if policy_std_hist[e] else None, + "policy_log_prob": ( + np.stack(policy_log_prob_hist[e], axis=0) if policy_log_prob_hist[e] else None + ), + } + if pool_hist: + for k, hists in pool_hist.items(): + compact_replay[k] = hists[e] + viz.generate_interactive_replay(scenarios[e], compact_replay, filename=str(path)) html_paths.append(path) scenarios_done += 1 + progress.update(1) if scenarios_done >= num_scenarios: break finally: vec.close() + progress.close() if html_paths: viz.build_gallery_index(str(out_dir)) diff --git a/pufferlib/ocean/drive/binding.c b/pufferlib/ocean/drive/binding.c index 8cb99b870b..76328d9f01 100644 --- a/pufferlib/ocean/drive/binding.c +++ b/pufferlib/ocean/drive/binding.c @@ -5,6 +5,156 @@ #define MY_GET #include "../env_binding.h" +static int clipped_debug_goal_count(Drive *env, Agent *agent) { + assert(env != NULL); + assert(agent != NULL); + int goal_count = get_agent_goal_count(env, agent); + if (goal_count < 0) { + return 0; + } + if (goal_count > MAX_TARGET_WAYPOINTS) { + return MAX_TARGET_WAYPOINTS; + } + return goal_count; +} + +static PyObject *build_goal_lane_ids(Agent *agent, int goal_count) { + assert(agent != NULL); + assert(goal_count >= 0); + PyObject *lst = PyList_New(goal_count); + if (lst == NULL) { + return NULL; + } + for (int i = 0; i < goal_count; i++) { + PyObject *v = PyLong_FromLong(agent->goal_lane_ids[i]); + if (v == NULL || PyList_SetItem(lst, i, v) < 0) { + Py_XDECREF(v); + Py_DECREF(lst); + return NULL; + } + } + return lst; +} + +static PyObject *build_goal_lane_s(Agent *agent, int goal_count) { + assert(agent != NULL); + assert(goal_count >= 0); + PyObject *lst = PyList_New(goal_count); + if (lst == NULL) { + return NULL; + } + for (int i = 0; i < goal_count; i++) { + PyObject *v = PyFloat_FromDouble((double) agent->goal_lane_s[i]); + if (v == NULL || PyList_SetItem(lst, i, v) < 0) { + Py_XDECREF(v); + Py_DECREF(lst); + return NULL; + } + } + return lst; +} + +static PyObject *build_goal_positions(Agent *agent, int goal_count) { + assert(agent != NULL); + assert(goal_count >= 0); + PyObject *lst = PyList_New(goal_count); + if (lst == NULL) { + return NULL; + } + for (int i = 0; i < goal_count; i++) { + PyObject *pos = Py_BuildValue( + "(ddd)", + (double) agent->goal_positions_x[i], + (double) agent->goal_positions_y[i], + (double) agent->goal_positions_z[i]); + if (pos == NULL || PyList_SetItem(lst, i, pos) < 0) { + Py_XDECREF(pos); + Py_DECREF(lst); + return NULL; + } + } + return lst; +} + +static PyObject *build_goal_route_distances(Drive *env, Agent *agent, int goal_count) { + assert(env != NULL); + assert(agent != NULL); + PyObject *lst = PyList_New(goal_count); + if (lst == NULL) { + return NULL; + } + int from_lane_idx = agent->route_length > 0 ? agent->route[0] : agent->current_lane_idx; + float from_s = 0.0f; + if (from_lane_idx >= 0 && from_lane_idx < env->num_road_elements) { + from_s = lane_s_at_position(&env->road_elements[from_lane_idx], agent->sim_x, agent->sim_y); + } + for (int i = 0; i < goal_count; i++) { + float distance = route_distance_between_lane_positions( + env, + from_lane_idx, + from_s, + agent->goal_lane_ids[i], + agent->goal_lane_s[i]); + PyObject *v = PyFloat_FromDouble((double) distance); + if (v == NULL || PyList_SetItem(lst, i, v) < 0) { + Py_XDECREF(v); + Py_DECREF(lst); + return NULL; + } + from_lane_idx = agent->goal_lane_ids[i]; + from_s = agent->goal_lane_s[i]; + } + return lst; +} + +static int set_agent_goal_debug_fields(PyObject *agent_dict, Drive *env, Agent *agent) { + assert(agent_dict != NULL); + assert(env != NULL); + int goal_count = clipped_debug_goal_count(env, agent); + PyObject *v = PyLong_FromLong(goal_count); + if (v == NULL || PyDict_SetItemString(agent_dict, "num_active_goals", v) < 0) { + Py_XDECREF(v); + return 1; + } + Py_DECREF(v); + + v = PyLong_FromLong(agent->current_goal_idx); + if (v == NULL || PyDict_SetItemString(agent_dict, "current_goal_idx", v) < 0) { + Py_XDECREF(v); + return 1; + } + Py_DECREF(v); + + v = build_goal_lane_ids(agent, goal_count); + if (v == NULL || PyDict_SetItemString(agent_dict, "goal_lane_ids", v) < 0) { + Py_XDECREF(v); + return 1; + } + Py_DECREF(v); + + v = build_goal_lane_s(agent, goal_count); + if (v == NULL || PyDict_SetItemString(agent_dict, "goal_lane_s", v) < 0) { + Py_XDECREF(v); + return 1; + } + Py_DECREF(v); + + v = build_goal_route_distances(env, agent, goal_count); + if (v == NULL || PyDict_SetItemString(agent_dict, "goal_route_distances", v) < 0) { + Py_XDECREF(v); + return 1; + } + Py_DECREF(v); + + v = build_goal_positions(agent, goal_count); + if (v == NULL || PyDict_SetItemString(agent_dict, "goal_positions", v) < 0) { + Py_XDECREF(v); + return 1; + } + Py_DECREF(v); + return 0; +} + static int my_put(Env *env, PyObject *args, PyObject *kwargs) { PyObject *obs = PyDict_GetItemString(kwargs, "observations"); if (!PyObject_TypeCheck(obs, &PyArray_Type)) { @@ -306,8 +456,15 @@ static PyObject *my_get(PyObject *dict, Env *env) { if (!agents_list) { return NULL; } + int next_active_log_idx = 0; for (int i = 0; i < env->num_total_agents; i++) { Agent *a = &env->agents[i]; + int active_log_idx = -1; + if (env->active_agent_indices && next_active_log_idx < env->active_agent_count + && env->active_agent_indices[next_active_log_idx] == i) { + active_log_idx = next_active_log_idx; + next_active_log_idx++; + } PyObject *agent = PyDict_New(); if (!agent) { @@ -600,6 +757,76 @@ static PyObject *my_get(PyObject *dict, Env *env) { } Py_DECREF(pf); + pf = PyFloat_FromDouble((double) a->steering_angle); + if (!pf) { + Py_DECREF(agent); + Py_DECREF(agents_list); + return NULL; + } + if (PyDict_SetItemString(agent, "sim_steering", pf) < 0) { + Py_DECREF(pf); + Py_DECREF(agent); + Py_DECREF(agents_list); + return NULL; + } + Py_DECREF(pf); + + pf = PyFloat_FromDouble((double) a->a_long); + if (!pf) { + Py_DECREF(agent); + Py_DECREF(agents_list); + return NULL; + } + if (PyDict_SetItemString(agent, "accel_long", pf) < 0) { + Py_DECREF(pf); + Py_DECREF(agent); + Py_DECREF(agents_list); + return NULL; + } + Py_DECREF(pf); + + pf = PyFloat_FromDouble((double) a->a_lat); + if (!pf) { + Py_DECREF(agent); + Py_DECREF(agents_list); + return NULL; + } + if (PyDict_SetItemString(agent, "accel_lat", pf) < 0) { + Py_DECREF(pf); + Py_DECREF(agent); + Py_DECREF(agents_list); + return NULL; + } + Py_DECREF(pf); + + pf = PyFloat_FromDouble((double) a->jerk_long); + if (!pf) { + Py_DECREF(agent); + Py_DECREF(agents_list); + return NULL; + } + if (PyDict_SetItemString(agent, "jerk_long", pf) < 0) { + Py_DECREF(pf); + Py_DECREF(agent); + Py_DECREF(agents_list); + return NULL; + } + Py_DECREF(pf); + + pf = PyFloat_FromDouble((double) a->jerk_lat); + if (!pf) { + Py_DECREF(agent); + Py_DECREF(agents_list); + return NULL; + } + if (PyDict_SetItemString(agent, "jerk_lat", pf) < 0) { + Py_DECREF(pf); + Py_DECREF(agent); + Py_DECREF(agents_list); + return NULL; + } + Py_DECREF(pf); + pf = PyFloat_FromDouble((double) a->sim_length); if (!pf) { Py_DECREF(agent); @@ -727,6 +954,12 @@ static PyObject *my_get(PyObject *dict, Env *env) { } Py_DECREF(pf); + if (set_agent_goal_debug_fields(agent, env, a)) { + Py_DECREF(agent); + Py_DECREF(agents_list); + return NULL; + } + /* Status flags */ tmp = PyLong_FromLong(a->stopped); if (!tmp) { @@ -881,6 +1114,58 @@ static PyObject *my_get(PyObject *dict, Env *env) { } Py_DECREF(metrics); + if (env->compute_eval_metrics && env->logs && active_log_idx >= 0 && active_log_idx < env->logs_capacity) { + Log *log = &env->logs[active_log_idx]; + + pf = PyFloat_FromDouble((double) log->puffer_score); + if (!pf) { + Py_DECREF(agent); + Py_DECREF(agents_list); + return NULL; + } + if (PyDict_SetItemString(agent, "puffer_score", pf) < 0) { + Py_DECREF(pf); + Py_DECREF(agent); + Py_DECREF(agents_list); + return NULL; + } + Py_DECREF(pf); + + PyObject *puffer_metrics = PyDict_New(); + if (!puffer_metrics) { + Py_DECREF(agent); + Py_DECREF(agents_list); + return NULL; + } + if (assign_to_dict(puffer_metrics, "score", log->puffer_score) + || assign_to_dict(puffer_metrics, "no_at_fault", log->no_at_fault) + || assign_to_dict(puffer_metrics, "no_offroad", log->no_offroad) + || assign_to_dict(puffer_metrics, "no_red_light", log->no_red_light) + || assign_to_dict(puffer_metrics, "making_progress", log->making_progress) + || assign_to_dict(puffer_metrics, "direction_score", log->driving_direction_score) + || assign_to_dict(puffer_metrics, "ttc_puffer_rate", log->ttc_puffer_rate) + || assign_to_dict(puffer_metrics, "progress_ratio", log->progress_ratio) + || assign_to_dict(puffer_metrics, "speed_limit_compliance", log->speed_limit_compliance) + || assign_to_dict(puffer_metrics, "comfort_score", log->comfort_score) + || assign_to_dict(puffer_metrics, "multi_lane_score", log->multi_lane_score) + || assign_to_dict(puffer_metrics, "wrong_way_distance", log->wrong_way_distance) + || assign_to_dict(puffer_metrics, "speed_violation_sum", log->speed_violation_sum) + || assign_to_dict(puffer_metrics, "multiplier", log->multiplier) + || assign_to_dict(puffer_metrics, "weighted_average", log->weighted_average)) { + Py_DECREF(puffer_metrics); + Py_DECREF(agent); + Py_DECREF(agents_list); + return NULL; + } + if (PyDict_SetItemString(agent, "puffer_metrics", puffer_metrics) < 0) { + Py_DECREF(puffer_metrics); + Py_DECREF(agent); + Py_DECREF(agents_list); + return NULL; + } + Py_DECREF(puffer_metrics); + } + /* Export route information */ tmp = PyLong_FromLong(a->route_length); if (!tmp) { @@ -922,12 +1207,10 @@ static PyObject *my_get(PyObject *dict, Env *env) { return NULL; } Py_DECREF(route_list); - } else { - if (PyDict_SetItemString(agent, "route", Py_None) < 0) { - Py_DECREF(agent); - Py_DECREF(agents_list); - return NULL; - } + } else if (PyDict_SetItemString(agent, "route", Py_None) < 0) { + Py_DECREF(agent); + Py_DECREF(agents_list); + return NULL; } PyList_SetItem(agents_list, i, agent); @@ -1146,12 +1429,10 @@ static PyObject *my_get(PyObject *dict, Env *env) { return NULL; } Py_DECREF(lx); - } else { - if (PyDict_SetItemString(road, "x", Py_None) < 0) { - Py_DECREF(road); - Py_DECREF(road_list); - return NULL; - } + } else if (PyDict_SetItemString(road, "x", Py_None) < 0) { + Py_DECREF(road); + Py_DECREF(road_list); + return NULL; } if (r->y && seg_len > 0) { PyObject *ly = PyList_New(seg_len); @@ -1177,12 +1458,10 @@ static PyObject *my_get(PyObject *dict, Env *env) { return NULL; } Py_DECREF(ly); - } else { - if (PyDict_SetItemString(road, "y", Py_None) < 0) { - Py_DECREF(road); - Py_DECREF(road_list); - return NULL; - } + } else if (PyDict_SetItemString(road, "y", Py_None) < 0) { + Py_DECREF(road); + Py_DECREF(road_list); + return NULL; } if (r->z && seg_len > 0) { PyObject *lz = PyList_New(seg_len); @@ -1208,12 +1487,10 @@ static PyObject *my_get(PyObject *dict, Env *env) { return NULL; } Py_DECREF(lz); - } else { - if (PyDict_SetItemString(road, "z", Py_None) < 0) { - Py_DECREF(road); - Py_DECREF(road_list); - return NULL; - } + } else if (PyDict_SetItemString(road, "z", Py_None) < 0) { + Py_DECREF(road); + Py_DECREF(road_list); + return NULL; } /* Lane-specific fields */ @@ -1277,6 +1554,50 @@ static PyObject *my_get(PyObject *dict, Env *env) { } Py_DECREF(pf); + pf = PyFloat_FromDouble((double) r->length); + if (!pf) { + Py_DECREF(road); + Py_DECREF(road_list); + return NULL; + } + if (PyDict_SetItemString(road, "length", pf) < 0) { + Py_DECREF(pf); + Py_DECREF(road); + Py_DECREF(road_list); + return NULL; + } + Py_DECREF(pf); + + if (is_road_lane(r->type) && r->cum_lengths != NULL && seg_len > 0) { + tmp = PyList_New(seg_len); + if (!tmp) { + Py_DECREF(road); + Py_DECREF(road_list); + return NULL; + } + for (int k = 0; k < seg_len; k++) { + PyObject *fv = PyFloat_FromDouble((double) r->cum_lengths[k]); + if (!fv) { + Py_DECREF(tmp); + Py_DECREF(road); + Py_DECREF(road_list); + return NULL; + } + PyList_SET_ITEM(tmp, k, fv); + } + if (PyDict_SetItemString(road, "cum_lengths", tmp) < 0) { + Py_DECREF(tmp); + Py_DECREF(road); + Py_DECREF(road_list); + return NULL; + } + Py_DECREF(tmp); + } else if (PyDict_SetItemString(road, "cum_lengths", Py_None) < 0) { + Py_DECREF(road); + Py_DECREF(road_list); + return NULL; + } + PyList_SetItem(road_list, i, road); } if (PyDict_SetItemString(dict, "road_elements", road_list) < 0) { @@ -1373,12 +1694,10 @@ static PyObject *my_get(PyObject *dict, Env *env) { return NULL; } Py_DECREF(ls); - } else { - if (PyDict_SetItemString(traffic, "states", Py_None) < 0) { - Py_DECREF(traffic); - Py_DECREF(traffic_list); - return NULL; - } + } else if (PyDict_SetItemString(traffic, "states", Py_None) < 0) { + Py_DECREF(traffic); + Py_DECREF(traffic_list); + return NULL; } /* Stop line endpoints */ @@ -1455,12 +1774,10 @@ static PyObject *my_get(PyObject *dict, Env *env) { return NULL; } Py_DECREF(ll); - } else { - if (PyDict_SetItemString(traffic, "controlled_lanes", Py_None) < 0) { - Py_DECREF(traffic); - Py_DECREF(traffic_list); - return NULL; - } + } else if (PyDict_SetItemString(traffic, "controlled_lanes", Py_None) < 0) { + Py_DECREF(traffic); + Py_DECREF(traffic_list); + return NULL; } PyList_SetItem(traffic_list, i, traffic); @@ -1807,6 +2124,9 @@ static int my_init(Env *env, PyObject *args, PyObject *kwargs) { env->num_target_waypoints = MAX_TARGET_WAYPOINTS; } env->target_type = (int) unpack(kwargs, "target_type"); + if (env->target_type == TARGET_DIJKSTRA) { + env->num_target_waypoints = DIJKSTRA_TARGET_SLOTS; + } env->obs_slots_boundary_n = (int) unpack(kwargs, "obs_slots_boundary_n"); env->obs_slots_lane_n = (int) unpack(kwargs, "obs_slots_lane_n"); env->obs_slots_partners_n = (int) unpack(kwargs, "obs_slots_partners_n"); diff --git a/pufferlib/ocean/drive/datatypes.h b/pufferlib/ocean/drive/datatypes.h index c636b9f5c8..1f90fcdcaf 100644 --- a/pufferlib/ocean/drive/datatypes.h +++ b/pufferlib/ocean/drive/datatypes.h @@ -240,7 +240,10 @@ struct Agent { float goal_position_x; // alias = goal_positions_x[current_goal_idx] float goal_position_y; // alias = goal_positions_y[current_goal_idx] float goal_position_z; // alias = goal_positions_z[current_goal_idx] - int current_goal_idx; // index of next goal to reach (0..N-1) + int goal_lane_ids[MAX_TARGET_WAYPOINTS]; + float goal_lane_s[MAX_TARGET_WAYPOINTS]; + int num_active_goals; + int current_goal_idx; // index of next goal to reach (0..N-1) int stopped; // 0/1 -> freeze if set int removed; // 0/1 -> remove from sim if set @@ -284,6 +287,8 @@ struct RoadMapElement { int num_exits; int *exit_lanes; float speed_limit; + float length; + float *cum_lengths; }; struct TrafficControlElement { @@ -306,7 +311,6 @@ typedef struct { struct LaneGraph { int n_lanes; int *lane_ids; - float *lane_lengths; float *distances; // n_lanes * n_lanes row-major }; @@ -332,6 +336,7 @@ void free_road_element(struct RoadMapElement *element) { free(element->headings); free(element->entry_lanes); free(element->exit_lanes); + free(element->cum_lengths); } void free_traffic_element(struct TrafficControlElement *element) { @@ -341,6 +346,5 @@ void free_traffic_element(struct TrafficControlElement *element) { void free_lane_graph(struct LaneGraph *graph) { free(graph->lane_ids); - free(graph->lane_lengths); free(graph->distances); } diff --git a/pufferlib/ocean/drive/drive.c b/pufferlib/ocean/drive/drive.c index bedd365122..c25b61d789 100644 --- a/pufferlib/ocean/drive/drive.c +++ b/pufferlib/ocean/drive/drive.c @@ -145,7 +145,7 @@ void demo() { while (!WindowShouldClose()) { // Handle camera controls int (*actions)[2] = (int (*)[2]) env.actions; - forward(net, env.observations, env.actions); + forward(net, env.observations, (int *) env.actions); if (IsKeyDown(KEY_LEFT_SHIFT)) { actions[env.human_agent_idx][0] = 3; actions[env.human_agent_idx][1] = 6; diff --git a/pufferlib/ocean/drive/drive.h b/pufferlib/ocean/drive/drive.h index 67bf48d8e5..975b3b1808 100644 --- a/pufferlib/ocean/drive/drive.h +++ b/pufferlib/ocean/drive/drive.h @@ -115,10 +115,11 @@ // TARGET_TYPE modes (controls what target info is in observations) #define TARGET_STATIC 0 #define TARGET_DYNAMIC 1 +#define TARGET_DIJKSTRA 2 // Observation feature counts #define EGO_FEATURES 10 -#define ROAD_FEATURES 7 +#define ROAD_FEATURES 9 #define PARTNER_FEATURES 8 #define TRAFFIC_CONTROL_FEATURES 7 #define PADDED_OBSERVATION_VALUE -0.001f @@ -127,6 +128,10 @@ // GIGAFLOW specific #define MAX_ROUTE_LENGTH 64 +#define DIJKSTRA_TARGET_SLOTS 4 +#define DIJKSTRA_MIN_GOAL_DISTANCE 20.0f +#define DIJKSTRA_MAX_GOAL_DISTANCE 200.0f +#define DIJKSTRA_MAX_ROUTE_ATTEMPTS 64 // Traffic light generation #define TL_DEFAULT_RED_DURATION 2.0f #define TL_DEFAULT_YELLOW_DURATION 3.0f @@ -309,6 +314,7 @@ struct Drive { int num_traffic_elements; int num_objects; struct LaneGraph lane_graph; + int *lane_graph_index_by_lane; int static_agent_count; int *static_agent_indices; int expert_static_agent_count; @@ -488,6 +494,27 @@ static int traffic_control_in_scope(int type, int scope) { } } +static void build_lane_graph_index(Drive *env) { + free(env->lane_graph_index_by_lane); + env->lane_graph_index_by_lane = NULL; + if (env->num_road_elements <= 0 || env->lane_graph.n_lanes <= 0 || env->lane_graph.lane_ids == NULL) { + return; + } + env->lane_graph_index_by_lane = (int *) malloc(env->num_road_elements * sizeof(int)); + if (env->lane_graph_index_by_lane == NULL) { + return; + } + for (int i = 0; i < env->num_road_elements; i++) { + env->lane_graph_index_by_lane[i] = -1; + } + for (int i = 0; i < env->lane_graph.n_lanes; i++) { + int lane_idx = env->lane_graph.lane_ids[i]; + if (lane_idx >= 0 && lane_idx < env->num_road_elements) { + env->lane_graph_index_by_lane[lane_idx] = i; + } + } +} + static void reset_agent_metrics(Drive *env, int agent_idx) { Agent *agent = &env->agents[agent_idx]; for (int i = 0; i < NUM_METRICS; i++) { @@ -1146,12 +1173,23 @@ int load_map_binary(const char *filename, Drive *drive) { fclose(file); return -1; } + if (fread(&road->length, sizeof(float), 1, file) != 1) { + fclose(file); + return -1; + } + road->cum_lengths = (float *) malloc(slen * sizeof(float)); + if ((size_t) slen > 0 && fread(road->cum_lengths, sizeof(float), slen, file) != (size_t) slen) { + fclose(file); + return -1; + } } else { road->num_entries = 0; road->num_exits = 0; road->entry_lanes = NULL; road->exit_lanes = NULL; road->speed_limit = 0.0f; + road->length = 0.0f; + road->cum_lengths = NULL; } } @@ -1232,7 +1270,6 @@ int load_map_binary(const char *filename, Drive *drive) { } drive->lane_graph.n_lanes = n_lanes_graph; drive->lane_graph.lane_ids = NULL; - drive->lane_graph.lane_lengths = NULL; drive->lane_graph.distances = NULL; if (n_lanes_graph > 0) { drive->lane_graph.lane_ids = (int *) malloc(n_lanes_graph * sizeof(int)); @@ -1240,11 +1277,6 @@ int load_map_binary(const char *filename, Drive *drive) { fclose(file); return -1; } - drive->lane_graph.lane_lengths = (float *) malloc(n_lanes_graph * sizeof(float)); - if (fread(drive->lane_graph.lane_lengths, sizeof(float), n_lanes_graph, file) != (size_t) n_lanes_graph) { - fclose(file); - return -1; - } drive->lane_graph.distances = (float *) malloc(n_lanes_graph * n_lanes_graph * sizeof(float)); if (fread(drive->lane_graph.distances, sizeof(float), n_lanes_graph * n_lanes_graph, file) != (size_t) (n_lanes_graph * n_lanes_graph)) { @@ -1252,6 +1284,7 @@ int load_map_binary(const char *filename, Drive *drive) { return -1; } } + build_lane_graph_index(drive); // Metadata if (fread(drive->scenario_id, sizeof(char), 128, file) != 128) { @@ -1310,17 +1343,6 @@ int load_map_binary(const char *filename, Drive *drive) { // Road Utility Functions // ======================================== -// Compute the length of a lane -static float compute_lane_length(RoadMapElement *lane) { - float length = 0.0f; - for (int i = 1; i < lane->segment_length; i++) { - float dx = lane->x[i] - lane->x[i - 1]; - float dy = lane->y[i] - lane->y[i - 1]; - length += sqrtf(dx * dx + dy * dy); - } - return length; -} - // Compute the remaining distance on a lane from a given position to the end of the lane static float compute_remaining_lane_distance(RoadMapElement *lane, float pos_x, float pos_y) { // Find the closest segment to the position @@ -1355,23 +1377,9 @@ static float compute_remaining_lane_distance(RoadMapElement *lane, float pos_x, } } - // Compute remaining distance from closest point to end of lane - float remaining = 0.0f; - - // Partial distance in current segment (from t to end of segment) - float dx = lane->x[closest_seg + 1] - lane->x[closest_seg]; - float dy = lane->y[closest_seg + 1] - lane->y[closest_seg]; - float seg_len = sqrtf(dx * dx + dy * dy); - remaining += (1.0f - closest_t) * seg_len; - - // Full distance of remaining segments - for (int i = closest_seg + 1; i < lane->segment_length - 1; i++) { - dx = lane->x[i + 1] - lane->x[i]; - dy = lane->y[i + 1] - lane->y[i]; - remaining += sqrtf(dx * dx + dy * dy); - } - - return remaining; + float progress = lane->cum_lengths[closest_seg] + + closest_t * (lane->cum_lengths[closest_seg + 1] - lane->cum_lengths[closest_seg]); + return fmaxf(0.0f, lane->length - progress); } static float compute_lane_end_distance_sq(RoadMapElement *lane, float origin_x, float origin_y) { @@ -1385,6 +1393,118 @@ static float compute_lane_end_distance_sq(RoadMapElement *lane, float origin_x, return dx * dx + dy * dy; } +static int get_lane_graph_index(Drive *env, int lane_idx) { + if (lane_idx < 0 || lane_idx >= env->num_road_elements || env->lane_graph_index_by_lane == NULL) { + return -1; + } + return env->lane_graph_index_by_lane[lane_idx]; +} + +static int valid_route_distance(float distance) { + return isfinite(distance) && distance >= 0.0f && distance < 1e8f; +} + +static float lane_graph_distance(Drive *env, int from_lane_idx, int to_lane_idx) { + int from_idx = get_lane_graph_index(env, from_lane_idx); + int to_idx = get_lane_graph_index(env, to_lane_idx); + if (from_idx < 0 || to_idx < 0 || env->lane_graph.distances == NULL) { + return INFINITY; + } + return env->lane_graph.distances[from_idx * env->lane_graph.n_lanes + to_idx]; +} + +static float lane_s_at_position(RoadMapElement *lane, float pos_x, float pos_y) { + if (lane->cum_lengths == NULL || lane->segment_length < 2) { + return 0.0f; + } + + int closest_seg = 0; + float closest_t = 0.0f; + float min_dist_sq = 1e30f; + + for (int i = 0; i < lane->segment_length - 1; i++) { + float x0 = lane->x[i]; + float y0 = lane->y[i]; + float x1 = lane->x[i + 1]; + float y1 = lane->y[i + 1]; + float dx = x1 - x0; + float dy = y1 - y0; + float seg_len_sq = dx * dx + dy * dy; + float t = 0.0f; + if (seg_len_sq > 1e-6f) { + t = ((pos_x - x0) * dx + (pos_y - y0) * dy) / seg_len_sq; + t = fmaxf(0.0f, fminf(1.0f, t)); + } + float proj_x = x0 + t * dx; + float proj_y = y0 + t * dy; + float dist_sq = (pos_x - proj_x) * (pos_x - proj_x) + (pos_y - proj_y) * (pos_y - proj_y); + if (dist_sq < min_dist_sq) { + min_dist_sq = dist_sq; + closest_seg = i; + closest_t = t; + } + } + + float s = lane->cum_lengths[closest_seg] + + closest_t * (lane->cum_lengths[closest_seg + 1] - lane->cum_lengths[closest_seg]); + return clip(s, 0.0f, lane->length); +} + +static float lane_midpoint_s(RoadMapElement *lane, int geometry_idx) { + if (lane->cum_lengths == NULL || geometry_idx < 0 || geometry_idx >= lane->segment_length - 1) { + return 0.0f; + } + return 0.5f * (lane->cum_lengths[geometry_idx] + lane->cum_lengths[geometry_idx + 1]); +} + +static int lane_segment_at_s(RoadMapElement *lane, float lane_s) { + if (lane->segment_length < 2 || lane->cum_lengths == NULL) { + return 0; + } + float s = clip(lane_s, 0.0f, lane->length); + for (int i = 0; i < lane->segment_length - 1; i++) { + if (s <= lane->cum_lengths[i + 1]) { + return i; + } + } + return lane->segment_length - 2; +} + +static float route_distance_between_lane_positions( + Drive *env, + int from_lane_idx, + float from_s, + int to_lane_idx, + float to_s) { + if (from_lane_idx < 0 || from_lane_idx >= env->num_road_elements || to_lane_idx < 0 + || to_lane_idx >= env->num_road_elements) { + return INFINITY; + } + RoadMapElement *from_lane = &env->road_elements[from_lane_idx]; + RoadMapElement *to_lane = &env->road_elements[to_lane_idx]; + if (!is_drivable_road_lane(from_lane->type) || !is_drivable_road_lane(to_lane->type)) { + return INFINITY; + } + float from_clamped = clip(from_s, 0.0f, from_lane->length); + float to_clamped = clip(to_s, 0.0f, to_lane->length); + if (from_lane_idx == to_lane_idx) { + return (to_clamped >= from_clamped) ? (to_clamped - from_clamped) : INFINITY; + } + + float graph_distance = lane_graph_distance(env, from_lane_idx, to_lane_idx); + if (!valid_route_distance(graph_distance)) { + return INFINITY; + } + return fmaxf(0.0f, from_lane->length - from_clamped) + graph_distance + to_clamped; +} + +static int get_agent_goal_count(Drive *env, Agent *agent) { + if (env->target_type == TARGET_DIJKSTRA) { + return agent->num_active_goals; + } + return env->num_target_waypoints; +} + static float compute_progression(Agent *agent) { int num_wp = agent->path->num_waypoints; if (num_wp < 2) { @@ -1804,7 +1924,7 @@ static int generate_random_route( // Accumulate distance RoadMapElement *exit_lane = &env->road_elements[chosen_exit_idx]; - accumulated_distance += compute_lane_length(exit_lane); + accumulated_distance += exit_lane->length; if (chosen_exit_dist_sq > max_end_distance_sq) { max_end_distance_sq = chosen_exit_dist_sq; } @@ -1874,8 +1994,247 @@ static int compute_new_route(Drive *env, int agent_idx, int current_lane_idx) { return 1; // Success } +static void set_agent_goal_from_lane_s(Drive *env, Agent *agent, int goal_idx, int lane_idx, float lane_s) { + RoadMapElement *lane = &env->road_elements[lane_idx]; + int seg_idx = lane_segment_at_s(lane, lane_s); + float s0 = lane->cum_lengths[seg_idx]; + float s1 = lane->cum_lengths[seg_idx + 1]; + float denom = fmaxf(s1 - s0, 1e-6f); + float t = clip((lane_s - s0) / denom, 0.0f, 1.0f); + agent->goal_positions_x[goal_idx] = lane->x[seg_idx] + t * (lane->x[seg_idx + 1] - lane->x[seg_idx]); + agent->goal_positions_y[goal_idx] = lane->y[seg_idx] + t * (lane->y[seg_idx + 1] - lane->y[seg_idx]); + agent->goal_positions_z[goal_idx] = lane->z[seg_idx] + t * (lane->z[seg_idx + 1] - lane->z[seg_idx]); + agent->goal_lane_ids[goal_idx] = lane_idx; + agent->goal_lane_s[goal_idx] = clip(lane_s, 0.0f, lane->length); +} + +static int build_dijkstra_route_to_goal( + Drive *env, + int start_lane_idx, + int goal_lane_idx, + int *route, + int max_route_length) { + if (start_lane_idx < 0 || goal_lane_idx < 0 || max_route_length <= 0) { + return 0; + } + int current_lane_idx = start_lane_idx; + int route_length = 0; + route[route_length++] = current_lane_idx; + + for (int step = 0; step < DIJKSTRA_MAX_ROUTE_ATTEMPTS && route_length < max_route_length; step++) { + if (current_lane_idx == goal_lane_idx) { + return route_length; + } + RoadMapElement *current_lane = &env->road_elements[current_lane_idx]; + float current_dist = lane_graph_distance(env, current_lane_idx, goal_lane_idx); + if (!valid_route_distance(current_dist)) { + return 0; + } + + int best_exit = -1; + float best_dist = current_dist; + for (int e = 0; e < current_lane->num_exits; e++) { + int exit_lane_idx = current_lane->exit_lanes[e]; + if (exit_lane_idx < 0 || exit_lane_idx >= env->num_road_elements) { + continue; + } + if (!is_drivable_road_lane(env->road_elements[exit_lane_idx].type)) { + continue; + } + float exit_dist = lane_graph_distance(env, exit_lane_idx, goal_lane_idx); + if (valid_route_distance(exit_dist) && exit_dist < best_dist - 1e-3f) { + best_dist = exit_dist; + best_exit = exit_lane_idx; + } + } + if (best_exit == -1) { + return 0; + } + route[route_length++] = best_exit; + current_lane_idx = best_exit; + } + return current_lane_idx == goal_lane_idx ? route_length : 0; +} + +static int set_agent_route_from_buffer(Drive *env, int agent_idx, int *route, int route_length) { + Agent *agent = &env->agents[agent_idx]; + if (route_length <= 0) { + return 0; + } + free(agent->route); + agent->route = (int *) malloc(route_length * sizeof(int)); + if (agent->route == NULL) { + agent->route_length = 0; + return 0; + } + agent->route_length = route_length; + for (int i = 0; i < route_length; i++) { + agent->route[i] = route[i]; + } + agent->current_route_index = 0; + build_path(env, agent_idx); + agent->closest_path_idx_wp = 0; + agent->closest_path_idx_wp = get_closest_waypoint_index_on_path(env, agent_idx); + agent->path_progression = compute_progression(agent); + return 1; +} + +static int rebuild_dijkstra_route_to_current_goal(Drive *env, int agent_idx) { + Agent *agent = &env->agents[agent_idx]; + int goal_count = get_agent_goal_count(env, agent); + if (agent->current_goal_idx < 0 || agent->current_goal_idx >= goal_count) { + return 0; + } + int start_lane_idx = agent->current_lane_idx; + if (start_lane_idx < 0) { + start_lane_idx = agent->previous_lane_idx; + } + int goal_lane_idx = agent->goal_lane_ids[agent->current_goal_idx]; + int temp_route[MAX_ROUTE_LENGTH]; + int route_length = build_dijkstra_route_to_goal(env, start_lane_idx, goal_lane_idx, temp_route, MAX_ROUTE_LENGTH); + return set_agent_route_from_buffer(env, agent_idx, temp_route, route_length); +} + +static int compute_dijkstra_goals(Drive *env, int agent_idx) { + assert(env != NULL); + assert(agent_idx >= 0 && agent_idx < env->num_total_agents); + + Agent *agent = &env->agents[agent_idx]; + int active_goal_count = 1 + (rand() % DIJKSTRA_TARGET_SLOTS); + for (int i = 0; i < MAX_TARGET_WAYPOINTS; i++) { + agent->goal_positions_x[i] = 0.0f; + agent->goal_positions_y[i] = 0.0f; + agent->goal_positions_z[i] = 0.0f; + agent->goal_lane_ids[i] = -1; + agent->goal_lane_s[i] = 0.0f; + } + + int start_lane_idx = agent->current_lane_idx; + if (start_lane_idx < 0) { + return 0; + } + RoadMapElement *start_lane = &env->road_elements[start_lane_idx]; + float start_s = lane_s_at_position(start_lane, agent->sim_x, agent->sim_y); + + int first_route[MAX_ROUTE_LENGTH]; + int first_route_length = 0; + int first_goal_lane_idx = -1; + float first_goal_lane_s = 0.0f; + + for (int attempt = 0; attempt < 10; attempt++) { + int idx = rand() % env->num_road_elements; + if (!is_drivable_road_lane(env->road_elements[idx].type)) { + continue; + } + float graph_dist = lane_graph_distance(env, start_lane_idx, idx); + if (!valid_route_distance(graph_dist)) { + continue; + } + float lane_len = env->road_elements[idx].length; + float s = random_uniform(0.0f, lane_len); + float total_dist = graph_dist + s - start_s; + if (total_dist < DIJKSTRA_MIN_GOAL_DISTANCE) { + continue; + } + int temp_route[MAX_ROUTE_LENGTH]; + int route_length = build_dijkstra_route_to_goal(env, start_lane_idx, idx, temp_route, MAX_ROUTE_LENGTH); + if (route_length > 0) { + first_goal_lane_idx = idx; + first_route_length = route_length; + first_goal_lane_s = s; + for (int i = 0; i < route_length; i++) { + first_route[i] = temp_route[i]; + } + break; + } + } + + if (first_goal_lane_idx == -1) { + first_goal_lane_idx = start_lane_idx; + first_route_length = 1; + first_route[0] = start_lane_idx; + float fallback_s = start_s + DIJKSTRA_MIN_GOAL_DISTANCE; + float lane_len = env->road_elements[start_lane_idx].length; + if (fallback_s > lane_len) { + fallback_s = lane_len; + } + first_goal_lane_s = fallback_s; + } + + set_agent_goal_from_lane_s(env, agent, 0, first_goal_lane_idx, first_goal_lane_s); + + int current_lane_idx = first_goal_lane_idx; + float current_s = first_goal_lane_s; + int sampled_count = 1; + + for (int goal_idx = 1; goal_idx < active_goal_count; goal_idx++) { + float remaining_dist = random_uniform(DIJKSTRA_MIN_GOAL_DISTANCE + 5.0f, DIJKSTRA_MAX_GOAL_DISTANCE - 5.0f); + int next_lane_idx = current_lane_idx; + float next_s = current_s; + + int step = 0; + int max_steps = 100; + while (remaining_dist > 0.0f && step < max_steps) { + step++; + RoadMapElement *lane = &env->road_elements[next_lane_idx]; + float segment_rem = lane->length - next_s; + if (remaining_dist <= segment_rem) { + next_s += remaining_dist; + remaining_dist = 0.0f; + break; + } else { + remaining_dist -= segment_rem; + if (lane->num_exits > 0) { + int valid_exits[MAX_ROUTE_LENGTH]; + int valid_count = 0; + for (int e = 0; e < lane->num_exits && e < MAX_ROUTE_LENGTH; e++) { + int exit_lane_idx = lane->exit_lanes[e]; + if (exit_lane_idx >= 0 && exit_lane_idx < env->num_road_elements) { + if (is_drivable_road_lane(env->road_elements[exit_lane_idx].type)) { + valid_exits[valid_count++] = exit_lane_idx; + } + } + } + if (valid_count > 0) { + next_lane_idx = valid_exits[rand() % valid_count]; + next_s = 0.0f; + } else { + next_s = lane->length; + remaining_dist = 0.0f; + break; + } + } else { + next_s = lane->length; + remaining_dist = 0.0f; + break; + } + } + } + + set_agent_goal_from_lane_s(env, agent, goal_idx, next_lane_idx, next_s); + current_lane_idx = next_lane_idx; + current_s = next_s; + sampled_count++; + } + + agent->num_active_goals = sampled_count; + agent->current_goal_idx = 0; + agent->goal_position_x = agent->goal_positions_x[0]; + agent->goal_position_y = agent->goal_positions_y[0]; + agent->goal_position_z = agent->goal_positions_z[0]; + return set_agent_route_from_buffer(env, agent_idx, first_route, first_route_length); +} + static void compute_goals(Drive *env, int agent_idx) { Agent *agent = &env->agents[agent_idx]; + if (env->target_type == TARGET_DIJKSTRA) { + if (!compute_dijkstra_goals(env, agent_idx)) { + printf("[GIGAFLOW WARNING] -> Failed to compute dijkstra goals for agent %d\n", agent_idx); + agent->removed = 1; + } + return; + } + struct Path *path = agent->path; // Validate path exists @@ -1942,6 +2301,12 @@ static void compute_goals(Drive *env, int agent_idx) { agent->goal_positions_x[i] = path->waypoints[wp_idx].x; agent->goal_positions_y[i] = path->waypoints[wp_idx].y; agent->goal_positions_z[i] = path->waypoints[wp_idx].z; + int goal_lane_idx = path->waypoints[wp_idx].lane_idx; + agent->goal_lane_ids[i] = goal_lane_idx; + agent->goal_lane_s[i] = lane_s_at_position( + &env->road_elements[goal_lane_idx], + agent->goal_positions_x[i], + agent->goal_positions_y[i]); } // Reset goal index and update alias @@ -3021,14 +3386,22 @@ static int spawn_agent(Drive *env, int agent_idx, int num_agents) { agent->yaw_rate = 0.0f; update_agent_speed(agent); - // Compute initial route - if (!compute_new_route(env, agent_idx, start_lane_idx)) { - printf("[GIGAFLOW WARNING] -> Failed to compute a new route for agent %d\n", agent_idx); - return 0; // Failed to compute new goal - } + if (env->target_type == TARGET_DIJKSTRA) { + agent->current_lane_idx = start_lane_idx; + if (!compute_dijkstra_goals(env, agent_idx)) { + printf("[GIGAFLOW WARNING] -> Failed to compute dijkstra goals for agent %d\n", agent_idx); + return 0; + } + } else { + // Compute initial route + if (!compute_new_route(env, agent_idx, start_lane_idx)) { + printf("[GIGAFLOW WARNING] -> Failed to compute a new route for agent %d\n", agent_idx); + return 0; // Failed to compute new goal + } - // Compute initial goal - compute_goals(env, agent_idx); + // Compute initial goal + compute_goals(env, agent_idx); + } return 1; // Success } @@ -3373,6 +3746,9 @@ void remove_bad_trajectories(Drive *env) { void init(Drive *env) { env->human_agent_idx = 0; env->timestep = 0; + if (env->target_type == TARGET_DIJKSTRA) { + env->num_target_waypoints = DIJKSTRA_TARGET_SLOTS; + } load_map_binary(env->map_name, env); env->road_dropout_enabled = (env->obs_slots_lane_kept < env->obs_slots_lane_n) || (env->obs_slots_boundary_kept < env->obs_slots_boundary_n); @@ -3977,9 +4353,8 @@ static void compute_metrics(Drive *env, int agent_idx) { = compute_euclidean_distance(agent->sim_x, agent->sim_y, agent->goal_position_x, agent->goal_position_y); float goal_z_dist = fabsf(agent->sim_z - agent->goal_position_z); - // Goal reaching — guard against incrementing past num_target_waypoints - if (agent->current_goal_idx < env->num_target_waypoints - && distance_to_goal < agent->reward_coefs[REWARD_COEF_GOAL_RADIUS] && goal_z_dist < Z_BUFFER) { + // Goal reaching + if (distance_to_goal < agent->reward_coefs[REWARD_COEF_GOAL_RADIUS] && goal_z_dist < Z_BUFFER) { agent->metrics_array[REACHED_GOAL_IDX] = 1.0f; agent->current_goal_idx++; } @@ -4026,8 +4401,9 @@ static void compute_rewards(Drive *env, int i) { // Goal reward if (agent->metrics_array[REACHED_GOAL_IDX] > 0.0f) { float weight = 1.0f; + int goal_count = get_agent_goal_count(env, agent); if (env->simulation_mode == SIMULATION_GIGAFLOW) { - if (agent->current_goal_idx == env->num_target_waypoints + if (agent->current_goal_idx == goal_count && agent->sim_speed > agent->reward_coefs[REWARD_COEF_GOAL_SPEED]) { weight = 0.0f; } @@ -4354,6 +4730,21 @@ static int write_road_obs(Drive *env, Agent *ego, float *obs, int obs_idx, int * int lanes_found = 0; int boundaries_found = 0; + int goal_count = get_agent_goal_count(env, ego); + int goal_lane_idx = -1; + float goal_lane_s = 0.0f; + float ego_goal_route_distance = INFINITY; + if (ego->current_goal_idx >= 0 && ego->current_goal_idx < goal_count) { + goal_lane_idx = ego->goal_lane_ids[ego->current_goal_idx]; + goal_lane_s = ego->goal_lane_s[ego->current_goal_idx]; + } + if (goal_lane_idx >= 0 && ego->current_lane_idx >= 0) { + RoadMapElement *ego_lane = &env->road_elements[ego->current_lane_idx]; + float ego_lane_s = lane_s_at_position(ego_lane, ego->sim_x, ego->sim_y); + ego_goal_route_distance + = route_distance_between_lane_positions(env, ego->current_lane_idx, ego_lane_s, goal_lane_idx, goal_lane_s); + } + for (int k = 0; k < neighbor_count; k++) { if (lanes_found >= env->obs_slots_lane_n && boundaries_found >= env->obs_slots_boundary_n) { break; @@ -4418,6 +4809,28 @@ static int write_road_obs(Drive *env, Agent *ego, float *obs, int obs_idx, int * segment_dest[feature_base + 4] = LANE_WIDTH / env->obs_norm_road_seg_width_m; segment_dest[feature_base + 5] = rel_seg_dir_x; segment_dest[feature_base + 6] = rel_seg_dir_y; + if (is_lane) { + float segment_s = lane_midpoint_s(road_element, geometry_idx); + float segment_goal_distance + = route_distance_between_lane_positions(env, entity_idx, segment_s, goal_lane_idx, goal_lane_s); + if (valid_route_distance(segment_goal_distance)) { + segment_dest[feature_base + 7] = clip(segment_goal_distance / env->obs_norm_goal_offset_m, -1.0f, 1.0f); + if (valid_route_distance(ego_goal_route_distance)) { + segment_dest[feature_base + 8] = clip( + (segment_goal_distance - ego_goal_route_distance) / env->obs_norm_goal_offset_m, + -1.0f, + 1.0f); + } else { + segment_dest[feature_base + 8] = 0.0f; + } + } else { + segment_dest[feature_base + 7] = 1.0f; + segment_dest[feature_base + 8] = 0.0f; + } + } else { + segment_dest[feature_base + 7] = 0.0f; + segment_dest[feature_base + 8] = 0.0f; + } } if (env->road_dropout_enabled) { @@ -5010,8 +5423,9 @@ void c_step(Drive *env) { int agent_idx = env->active_agent_indices[i]; Agent *agent = &env->agents[agent_idx]; if (agent->metrics_array[REACHED_GOAL_IDX] > 0.0f) { - if (agent->current_goal_idx == env->num_target_waypoints) { - // Last goal reached + int goal_count = get_agent_goal_count(env, agent); + if (agent->current_goal_idx == goal_count) { + // Last goal reached - generate new set of goals env->logs[i].num_goals_reached += 1; if (env->simulation_mode == SIMULATION_REPLAY) { // Replay mode: leave current_goal_idx saturated so the @@ -5025,6 +5439,9 @@ void c_step(Drive *env) { agent->goal_position_x = agent->goal_positions_x[agent->current_goal_idx]; agent->goal_position_y = agent->goal_positions_y[agent->current_goal_idx]; agent->goal_position_z = agent->goal_positions_z[agent->current_goal_idx]; + if (env->target_type == TARGET_DIJKSTRA) { + rebuild_dijkstra_route_to_current_goal(env, agent_idx); + } } } } diff --git a/pufferlib/ocean/drive/drive.py b/pufferlib/ocean/drive/drive.py index cf3be513af..15d8e7b6c3 100644 --- a/pufferlib/ocean/drive/drive.py +++ b/pufferlib/ocean/drive/drive.py @@ -141,8 +141,12 @@ def __init__( self.target_type = binding.TARGET_STATIC elif target_type == "dynamic": self.target_type = binding.TARGET_DYNAMIC + elif target_type == "dijkstra": + self.target_type = binding.TARGET_DIJKSTRA + self.num_target_waypoints = 4 + num_target_waypoints = 4 else: - raise ValueError(f"target_type must be 'static' or 'dynamic'. Got: {target_type}") + raise ValueError(f"target_type must be 'static', 'dynamic', or 'dijkstra'. Got: {target_type}") self.collision_behavior = collision_behavior self.offroad_behavior = offroad_behavior self.traffic_light_behavior = traffic_light_behavior @@ -209,7 +213,7 @@ def __init__( self.num_reward_coefs = binding.NUM_REWARD_COEFS if reward_conditioning else 0 # Target features based on target_type - if target_type == "static": + if target_type == "static" or target_type == "dijkstra": self.target_features = binding.STATIC_TARGET_FEATURES else: self.target_features = binding.DYNAMIC_TARGET_FEATURES @@ -818,6 +822,16 @@ def get_state(self): except Exception: return binding.env_get(self.c_envs) + def get_obs_html_frame(self, agent_f32, agent_i32, metrics_f32, puffer_f32, traffic_i16): + binding.vec_get_obs_html_frame( + self.c_envs, + agent_f32, + agent_i32, + metrics_f32, + puffer_f32, + traffic_i16, + ) + def calculate_area(p1, p2, p3): # Calculate the area of the triangle using the determinant method diff --git a/pufferlib/ocean/drive/visualize.c b/pufferlib/ocean/drive/visualize.c index a8c8932b16..3c60f0c2f9 100644 --- a/pufferlib/ocean/drive/visualize.c +++ b/pufferlib/ocean/drive/visualize.c @@ -290,7 +290,7 @@ int eval_gif( .goal_speed = conf.goal_speed, .map_name = (char *) map_name, .init_step = init_step, - .max_controlled_agents = max_controlled_agents, + .num_controllable_agents = max_controlled_agents, .collision_behavior = conf.collision_behavior, .offroad_behavior = conf.offroad_behavior, .compute_eval_metrics = conf.compute_eval_metrics, diff --git a/pufferlib/ocean/env_binding.h b/pufferlib/ocean/env_binding.h index 0c5e8cc112..309dfb115f 100644 --- a/pufferlib/ocean/env_binding.h +++ b/pufferlib/ocean/env_binding.h @@ -828,6 +828,133 @@ static PyObject *vec_get(PyObject *self, PyObject *args) { return list; } +static PyObject *vec_get_obs_html_frame(PyObject *self, PyObject *args) { + if (PyTuple_Size(args) != 6) { + PyErr_SetString(PyExc_TypeError, "vec_get_obs_html_frame requires 6 arguments"); + return NULL; + } + + VecEnv *vec = unpack_vecenv(args); + if (!vec) { + return NULL; + } + + PyArrayObject *agent_f32_array = (PyArrayObject *) PyTuple_GetItem(args, 1); + PyArrayObject *agent_i32_array = (PyArrayObject *) PyTuple_GetItem(args, 2); + PyArrayObject *metrics_f32_array = (PyArrayObject *) PyTuple_GetItem(args, 3); + PyArrayObject *puffer_f32_array = (PyArrayObject *) PyTuple_GetItem(args, 4); + PyArrayObject *traffic_i16_array = (PyArrayObject *) PyTuple_GetItem(args, 5); + + if (!PyArray_Check(agent_f32_array) || !PyArray_Check(agent_i32_array) || !PyArray_Check(metrics_f32_array) + || !PyArray_Check(puffer_f32_array) || !PyArray_Check(traffic_i16_array)) { + PyErr_SetString(PyExc_TypeError, "All output arrays must be NumPy arrays"); + return NULL; + } + + memset(PyArray_DATA(agent_f32_array), 0, PyArray_NBYTES(agent_f32_array)); + memset(PyArray_DATA(agent_i32_array), 0, PyArray_NBYTES(agent_i32_array)); + memset(PyArray_DATA(metrics_f32_array), 0, PyArray_NBYTES(metrics_f32_array)); + memset(PyArray_DATA(puffer_f32_array), 0, PyArray_NBYTES(puffer_f32_array)); + memset(PyArray_DATA(traffic_i16_array), 0, PyArray_NBYTES(traffic_i16_array)); + + float *agent_f32 = (float *) PyArray_DATA(agent_f32_array); + int *agent_i32 = (int *) PyArray_DATA(agent_i32_array); + float *metrics_f32 = (float *) PyArray_DATA(metrics_f32_array); + float *puffer_f32 = (float *) PyArray_DATA(puffer_f32_array); + short *traffic_i16 = (short *) PyArray_DATA(traffic_i16_array); + + int env_cap = (int) PyArray_DIM(agent_f32_array, 0); + int env_count = vec->num_envs < env_cap ? vec->num_envs : env_cap; + int agent_cap = (int) PyArray_DIM(agent_f32_array, 1); + int agent_f32_fields = (int) PyArray_DIM(agent_f32_array, 2); + int agent_i32_fields = (int) PyArray_DIM(agent_i32_array, 2); + int metric_fields = (int) PyArray_DIM(metrics_f32_array, 2); + int puffer_fields = (int) PyArray_DIM(puffer_f32_array, 2); + int traffic_cap = (int) PyArray_DIM(traffic_i16_array, 1); + int traffic_fields = (int) PyArray_DIM(traffic_i16_array, 2); + + for (int e = 0; e < env_count; e++) { + Drive *drive = (Drive *) vec->envs[e]; + int agent_count = drive->num_total_agents < agent_cap ? drive->num_total_agents : agent_cap; + int traffic_count = drive->num_traffic_elements < traffic_cap ? drive->num_traffic_elements : traffic_cap; + + for (int i = 0; i < agent_count; i++) { + Agent *a = &drive->agents[i]; + int f32_base = (e * agent_cap + i) * agent_f32_fields; + int i32_base = (e * agent_cap + i) * agent_i32_fields; + int metrics_base = (e * agent_cap + i) * metric_fields; + + agent_f32[f32_base + 0] = a->sim_x; + agent_f32[f32_base + 1] = a->sim_y; + agent_f32[f32_base + 2] = a->sim_z; + agent_f32[f32_base + 3] = a->sim_heading; + agent_f32[f32_base + 4] = a->sim_length; + agent_f32[f32_base + 5] = a->sim_width; + agent_f32[f32_base + 6] = a->sim_speed; + agent_f32[f32_base + 7] = a->steering_angle; + agent_f32[f32_base + 8] = a->a_long; + agent_f32[f32_base + 9] = a->a_lat; + agent_f32[f32_base + 10] = a->jerk_long; + agent_f32[f32_base + 11] = a->jerk_lat; + + agent_i32[i32_base + 0] = i; + agent_i32[i32_base + 1] = a->type; + agent_i32[i32_base + 2] = a->sim_valid; + agent_i32[i32_base + 3] = a->active_agent; + agent_i32[i32_base + 4] = a->stopped; + agent_i32[i32_base + 5] = a->removed; + agent_i32[i32_base + 6] = a->current_lane_idx; + agent_i32[i32_base + 7] = -1; + + memcpy(&metrics_f32[metrics_base], a->metrics_array, sizeof(float) * NUM_METRICS); + } + + if (drive->active_agent_indices) { + for (int j = 0; j < drive->active_agent_count; j++) { + int agent_idx = drive->active_agent_indices[j]; + if (agent_idx < 0 || agent_idx >= agent_count) { + continue; + } + int i32_base = (e * agent_cap + agent_idx) * agent_i32_fields; + int puffer_base = (e * agent_cap + agent_idx) * puffer_fields; + agent_i32[i32_base + 7] = j; + + if (!drive->compute_eval_metrics || !drive->logs || j >= drive->logs_capacity) { + continue; + } + Log *log = &drive->logs[j]; + puffer_f32[puffer_base + 0] = log->puffer_score; + puffer_f32[puffer_base + 1] = log->no_at_fault; + puffer_f32[puffer_base + 2] = log->no_offroad; + puffer_f32[puffer_base + 3] = log->no_red_light; + puffer_f32[puffer_base + 4] = log->making_progress; + puffer_f32[puffer_base + 5] = log->driving_direction_score; + puffer_f32[puffer_base + 6] = log->ttc_puffer_rate; + puffer_f32[puffer_base + 7] = log->progress_ratio; + puffer_f32[puffer_base + 8] = log->speed_limit_compliance; + puffer_f32[puffer_base + 9] = log->comfort_score; + puffer_f32[puffer_base + 10] = log->multi_lane_score; + puffer_f32[puffer_base + 11] = log->wrong_way_distance; + puffer_f32[puffer_base + 12] = log->speed_violation_sum; + puffer_f32[puffer_base + 13] = log->multiplier; + puffer_f32[puffer_base + 14] = log->weighted_average; + } + } + + for (int i = 0; i < traffic_count; i++) { + TrafficControlElement *t = &drive->traffic_elements[i]; + int base = (e * traffic_cap + i) * traffic_fields; + traffic_i16[base + 0] = 1; + traffic_i16[base + 1] = (short) t->type; + if (t->states && drive->timestep >= 0 && drive->timestep < t->state_length) { + traffic_i16[base + 2] = (short) t->states[drive->timestep]; + } + } + } + + Py_RETURN_NONE; +} + static PyObject *vec_close(PyObject *self, PyObject *args) { VecEnv *vec = unpack_vecenv(args); if (!vec) { @@ -1214,6 +1341,10 @@ static PyMethodDef methods[] "Release a single env's render client without destroying the env"}, {"vec_close", vec_close, METH_VARARGS, "Close the vector of environments"}, {"vec_get", vec_get, METH_VARARGS, "Get attributes from each env in a VecEnv"}, + {"vec_get_obs_html_frame", + vec_get_obs_html_frame, + METH_VARARGS, + "Fill compact obs_html frame arrays from a VecEnv"}, {"shared", (PyCFunction) my_shared, METH_VARARGS | METH_KEYWORDS, "Shared state"}, {"get_global_agent_state", get_global_agent_state, METH_VARARGS, "Get global agent state"}, {"vec_get_global_agent_state", vec_get_global_agent_state, METH_VARARGS, "Get agent state from vectorized env"}, @@ -1267,6 +1398,7 @@ PyMODINIT_FUNC PyInit_binding(void) { PyModule_AddIntConstant(m, "NUM_REWARD_COEFS", NUM_REWARD_COEFS); PyModule_AddIntConstant(m, "TARGET_STATIC", TARGET_STATIC); PyModule_AddIntConstant(m, "TARGET_DYNAMIC", TARGET_DYNAMIC); + PyModule_AddIntConstant(m, "TARGET_DIJKSTRA", TARGET_DIJKSTRA); PyObject_SetAttrString(m, "MULTI_LANE_FULL_SCORE_TIME", PyFloat_FromDouble(MULTI_LANE_FULL_SCORE_TIME)); PyObject_SetAttrString(m, "MULTI_LANE_HALF_SCORE_TIME", PyFloat_FromDouble(MULTI_LANE_HALF_SCORE_TIME)); diff --git a/pufferlib/ocean/env_config.h b/pufferlib/ocean/env_config.h index e0b398041f..332410b788 100644 --- a/pufferlib/ocean/env_config.h +++ b/pufferlib/ocean/env_config.h @@ -105,6 +105,8 @@ static int handler(void *config, const char *section, const char *name, const ch env_config->target_type = 0; // TARGET_STATIC } else if (strcmp(value, "\"dynamic\"") == 0 || strcmp(value, "dynamic") == 0) { env_config->target_type = 1; // TARGET_DYNAMIC + } else if (strcmp(value, "\"dijkstra\"") == 0 || strcmp(value, "dijkstra") == 0) { + env_config->target_type = 2; // TARGET_DIJKSTRA } else { printf("Warning: Unknown target_type value '%s', defaulting to static\n", value); env_config->target_type = 0; diff --git a/pufferlib/ocean/torch.py b/pufferlib/ocean/torch.py index d2e3d885b2..c0e4f06cb7 100644 --- a/pufferlib/ocean/torch.py +++ b/pufferlib/ocean/torch.py @@ -191,6 +191,73 @@ def forward(self, observations, ego_dim): concat_features = torch.cat(feature_list, dim=1) return self.backbone(concat_features) + def pool_slot_counts(self, observations, ego_dim): + partner_dim = self.obs_slots_partners_n * self.partner_features_count + lane_dim = self.obs_slots_lane_kept * self.road_features_count + boundary_dim = self.obs_slots_boundary_kept * self.road_features_count + traffic_control_dim = self.obs_slots_traffic_controls_n * self.traffic_control_features_count + + slide_idx = ego_dim + self.conditioning_dim + partner_observations = observations[:, slide_idx : slide_idx + partner_dim] + slide_idx += partner_dim + lane_observations = observations[:, slide_idx : slide_idx + lane_dim] + slide_idx += lane_dim + boundary_observations = observations[:, slide_idx : slide_idx + boundary_dim] + slide_idx += boundary_dim + traffic_control_observations = observations[:, slide_idx : slide_idx + traffic_control_dim] + + counts = {} + if self.obs_slots_lane_kept > 0: + lane_objects = lane_observations.view(-1, self.obs_slots_lane_kept, self.road_features_count) + lane_winners = self.lane_encoder(lane_objects).max(dim=1).indices + lane_counts = torch.zeros( + observations.shape[0], self.obs_slots_lane_kept, device=observations.device, dtype=torch.int64 + ) + counts["pool_lane"] = lane_counts.scatter_add(1, lane_winners, torch.ones_like(lane_winners)) + if self.obs_slots_boundary_kept > 0: + boundary_objects = boundary_observations.view(-1, self.obs_slots_boundary_kept, self.road_features_count) + boundary_winners = self.boundary_encoder(boundary_objects).max(dim=1).indices + boundary_counts = torch.zeros( + observations.shape[0], self.obs_slots_boundary_kept, device=observations.device, dtype=torch.int64 + ) + counts["pool_boundary"] = boundary_counts.scatter_add( + 1, boundary_winners, torch.ones_like(boundary_winners) + ) + if self.obs_slots_partners_n > 0: + partner_objects = partner_observations.view(-1, self.obs_slots_partners_n, self.partner_features_count) + partner_winners = self.partner_encoder(partner_objects).max(dim=1).indices + partner_counts = torch.zeros( + observations.shape[0], self.obs_slots_partners_n, device=observations.device, dtype=torch.int64 + ) + counts["pool_partner"] = partner_counts.scatter_add(1, partner_winners, torch.ones_like(partner_winners)) + if self.obs_slots_traffic_controls_n > 0: + traffic_control_objects = traffic_control_observations.view( + -1, self.obs_slots_traffic_controls_n, self.traffic_control_features_count + ) + traffic_control_continuous = traffic_control_objects[:, :, : self.traffic_control_continuous_features] + traffic_control_type = traffic_control_objects[:, :, self.traffic_control_continuous_features] + traffic_control_state = traffic_control_objects[:, :, self.traffic_control_continuous_features + 1] + traffic_control_type_onehot = F.one_hot( + traffic_control_type.long(), + num_classes=binding.NUM_TRAFFIC_CONTROL_TYPES, + ).to(traffic_control_continuous.dtype) + traffic_control_state_onehot = F.one_hot( + traffic_control_state.long(), + num_classes=binding.NUM_TRAFFIC_CONTROL_STATES, + ).to(traffic_control_continuous.dtype) + traffic_control_objects = torch.cat( + [traffic_control_continuous, traffic_control_type_onehot, traffic_control_state_onehot], + dim=2, + ) + traffic_control_winners = self.traffic_control_encoder(traffic_control_objects).max(dim=1).indices + traffic_control_counts = torch.zeros( + observations.shape[0], self.obs_slots_traffic_controls_n, device=observations.device, dtype=torch.int64 + ) + counts["pool_traffic"] = traffic_control_counts.scatter_add( + 1, traffic_control_winners, torch.ones_like(traffic_control_winners) + ) + return counts + class Drive(nn.Module): def __init__( @@ -295,6 +362,9 @@ def forward_train(self, x, state=None): def forward_eval(self, x, state=None): return self.forward(x, state) + def pool_slot_counts(self, observations, state=None): + return self.actor_backbone.pool_slot_counts(observations, self.ego_dim) + # Required for PufferLib recurrent wrappers def encode_observations(self, observations, state=None): assert not self.split_network, "LSTM wrapper doesn't support split_network=True" diff --git a/pufferlib/resources/drive/binaries/carla/opendrive__Town01.bin b/pufferlib/resources/drive/binaries/carla/opendrive__Town01.bin index 7cdff08dd0..242e0a252a 100644 Binary files a/pufferlib/resources/drive/binaries/carla/opendrive__Town01.bin and b/pufferlib/resources/drive/binaries/carla/opendrive__Town01.bin differ diff --git a/pufferlib/resources/drive/binaries/carla/opendrive__Town02.bin b/pufferlib/resources/drive/binaries/carla/opendrive__Town02.bin index 5eed725922..0c1f829abb 100644 Binary files a/pufferlib/resources/drive/binaries/carla/opendrive__Town02.bin and b/pufferlib/resources/drive/binaries/carla/opendrive__Town02.bin differ diff --git a/pufferlib/resources/drive/binaries/carla/opendrive__Town03.bin b/pufferlib/resources/drive/binaries/carla/opendrive__Town03.bin index e5db89ffdd..30a47d4fcb 100644 Binary files a/pufferlib/resources/drive/binaries/carla/opendrive__Town03.bin and b/pufferlib/resources/drive/binaries/carla/opendrive__Town03.bin differ diff --git a/pufferlib/resources/drive/binaries/carla/opendrive__Town04.bin b/pufferlib/resources/drive/binaries/carla/opendrive__Town04.bin index 38d42572f1..0dcbabfcab 100644 Binary files a/pufferlib/resources/drive/binaries/carla/opendrive__Town04.bin and b/pufferlib/resources/drive/binaries/carla/opendrive__Town04.bin differ diff --git a/pufferlib/resources/drive/binaries/carla/opendrive__Town05.bin b/pufferlib/resources/drive/binaries/carla/opendrive__Town05.bin index 4836a5dc37..e7d344333a 100644 Binary files a/pufferlib/resources/drive/binaries/carla/opendrive__Town05.bin and b/pufferlib/resources/drive/binaries/carla/opendrive__Town05.bin differ diff --git a/pufferlib/resources/drive/binaries/carla/opendrive__Town06.bin b/pufferlib/resources/drive/binaries/carla/opendrive__Town06.bin index bfe8c024f5..506ed84d90 100644 Binary files a/pufferlib/resources/drive/binaries/carla/opendrive__Town06.bin and b/pufferlib/resources/drive/binaries/carla/opendrive__Town06.bin differ diff --git a/pufferlib/resources/drive/binaries/carla/opendrive__Town07.bin b/pufferlib/resources/drive/binaries/carla/opendrive__Town07.bin index 1732c594fd..ea45f1cd08 100644 Binary files a/pufferlib/resources/drive/binaries/carla/opendrive__Town07.bin and b/pufferlib/resources/drive/binaries/carla/opendrive__Town07.bin differ diff --git a/pufferlib/resources/drive/binaries/carla/opendrive__Town10HD.bin b/pufferlib/resources/drive/binaries/carla/opendrive__Town10HD.bin index bd913dee5a..a7da5036e8 100644 Binary files a/pufferlib/resources/drive/binaries/carla/opendrive__Town10HD.bin and b/pufferlib/resources/drive/binaries/carla/opendrive__Town10HD.bin differ diff --git a/pufferlib/viz.py b/pufferlib/viz.py index 1b698d1806..e13fae71c1 100644 --- a/pufferlib/viz.py +++ b/pufferlib/viz.py @@ -1,10 +1,9 @@ """Bird's Eye View visualization for PufferDrive scenarios using Matplotlib.""" import dataclasses -import weakref from typing import Optional, Tuple -import math + import re import matplotlib.figure import matplotlib.patches as mpatches @@ -16,6 +15,7 @@ import json import zlib import base64 +import struct from pufferlib.ocean.drive import binding from pufferlib.ocean.drive.drive import compute_effective_road_obs_count @@ -63,12 +63,6 @@ "#9EDAE5", ] -_figure_cache: weakref.WeakValueDictionary = weakref.WeakValueDictionary() -_map_cache = {} - -MULTI_LANE_FULL_SCORE_TIME = binding.MULTI_LANE_FULL_SCORE_TIME -MULTI_LANE_HALF_SCORE_TIME = binding.MULTI_LANE_HALF_SCORE_TIME - METRIC_LABELS = [ "collision", "offroad", @@ -100,22 +94,13 @@ class VizConfig: figsize: Tuple[float, float] = (20.0, 20.0) dpi: int = 100 show_agent_id: bool = True - show_routes: bool = False show_goal: bool = True - show_sdc_paths: bool = False - show_trajectories: bool = False goal_radius: float = 2.0 - follow_ego: bool = False - debug_metrics: bool = False - reuse_figure: bool = True def get_bounds(self, scenario) -> Tuple[float, float, float, float]: map_corners = scenario.get("map_corners") - if self.follow_ego: - ego_agent = scenario.get("agents")[-1] - cx, cy = ego_agent["sim_x"], ego_agent["sim_y"] - elif self.center is not None: + if self.center is not None: cx, cy = self.center elif map_corners and len(map_corners) >= 4: cx, cy = (map_corners[0] + map_corners[2]) / 2, (map_corners[1] + map_corners[3]) / 2 @@ -154,6 +139,10 @@ def _scale_ratio(numerator, denominator, default=1.0): return default if denominator == 0 else float(numerator) / float(denominator) +def _is_empty_obs_row(row): + return np.all(row == 0) or np.all(row == PADDED_OBSERVATION_VALUE) + + def _obs_scales( env_cfg=None, obs_norm_goal_offset_m=100.0, @@ -181,45 +170,19 @@ def _obs_scales( } -def _init_fig_ax(config: VizConfig, reuse_key: str = None, with_metrics: bool = False): - cache_key = f"{reuse_key}_{'metrics' if with_metrics else 'single'}" if reuse_key else None - - if config.reuse_figure and cache_key and cache_key in _figure_cache: - fig = _figure_cache[cache_key] - if fig and plt.fignum_exists(fig.number): - for ax in fig.axes: - ax.clear() - ax.set_facecolor(COLORS["background"]) - if with_metrics: - return fig, fig.axes[0], fig.axes[1] - return fig, fig.axes[0] - - if with_metrics: - fig, (ax_main, ax_metrics) = plt.subplots( - 1, 2, figsize=(config.figsize[0] * 1.5, config.figsize[1]), gridspec_kw={"width_ratios": [2, 1]} - ) - else: - fig, ax_main = plt.subplots() - fig.set_size_inches(config.figsize) - ax_metrics = None +def _init_fig_ax(config: VizConfig): + fig, ax_main = plt.subplots() + fig.set_size_inches(config.figsize) fig.set_dpi(config.dpi) fig.set_facecolor(COLORS["background"]) ax_main.set_facecolor(COLORS["background"]) - if ax_metrics: - ax_metrics.set_facecolor(COLORS["background"]) - if config.reuse_figure and cache_key: - _figure_cache[cache_key] = fig - - if with_metrics: - return fig, ax_main, ax_metrics return fig, ax_main -def _build_road_cache(road_elements): +def _build_road_data(road_elements): lanes, lines, edges = [], [], [] - lane_dict = {} for elem in road_elements or []: if not isinstance(elem, dict): continue @@ -229,9 +192,6 @@ def _build_road_cache(road_elements): pts = np.column_stack((np.asarray(x), np.asarray(y))) if 1 <= t <= 3: lanes.append(pts) - lid = elem.get("id") - if lid is not None: - lane_dict[lid] = pts elif 11 <= t <= 18: lines.append(pts) elif 21 <= t <= 23: @@ -240,41 +200,33 @@ def _build_road_cache(road_elements): "lanes": lanes, "lines": lines, "edges": edges, - "lane_dict": lane_dict, - "collections": None, } -def _render_roads(ax, road_cache): - if not road_cache: +def _render_roads(ax, road_data): + if not road_data: return - collections = road_cache.get("collections") - if collections is None: - collections = [] - lanes = road_cache.get("lanes") or [] - lines = road_cache.get("lines") or [] - edges = road_cache.get("edges") or [] - if lanes: - collections.append(LineCollection(lanes, colors=COLORS["lane"], linewidths=0.8, alpha=0.7, zorder=1)) - if lines: - collections.append( - LineCollection( - lines, - colors=COLORS["road_line"], - linewidths=0.8, - alpha=0.6, - linestyles=(0, (5, 5)), - zorder=2, - ) + lanes = road_data.get("lanes") or [] + lines = road_data.get("lines") or [] + edges = road_data.get("edges") or [] + if lanes: + ax.add_collection(LineCollection(lanes, colors=COLORS["lane"], linewidths=0.8, alpha=0.7, zorder=1)) + if lines: + ax.add_collection( + LineCollection( + lines, + colors=COLORS["road_line"], + linewidths=0.8, + alpha=0.6, + linestyles=(0, (5, 5)), + zorder=2, ) - if edges: - collections.append(LineCollection(edges, colors=COLORS["road_edge"], linewidths=0.8, alpha=0.8, zorder=2)) - road_cache["collections"] = collections - for collection in collections: - ax.add_collection(collection) + ) + if edges: + ax.add_collection(LineCollection(edges, colors=COLORS["road_edge"], linewidths=0.8, alpha=0.8, zorder=2)) -def _build_traffic_cache(traffic_elements): +def _build_traffic_data(traffic_elements): traffic_lights = [] # (stop_line, states) stop_signs = [] # stop_line endpoints yield_signs = [] # stop_line endpoints @@ -299,11 +251,11 @@ def _build_traffic_cache(traffic_elements): } -def _render_traffic(ax, traffic_cache, timestep): - if not traffic_cache: +def _render_traffic(ax, traffic_data, timestep): + if not traffic_data: return # Traffic lights — colored by state - for light in traffic_cache.get("traffic_lights", []): + for light in traffic_data.get("traffic_lights", []): sl = light["stop_line"] states = light["states"] state = int(states[timestep]) if states and len(states) > timestep else 0 @@ -311,7 +263,7 @@ def _render_traffic(ax, traffic_cache, timestep): ax.plot([sl[0], sl[3]], [sl[1], sl[4]], color=color, linewidth=3, solid_capstyle="butt", alpha=0.9, zorder=15) # Stop signs — red/black striped - for sl in traffic_cache.get("stop_signs", []): + for sl in traffic_data.get("stop_signs", []): ax.plot([sl[0], sl[3]], [sl[1], sl[4]], color="black", linewidth=4, solid_capstyle="butt", alpha=0.9, zorder=15) ax.plot( [sl[0], sl[3]], @@ -325,7 +277,7 @@ def _render_traffic(ax, traffic_cache, timestep): ) # Yield signs — yellow/black striped - for sl in traffic_cache.get("yield_signs", []): + for sl in traffic_data.get("yield_signs", []): ax.plot([sl[0], sl[3]], [sl[1], sl[4]], color="black", linewidth=4, solid_capstyle="butt", alpha=0.9, zorder=15) ax.plot( [sl[0], sl[3]], @@ -339,29 +291,6 @@ def _render_traffic(ax, traffic_cache, timestep): ) -def _render_routes(ax, agents, lane_dict, active_indices): - if not agents or not lane_dict: - return - - active_set = set(active_indices or []) - segments_by_color = {} - for idx, agent in enumerate(agents): - if not isinstance(agent, dict) or idx not in active_set: - continue - route = agent.get("route", []) - if not route: - continue - color = get_agent_color(agent.get("id", idx)) - segs = segments_by_color.setdefault(color, []) - for lid in route: - if lid in lane_dict: - segs.append(lane_dict[lid]) - - for color, segs in segments_by_color.items(): - if segs: - ax.add_collection(LineCollection(segs, colors=color, linewidths=2.0, alpha=0.6, linestyles="--", zorder=5)) - - def _render_agents(ax, agents, active_indices, static_indices, config, px_per_meter): if not agents: return @@ -533,240 +462,15 @@ def _render_agents(ax, agents, active_indices, static_indices, config, px_per_me ax.add_collection(PatchCollection(cyclist_patches, match_original=True)) -def _render_paths(ax, scenario): - """Render SDC planned paths.""" - for idx in range(scenario["active_agent_count"]): - x = np.array([item["x"] for item in scenario["sdc_paths"][idx]["waypoints"]]) - y = np.array([item["y"] for item in scenario["sdc_paths"][idx]["waypoints"]]) - init_idx = scenario["agents"][scenario["active_agent_indices"][idx]]["closest_path_idx_wp"] - end_idx = min(init_idx + 20, scenario["sdc_paths"][idx]["num_waypoints"] - 1) - agent_id = scenario["agents"][scenario["active_agent_indices"][idx]]["id"] - color = get_agent_color(agent_id, is_active=True) - ax.scatter(x[init_idx:end_idx], y[init_idx:end_idx], color=color, s=20) - - -def _render_trajectories(ax, scenario): - for idx in range(scenario["active_agent_count"]): - wps = scenario["trajectory_waypoints_global"][idx]["waypoints"] - x = np.array([item["x"] for item in wps]) - y = np.array([item["y"] for item in wps]) - heading = np.array([item["heading"] for item in wps]) - ax.scatter(x, y, color=np.array([0, 100, 0]) / 255.0, s=20) - ax.quiver( - x, - y, - np.cos(heading), - np.sin(heading), - color=np.array([0, 100, 0]) / 255.0, - scale_units="xy", # Use data coordinates for scaling - scale=1.0, # A scale of 1.0 means arrows of length (U,V) are plotted as such - width=0.005, - ) - - -def _render_debug_metrics_table(ax, agents, active_agent_indices, px_per_meter=10.0): - """Render a table of per-agent metrics for debugging.""" - font_size = max(10, int(px_per_meter / 5)) - - if not agents or not active_agent_indices: - ax.text(0.5, 0.5, "No active agents", ha="center", va="center", fontsize=font_size) - ax.axis("off") - return - - active_set = set(active_agent_indices) - - # Gather metrics for active agents - metrics_data = [] - for idx, agent in enumerate(agents): - if idx not in active_set: - continue - agent_id = agent["id"] - vx, vy = agent.get("sim_vx", 0), agent.get("sim_vy", 0) - speed = np.sqrt(vx**2 + vy**2) - current_lane_id = agent.get("current_lane_idx", -1) - metrics = agent.get("metrics_array", [0.0] * len(METRIC_LABELS)) - metrics_data.append( - { - "id": agent_id, - "current_lane": current_lane_id, - "speed": speed, - "lane_dist": metrics[5], - "lane_head": metrics[6], - "offroad": metrics[1], - "collision": metrics[0], - "comfort": metrics[7], - "red_light": metrics[2], - "at_fault": metrics[12], - "ttc": metrics[13], - "ttc_tfl": metrics[14], - "progress": metrics[15], - "ml_time": metrics[16] if len(metrics) > 16 else 0.0, - "color": get_agent_color(agent_id, is_active=True), - } - ) - - if not metrics_data: - ax.text(0.5, 0.5, "No active agents", ha="center", va="center", fontsize=font_size) - ax.axis("off") - return - - ax.axis("off") - ax.set_xlim(0, 1) - ax.set_ylim(0, 1) - - # Remove margins - ax.margins(0) - - # Table headers - headers = [ - "ID", - "Lane", - "LDist", - "LHead", - "Spd", - "Cmft", - "Off", - "Col", - "Red", - "AF", - "TTC", - "TTC_TFL", - "Prog", - "MLt", - ] - num_agents = len(metrics_data) - y_start, y_end = 0.95, 0.05 - row_height = min(0.06, (y_start - y_end) / (num_agents + 2)) - x_positions = np.linspace(0.02, 0.96, len(headers)) - for i, header in enumerate(headers): - ax.text(x_positions[i], y_start, header, fontsize=font_size + 2, fontweight="bold", va="top") - - for row_idx, data in enumerate(metrics_data): - y_pos = y_start - (row_idx + 1) * row_height - ax.text( - x_positions[0], y_pos, str(data["id"]), fontsize=font_size, color=data["color"], fontweight="bold", va="top" - ) - ax.text(x_positions[1], y_pos, f"{data['current_lane']:.0f}", fontsize=font_size, va="top") - ax.text(x_positions[2], y_pos, f"{data['lane_dist']:.2f}", fontsize=font_size, va="top") - ax.text(x_positions[3], y_pos, f"{data['lane_head']:.2f}", fontsize=font_size, va="top") - ax.text(x_positions[4], y_pos, f"{data['speed']:.1f}", fontsize=font_size, va="top") - ax.text( - x_positions[5], - y_pos, - f"{data['comfort']:.1f}", - fontsize=font_size, - color="red" if data["comfort"] > 0 else "green", - va="top", - ) - ax.text( - x_positions[6], - y_pos, - "+" if data["offroad"] else "-", - fontsize=font_size, - color="red" if data["offroad"] else "green", - va="top", - ) - ax.text( - x_positions[7], - y_pos, - "+" if data["collision"] else "-", - fontsize=font_size, - color="red" if data["collision"] else "green", - va="top", - ) - ax.text( - x_positions[8], - y_pos, - "+" if data["red_light"] else "-", - fontsize=font_size, - color="red" if data["red_light"] else "green", - va="top", - ) - ax.text( - x_positions[9], - y_pos, - "+" if data["at_fault"] else "-", - fontsize=font_size, - color="red" if data["at_fault"] else "green", - va="top", - ) - ax.text( - x_positions[10], - y_pos, - f"{data['ttc']:.2f}", - fontsize=font_size, - color="red" if data["ttc"] < 0.95 else "green", - va="top", - ) - ax.text( - x_positions[11], - y_pos, - f"{data['ttc_tfl']:.2f}", - fontsize=font_size, - color="red" if data["ttc_tfl"] < 0.95 else "green", - va="top", - ) - ax.text( - x_positions[12], - y_pos, - f"{data['progress']:.2f}", - fontsize=font_size, - color="green" if data["progress"] > 0.2 else "red", - va="top", - ) - ax.text( - x_positions[13], - y_pos, - f"{data['ml_time']:.1f}", - fontsize=font_size, - color="red" if data["ml_time"] > MULTI_LANE_FULL_SCORE_TIME else "green", - va="top", - ) - - ax.set_title("Active Agent Metrics + V-Max", fontsize=font_size + 4, fontweight="bold", pad=10) - - -def _get_cache_key(reuse_key): - return reuse_key - - -def _get_or_build_map_cache(cache_key, scenario): - if cache_key: - cache = _map_cache.get(cache_key) - map_name = scenario.get("map_name") - if cache and cache.get("map_name") == map_name: - return cache - road_cache = _build_road_cache(scenario.get("road_elements", [])) - traffic_cache = _build_traffic_cache(scenario.get("traffic_elements", [])) - cache = { - "map_name": map_name, - "road": road_cache, - "traffic": traffic_cache, - } - _map_cache[cache_key] = cache - return cache - - return { +def plot_simulator_state(scenario, timestep: int = 0) -> np.ndarray: + """Render simulator state to RGB image array.""" + vis_config = VizConfig() + map_data = { "map_name": scenario.get("map_name"), - "road": _build_road_cache(scenario.get("road_elements", [])), - "traffic": _build_traffic_cache(scenario.get("traffic_elements", [])), + "road": _build_road_data(scenario.get("road_elements", [])), + "traffic": _build_traffic_data(scenario.get("traffic_elements", [])), } - -def plot_simulator_state( - scenario, - timestep: int = 0, - show_trajectories: bool = False, - simulation_mode: str = None, - reuse_key: str = None, -) -> np.ndarray: - """Render simulator state to RGB image array.""" - vis_radius = None if simulation_mode == "gigaflow" or simulation_mode is None else 75.0 - vis_config = VizConfig(radius=vis_radius, show_trajectories=show_trajectories) - - cache_key = _get_cache_key(reuse_key) - map_cache = _get_or_build_map_cache(cache_key, scenario) - bounds = vis_config.get_bounds(scenario) x_min, x_max, y_min, y_max = bounds @@ -775,11 +479,7 @@ def plot_simulator_state( vis_config.figsize[1] * vis_config.dpi / (y_max - y_min), ) - if vis_config.debug_metrics: - fig, ax, ax_metrics = _init_fig_ax(vis_config, cache_key, with_metrics=True) - else: - fig, ax = _init_fig_ax(vis_config, cache_key, with_metrics=False) - ax_metrics = None + fig, ax = _init_fig_ax(vis_config) ax.set_aspect("equal") ax.set_title( @@ -788,19 +488,8 @@ def plot_simulator_state( fontweight="bold", ) - _render_roads(ax, map_cache.get("road")) - _render_traffic(ax, map_cache.get("traffic"), timestep) - if vis_config.show_routes: - _render_routes( - ax, - scenario.get("agents", []), - map_cache.get("road", {}).get("lane_dict"), - scenario.get("active_agent_indices", []), - ) - if vis_config.show_sdc_paths: - _render_paths(ax, scenario) - if vis_config.show_trajectories and timestep > 0: - _render_trajectories(ax, scenario) + _render_roads(ax, map_data.get("road")) + _render_traffic(ax, map_data.get("traffic"), timestep) _render_agents( ax, @@ -814,16 +503,7 @@ def plot_simulator_state( ax.set_xlim(x_min, x_max) ax.set_ylim(y_min, y_max) - if vis_config.debug_metrics and ax_metrics: - _render_debug_metrics_table( - ax_metrics, - scenario.get("agents", []), - scenario.get("active_agent_indices", []), - px_per_meter=px_per_meter, - ) - - close_fig = not (vis_config.reuse_figure and cache_key) - return _img_from_fig(fig, close=close_fig) + return _img_from_fig(fig) def _img_from_fig(fig: matplotlib.figure.Figure, close: bool = True) -> np.ndarray: @@ -836,23 +516,11 @@ def _img_from_fig(fig: matplotlib.figure.Figure, close: bool = True) -> np.ndarr return img -def close_figure(reuse_key: str): - if not reuse_key: - return - for suffix in ("single", "metrics"): - cache_key = f"{reuse_key}_{suffix}" - fig = _figure_cache.pop(cache_key, None) - if fig and plt.fignum_exists(fig.number): - plt.close(fig) - _map_cache.pop(reuse_key, None) - - def unpack_obs( obs_flat, - dynamics_model: int = 0, target_type: str = "static", reward_conditioning: bool = False, - num_target_waypoints: int = 5, + num_target_waypoints: int = 3, max_partners: int = 16, max_lane_segments: int = 16, max_boundary_segments: int = 16, @@ -865,7 +533,6 @@ def unpack_obs( Unpack the flattened observation into ego, map, partner, and traffic-control views. Args: obs_flat: flattened observation tensor of shape (batch_size, obs_dim) or (obs_dim,) - dynamics_model: 0 for CLASSIC, 1 for JERK Return: ego_state, target_obs, partners_obs, lane_obs, boundary_obs, traffic_controls_obs """ @@ -885,7 +552,7 @@ def unpack_obs( boundary_segment_count = compute_effective_road_obs_count(max_boundary_segments, obs_dropout_boundary) # Target obs - target_features = binding.STATIC_TARGET_FEATURES if target_type == "static" else binding.DYNAMIC_TARGET_FEATURES + target_features = binding.DYNAMIC_TARGET_FEATURES if target_type == "dynamic" else binding.STATIC_TARGET_FEATURES target_dim = num_target_waypoints * target_features # Extract ego state @@ -902,20 +569,29 @@ def unpack_obs( # Extract partners partners_start = target_end partners_end = partners_start + max_partners * partner_feature_size - partners_obs = obs_flat[:, partners_start:partners_end] - partners_obs = partners_obs.reshape(-1, max_partners, partner_feature_size) + if max_partners > 0: + partners_obs = obs_flat[:, partners_start:partners_end] + partners_obs = partners_obs.reshape(-1, max_partners, partner_feature_size) + else: + partners_obs = np.zeros((obs_flat.shape[0], 0, partner_feature_size)) # Extract lane elements lane_start = partners_end lane_end = lane_start + lane_segment_count * road_feature_size - lane_obs = obs_flat[:, lane_start:lane_end] - lane_obs = lane_obs.reshape(-1, lane_segment_count, road_feature_size) + if lane_segment_count > 0: + lane_obs = obs_flat[:, lane_start:lane_end] + lane_obs = lane_obs.reshape(-1, lane_segment_count, road_feature_size) + else: + lane_obs = np.zeros((obs_flat.shape[0], 0, road_feature_size)) # Extract boundary elements boundary_start = lane_end boundary_end = boundary_start + boundary_segment_count * road_feature_size - boundary_obs = obs_flat[:, boundary_start:boundary_end] - boundary_obs = boundary_obs.reshape(-1, boundary_segment_count, road_feature_size) + if boundary_segment_count > 0: + boundary_obs = obs_flat[:, boundary_start:boundary_end] + boundary_obs = boundary_obs.reshape(-1, boundary_segment_count, road_feature_size) + else: + boundary_obs = np.zeros((obs_flat.shape[0], 0, road_feature_size)) # Extract traffic controls traffic_start = boundary_end @@ -940,10 +616,9 @@ def unpack_obs( def plot_observation( obs, - dynamics_model="classic", target_type="static", reward_conditioning=False, - num_target_waypoints=10, + num_target_waypoints=3, max_partners=16, max_lane_segments=32, max_boundary_segments=32, @@ -962,24 +637,25 @@ def plot_observation( Args: obs: flattened observation tensor - dynamics_model: 0 for CLASSIC, 1 for JERK target_type: 0 for goal only, 1 for waypoints only, 2 for both """ fig, ax = plt.subplots(figsize=(20, 20)) + dynamics_model = _dynamics_model_name(dynamics_model) + target_type = _target_type_name(target_type) + num_target_waypoints = _target_waypoint_count(target_type, num_target_waypoints) ego_state, target_obs, partners_obs, lane_obs, boundary_obs, traffic_controls_obs = unpack_obs( obs, - dynamics_model, - target_type, - reward_conditioning, - num_target_waypoints, - max_partners, - max_lane_segments, - max_boundary_segments, - obs_slots_traffic_controls_n, - obs_dropout_lane, - obs_dropout_boundary, - agent_idx, + target_type=target_type, + reward_conditioning=reward_conditioning, + num_target_waypoints=num_target_waypoints, + max_partners=max_partners, + max_lane_segments=max_lane_segments, + max_boundary_segments=max_boundary_segments, + obs_slots_traffic_controls_n=obs_slots_traffic_controls_n, + obs_dropout_lane=obs_dropout_lane, + obs_dropout_boundary=obs_dropout_boundary, + agent_idx=agent_idx, ) scales = _obs_scales( obs_norm_goal_offset_m=obs_norm_goal_offset_m, @@ -989,7 +665,7 @@ def plot_observation( obs_norm_road_seg_length_m=obs_norm_road_seg_length_m, obs_norm_road_seg_width_m=obs_norm_road_seg_width_m, ) - target_position_scale = scales["goal_to_position"] if target_type == "static" else 1.0 + target_position_scale = scales["goal_to_position"] if target_type != "dynamic" else 1.0 ego_speed, ego_width, ego_length, steering_angle, a_long, a_lat, lcenter, lalign, speed_limit, _ = ego_state @@ -1002,26 +678,13 @@ def plot_observation( (-ego_length / 2, -ego_width / 2), ego_length, ego_width, - facecolor="#0055FF", - edgecolor="#FFD700", - linewidth=4, - alpha=0.9, + facecolor="blue", + edgecolor="black", + linewidth=2, + alpha=0.7, zorder=10, ) ) - # SDC label above the vehicle - ax.text( - 0, - ego_width / 2 + 0.03, - "SDC", - ha="center", - va="bottom", - fontsize=11, - fontweight="bold", - color="#FFD700", - bbox=dict(boxstyle="round,pad=0.2", facecolor="#0055FF", edgecolor="#FFD700", linewidth=1.5), - zorder=11, - ) # Draw target waypoints for i in range(target_obs.shape[0]): @@ -1029,7 +692,7 @@ def plot_observation( continue wp_x = target_obs[i][0] * target_position_scale wp_y = target_obs[i][1] * target_position_scale - if target_type == "static": + if target_type != "dynamic": color = "red" if i == 0 else "orange" marker = "*" if i == 0 else "o" s = 200 if i == 0 else 80 @@ -1040,7 +703,7 @@ def plot_observation( ax.scatter(wp_x, wp_y, color=color, marker=marker, s=s, zorder=15) # Add dynamics info text for JERK model - ego_info = f"Speed: {ego_speed:.2f}\nLane Centering: {lcenter:.2f}\nLane Align: {lalign:.2f}\nSpeed Limit: {speed_limit:.2f}" + ego_info = f"Speed: {ego_speed:.2f}\nSteering: {steering_angle:.3f}\nLane Centering: {lcenter:.2f}\nLane Align: {lalign:.2f}\nSpeed Limit: {speed_limit:.2f}\nStopped: {stopped:.2f}" ego_info += f"\nSteering: {steering_angle:.3f}\naccel_long: {a_long:.2f}\naccel_lat: {a_lat:.2f}" @@ -1056,7 +719,7 @@ def plot_observation( # Partner agents for i in range(partners_obs.shape[0]): - if np.all(partners_obs[i] == 0): + if _is_empty_obs_row(partners_obs[i]): continue rel_x, rel_y = partners_obs[i][0], partners_obs[i][1] length = partners_obs[i][3] * scales["veh_len_to_position"] @@ -1082,7 +745,7 @@ def plot_observation( rw2p = scales["road_width_to_position"] count_lane = 0 for i in range(lane_obs.shape[0]): - if np.all(lane_obs[i] == 0): + if _is_empty_obs_row(lane_obs[i]): continue count_lane += 1 rel_x, rel_y = lane_obs[i][0], lane_obs[i][1] @@ -1091,16 +754,15 @@ def plot_observation( color = "lightgrey" ax.scatter(rel_x, rel_y, color=color, s=10, zorder=1) ax.plot( - [rel_x + dir_cos * length / 2, rel_x - dir_cos * length / 2], - [rel_y + dir_sin * length / 2, rel_y - dir_sin * length / 2], + [rel_x + dir_cos * length, rel_x - dir_cos * length], + [rel_y + dir_sin * length, rel_y - dir_sin * length], color=color, linewidth=1, zorder=1, ) - count_boundary = 0 for i in range(boundary_obs.shape[0]): - if np.all(boundary_obs[i] == 0): + if _is_empty_obs_row(boundary_obs[i]): continue count_boundary += 1 rel_x, rel_y = boundary_obs[i][0], boundary_obs[i][1] @@ -1109,8 +771,8 @@ def plot_observation( color = "black" ax.scatter(rel_x, rel_y, color=color, s=10, zorder=1) ax.plot( - [rel_x + dir_cos * length / 2, rel_x - dir_cos * length / 2], - [rel_y + dir_sin * length / 2, rel_y - dir_sin * length / 2], + [rel_x + dir_cos * length, rel_x - dir_cos * length], + [rel_y + dir_sin * length, rel_y - dir_sin * length], color=color, linewidth=1, zorder=1, @@ -1128,10 +790,12 @@ def plot_observation( # Traffic controls for i in range(traffic_controls_obs.shape[0]): - if np.all(traffic_controls_obs[i] == 0): + if _is_empty_obs_row(traffic_controls_obs[i]): continue rel_x1, rel_y1, rel_x2, rel_y2, _, control_type, state = traffic_controls_obs[i] control_type = int(control_type) + if _traffic_control_kind(control_type) is None: + continue if control_type == binding.TRAFFIC_CONTROL_TYPE_TRAFFIC_LIGHT: ax.plot( [rel_x1, rel_x2], @@ -1174,1022 +838,480 @@ def plot_observation( return _img_from_fig(fig) -# HTML INTERACTIVE REPLAY -def fill_agents_state(scenario, use_trajectory=False): - current_agents_data = [] - active_indices = scenario.get("active_agent_indices", []) - - # Actions - if use_trajectory: - raw_actions = scenario.get("ctrl_trajectory_actions", []) - else: - raw_actions = scenario.get("actions", []) - action_map = {} - if raw_actions and len(raw_actions) == len(active_indices): - for i, agent_idx in enumerate(active_indices): - action_map[agent_idx] = raw_actions[i] - - for idx, agent in enumerate(scenario.get("agents", [])): - if not agent.get("sim_valid"): - continue - - agent_id = agent.get("id", idx) - is_active = idx in active_indices - - # Couleur - if agent.get("stopped", False): - color = "red" - else: - color = get_agent_color(agent_id, is_active) - req_acc = float(action_map[idx][0]) if idx in action_map else 0.0 - req_str = float(action_map[idx][1]) if idx in action_map else 0.0 - - # On arrondit tout pour alléger le JSON final - current_agents_data.append( - { - "id": int(agent_id), - "x": round(float(agent["sim_x"]), 2), - "y": round(float(agent["sim_y"]), 2), - "z": round(float(agent.get("sim_z", 0)), 2), - "h": round(float(agent["sim_heading"]), 3), - "l": round(float(agent["sim_length"]), 2), - "w": round(float(agent["sim_width"]), 2), - "s": round(float(agent.get("sim_speed", 0)), 2), - "st": round(float(agent.get("sim_steering", 0)), 3), - "c": color, - # Commandes - "ra": round(req_acc, 2), - "rs": round(req_str, 2), - # Compact metrics array (M1..M18) - "m": [round(float(m), 2) for m in agent.get("metrics_array")], - } - ) - - return current_agents_data - - -def fill_traffics_state(scenario, timestep): - current_traffic_data = [] - traffic_elements = scenario.get("traffic_elements", []) - for elem in traffic_elements or []: - if not isinstance(elem, dict): - continue - - t_type = elem.get("type", binding.TRAFFIC_CONTROL_TYPE_TRAFFIC_LIGHT) - sl = elem.get("stop_line") - if sl is None or len(sl) < 4: - continue - - kind = _traffic_control_kind(t_type) - if kind == "light": - states = elem.get("states", []) - state = int(states[timestep]) if states and len(states) > timestep else 0 - color = _traffic_light_color(state) - current_traffic_data.append({"type": "light", "stop_line": sl, "c": color}) - elif kind == "stop": - current_traffic_data.append({"type": "stop", "stop_line": sl, "c": "#FF0000", "c2": "#000000"}) - elif kind == "yield": - current_traffic_data.append({"type": "yield", "stop_line": sl, "c": "#FFD700", "c2": "#000000"}) - - return current_traffic_data - - -def fill_trajectories(scenario, timestep): - current_trajectories = [] - if timestep > 0: - traj_data = scenario.get("trajectory_waypoints_global", []) - active_count = scenario.get("active_agent_count", 0) - - # On itère seulement sur les agents actifs qui ont des trajectoires - for idx in range(min(len(traj_data), active_count)): - waypoints = traj_data[idx].get("waypoints", []) - pts = [] - for wp in waypoints: - pts.append([float(wp["x"]), float(wp["y"]), float(wp["heading"])]) - - current_trajectories.append(pts) - return current_trajectories - - -def extract_obs_frame(obs, scenario, args, timestep, obs_index=0, agent_idx=0, head_north=False): - ego_state, target_obs, partners_obs, lane_obs, boundary_obs, traffic_controls_obs = unpack_obs( - obs, - dynamics_model=args["env"]["dynamics_model"], - target_type=args["env"]["target_type"], - reward_conditioning=args["env"]["reward_conditioning"], - num_target_waypoints=args["env"]["num_target_waypoints"], - max_partners=args["env"]["obs_slots_partners_n"], - max_lane_segments=args["env"]["obs_slots_lane_n"], - max_boundary_segments=args["env"]["obs_slots_boundary_n"], - obs_slots_traffic_controls_n=args["env"]["obs_slots_traffic_controls_n"], - obs_dropout_lane=args["env"].get("obs_dropout_lane", 0.0), - obs_dropout_boundary=args["env"].get("obs_dropout_boundary", 0.0), - agent_idx=obs_index, - ) - scales = _obs_scales(args.get("env")) - target_position_scale = scales["goal_to_position"] if args["env"]["target_type"] == "static" else 1.0 - - # --- Rotation Helper --- - def _rot(x, y): - """Rotates coordinates 90 degrees CCW if head_north is True.""" - return (-y, x) if head_north else (x, y) - - # --- Parse Ego --- - if args["env"]["dynamics_model"] == "jerk": - ego_speed, ego_width, ego_length, steering_angle, a_long, a_lat = ego_state[:6] - else: - ego_speed, ego_width, ego_length = ego_state[:3] - steering_angle, a_long, a_lat = 0.0, 0.0, 0.0 - - ego_width *= scales["veh_width_to_position"] - ego_length *= scales["veh_len_to_position"] - - ego_data = { - "s": round(float(ego_speed), 3), - "w": round(float(ego_width), 3), - "l": round(float(ego_length), 3), - "st": round(float(steering_angle), 3), - "al": round(float(a_long), 3), - "alat": round(float(a_lat), 3), +def _pack_replay_binary(header, chunks): + packed = {} + blob_parts = [] + offset = 0 + dtype_names = { + np.dtype(np.float32): "float32", + np.dtype(np.int32): "int32", + np.dtype(np.int16): "int16", + np.dtype(np.uint8): "uint8", } - - # --- Parse Road Segments --- - rl2p = scales["road_length_to_position"] - rw2p = scales["road_width_to_position"] - - def parse_roads(roads): - res = [] - for r in roads: - if np.all(r == 0): - continue - x, y = r[0], r[1] - length, width = r[3] * rl2p, r[4] * rw2p - cos_a, sin_a = r[5], r[6] - if head_north: - x_rot, y_rot = _rot(x, y) - cos_rot, sin_rot = _rot(cos_a, sin_a) - else: - x_rot, y_rot = x, y - cos_rot, sin_rot = cos_a, sin_a - res.append( - [ - round(float(x_rot), 4), - round(float(y_rot), 4), - round(float(length), 4), - round(float(width), 4), - round(float(cos_rot), 4), - round(float(sin_rot), 4), - ] - ) - return res - - # --- Parse Partners --- - parsed_partners = [] - for p in partners_obs: - if np.all(p == 0): + for name, arr in chunks.items(): + arr = np.ascontiguousarray(arr) + dtype = dtype_names[arr.dtype] + raw = arr.tobytes() + packed[name] = {"dtype": dtype, "shape": list(arr.shape), "offset": offset, "nbytes": len(raw)} + blob_parts.append(raw) + offset += len(raw) + pad = (-offset) % 4 + if pad: + blob_parts.append(b"\0" * pad) + offset += pad + + header = dict(header) + header["chunks"] = packed + header_bytes = json.dumps(header, separators=(",", ":")).encode("utf-8") + pad = (-(4 + len(header_bytes))) % 4 + payload = struct.pack(" 0: - wps = scenario["trajectory_waypoints_local"][agent_idx]["waypoints"] - for wp in wps: - wx, wy = _rot(float(wp["x"]) / 70.0, float(wp["y"]) / 70.0) - traj_data.append({"x": round(wx, 4), "y": round(wy, 4)}) - - gps_data = [] - for g in target_obs: - if np.all(g == 0): + count = min(len(xs), len(ys)) + road_lengths.append(count) + road_types.append(draw_type) + for i in range(count): + road_points.append((float(xs[i]), float(ys[i]))) + + traffic_stop_lines = [] + traffic_types = [] + for elem in scenario.get("traffic_elements", []) or []: + if not isinstance(elem, dict): continue - gx, gy = _rot(g[0] * target_position_scale, g[1] * target_position_scale) - gps_data.append([round(float(gx), 3), round(float(gy), 3)]) + stop_line = elem.get("stop_line") or [0, 0, 0, 0, 0, 0] + traffic_stop_lines.append([float(v) for v in stop_line[:6]]) + traffic_types.append(int(elem.get("type", 0))) + + env_cfg = replay["env"] + scales = _obs_scales(env_cfg) + lane_count = compute_effective_road_obs_count(env_cfg["obs_slots_lane_n"], env_cfg.get("obs_dropout_lane", 0.0)) + boundary_count = compute_effective_road_obs_count( + env_cfg["obs_slots_boundary_n"], env_cfg.get("obs_dropout_boundary", 0.0) + ) - return { - "ego": ego_data, - "partners": parsed_partners, - "lanes": parse_roads(lane_obs), - "bounds": parse_roads(boundary_obs), - "traffic_controls": parsed_traffic_controls, - "traj": traj_data, - "gps": gps_data, + chunks = { + "road_points": np.asarray(road_points or [(0.0, 0.0)], dtype=np.float32), + "road_lengths": np.asarray(road_lengths or [0], dtype=np.int32), + "road_types": np.asarray(road_types or [0], dtype=np.int16), + "traffic_stop_lines": np.asarray(traffic_stop_lines or [[0, 0, 0, 0, 0, 0]], dtype=np.float32), + "traffic_types": np.asarray(traffic_types or [0], dtype=np.int16), + "agent_f32": replay["agent_f32"].astype(np.float32, copy=False), + "agent_i32": replay["agent_i32"].astype(np.int32, copy=False), + "metrics_f32": replay["metrics_f32"].astype(np.float32, copy=False), + "puffer_f32": replay["puffer_f32"].astype(np.float32, copy=False), + "traffic_i16": replay["traffic_i16"].astype(np.int16, copy=False), + "obs": replay["obs"].astype(np.float32, copy=False), + "raw_action": replay["raw_action"].astype(np.float32, copy=False), + "clipped_action": replay["clipped_action"].astype(np.float32, copy=False), + "value": replay["value"].astype(np.float32, copy=False), + "entropy": replay["entropy"].astype(np.float32, copy=False), } - - -def generate_interactive_replay( - scenario, - agent_history, - traffic_history, - trajectory_history, - all_agents_obs_history, - filename="replay.html", - head_north=False, -): - # --- 0. COMPRESSION HELPER --- - def pack_and_compress_data(data, decimals=3): - # Recursively round all floats to save string space - def round_floats(o): - if isinstance(o, float): - return round(o, decimals) - if isinstance(o, dict): - return {k: round_floats(v) for k, v in o.items()} - if isinstance(o, (list, tuple)): - return [round_floats(v) for v in o] - return o - - # Dump without whitespace - compact_json = json.dumps(round_floats(data), separators=(",", ":")) - - # Compress using zlib (deflate) - compressed_bytes = zlib.compress(compact_json.encode("utf-8")) - - # Return as Base64 string for safe HTML embedding - return base64.b64encode(compressed_bytes).decode("ascii") - - # --- 1. METADATA --- - raw_dyn = scenario.get("dynamics_model", 0) - dyn_str = "Jerk" if raw_dyn == 1 else "Classic" + if replay.get("policy_probs") is not None: + chunks["policy_probs"] = replay["policy_probs"].astype(np.float32, copy=False) + if replay.get("policy_mean") is not None: + chunks["policy_mean"] = replay["policy_mean"].astype(np.float32, copy=False) + chunks["policy_std"] = replay["policy_std"].astype(np.float32, copy=False) + chunks["policy_log_prob"] = replay["policy_log_prob"].astype(np.float32, copy=False) + for pool_name in ("pool_partner", "pool_lane", "pool_boundary", "pool_traffic"): + if replay.get(pool_name) is not None: + chunks[pool_name] = replay[pool_name].astype(np.int16, copy=False) + + target_type = scenario.get("target_type", env_cfg.get("target_type", "static")) + if target_type == binding.TARGET_STATIC: + target_type = "static" + elif target_type == binding.TARGET_DYNAMIC: + target_type = "dynamic" + elif target_type == binding.TARGET_DIJKSTRA: + target_type = "dijkstra" + target_features = binding.DYNAMIC_TARGET_FEATURES if target_type == "dynamic" else binding.STATIC_TARGET_FEATURES metadata = { "map_name": scenario.get("map_name", "Unknown"), "scenario_id": scenario.get("scenario_id", "Unknown"), - "dynamics_model": dyn_str, - "target_type": scenario.get("target_type", "static"), - "active_indices": str(scenario.get("active_agent_indices", [])), + "target_type": target_type, + "active_indices": scenario.get("active_agent_indices", []), + "frames": int(replay["agent_f32"].shape[0]), + "agent_cap": int(replay["agent_f32"].shape[1]), + "traffic_cap": int(replay["traffic_i16"].shape[1]), + "active_count": int(replay["obs"].shape[1]), + "obs_dim": int(replay["obs"].shape[2]), + "ego_features": int(binding.EGO_FEATURES), + "num_reward_coefs": int(binding.NUM_REWARD_COEFS), + "partner_features": int(binding.PARTNER_FEATURES), + "road_features": int(binding.ROAD_FEATURES), + "traffic_control_features": int(binding.TRAFFIC_CONTROL_FEATURES), + "action_type": env_cfg.get("action_type", "continuous"), + "dynamics_model": env_cfg.get("dynamics_model", "classic"), + "num_target_waypoints": int(env_cfg["num_target_waypoints"]), + "reward_conditioning": bool(env_cfg["reward_conditioning"]), + "max_partners": int(env_cfg["obs_slots_partners_n"]), + "lane_count": int(lane_count), + "boundary_count": int(boundary_count), + "traffic_obs_count": int(env_cfg["obs_slots_traffic_controls_n"]), + "target_features": int(target_features), + "scales": scales, + "road_polyline_count": len(road_lengths), + "traffic_static_count": len(traffic_types), } + payload = _pack_replay_binary(metadata, chunks) - # --- 2. MAP DATA --- - map_data = {"lanes": [], "lines": [], "edges": []} - for elem in scenario.get("road_elements", []): - if not isinstance(elem, dict): - continue - t = elem.get("type", 0) - if "x" in elem and "y" in elem: - pts = [[float(x), float(y)] for x, y in zip(elem["x"], elem["y"])] - if 1 <= t <= 3: - map_data["lanes"].append(pts) - elif 11 <= t <= 18: - map_data["lines"].append(pts) - elif 21 <= t <= 23: - map_data["edges"].append(pts) - - # --- 3. TEMPLATE HTML --- html_template = """ - PufferDrive Replay XXL + + PufferDrive Replay -
Unpacking Replay Data...
-
-
SPACE: Play | ARROWS: Step | ESC: Free | CLICK: Follow | ENTER: Search
- +
Loading replay...
+
-
-
⚠ COLLISION ⚠
-

☰ DRAG | Agent ?

- -
-
-
Speed
-
0.0 km/h
-
-
-
Req Acc/Str
-
0.0 / 0.0
-
+
+

Agent ?

+
+
Speed
0.0 km/h
+
Lane
-1
- - - -
Metrics Table
-
- -
Position (X/Y/Z)
-
0 , 0 , 0
-
- -
-
☰ DRAG TO MOVE | EGO-CENTRIC NN OBS
- -
- - - +
EGO-CENTRIC NN OBS
+
- - + +
- - """ - - # --- 4. PACKAGE, COMPRESS, AND INJECT --- - master_payload = { - "map": map_data, - "agents": agent_history, - "traffic": traffic_history, - "traj": trajectory_history, - "meta": metadata, - "obs": all_agents_obs_history, - "head_north": head_north, - } - - print("Compressing replay data, this might take a second...") - compressed_payload = pack_and_compress_data(master_payload, decimals=3) - - try: - final_html = html_template.replace("__COMPRESSED_PAYLOAD__", compressed_payload) - with open(filename, "w") as f: - f.write(final_html) - print(f"Success! Saved optimized replay to {filename}") - except Exception as e: - print(f"Error: {e}") + final_html = ( + html_template.replace("__B64_PAYLOAD__", payload) + .replace("__METRIC_LABELS__", json.dumps(METRIC_LABELS, separators=(",", ":"))) + .replace("__VEHICLE_COLORS__", json.dumps(VEHICLE_COLORS, separators=(",", ":"))) + ) + with open(filename, "w") as f: + f.write(final_html) def build_gallery_index(folder_path="."):