From 4a60b300f7959e1a386a107acbdeed5bfe9f2e03 Mon Sep 17 00:00:00 2001 From: Valentin Charraut Date: Tue, 2 Jun 2026 22:28:42 +0200 Subject: [PATCH 01/10] Refactor DrivePolicy architecture and configuration - Updated DrivePolicy to use shared network architecture instead of split network. - Changed input size parameters to specific sizes for ego, partner, lane, boundary, traffic control, and conditioning inputs. - Modified encoder configuration to include activation functions and layer normalization options. - Removed gigaflow architecture in favor of a more flexible encoder design. - Adjusted observation size calculations to include counts of various features. - Updated environment bindings and configuration files to reflect new parameter names and structures. - Enhanced the DriveBackbone class to support new encoder configurations and pooling mechanisms. - Updated the Drive class to accommodate changes in the backbone initialization and observation encoding. --- notebooks/01_observations.ipynb | 4 +- notebooks/05_inference.ipynb | 43 ++++--- notebooks/06_architecture.ipynb | 147 ++++++++++++---------- notebooks/notebook_utils.py | 15 ++- pufferlib/config/ocean/drive.ini | 23 ++-- pufferlib/ocean/drive/drive.h | 66 ++++------ pufferlib/ocean/drive/drive.py | 6 +- pufferlib/ocean/env_binding.h | 1 + pufferlib/ocean/torch.py | 202 ++++++++++++++++++++----------- 9 files changed, 293 insertions(+), 214 deletions(-) diff --git a/notebooks/01_observations.ipynb b/notebooks/01_observations.ipynb index a72fcaa6f3..7931e59af9 100644 --- a/notebooks/01_observations.ipynb +++ b/notebooks/01_observations.ipynb @@ -165,6 +165,8 @@ "idx += env.obs_slots_traffic_controls_n * env.traffic_control_features\n", "assert np.allclose(traffic_manual, traffic), \"traffic mismatch\"\n", "\n", + "idx += 4 # padding at end of obs\n", + "\n", "assert idx == obs.shape[1], f\"obs size mismatch: used {idx}, total {obs.shape[1]}\"\n", "print(f\"All slices match. Total features used: {idx}\")" ] @@ -214,7 +216,7 @@ "metadata": {}, "outputs": [], "source": [ - "partner_labels = [\"rel_x\", \"rel_y\", \"rel_z\", \"length\", \"width\", \"heading_cos\", \"heading_sin\", \"speed\"]\n", + "partner_labels = [\"rel_x\", \"rel_y\", \"rel_z\", \"length\", \"width\", \"heading_cos\", \"heading_sin\", \"rel_vx\", \"rel_vy\"]\n", "active_mask = ~np.all(partners == 0, axis=1)\n", "n_active = active_mask.sum()\n", "print(f\"Active partners: {n_active}/{env.obs_slots_partners_n}\")\n", diff --git a/notebooks/05_inference.ipynb b/notebooks/05_inference.ipynb index c1f99ec009..9e2653d828 100644 --- a/notebooks/05_inference.ipynb +++ b/notebooks/05_inference.ipynb @@ -299,7 +299,7 @@ "- **Ego**: speed, width, length, [jerk: steering, a_long, a_lat], lane_center_dist, lane_angle, speed_limit\n", "- **Conditioning** (if enabled): 17 reward coefs (goal_radius, goal_speed, collision, offroad, comfort, lane_align, vel_align, lane_center, center_bias, velocity, reverse, stop_line, timestep, overspeed, throttle, steer, acc) + target waypoints\n", "- **Target**: static=rel_x,rel_y,rel_z per waypoint; dynamic=rel_x,rel_y,rel_z,heading_cos,heading_sin per waypoint\n", - "- **Partners** (MAX_PARTNERS x 8): rel_x, rel_y, rel_z, length, width, heading_cos, heading_sin, speed\n", + "- **Partners** (MAX_PARTNERS x 9): rel_x, rel_y, rel_z, length, width, heading_cos, heading_sin, rel_vx, rel_vy\n", "- **Lanes** (MAX_LANES x 7): rel_x, rel_y, rel_z, seg_length, seg_width, dir_cos, dir_sin\n", "- **Boundaries** (MAX_BOUNDS x 7): same as lanes\n", "- **Traffic controls** (MAX_TRAFFIC x 7): rel_x1, rel_y1, rel_x2, rel_y2, rel_z, type, state" @@ -390,9 +390,9 @@ "# --- Partner summary ---\n", "n_visible = np.sum(np.any(partners != 0, axis=1))\n", "print(f\"\\n--- Partners: {n_visible}/{partners.shape[0]} visible ---\")\n", - "partner_labels = [\"rel_x\", \"rel_y\", \"rel_z\", \"width\", \"length\", \"heading_cos\", \"heading_sin\", \"speed\"]\n", + "partner_labels = [\"rel_x\", \"rel_y\", \"rel_z\", \"length\", \"width\", \"heading_cos\", \"heading_sin\", \"rel_vx\", \"rel_vy\"]\n", "for p in range(min(int(n_visible), 5)):\n", - " vals = \", \".join(f\"{partner_labels[j]}={partners[p, j]:.3f}\" for j in range(8))\n", + " vals = \", \".join(f\"{partner_labels[j]}={partners[p, j]:.3f}\" for j in range(env.partner_features))\n", " print(f\" [{p}] {vals}\")\n", "if n_visible > 5:\n", " print(f\" ... ({n_visible - 5} more)\")\n", @@ -689,12 +689,14 @@ "for i in range(partners.shape[0]):\n", " if np.allclose(partners[i], 0):\n", " continue\n", - " rx, ry, rz, w, l, hc, hs, spd = partners[i]\n", + " rx, ry, rz, length, width, hc, hs, rel_vx, rel_vy = partners[i]\n", " heading = np.arctan2(hs, hc)\n", - " rect = Rectangle((-l / 2, -w / 2), l, w, facecolor=\"orange\", edgecolor=\"black\", alpha=0.6, zorder=9)\n", + " rect = Rectangle(\n", + " (-length / 2, -width / 2), length, width, facecolor=\"orange\", edgecolor=\"black\", alpha=0.6, zorder=9\n", + " )\n", " rect.set_transform(plt.matplotlib.transforms.Affine2D().rotate(heading).translate(rx, ry) + ax.transData)\n", " ax.add_patch(rect)\n", - " ax.annotate(f\"{spd:.2f}\", (rx, ry), fontsize=7, ha=\"center\", color=\"darkred\", zorder=12)\n", + " ax.annotate(f\"{rel_vx:.2f},{rel_vy:.2f}\", (rx, ry), fontsize=7, ha=\"center\", color=\"darkred\", zorder=12)\n", "part_mask = np.any(partners != 0, axis=1)\n", "if part_mask.any():\n", " ax.scatter(\n", @@ -847,7 +849,7 @@ "outputs": [], "source": [ "# Partner per-feature distributions (pooled over all agents + timesteps, visible only)\n", - "partner_labels = [\"rel_x\", \"rel_y\", \"rel_z\", \"width\", \"length\", \"heading_cos\", \"heading_sin\", \"speed\"]\n", + "partner_labels = [\"rel_x\", \"rel_y\", \"rel_z\", \"length\", \"width\", \"heading_cos\", \"heading_sin\", \"rel_vx\", \"rel_vy\"]\n", "obs_slots_partners_n = env.obs_slots_partners_n\n", "pf = env.partner_features\n", "\n", @@ -861,17 +863,17 @@ "\n", "all_partners = buf_stoch[\"obs\"][:, :, _p_start:_p_end].reshape(\n", " -1, obs_slots_partners_n, pf\n", - ") # (H*N, obs_slots_partners_n, 8)\n", + ") # (H*N, obs_slots_partners_n, 9)\n", "# Mask: partner is visible if any feature != 0\n", "visible_mask = np.any(all_partners != 0, axis=2) # (H*N, 16)\n", - "visible_partners = all_partners[visible_mask] # (K, 8) — all visible partner observations\n", + "visible_partners = all_partners[visible_mask] # (K, 9) — all visible partner observations\n", "\n", "print(\n", " f\"Total partner obs: {all_partners.shape[0] * obs_slots_partners_n}, visible: {len(visible_partners)} \"\n", " f\"({100 * len(visible_partners) / (all_partners.shape[0] * obs_slots_partners_n):.1f}%)\"\n", ")\n", "\n", - "fig, axes = plt.subplots(3, 3, figsize=(21, 10))\n", + "fig, axes = plt.subplots(2, 5, figsize=(21, 8))\n", "axes = axes.flatten()\n", "\n", "for i, label in enumerate(partner_labels):\n", @@ -883,12 +885,13 @@ " axes[i].tick_params(labelsize=7)\n", "\n", "# rel_x vs rel_y scatter in last panel\n", - "axes[8].scatter(visible_partners[:, 0], visible_partners[:, 1], s=1, alpha=0.15, color=\"darkorange\")\n", - "axes[8].set_xlabel(\"rel_x\")\n", - "axes[8].set_ylabel(\"rel_y\")\n", - "axes[8].set_title(\"Partner positions (ego frame)\")\n", - "axes[8].set_aspect(\"equal\")\n", - "axes[8].grid(True, alpha=0.3)\n", + "pos_ax = axes[len(partner_labels)]\n", + "pos_ax.scatter(visible_partners[:, 0], visible_partners[:, 1], s=1, alpha=0.15, color=\"darkorange\")\n", + "pos_ax.set_xlabel(\"rel_x\")\n", + "pos_ax.set_ylabel(\"rel_y\")\n", + "pos_ax.set_title(\"Partner positions (ego frame)\")\n", + "pos_ax.set_aspect(\"equal\")\n", + "pos_ax.grid(True, alpha=0.3)\n", "\n", "fig.suptitle(\"Partner features: all visible, full rollout\", fontsize=13)\n", "plt.tight_layout()\n", @@ -1464,7 +1467,7 @@ "source": [ "## Encoder analysis — what the policy encodes\n", "\n", - "Each obs layer has its own encoder projecting raw features → `input_size` embedding:\n", + "Each obs layer has its own encoder projecting raw features → embedding width:\n", "- **ego** and **conditioning** (reward coefs + target): single vector, no pooling.\n", "- **partners / lanes / boundaries / traffic**: per-slot encoder, padded slots masked to `-inf`, then **max-pooled** across slots → one embedding. Fully-padded layers are zeroed.\n", "\n", @@ -1522,10 +1525,10 @@ "for name, mod, rin, nslots, is_set in enc_inventory:\n", " nparam = sum(p.numel() for p in mod.parameters())\n", " print(\n", - " f\"{name:>13s} | {rin:>6d} | {bb.input_size:>7d} | {nslots:>5d} | {('max' if is_set else '-'):>6s} | {nparam:>9,d}\"\n", + " f\"{name:>13s} | {rin:>6d} | {mod[-1].out_features:>7d} | {nslots:>5d} | {('max' if is_set else '-'):>6s} | {nparam:>9,d}\"\n", " )\n", "print(\n", - " f\"\\nBackbone input = {len(enc_inventory)} x {bb.input_size} = {len(enc_inventory) * bb.input_size} -> backbone -> {bb.out_dim}\"\n", + " f\"\\nBackbone input = {sum(mod[-1].out_features for _, mod, _, _, _ in enc_inventory)} -> backbone -> {bb.out_dim}\"\n", ")\n", "\n", "# Capture pre-pool encoder outputs via forward hooks\n", @@ -1581,7 +1584,7 @@ " masked = captured[name].masked_fill(pad[name].unsqueeze(2), -torch.inf)\n", " vm = (~pad[name]).any(dim=1)\n", " valid_sample[name] = vm\n", - " winners[name] = masked.max(dim=1).indices # (B, input_size): winning slot per dim\n", + " winners[name] = masked.max(dim=1).indices # (B, embedding dim): winning slot per dim\n", " pooled[name] = torch.where(vm.unsqueeze(1), masked.max(dim=1).values, torch.zeros_like(masked.max(dim=1).values))\n", "\n", "for name in (\"ego\", \"conditioning\"):\n", diff --git a/notebooks/06_architecture.ipynb b/notebooks/06_architecture.ipynb index 8df1731aba..260b636444 100644 --- a/notebooks/06_architecture.ipynb +++ b/notebooks/06_architecture.ipynb @@ -31,35 +31,44 @@ "ACTOR_NUM_LAYERS = 3\n", "CRITIC_HIDDEN_SIZE = 64\n", "CRITIC_NUM_LAYERS = 2\n", - "SPLIT_NETWORK = False\n", - "ENCODER_GIGAFLOW = True\n", - "DROPOUT = 0.0\n", + "SHARED_NETWORK = True\n", + "ENCODER_ACTIVATION = \"tanh\"\n", + "ENCODER_LAYER_NORM = True\n", + "BACKBONE_ACTIVATION = \"gelu\"\n", + "BACKBONE_LAYER_NORM = False\n", "\n", "env, obs, info = make_drive_env()\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "policy = DrivePolicy(\n", " env,\n", - " input_size=INPUT_SIZE,\n", + " ego_input_size=INPUT_SIZE,\n", + " partner_input_size=INPUT_SIZE,\n", + " lane_input_size=INPUT_SIZE,\n", + " boundary_input_size=INPUT_SIZE,\n", + " traffic_control_input_size=INPUT_SIZE,\n", + " conditioning_input_size=INPUT_SIZE,\n", " backbone_hidden_size=BACKBONE_HIDDEN_SIZE,\n", " backbone_num_layers=BACKBONE_NUM_LAYERS,\n", " actor_hidden_size=ACTOR_HIDDEN_SIZE,\n", " actor_num_layers=ACTOR_NUM_LAYERS,\n", " critic_hidden_size=CRITIC_HIDDEN_SIZE,\n", " critic_num_layers=CRITIC_NUM_LAYERS,\n", - " split_network=SPLIT_NETWORK,\n", - " encoder_gigaflow=ENCODER_GIGAFLOW,\n", - " dropout=DROPOUT,\n", + " encoder_activation=ENCODER_ACTIVATION,\n", + " encoder_layer_norm=ENCODER_LAYER_NORM,\n", + " backbone_activation=BACKBONE_ACTIVATION,\n", + " backbone_layer_norm=BACKBONE_LAYER_NORM,\n", + " shared_network=SHARED_NETWORK,\n", ").to(device)\n", "\n", "print(f\"Device: {device}\")\n", "print(f\"Obs dim: {obs.shape[1]}\")\n", "print(f\"Action dim: {policy.atn_dim}\")\n", - "print(f\"Split network: {SPLIT_NETWORK}\")\n", + "print(f\"Shared network: {SHARED_NETWORK}\")\n", "print(f\"Backbone: {BACKBONE_HIDDEN_SIZE} x {BACKBONE_NUM_LAYERS}L\")\n", "print(f\"Actor: {ACTOR_HIDDEN_SIZE} x {ACTOR_NUM_LAYERS}L\")\n", "print(f\"Critic: {CRITIC_HIDDEN_SIZE} x {CRITIC_NUM_LAYERS}L\")\n", - "print(f\"Encoder gigaflow: {ENCODER_GIGAFLOW}, Dropout: {DROPOUT}\")" + "print(f\"Encoder: {ENCODER_ACTIVATION}, LayerNorm: {ENCODER_LAYER_NORM}\")" ] }, { @@ -95,20 +104,19 @@ "backbone = policy.actor_backbone\n", "cond_dim = backbone.conditioning_dim\n", "\n", - "# Collect encoder info — encoder_gigaflow adds Tanh+Dropout between LN and second Linear\n", - "# ego, partner, conditioning use encoder_gigaflow; lane, boundary, traffic_ctrl use dropout\n", + "# Collect encoder info\n", "encoders = [\n", - " (\"ego\", env.ego_features, 1, \"direct\", ENCODER_GIGAFLOW),\n", - " (\"conditioning\", cond_dim, 1, \"direct\", ENCODER_GIGAFLOW) if cond_dim > 0 else None,\n", - " (\"partner\", env.partner_features, env.obs_slots_partners_n, \"max-pool\", ENCODER_GIGAFLOW),\n", - " (\"lane\", env.road_features, env.obs_slots_lane_kept, \"max-pool\", ENCODER_GIGAFLOW),\n", - " (\"boundary\", env.road_features, env.obs_slots_boundary_kept, \"max-pool\", ENCODER_GIGAFLOW),\n", + " (\"ego\", env.ego_features, 1, \"direct\", INPUT_SIZE),\n", + " (\"conditioning\", cond_dim, 1, \"direct\", INPUT_SIZE) if cond_dim > 0 else None,\n", + " (\"partner\", env.partner_features, env.obs_slots_partners_n, \"max-pool\", INPUT_SIZE),\n", + " (\"lane\", env.road_features, env.obs_slots_lane_kept, \"max-pool\", INPUT_SIZE),\n", + " (\"boundary\", env.road_features, env.obs_slots_boundary_kept, \"max-pool\", INPUT_SIZE),\n", " (\n", " \"traffic_ctrl\",\n", " env.traffic_control_features - 2 + binding.NUM_TRAFFIC_CONTROL_TYPES + binding.NUM_TRAFFIC_CONTROL_STATES,\n", " env.obs_slots_traffic_controls_n,\n", " \"max-pool (onehot)\",\n", - " ENCODER_GIGAFLOW,\n", + " INPUT_SIZE,\n", " ),\n", "]\n", "encoders = [e for e in encoders if e is not None]\n", @@ -123,36 +131,25 @@ "colors = plt.cm.Set2(np.linspace(0, 1, n_enc))\n", "\n", "# Draw encoders\n", - "for i, ((name, in_f, n_obj, agg, gigaflow), y, c) in enumerate(zip(encoders, y_positions, colors)):\n", + "for i, ((name, in_f, n_obj, agg, out_size), y, c) in enumerate(zip(encoders, y_positions, colors)):\n", " # Input box\n", " label = f\"{name}\\n{n_obj}x{in_f}\" if n_obj > 1 else f\"{name}\\n{in_f}\"\n", " ax.add_patch(plt.Rectangle((0.2, y - 0.3), 1.6, 0.6, facecolor=c, edgecolor=\"black\", lw=1.2, alpha=0.8))\n", " ax.text(1.0, y, label, ha=\"center\", va=\"center\", fontsize=8, fontweight=\"bold\")\n", "\n", - " # Encoder box — show gigaflow arch vs standard\n", + " # Encoder box\n", " ax.add_patch(plt.Rectangle((2.5, y - 0.25), 2.0, 0.5, facecolor=\"lightyellow\", edgecolor=\"black\", lw=1))\n", - " ax.text(3.5, y + 0.05, f\"Linear({in_f},{INPUT_SIZE})\", ha=\"center\", va=\"center\", fontsize=7)\n", - " if gigaflow:\n", - " ax.text(\n", - " 3.5,\n", - " y - 0.12,\n", - " f\"LN+Tanh+Drop+Linear({INPUT_SIZE},{INPUT_SIZE})\",\n", - " ha=\"center\",\n", - " va=\"center\",\n", - " fontsize=5.5,\n", - " color=\"darkgreen\",\n", - " )\n", - " else:\n", - " drop_str = f\"+Drop({DROPOUT})\" if DROPOUT > 0 and name not in (\"ego\", \"partner\", \"conditioning\") else \"\"\n", - " ax.text(\n", - " 3.5,\n", - " y - 0.12,\n", - " f\"LN{drop_str}+Linear({INPUT_SIZE},{INPUT_SIZE})\",\n", - " ha=\"center\",\n", - " va=\"center\",\n", - " fontsize=6,\n", - " color=\"gray\",\n", - " )\n", + " ax.text(3.5, y + 0.05, f\"Linear({in_f},{out_size})\", ha=\"center\", va=\"center\", fontsize=7)\n", + " ln_label = \"LN+\" if ENCODER_LAYER_NORM else \"\"\n", + " ax.text(\n", + " 3.5,\n", + " y - 0.12,\n", + " f\"{ln_label}{ENCODER_ACTIVATION}+Linear({out_size},{out_size})\",\n", + " ha=\"center\",\n", + " va=\"center\",\n", + " fontsize=6,\n", + " color=\"gray\",\n", + " )\n", "\n", " # Aggregation\n", " if n_obj > 1:\n", @@ -194,14 +191,13 @@ "ax.annotate(\"\", xy=(9.0, 6.0), xytext=(8.8, 5.3), arrowprops=dict(arrowstyle=\"->\", lw=1.2))\n", "ax.annotate(\"\", xy=(9.0, 4.0), xytext=(8.8, 4.7), arrowprops=dict(arrowstyle=\"->\", lw=1.2))\n", "\n", - "split_label = \"SPLIT\" if SPLIT_NETWORK else \"SHARED\"\n", + "split_label = \"SHARED\" if SHARED_NETWORK else \"SPLIT\"\n", "ax.text(8.9, 4.55, split_label, ha=\"center\", va=\"center\", fontsize=7, color=\"red\", fontweight=\"bold\")\n", "\n", - "gigaflow_label = \"GIGAFLOW\" if ENCODER_GIGAFLOW else \"STANDARD\"\n", "ax.text(\n", " 5.0,\n", " 0.3,\n", - " f\"Encoder mode: {gigaflow_label} | Dropout: {DROPOUT}\",\n", + " f\"Encoder: {ENCODER_ACTIVATION} | LayerNorm: {ENCODER_LAYER_NORM}\",\n", " ha=\"center\",\n", " va=\"center\",\n", " fontsize=8,\n", @@ -210,7 +206,7 @@ ")\n", "\n", "ax.set_title(\n", - " f\"DrivePolicy Architecture (input_size={INPUT_SIZE}, backbone={BACKBONE_HIDDEN_SIZE})\",\n", + " f\"DrivePolicy Architecture (encoder_size={INPUT_SIZE}, backbone={BACKBONE_HIDDEN_SIZE})\",\n", " fontsize=12,\n", " fontweight=\"bold\",\n", ")\n", @@ -258,7 +254,7 @@ " print(f\"{n:>25s} | {c:>10,d} | {c / total:>5.1%}\")\n", "print(\"-\" * 48)\n", "print(f\"{'TOTAL':>25s} | {total:>10,d}\")\n", - "if SPLIT_NETWORK:\n", + "if not SHARED_NETWORK:\n", " critic_bb = count_params(policy.critic_backbone)\n", " print(f\"{'+ critic_backbone':>25s} | {critic_bb:>10,d}\")\n", " print(f\"{'GRAND TOTAL':>25s} | {total + critic_bb:>10,d}\")\n", @@ -545,12 +541,12 @@ "outputs": [], "source": [ "configs = [\n", - " {\"name\": \"tiny\", \"input_size\": 32, \"backbone_hidden_size\": 64},\n", - " {\"name\": \"small\", \"input_size\": 64, \"backbone_hidden_size\": 128},\n", - " {\"name\": \"medium\", \"input_size\": 128, \"backbone_hidden_size\": 256, \"backbone_num_layers\": 2},\n", + " {\"name\": \"tiny\", \"encoder_size\": 32, \"backbone_hidden_size\": 64},\n", + " {\"name\": \"small\", \"encoder_size\": 64, \"backbone_hidden_size\": 128},\n", + " {\"name\": \"medium\", \"encoder_size\": 128, \"backbone_hidden_size\": 256, \"backbone_num_layers\": 2},\n", " {\n", " \"name\": \"large\",\n", - " \"input_size\": 128,\n", + " \"encoder_size\": 128,\n", " \"backbone_hidden_size\": 512,\n", " \"backbone_num_layers\": 2,\n", " \"actor_num_layers\": 2,\n", @@ -560,7 +556,7 @@ " },\n", " {\n", " \"name\": \"xlarge\",\n", - " \"input_size\": 256,\n", + " \"encoder_size\": 256,\n", " \"backbone_hidden_size\": 1024,\n", " \"backbone_num_layers\": 3,\n", " \"actor_num_layers\": 2,\n", @@ -568,32 +564,50 @@ " \"critic_num_layers\": 2,\n", " \"critic_hidden_size\": 512,\n", " },\n", - " {\"name\": \"small+giga\", \"input_size\": 64, \"backbone_hidden_size\": 128, \"encoder_gigaflow\": True, \"dropout\": 0.1},\n", + " {\"name\": \"small+tanh\", \"encoder_size\": 64, \"backbone_hidden_size\": 128, \"encoder_activation\": \"tanh\"},\n", " {\n", - " \"name\": \"medium+giga\",\n", - " \"input_size\": 128,\n", + " \"name\": \"medium+tanh\",\n", + " \"encoder_size\": 128,\n", " \"backbone_hidden_size\": 256,\n", " \"backbone_num_layers\": 2,\n", - " \"encoder_gigaflow\": True,\n", - " \"dropout\": 0.1,\n", + " \"encoder_activation\": \"tanh\",\n", " },\n", "]\n", "\n", "POLICY_DEFAULTS = {\n", + " \"ego_input_size\": 64,\n", + " \"partner_input_size\": 64,\n", + " \"lane_input_size\": 64,\n", + " \"boundary_input_size\": 64,\n", + " \"traffic_control_input_size\": 64,\n", + " \"conditioning_input_size\": 64,\n", " \"backbone_num_layers\": 1,\n", " \"actor_hidden_size\": 128,\n", " \"actor_num_layers\": 0,\n", " \"critic_hidden_size\": 128,\n", " \"critic_num_layers\": 0,\n", - " \"encoder_gigaflow\": False,\n", - " \"dropout\": 0.0,\n", - " \"split_network\": False,\n", + " \"encoder_activation\": \"relu\",\n", + " \"encoder_layer_norm\": True,\n", + " \"backbone_activation\": \"gelu\",\n", + " \"backbone_layer_norm\": False,\n", + " \"shared_network\": True,\n", "}\n", "\n", "results = []\n", "for cfg in configs:\n", - " name = cfg.pop(\"name\")\n", - " full_cfg = {**POLICY_DEFAULTS, **cfg}\n", + " name = cfg[\"name\"]\n", + " encoder_size = cfg.get(\"encoder_size\", POLICY_DEFAULTS[\"ego_input_size\"])\n", + " full_cfg = {**POLICY_DEFAULTS, **{k: v for k, v in cfg.items() if k not in (\"name\", \"encoder_size\")}}\n", + " full_cfg.update(\n", + " {\n", + " \"ego_input_size\": encoder_size,\n", + " \"partner_input_size\": encoder_size,\n", + " \"lane_input_size\": encoder_size,\n", + " \"boundary_input_size\": encoder_size,\n", + " \"traffic_control_input_size\": encoder_size,\n", + " \"conditioning_input_size\": encoder_size,\n", + " }\n", + " )\n", " p = DrivePolicy(env, **full_cfg).to(device)\n", " n_params = sum(pp.numel() for pp in p.parameters())\n", "\n", @@ -607,17 +621,16 @@ " torch.cuda.synchronize()\n", " ms_per_fwd = (time.time() - t0) / 100 * 1000\n", "\n", - " results.append({\"name\": name, \"params\": n_params, \"ms/fwd\": ms_per_fwd, **cfg})\n", - " cfg[\"name\"] = name # restore\n", + " results.append({\"name\": name, \"encoder_size\": encoder_size, \"params\": n_params, \"ms/fwd\": ms_per_fwd, **full_cfg})\n", " del p\n", "\n", "print(\n", - " f\"{'Config':>12s} | {'input':>5s} | {'bb_h':>5s} | {'bb_L':>4s} | {'act_h':>5s} | {'act_L':>5s} | {'crt_h':>5s} | {'crt_L':>5s} | {'giga':>5s} | {'Params':>10s} | {'ms/fwd':>8s}\"\n", + " f\"{'Config':>12s} | {'enc':>5s} | {'bb_h':>5s} | {'bb_L':>4s} | {'act_h':>5s} | {'act_L':>5s} | {'crt_h':>5s} | {'crt_L':>5s} | {'enc_act':>7s} | {'Params':>10s} | {'ms/fwd':>8s}\"\n", ")\n", "print(\"-\" * 105)\n", "for r in results:\n", " print(\n", - " f\"{r['name']:>12s} | {r['input_size']:>5d} | {r.get('backbone_hidden_size', 1024):>5d} | {r.get('backbone_num_layers', 1):>4d} | {r.get('actor_hidden_size', 1024):>5d} | {r.get('actor_num_layers', 1):>5d} | {r.get('critic_hidden_size', 1024):>5d} | {r.get('critic_num_layers', 1):>5d} | {str(r.get('encoder_gigaflow', False)):>5s} | {r['params']:>10,d} | {r['ms/fwd']:>7.2f}ms\"\n", + " f\"{r['name']:>12s} | {r['encoder_size']:>5d} | {r['backbone_hidden_size']:>5d} | {r['backbone_num_layers']:>4d} | {r['actor_hidden_size']:>5d} | {r['actor_num_layers']:>5d} | {r['critic_hidden_size']:>5d} | {r['critic_num_layers']:>5d} | {r['encoder_activation']:>7s} | {r['params']:>10,d} | {r['ms/fwd']:>7.2f}ms\"\n", " )\n", "\n", "fig, axes = plt.subplots(1, 2, figsize=(14, 4))\n", @@ -625,11 +638,11 @@ "params = [r[\"params\"] for r in results]\n", "times = [r[\"ms/fwd\"] for r in results]\n", "\n", - "bar_colors = [\"coral\" if r.get(\"encoder_gigaflow\") else \"steelblue\" for r in results]\n", + "bar_colors = [\"coral\" if r[\"encoder_activation\"] == \"tanh\" else \"steelblue\" for r in results]\n", "\n", "axes[0].bar(names, params, color=bar_colors, edgecolor=\"black\")\n", "axes[0].set_ylabel(\"Parameters\")\n", - "axes[0].set_title(\"Parameter Count (orange=gigaflow)\")\n", + "axes[0].set_title(\"Parameter Count (orange=tanh encoder)\")\n", "axes[0].tick_params(axis=\"x\", rotation=30)\n", "for i, v in enumerate(params):\n", " axes[0].text(i, v, f\"{v:,}\", ha=\"center\", va=\"bottom\", fontsize=7)\n", diff --git a/notebooks/notebook_utils.py b/notebooks/notebook_utils.py index 02164349f7..100d15c011 100644 --- a/notebooks/notebook_utils.py +++ b/notebooks/notebook_utils.py @@ -85,16 +85,23 @@ } DEFAULT_POLICY_KWARGS = { - "input_size": 64, + "ego_input_size": 64, + "partner_input_size": 64, + "lane_input_size": 64, + "boundary_input_size": 64, + "traffic_control_input_size": 64, + "conditioning_input_size": 64, "backbone_hidden_size": 128, "backbone_num_layers": 1, "actor_hidden_size": 128, "actor_num_layers": 0, "critic_hidden_size": 128, "critic_num_layers": 0, - "encoder_gigaflow": True, - "dropout": 0.0, - "split_network": False, + "encoder_activation": "tanh", + "encoder_layer_norm": True, + "backbone_activation": "gelu", + "backbone_layer_norm": False, + "shared_network": True, } diff --git a/pufferlib/config/ocean/drive.ini b/pufferlib/config/ocean/drive.ini index 813d1970eb..63d87ac303 100644 --- a/pufferlib/config/ocean/drive.ini +++ b/pufferlib/config/ocean/drive.ini @@ -11,13 +11,23 @@ num_workers = auto batch_size = auto [policy] -; Encoder layer -input_size = 64 -encoder_gigaflow = True -dropout = 0.0 +; Encoder layer (per-encoder embedding width) +ego_input_size = 64 +partner_input_size = 64 +lane_input_size = 64 +boundary_input_size = 64 +traffic_control_input_size = 64 +conditioning_input_size = 64 +; Encoder activation - options: "relu", "tanh", "gelu" +encoder_activation = "relu" +encoder_layer_norm = True +mask_padded_observations = False ; Shared backbone layer backbone_hidden_size = 512 backbone_num_layers = 4 +; Backbone activation - options: "relu", "tanh", "gelu" +backbone_activation = "gelu" +backbone_layer_norm = False ; Actor head layer actor_hidden_size = 512 actor_num_layers = 0 @@ -25,7 +35,7 @@ actor_num_layers = 0 critic_hidden_size = 512 critic_num_layers = 0 ; Dual or shared actor-critic backbone -split_network = False +shared_network = True [rnn] input_size = 512 @@ -98,7 +108,7 @@ reward_lane_align = 0.025 reward_vel_align = 1.0 reward_lane_center = 0.0038 reward_center_bias = 0.0 -reward_velocity = 0.1 +reward_velocity = 0.0025 reward_reverse = 0.005 reward_timestep = 0.000025 reward_overspeed = 0.05 @@ -170,7 +180,6 @@ adam_beta2 = 0.999 adam_eps = 1e-8 vtrace_c_clip = 1 vtrace_rho_clip = 1 -ppo_granularity = auto adv_sampling_prio_alpha = 0.8499999999999999 adv_sampling_prio_beta0 = 0.8499999999999999 adv_filter_ewma_beta = 0.25 diff --git a/pufferlib/ocean/drive/drive.h b/pufferlib/ocean/drive/drive.h index 1dc87e6fd3..2397310f36 100644 --- a/pufferlib/ocean/drive/drive.h +++ b/pufferlib/ocean/drive/drive.h @@ -119,9 +119,9 @@ // Observation feature counts #define EGO_FEATURES 10 #define ROAD_FEATURES 7 -#define PARTNER_FEATURES 8 +#define PARTNER_FEATURES 9 #define TRAFFIC_CONTROL_FEATURES 7 -#define PADDED_OBSERVATION_VALUE -0.001f +#define OBS_COUNT_FEATURES 4 #define STATIC_TARGET_FEATURES 3 #define DYNAMIC_TARGET_FEATURES 5 @@ -3700,29 +3700,8 @@ static int compute_observation_size(Drive *env) { : 0; return EGO_FEATURES + PARTNER_FEATURES * env->obs_slots_partners_n + ROAD_FEATURES * (env->obs_slots_lane_kept + env->obs_slots_boundary_kept) - + TRAFFIC_CONTROL_FEATURES * env->obs_slots_traffic_controls_n + env->reward_conditioning * NUM_REWARD_COEFS - + env->num_target_waypoints * target_features; -} - -// Fill `rows` x `features` observation slots with the padding sentinel. -static inline void fill_padded_observation_rows(float *obs, int rows, int features) { - for (int r = 0; r < rows; r++) { - for (int c = 0; c < features; c++) { - obs[r * features + c] = PADDED_OBSERVATION_VALUE; - } - } -} - -// Pad `rows` traffic-control slots with the sentinel; type/state columns set to NONE/UNKNOWN. -static inline void fill_padded_traffic_control_rows(float *obs, int rows) { - for (int r = 0; r < rows; r++) { - int base = r * TRAFFIC_CONTROL_FEATURES; - for (int c = 0; c < TRAFFIC_CONTROL_FEATURES - 2; c++) { - obs[base + c] = PADDED_OBSERVATION_VALUE; - } - obs[base + TRAFFIC_CONTROL_FEATURES - 2] = TRAFFIC_CONTROL_TYPE_NONE; - obs[base + TRAFFIC_CONTROL_FEATURES - 1] = TRAFFIC_CONTROL_STATE_UNKNOWN; - } + + TRAFFIC_CONTROL_FEATURES * env->obs_slots_traffic_controls_n + OBS_COUNT_FEATURES + + env->reward_conditioning * NUM_REWARD_COEFS + env->num_target_waypoints * target_features; } void allocate(Drive *env) { @@ -4473,7 +4452,7 @@ static int write_reward_target_obs(Drive *env, Agent *ego, float *obs, int obs_i static int write_partner_obs(Drive *env, Agent *ego, int agent_idx, float *obs, int obs_idx, int *partner_count) { if (ego->is_blind_partner && random_uniform(0.0f, 1.0f) < env->partner_blindness_trigger_prob) { int partner_obs_stride = env->obs_slots_partners_n * PARTNER_FEATURES; - fill_padded_observation_rows(&obs[obs_idx], env->obs_slots_partners_n, PARTNER_FEATURES); + memset(&obs[obs_idx], 0, partner_obs_stride * sizeof(float)); *partner_count = 0; return obs_idx + partner_obs_stride; } @@ -4534,7 +4513,7 @@ static int write_partner_obs(Drive *env, Agent *ego, int agent_idx, float *obs, float rel_x, rel_y, rel_heading_x, rel_heading_y, rel_vx, rel_vy; project_vector_to_ego_frame(ego, nearby_agents[j].dx, nearby_agents[j].dy, &rel_x, &rel_y); project_vector_to_ego_frame(ego, other->cos_heading, other->sin_heading, &rel_heading_x, &rel_heading_y); - project_vector_to_ego_frame(ego, other->sim_vx, other->sim_vy, &rel_vx, &rel_vy); + project_vector_to_ego_frame(ego, other->sim_vx - ego->sim_vx, other->sim_vy - ego->sim_vy, &rel_vx, &rel_vy); obs[obs_idx++] = rel_x / env->obs_norm_xy_offset_m; obs[obs_idx++] = rel_y / env->obs_norm_xy_offset_m; obs[obs_idx++] = nearby_agents[j].dz / Z_BUFFER; @@ -4542,12 +4521,12 @@ static int write_partner_obs(Drive *env, Agent *ego, int agent_idx, float *obs, obs[obs_idx++] = other->sim_width / env->obs_norm_veh_width_m; obs[obs_idx++] = rel_heading_x; obs[obs_idx++] = rel_heading_y; - obs[obs_idx++] = other->sim_speed_signed / MAX_SPEED; + obs[obs_idx++] = rel_vx / (2.0f * MAX_SPEED); + obs[obs_idx++] = rel_vy / (2.0f * MAX_SPEED); partners_written++; } *partner_count = partners_written; - fill_padded_observation_rows(&obs[obs_idx], env->obs_slots_partners_n - partners_written, PARTNER_FEATURES); return obs_idx + (env->obs_slots_partners_n - partners_written) * PARTNER_FEATURES; } @@ -4646,28 +4625,28 @@ static int write_road_obs(Drive *env, Agent *ego, float *obs, int obs_idx, int * subsample_road_observation_rows(lanes_buffer, lanes_found, lanes_to_copy); subsample_road_observation_rows(boundaries_buffer, boundaries_found, boundaries_to_copy); memcpy(&obs[lane_obs_idx], lanes_buffer, lanes_to_copy * ROAD_FEATURES * sizeof(float)); - fill_padded_observation_rows( + memset( &obs[lane_obs_idx + lanes_to_copy * ROAD_FEATURES], - env->obs_slots_lane_kept - lanes_to_copy, - ROAD_FEATURES); + 0, + (env->obs_slots_lane_kept - lanes_to_copy) * ROAD_FEATURES * sizeof(float)); memcpy(&obs[boundary_obs_idx], boundaries_buffer, boundaries_to_copy * ROAD_FEATURES * sizeof(float)); - fill_padded_observation_rows( + memset( &obs[boundary_obs_idx + boundaries_to_copy * ROAD_FEATURES], - env->obs_slots_boundary_kept - boundaries_to_copy, - ROAD_FEATURES); + 0, + (env->obs_slots_boundary_kept - boundaries_to_copy) * ROAD_FEATURES * sizeof(float)); return obs_idx; } *lane_count = lanes_found; *boundary_count = boundaries_found; - fill_padded_observation_rows( + memset( &obs[lane_obs_idx + lanes_found * ROAD_FEATURES], - env->obs_slots_lane_kept - lanes_found, - ROAD_FEATURES); - fill_padded_observation_rows( + 0, + (env->obs_slots_lane_kept - lanes_found) * ROAD_FEATURES * sizeof(float)); + memset( &obs[boundary_obs_idx + boundaries_found * ROAD_FEATURES], - env->obs_slots_boundary_kept - boundaries_found, - ROAD_FEATURES); + 0, + (env->obs_slots_boundary_kept - boundaries_found) * ROAD_FEATURES * sizeof(float)); return obs_idx; } @@ -4740,7 +4719,6 @@ static int write_traffic_control_obs(Drive *env, Agent *ego, float *obs, int obs } *traffic_control_count = controls_written; - fill_padded_traffic_control_rows(&obs[obs_idx], env->obs_slots_traffic_controls_n - controls_written); return obs_idx + (env->obs_slots_traffic_controls_n - controls_written) * TRAFFIC_CONTROL_FEATURES; } @@ -4763,6 +4741,10 @@ static void compute_observations(Drive *env) { obs_idx = write_partner_obs(env, ego, i, obs, obs_idx, &partner_count); obs_idx = write_road_obs(env, ego, obs, obs_idx, &lane_count, &boundary_count); obs_idx = write_traffic_control_obs(env, ego, obs, obs_idx, &traffic_control_count); + obs[obs_idx++] = (float) lane_count; + obs[obs_idx++] = (float) boundary_count; + obs[obs_idx++] = (float) partner_count; + obs[obs_idx++] = (float) traffic_control_count; assert(obs_idx == obs_per_agent); } } diff --git a/pufferlib/ocean/drive/drive.py b/pufferlib/ocean/drive/drive.py index 577d42252e..1884eaed8a 100644 --- a/pufferlib/ocean/drive/drive.py +++ b/pufferlib/ocean/drive/drive.py @@ -81,7 +81,7 @@ def __init__( reward_conditioning=False, reward_randomization=False, compute_eval_metrics=True, - split_network=False, + shared_network=True, obs_slots_lane_n=32, obs_slots_boundary_n=32, obs_slots_partners_n=16, @@ -113,7 +113,7 @@ def __init__( self.reward_conditioning = reward_conditioning self.reward_randomization = reward_randomization self.compute_eval_metrics = compute_eval_metrics - self.split_network = split_network + self.shared_network = shared_network self.render_mode = render_mode self.num_maps = num_maps self.report_interval = report_interval @@ -210,6 +210,7 @@ def __init__( self.partner_features = binding.PARTNER_FEATURES self.road_features = binding.ROAD_FEATURES self.traffic_control_features = binding.TRAFFIC_CONTROL_FEATURES + self.obs_count_features = binding.OBS_COUNT_FEATURES self.num_reward_coefs = binding.NUM_REWARD_COEFS if reward_conditioning else 0 # Target features based on target_type @@ -227,6 +228,7 @@ def __init__( + self.obs_slots_lane_kept * self.road_features + self.obs_slots_boundary_kept * self.road_features + self.obs_slots_traffic_controls_n * self.traffic_control_features + + self.obs_count_features ) self.single_observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(self.num_obs,), dtype=np.float32) diff --git a/pufferlib/ocean/env_binding.h b/pufferlib/ocean/env_binding.h index 9c6cb56f6b..52f22173ce 100644 --- a/pufferlib/ocean/env_binding.h +++ b/pufferlib/ocean/env_binding.h @@ -1380,6 +1380,7 @@ PyMODINIT_FUNC PyInit_binding(void) { PyModule_AddIntConstant(m, "ROAD_FEATURES", ROAD_FEATURES); PyModule_AddIntConstant(m, "PARTNER_FEATURES", PARTNER_FEATURES); PyModule_AddIntConstant(m, "TRAFFIC_CONTROL_FEATURES", TRAFFIC_CONTROL_FEATURES); + PyModule_AddIntConstant(m, "OBS_COUNT_FEATURES", OBS_COUNT_FEATURES); PyModule_AddIntConstant(m, "NUM_TRAFFIC_CONTROL_TYPES", NUM_TRAFFIC_CONTROL_TYPES); PyModule_AddIntConstant(m, "NUM_TRAFFIC_CONTROL_STATES", NUM_TRAFFIC_CONTROL_STATES); PyModule_AddIntConstant(m, "TRAFFIC_CONTROL_TYPE_NONE", TRAFFIC_CONTROL_TYPE_NONE); diff --git a/pufferlib/ocean/torch.py b/pufferlib/ocean/torch.py index c0e4f06cb7..d86edbce36 100644 --- a/pufferlib/ocean/torch.py +++ b/pufferlib/ocean/torch.py @@ -11,6 +11,8 @@ Recurrent = pufferlib.models.LSTMWrapper +ACTIVATIONS = {"relu": nn.ReLU, "tanh": nn.Tanh, "gelu": nn.GELU} + class DriveBackbone(nn.Module): """ @@ -19,34 +21,55 @@ class DriveBackbone(nn.Module): - Split Actor/Critic (configurable) """ - def _create_encoder(self, in_features, input_size, encoder_gigaflow, dropout=0.0): - if encoder_gigaflow: - return nn.Sequential( - pufferlib.pytorch.layer_init(nn.Linear(in_features, input_size)), - nn.LayerNorm(input_size), - nn.Tanh(), - nn.Dropout(dropout), - pufferlib.pytorch.layer_init(nn.Linear(input_size, input_size)), - ) - else: - return nn.Sequential( - pufferlib.pytorch.layer_init(nn.Linear(in_features, input_size)), - nn.LayerNorm(input_size), - pufferlib.pytorch.layer_init(nn.Linear(input_size, input_size)), - ) + def _create_encoder(self, in_features, out_size): + layers = [pufferlib.pytorch.layer_init(nn.Linear(in_features, out_size))] + if self.encoder_layer_norm: + layers.append(nn.LayerNorm(out_size)) + layers.append(self.encoder_act_cls()) + layers.append(pufferlib.pytorch.layer_init(nn.Linear(out_size, out_size))) + return nn.Sequential(*layers) + + def _encode_and_pool(self, objects, valid_counts, encoder, out_size): + if not self.mask_padded_observations: + return encoder(objects).max(dim=1).values + + valid_mask = torch.arange(objects.shape[1], device=objects.device) < valid_counts.unsqueeze(1) + encoded_objects = objects.new_full( + (objects.shape[0], objects.shape[1], out_size), + torch.finfo(objects.dtype).min, + ) + encoded_objects[valid_mask] = encoder(objects[valid_mask]) + pooled = encoded_objects.amax(dim=1) + return torch.where(valid_counts.unsqueeze(1) == 0, encoded_objects.new_zeros(()), pooled) def __init__( self, env, - input_size, + ego_input_size, + partner_input_size, + lane_input_size, + boundary_input_size, + traffic_control_input_size, + conditioning_input_size, backbone_hidden_size, backbone_num_layers, ego_dim, - encoder_gigaflow, - dropout, + encoder_activation, + encoder_layer_norm, + backbone_activation, + backbone_layer_norm, + mask_padded_observations, ): super().__init__() - self.input_size = input_size + self.encoder_act_cls = ACTIVATIONS[encoder_activation] + self.encoder_layer_norm = encoder_layer_norm + self.ego_dim = ego_dim + self.ego_input_size = ego_input_size + self.partner_input_size = partner_input_size + self.lane_input_size = lane_input_size + self.boundary_input_size = boundary_input_size + self.traffic_control_input_size = traffic_control_input_size + self.conditioning_input_size = conditioning_input_size # Observation dimensions from environment config self.obs_slots_partners_n = env.obs_slots_partners_n @@ -64,55 +87,47 @@ def __init__( + binding.NUM_TRAFFIC_CONTROL_TYPES + binding.NUM_TRAFFIC_CONTROL_STATES ) + self.obs_count_features = binding.OBS_COUNT_FEATURES + self.mask_padded_observations = mask_padded_observations # Conditioning size (reward coefficients + target info) self.conditioning_dim = env.num_reward_coefs + env.target_dim - num_feature_sets = 1 - # 1. observations Encoders - # Each encoder projects raw features into a common input_size embedding space - self.ego_encoder = self._create_encoder(ego_dim, input_size, encoder_gigaflow) + # Each encoder projects raw features into its own embedding space + self.ego_encoder = self._create_encoder(ego_dim, ego_input_size) + encoders_out = ego_input_size if self.obs_slots_lane_kept > 0: - self.lane_encoder = self._create_encoder( - self.road_features_count, - input_size, - encoder_gigaflow, - dropout=dropout, - ) - num_feature_sets += 1 + self.lane_encoder = self._create_encoder(self.road_features_count, lane_input_size) + encoders_out += lane_input_size if self.obs_slots_boundary_kept > 0: - self.boundary_encoder = self._create_encoder( - self.road_features_count, - input_size, - encoder_gigaflow, - dropout=dropout, - ) - num_feature_sets += 1 + self.boundary_encoder = self._create_encoder(self.road_features_count, boundary_input_size) + encoders_out += boundary_input_size if self.obs_slots_partners_n > 0: - self.partner_encoder = self._create_encoder(self.partner_features_count, input_size, encoder_gigaflow) - num_feature_sets += 1 + self.partner_encoder = self._create_encoder(self.partner_features_count, partner_input_size) + encoders_out += partner_input_size if self.obs_slots_traffic_controls_n > 0: self.traffic_control_encoder = self._create_encoder( - self.traffic_control_features_after_onehot, - input_size, - encoder_gigaflow, + self.traffic_control_features_after_onehot, traffic_control_input_size ) - num_feature_sets += 1 + encoders_out += traffic_control_input_size if self.conditioning_dim > 0: - self.conditioning_encoder = self._create_encoder(self.conditioning_dim, input_size, encoder_gigaflow) - num_feature_sets += 1 + self.conditioning_encoder = self._create_encoder(self.conditioning_dim, conditioning_input_size) + encoders_out += conditioning_input_size # 2. Main Backbone MLP + backbone_act_cls = ACTIVATIONS[backbone_activation] backbone_layers = [] - bb_in = num_feature_sets * input_size + bb_in = encoders_out for _ in range(backbone_num_layers): - backbone_layers.append(nn.GELU()) + backbone_layers.append(backbone_act_cls()) backbone_layers.append(pufferlib.pytorch.layer_init(nn.Linear(bb_in, backbone_hidden_size))) + if backbone_layer_norm: + backbone_layers.append(nn.LayerNorm(backbone_hidden_size)) bb_in = backbone_hidden_size - # Add final GELU before heads - backbone_layers.append(nn.GELU()) + # Add final activation before heads + backbone_layers.append(backbone_act_cls()) self.backbone = nn.Sequential(*backbone_layers) - self.out_dim = backbone_hidden_size if backbone_num_layers > 0 else num_feature_sets * input_size + self.out_dim = backbone_hidden_size if backbone_num_layers > 0 else encoders_out def forward(self, observations, ego_dim): # Extract and slice observations from the flat buffer @@ -137,6 +152,20 @@ def forward(self, observations, ego_dim): slide_idx += boundary_dim traffic_control_observations = observations[:, slide_idx : slide_idx + traffic_control_dim] + count_observations = observations[ + :, slide_idx + traffic_control_dim : slide_idx + traffic_control_dim + self.obs_count_features + ] + lane_counts, boundary_counts, partner_counts, traffic_control_counts = [ + count_observations[:, i].long().clamp_(0, capacity) + for i, capacity in enumerate( + ( + self.obs_slots_lane_kept, + self.obs_slots_boundary_kept, + self.obs_slots_partners_n, + self.obs_slots_traffic_controls_n, + ) + ) + ] # Encode Ego State ego_features = self.ego_encoder(ego_observations) @@ -146,17 +175,27 @@ def forward(self, observations, ego_dim): # Encode Lanes and Boundaries separately if self.obs_slots_lane_kept > 0: lane_objects = lane_observations.view(-1, self.obs_slots_lane_kept, self.road_features_count) - lane_features = self.lane_encoder(lane_objects).max(dim=1).values + lane_features = self._encode_and_pool(lane_objects, lane_counts, self.lane_encoder, self.lane_input_size) feature_list.append(lane_features) if self.obs_slots_boundary_kept > 0: boundary_objects = boundary_observations.view(-1, self.obs_slots_boundary_kept, self.road_features_count) - boundary_features = self.boundary_encoder(boundary_objects).max(dim=1).values + boundary_features = self._encode_and_pool( + boundary_objects, + boundary_counts, + self.boundary_encoder, + self.boundary_input_size, + ) feature_list.append(boundary_features) # Encode Partners if self.obs_slots_partners_n > 0: partner_objects = partner_observations.view(-1, self.obs_slots_partners_n, self.partner_features_count) - partner_features = self.partner_encoder(partner_objects).max(dim=1).values + partner_features = self._encode_and_pool( + partner_objects, + partner_counts, + self.partner_encoder, + self.partner_input_size, + ) feature_list.append(partner_features) # Encode Traffic Controls @@ -179,7 +218,12 @@ def forward(self, observations, ego_dim): [traffic_control_continuous, traffic_control_type_onehot, traffic_control_state_onehot], dim=2, ) - traffic_control_features = self.traffic_control_encoder(traffic_control_objects).max(dim=1).values + traffic_control_features = self._encode_and_pool( + traffic_control_objects, + traffic_control_counts, + self.traffic_control_encoder, + self.traffic_control_input_size, + ) feature_list.append(traffic_control_features) # Add optional features if enabled @@ -263,43 +307,59 @@ class Drive(nn.Module): def __init__( self, env, - input_size: int, + ego_input_size: int, + partner_input_size: int, + lane_input_size: int, + boundary_input_size: int, + traffic_control_input_size: int, + conditioning_input_size: int, backbone_hidden_size: int, backbone_num_layers: int, actor_hidden_size: int, actor_num_layers: int, critic_hidden_size: int, critic_num_layers: int, - encoder_gigaflow: bool, - dropout: int, - split_network: bool, + encoder_activation: str, + encoder_layer_norm: bool, + backbone_activation: str, + backbone_layer_norm: bool, + shared_network: bool, + mask_padded_observations: bool = True, ): super().__init__() # Configuration flags from policy kwargs - self.split_network = split_network + self.shared_network = shared_network self.ego_dim = env.ego_features # Prepare arguments for the Backbone backbone_args = { "env": env, - "input_size": input_size, + "ego_input_size": ego_input_size, + "partner_input_size": partner_input_size, + "lane_input_size": lane_input_size, + "boundary_input_size": boundary_input_size, + "traffic_control_input_size": traffic_control_input_size, + "conditioning_input_size": conditioning_input_size, "backbone_hidden_size": backbone_hidden_size, "backbone_num_layers": backbone_num_layers, "ego_dim": self.ego_dim, - "encoder_gigaflow": encoder_gigaflow, - "dropout": dropout, + "encoder_activation": encoder_activation, + "encoder_layer_norm": encoder_layer_norm, + "backbone_activation": backbone_activation, + "backbone_layer_norm": backbone_layer_norm, + "mask_padded_observations": mask_padded_observations, } # Instantiate backbones self.actor_backbone = DriveBackbone(**backbone_args) - # If split_network is True, create a separate backbone for the critic. - # Otherwise, share the same backbone for both. - if self.split_network: - self.critic_backbone = DriveBackbone(**backbone_args) - else: + # If using shared network, critic backbone is the same as actor backbone. + # Otherwise, create a separate critic backbone with the same architecture. + if self.shared_network: self.critic_backbone = self.actor_backbone + else: + self.critic_backbone = DriveBackbone(**backbone_args) # Setup action and value heads self.is_continuous = isinstance(env.single_action_space, pufferlib.spaces.Box) @@ -337,10 +397,10 @@ def forward(self, observations, state=None): actor_hidden = self.actor_backbone(observations, self.ego_dim) # Forward pass for critic (may use separate backbone) - if self.split_network: - critic_hidden = self.critic_backbone(observations, self.ego_dim) - else: + if self.shared_network: critic_hidden = actor_hidden + else: + critic_hidden = self.critic_backbone(observations, self.ego_dim) # Compute actions if self.is_continuous: @@ -367,7 +427,7 @@ def pool_slot_counts(self, observations, state=None): # Required for PufferLib recurrent wrappers def encode_observations(self, observations, state=None): - assert not self.split_network, "LSTM wrapper doesn't support split_network=True" + assert self.shared_network, "LSTM wrapper requires shared_network=True" return self.actor_backbone(observations, self.ego_dim) def decode_actions(self, hidden): From a9952e57cc271c361a51c4b051869182323d4b7f Mon Sep 17 00:00:00 2001 From: Valentin Charraut Date: Tue, 2 Jun 2026 23:32:37 +0200 Subject: [PATCH 02/10] Update tests --- .../data/drive_rollout_golden.json | 14 ++++- .../smoke_tests/data/drive_smoke_golden.json | 57 ++++++++++++------- tests/smoke_tests/test_drive_train.py | 7 ++- 3 files changed, 55 insertions(+), 23 deletions(-) diff --git a/tests/smoke_tests/data/drive_rollout_golden.json b/tests/smoke_tests/data/drive_rollout_golden.json index 8f62e6e37b..0c5a06b948 100644 --- a/tests/smoke_tests/data/drive_rollout_golden.json +++ b/tests/smoke_tests/data/drive_rollout_golden.json @@ -6,12 +6,24 @@ "comfort_violation_count": 0.7257332022373493, "dnf_rate": 0.5625, "episode_length": 46.42307692307692, - "episode_return": -2.3540108112188487, + "episode_return": -2.4922777001674357, "lane_center_rate": 0.7063253728243021, "n": 16.0, "num_goals_reached": 0.0, "offroad_rate": 0.3894230769230769, "red_light_violation_rate": 0.038461538461538464, + "reward_components/ade": 0.0, + "reward_components/collision": -0.015820913589917697, + "reward_components/comfort": -1.690985042315263, + "reward_components/goal": 0.0, + "reward_components/lane_align": -0.14203165815426752, + "reward_components/lane_center": -0.0026557931694417046, + "reward_components/offroad": -0.5841346153846154, + "reward_components/overspeed": 0.0, + "reward_components/red_light": -0.038461538461538464, + "reward_components/reverse": -0.01809134094331127, + "reward_components/timestep": -0.0001000300175152146, + "reward_components/velocity": 2.3821118632510593e-06, "score": 0.0, "velocity_progress_sum": 0.00018323936428015048 }, diff --git a/tests/smoke_tests/data/drive_smoke_golden.json b/tests/smoke_tests/data/drive_smoke_golden.json index 2825c653ba..b3e2c6865e 100644 --- a/tests/smoke_tests/data/drive_smoke_golden.json +++ b/tests/smoke_tests/data/drive_smoke_golden.json @@ -1,33 +1,48 @@ { "env": { - "avg_distance_per_infraction": 13.20275001525879, - "avg_speed_per_agent": 1.3609784364700317, - "collision_rate": 0.0125, - "comfort_violation_count": 0.7306405305862427, - "dnf_rate": 0.55, - "episode_length": 43.6, - "episode_return": -2.2548463344573975, - "lane_center_rate": 0.6961542010307312, + "avg_distance_per_infraction": 13.18341121673584, + "avg_speed_per_agent": 1.3711345672607422, + "collision_rate": 0.0375, + "comfort_violation_count": 0.7263497829437255, + "dnf_rate": 0.5625, + "episode_length": 41.8, + "episode_return": -2.334346318244934, + "lane_center_rate": 0.7175199031829834, "n": 16.0, "num_goals_reached": 0.0, - "offroad_rate": 0.3625, - "red_light_violation_rate": 0.075, + "obs/max": 80.0, + "obs/mean": 0.2281341243069619, + "obs/min": -1.236907229758799, + "offroad_rate": 0.375, + "red_light_violation_rate": 0.025, + "reward_components/ade": 0.0, + "reward_components/collision": -0.06195625364780426, + "reward_components/comfort": -1.5356246471405028, + "reward_components/goal": 0.0, + "reward_components/lane_align": -0.13004764765501023, + "reward_components/lane_center": -0.002508045267313719, + "reward_components/offroad": -0.5625, + "reward_components/overspeed": 0.0, + "reward_components/red_light": -0.025, + "reward_components/reverse": -0.01661874633282423, + "reward_components/timestep": -9.165622614091263e-05, + "reward_components/velocity": 0.0, "score": 0.0, "velocity_progress_sum": 0.0 }, "losses": { - "approx_kl": 0.0004938042755903942, + "approx_kl": 0.0002155543354872082, "clipfrac": 0.0, - "ema_max": 1.191539317369461, - "entropy": 2.4832939420427596, - "explained_variance": 0.18538302183151245, - "filter_threshold": 0.01191539317369461, - "filtered_fraction": 0.023398328690807824, - "kept_fraction": 0.9766016713091922, - "masked_fraction": 0.12353515625, - "old_approx_kl": 0.000515544121818883, - "policy_loss": -0.0005634417258469122, - "value_loss": 0.08380150688546044 + "ema_max": 1.2262159883975983, + "entropy": 2.484058209827968, + "explained_variance": 0.16326427459716797, + "filter_threshold": 0.012262159883975983, + "filtered_fraction": 0.024684585847504104, + "kept_fraction": 0.9753154141524959, + "masked_fraction": 0.10986328125, + "old_approx_kl": -0.0011627360114029475, + "policy_loss": -0.0004513516489948545, + "value_loss": 0.08172488159367017 }, "meta": { "bptt_horizon": 64, diff --git a/tests/smoke_tests/test_drive_train.py b/tests/smoke_tests/test_drive_train.py index 038e974211..554d1875ae 100644 --- a/tests/smoke_tests/test_drive_train.py +++ b/tests/smoke_tests/test_drive_train.py @@ -148,7 +148,12 @@ def _build_config(): _set_existing( args["policy"], { - "input_size": 32, + "ego_input_size": 32, + "partner_input_size": 32, + "lane_input_size": 32, + "boundary_input_size": 32, + "traffic_control_input_size": 32, + "conditioning_input_size": 32, "backbone_hidden_size": 32, "actor_hidden_size": 32, "critic_hidden_size": 32, From 212fd0e3e4d79e8a546bc8d0a1fcd7891beca7cd Mon Sep 17 00:00:00 2001 From: Valentin Charraut <16002514+vcharraut@users.noreply.github.com> Date: Tue, 2 Jun 2026 23:42:41 +0200 Subject: [PATCH 03/10] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- notebooks/01_observations.ipynb | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/notebooks/01_observations.ipynb b/notebooks/01_observations.ipynb index 7931e59af9..3197e06c39 100644 --- a/notebooks/01_observations.ipynb +++ b/notebooks/01_observations.ipynb @@ -165,8 +165,7 @@ "idx += env.obs_slots_traffic_controls_n * env.traffic_control_features\n", "assert np.allclose(traffic_manual, traffic), \"traffic mismatch\"\n", "\n", - "idx += 4 # padding at end of obs\n", - "\n", + idx += 4 # appended slot-count features at end of obs "assert idx == obs.shape[1], f\"obs size mismatch: used {idx}, total {obs.shape[1]}\"\n", "print(f\"All slices match. Total features used: {idx}\")" ] From b3c91f0045dc8ed91bf503f676579fc3828ca888 Mon Sep 17 00:00:00 2001 From: Valentin Charraut Date: Tue, 2 Jun 2026 23:45:55 +0200 Subject: [PATCH 04/10] Rename mask_padded_observations to mask_padded_features for consistency --- pufferlib/config/ocean/drive.ini | 2 +- pufferlib/ocean/torch.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pufferlib/config/ocean/drive.ini b/pufferlib/config/ocean/drive.ini index 63d87ac303..0740b5896b 100644 --- a/pufferlib/config/ocean/drive.ini +++ b/pufferlib/config/ocean/drive.ini @@ -21,7 +21,7 @@ conditioning_input_size = 64 ; Encoder activation - options: "relu", "tanh", "gelu" encoder_activation = "relu" encoder_layer_norm = True -mask_padded_observations = False +mask_padded_features = False ; Shared backbone layer backbone_hidden_size = 512 backbone_num_layers = 4 diff --git a/pufferlib/ocean/torch.py b/pufferlib/ocean/torch.py index d86edbce36..1bc828bf5e 100644 --- a/pufferlib/ocean/torch.py +++ b/pufferlib/ocean/torch.py @@ -30,7 +30,7 @@ def _create_encoder(self, in_features, out_size): return nn.Sequential(*layers) def _encode_and_pool(self, objects, valid_counts, encoder, out_size): - if not self.mask_padded_observations: + if not self.mask_padded_features: return encoder(objects).max(dim=1).values valid_mask = torch.arange(objects.shape[1], device=objects.device) < valid_counts.unsqueeze(1) @@ -58,7 +58,7 @@ def __init__( encoder_layer_norm, backbone_activation, backbone_layer_norm, - mask_padded_observations, + mask_padded_features, ): super().__init__() self.encoder_act_cls = ACTIVATIONS[encoder_activation] @@ -88,7 +88,7 @@ def __init__( + binding.NUM_TRAFFIC_CONTROL_STATES ) self.obs_count_features = binding.OBS_COUNT_FEATURES - self.mask_padded_observations = mask_padded_observations + self.mask_padded_features = mask_padded_features # Conditioning size (reward coefficients + target info) self.conditioning_dim = env.num_reward_coefs + env.target_dim @@ -324,7 +324,7 @@ def __init__( backbone_activation: str, backbone_layer_norm: bool, shared_network: bool, - mask_padded_observations: bool = True, + mask_padded_features: bool, ): super().__init__() @@ -348,7 +348,7 @@ def __init__( "encoder_layer_norm": encoder_layer_norm, "backbone_activation": backbone_activation, "backbone_layer_norm": backbone_layer_norm, - "mask_padded_observations": mask_padded_observations, + "mask_padded_features": mask_padded_features, } # Instantiate backbones From 718613194e02c619c55bdd5f669b0fb1e3569917 Mon Sep 17 00:00:00 2001 From: Valentin Charraut Date: Tue, 2 Jun 2026 23:52:48 +0200 Subject: [PATCH 05/10] Rename OBS_COUNT_FEATURES to OBS_SLOT_NUM_TYPES for consistency across the codebase --- pufferlib/ocean/drive/drive.h | 4 ++-- pufferlib/ocean/drive/drive.py | 4 ++-- pufferlib/ocean/env_binding.h | 2 +- pufferlib/ocean/torch.py | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pufferlib/ocean/drive/drive.h b/pufferlib/ocean/drive/drive.h index 2397310f36..5303799b1a 100644 --- a/pufferlib/ocean/drive/drive.h +++ b/pufferlib/ocean/drive/drive.h @@ -121,9 +121,9 @@ #define ROAD_FEATURES 7 #define PARTNER_FEATURES 9 #define TRAFFIC_CONTROL_FEATURES 7 -#define OBS_COUNT_FEATURES 4 #define STATIC_TARGET_FEATURES 3 #define DYNAMIC_TARGET_FEATURES 5 +#define OBS_SLOT_NUM_TYPES 4 // GIGAFLOW specific #define MAX_ROUTE_LENGTH 64 @@ -3700,7 +3700,7 @@ static int compute_observation_size(Drive *env) { : 0; return EGO_FEATURES + PARTNER_FEATURES * env->obs_slots_partners_n + ROAD_FEATURES * (env->obs_slots_lane_kept + env->obs_slots_boundary_kept) - + TRAFFIC_CONTROL_FEATURES * env->obs_slots_traffic_controls_n + OBS_COUNT_FEATURES + + TRAFFIC_CONTROL_FEATURES * env->obs_slots_traffic_controls_n + OBS_SLOT_NUM_TYPES + env->reward_conditioning * NUM_REWARD_COEFS + env->num_target_waypoints * target_features; } diff --git a/pufferlib/ocean/drive/drive.py b/pufferlib/ocean/drive/drive.py index 1884eaed8a..59ed0aa3eb 100644 --- a/pufferlib/ocean/drive/drive.py +++ b/pufferlib/ocean/drive/drive.py @@ -210,7 +210,7 @@ def __init__( self.partner_features = binding.PARTNER_FEATURES self.road_features = binding.ROAD_FEATURES self.traffic_control_features = binding.TRAFFIC_CONTROL_FEATURES - self.obs_count_features = binding.OBS_COUNT_FEATURES + self.obs_slot_num_types = binding.OBS_SLOT_NUM_TYPES self.num_reward_coefs = binding.NUM_REWARD_COEFS if reward_conditioning else 0 # Target features based on target_type @@ -228,7 +228,7 @@ def __init__( + self.obs_slots_lane_kept * self.road_features + self.obs_slots_boundary_kept * self.road_features + self.obs_slots_traffic_controls_n * self.traffic_control_features - + self.obs_count_features + + self.obs_slot_num_types ) self.single_observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(self.num_obs,), dtype=np.float32) diff --git a/pufferlib/ocean/env_binding.h b/pufferlib/ocean/env_binding.h index 52f22173ce..a45038d8b5 100644 --- a/pufferlib/ocean/env_binding.h +++ b/pufferlib/ocean/env_binding.h @@ -1380,7 +1380,7 @@ PyMODINIT_FUNC PyInit_binding(void) { PyModule_AddIntConstant(m, "ROAD_FEATURES", ROAD_FEATURES); PyModule_AddIntConstant(m, "PARTNER_FEATURES", PARTNER_FEATURES); PyModule_AddIntConstant(m, "TRAFFIC_CONTROL_FEATURES", TRAFFIC_CONTROL_FEATURES); - PyModule_AddIntConstant(m, "OBS_COUNT_FEATURES", OBS_COUNT_FEATURES); + PyModule_AddIntConstant(m, "OBS_SLOT_NUM_TYPES", OBS_SLOT_NUM_TYPES); PyModule_AddIntConstant(m, "NUM_TRAFFIC_CONTROL_TYPES", NUM_TRAFFIC_CONTROL_TYPES); PyModule_AddIntConstant(m, "NUM_TRAFFIC_CONTROL_STATES", NUM_TRAFFIC_CONTROL_STATES); PyModule_AddIntConstant(m, "TRAFFIC_CONTROL_TYPE_NONE", TRAFFIC_CONTROL_TYPE_NONE); diff --git a/pufferlib/ocean/torch.py b/pufferlib/ocean/torch.py index 1bc828bf5e..1c1f265d4b 100644 --- a/pufferlib/ocean/torch.py +++ b/pufferlib/ocean/torch.py @@ -87,7 +87,7 @@ def __init__( + binding.NUM_TRAFFIC_CONTROL_TYPES + binding.NUM_TRAFFIC_CONTROL_STATES ) - self.obs_count_features = binding.OBS_COUNT_FEATURES + self.obs_slot_num_types = binding.OBS_SLOT_NUM_TYPES self.mask_padded_features = mask_padded_features # Conditioning size (reward coefficients + target info) self.conditioning_dim = env.num_reward_coefs + env.target_dim @@ -153,7 +153,7 @@ def forward(self, observations, ego_dim): traffic_control_observations = observations[:, slide_idx : slide_idx + traffic_control_dim] count_observations = observations[ - :, slide_idx + traffic_control_dim : slide_idx + traffic_control_dim + self.obs_count_features + :, slide_idx + traffic_control_dim : slide_idx + traffic_control_dim + self.obs_slot_num_types ] lane_counts, boundary_counts, partner_counts, traffic_control_counts = [ count_observations[:, i].long().clamp_(0, capacity) From e1d817f51f6d93c0cd930df780e2bd74243ede24 Mon Sep 17 00:00:00 2001 From: Valentin Charraut Date: Wed, 3 Jun 2026 00:23:22 +0200 Subject: [PATCH 06/10] Refactor conditioning terminology to target in training, inference, architecture, and utility files - Updated references from `conditioning_dim` to `target_dim` in training and inference notebooks. - Changed `conditioning_input_size` to `target_input_size` in configuration files and utility scripts. - Adjusted encoder creation and usage in the DriveBackbone class to reflect the new target terminology. - Ensured consistency across all relevant files to improve clarity and maintainability. --- notebooks/04_training.ipynb | 4 ++-- notebooks/05_inference.ipynb | 6 +++--- notebooks/06_architecture.ipynb | 18 ++++++++--------- notebooks/notebook_utils.py | 2 +- pufferlib/config/ocean/drive.ini | 2 +- pufferlib/ocean/torch.py | 28 +++++++++++++-------------- tests/smoke_tests/test_drive_train.py | 2 +- 7 files changed, 31 insertions(+), 31 deletions(-) diff --git a/notebooks/04_training.ipynb b/notebooks/04_training.ipynb index 7a79a8e0ae..b1da07df42 100644 --- a/notebooks/04_training.ipynb +++ b/notebooks/04_training.ipynb @@ -146,7 +146,7 @@ " f\"ego_obs: shape={ego_obs.shape}, NaN={torch.isnan(ego_obs).sum().item()}, range=[{ego_obs.min():.3f}, {ego_obs.max():.3f}]\"\n", ")\n", "\n", - "cond_dim = backbone.conditioning_dim\n", + "cond_dim = backbone.target_dim\n", "if cond_dim > 0:\n", " cond_obs = x[:, slide_idx : slide_idx + cond_dim]\n", " slide_idx += cond_dim\n", @@ -180,7 +180,7 @@ "\n", "if cond_dim > 0:\n", " with torch.no_grad():\n", - " cond_enc = backbone.conditioning_encoder(cond_obs)\n", + " cond_enc = backbone.target_encoder(cond_obs)\n", " print(\n", " f\"{'cond':>10s}_enc: NaN={torch.isnan(cond_enc).sum().item()}, dead={((cond_enc.abs().sum(dim=0) == 0).sum().item())}, range=[{cond_enc.min():.3f}, {cond_enc.max():.3f}]\"\n", " )" diff --git a/notebooks/05_inference.ipynb b/notebooks/05_inference.ipynb index 9e2653d828..c605aaffee 100644 --- a/notebooks/05_inference.ipynb +++ b/notebooks/05_inference.ipynb @@ -1514,8 +1514,8 @@ " True,\n", " )\n", " )\n", - "if bb.conditioning_dim > 0:\n", - " enc_inventory.append((\"conditioning\", bb.conditioning_encoder, bb.conditioning_dim, 1, False))\n", + "if bb.target_dim > 0:\n", + " enc_inventory.append((\"conditioning\", bb.target_encoder, bb.target_dim, 1, False))\n", "\n", "enc_names = [n for n, *_ in enc_inventory]\n", "set_encs = [n for n, _, _, _, is_set in enc_inventory if is_set]\n", @@ -1554,7 +1554,7 @@ "lane_dim = bb.obs_slots_lane_kept * bb.road_features_count\n", "boundary_dim = bb.obs_slots_boundary_kept * bb.road_features_count\n", "traffic_dim = bb.obs_slots_traffic_controls_n * bb.traffic_control_features_count\n", - "_s = ego_dim + bb.conditioning_dim\n", + "_s = ego_dim + bb.target_dim\n", "sl = {}\n", "sl[\"partner\"] = (_s, _s + partner_dim, bb.obs_slots_partners_n, bb.partner_features_count)\n", "_s += partner_dim\n", diff --git a/notebooks/06_architecture.ipynb b/notebooks/06_architecture.ipynb index 260b636444..7a799a2abf 100644 --- a/notebooks/06_architecture.ipynb +++ b/notebooks/06_architecture.ipynb @@ -47,7 +47,7 @@ " lane_input_size=INPUT_SIZE,\n", " boundary_input_size=INPUT_SIZE,\n", " traffic_control_input_size=INPUT_SIZE,\n", - " conditioning_input_size=INPUT_SIZE,\n", + " target_input_size=INPUT_SIZE,\n", " backbone_hidden_size=BACKBONE_HIDDEN_SIZE,\n", " backbone_num_layers=BACKBONE_NUM_LAYERS,\n", " actor_hidden_size=ACTOR_HIDDEN_SIZE,\n", @@ -102,7 +102,7 @@ "outputs": [], "source": [ "backbone = policy.actor_backbone\n", - "cond_dim = backbone.conditioning_dim\n", + "cond_dim = backbone.target_dim\n", "\n", "# Collect encoder info\n", "encoders = [\n", @@ -239,8 +239,8 @@ " \"partner_encoder\": backbone.partner_encoder,\n", " \"traffic_ctrl_encoder\": backbone.traffic_control_encoder,\n", "}\n", - "if backbone.conditioning_dim > 0:\n", - " components[\"conditioning_encoder\"] = backbone.conditioning_encoder\n", + "if backbone.target_dim > 0:\n", + " components[\"target_encoder\"] = backbone.target_encoder\n", "components[\"backbone_mlp\"] = backbone.backbone\n", "components[\"actor_head\"] = policy.actor_head\n", "components[\"critic_head\"] = policy.critic_head\n", @@ -287,7 +287,7 @@ "backbone = policy.actor_backbone\n", "\n", "slide_idx = env.ego_features\n", - "cond_dim = backbone.conditioning_dim\n", + "cond_dim = backbone.target_dim\n", "partner_dim = env.obs_slots_partners_n * env.partner_features\n", "lane_dim = env.obs_slots_lane_kept * env.road_features\n", "boundary_dim = env.obs_slots_boundary_kept * env.road_features\n", @@ -330,7 +330,7 @@ " print(f\" ego_encoder: {ego_obs.shape} -> {ego_enc.shape}\")\n", "\n", " if cond_dim > 0:\n", - " cond_enc = backbone.conditioning_encoder(cond_obs)\n", + " cond_enc = backbone.target_encoder(cond_obs)\n", " print(f\" cond_encoder: {cond_obs.shape} -> {cond_enc.shape}\")\n", "\n", " p_reshaped = partner_obs.view(-1, env.obs_slots_partners_n, env.partner_features)\n", @@ -436,7 +436,7 @@ " activations[\"ego\"] = backbone.ego_encoder(obs_tensor[:, : env.ego_features])\n", "\n", " if cond_dim > 0:\n", - " activations[\"conditioning\"] = backbone.conditioning_encoder(obs_tensor[:, slide : slide + cond_dim])\n", + " activations[\"conditioning\"] = backbone.target_encoder(obs_tensor[:, slide : slide + cond_dim])\n", " slide += cond_dim\n", "\n", " p_obs = obs_tensor[:, slide : slide + partner_dim].view(-1, env.obs_slots_partners_n, env.partner_features)\n", @@ -580,7 +580,7 @@ " \"lane_input_size\": 64,\n", " \"boundary_input_size\": 64,\n", " \"traffic_control_input_size\": 64,\n", - " \"conditioning_input_size\": 64,\n", + " \"target_input_size\": 64,\n", " \"backbone_num_layers\": 1,\n", " \"actor_hidden_size\": 128,\n", " \"actor_num_layers\": 0,\n", @@ -605,7 +605,7 @@ " \"lane_input_size\": encoder_size,\n", " \"boundary_input_size\": encoder_size,\n", " \"traffic_control_input_size\": encoder_size,\n", - " \"conditioning_input_size\": encoder_size,\n", + " \"target_input_size\": encoder_size,\n", " }\n", " )\n", " p = DrivePolicy(env, **full_cfg).to(device)\n", diff --git a/notebooks/notebook_utils.py b/notebooks/notebook_utils.py index 100d15c011..f7f61f459c 100644 --- a/notebooks/notebook_utils.py +++ b/notebooks/notebook_utils.py @@ -90,7 +90,7 @@ "lane_input_size": 64, "boundary_input_size": 64, "traffic_control_input_size": 64, - "conditioning_input_size": 64, + "target_input_size": 64, "backbone_hidden_size": 128, "backbone_num_layers": 1, "actor_hidden_size": 128, diff --git a/pufferlib/config/ocean/drive.ini b/pufferlib/config/ocean/drive.ini index 0740b5896b..dfff41aa3b 100644 --- a/pufferlib/config/ocean/drive.ini +++ b/pufferlib/config/ocean/drive.ini @@ -17,7 +17,7 @@ partner_input_size = 64 lane_input_size = 64 boundary_input_size = 64 traffic_control_input_size = 64 -conditioning_input_size = 64 +target_input_size = 64 ; Encoder activation - options: "relu", "tanh", "gelu" encoder_activation = "relu" encoder_layer_norm = True diff --git a/pufferlib/ocean/torch.py b/pufferlib/ocean/torch.py index 1c1f265d4b..3e33d32420 100644 --- a/pufferlib/ocean/torch.py +++ b/pufferlib/ocean/torch.py @@ -50,7 +50,7 @@ def __init__( lane_input_size, boundary_input_size, traffic_control_input_size, - conditioning_input_size, + target_input_size, backbone_hidden_size, backbone_num_layers, ego_dim, @@ -69,7 +69,7 @@ def __init__( self.lane_input_size = lane_input_size self.boundary_input_size = boundary_input_size self.traffic_control_input_size = traffic_control_input_size - self.conditioning_input_size = conditioning_input_size + self.target_input_size = target_input_size # Observation dimensions from environment config self.obs_slots_partners_n = env.obs_slots_partners_n @@ -90,7 +90,7 @@ def __init__( self.obs_slot_num_types = binding.OBS_SLOT_NUM_TYPES self.mask_padded_features = mask_padded_features # Conditioning size (reward coefficients + target info) - self.conditioning_dim = env.num_reward_coefs + env.target_dim + self.target_dim = env.num_reward_coefs + env.goal_dim # 1. observations Encoders # Each encoder projects raw features into its own embedding space @@ -110,9 +110,9 @@ def __init__( self.traffic_control_features_after_onehot, traffic_control_input_size ) encoders_out += traffic_control_input_size - if self.conditioning_dim > 0: - self.conditioning_encoder = self._create_encoder(self.conditioning_dim, conditioning_input_size) - encoders_out += conditioning_input_size + if self.target_dim > 0: + self.target_encoder = self._create_encoder(self.target_dim, target_input_size) + encoders_out += target_input_size # 2. Main Backbone MLP backbone_act_cls = ACTIVATIONS[backbone_activation] @@ -139,8 +139,8 @@ def forward(self, observations, ego_dim): slide_idx = ego_dim ego_observations = observations[:, :slide_idx] - conditioning_observations = observations[:, slide_idx : slide_idx + self.conditioning_dim] - slide_idx += self.conditioning_dim + target_observations = observations[:, slide_idx : slide_idx + self.target_dim] + slide_idx += self.target_dim partner_observations = observations[:, slide_idx : slide_idx + partner_dim] slide_idx += partner_dim @@ -227,9 +227,9 @@ def forward(self, observations, ego_dim): feature_list.append(traffic_control_features) # Add optional features if enabled - if self.conditioning_dim > 0: - conditioning_features = self.conditioning_encoder(conditioning_observations) - feature_list.append(conditioning_features) + if self.target_dim > 0: + target_features = self.target_encoder(target_observations) + feature_list.append(target_features) # Concatenate all features and pass through main backbone concat_features = torch.cat(feature_list, dim=1) @@ -241,7 +241,7 @@ def pool_slot_counts(self, observations, ego_dim): boundary_dim = self.obs_slots_boundary_kept * self.road_features_count traffic_control_dim = self.obs_slots_traffic_controls_n * self.traffic_control_features_count - slide_idx = ego_dim + self.conditioning_dim + slide_idx = ego_dim + self.target_dim partner_observations = observations[:, slide_idx : slide_idx + partner_dim] slide_idx += partner_dim lane_observations = observations[:, slide_idx : slide_idx + lane_dim] @@ -312,7 +312,7 @@ def __init__( lane_input_size: int, boundary_input_size: int, traffic_control_input_size: int, - conditioning_input_size: int, + target_input_size: int, backbone_hidden_size: int, backbone_num_layers: int, actor_hidden_size: int, @@ -340,7 +340,7 @@ def __init__( "lane_input_size": lane_input_size, "boundary_input_size": boundary_input_size, "traffic_control_input_size": traffic_control_input_size, - "conditioning_input_size": conditioning_input_size, + "target_input_size": target_input_size, "backbone_hidden_size": backbone_hidden_size, "backbone_num_layers": backbone_num_layers, "ego_dim": self.ego_dim, diff --git a/tests/smoke_tests/test_drive_train.py b/tests/smoke_tests/test_drive_train.py index 554d1875ae..a0aba76048 100644 --- a/tests/smoke_tests/test_drive_train.py +++ b/tests/smoke_tests/test_drive_train.py @@ -153,7 +153,7 @@ def _build_config(): "lane_input_size": 32, "boundary_input_size": 32, "traffic_control_input_size": 32, - "conditioning_input_size": 32, + "target_input_size": 32, "backbone_hidden_size": 32, "actor_hidden_size": 32, "critic_hidden_size": 32, From 77ca722fdf34b7c1ef576a1f806cc42508688401 Mon Sep 17 00:00:00 2001 From: Valentin Charraut Date: Wed, 3 Jun 2026 17:31:16 +0200 Subject: [PATCH 07/10] Add neural network architecture notebook and initialize notebooks package - Created a new notebook `06_architecture.py` for visualizing and analyzing the DrivePolicy architecture, including model summary, encoder breakdown, forward pass tracing, weight distributions, and architecture comparisons. - Initialized the `notebooks` package with an empty `__init__.py` file. --- notebooks/01_observations.ipynb | 485 --------- notebooks/01_observations.py | 375 +++++++ notebooks/02_rewards.ipynb | 332 ------ notebooks/02_rewards.py | 241 +++++ notebooks/03_metrics.ipynb | 330 ------ notebooks/03_metrics.py | 235 ++++ notebooks/04_training.ipynb | 549 ---------- notebooks/04_training.py | 418 ++++++++ notebooks/05_inference.ipynb | 1782 ------------------------------- notebooks/05_inference.py | 1576 +++++++++++++++++++++++++++ notebooks/06_architecture.ipynb | 813 -------------- notebooks/06_architecture.py | 694 ++++++++++++ 12 files changed, 3539 insertions(+), 4291 deletions(-) delete mode 100644 notebooks/01_observations.ipynb create mode 100644 notebooks/01_observations.py delete mode 100644 notebooks/02_rewards.ipynb create mode 100644 notebooks/02_rewards.py delete mode 100644 notebooks/03_metrics.ipynb create mode 100644 notebooks/03_metrics.py delete mode 100644 notebooks/04_training.ipynb create mode 100644 notebooks/04_training.py delete mode 100644 notebooks/05_inference.ipynb create mode 100644 notebooks/05_inference.py delete mode 100644 notebooks/06_architecture.ipynb create mode 100644 notebooks/06_architecture.py diff --git a/notebooks/01_observations.ipynb b/notebooks/01_observations.ipynb deleted file mode 100644 index 3197e06c39..0000000000 --- a/notebooks/01_observations.ipynb +++ /dev/null @@ -1,485 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 01 - Observation Pipeline Debug\n", - "Verify obs vector is correctly packed, normalized, interpretable." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "from pufferlib.viz import plot_observation, plot_simulator_state, unpack_obs\n", - "from notebooks.notebook_utils import COEF_NAMES, make_drive_env, zero_actions\n", - "\n", - "env, obs, info = make_drive_env()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Raw obs inspection" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Take first step so obs are populated\n", - "actions = zero_actions(env)\n", - "\n", - "obs, rew, term, trunc, info = env.step(actions)\n", - "\n", - "print(f\"shape: {obs.shape}, dtype: {obs.dtype}\")\n", - "print(f\"min: {obs.min():.4f}, max: {obs.max():.4f}, mean: {obs.mean():.4f}, std: {obs.std():.4f}\")\n", - "print(f\"NaN: {np.isnan(obs).sum()}, Inf: {np.isinf(obs).sum()}\")\n", - "print(f\"% zeros: {(obs == 0).mean() * 100:.1f}%\")\n", - "print(f\"% outside [-1,1]: {((obs < -1) | (obs > 1)).mean() * 100:.2f}%\")\n", - "\n", - "fig, axes = plt.subplots(1, 2, figsize=(14, 4))\n", - "axes[0].hist(obs.flatten(), bins=100, edgecolor=\"black\", alpha=0.7)\n", - "axes[0].set_title(\"Full obs distribution\")\n", - "axes[0].set_xlabel(\"Value\")\n", - "# Per-agent: show obs[0] vs obs[1]\n", - "for i in range(min(4, obs.shape[0])):\n", - " axes[1].plot(obs[i], alpha=0.5, label=f\"agent {i}\")\n", - "axes[1].set_title(\"Obs vector by index (first 4 agents)\")\n", - "axes[1].legend()\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Unpack with pufferlib.viz.unpack_obs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ego, target, partners, lanes, boundaries, traffic = unpack_obs(\n", - " obs[:1],\n", - " target_type=env.target_type,\n", - " reward_conditioning=env.reward_conditioning,\n", - " num_target_waypoints=env.num_target_waypoints,\n", - " obs_slots_partners_n=env.obs_slots_partners_n,\n", - " obs_slots_lane_n=env.obs_slots_lane_kept,\n", - " obs_slots_boundary_n=env.obs_slots_boundary_kept,\n", - " obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,\n", - ")\n", - "print(f\"ego: {ego.shape} = {ego}\")\n", - "print(f\"target: {target.shape}\")\n", - "print(f\"partners: {partners.shape}\")\n", - "print(f\"lanes: {lanes.shape}\")\n", - "print(f\"boundaries: {boundaries.shape}\")\n", - "print(f\"traffic: {traffic.shape}\")\n", - "\n", - "\n", - "labels = [\n", - " \"speed\",\n", - " \"width\",\n", - " \"length\",\n", - " \"steering\",\n", - " \"a_long\",\n", - " \"a_lat\",\n", - " \"lane_center_dist_01\",\n", - " \"lane_heading_cos\",\n", - " \"speed_limit\",\n", - "]\n", - "for name, val in zip(labels, ego):\n", - " print(f\" {name}: {val:.4f}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Manual slice verification" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "o = obs[0] # first agent flat obs\n", - "idx = 0\n", - "\n", - "# Ego\n", - "ego_manual = o[idx : idx + env.ego_features]\n", - "idx += env.ego_features\n", - "assert np.allclose(ego_manual, ego), f\"ego mismatch: {ego_manual} vs {ego}\"\n", - "\n", - "# Reward conditioning coefs\n", - "coefs_manual = o[idx : idx + env.num_reward_coefs]\n", - "idx += env.num_reward_coefs\n", - "\n", - "# Target\n", - "target_manual = o[idx : idx + env.num_target_waypoints * env.target_features].reshape(\n", - " env.num_target_waypoints, env.target_features\n", - ")\n", - "idx += env.num_target_waypoints * env.target_features\n", - "assert np.allclose(target_manual, target), \"target mismatch\"\n", - "\n", - "# Partners\n", - "partners_manual = o[idx : idx + env.obs_slots_partners_n * env.partner_features].reshape(\n", - " env.obs_slots_partners_n, env.partner_features\n", - ")\n", - "idx += env.obs_slots_partners_n * env.partner_features\n", - "assert np.allclose(partners_manual, partners), \"partners mismatch\"\n", - "\n", - "# Lanes\n", - "lanes_manual = o[idx : idx + env.obs_slots_lane_kept * env.road_features].reshape(\n", - " env.obs_slots_lane_kept, env.road_features\n", - ")\n", - "idx += env.obs_slots_lane_kept * env.road_features\n", - "assert np.allclose(lanes_manual, lanes), \"lanes mismatch\"\n", - "\n", - "# Boundaries\n", - "bounds_manual = o[idx : idx + env.obs_slots_boundary_kept * env.road_features].reshape(\n", - " env.obs_slots_boundary_kept, env.road_features\n", - ")\n", - "idx += env.obs_slots_boundary_kept * env.road_features\n", - "assert np.allclose(bounds_manual, boundaries), \"boundaries mismatch\"\n", - "\n", - "# Traffic\n", - "traffic_manual = o[idx : idx + env.obs_slots_traffic_controls_n * env.traffic_control_features].reshape(\n", - " env.obs_slots_traffic_controls_n, env.traffic_control_features\n", - ")\n", - "idx += env.obs_slots_traffic_controls_n * env.traffic_control_features\n", - "assert np.allclose(traffic_manual, traffic), \"traffic mismatch\"\n", - "\n", - idx += 4 # appended slot-count features at end of obs - "assert idx == obs.shape[1], f\"obs size mismatch: used {idx}, total {obs.shape[1]}\"\n", - "print(f\"All slices match. Total features used: {idx}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Reward conditioning coefficients" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "coefs = obs[0, env.ego_features : env.ego_features + env.num_reward_coefs]\n", - "fig, ax = plt.subplots(figsize=(12, 4))\n", - "bars = ax.bar(range(env.num_reward_coefs), coefs, tick_label=COEF_NAMES)\n", - "ax.set_ylabel(\"Normalized coef value\")\n", - "ax.set_title(\"Reward conditioning coefficients (agent 0)\")\n", - "plt.xticks(rotation=45, ha=\"right\")\n", - "for bar, val in zip(bars, coefs):\n", - " ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height(), f\"{val:.3f}\", ha=\"center\", va=\"bottom\", fontsize=8)\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "# Compare across agents\n", - "all_coefs = obs[:, env.ego_features : env.ego_features + env.num_reward_coefs]\n", - "print(\"Coef stats across agents:\")\n", - "for i, name in enumerate(COEF_NAMES):\n", - " c = all_coefs[:, i]\n", - " print(f\" {name:15s}: mean={c.mean():.3f} std={c.std():.3f} min={c.min():.3f} max={c.max():.3f}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Partner observations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "partner_labels = [\"rel_x\", \"rel_y\", \"rel_z\", \"length\", \"width\", \"heading_cos\", \"heading_sin\", \"rel_vx\", \"rel_vy\"]\n", - "active_mask = ~np.all(partners == 0, axis=1)\n", - "n_active = active_mask.sum()\n", - "print(f\"Active partners: {n_active}/{env.obs_slots_partners_n}\")\n", - "\n", - "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", - "\n", - "# Heatmap\n", - "im = axes[0].imshow(partners, aspect=\"auto\", cmap=\"RdBu_r\", vmin=-1, vmax=1)\n", - "axes[0].set_xticks(range(env.partner_features))\n", - "axes[0].set_xticklabels(partner_labels, rotation=45, ha=\"right\")\n", - "axes[0].set_ylabel(\"Partner index\")\n", - "axes[0].set_title(f\"Partner obs heatmap ({n_active} active)\")\n", - "plt.colorbar(im, ax=axes[0])\n", - "\n", - "# Scatter in ego frame\n", - "active_partners = partners[active_mask]\n", - "if len(active_partners) > 0:\n", - " axes[1].scatter(active_partners[:, 0], active_partners[:, 1], c=\"gray\", s=100, edgecolors=\"black\")\n", - " for i, p in enumerate(active_partners):\n", - " axes[1].annotate(str(i), (p[0], p[1]), fontsize=8, ha=\"center\", va=\"bottom\")\n", - "axes[1].scatter(0, 0, c=\"blue\", s=200, marker=\"s\", label=\"ego\", zorder=10)\n", - "axes[1].set_xlabel(\"rel_x\")\n", - "axes[1].set_ylabel(\"rel_y\")\n", - "axes[1].set_title(\"Partners in ego frame\")\n", - "axes[1].legend()\n", - "axes[1].set_aspect(\"equal\")\n", - "axes[1].grid(True, alpha=0.3)\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Lane / boundary segments" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "road_labels = [\"rel_x\", \"rel_y\", \"rel_z\", \"length\", \"width\", \"dir_cos\", \"dir_sin\"]\n", - "\n", - "lane_active = ~np.all(lanes == 0, axis=1)\n", - "bound_active = ~np.all(boundaries == 0, axis=1)\n", - "print(\n", - " f\"Active lanes: {lane_active.sum()}/{env.obs_slots_lane_kept}, boundaries: {bound_active.sum()}/{env.obs_slots_boundary_kept}\"\n", - ")\n", - "\n", - "fig, ax = plt.subplots(figsize=(10, 10))\n", - "\n", - "# Mirror the canonical road rendering in pufferlib.viz.plot_observation\n", - "for seg in lanes[lane_active]:\n", - " x, y, z, length, width, dc, ds = seg\n", - " ax.scatter(x, y, color=\"lightgrey\", s=10, zorder=1)\n", - " ax.plot(\n", - " [x + dc * length / 2, x - dc * length / 2],\n", - " [y + ds * length / 2, y - ds * length / 2],\n", - " color=\"lightgrey\",\n", - " linewidth=1,\n", - " zorder=1,\n", - " )\n", - "\n", - "for seg in boundaries[bound_active]:\n", - " x, y, z, length, width, dc, ds = seg\n", - " ax.scatter(x, y, color=\"black\", s=10, zorder=1)\n", - " ax.plot(\n", - " [x + dc * length / 2, x - dc * length / 2],\n", - " [y + ds * length / 2, y - ds * length / 2],\n", - " color=\"black\",\n", - " linewidth=1,\n", - " zorder=1,\n", - " )\n", - "\n", - "ax.scatter(0, 0, color=\"blue\", s=200, marker=\"s\", label=\"ego\", zorder=10)\n", - "ax.text(\n", - " 0.12,\n", - " 0.95,\n", - " f\"Lanes: {lane_active.sum()}\\nBoundaries: {bound_active.sum()}\",\n", - " transform=ax.transAxes,\n", - " fontsize=10,\n", - " verticalalignment=\"top\",\n", - " bbox=dict(boxstyle=\"round\", facecolor=\"wheat\", alpha=0.8),\n", - ")\n", - "ax.axis((-1, 1, -1, 1))\n", - "ax.set_aspect(\"equal\", adjustable=\"box\")\n", - "ax.set_xlabel(\"X (ego frame)\")\n", - "ax.set_ylabel(\"Y (ego frame)\")\n", - "ax.set_title(\"Lane + boundary segments in ego frame\")\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Ego-centric view (pufferlib.viz)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "img = plot_observation(\n", - " obs[:1],\n", - " target_type=env.target_type,\n", - " reward_conditioning=env.reward_conditioning,\n", - " num_target_waypoints=env.num_target_waypoints,\n", - " obs_slots_partners_n=env.obs_slots_partners_n,\n", - " obs_slots_lane_n=env.obs_slots_lane_kept,\n", - " obs_slots_boundary_n=env.obs_slots_boundary_kept,\n", - " obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,\n", - ")\n", - "fig, ax = plt.subplots(figsize=(10, 10))\n", - "ax.imshow(img)\n", - "ax.axis(\"off\")\n", - "ax.set_title(\"Ego-centric observation (agent 0)\")\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Bird's eye view (simulator state)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "scenarios = env.get_state()\n", - "# get_state returns a list of scenario dicts (one per sub-env) or a single dict\n", - "if isinstance(scenarios, list):\n", - " scenario = scenarios[0]\n", - "else:\n", - " scenario = scenarios\n", - "\n", - "img = plot_simulator_state(scenario)\n", - "fig, ax = plt.subplots(figsize=(12, 12))\n", - "ax.imshow(img)\n", - "ax.axis(\"off\")\n", - "ax.set_title(\"Bird's eye view\")\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Multi-step: ego features over time" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "N_STEPS = 20\n", - "ego_history = np.zeros((N_STEPS, env.ego_features))\n", - "\n", - "for t in range(N_STEPS):\n", - " actions = zero_actions(env)\n", - " obs_t, _, _, _, _ = env.step(actions)\n", - " ego_history[t] = obs_t[0, : env.ego_features]\n", - "\n", - "fig, axes = plt.subplots(2, 2, figsize=(14, 8))\n", - "# Speed\n", - "axes[0, 0].plot(ego_history[:, 0])\n", - "axes[0, 0].set_title(\"speed\")\n", - "axes[0, 0].set_xlabel(\"step\")\n", - "# Steering\n", - "axes[0, 1].plot(ego_history[:, 3])\n", - "axes[0, 1].set_title(\"steering\")\n", - "axes[0, 1].set_xlabel(\"step\")\n", - "# a_long\n", - "axes[1, 0].plot(ego_history[:, 4])\n", - "axes[1, 0].set_title(\"a_long\")\n", - "axes[1, 0].set_xlabel(\"step\")\n", - "# a_lat\n", - "axes[1, 1].plot(ego_history[:, 5])\n", - "axes[1, 1].set_title(\"a_lat\")\n", - "axes[1, 1].set_xlabel(\"step\")\n", - "plt.suptitle(\"Agent 0 ego features over 20 steps (no-op action)\")\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Cross-agent distributions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Current obs across all agents\n", - "# Ego features (jerk): speed(0), width(1), length(2), steering(3), a_long(4), a_lat(5), lane_center(6), lane_heading(7), speed_limit(8)\n", - "speeds = obs[:, 0] # speed is at index 0\n", - "\n", - "# Target waypoints start after ego + reward coefs\n", - "target_start = env.ego_features + env.num_reward_coefs\n", - "# Each target waypoint has TARGET_F features; first two are rel_x, rel_y\n", - "first_target_x = obs[:, target_start]\n", - "first_target_y = obs[:, target_start + 1]\n", - "target_dists = np.sqrt(first_target_x**2 + first_target_y**2)\n", - "\n", - "# Count active partners per agent\n", - "partner_start = env.ego_features + env.num_reward_coefs + env.num_target_waypoints * env.target_features\n", - "partner_end = partner_start + env.obs_slots_partners_n * env.partner_features\n", - "all_partners = obs[:, partner_start:partner_end].reshape(-1, env.obs_slots_partners_n, env.partner_features)\n", - "partner_counts = (~np.all(all_partners == 0, axis=2)).sum(axis=1)\n", - "\n", - "fig, axes = plt.subplots(1, 3, figsize=(15, 4))\n", - "axes[0].hist(speeds, bins=20, edgecolor=\"black\", alpha=0.7)\n", - "axes[0].set_title(f\"Speed distribution (N={len(speeds)})\")\n", - "axes[0].set_xlabel(\"speed\")\n", - "\n", - "axes[1].hist(target_dists, bins=20, edgecolor=\"black\", alpha=0.7, color=\"orange\")\n", - "axes[1].set_title(\"Distance to first target waypoint\")\n", - "axes[1].set_xlabel(\"distance\")\n", - "\n", - "axes[2].hist(partner_counts, bins=range(env.obs_slots_partners_n + 2), edgecolor=\"black\", alpha=0.7, color=\"green\")\n", - "axes[2].set_title(\"Active partners per agent\")\n", - "axes[2].set_xlabel(\"count\")\n", - "plt.tight_layout()\n", - "plt.show()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/notebooks/01_observations.py b/notebooks/01_observations.py new file mode 100644 index 0000000000..7a60c9825b --- /dev/null +++ b/notebooks/01_observations.py @@ -0,0 +1,375 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.19.3 +# kernelspec: +# display_name: .venv +# language: python +# name: python3 +# --- + +# %% [markdown] +# # 01 - Observation Pipeline Debug +# Verify obs vector is correctly packed, normalized, interpretable. + +# %% +import numpy as np +import matplotlib.pyplot as plt +from pufferlib.viz import plot_observation, plot_simulator_state, unpack_obs +from notebooks.notebook_utils import COEF_NAMES, make_drive_env, zero_actions + +env, obs, info = make_drive_env() + +# %% [markdown] +# ## Raw obs inspection + +# %% +# Take first step so obs are populated +actions = zero_actions(env) + +obs, rew, term, trunc, info = env.step(actions) + +print(f"shape: {obs.shape}, dtype: {obs.dtype}") +print(f"min: {obs.min():.4f}, max: {obs.max():.4f}, mean: {obs.mean():.4f}, std: {obs.std():.4f}") +print(f"NaN: {np.isnan(obs).sum()}, Inf: {np.isinf(obs).sum()}") +print(f"% zeros: {(obs == 0).mean() * 100:.1f}%") +print(f"% outside [-1,1]: {((obs < -1) | (obs > 1)).mean() * 100:.2f}%") + +fig, axes = plt.subplots(1, 2, figsize=(14, 4)) +axes[0].hist(obs.flatten(), bins=100, edgecolor="black", alpha=0.7) +axes[0].set_title("Full obs distribution") +axes[0].set_xlabel("Value") +# Per-agent: show obs[0] vs obs[1] +for i in range(min(4, obs.shape[0])): + axes[1].plot(obs[i], alpha=0.5, label=f"agent {i}") +axes[1].set_title("Obs vector by index (first 4 agents)") +axes[1].legend() +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## Unpack with pufferlib.viz.unpack_obs + +# %% +ego, target, partners, lanes, boundaries, traffic = unpack_obs( + obs[:1], + target_type=env.target_type, + reward_conditioning=env.reward_conditioning, + num_target_waypoints=env.num_target_waypoints, + obs_slots_partners_n=env.obs_slots_partners_n, + obs_slots_lane_n=env.obs_slots_lane_kept, + obs_slots_boundary_n=env.obs_slots_boundary_kept, + obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n, +) +print(f"ego: {ego.shape} = {ego}") +print(f"target: {target.shape}") +print(f"partners: {partners.shape}") +print(f"lanes: {lanes.shape}") +print(f"boundaries: {boundaries.shape}") +print(f"traffic: {traffic.shape}") + + +labels = [ + "speed", + "width", + "length", + "steering", + "a_long", + "a_lat", + "lane_center_dist_01", + "lane_heading_cos", + "speed_limit", +] +for name, val in zip(labels, ego): + print(f" {name}: {val:.4f}") + +# %% [markdown] +# ## Manual slice verification + +# %% +o = obs[0] # first agent flat obs +idx = 0 + +# Ego +ego_manual = o[idx : idx + env.ego_features] +idx += env.ego_features +assert np.allclose(ego_manual, ego), f"ego mismatch: {ego_manual} vs {ego}" + +# Reward conditioning coefs +coefs_manual = o[idx : idx + env.num_reward_coefs] +idx += env.num_reward_coefs + +# Target +target_manual = o[idx : idx + env.num_target_waypoints * env.target_features].reshape( + env.num_target_waypoints, env.target_features +) +idx += env.num_target_waypoints * env.target_features +assert np.allclose(target_manual, target), "target mismatch" + +# Partners +partners_manual = o[idx : idx + env.obs_slots_partners_n * env.partner_features].reshape( + env.obs_slots_partners_n, env.partner_features +) +idx += env.obs_slots_partners_n * env.partner_features +assert np.allclose(partners_manual, partners), "partners mismatch" + +# Lanes +lanes_manual = o[idx : idx + env.obs_slots_lane_kept * env.road_features].reshape( + env.obs_slots_lane_kept, env.road_features +) +idx += env.obs_slots_lane_kept * env.road_features +assert np.allclose(lanes_manual, lanes), "lanes mismatch" + +# Boundaries +bounds_manual = o[idx : idx + env.obs_slots_boundary_kept * env.road_features].reshape( + env.obs_slots_boundary_kept, env.road_features +) +idx += env.obs_slots_boundary_kept * env.road_features +assert np.allclose(bounds_manual, boundaries), "boundaries mismatch" + +# Traffic +traffic_manual = o[idx : idx + env.obs_slots_traffic_controls_n * env.traffic_control_features].reshape( + env.obs_slots_traffic_controls_n, env.traffic_control_features +) +idx += env.obs_slots_traffic_controls_n * env.traffic_control_features +assert np.allclose(traffic_manual, traffic), "traffic mismatch" + +idx += 4 # appended slot-count features at end of obs +assert idx == obs.shape[1], f"obs size mismatch: used {idx}, total {obs.shape[1]}" +print(f"All slices match. Total features used: {idx}") + +# %% [markdown] +# ## Reward conditioning coefficients + +# %% +coefs = obs[0, env.ego_features : env.ego_features + env.num_reward_coefs] +fig, ax = plt.subplots(figsize=(12, 4)) +bars = ax.bar(range(env.num_reward_coefs), coefs, tick_label=COEF_NAMES) +ax.set_ylabel("Normalized coef value") +ax.set_title("Reward conditioning coefficients (agent 0)") +plt.xticks(rotation=45, ha="right") +for bar, val in zip(bars, coefs): + ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height(), f"{val:.3f}", ha="center", va="bottom", fontsize=8) +plt.tight_layout() +plt.show() + +# Compare across agents +all_coefs = obs[:, env.ego_features : env.ego_features + env.num_reward_coefs] +print("Coef stats across agents:") +for i, name in enumerate(COEF_NAMES): + c = all_coefs[:, i] + print(f" {name:15s}: mean={c.mean():.3f} std={c.std():.3f} min={c.min():.3f} max={c.max():.3f}") + +# %% [markdown] +# ## Partner observations + +# %% +partner_labels = [ + "rel_x", + "rel_y", + "rel_z", + "length", + "width", + "heading_cos", + "heading_sin", + "rel_vx", + "rel_vy", + "seconds_stopped", +] +active_mask = ~np.all(partners == 0, axis=1) +n_active = active_mask.sum() +print(f"Active partners: {n_active}/{env.obs_slots_partners_n}") + +fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + +# Heatmap +im = axes[0].imshow(partners, aspect="auto", cmap="RdBu_r", vmin=-1, vmax=1) +axes[0].set_xticks(range(env.partner_features)) +axes[0].set_xticklabels(partner_labels, rotation=45, ha="right") +axes[0].set_ylabel("Partner index") +axes[0].set_title(f"Partner obs heatmap ({n_active} active)") +plt.colorbar(im, ax=axes[0]) + +# Scatter in ego frame +active_partners = partners[active_mask] +if len(active_partners) > 0: + axes[1].scatter(active_partners[:, 0], active_partners[:, 1], c="gray", s=100, edgecolors="black") + for i, p in enumerate(active_partners): + axes[1].annotate(str(i), (p[0], p[1]), fontsize=8, ha="center", va="bottom") +axes[1].scatter(0, 0, c="blue", s=200, marker="s", label="ego", zorder=10) +axes[1].set_xlabel("rel_x") +axes[1].set_ylabel("rel_y") +axes[1].set_title("Partners in ego frame") +axes[1].legend() +axes[1].set_aspect("equal") +axes[1].grid(True, alpha=0.3) +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## Lane / boundary segments + +# %% +road_labels = ["rel_x", "rel_y", "rel_z", "length", "width", "dir_cos", "dir_sin"] + +lane_active = ~np.all(lanes == 0, axis=1) +bound_active = ~np.all(boundaries == 0, axis=1) +print( + f"Active lanes: {lane_active.sum()}/{env.obs_slots_lane_kept}, boundaries: {bound_active.sum()}/{env.obs_slots_boundary_kept}" +) + +fig, ax = plt.subplots(figsize=(10, 10)) + +# Mirror the canonical road rendering in pufferlib.viz.plot_observation +for seg in lanes[lane_active]: + x, y, z, length, width, dc, ds = seg + ax.scatter(x, y, color="lightgrey", s=10, zorder=1) + ax.plot( + [x + dc * length / 2, x - dc * length / 2], + [y + ds * length / 2, y - ds * length / 2], + color="lightgrey", + linewidth=1, + zorder=1, + ) + +for seg in boundaries[bound_active]: + x, y, z, length, width, dc, ds = seg + ax.scatter(x, y, color="black", s=10, zorder=1) + ax.plot( + [x + dc * length / 2, x - dc * length / 2], + [y + ds * length / 2, y - ds * length / 2], + color="black", + linewidth=1, + zorder=1, + ) + +ax.scatter(0, 0, color="blue", s=200, marker="s", label="ego", zorder=10) +ax.text( + 0.12, + 0.95, + f"Lanes: {lane_active.sum()}\nBoundaries: {bound_active.sum()}", + transform=ax.transAxes, + fontsize=10, + verticalalignment="top", + bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8), +) +ax.axis((-1, 1, -1, 1)) +ax.set_aspect("equal", adjustable="box") +ax.set_xlabel("X (ego frame)") +ax.set_ylabel("Y (ego frame)") +ax.set_title("Lane + boundary segments in ego frame") +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## Ego-centric view (pufferlib.viz) + +# %% +img = plot_observation( + obs[:1], + target_type=env.target_type, + reward_conditioning=env.reward_conditioning, + num_target_waypoints=env.num_target_waypoints, + obs_slots_partners_n=env.obs_slots_partners_n, + obs_slots_lane_n=env.obs_slots_lane_kept, + obs_slots_boundary_n=env.obs_slots_boundary_kept, + obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n, +) +fig, ax = plt.subplots(figsize=(10, 10)) +ax.imshow(img) +ax.axis("off") +ax.set_title("Ego-centric observation (agent 0)") +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## Bird's eye view (simulator state) + +# %% +scenarios = env.get_state() +# get_state returns a list of scenario dicts (one per sub-env) or a single dict +if isinstance(scenarios, list): + scenario = scenarios[0] +else: + scenario = scenarios + +img = plot_simulator_state(scenario) +fig, ax = plt.subplots(figsize=(12, 12)) +ax.imshow(img) +ax.axis("off") +ax.set_title("Bird's eye view") +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## Multi-step: ego features over time + +# %% +N_STEPS = 20 +ego_history = np.zeros((N_STEPS, env.ego_features)) + +for t in range(N_STEPS): + actions = zero_actions(env) + obs_t, _, _, _, _ = env.step(actions) + ego_history[t] = obs_t[0, : env.ego_features] + +fig, axes = plt.subplots(2, 2, figsize=(14, 8)) +# Speed +axes[0, 0].plot(ego_history[:, 0]) +axes[0, 0].set_title("speed") +axes[0, 0].set_xlabel("step") +# Steering +axes[0, 1].plot(ego_history[:, 3]) +axes[0, 1].set_title("steering") +axes[0, 1].set_xlabel("step") +# a_long +axes[1, 0].plot(ego_history[:, 4]) +axes[1, 0].set_title("a_long") +axes[1, 0].set_xlabel("step") +# a_lat +axes[1, 1].plot(ego_history[:, 5]) +axes[1, 1].set_title("a_lat") +axes[1, 1].set_xlabel("step") +plt.suptitle("Agent 0 ego features over 20 steps (no-op action)") +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## Cross-agent distributions + +# %% +# Current obs across all agents +# Ego features (jerk): speed(0), width(1), length(2), steering(3), a_long(4), a_lat(5), lane_center(6), lane_heading(7), speed_limit(8) +speeds = obs[:, 0] # speed is at index 0 + +# Target waypoints start after ego + reward coefs +target_start = env.ego_features + env.num_reward_coefs +# Each target waypoint has TARGET_F features; first two are rel_x, rel_y +first_target_x = obs[:, target_start] +first_target_y = obs[:, target_start + 1] +target_dists = np.sqrt(first_target_x**2 + first_target_y**2) + +# Count active partners per agent +partner_start = env.ego_features + env.num_reward_coefs + env.num_target_waypoints * env.target_features +partner_end = partner_start + env.obs_slots_partners_n * env.partner_features +all_partners = obs[:, partner_start:partner_end].reshape(-1, env.obs_slots_partners_n, env.partner_features) +partner_counts = (~np.all(all_partners == 0, axis=2)).sum(axis=1) + +fig, axes = plt.subplots(1, 3, figsize=(15, 4)) +axes[0].hist(speeds, bins=20, edgecolor="black", alpha=0.7) +axes[0].set_title(f"Speed distribution (N={len(speeds)})") +axes[0].set_xlabel("speed") + +axes[1].hist(target_dists, bins=20, edgecolor="black", alpha=0.7, color="orange") +axes[1].set_title("Distance to first target waypoint") +axes[1].set_xlabel("distance") + +axes[2].hist(partner_counts, bins=range(env.obs_slots_partners_n + 2), edgecolor="black", alpha=0.7, color="green") +axes[2].set_title("Active partners per agent") +axes[2].set_xlabel("count") +plt.tight_layout() +plt.show() diff --git a/notebooks/02_rewards.ipynb b/notebooks/02_rewards.ipynb deleted file mode 100644 index 63372f8f9d..0000000000 --- a/notebooks/02_rewards.ipynb +++ /dev/null @@ -1,332 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 02 - Reward Signals Debug\n", - "Understand reward magnitudes, components, and correlation with agent behavior." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "from notebooks.notebook_utils import COEF_NAMES, make_drive_env, random_actions, zero_actions\n", - "\n", - "env, obs, info = make_drive_env()\n", - "\n", - "print(\n", - " f\"env ready: {env.num_agents} agents, obs={obs.shape}, act_shape={(env.num_agents, len(env.single_action_space.nvec))}\"\n", - ")\n", - "print(\n", - " f\"ego_features={env.ego_features}, num_reward_coefs={env.num_reward_coefs}, obs_slots_partners_n={env.obs_slots_partners_n}, partner_features={env.partner_features}\"\n", - ")\n", - "print(\n", - " f\"obs_slots_lane_kept={env.obs_slots_lane_kept}, obs_slots_boundary_kept={env.obs_slots_boundary_kept}, road_features={env.road_features}\"\n", - ")\n", - "print(\n", - " f\"obs_slots_traffic_controls_n={env.obs_slots_traffic_controls_n}, traffic_control_features={env.traffic_control_features}\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Single step: no-op reward distribution" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "actions = zero_actions(env)\n", - "obs, rew, term, trunc, info = env.step(actions)\n", - "\n", - "print(f\"reward shape: {rew.shape}\")\n", - "print(f\"min: {rew.min():.6f}, max: {rew.max():.6f}, mean: {rew.mean():.6f}, std: {rew.std():.6f}\")\n", - "print(f\"NaN: {np.isnan(rew).sum()}, all zero: {(rew == 0).all()}\")\n", - "print(f\"terminals: {term.sum()}, truncations: {trunc.sum()}\")\n", - "\n", - "fig, ax = plt.subplots(figsize=(8, 4))\n", - "ax.bar(range(len(rew)), rew, color=[\"red\" if r < 0 else \"green\" for r in rew])\n", - "ax.set_xlabel(\"Agent index\")\n", - "ax.set_ylabel(\"Reward\")\n", - "ax.set_title(\"Single step reward (no-op action)\")\n", - "ax.axhline(0, color=\"black\", lw=0.5)\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 100-step rollout: reward heatmap and cumulative returns" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "N_STEPS = 100\n", - "rewards_history = np.zeros((N_STEPS, env.num_agents))\n", - "terms_history = np.zeros((N_STEPS, env.num_agents))\n", - "\n", - "for t in range(N_STEPS):\n", - " actions = random_actions(env)\n", - " obs, rew, term, trunc, info = env.step(actions)\n", - " rewards_history[t] = rew\n", - " terms_history[t] = term\n", - "\n", - "fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n", - "\n", - "axes[0].plot(rewards_history.mean(axis=1))\n", - "axes[0].set_xlabel(\"Step\")\n", - "axes[0].set_ylabel(\"Mean reward\")\n", - "axes[0].set_title(\"Mean reward per step\")\n", - "\n", - "im = axes[1].imshow(rewards_history.T, aspect=\"auto\", cmap=\"RdYlGn\", interpolation=\"nearest\")\n", - "axes[1].set_xlabel(\"Step\")\n", - "axes[1].set_ylabel(\"Agent\")\n", - "axes[1].set_title(\"Reward heatmap (steps x agents)\")\n", - "plt.colorbar(im, ax=axes[1])\n", - "\n", - "cum_returns = rewards_history.cumsum(axis=0)\n", - "for i in range(min(8, env.num_agents)):\n", - " axes[2].plot(cum_returns[:, i], alpha=0.6, label=f\"agent {i}\")\n", - "axes[2].set_xlabel(\"Step\")\n", - "axes[2].set_ylabel(\"Cumulative return\")\n", - "axes[2].set_title(\"Cumulative returns\")\n", - "axes[2].legend(fontsize=7)\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "print(f\"Total reward stats: mean={rewards_history.mean():.5f}, std={rewards_history.std():.5f}\")\n", - "print(f\"Per-episode return (100 steps): mean={cum_returns[-1].mean():.3f}, std={cum_returns[-1].std():.3f}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Reward coefficient inspection" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "all_coefs = obs[:, env.ego_features : env.ego_features + env.num_reward_coefs]\n", - "print(f\"Reward coefs shape: {all_coefs.shape}\")\n", - "print()\n", - "print(f\"{'Coef':>15s} | {'mean':>8s} {'std':>8s} {'min':>8s} {'max':>8s}\")\n", - "print(\"-\" * 55)\n", - "for i, name in enumerate(COEF_NAMES):\n", - " c = all_coefs[:, i]\n", - " print(f\"{name:>15s} | {c.mean():8.4f} {c.std():8.4f} {c.min():8.4f} {c.max():8.4f}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Terminal analysis" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "N_STEPS = 200\n", - "term_steps, trunc_steps = [], []\n", - "term_rewards, trunc_rewards = [], []\n", - "\n", - "for t in range(N_STEPS):\n", - " actions = random_actions(env)\n", - " obs, rew, term, trunc, info = env.step(actions)\n", - " for i in range(env.num_agents):\n", - " if term[i]:\n", - " term_steps.append(t)\n", - " term_rewards.append(rew[i])\n", - " if trunc[i]:\n", - " trunc_steps.append(t)\n", - " trunc_rewards.append(rew[i])\n", - "\n", - "print(f\"Terminals: {len(term_steps)}, Truncations: {len(trunc_steps)}\")\n", - "if term_rewards:\n", - " tr = np.array(term_rewards)\n", - " print(f\"Terminal reward: mean={tr.mean():.4f}, std={tr.std():.4f}\")\n", - " n_positive = (tr > 0).sum()\n", - " n_negative = (tr < 0).sum()\n", - " n_zero = (tr == 0).sum()\n", - " print(f\" positive: {n_positive}, negative: {n_negative}, zero: {n_zero}\")\n", - "\n", - "fig, ax = plt.subplots(figsize=(10, 4))\n", - "if term_steps:\n", - " ax.scatter(term_steps, term_rewards, c=\"red\", s=20, alpha=0.5, label=f\"terminal ({len(term_steps)})\")\n", - "if trunc_steps:\n", - " ax.scatter(trunc_steps, trunc_rewards, c=\"blue\", s=20, alpha=0.5, label=f\"truncation ({len(trunc_steps)})\")\n", - "ax.axhline(0, color=\"black\", lw=0.5)\n", - "ax.set_xlabel(\"Step\")\n", - "ax.set_ylabel(\"Reward at terminal/truncation\")\n", - "ax.set_title(\"Terminal events over 200 steps\")\n", - "ax.legend()\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Goal detection: high reward events" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "N_STEPS = 512\n", - "goal_events = []\n", - "\n", - "for t in range(N_STEPS):\n", - " prev_obs = obs.copy()\n", - " actions = random_actions(env)\n", - " obs, rew, term, trunc, info = env.step(actions)\n", - " for i in range(env.num_agents):\n", - " if rew[i] >= 0.5:\n", - " target_start = env.ego_features + env.num_reward_coefs\n", - " goal_dist = np.sqrt(prev_obs[i, target_start] ** 2 + prev_obs[i, target_start + 1] ** 2)\n", - " goal_events.append((t, i, rew[i], goal_dist))\n", - "\n", - "print(f\"Goal-like events (reward >= 0.5): {len(goal_events)}\")\n", - "if goal_events:\n", - " events = np.array(goal_events)\n", - " print(f\"Reward range: [{events[:, 2].min():.3f}, {events[:, 2].max():.3f}]\")\n", - " print(f\"Goal distance at event: mean={events[:, 3].mean():.3f}, std={events[:, 3].std():.3f}\")\n", - "\n", - " fig, axes = plt.subplots(1, 2, figsize=(12, 4))\n", - " axes[0].hist(events[:, 2], bins=20, edgecolor=\"black\", alpha=0.7, color=\"gold\")\n", - " axes[0].set_title(\"Reward magnitude at goal events\")\n", - " axes[0].set_xlabel(\"Reward\")\n", - " axes[1].scatter(events[:, 3], events[:, 2], alpha=0.5)\n", - " axes[1].set_xlabel(\"Goal distance before event\")\n", - " axes[1].set_ylabel(\"Reward\")\n", - " axes[1].set_title(\"Goal distance vs reward\")\n", - " plt.tight_layout()\n", - " plt.show()\n", - "else:\n", - " print(\"No goal events detected in 512 steps with random actions\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Reward scale for PPO" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "all_rewards = rewards_history.flatten()\n", - "episodic_returns = rewards_history.sum(axis=0)\n", - "\n", - "fig, axes = plt.subplots(1, 2, figsize=(12, 4))\n", - "axes[0].hist(all_rewards[all_rewards != 0], bins=50, edgecolor=\"black\", alpha=0.7)\n", - "axes[0].set_title(f\"Per-step reward distribution (non-zero, N={(all_rewards != 0).sum()})\")\n", - "axes[0].set_xlabel(\"Reward\")\n", - "\n", - "axes[1].hist(episodic_returns, bins=20, edgecolor=\"black\", alpha=0.7, color=\"purple\")\n", - "axes[1].set_title(f\"Episodic return (100 steps): mean={episodic_returns.mean():.3f}\")\n", - "axes[1].set_xlabel(\"Return\")\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "print(f\"Reward magnitude range: [{all_rewards.min():.5f}, {all_rewards.max():.5f}]\")\n", - "print(f\"Mean episodic return: {episodic_returns.mean():.4f} +/- {episodic_returns.std():.4f}\")\n", - "if abs(episodic_returns.mean()) > 10:\n", - " print(\"WARNING: large episodic returns, consider scaling\")\n", - "if episodic_returns.std() < 1e-6:\n", - " print(\"WARNING: near-zero return variance\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Action-reward correlation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "STEPS_PER_ACTION = 20\n", - "action_rewards = {}\n", - "\n", - "for a in range(env.single_action_space.nvec[0]):\n", - " rews = []\n", - " for _ in range(STEPS_PER_ACTION):\n", - " actions = np.full((env.num_agents, len(env.single_action_space.nvec)), a, dtype=np.int64)\n", - " obs, rew, term, trunc, info = env.step(actions)\n", - " rews.append(rew.mean())\n", - " action_rewards[a] = np.mean(rews)\n", - "\n", - "fig, ax = plt.subplots(figsize=(10, 5))\n", - "actions_list = sorted(action_rewards.keys())\n", - "means = [action_rewards[a] for a in actions_list]\n", - "colors = [\"green\" if m > 0 else \"red\" for m in means]\n", - "labels = [f\"{a // 3}L,{a % 3}R\" for a in actions_list]\n", - "ax.bar(range(len(actions_list)), means, tick_label=labels, color=colors, edgecolor=\"black\")\n", - "ax.set_xlabel(\"Action (longitudinal, lateral)\")\n", - "ax.set_ylabel(\"Mean reward\")\n", - "ax.set_title(f\"Mean reward per action over {STEPS_PER_ACTION} steps\")\n", - "ax.axhline(0, color=\"black\", lw=0.5)\n", - "plt.tight_layout()\n", - "plt.show()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/notebooks/02_rewards.py b/notebooks/02_rewards.py new file mode 100644 index 0000000000..536f2e0850 --- /dev/null +++ b/notebooks/02_rewards.py @@ -0,0 +1,241 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.19.3 +# kernelspec: +# display_name: .venv +# language: python +# name: python3 +# --- + +# %% [markdown] +# # 02 - Reward Signals Debug +# Understand reward magnitudes, components, and correlation with agent behavior. + +# %% +import numpy as np +import matplotlib.pyplot as plt +from notebooks.notebook_utils import COEF_NAMES, make_drive_env, random_actions, zero_actions + +env, obs, info = make_drive_env() + +print( + f"env ready: {env.num_agents} agents, obs={obs.shape}, act_shape={(env.num_agents, len(env.single_action_space.nvec))}" +) +print( + f"ego_features={env.ego_features}, num_reward_coefs={env.num_reward_coefs}, obs_slots_partners_n={env.obs_slots_partners_n}, partner_features={env.partner_features}" +) +print( + f"obs_slots_lane_kept={env.obs_slots_lane_kept}, obs_slots_boundary_kept={env.obs_slots_boundary_kept}, road_features={env.road_features}" +) +print( + f"obs_slots_traffic_controls_n={env.obs_slots_traffic_controls_n}, traffic_control_features={env.traffic_control_features}" +) + +# %% [markdown] +# ## Single step: no-op reward distribution + +# %% +actions = zero_actions(env) +obs, rew, term, trunc, info = env.step(actions) + +print(f"reward shape: {rew.shape}") +print(f"min: {rew.min():.6f}, max: {rew.max():.6f}, mean: {rew.mean():.6f}, std: {rew.std():.6f}") +print(f"NaN: {np.isnan(rew).sum()}, all zero: {(rew == 0).all()}") +print(f"terminals: {term.sum()}, truncations: {trunc.sum()}") + +fig, ax = plt.subplots(figsize=(8, 4)) +ax.bar(range(len(rew)), rew, color=["red" if r < 0 else "green" for r in rew]) +ax.set_xlabel("Agent index") +ax.set_ylabel("Reward") +ax.set_title("Single step reward (no-op action)") +ax.axhline(0, color="black", lw=0.5) +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## 100-step rollout: reward heatmap and cumulative returns + +# %% +N_STEPS = 100 +rewards_history = np.zeros((N_STEPS, env.num_agents)) +terms_history = np.zeros((N_STEPS, env.num_agents)) + +for t in range(N_STEPS): + actions = random_actions(env) + obs, rew, term, trunc, info = env.step(actions) + rewards_history[t] = rew + terms_history[t] = term + +fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + +axes[0].plot(rewards_history.mean(axis=1)) +axes[0].set_xlabel("Step") +axes[0].set_ylabel("Mean reward") +axes[0].set_title("Mean reward per step") + +im = axes[1].imshow(rewards_history.T, aspect="auto", cmap="RdYlGn", interpolation="nearest") +axes[1].set_xlabel("Step") +axes[1].set_ylabel("Agent") +axes[1].set_title("Reward heatmap (steps x agents)") +plt.colorbar(im, ax=axes[1]) + +cum_returns = rewards_history.cumsum(axis=0) +for i in range(min(8, env.num_agents)): + axes[2].plot(cum_returns[:, i], alpha=0.6, label=f"agent {i}") +axes[2].set_xlabel("Step") +axes[2].set_ylabel("Cumulative return") +axes[2].set_title("Cumulative returns") +axes[2].legend(fontsize=7) +plt.tight_layout() +plt.show() + +print(f"Total reward stats: mean={rewards_history.mean():.5f}, std={rewards_history.std():.5f}") +print(f"Per-episode return (100 steps): mean={cum_returns[-1].mean():.3f}, std={cum_returns[-1].std():.3f}") + +# %% [markdown] +# ## Reward coefficient inspection + +# %% +all_coefs = obs[:, env.ego_features : env.ego_features + env.num_reward_coefs] +print(f"Reward coefs shape: {all_coefs.shape}") +print() +print(f"{'Coef':>15s} | {'mean':>8s} {'std':>8s} {'min':>8s} {'max':>8s}") +print("-" * 55) +for i, name in enumerate(COEF_NAMES): + c = all_coefs[:, i] + print(f"{name:>15s} | {c.mean():8.4f} {c.std():8.4f} {c.min():8.4f} {c.max():8.4f}") + +# %% [markdown] +# ## Terminal analysis + +# %% +N_STEPS = 200 +term_steps, trunc_steps = [], [] +term_rewards, trunc_rewards = [], [] + +for t in range(N_STEPS): + actions = random_actions(env) + obs, rew, term, trunc, info = env.step(actions) + for i in range(env.num_agents): + if term[i]: + term_steps.append(t) + term_rewards.append(rew[i]) + if trunc[i]: + trunc_steps.append(t) + trunc_rewards.append(rew[i]) + +print(f"Terminals: {len(term_steps)}, Truncations: {len(trunc_steps)}") +if term_rewards: + tr = np.array(term_rewards) + print(f"Terminal reward: mean={tr.mean():.4f}, std={tr.std():.4f}") + n_positive = (tr > 0).sum() + n_negative = (tr < 0).sum() + n_zero = (tr == 0).sum() + print(f" positive: {n_positive}, negative: {n_negative}, zero: {n_zero}") + +fig, ax = plt.subplots(figsize=(10, 4)) +if term_steps: + ax.scatter(term_steps, term_rewards, c="red", s=20, alpha=0.5, label=f"terminal ({len(term_steps)})") +if trunc_steps: + ax.scatter(trunc_steps, trunc_rewards, c="blue", s=20, alpha=0.5, label=f"truncation ({len(trunc_steps)})") +ax.axhline(0, color="black", lw=0.5) +ax.set_xlabel("Step") +ax.set_ylabel("Reward at terminal/truncation") +ax.set_title("Terminal events over 200 steps") +ax.legend() +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## Goal detection: high reward events + +# %% +N_STEPS = 512 +goal_events = [] + +for t in range(N_STEPS): + prev_obs = obs.copy() + actions = random_actions(env) + obs, rew, term, trunc, info = env.step(actions) + for i in range(env.num_agents): + if rew[i] >= 0.5: + target_start = env.ego_features + env.num_reward_coefs + goal_dist = np.sqrt(prev_obs[i, target_start] ** 2 + prev_obs[i, target_start + 1] ** 2) + goal_events.append((t, i, rew[i], goal_dist)) + +print(f"Goal-like events (reward >= 0.5): {len(goal_events)}") +if goal_events: + events = np.array(goal_events) + print(f"Reward range: [{events[:, 2].min():.3f}, {events[:, 2].max():.3f}]") + print(f"Goal distance at event: mean={events[:, 3].mean():.3f}, std={events[:, 3].std():.3f}") + + fig, axes = plt.subplots(1, 2, figsize=(12, 4)) + axes[0].hist(events[:, 2], bins=20, edgecolor="black", alpha=0.7, color="gold") + axes[0].set_title("Reward magnitude at goal events") + axes[0].set_xlabel("Reward") + axes[1].scatter(events[:, 3], events[:, 2], alpha=0.5) + axes[1].set_xlabel("Goal distance before event") + axes[1].set_ylabel("Reward") + axes[1].set_title("Goal distance vs reward") + plt.tight_layout() + plt.show() +else: + print("No goal events detected in 512 steps with random actions") + +# %% [markdown] +# ## Reward scale for PPO + +# %% +all_rewards = rewards_history.flatten() +episodic_returns = rewards_history.sum(axis=0) + +fig, axes = plt.subplots(1, 2, figsize=(12, 4)) +axes[0].hist(all_rewards[all_rewards != 0], bins=50, edgecolor="black", alpha=0.7) +axes[0].set_title(f"Per-step reward distribution (non-zero, N={(all_rewards != 0).sum()})") +axes[0].set_xlabel("Reward") + +axes[1].hist(episodic_returns, bins=20, edgecolor="black", alpha=0.7, color="purple") +axes[1].set_title(f"Episodic return (100 steps): mean={episodic_returns.mean():.3f}") +axes[1].set_xlabel("Return") +plt.tight_layout() +plt.show() + +print(f"Reward magnitude range: [{all_rewards.min():.5f}, {all_rewards.max():.5f}]") +print(f"Mean episodic return: {episodic_returns.mean():.4f} +/- {episodic_returns.std():.4f}") +if abs(episodic_returns.mean()) > 10: + print("WARNING: large episodic returns, consider scaling") +if episodic_returns.std() < 1e-6: + print("WARNING: near-zero return variance") + +# %% [markdown] +# ## Action-reward correlation + +# %% +STEPS_PER_ACTION = 20 +action_rewards = {} + +for a in range(env.single_action_space.nvec[0]): + rews = [] + for _ in range(STEPS_PER_ACTION): + actions = np.full((env.num_agents, len(env.single_action_space.nvec)), a, dtype=np.int64) + obs, rew, term, trunc, info = env.step(actions) + rews.append(rew.mean()) + action_rewards[a] = np.mean(rews) + +fig, ax = plt.subplots(figsize=(10, 5)) +actions_list = sorted(action_rewards.keys()) +means = [action_rewards[a] for a in actions_list] +colors = ["green" if m > 0 else "red" for m in means] +labels = [f"{a // 3}L,{a % 3}R" for a in actions_list] +ax.bar(range(len(actions_list)), means, tick_label=labels, color=colors, edgecolor="black") +ax.set_xlabel("Action (longitudinal, lateral)") +ax.set_ylabel("Mean reward") +ax.set_title(f"Mean reward per action over {STEPS_PER_ACTION} steps") +ax.axhline(0, color="black", lw=0.5) +plt.tight_layout() +plt.show() diff --git a/notebooks/03_metrics.ipynb b/notebooks/03_metrics.ipynb deleted file mode 100644 index d2e5084a66..0000000000 --- a/notebooks/03_metrics.ipynb +++ /dev/null @@ -1,330 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 03 - Episode Metrics & Logging Debug\n", - "Verify vec_log returns correct metrics, aggregation is sane, episode boundaries handled." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "from pufferlib.ocean.drive import binding\n", - "from notebooks.notebook_utils import make_drive_env, random_actions\n", - "\n", - "\n", - "env, obs, info = make_drive_env()\n", - "\n", - "print(\n", - " f\"env ready: {env.num_agents} agents, obs={obs.shape}, act_shape={(env.num_agents, len(env.single_action_space.nvec))}\"\n", - ")\n", - "print(\n", - " f\"ego_features={env.ego_features}, num_reward_coefs={env.num_reward_coefs}, obs_slots_partners_n={env.obs_slots_partners_n}, partner_features={env.partner_features}\"\n", - ")\n", - "print(\n", - " f\"obs_slots_lane_kept={env.obs_slots_lane_kept}, obs_slots_boundary_kept={env.obs_slots_boundary_kept}, road_features={env.road_features}\"\n", - ")\n", - "print(\n", - " f\"obs_slots_traffic_controls_n={env.obs_slots_traffic_controls_n}, traffic_control_features={env.traffic_control_features}\"\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Single vec_log call" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for _ in range(10):\n", - " actions = random_actions(env)\n", - " obs, rew, term, trunc, info = env.step(actions)\n", - "\n", - "log = binding.vec_log(env.c_envs, env.num_agents)\n", - "print(f\"vec_log type: {type(log)}\")\n", - "if log:\n", - " print(f\"Keys: {sorted(log.keys())}\")\n", - " for k, v in sorted(log.items()):\n", - " print(f\" {k}: {v}\")\n", - "else:\n", - " print(\"vec_log returned empty/None\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 512-step collection: all info dicts" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "N_STEPS = 512\n", - "all_logs = []\n", - "all_rewards = np.zeros((N_STEPS, env.num_agents))\n", - "all_terms = np.zeros((N_STEPS, env.num_agents))\n", - "all_truncs = np.zeros((N_STEPS, env.num_agents))\n", - "\n", - "for t in range(N_STEPS):\n", - " actions = random_actions(env)\n", - " obs, rew, term, trunc, info = env.step(actions)\n", - " all_rewards[t] = rew\n", - " all_terms[t] = term\n", - " all_truncs[t] = trunc\n", - " if info:\n", - " for log_entry in info:\n", - " log_entry[\"_step\"] = t\n", - " all_logs.append(log_entry)\n", - "\n", - "print(f\"Collected {len(all_logs)} log entries over {N_STEPS} steps\")\n", - "if all_logs:\n", - " keys = set()\n", - " for log in all_logs:\n", - " keys.update(log.keys())\n", - " keys.discard(\"_step\")\n", - " print(f\"\\n{'Metric':>25s} | {'count':>5s} {'mean':>10s} {'std':>10s} {'min':>10s} {'max':>10s}\")\n", - " print(\"-\" * 75)\n", - " for k in sorted(keys):\n", - " vals = [log[k] for log in all_logs if k in log and isinstance(log[k], (int, float))]\n", - " if vals:\n", - " v = np.array(vals)\n", - " print(f\"{k:>25s} | {len(v):5d} {v.mean():10.4f} {v.std():10.4f} {v.min():10.4f} {v.max():10.4f}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Metric definitions reference\n", - "\n", - "| Metric | Description |\n", - "|--------|-------------|\n", - "| score | Goals reached cleanly (no collision/offroad) |\n", - "| collision_rate | Fraction of agents that collided |\n", - "| offroad_rate | Fraction of agents that went off-road |\n", - "| completion_rate | Fraction that reached goal (even with collision/offroad) |\n", - "| lane_heading_aligned_rate | Fraction of steps with cos(theta) >= 0.5 (within ~60 deg of lane heading) |\n", - "| lane_center_rate | Lane centering metric average (same as reward term) |\n", - "| avg_collisions_per_agent | Average collision events per agent per episode |" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Terminal / truncation timeline" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "term_per_step = all_terms.sum(axis=1)\n", - "trunc_per_step = all_truncs.sum(axis=1)\n", - "\n", - "fig, ax = plt.subplots(figsize=(14, 4))\n", - "ax.plot(term_per_step, label=\"terminals\", alpha=0.7, color=\"red\")\n", - "ax.plot(trunc_per_step, label=\"truncations\", alpha=0.7, color=\"blue\")\n", - "ax.set_xlabel(\"Step\")\n", - "ax.set_ylabel(\"Count\")\n", - "ax.set_title(\"Terminal/truncation events per step\")\n", - "ax.legend()\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "print(f\"Total terminals: {all_terms.sum():.0f}, truncations: {all_truncs.sum():.0f}\")\n", - "print(f\"Terminals per step: mean={term_per_step.mean():.2f}, max={term_per_step.max():.0f}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Agent lifecycle trajectories" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "TRACK_STEPS = 100\n", - "TRACK_AGENTS = min(5, env.num_agents)\n", - "xy_history = np.zeros((TRACK_STEPS, TRACK_AGENTS, 2))\n", - "\n", - "for t in range(TRACK_STEPS):\n", - " actions = random_actions(env)\n", - " env.step(actions)\n", - " states = env.get_global_agent_state()\n", - " for i in range(TRACK_AGENTS):\n", - " xy_history[t, i, 0] = states[\"x\"][i]\n", - " xy_history[t, i, 1] = states[\"y\"][i]\n", - "\n", - "fig, ax = plt.subplots(figsize=(10, 10))\n", - "for i in range(TRACK_AGENTS):\n", - " ax.plot(xy_history[:, i, 0], xy_history[:, i, 1], \"-o\", markersize=2, alpha=0.7, label=f\"agent {i}\")\n", - " ax.scatter(xy_history[0, i, 0], xy_history[0, i, 1], s=100, marker=\"s\", zorder=10)\n", - " ax.scatter(xy_history[-1, i, 0], xy_history[-1, i, 1], s=100, marker=\"*\", zorder=10)\n", - "ax.set_xlabel(\"x\")\n", - "ax.set_ylabel(\"y\")\n", - "ax.set_title(f\"{TRACK_AGENTS} agent trajectories over {TRACK_STEPS} steps\")\n", - "ax.legend()\n", - "ax.set_aspect(\"equal\")\n", - "ax.grid(True, alpha=0.3)\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Consistency checks" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "if all_logs:\n", - " passed = 0\n", - " failed = 0\n", - " for log in all_logs:\n", - " if \"score\" in log and \"completion_rate\" in log:\n", - " if log[\"score\"] > log[\"completion_rate\"] + 1e-6:\n", - " print(\n", - " f\"FAIL: score ({log['score']:.4f}) > completion_rate ({log['completion_rate']:.4f}) at step {log['_step']}\"\n", - " )\n", - " failed += 1\n", - " else:\n", - " passed += 1\n", - " for rate_key in [\"collision_rate\", \"offroad_rate\", \"completion_rate\", \"score\"]:\n", - " if rate_key in log:\n", - " v = log[rate_key]\n", - " if v < -1e-6 or v > 1.0 + 1e-6:\n", - " print(f\"FAIL: {rate_key} = {v:.4f} outside [0,1] at step {log['_step']}\")\n", - " failed += 1\n", - " else:\n", - " passed += 1\n", - " print(f\"\\nConsistency checks: {passed} passed, {failed} failed\")\n", - "else:\n", - " print(\"No logs to check\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Gigaflow agent dynamics" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "episode_lengths = []\n", - "agent_step_count = np.zeros(env.num_agents)\n", - "active_counts = []\n", - "\n", - "for t in range(N_STEPS):\n", - " active = (~np.all(all_rewards[: t + 1] == 0, axis=0) if t > 0 else np.ones(env.num_agents, dtype=bool)).sum()\n", - " active_counts.append(active)\n", - " for i in range(env.num_agents):\n", - " agent_step_count[i] += 1\n", - " if all_terms[t, i] or all_truncs[t, i]:\n", - " episode_lengths.append(agent_step_count[i])\n", - " agent_step_count[i] = 0\n", - "\n", - "fig, axes = plt.subplots(1, 2, figsize=(14, 4))\n", - "axes[0].plot(active_counts)\n", - "axes[0].set_xlabel(\"Step\")\n", - "axes[0].set_ylabel(\"Active agents\")\n", - "axes[0].set_title(\"Active agent count over time\")\n", - "\n", - "if episode_lengths:\n", - " axes[1].hist(episode_lengths, bins=30, edgecolor=\"black\", alpha=0.7)\n", - " axes[1].set_xlabel(\"Episode length (steps)\")\n", - " axes[1].set_title(f\"Episode length distribution (N={len(episode_lengths)})\")\n", - " print(f\"Episode lengths: mean={np.mean(episode_lengths):.1f}, median={np.median(episode_lengths):.1f}\")\n", - "else:\n", - " print(\"No episodes completed\")\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Score vs cumulative reward" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "if all_logs and \"score\" in all_logs[0]:\n", - " scores = [log[\"score\"] for log in all_logs if \"score\" in log]\n", - " log_steps = [log[\"_step\"] for log in all_logs if \"score\" in log]\n", - " cum_rew_at_log = [all_rewards[: t + 1].sum() / env.num_agents for t in log_steps]\n", - "\n", - " fig, ax = plt.subplots(figsize=(8, 6))\n", - " ax.scatter(cum_rew_at_log, scores, alpha=0.5)\n", - " ax.set_xlabel(\"Avg cumulative reward up to step\")\n", - " ax.set_ylabel(\"Score\")\n", - " ax.set_title(\"Score vs cumulative reward\")\n", - " ax.grid(True, alpha=0.3)\n", - " plt.tight_layout()\n", - " plt.show()\n", - "else:\n", - " print(\"No score data available\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/notebooks/03_metrics.py b/notebooks/03_metrics.py new file mode 100644 index 0000000000..3c04a7dd81 --- /dev/null +++ b/notebooks/03_metrics.py @@ -0,0 +1,235 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.19.3 +# kernelspec: +# display_name: .venv +# language: python +# name: python3 +# --- + +# %% [markdown] +# # 03 - Episode Metrics & Logging Debug +# Verify vec_log returns correct metrics, aggregation is sane, episode boundaries handled. + +# %% +import numpy as np +import matplotlib.pyplot as plt +from pufferlib.ocean.drive import binding +from notebooks.notebook_utils import make_drive_env, random_actions + + +env, obs, info = make_drive_env() + +print( + f"env ready: {env.num_agents} agents, obs={obs.shape}, act_shape={(env.num_agents, len(env.single_action_space.nvec))}" +) +print( + f"ego_features={env.ego_features}, num_reward_coefs={env.num_reward_coefs}, obs_slots_partners_n={env.obs_slots_partners_n}, partner_features={env.partner_features}" +) +print( + f"obs_slots_lane_kept={env.obs_slots_lane_kept}, obs_slots_boundary_kept={env.obs_slots_boundary_kept}, road_features={env.road_features}" +) +print( + f"obs_slots_traffic_controls_n={env.obs_slots_traffic_controls_n}, traffic_control_features={env.traffic_control_features}" +) + +# %% [markdown] +# ## Single vec_log call + +# %% +for _ in range(10): + actions = random_actions(env) + obs, rew, term, trunc, info = env.step(actions) + +log = binding.vec_log(env.c_envs, env.num_agents) +print(f"vec_log type: {type(log)}") +if log: + print(f"Keys: {sorted(log.keys())}") + for k, v in sorted(log.items()): + print(f" {k}: {v}") +else: + print("vec_log returned empty/None") + +# %% [markdown] +# ## 512-step collection: all info dicts + +# %% +N_STEPS = 512 +all_logs = [] +all_rewards = np.zeros((N_STEPS, env.num_agents)) +all_terms = np.zeros((N_STEPS, env.num_agents)) +all_truncs = np.zeros((N_STEPS, env.num_agents)) + +for t in range(N_STEPS): + actions = random_actions(env) + obs, rew, term, trunc, info = env.step(actions) + all_rewards[t] = rew + all_terms[t] = term + all_truncs[t] = trunc + if info: + for log_entry in info: + log_entry["_step"] = t + all_logs.append(log_entry) + +print(f"Collected {len(all_logs)} log entries over {N_STEPS} steps") +if all_logs: + keys = set() + for log in all_logs: + keys.update(log.keys()) + keys.discard("_step") + print(f"\n{'Metric':>25s} | {'count':>5s} {'mean':>10s} {'std':>10s} {'min':>10s} {'max':>10s}") + print("-" * 75) + for k in sorted(keys): + vals = [log[k] for log in all_logs if k in log and isinstance(log[k], (int, float))] + if vals: + v = np.array(vals) + print(f"{k:>25s} | {len(v):5d} {v.mean():10.4f} {v.std():10.4f} {v.min():10.4f} {v.max():10.4f}") + +# %% [markdown] +# ## Metric definitions reference +# +# | Metric | Description | +# |--------|-------------| +# | score | Goals reached cleanly (no collision/offroad) | +# | collision_rate | Fraction of agents that collided | +# | offroad_rate | Fraction of agents that went off-road | +# | completion_rate | Fraction that reached goal (even with collision/offroad) | +# | lane_heading_aligned_rate | Fraction of steps with cos(theta) >= 0.5 (within ~60 deg of lane heading) | +# | lane_center_rate | Lane centering metric average (same as reward term) | +# | avg_collisions_per_agent | Average collision events per agent per episode | + +# %% [markdown] +# ## Terminal / truncation timeline + +# %% +term_per_step = all_terms.sum(axis=1) +trunc_per_step = all_truncs.sum(axis=1) + +fig, ax = plt.subplots(figsize=(14, 4)) +ax.plot(term_per_step, label="terminals", alpha=0.7, color="red") +ax.plot(trunc_per_step, label="truncations", alpha=0.7, color="blue") +ax.set_xlabel("Step") +ax.set_ylabel("Count") +ax.set_title("Terminal/truncation events per step") +ax.legend() +plt.tight_layout() +plt.show() + +print(f"Total terminals: {all_terms.sum():.0f}, truncations: {all_truncs.sum():.0f}") +print(f"Terminals per step: mean={term_per_step.mean():.2f}, max={term_per_step.max():.0f}") + +# %% [markdown] +# ## Agent lifecycle trajectories + +# %% +TRACK_STEPS = 100 +TRACK_AGENTS = min(5, env.num_agents) +xy_history = np.zeros((TRACK_STEPS, TRACK_AGENTS, 2)) + +for t in range(TRACK_STEPS): + actions = random_actions(env) + env.step(actions) + states = env.get_global_agent_state() + for i in range(TRACK_AGENTS): + xy_history[t, i, 0] = states["x"][i] + xy_history[t, i, 1] = states["y"][i] + +fig, ax = plt.subplots(figsize=(10, 10)) +for i in range(TRACK_AGENTS): + ax.plot(xy_history[:, i, 0], xy_history[:, i, 1], "-o", markersize=2, alpha=0.7, label=f"agent {i}") + ax.scatter(xy_history[0, i, 0], xy_history[0, i, 1], s=100, marker="s", zorder=10) + ax.scatter(xy_history[-1, i, 0], xy_history[-1, i, 1], s=100, marker="*", zorder=10) +ax.set_xlabel("x") +ax.set_ylabel("y") +ax.set_title(f"{TRACK_AGENTS} agent trajectories over {TRACK_STEPS} steps") +ax.legend() +ax.set_aspect("equal") +ax.grid(True, alpha=0.3) +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## Consistency checks + +# %% +if all_logs: + passed = 0 + failed = 0 + for log in all_logs: + if "score" in log and "completion_rate" in log: + if log["score"] > log["completion_rate"] + 1e-6: + print( + f"FAIL: score ({log['score']:.4f}) > completion_rate ({log['completion_rate']:.4f}) at step {log['_step']}" + ) + failed += 1 + else: + passed += 1 + for rate_key in ["collision_rate", "offroad_rate", "completion_rate", "score"]: + if rate_key in log: + v = log[rate_key] + if v < -1e-6 or v > 1.0 + 1e-6: + print(f"FAIL: {rate_key} = {v:.4f} outside [0,1] at step {log['_step']}") + failed += 1 + else: + passed += 1 + print(f"\nConsistency checks: {passed} passed, {failed} failed") +else: + print("No logs to check") + +# %% [markdown] +# ## Gigaflow agent dynamics + +# %% +episode_lengths = [] +agent_step_count = np.zeros(env.num_agents) +active_counts = [] + +for t in range(N_STEPS): + active = (~np.all(all_rewards[: t + 1] == 0, axis=0) if t > 0 else np.ones(env.num_agents, dtype=bool)).sum() + active_counts.append(active) + for i in range(env.num_agents): + agent_step_count[i] += 1 + if all_terms[t, i] or all_truncs[t, i]: + episode_lengths.append(agent_step_count[i]) + agent_step_count[i] = 0 + +fig, axes = plt.subplots(1, 2, figsize=(14, 4)) +axes[0].plot(active_counts) +axes[0].set_xlabel("Step") +axes[0].set_ylabel("Active agents") +axes[0].set_title("Active agent count over time") + +if episode_lengths: + axes[1].hist(episode_lengths, bins=30, edgecolor="black", alpha=0.7) + axes[1].set_xlabel("Episode length (steps)") + axes[1].set_title(f"Episode length distribution (N={len(episode_lengths)})") + print(f"Episode lengths: mean={np.mean(episode_lengths):.1f}, median={np.median(episode_lengths):.1f}") +else: + print("No episodes completed") +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## Score vs cumulative reward + +# %% +if all_logs and "score" in all_logs[0]: + scores = [log["score"] for log in all_logs if "score" in log] + log_steps = [log["_step"] for log in all_logs if "score" in log] + cum_rew_at_log = [all_rewards[: t + 1].sum() / env.num_agents for t in log_steps] + + fig, ax = plt.subplots(figsize=(8, 6)) + ax.scatter(cum_rew_at_log, scores, alpha=0.5) + ax.set_xlabel("Avg cumulative reward up to step") + ax.set_ylabel("Score") + ax.set_title("Score vs cumulative reward") + ax.grid(True, alpha=0.3) + plt.tight_layout() + plt.show() +else: + print("No score data available") diff --git a/notebooks/04_training.ipynb b/notebooks/04_training.ipynb deleted file mode 100644 index b1da07df42..0000000000 --- a/notebooks/04_training.ipynb +++ /dev/null @@ -1,549 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 04 - RL Training Loop Debug\n", - "End-to-end data flow from env -> policy -> loss. Debug encoding, sampling, advantages, gradients." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "import torch\n", - "import torch.nn.functional as F\n", - "from notebooks.notebook_utils import make_drive_env, make_drive_policy, zero_actions\n", - "\n", - "env, obs, info = make_drive_env()\n", - "\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "policy = make_drive_policy(env, device)\n", - "print(f\"Policy on {device}, params: {sum(p.numel() for p in policy.parameters()):,}\")\n", - "print(f\"Action dim: {policy.atn_dim}, act_shape: {(env.num_agents, len(env.single_action_space.nvec))}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Optional: load checkpoint" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# CHECKPOINT_PATH = ''\n", - "# state_dict = torch.load(CHECKPOINT_PATH, map_location=device)\n", - "# state_dict = {k.replace(\"module.\", \"\"): v for k, v in state_dict.items()}\n", - "# print('Checkpoint loaded')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Encode observations" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "actions = zero_actions(env)\n", - "obs, rew, term, trunc, info = env.step(actions)\n", - "\n", - "obs_tensor = torch.FloatTensor(obs).to(device)\n", - "with torch.no_grad():\n", - " hidden = policy.encode_observations(obs_tensor)\n", - "\n", - "print(f\"Hidden shape: {hidden.shape}\")\n", - "print(f\"Hidden stats: min={hidden.min():.4f}, max={hidden.max():.4f}, mean={hidden.mean():.4f}\")\n", - "print(f\"NaN in hidden: {torch.isnan(hidden).sum().item()}\")\n", - "print(f\"Dead neurons (always 0): {(hidden.abs().sum(dim=0) == 0).sum().item()}/{hidden.shape[1]}\")\n", - "print(f\"% near-zero (<1e-6): {(hidden.abs() < 1e-6).float().mean().item() * 100:.1f}%\")\n", - "\n", - "fig, ax = plt.subplots(figsize=(10, 4))\n", - "ax.hist(hidden.cpu().numpy().flatten(), bins=50, edgecolor=\"black\", alpha=0.7)\n", - "ax.set_title(\"Hidden activation distribution\")\n", - "ax.set_xlabel(\"Activation value\")\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Action sampling" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with torch.no_grad():\n", - " action_logits, value = policy.decode_actions(hidden)\n", - "\n", - "for i, logit in enumerate(action_logits):\n", - " print(f\"Action head {i}: shape={logit.shape}\")\n", - " probs = F.softmax(logit, dim=-1)\n", - " entropy = -(probs * probs.log()).sum(dim=-1).mean()\n", - " max_entropy = np.log(logit.shape[-1])\n", - " print(f\" Entropy: {entropy:.4f} / {max_entropy:.4f} (max) = {entropy / max_entropy:.2%}\")\n", - " print(f\" Logit range: [{logit.min():.3f}, {logit.max():.3f}]\")\n", - "\n", - "print(f\"\\nValue: mean={value.mean():.4f}, std={value.std():.4f}\")\n", - "\n", - "fig, axes = plt.subplots(1, 2, figsize=(14, 4))\n", - "probs = F.softmax(action_logits[0], dim=-1)\n", - "mean_probs = probs.mean(dim=0).cpu().numpy()\n", - "axes[0].bar(range(len(mean_probs)), mean_probs, edgecolor=\"black\", alpha=0.7)\n", - "axes[0].axhline(1.0 / len(mean_probs), color=\"red\", ls=\"--\", label=\"uniform\")\n", - "axes[0].set_xlabel(\"Action\")\n", - "axes[0].set_ylabel(\"Probability\")\n", - "axes[0].set_title(\"Mean action probabilities\")\n", - "axes[0].legend()\n", - "\n", - "axes[1].hist(value.cpu().numpy().flatten(), bins=20, edgecolor=\"black\", alpha=0.7, color=\"purple\")\n", - "axes[1].set_title(\"Value predictions\")\n", - "axes[1].set_xlabel(\"Value\")\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Manual encode trace: check each encoder for NaN" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "x = obs_tensor\n", - "backbone = policy.actor_backbone\n", - "slide_idx = env.ego_features\n", - "\n", - "ego_obs = x[:, :slide_idx]\n", - "print(\n", - " f\"ego_obs: shape={ego_obs.shape}, NaN={torch.isnan(ego_obs).sum().item()}, range=[{ego_obs.min():.3f}, {ego_obs.max():.3f}]\"\n", - ")\n", - "\n", - "cond_dim = backbone.target_dim\n", - "if cond_dim > 0:\n", - " cond_obs = x[:, slide_idx : slide_idx + cond_dim]\n", - " slide_idx += cond_dim\n", - " print(f\"cond_obs: shape={cond_obs.shape}, NaN={torch.isnan(cond_obs).sum().item()}\")\n", - "\n", - "partner_dim = env.obs_slots_partners_n * env.partner_features\n", - "lane_dim = env.obs_slots_lane_kept * env.road_features\n", - "boundary_dim = env.obs_slots_boundary_kept * env.road_features\n", - "\n", - "partner_obs = x[:, slide_idx : slide_idx + partner_dim]\n", - "slide_idx += partner_dim\n", - "lane_obs = x[:, slide_idx : slide_idx + lane_dim]\n", - "slide_idx += lane_dim\n", - "boundary_obs = x[:, slide_idx : slide_idx + boundary_dim]\n", - "slide_idx += boundary_dim\n", - "\n", - "with torch.no_grad():\n", - " ego_enc = backbone.ego_encoder(ego_obs)\n", - " partner_enc, _ = backbone.partner_encoder(partner_obs.view(-1, env.obs_slots_partners_n, env.partner_features)).max(\n", - " dim=1\n", - " )\n", - " lane_enc, _ = backbone.lane_encoder(lane_obs.view(-1, env.obs_slots_lane_kept, env.road_features)).max(dim=1)\n", - " bound_enc, _ = backbone.boundary_encoder(boundary_obs.view(-1, env.obs_slots_boundary_kept, env.road_features)).max(\n", - " dim=1\n", - " )\n", - "\n", - "for name, enc in [(\"ego\", ego_enc), (\"partner\", partner_enc), (\"lane\", lane_enc), (\"boundary\", bound_enc)]:\n", - " print(\n", - " f\"{name:>10s}_enc: NaN={torch.isnan(enc).sum().item()}, dead={((enc.abs().sum(dim=0) == 0).sum().item())}, range=[{enc.min():.3f}, {enc.max():.3f}]\"\n", - " )\n", - "\n", - "if cond_dim > 0:\n", - " with torch.no_grad():\n", - " cond_enc = backbone.target_encoder(cond_obs)\n", - " print(\n", - " f\"{'cond':>10s}_enc: NaN={torch.isnan(cond_enc).sum().item()}, dead={((cond_enc.abs().sum(dim=0) == 0).sum().item())}, range=[{cond_enc.min():.3f}, {cond_enc.max():.3f}]\"\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Forward-backward: fake advantage, loss, grads" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "policy.train()\n", - "optimizer = torch.optim.Adam(policy.parameters(), lr=3e-4)\n", - "\n", - "action_logits_list, value = policy(obs_tensor)\n", - "\n", - "fake_actions = torch.randint(0, env.single_action_space.nvec[0], (env.num_agents,), device=device)\n", - "fake_advantages = torch.randn(env.num_agents, device=device)\n", - "fake_returns = torch.randn(env.num_agents, device=device)\n", - "fake_old_logprobs = torch.randn(env.num_agents, device=device)\n", - "\n", - "logits = action_logits_list[0]\n", - "dist = torch.distributions.Categorical(logits=logits)\n", - "new_logprobs = dist.log_prob(fake_actions)\n", - "entropy = dist.entropy()\n", - "\n", - "ratio = torch.exp(new_logprobs - fake_old_logprobs)\n", - "clip_coef = 0.2\n", - "pg_loss1 = -fake_advantages * ratio\n", - "pg_loss2 = -fake_advantages * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)\n", - "pg_loss = torch.max(pg_loss1, pg_loss2).mean()\n", - "v_loss = 0.5 * ((value.squeeze() - fake_returns) ** 2).mean()\n", - "entropy_loss = entropy.mean()\n", - "loss = pg_loss + 0.5 * v_loss - 0.01 * entropy_loss\n", - "\n", - "print(f\"pg_loss: {pg_loss.item():.4f}\")\n", - "print(f\"v_loss: {v_loss.item():.4f}\")\n", - "print(f\"entropy: {entropy_loss.item():.4f}\")\n", - "print(f\"total: {loss.item():.4f}\")\n", - "print(f\"ratio: mean={ratio.mean():.4f}, std={ratio.std():.4f}\")\n", - "\n", - "optimizer.zero_grad()\n", - "loss.backward()\n", - "total_grad_norm = torch.nn.utils.clip_grad_norm_(policy.parameters(), float(\"inf\"))\n", - "print(f\"\\nTotal grad norm: {total_grad_norm:.4f}\")\n", - "print(f\"NaN in loss: {torch.isnan(loss).item()}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Gradient flow: per-parameter analysis" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(f\"{'Parameter':>45s} | {'shape':>20s} | {'grad_norm':>10s} {'grad_mean':>10s} {'grad_max':>10s} | flag\")\n", - "print(\"-\" * 120)\n", - "for name, param in policy.named_parameters():\n", - " if param.grad is not None:\n", - " g = param.grad\n", - " norm = g.norm().item()\n", - " mean = g.mean().item()\n", - " mx = g.abs().max().item()\n", - " flag = \"\"\n", - " if norm == 0:\n", - " flag = \"ZERO GRAD\"\n", - " elif norm > 100:\n", - " flag = \"EXPLODING\"\n", - " elif norm < 1e-7:\n", - " flag = \"VANISHING\"\n", - " print(f\"{name:>45s} | {str(list(param.shape)):>20s} | {norm:10.6f} {mean:10.6f} {mx:10.6f} | {flag}\")\n", - " else:\n", - " print(f\"{name:>45s} | {str(list(param.shape)):>20s} | NO GRAD\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Experience buffer simulation: 128-step rollout" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "HORIZON = 128\n", - "obs_dim = obs.shape[1]\n", - "\n", - "obs_buf = np.zeros((HORIZON, env.num_agents, obs_dim), dtype=np.float32)\n", - "act_buf = np.zeros((HORIZON, env.num_agents), dtype=np.int64)\n", - "rew_buf = np.zeros((HORIZON, env.num_agents), dtype=np.float32)\n", - "val_buf = np.zeros((HORIZON, env.num_agents), dtype=np.float32)\n", - "logp_buf = np.zeros((HORIZON, env.num_agents), dtype=np.float32)\n", - "done_buf = np.zeros((HORIZON, env.num_agents), dtype=np.float32)\n", - "\n", - "policy.eval()\n", - "for t in range(HORIZON):\n", - " obs_t = torch.FloatTensor(obs).to(device)\n", - " with torch.no_grad():\n", - " logits_list, val = policy(obs_t)\n", - " dist = torch.distributions.Categorical(logits=logits_list[0])\n", - " act = dist.sample()\n", - " logp = dist.log_prob(act)\n", - "\n", - " obs_buf[t] = obs\n", - " act_buf[t] = act.cpu().numpy()\n", - " val_buf[t] = val.squeeze().cpu().numpy()\n", - " logp_buf[t] = logp.cpu().numpy()\n", - "\n", - " # Reshape (N,) -> (N, 1) for env.step with MultiDiscrete\n", - " env_actions = act.cpu().numpy().reshape(env.num_agents, len(env.single_action_space.nvec))\n", - " obs, rew, term, trunc, info = env.step(env_actions)\n", - " rew_buf[t] = rew\n", - " done_buf[t] = term | trunc\n", - "\n", - "print(f\"Buffer shapes: obs={obs_buf.shape}, act={act_buf.shape}, rew={rew_buf.shape}\")\n", - "print(f\"Reward stats: mean={rew_buf.mean():.5f}, std={rew_buf.std():.5f}\")\n", - "print(f\"Value stats: mean={val_buf.mean():.5f}, std={val_buf.std():.5f}\")\n", - "print(f\"Done count: {done_buf.sum():.0f}\")\n", - "print(f\"LogProb stats: mean={logp_buf.mean():.4f}, std={logp_buf.std():.4f}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## GAE advantage computation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "gamma, lam = 0.98, 0.95\n", - "advantages = np.zeros_like(rew_buf)\n", - "\n", - "last_gae = np.zeros(env.num_agents)\n", - "for t in reversed(range(HORIZON - 1)):\n", - " next_non_terminal = 1.0 - done_buf[t + 1]\n", - " delta = rew_buf[t + 1] + gamma * val_buf[t + 1] * next_non_terminal - val_buf[t]\n", - " last_gae = delta + gamma * lam * last_gae * next_non_terminal\n", - " advantages[t] = last_gae\n", - "\n", - "returns = advantages + val_buf\n", - "\n", - "print(f\"Advantages: mean={advantages.mean():.5f}, std={advantages.std():.5f}\")\n", - "print(f\"Returns: mean={returns.mean():.5f}, std={returns.std():.5f}\")\n", - "print(f\"Advantage vs Return corr: {np.corrcoef(advantages.flatten(), returns.flatten())[0, 1]:.4f}\")\n", - "\n", - "fig, axes = plt.subplots(1, 4, figsize=(18, 4))\n", - "axes[0].hist(advantages.flatten(), bins=50, edgecolor=\"black\", alpha=0.7)\n", - "axes[0].set_title(f\"Advantage distribution (std={advantages.std():.4f})\")\n", - "\n", - "axes[1].hist(returns.flatten(), bins=50, edgecolor=\"black\", alpha=0.7, color=\"orange\")\n", - "axes[1].set_title(\"Returns distribution\")\n", - "\n", - "axes[2].plot(advantages.mean(axis=1))\n", - "axes[2].set_xlabel(\"Step\")\n", - "axes[2].set_ylabel(\"Mean advantage\")\n", - "axes[2].set_title(\"Mean advantage over time\")\n", - "\n", - "axes[3].plot(done_buf.mean(axis=1), color=\"orange\")\n", - "axes[3].set_xlabel(\"Step\")\n", - "axes[3].set_ylabel(\"Mean done\")\n", - "axes[3].set_title(\"Mean done over time\")\n", - "\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## PPO loss components" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "MB = 16\n", - "mb_obs = torch.FloatTensor(obs_buf[:MB].reshape(-1, obs_dim)).to(device)\n", - "mb_act = torch.LongTensor(act_buf[:MB].flatten()).to(device)\n", - "mb_old_logp = torch.FloatTensor(logp_buf[:MB].flatten()).to(device)\n", - "mb_adv = torch.FloatTensor(advantages[:MB].flatten()).to(device)\n", - "mb_ret = torch.FloatTensor(returns[:MB].flatten()).to(device)\n", - "mb_old_val = torch.FloatTensor(val_buf[:MB].flatten()).to(device)\n", - "\n", - "mb_adv = (mb_adv - mb_adv.mean()) / (mb_adv.std() + 1e-8)\n", - "\n", - "policy.train()\n", - "logits_list, newvalue = policy(mb_obs)\n", - "newvalue = newvalue.squeeze()\n", - "dist = torch.distributions.Categorical(logits=logits_list[0])\n", - "new_logp = dist.log_prob(mb_act)\n", - "entropy = dist.entropy()\n", - "\n", - "ratio = torch.exp(new_logp - mb_old_logp)\n", - "print(f\"Ratio: mean={ratio.mean():.4f}, std={ratio.std():.4f}, min={ratio.min():.4f}, max={ratio.max():.4f}\")\n", - "if ratio.mean() < 0.5 or ratio.mean() > 2.0:\n", - " print(\"WARNING: ratio far from 1.0, policy may have diverged\")\n", - "\n", - "clip_coef = 0.2\n", - "pg_loss1 = -mb_adv * ratio\n", - "pg_loss2 = -mb_adv * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)\n", - "pg_loss = torch.max(pg_loss1, pg_loss2).mean()\n", - "\n", - "vf_clip = 0.2\n", - "v_clipped = mb_old_val + torch.clamp(newvalue - mb_old_val, -vf_clip, vf_clip)\n", - "v_loss_unclipped = (newvalue - mb_ret) ** 2\n", - "v_loss_clipped = (v_clipped - mb_ret) ** 2\n", - "v_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()\n", - "\n", - "entropy_loss = entropy.mean()\n", - "\n", - "print(f\"\\npg_loss: {pg_loss.item():.6f}\")\n", - "print(f\"v_loss: {v_loss.item():.6f}\")\n", - "print(f\"entropy: {entropy_loss.item():.6f} (max={np.log(env.single_action_space.nvec[0]):.4f})\")\n", - "print(f\"total: {(pg_loss + 0.5 * v_loss - 0.01 * entropy_loss).item():.6f}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 5-epoch sanity training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "optimizer = torch.optim.Adam(policy.parameters(), lr=3e-4)\n", - "all_obs = torch.FloatTensor(obs_buf.reshape(-1, obs_dim)).to(device)\n", - "all_act = torch.LongTensor(act_buf.flatten()).to(device)\n", - "all_old_logp = torch.FloatTensor(logp_buf.flatten()).to(device)\n", - "all_adv = torch.FloatTensor(advantages.flatten()).to(device)\n", - "all_ret = torch.FloatTensor(returns.flatten()).to(device)\n", - "\n", - "all_adv = (all_adv - all_adv.mean()) / (all_adv.std() + 1e-8)\n", - "\n", - "N_EPOCHS = 5\n", - "history = {\"pg_loss\": [], \"v_loss\": [], \"entropy\": [], \"kl\": []}\n", - "\n", - "policy.train()\n", - "for epoch in range(N_EPOCHS):\n", - " logits_list, newval = policy(all_obs)\n", - " newval = newval.squeeze()\n", - " dist = torch.distributions.Categorical(logits=logits_list[0])\n", - " new_logp = dist.log_prob(all_act)\n", - " ent = dist.entropy().mean()\n", - "\n", - " ratio = torch.exp(new_logp - all_old_logp)\n", - " approx_kl = (all_old_logp - new_logp).mean()\n", - "\n", - " pg1 = -all_adv * ratio\n", - " pg2 = -all_adv * torch.clamp(ratio, 0.8, 1.2)\n", - " pg = torch.max(pg1, pg2).mean()\n", - " vl = 0.5 * ((newval - all_ret) ** 2).mean()\n", - " loss = pg + 0.5 * vl - 0.01 * ent\n", - "\n", - " optimizer.zero_grad()\n", - " loss.backward()\n", - " torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.5)\n", - " optimizer.step()\n", - "\n", - " history[\"pg_loss\"].append(pg.item())\n", - " history[\"v_loss\"].append(vl.item())\n", - " history[\"entropy\"].append(ent.item())\n", - " history[\"kl\"].append(approx_kl.item())\n", - " print(f\"Epoch {epoch}: pg={pg.item():.5f}, v={vl.item():.5f}, ent={ent.item():.4f}, kl={approx_kl.item():.5f}\")\n", - "\n", - "fig, axes = plt.subplots(1, 4, figsize=(16, 3))\n", - "for i, (key, color) in enumerate(zip(history.keys(), [\"red\", \"blue\", \"green\", \"orange\"])):\n", - " axes[i].plot(history[key], \"-o\", color=color)\n", - " axes[i].set_title(key)\n", - " axes[i].set_xlabel(\"Epoch\")\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Value accuracy: predicted vs actual returns" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "policy.eval()\n", - "with torch.no_grad():\n", - " _, pred_values = policy(all_obs)\n", - "pred_values = pred_values.squeeze().cpu().numpy()\n", - "actual_returns = returns.flatten()\n", - "\n", - "var_actual = np.var(actual_returns)\n", - "explained_var = 1 - np.var(actual_returns - pred_values) / (var_actual + 1e-8) if var_actual > 1e-8 else 0.0\n", - "\n", - "fig, ax = plt.subplots(figsize=(7, 7))\n", - "ax.scatter(actual_returns, pred_values, alpha=0.3, s=10)\n", - "lims = [min(actual_returns.min(), pred_values.min()), max(actual_returns.max(), pred_values.max())]\n", - "ax.plot(lims, lims, \"r--\", label=\"perfect\")\n", - "ax.set_xlabel(\"Actual return\")\n", - "ax.set_ylabel(\"Predicted value\")\n", - "ax.set_title(f\"Value accuracy (explained var: {explained_var:.4f})\")\n", - "ax.legend()\n", - "ax.grid(True, alpha=0.3)\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "print(f\"Explained variance: {explained_var:.4f}\")\n", - "print(f\"Value MSE: {np.mean((actual_returns - pred_values) ** 2):.6f}\")\n", - "if explained_var < 0:\n", - " print(\"WARNING: negative explained variance, value head worse than predicting mean\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/notebooks/04_training.py b/notebooks/04_training.py new file mode 100644 index 0000000000..c1378bd2b1 --- /dev/null +++ b/notebooks/04_training.py @@ -0,0 +1,418 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.19.3 +# kernelspec: +# display_name: .venv +# language: python +# name: python3 +# --- + +# %% [markdown] +# # 04 - RL Training Loop Debug +# End-to-end data flow from env -> policy -> loss. Debug encoding, sampling, advantages, gradients. + +# %% +import numpy as np +import matplotlib.pyplot as plt +import torch +import torch.nn.functional as F +from notebooks.notebook_utils import make_drive_env, make_drive_policy, zero_actions + +env, obs, info = make_drive_env() + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +policy = make_drive_policy(env, device) +print(f"Policy on {device}, params: {sum(p.numel() for p in policy.parameters()):,}") +print(f"Action dim: {policy.atn_dim}, act_shape: {(env.num_agents, len(env.single_action_space.nvec))}") + +# %% [markdown] +# ### Optional: load checkpoint + +# %% +# CHECKPOINT_PATH = '' +# state_dict = torch.load(CHECKPOINT_PATH, map_location=device) +# state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} +# print('Checkpoint loaded') + +# %% [markdown] +# ## Encode observations + +# %% +actions = zero_actions(env) +obs, rew, term, trunc, info = env.step(actions) + +obs_tensor = torch.FloatTensor(obs).to(device) +with torch.no_grad(): + hidden = policy.encode_observations(obs_tensor) + +print(f"Hidden shape: {hidden.shape}") +print(f"Hidden stats: min={hidden.min():.4f}, max={hidden.max():.4f}, mean={hidden.mean():.4f}") +print(f"NaN in hidden: {torch.isnan(hidden).sum().item()}") +print(f"Dead neurons (always 0): {(hidden.abs().sum(dim=0) == 0).sum().item()}/{hidden.shape[1]}") +print(f"% near-zero (<1e-6): {(hidden.abs() < 1e-6).float().mean().item() * 100:.1f}%") + +fig, ax = plt.subplots(figsize=(10, 4)) +ax.hist(hidden.cpu().numpy().flatten(), bins=50, edgecolor="black", alpha=0.7) +ax.set_title("Hidden activation distribution") +ax.set_xlabel("Activation value") +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## Action sampling + +# %% +with torch.no_grad(): + action_logits, value = policy.decode_actions(hidden) + +for i, logit in enumerate(action_logits): + print(f"Action head {i}: shape={logit.shape}") + probs = F.softmax(logit, dim=-1) + entropy = -(probs * probs.log()).sum(dim=-1).mean() + max_entropy = np.log(logit.shape[-1]) + print(f" Entropy: {entropy:.4f} / {max_entropy:.4f} (max) = {entropy / max_entropy:.2%}") + print(f" Logit range: [{logit.min():.3f}, {logit.max():.3f}]") + +print(f"\nValue: mean={value.mean():.4f}, std={value.std():.4f}") + +fig, axes = plt.subplots(1, 2, figsize=(14, 4)) +probs = F.softmax(action_logits[0], dim=-1) +mean_probs = probs.mean(dim=0).cpu().numpy() +axes[0].bar(range(len(mean_probs)), mean_probs, edgecolor="black", alpha=0.7) +axes[0].axhline(1.0 / len(mean_probs), color="red", ls="--", label="uniform") +axes[0].set_xlabel("Action") +axes[0].set_ylabel("Probability") +axes[0].set_title("Mean action probabilities") +axes[0].legend() + +axes[1].hist(value.cpu().numpy().flatten(), bins=20, edgecolor="black", alpha=0.7, color="purple") +axes[1].set_title("Value predictions") +axes[1].set_xlabel("Value") +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## Manual encode trace: check each encoder for NaN + +# %% +x = obs_tensor +backbone = policy.actor_backbone +slide_idx = env.ego_features + +ego_obs = x[:, :slide_idx] +print( + f"ego_obs: shape={ego_obs.shape}, NaN={torch.isnan(ego_obs).sum().item()}, range=[{ego_obs.min():.3f}, {ego_obs.max():.3f}]" +) + +cond_dim = backbone.target_dim +if cond_dim > 0: + cond_obs = x[:, slide_idx : slide_idx + cond_dim] + slide_idx += cond_dim + print(f"cond_obs: shape={cond_obs.shape}, NaN={torch.isnan(cond_obs).sum().item()}") + +partner_dim = env.obs_slots_partners_n * env.partner_features +lane_dim = env.obs_slots_lane_kept * env.road_features +boundary_dim = env.obs_slots_boundary_kept * env.road_features + +partner_obs = x[:, slide_idx : slide_idx + partner_dim] +slide_idx += partner_dim +lane_obs = x[:, slide_idx : slide_idx + lane_dim] +slide_idx += lane_dim +boundary_obs = x[:, slide_idx : slide_idx + boundary_dim] +slide_idx += boundary_dim + +with torch.no_grad(): + ego_enc = backbone.ego_encoder(ego_obs) + partner_enc, _ = backbone.partner_encoder(partner_obs.view(-1, env.obs_slots_partners_n, env.partner_features)).max( + dim=1 + ) + lane_enc, _ = backbone.lane_encoder(lane_obs.view(-1, env.obs_slots_lane_kept, env.road_features)).max(dim=1) + bound_enc, _ = backbone.boundary_encoder(boundary_obs.view(-1, env.obs_slots_boundary_kept, env.road_features)).max( + dim=1 + ) + +for name, enc in [("ego", ego_enc), ("partner", partner_enc), ("lane", lane_enc), ("boundary", bound_enc)]: + print( + f"{name:>10s}_enc: NaN={torch.isnan(enc).sum().item()}, dead={((enc.abs().sum(dim=0) == 0).sum().item())}, range=[{enc.min():.3f}, {enc.max():.3f}]" + ) + +if cond_dim > 0: + with torch.no_grad(): + cond_enc = backbone.target_encoder(cond_obs) + print( + f"{'cond':>10s}_enc: NaN={torch.isnan(cond_enc).sum().item()}, dead={((cond_enc.abs().sum(dim=0) == 0).sum().item())}, range=[{cond_enc.min():.3f}, {cond_enc.max():.3f}]" + ) + +# %% [markdown] +# ## Forward-backward: fake advantage, loss, grads + +# %% +policy.train() +optimizer = torch.optim.Adam(policy.parameters(), lr=3e-4) + +action_logits_list, value = policy(obs_tensor) + +fake_actions = torch.randint(0, env.single_action_space.nvec[0], (env.num_agents,), device=device) +fake_advantages = torch.randn(env.num_agents, device=device) +fake_returns = torch.randn(env.num_agents, device=device) +fake_old_logprobs = torch.randn(env.num_agents, device=device) + +logits = action_logits_list[0] +dist = torch.distributions.Categorical(logits=logits) +new_logprobs = dist.log_prob(fake_actions) +entropy = dist.entropy() + +ratio = torch.exp(new_logprobs - fake_old_logprobs) +clip_coef = 0.2 +pg_loss1 = -fake_advantages * ratio +pg_loss2 = -fake_advantages * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef) +pg_loss = torch.max(pg_loss1, pg_loss2).mean() +v_loss = 0.5 * ((value.squeeze() - fake_returns) ** 2).mean() +entropy_loss = entropy.mean() +loss = pg_loss + 0.5 * v_loss - 0.01 * entropy_loss + +print(f"pg_loss: {pg_loss.item():.4f}") +print(f"v_loss: {v_loss.item():.4f}") +print(f"entropy: {entropy_loss.item():.4f}") +print(f"total: {loss.item():.4f}") +print(f"ratio: mean={ratio.mean():.4f}, std={ratio.std():.4f}") + +optimizer.zero_grad() +loss.backward() +total_grad_norm = torch.nn.utils.clip_grad_norm_(policy.parameters(), float("inf")) +print(f"\nTotal grad norm: {total_grad_norm:.4f}") +print(f"NaN in loss: {torch.isnan(loss).item()}") + +# %% [markdown] +# ## Gradient flow: per-parameter analysis + +# %% +print(f"{'Parameter':>45s} | {'shape':>20s} | {'grad_norm':>10s} {'grad_mean':>10s} {'grad_max':>10s} | flag") +print("-" * 120) +for name, param in policy.named_parameters(): + if param.grad is not None: + g = param.grad + norm = g.norm().item() + mean = g.mean().item() + mx = g.abs().max().item() + flag = "" + if norm == 0: + flag = "ZERO GRAD" + elif norm > 100: + flag = "EXPLODING" + elif norm < 1e-7: + flag = "VANISHING" + print(f"{name:>45s} | {str(list(param.shape)):>20s} | {norm:10.6f} {mean:10.6f} {mx:10.6f} | {flag}") + else: + print(f"{name:>45s} | {str(list(param.shape)):>20s} | NO GRAD") + +# %% [markdown] +# ## Experience buffer simulation: 128-step rollout + +# %% +HORIZON = 128 +obs_dim = obs.shape[1] + +obs_buf = np.zeros((HORIZON, env.num_agents, obs_dim), dtype=np.float32) +act_buf = np.zeros((HORIZON, env.num_agents), dtype=np.int64) +rew_buf = np.zeros((HORIZON, env.num_agents), dtype=np.float32) +val_buf = np.zeros((HORIZON, env.num_agents), dtype=np.float32) +logp_buf = np.zeros((HORIZON, env.num_agents), dtype=np.float32) +done_buf = np.zeros((HORIZON, env.num_agents), dtype=np.float32) + +policy.eval() +for t in range(HORIZON): + obs_t = torch.FloatTensor(obs).to(device) + with torch.no_grad(): + logits_list, val = policy(obs_t) + dist = torch.distributions.Categorical(logits=logits_list[0]) + act = dist.sample() + logp = dist.log_prob(act) + + obs_buf[t] = obs + act_buf[t] = act.cpu().numpy() + val_buf[t] = val.squeeze().cpu().numpy() + logp_buf[t] = logp.cpu().numpy() + + # Reshape (N,) -> (N, 1) for env.step with MultiDiscrete + env_actions = act.cpu().numpy().reshape(env.num_agents, len(env.single_action_space.nvec)) + obs, rew, term, trunc, info = env.step(env_actions) + rew_buf[t] = rew + done_buf[t] = term | trunc + +print(f"Buffer shapes: obs={obs_buf.shape}, act={act_buf.shape}, rew={rew_buf.shape}") +print(f"Reward stats: mean={rew_buf.mean():.5f}, std={rew_buf.std():.5f}") +print(f"Value stats: mean={val_buf.mean():.5f}, std={val_buf.std():.5f}") +print(f"Done count: {done_buf.sum():.0f}") +print(f"LogProb stats: mean={logp_buf.mean():.4f}, std={logp_buf.std():.4f}") + +# %% [markdown] +# ## GAE advantage computation + +# %% +gamma, lam = 0.98, 0.95 +advantages = np.zeros_like(rew_buf) + +last_gae = np.zeros(env.num_agents) +for t in reversed(range(HORIZON - 1)): + next_non_terminal = 1.0 - done_buf[t + 1] + delta = rew_buf[t + 1] + gamma * val_buf[t + 1] * next_non_terminal - val_buf[t] + last_gae = delta + gamma * lam * last_gae * next_non_terminal + advantages[t] = last_gae + +returns = advantages + val_buf + +print(f"Advantages: mean={advantages.mean():.5f}, std={advantages.std():.5f}") +print(f"Returns: mean={returns.mean():.5f}, std={returns.std():.5f}") +print(f"Advantage vs Return corr: {np.corrcoef(advantages.flatten(), returns.flatten())[0, 1]:.4f}") + +fig, axes = plt.subplots(1, 4, figsize=(18, 4)) +axes[0].hist(advantages.flatten(), bins=50, edgecolor="black", alpha=0.7) +axes[0].set_title(f"Advantage distribution (std={advantages.std():.4f})") + +axes[1].hist(returns.flatten(), bins=50, edgecolor="black", alpha=0.7, color="orange") +axes[1].set_title("Returns distribution") + +axes[2].plot(advantages.mean(axis=1)) +axes[2].set_xlabel("Step") +axes[2].set_ylabel("Mean advantage") +axes[2].set_title("Mean advantage over time") + +axes[3].plot(done_buf.mean(axis=1), color="orange") +axes[3].set_xlabel("Step") +axes[3].set_ylabel("Mean done") +axes[3].set_title("Mean done over time") + +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## PPO loss components + +# %% +MB = 16 +mb_obs = torch.FloatTensor(obs_buf[:MB].reshape(-1, obs_dim)).to(device) +mb_act = torch.LongTensor(act_buf[:MB].flatten()).to(device) +mb_old_logp = torch.FloatTensor(logp_buf[:MB].flatten()).to(device) +mb_adv = torch.FloatTensor(advantages[:MB].flatten()).to(device) +mb_ret = torch.FloatTensor(returns[:MB].flatten()).to(device) +mb_old_val = torch.FloatTensor(val_buf[:MB].flatten()).to(device) + +mb_adv = (mb_adv - mb_adv.mean()) / (mb_adv.std() + 1e-8) + +policy.train() +logits_list, newvalue = policy(mb_obs) +newvalue = newvalue.squeeze() +dist = torch.distributions.Categorical(logits=logits_list[0]) +new_logp = dist.log_prob(mb_act) +entropy = dist.entropy() + +ratio = torch.exp(new_logp - mb_old_logp) +print(f"Ratio: mean={ratio.mean():.4f}, std={ratio.std():.4f}, min={ratio.min():.4f}, max={ratio.max():.4f}") +if ratio.mean() < 0.5 or ratio.mean() > 2.0: + print("WARNING: ratio far from 1.0, policy may have diverged") + +clip_coef = 0.2 +pg_loss1 = -mb_adv * ratio +pg_loss2 = -mb_adv * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef) +pg_loss = torch.max(pg_loss1, pg_loss2).mean() + +vf_clip = 0.2 +v_clipped = mb_old_val + torch.clamp(newvalue - mb_old_val, -vf_clip, vf_clip) +v_loss_unclipped = (newvalue - mb_ret) ** 2 +v_loss_clipped = (v_clipped - mb_ret) ** 2 +v_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean() + +entropy_loss = entropy.mean() + +print(f"\npg_loss: {pg_loss.item():.6f}") +print(f"v_loss: {v_loss.item():.6f}") +print(f"entropy: {entropy_loss.item():.6f} (max={np.log(env.single_action_space.nvec[0]):.4f})") +print(f"total: {(pg_loss + 0.5 * v_loss - 0.01 * entropy_loss).item():.6f}") + +# %% [markdown] +# ## 5-epoch sanity training + +# %% +optimizer = torch.optim.Adam(policy.parameters(), lr=3e-4) +all_obs = torch.FloatTensor(obs_buf.reshape(-1, obs_dim)).to(device) +all_act = torch.LongTensor(act_buf.flatten()).to(device) +all_old_logp = torch.FloatTensor(logp_buf.flatten()).to(device) +all_adv = torch.FloatTensor(advantages.flatten()).to(device) +all_ret = torch.FloatTensor(returns.flatten()).to(device) + +all_adv = (all_adv - all_adv.mean()) / (all_adv.std() + 1e-8) + +N_EPOCHS = 5 +history = {"pg_loss": [], "v_loss": [], "entropy": [], "kl": []} + +policy.train() +for epoch in range(N_EPOCHS): + logits_list, newval = policy(all_obs) + newval = newval.squeeze() + dist = torch.distributions.Categorical(logits=logits_list[0]) + new_logp = dist.log_prob(all_act) + ent = dist.entropy().mean() + + ratio = torch.exp(new_logp - all_old_logp) + approx_kl = (all_old_logp - new_logp).mean() + + pg1 = -all_adv * ratio + pg2 = -all_adv * torch.clamp(ratio, 0.8, 1.2) + pg = torch.max(pg1, pg2).mean() + vl = 0.5 * ((newval - all_ret) ** 2).mean() + loss = pg + 0.5 * vl - 0.01 * ent + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(policy.parameters(), 0.5) + optimizer.step() + + history["pg_loss"].append(pg.item()) + history["v_loss"].append(vl.item()) + history["entropy"].append(ent.item()) + history["kl"].append(approx_kl.item()) + print(f"Epoch {epoch}: pg={pg.item():.5f}, v={vl.item():.5f}, ent={ent.item():.4f}, kl={approx_kl.item():.5f}") + +fig, axes = plt.subplots(1, 4, figsize=(16, 3)) +for i, (key, color) in enumerate(zip(history.keys(), ["red", "blue", "green", "orange"])): + axes[i].plot(history[key], "-o", color=color) + axes[i].set_title(key) + axes[i].set_xlabel("Epoch") +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## Value accuracy: predicted vs actual returns + +# %% +policy.eval() +with torch.no_grad(): + _, pred_values = policy(all_obs) +pred_values = pred_values.squeeze().cpu().numpy() +actual_returns = returns.flatten() + +var_actual = np.var(actual_returns) +explained_var = 1 - np.var(actual_returns - pred_values) / (var_actual + 1e-8) if var_actual > 1e-8 else 0.0 + +fig, ax = plt.subplots(figsize=(7, 7)) +ax.scatter(actual_returns, pred_values, alpha=0.3, s=10) +lims = [min(actual_returns.min(), pred_values.min()), max(actual_returns.max(), pred_values.max())] +ax.plot(lims, lims, "r--", label="perfect") +ax.set_xlabel("Actual return") +ax.set_ylabel("Predicted value") +ax.set_title(f"Value accuracy (explained var: {explained_var:.4f})") +ax.legend() +ax.grid(True, alpha=0.3) +plt.tight_layout() +plt.show() + +print(f"Explained variance: {explained_var:.4f}") +print(f"Value MSE: {np.mean((actual_returns - pred_values) ** 2):.6f}") +if explained_var < 0: + print("WARNING: negative explained variance, value head worse than predicting mean") diff --git a/notebooks/05_inference.ipynb b/notebooks/05_inference.ipynb deleted file mode 100644 index c605aaffee..0000000000 --- a/notebooks/05_inference.ipynb +++ /dev/null @@ -1,1782 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "cell-0", - "metadata": {}, - "source": [ - "# 05 - Model Inference Debug\n", - "End-to-end inference pipeline: config loading, policy forward pass, rollouts (deterministic vs stochastic), observation/reward analysis, value accuracy, trajectories, LSTM state." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cell-1", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "import torch\n", - "import torch.nn.functional as F\n", - "from pufferlib.ocean.drive.drive import Drive\n", - "from pufferlib.ocean.drive import binding\n", - "from pufferlib.ocean.torch import Drive as DrivePolicy\n", - "from pufferlib.pytorch import sample_logits\n", - "from notebooks.notebook_utils import COEF_NAMES, EGO_LABELS, MAP_DIR, load_notebook_config, zero_actions\n", - "\n", - "CHECKPOINT_PATH = \"../weights/tomate/models/model_puffer_drive_013100.pt\"\n", - "ENV_NAME = \"puffer_drive\"\n", - "\n", - "config = load_notebook_config(CHECKPOINT_PATH, ENV_NAME)\n", - "config[\"env\"][\"num_agents\"] = 64\n", - "config[\"env\"][\"num_maps\"] = 8\n", - "config[\"env\"][\"eval_mode\"] = 1\n", - "config[\"env\"][\"map_dir\"] = MAP_DIR\n", - "\n", - "config[\"env\"][\"obs_slots_boundary_n\"] = 80\n", - "config[\"env\"][\"obs_slots_lane_n\"] = 80\n", - "config[\"env\"][\"obs_dropout_lane\"] = 0.0\n", - "config[\"env\"][\"obs_dropout_boundary\"] = 0.0\n", - "\n", - "env = Drive(**config[\"env\"])\n", - "obs, info = env.reset(seed=42)\n", - "N = env.num_agents\n", - "\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "policy = DrivePolicy(env, **config[\"policy\"]).to(device)\n", - "\n", - "if CHECKPOINT_PATH:\n", - " sd = torch.load(CHECKPOINT_PATH, map_location=device)\n", - " sd = {k.replace(\"module.\", \"\"): v for k, v in sd.items()}\n", - " policy.load_state_dict(sd)\n", - " print(f\"Loaded checkpoint: {CHECKPOINT_PATH}\")\n", - "\n", - "is_continuous = policy.is_continuous\n", - "ACT_SHAPE = (N, len(env.single_action_space.nvec)) if not is_continuous else (N, env.single_action_space.shape[0])\n", - "\n", - "print(f\"Policy on {device}, params: {sum(p.numel() for p in policy.parameters()):,}\")\n", - "print(f\"Obs shape: {obs.shape}, Action space: {env.single_action_space}\")\n", - "print(f\"Config: dynamics={config['env']['dynamics_model']}, action={config['env']['action_type']}\")" - ] - }, - { - "cell_type": "markdown", - "id": "cell-2", - "metadata": {}, - "source": [ - "## Single-step policy output" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cell-3", - "metadata": {}, - "outputs": [], - "source": [ - "# Take one step to get fresh obs\n", - "actions = zero_actions(env)\n", - "obs, rew, term, trunc, info = env.step(actions)\n", - "\n", - "obs_tensor = torch.FloatTensor(obs).to(device)\n", - "policy.eval()\n", - "\n", - "with torch.no_grad():\n", - " logits_list, value = policy(obs_tensor)\n", - "\n", - "# Sample actions\n", - "action, logprob, ent = sample_logits(logits_list)\n", - "action_det, _, _ = sample_logits(logits_list, deterministic=True)\n", - "\n", - "print(f\"Value: mean={value.mean():.4f}, std={value.std():.4f}, range=[{value.min():.4f}, {value.max():.4f}]\")\n", - "print(f\"Entropy: mean={ent.mean():.4f}, std={ent.std():.4f}\")\n", - "print(f\"LogProb: mean={logprob.mean():.4f}, std={logprob.std():.4f}\")\n", - "print(f\"Stochastic action sample: {action[0].cpu().numpy()}\")\n", - "print(f\"Deterministic action: {action_det[0].cpu().numpy()}\")\n", - "\n", - "# Plot\n", - "fig, axes = plt.subplots(1, 2, figsize=(14, 4))\n", - "\n", - "# Action probs (first head for multi-discrete, or full logits)\n", - "if isinstance(logits_list, list) or isinstance(logits_list, tuple):\n", - " probs = F.softmax(logits_list[0], dim=-1)\n", - "else:\n", - " probs = F.softmax(logits_list, dim=-1)\n", - "mean_probs = probs.mean(dim=0).cpu().numpy()\n", - "axes[0].bar(range(len(mean_probs)), mean_probs, edgecolor=\"black\", alpha=0.7)\n", - "axes[0].axhline(1.0 / len(mean_probs), color=\"red\", ls=\"--\", label=\"uniform\")\n", - "axes[0].set_xlabel(\"Action\")\n", - "axes[0].set_ylabel(\"Probability\")\n", - "axes[0].set_title(\"Mean action probabilities (across agents)\")\n", - "axes[0].legend()\n", - "\n", - "axes[1].hist(value.cpu().numpy().flatten(), bins=30, edgecolor=\"black\", alpha=0.7, color=\"purple\")\n", - "axes[1].set_title(\"Value predictions across agents\")\n", - "axes[1].set_xlabel(\"Value\")\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "cell-4", - "metadata": {}, - "source": [ - "## Full rollout: deterministic vs stochastic" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cell-5", - "metadata": {}, - "outputs": [], - "source": [ - "HORIZON = 256\n", - "TRACKED_AGENT = 0 # agent index to track in detail\n", - "obs_dim = obs.shape[1]\n", - "\n", - "dyn_model = config[\"env\"][\"dynamics_model\"]\n", - "tgt_type = config[\"env\"][\"target_type\"]\n", - "rew_cond = config[\"env\"].get(\"reward_conditioning\", False)\n", - "n_tgt_wp = config[\"env\"].get(\"num_target_waypoints\", 3)\n", - "\n", - "\n", - "def run_rollout(env, policy, deterministic=False, horizon=HORIZON):\n", - " obs, _ = env.reset(seed=42)\n", - " N = env.num_agents\n", - "\n", - " buffers = {\n", - " \"obs\": np.zeros((horizon, N, obs_dim), dtype=np.float32),\n", - " \"actions\": np.zeros((horizon, N), dtype=np.int64),\n", - " \"rewards\": np.zeros((horizon, N), dtype=np.float32),\n", - " \"values\": np.zeros((horizon, N), dtype=np.float32),\n", - " \"logprobs\": np.zeros((horizon, N), dtype=np.float32),\n", - " \"entropy\": np.zeros((horizon, N), dtype=np.float32),\n", - " \"terminals\": np.zeros((horizon, N), dtype=np.float32),\n", - " \"truncations\": np.zeros((horizon, N), dtype=np.float32),\n", - " \"positions_x\": np.zeros((horizon, N), dtype=np.float32),\n", - " \"positions_y\": np.zeros((horizon, N), dtype=np.float32),\n", - " }\n", - "\n", - " policy.eval()\n", - " for t in range(horizon):\n", - " obs_t = torch.FloatTensor(obs).to(device)\n", - " with torch.no_grad():\n", - " logits_list, val = policy(obs_t)\n", - " act, logp, entr = sample_logits(logits_list, deterministic=deterministic)\n", - "\n", - " buffers[\"obs\"][t] = obs\n", - " buffers[\"actions\"][t] = act.cpu().numpy().reshape(N) if act.dim() > 1 else act.cpu().numpy()\n", - " buffers[\"values\"][t] = val.squeeze().cpu().numpy()\n", - " buffers[\"logprobs\"][t] = logp.cpu().numpy()\n", - " buffers[\"entropy\"][t] = entr.cpu().numpy()\n", - "\n", - " # Get positions\n", - " gstate = env.get_global_agent_state()\n", - " buffers[\"positions_x\"][t] = gstate[\"x\"]\n", - " buffers[\"positions_y\"][t] = gstate[\"y\"]\n", - "\n", - " # Step env\n", - " env_actions = act.cpu().numpy().reshape(ACT_SHAPE)\n", - " obs, rew, term, trunc, info = env.step(env_actions)\n", - " buffers[\"rewards\"][t] = rew\n", - " buffers[\"terminals\"][t] = term\n", - " buffers[\"truncations\"][t] = trunc\n", - "\n", - " return buffers\n", - "\n", - "\n", - "print(\"Running stochastic rollout...\")\n", - "buf_stoch = run_rollout(env, policy, deterministic=False)\n", - "print(\"Running deterministic rollout...\")\n", - "buf_det = run_rollout(env, policy, deterministic=True)\n", - "\n", - "for name, buf in [(\"Stochastic\", buf_stoch), (\"Deterministic\", buf_det)]:\n", - " print(f\"\\n--- {name} ---\")\n", - " print(f\" Reward: mean={buf['rewards'].mean():.5f}, std={buf['rewards'].std():.5f}\")\n", - " print(f\" Value: mean={buf['values'].mean():.5f}, std={buf['values'].std():.5f}\")\n", - " print(f\" Entropy: mean={buf['entropy'].mean():.4f}\")\n", - " print(f\" Terminals: {buf['terminals'].sum():.0f}, Truncations: {buf['truncations'].sum():.0f}\")" - ] - }, - { - "cell_type": "markdown", - "id": "cell-6", - "metadata": {}, - "source": [ - "## Observation analysis" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cell-7", - "metadata": {}, - "outputs": [], - "source": [ - "from pufferlib.viz import unpack_obs, plot_observation, plot_simulator_state\n", - "\n", - "# Ego-centric observation at t=50 for tracked agent\n", - "sample_t = min(50, HORIZON - 1)\n", - "sample_obs = buf_stoch[\"obs\"][sample_t : sample_t + 1, TRACKED_AGENT : TRACKED_AGENT + 1][0]\n", - "print(dyn_model, tgt_type, rew_cond, n_tgt_wp)\n", - "img = plot_observation(\n", - " sample_obs,\n", - " target_type=tgt_type,\n", - " reward_conditioning=rew_cond,\n", - " num_target_waypoints=n_tgt_wp,\n", - " obs_slots_partners_n=env.obs_slots_partners_n,\n", - " obs_slots_lane_n=env.obs_slots_lane_kept,\n", - " obs_slots_boundary_n=env.obs_slots_boundary_kept,\n", - " obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,\n", - " obs_norm_goal_offset_m=env.obs_norm_goal_offset_m,\n", - " obs_norm_xy_offset_m=env.obs_norm_xy_offset_m,\n", - " obs_norm_veh_width_m=env.obs_norm_veh_width_m,\n", - " obs_norm_veh_length_m=env.obs_norm_veh_length_m,\n", - ")\n", - "plt.figure(figsize=(10, 10))\n", - "plt.imshow(img)\n", - "plt.axis(\"off\")\n", - "plt.title(f\"Ego-centric obs | agent={TRACKED_AGENT}, t={sample_t}\")\n", - "plt.show()\n", - "\n", - "# BEV simulator state\n", - "scenarios = env.get_state()\n", - "if scenarios and len(scenarios) > 0:\n", - " img_bev = plot_simulator_state(scenarios[0], timestep=0)\n", - " plt.figure(figsize=(10, 10))\n", - " plt.imshow(img_bev)\n", - " plt.axis(\"off\")\n", - " plt.title(\"BEV Simulator State\")\n", - " plt.show()\n", - "\n", - "# Ego feature time series for tracked agent\n", - "ego_features_over_time = []\n", - "for t in range(HORIZON):\n", - " ego, *_ = unpack_obs(\n", - " buf_stoch[\"obs\"][t : t + 1, TRACKED_AGENT : TRACKED_AGENT + 1][0],\n", - " target_type=tgt_type,\n", - " reward_conditioning=rew_cond,\n", - " num_target_waypoints=n_tgt_wp,\n", - " obs_slots_partners_n=env.obs_slots_partners_n,\n", - " obs_slots_lane_n=env.obs_slots_lane_kept,\n", - " obs_slots_boundary_n=env.obs_slots_boundary_kept,\n", - " obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,\n", - " )\n", - " ego_features_over_time.append(ego)\n", - "ego_ts = np.array(ego_features_over_time)\n", - "\n", - "if dyn_model == \"jerk\":\n", - " labels = [\"speed\", \"width\", \"length\", \"steering\", \"a_long\", \"a_lat\", \"lcenter\", \"lalign\", \"speed_limit\"]\n", - " plot_idxs = [0, 3, 4, 5] # speed, steering, a_long, a_lat\n", - "else:\n", - " labels = [\"speed\", \"width\", \"length\", \"lcenter\", \"lalign\", \"speed_limit\"]\n", - " plot_idxs = [0, 3, 4, 5] # speed, lcenter, lalign, speed_limit\n", - "\n", - "fig, axes = plt.subplots(len(plot_idxs), 1, figsize=(14, 3 * len(plot_idxs)), sharex=True)\n", - "for i, idx in enumerate(plot_idxs):\n", - " axes[i].plot(ego_ts[:, idx])\n", - " print(ego_ts[10:, idx].argmin())\n", - " axes[i].set_ylabel(labels[idx])\n", - " axes[i].grid(True, alpha=0.3)\n", - "axes[-1].set_xlabel(\"Step\")\n", - "fig.suptitle(f\"Ego features over time | agent={TRACKED_AGENT}\", fontsize=14)\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "czs3aiuhgyo", - "metadata": {}, - "source": [ - "## Observation layer breakdown\n", - "\n", - "Obs layout (all ego-centric, normalized):\n", - "- **Ego**: speed, width, length, [jerk: steering, a_long, a_lat], lane_center_dist, lane_angle, speed_limit\n", - "- **Conditioning** (if enabled): 17 reward coefs (goal_radius, goal_speed, collision, offroad, comfort, lane_align, vel_align, lane_center, center_bias, velocity, reverse, stop_line, timestep, overspeed, throttle, steer, acc) + target waypoints\n", - "- **Target**: static=rel_x,rel_y,rel_z per waypoint; dynamic=rel_x,rel_y,rel_z,heading_cos,heading_sin per waypoint\n", - "- **Partners** (MAX_PARTNERS x 9): rel_x, rel_y, rel_z, length, width, heading_cos, heading_sin, rel_vx, rel_vy\n", - "- **Lanes** (MAX_LANES x 7): rel_x, rel_y, rel_z, seg_length, seg_width, dir_cos, dir_sin\n", - "- **Boundaries** (MAX_BOUNDS x 7): same as lanes\n", - "- **Traffic controls** (MAX_TRAFFIC x 7): rel_x1, rel_y1, rel_x2, rel_y2, rel_z, type, state" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ytjdjnb471l", - "metadata": {}, - "outputs": [], - "source": [ - "from pufferlib.viz import unpack_obs\n", - "\n", - "sample_t = min(50, HORIZON - 1)\n", - "sample_obs = buf_stoch[\"obs\"][sample_t : sample_t + 1, TRACKED_AGENT : TRACKED_AGENT + 1][0]\n", - "ego, target, partners, lanes, boundaries, traffic_controls = unpack_obs(\n", - " sample_obs,\n", - " target_type=tgt_type,\n", - " reward_conditioning=rew_cond,\n", - " num_target_waypoints=n_tgt_wp,\n", - " obs_slots_partners_n=env.obs_slots_partners_n,\n", - " obs_slots_lane_n=env.obs_slots_lane_kept,\n", - " obs_slots_boundary_n=env.obs_slots_boundary_kept,\n", - " obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,\n", - ")\n", - "\n", - "# Also unpack conditioning manually (unpack_obs doesn't return it separately)\n", - "ego_dim = binding.EGO_FEATURES\n", - "cond_dim = binding.NUM_REWARD_COEFS if rew_cond else 0\n", - "cond_obs = sample_obs[0, ego_dim : ego_dim + cond_dim] if cond_dim > 0 else None\n", - "\n", - "\n", - "# --- Print all layer shapes + stats ---\n", - "def layer_stats(name, arr):\n", - " flat = arr.flatten() if hasattr(arr, \"flatten\") else np.array(arr).flatten()\n", - " if flat.size == 0:\n", - " print(f\"{name:>14s}: shape={str(list(arr.shape)):>16s} (empty)\")\n", - " return\n", - " nonzero = np.count_nonzero(flat)\n", - " print(\n", - " f\"{name:>14s}: shape={str(list(arr.shape)):>16s} \"\n", - " f\"nonzero={nonzero:>5d}/{flat.size:<5d} \"\n", - " f\"range=[{flat.min():.4f}, {flat.max():.4f}] \"\n", - " f\"mean={flat.mean():.4f} std={flat.std():.4f}\"\n", - " )\n", - "\n", - "\n", - "print(f\"--- Observation breakdown at t={sample_t}, agent={TRACKED_AGENT} ---\")\n", - "print(f\"Total obs dim: {sample_obs.shape[-1]}\")\n", - "print()\n", - "layer_stats(\"Ego\", ego)\n", - "if cond_obs is not None:\n", - " layer_stats(\"Conditioning\", cond_obs)\n", - "layer_stats(\"Target\", target)\n", - "layer_stats(\"Partners\", partners)\n", - "layer_stats(\"Lanes\", lanes)\n", - "layer_stats(\"Boundaries\", boundaries)\n", - "layer_stats(\"TrafficControls\", traffic_controls)\n", - "\n", - "# --- Ego features detail ---\n", - "ego_labels = EGO_LABELS\n", - "\n", - "print(f\"\\n--- Ego features ---\")\n", - "for i, (label, val) in enumerate(zip(ego_labels, ego)):\n", - " print(f\" [{i}] {label:>14s} = {val:.4f}\")\n", - "\n", - "# --- Conditioning detail ---\n", - "if cond_obs is not None:\n", - " cond_labels = COEF_NAMES\n", - " print(f\"\\n--- Conditioning (reward coefs, normalized) ---\")\n", - " for i, (label, val) in enumerate(zip(cond_labels, cond_obs)):\n", - " print(f\" [{i:>2d}] {label:>16s} = {val:.4f}\")\n", - "\n", - "# --- Target waypoints ---\n", - "tgt_feat = binding.STATIC_TARGET_FEATURES if tgt_type == \"static\" else binding.DYNAMIC_TARGET_FEATURES\n", - "if tgt_type == \"static\":\n", - " tgt_labels = [\"rel_x\", \"rel_y\", \"rel_z\"]\n", - "else:\n", - " tgt_labels = [\"rel_x\", \"rel_y\", \"rel_z\", \"heading_cos\", \"heading_sin\"]\n", - "\n", - "print(f\"\\n--- Target waypoints (n={n_tgt_wp}, type={tgt_type}) ---\")\n", - "for wp in range(target.shape[0]):\n", - " vals = \", \".join(f\"{tgt_labels[j]}={target[wp, j]:.4f}\" for j in range(tgt_feat))\n", - " active = \"ACTIVE\" if not np.allclose(target[wp], 0) else \"zeroed\"\n", - " print(f\" wp[{wp}]: {vals} ({active})\")\n", - "\n", - "# --- Partner summary ---\n", - "n_visible = np.sum(np.any(partners != 0, axis=1))\n", - "print(f\"\\n--- Partners: {n_visible}/{partners.shape[0]} visible ---\")\n", - "partner_labels = [\"rel_x\", \"rel_y\", \"rel_z\", \"length\", \"width\", \"heading_cos\", \"heading_sin\", \"rel_vx\", \"rel_vy\"]\n", - "for p in range(min(int(n_visible), 5)):\n", - " vals = \", \".join(f\"{partner_labels[j]}={partners[p, j]:.3f}\" for j in range(env.partner_features))\n", - " print(f\" [{p}] {vals}\")\n", - "if n_visible > 5:\n", - " print(f\" ... ({n_visible - 5} more)\")\n", - "\n", - "# --- Lane/boundary occupancy ---\n", - "n_lanes = np.sum(np.any(lanes != 0, axis=1))\n", - "n_bounds = np.sum(np.any(boundaries != 0, axis=1))\n", - "print(f\"\\n--- Road: {n_lanes}/{lanes.shape[0]} lane segs, {n_bounds}/{boundaries.shape[0]} boundary segs ---\")\n", - "\n", - "# --- Traffic ---\n", - "n_traffic = np.sum(np.any(traffic_controls != 0, axis=1))\n", - "print(f\"\\n--- Traffic controls: {n_traffic}/{traffic_controls.shape[0]} visible ---\")\n", - "traffic_labels = [\"rel_x1\", \"rel_y1\", \"rel_x2\", \"rel_y2\", \"rel_z\", \"type\", \"state\"]\n", - "for t in range(min(int(n_traffic), 5)):\n", - " vals = \", \".join(\n", - " f\"{traffic_labels[j]}={traffic_controls[t, j]:.3f}\"\n", - " for j in range(min(len(traffic_labels), traffic_controls.shape[1]))\n", - " )\n", - " print(f\" [{t}] {vals}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "65bg95pn7dp", - "metadata": {}, - "outputs": [], - "source": [ - "# --- Layer-level stats across ALL agents at sample_t ---\n", - "all_obs = buf_stoch[\"obs\"][sample_t] # (N, obs_dim)\n", - "\n", - "ego_dim = binding.EGO_FEATURES\n", - "cond_dim = binding.NUM_REWARD_COEFS if rew_cond else 0\n", - "tgt_feat = binding.STATIC_TARGET_FEATURES if tgt_type == \"static\" else binding.DYNAMIC_TARGET_FEATURES\n", - "tgt_dim = n_tgt_wp * tgt_feat\n", - "partner_dim = env.obs_slots_partners_n * env.partner_features\n", - "lane_dim = env.obs_slots_lane_kept * env.road_features\n", - "boundary_dim = env.obs_slots_boundary_kept * env.road_features\n", - "traffic_dim = env.obs_slots_traffic_controls_n * env.traffic_control_features\n", - "\n", - "# Slice indices\n", - "idx = 0\n", - "slices = {}\n", - "slices[\"ego\"] = (idx, idx + ego_dim)\n", - "idx += ego_dim\n", - "if cond_dim > 0:\n", - " slices[\"conditioning\"] = (idx, idx + cond_dim)\n", - " idx += cond_dim\n", - "slices[\"target\"] = (idx, idx + tgt_dim)\n", - "idx += tgt_dim\n", - "slices[\"partners\"] = (idx, idx + partner_dim)\n", - "idx += partner_dim\n", - "slices[\"lanes\"] = (idx, idx + lane_dim)\n", - "idx += lane_dim\n", - "slices[\"boundaries\"] = (idx, idx + boundary_dim)\n", - "idx += boundary_dim\n", - "slices[\"traffic\"] = (idx, idx + traffic_dim)\n", - "idx += traffic_dim\n", - "\n", - "print(f\"Obs dim used: {idx} / {all_obs.shape[1]}\")\n", - "print(\n", - " f\"\\n{'Layer':>14s} | {'start':>5s}-{'end':>5s} | {'dim':>5s} | {'mean':>8s} | {'std':>8s} | {'min':>8s} | {'max':>8s} | {'%nonzero':>8s}\"\n", - ")\n", - "print(\"-\" * 95)\n", - "for name, (s, e) in slices.items():\n", - " chunk = all_obs[:, s:e]\n", - " nz_pct = 100 * np.count_nonzero(chunk) / chunk.size\n", - " print(\n", - " f\"{name:>14s} | {s:>5d}-{e:>5d} | {e - s:>5d} | {chunk.mean():>8.4f} | {chunk.std():>8.4f} | \"\n", - " f\"{chunk.min():>8.4f} | {chunk.max():>8.4f} | {nz_pct:>7.1f}%\"\n", - " )\n", - "\n", - "# --- Plots ---\n", - "n_layers = len(slices)\n", - "fig, axes = plt.subplots(2, (n_layers + 1) // 2, figsize=(5 * ((n_layers + 1) // 2), 8))\n", - "axes = axes.flatten()\n", - "\n", - "for i, (name, (s, e)) in enumerate(slices.items()):\n", - " chunk = all_obs[:, s:e].flatten()\n", - " # Filter out exact zeros for histogram readability on sparse layers\n", - " nonzero_vals = chunk[chunk != 0]\n", - " if len(nonzero_vals) > 0:\n", - " axes[i].hist(nonzero_vals, bins=50, edgecolor=\"black\", alpha=0.7)\n", - " axes[i].set_title(f\"{name} (nonzero only, {len(nonzero_vals)}/{len(chunk)})\")\n", - " else:\n", - " axes[i].hist(chunk, bins=50, edgecolor=\"black\", alpha=0.7)\n", - " axes[i].set_title(f\"{name} (all zeros)\")\n", - " axes[i].set_xlabel(\"Value\")\n", - "\n", - "# Hide unused axes\n", - "for j in range(i + 1, len(axes)):\n", - " axes[j].set_visible(False)\n", - "\n", - "fig.suptitle(f\"Observation distributions across {N} agents at t={sample_t}\", fontsize=14)\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dfm7par7pmo", - "metadata": {}, - "outputs": [], - "source": [ - "# --- Per-feature detail for partners, lanes, boundaries over time (tracked agent) ---\n", - "\n", - "\n", - "def unpack_all_timesteps(bufs, agent_idx):\n", - " \"\"\"Unpack all obs layers across time for one agent.\"\"\"\n", - " H = bufs[\"obs\"].shape[0]\n", - " egos, targets, conds = [], [], []\n", - " n_partners, n_lanes, n_bounds, n_traffic = [], [], [], []\n", - "\n", - " for t in range(H):\n", - " ob = bufs[\"obs\"][t : t + 1, agent_idx : agent_idx + 1][0]\n", - " ego, tgt, part, lane, bnd, tfc = unpack_obs(\n", - " ob,\n", - " target_type=tgt_type,\n", - " reward_conditioning=rew_cond,\n", - " num_target_waypoints=n_tgt_wp,\n", - " obs_slots_partners_n=env.obs_slots_partners_n,\n", - " obs_slots_lane_n=env.obs_slots_lane_kept,\n", - " obs_slots_boundary_n=env.obs_slots_boundary_kept,\n", - " obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,\n", - " )\n", - " egos.append(ego)\n", - " targets.append(tgt)\n", - " n_partners.append(np.sum(np.any(part != 0, axis=1)))\n", - " n_lanes.append(np.sum(np.any(lane != 0, axis=1)))\n", - " n_bounds.append(np.sum(np.any(bnd != 0, axis=1)))\n", - " n_traffic.append(np.sum(np.any(tfc != 0, axis=1)))\n", - "\n", - " if rew_cond:\n", - " ed = binding.EGO_FEATURES\n", - " conds.append(ob[0, ed : ed + binding.NUM_REWARD_COEFS])\n", - "\n", - " return {\n", - " \"ego\": np.array(egos),\n", - " \"target\": np.array(targets),\n", - " \"cond\": np.array(conds) if conds else None,\n", - " \"n_partners\": np.array(n_partners),\n", - " \"n_lanes\": np.array(n_lanes),\n", - " \"n_bounds\": np.array(n_bounds),\n", - " \"n_traffic\": np.array(n_traffic),\n", - " }\n", - "\n", - "\n", - "ts = unpack_all_timesteps(buf_stoch, TRACKED_AGENT)\n", - "\n", - "fig, axes = plt.subplots(2, 2, figsize=(16, 10))\n", - "\n", - "# Occupancy over time\n", - "axes[0, 0].plot(ts[\"n_partners\"], label=\"partners\", alpha=0.8)\n", - "axes[0, 0].plot(ts[\"n_lanes\"], label=\"lanes\", alpha=0.8)\n", - "axes[0, 0].plot(ts[\"n_bounds\"], label=\"boundaries\", alpha=0.8)\n", - "axes[0, 0].plot(ts[\"n_traffic\"], label=\"traffic\", alpha=0.8)\n", - "axes[0, 0].set_xlabel(\"Step\")\n", - "axes[0, 0].set_ylabel(\"Visible count\")\n", - "axes[0, 0].set_title(f\"Obs occupancy over time | agent={TRACKED_AGENT}\")\n", - "axes[0, 0].legend()\n", - "axes[0, 0].grid(True, alpha=0.3)\n", - "\n", - "# Target waypoint distances over time\n", - "tgt_x = ts[\"target\"][:, :, 0]\n", - "tgt_y = ts[\"target\"][:, :, 1]\n", - "tgt_dist = np.sqrt(tgt_x**2 + tgt_y**2)\n", - "for wp in range(n_tgt_wp):\n", - " axes[0, 1].plot(tgt_dist[:, wp], label=f\"wp[{wp}]\", alpha=0.8)\n", - "axes[0, 1].set_xlabel(\"Step\")\n", - "axes[0, 1].set_ylabel(\"Distance (normalized)\")\n", - "axes[0, 1].set_title(\"Target waypoint distance over time\")\n", - "axes[0, 1].legend()\n", - "axes[0, 1].grid(True, alpha=0.3)\n", - "\n", - "# Conditioning heatmap over time\n", - "if ts[\"cond\"] is not None:\n", - " cond_labels = COEF_NAMES\n", - " im = axes[1, 0].imshow(ts[\"cond\"].T, aspect=\"auto\", cmap=\"coolwarm\", interpolation=\"nearest\")\n", - " axes[1, 0].set_yticks(range(len(cond_labels)))\n", - " axes[1, 0].set_yticklabels(cond_labels, fontsize=8)\n", - " axes[1, 0].set_xlabel(\"Step\")\n", - " axes[1, 0].set_title(\"Conditioning coefs over time\")\n", - " plt.colorbar(im, ax=axes[1, 0])\n", - "else:\n", - " axes[1, 0].text(0.5, 0.5, \"No conditioning\", ha=\"center\", va=\"center\", transform=axes[1, 0].transAxes)\n", - " axes[1, 0].set_title(\"Conditioning (disabled)\")\n", - "\n", - "# Partner closest distance over time\n", - "partner_dists = []\n", - "for t in range(HORIZON):\n", - " ob = buf_stoch[\"obs\"][t : t + 1, TRACKED_AGENT : TRACKED_AGENT + 1][0]\n", - " _, _, part, _, _, _ = unpack_obs(\n", - " ob,\n", - " target_type=tgt_type,\n", - " reward_conditioning=rew_cond,\n", - " num_target_waypoints=n_tgt_wp,\n", - " obs_slots_partners_n=env.obs_slots_partners_n,\n", - " obs_slots_lane_n=env.obs_slots_lane_kept,\n", - " obs_slots_boundary_n=env.obs_slots_boundary_kept,\n", - " obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,\n", - " )\n", - " dists = np.sqrt(part[:, 0] ** 2 + part[:, 1] ** 2)\n", - " visible = np.any(part != 0, axis=1)\n", - " partner_dists.append(dists[visible].min() if visible.any() else np.nan)\n", - "\n", - "axes[1, 1].plot(partner_dists, alpha=0.8, color=\"red\")\n", - "axes[1, 1].set_xlabel(\"Step\")\n", - "axes[1, 1].set_ylabel(\"Min partner dist (normalized)\")\n", - "axes[1, 1].set_title(\"Closest partner distance over time\")\n", - "axes[1, 1].grid(True, alpha=0.3)\n", - "\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "j1h4h99tnve", - "metadata": {}, - "outputs": [], - "source": [ - "# --- Spatial scatter: all observed entities in ego frame at sample_t ---\n", - "sample_obs = buf_stoch[\"obs\"][sample_t : sample_t + 1, TRACKED_AGENT : TRACKED_AGENT + 1][0]\n", - "ego, target, partners, lanes, boundaries, traffic_controls = unpack_obs(\n", - " sample_obs,\n", - " target_type=tgt_type,\n", - " reward_conditioning=rew_cond,\n", - " num_target_waypoints=n_tgt_wp,\n", - " obs_slots_partners_n=env.obs_slots_partners_n,\n", - " obs_slots_lane_n=env.obs_slots_lane_kept,\n", - " obs_slots_boundary_n=env.obs_slots_boundary_kept,\n", - " obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n,\n", - ")\n", - "\n", - "fig, ax = plt.subplots(figsize=(10, 10))\n", - "\n", - "# Ego vehicle at origin\n", - "from matplotlib.patches import Rectangle\n", - "\n", - "ax.add_patch(\n", - " Rectangle((-ego[2] / 2, -ego[1] / 2), ego[2], ego[1], facecolor=\"blue\", edgecolor=\"black\", alpha=0.7, zorder=10)\n", - ")\n", - "ax.annotate(\"EGO\", (0, 0), fontsize=9, ha=\"center\", va=\"center\", color=\"white\", fontweight=\"bold\", zorder=11)\n", - "\n", - "# Lane segments\n", - "for i in range(lanes.shape[0]):\n", - " if np.allclose(lanes[i], 0):\n", - " continue\n", - " rx, ry, rz, length, _, dc, ds = lanes[i]\n", - " ax.plot(\n", - " [rx - dc * length / 2, rx + dc * length / 2],\n", - " [ry - ds * length / 2, ry + ds * length / 2],\n", - " color=\"lightgray\",\n", - " linewidth=1,\n", - " zorder=1,\n", - " )\n", - "ax.scatter(\n", - " lanes[np.any(lanes != 0, axis=1), 0],\n", - " lanes[np.any(lanes != 0, axis=1), 1],\n", - " s=5,\n", - " color=\"gray\",\n", - " alpha=0.5,\n", - " label=f\"lanes ({n_lanes})\",\n", - " zorder=2,\n", - ")\n", - "\n", - "# Boundary segments\n", - "for i in range(boundaries.shape[0]):\n", - " if np.allclose(boundaries[i], 0):\n", - " continue\n", - " rx, ry, rz, length, _, dc, ds = boundaries[i]\n", - " ax.plot(\n", - " [rx - dc * length / 2, rx + dc * length / 2],\n", - " [ry - ds * length / 2, ry + ds * length / 2],\n", - " color=\"black\",\n", - " linewidth=1,\n", - " zorder=1,\n", - " )\n", - "bnd_mask = np.any(boundaries != 0, axis=1)\n", - "if bnd_mask.any():\n", - " ax.scatter(\n", - " boundaries[bnd_mask, 0],\n", - " boundaries[bnd_mask, 1],\n", - " s=8,\n", - " color=\"black\",\n", - " alpha=0.6,\n", - " label=f\"boundaries ({n_bounds})\",\n", - " zorder=2,\n", - " )\n", - "\n", - "# Partners\n", - "for i in range(partners.shape[0]):\n", - " if np.allclose(partners[i], 0):\n", - " continue\n", - " rx, ry, rz, length, width, hc, hs, rel_vx, rel_vy = partners[i]\n", - " heading = np.arctan2(hs, hc)\n", - " rect = Rectangle(\n", - " (-length / 2, -width / 2), length, width, facecolor=\"orange\", edgecolor=\"black\", alpha=0.6, zorder=9\n", - " )\n", - " rect.set_transform(plt.matplotlib.transforms.Affine2D().rotate(heading).translate(rx, ry) + ax.transData)\n", - " ax.add_patch(rect)\n", - " ax.annotate(f\"{rel_vx:.2f},{rel_vy:.2f}\", (rx, ry), fontsize=7, ha=\"center\", color=\"darkred\", zorder=12)\n", - "part_mask = np.any(partners != 0, axis=1)\n", - "if part_mask.any():\n", - " ax.scatter(\n", - " partners[part_mask, 0],\n", - " partners[part_mask, 1],\n", - " s=40,\n", - " color=\"orange\",\n", - " edgecolors=\"black\",\n", - " label=f\"partners ({n_visible})\",\n", - " zorder=8,\n", - " )\n", - "\n", - "# Target waypoints\n", - "for wp in range(target.shape[0]):\n", - " if np.allclose(target[wp], 0):\n", - " continue\n", - " marker = \"*\" if wp == 0 else \"o\"\n", - " s = 200 if wp == 0 else 80\n", - " color = \"red\" if wp == 0 else \"salmon\"\n", - " ax.scatter(\n", - " target[wp, 0],\n", - " target[wp, 1],\n", - " color=color,\n", - " marker=marker,\n", - " s=s,\n", - " zorder=15,\n", - " label=f\"target wp[{wp}]\" if wp < 3 else None,\n", - " )\n", - "\n", - "# Traffic controls\n", - "for i in range(traffic_controls.shape[0]):\n", - " if np.allclose(traffic_controls[i], 0):\n", - " continue\n", - " x1, y1, x2, y2, _, control_type, state = traffic_controls[i]\n", - " if int(control_type) == binding.TRAFFIC_CONTROL_TYPE_TRAFFIC_LIGHT:\n", - " state_colors = {\n", - " binding.TRAFFIC_CONTROL_STATE_UNKNOWN: \"gray\",\n", - " binding.TRAFFIC_CONTROL_STATE_RED: \"red\",\n", - " binding.TRAFFIC_CONTROL_STATE_YELLOW: \"yellow\",\n", - " binding.TRAFFIC_CONTROL_STATE_GREEN: \"green\",\n", - " binding.TRAFFIC_CONTROL_STATE_OFF: \"gray\",\n", - " }\n", - " ax.plot([x1, x2], [y1, y2], color=state_colors.get(int(state), \"gray\"), linewidth=3, zorder=15)\n", - " else:\n", - " accent = \"red\" if int(control_type) == binding.TRAFFIC_CONTROL_TYPE_STOP_SIGN else \"gold\"\n", - " ax.plot([x1, x2], [y1, y2], color=\"black\", linewidth=4, zorder=14)\n", - " ax.plot([x1, x2], [y1, y2], color=accent, linewidth=2.5, linestyle=\"--\", zorder=15)\n", - "\n", - "ax.set_xlim(-1, 1)\n", - "ax.set_ylim(-1, 1)\n", - "ax.set_aspect(\"equal\")\n", - "ax.set_xlabel(\"X (ego frame, normalized)\")\n", - "ax.set_ylabel(\"Y (ego frame, normalized)\")\n", - "ax.set_title(f\"All observed entities | agent={TRACKED_AGENT}, t={sample_t}\")\n", - "ax.legend(loc=\"upper right\", fontsize=8)\n", - "ax.grid(True, alpha=0.3)\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "gsyrbg02vc", - "metadata": {}, - "source": [ - "### Ego + conditioning distributions across all agents" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "wwisx2muaj", - "metadata": {}, - "outputs": [], - "source": [ - "# Ego feature distributions across all agents, pooled over full rollout\n", - "ego_dim = binding.EGO_FEATURES\n", - "all_ego = buf_stoch[\"obs\"][:, :, :ego_dim].reshape(-1, ego_dim) # (H*N, ego_dim)\n", - "\n", - "ego_labels = EGO_LABELS\n", - "\n", - "fig, axes = plt.subplots(2, len(ego_labels), figsize=(3.5 * len(ego_labels), 7))\n", - "\n", - "# Row 0: histograms\n", - "for i, label in enumerate(ego_labels):\n", - " vals = all_ego[:, i]\n", - " print(f\"{label}: mean={vals}\")\n", - " axes[0, i].hist(vals, bins=60, edgecolor=\"black\", alpha=0.7, color=\"steelblue\")\n", - " axes[0, i].set_title(label, fontsize=10)\n", - " axes[0, i].set_xlabel(\"\")\n", - " axes[0, i].tick_params(labelsize=7)\n", - " axes[0, i].axvline(vals.mean(), color=\"red\", ls=\"--\", lw=1)\n", - "\n", - "# Row 1: boxplots per-agent (distribution across timesteps for each agent)\n", - "ego_per_agent = buf_stoch[\"obs\"][:, :, :ego_dim] # (H, N, ego_dim)\n", - "for i, label in enumerate(ego_labels):\n", - " data = [ego_per_agent[:, a, i] for a in range(N)]\n", - " bp = axes[1, i].boxplot(\n", - " data,\n", - " showfliers=False,\n", - " patch_artist=True,\n", - " boxprops=dict(facecolor=\"steelblue\", alpha=0.5),\n", - " medianprops=dict(color=\"red\"),\n", - " )\n", - " axes[1, i].set_xlabel(\"Agent\")\n", - " axes[1, i].tick_params(labelsize=7)\n", - " axes[1, i].set_title(f\"{label} per agent\", fontsize=9)\n", - "\n", - "fig.suptitle(\"Ego features: full rollout distributions\", fontsize=13)\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "# Conditioning distributions across all agents (if enabled)\n", - "if rew_cond:\n", - " cond_start = ego_dim\n", - " cond_end = cond_start + binding.NUM_REWARD_COEFS\n", - " all_cond = buf_stoch[\"obs\"][:, :, cond_start:cond_end].reshape(-1, binding.NUM_REWARD_COEFS)\n", - "\n", - " cond_labels = COEF_NAMES\n", - "\n", - " fig, ax = plt.subplots(figsize=(14, 5))\n", - " parts = ax.violinplot(\n", - " [all_cond[:, i] for i in range(binding.NUM_REWARD_COEFS)],\n", - " positions=range(binding.NUM_REWARD_COEFS),\n", - " showmeans=True,\n", - " showmedians=True,\n", - " )\n", - " ax.set_xticks(range(binding.NUM_REWARD_COEFS))\n", - " ax.set_xticklabels(cond_labels, rotation=45, ha=\"right\", fontsize=9)\n", - " ax.set_ylabel(\"Normalized value\")\n", - " ax.set_title(\"Conditioning coef distributions (all agents, full rollout)\")\n", - " ax.grid(True, alpha=0.3, axis=\"y\")\n", - " plt.tight_layout()\n", - " plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "fu6dbejpxdu", - "metadata": {}, - "source": [ - "### Partner per-feature distributions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6g55putjnwo", - "metadata": {}, - "outputs": [], - "source": [ - "# Partner per-feature distributions (pooled over all agents + timesteps, visible only)\n", - "partner_labels = [\"rel_x\", \"rel_y\", \"rel_z\", \"length\", \"width\", \"heading_cos\", \"heading_sin\", \"rel_vx\", \"rel_vy\"]\n", - "obs_slots_partners_n = env.obs_slots_partners_n\n", - "pf = env.partner_features\n", - "\n", - "# Compute slices\n", - "_ego_d = binding.EGO_FEATURES\n", - "_cond_d = binding.NUM_REWARD_COEFS if rew_cond else 0\n", - "_tgt_f = binding.STATIC_TARGET_FEATURES if tgt_type == \"static\" else binding.DYNAMIC_TARGET_FEATURES\n", - "_tgt_d = n_tgt_wp * _tgt_f\n", - "_p_start = _ego_d + _cond_d + _tgt_d\n", - "_p_end = _p_start + obs_slots_partners_n * pf\n", - "\n", - "all_partners = buf_stoch[\"obs\"][:, :, _p_start:_p_end].reshape(\n", - " -1, obs_slots_partners_n, pf\n", - ") # (H*N, obs_slots_partners_n, 9)\n", - "# Mask: partner is visible if any feature != 0\n", - "visible_mask = np.any(all_partners != 0, axis=2) # (H*N, 16)\n", - "visible_partners = all_partners[visible_mask] # (K, 9) — all visible partner observations\n", - "\n", - "print(\n", - " f\"Total partner obs: {all_partners.shape[0] * obs_slots_partners_n}, visible: {len(visible_partners)} \"\n", - " f\"({100 * len(visible_partners) / (all_partners.shape[0] * obs_slots_partners_n):.1f}%)\"\n", - ")\n", - "\n", - "fig, axes = plt.subplots(2, 5, figsize=(21, 8))\n", - "axes = axes.flatten()\n", - "\n", - "for i, label in enumerate(partner_labels):\n", - " vals = visible_partners[:, i]\n", - " axes[i].hist(vals, bins=80, edgecolor=\"black\", alpha=0.7, color=\"darkorange\")\n", - " axes[i].set_title(f\"{label} (n={len(vals)})\", fontsize=10)\n", - " axes[i].axvline(vals.mean(), color=\"red\", ls=\"--\", lw=1, label=f\"mean={vals.mean():.3f}\")\n", - " axes[i].legend(fontsize=7)\n", - " axes[i].tick_params(labelsize=7)\n", - "\n", - "# rel_x vs rel_y scatter in last panel\n", - "pos_ax = axes[len(partner_labels)]\n", - "pos_ax.scatter(visible_partners[:, 0], visible_partners[:, 1], s=1, alpha=0.15, color=\"darkorange\")\n", - "pos_ax.set_xlabel(\"rel_x\")\n", - "pos_ax.set_ylabel(\"rel_y\")\n", - "pos_ax.set_title(\"Partner positions (ego frame)\")\n", - "pos_ax.set_aspect(\"equal\")\n", - "pos_ax.grid(True, alpha=0.3)\n", - "\n", - "fig.suptitle(\"Partner features: all visible, full rollout\", fontsize=13)\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "# Partner count distribution across (timestep, agent)\n", - "partner_counts = visible_mask.sum(axis=1) # (H*N,)\n", - "fig, axes = plt.subplots(1, 2, figsize=(12, 4))\n", - "axes[0].hist(partner_counts, bins=range(obs_slots_partners_n + 2), edgecolor=\"black\", alpha=0.7, color=\"darkorange\")\n", - "axes[0].set_xlabel(\"Visible partners\")\n", - "axes[0].set_ylabel(\"Count\")\n", - "axes[0].set_title(\"Partner count distribution (per agent per step)\")\n", - "\n", - "# Partner distance distribution\n", - "dists = np.sqrt(visible_partners[:, 0] ** 2 + visible_partners[:, 1] ** 2)\n", - "axes[1].hist(dists, bins=80, edgecolor=\"black\", alpha=0.7, color=\"coral\")\n", - "axes[1].set_xlabel(\"Distance (normalized)\")\n", - "axes[1].set_ylabel(\"Count\")\n", - "axes[1].set_title(f\"Partner distance distribution (mean={dists.mean():.3f})\")\n", - "axes[1].axvline(dists.mean(), color=\"red\", ls=\"--\", lw=1)\n", - "\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "5n8a8dwmna3", - "metadata": {}, - "source": [ - "### Road (lanes + boundaries) and target distributions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4xgikpryeg", - "metadata": {}, - "outputs": [], - "source": [ - "# Road per-feature distributions (lanes + boundaries)\n", - "road_labels = [\"rel_x\", \"rel_y\", \"rel_z\", \"seg_length\", \"seg_width\", \"dir_cos\", \"dir_sin\"]\n", - "rf = env.road_features\n", - "max_lanes = env.obs_slots_lane_kept\n", - "max_bounds = env.obs_slots_boundary_kept\n", - "\n", - "_l_start = _p_end\n", - "_l_end = _l_start + max_lanes * rf\n", - "_b_start = _l_end\n", - "_b_end = _b_start + max_bounds * rf\n", - "\n", - "all_lanes = buf_stoch[\"obs\"][:, :, _l_start:_l_end].reshape(-1, max_lanes, rf)\n", - "all_bounds = buf_stoch[\"obs\"][:, :, _b_start:_b_end].reshape(-1, max_bounds, rf)\n", - "\n", - "vis_lanes = all_lanes[np.any(all_lanes != 0, axis=2)]\n", - "vis_bounds = all_bounds[np.any(all_bounds != 0, axis=2)]\n", - "\n", - "print(\n", - " f\"Lanes: {len(vis_lanes)} visible / {all_lanes.shape[0] * max_lanes} total \"\n", - " f\"({100 * len(vis_lanes) / (all_lanes.shape[0] * max_lanes):.1f}%)\"\n", - ")\n", - "print(\n", - " f\"Boundaries: {len(vis_bounds)} visible / {all_bounds.shape[0] * max_bounds} total \"\n", - " f\"({100 * len(vis_bounds) / (all_bounds.shape[0] * max_bounds):.1f}%)\"\n", - ")\n", - "\n", - "fig, axes = plt.subplots(2, 7, figsize=(28, 8))\n", - "for i, label in enumerate(road_labels):\n", - " # Lanes\n", - " axes[0, i].hist(vis_lanes[:, i], bins=80, edgecolor=\"black\", alpha=0.7, color=\"silver\")\n", - " axes[0, i].set_title(f\"lane {label}\", fontsize=9)\n", - " axes[0, i].axvline(vis_lanes[:, i].mean(), color=\"red\", ls=\"--\", lw=1)\n", - " axes[0, i].tick_params(labelsize=7)\n", - " # Boundaries\n", - " axes[1, i].hist(vis_bounds[:, i], bins=80, edgecolor=\"black\", alpha=0.7, color=\"dimgray\")\n", - " axes[1, i].set_title(f\"boundary {label}\", fontsize=9)\n", - " axes[1, i].axvline(vis_bounds[:, i].mean(), color=\"red\", ls=\"--\", lw=1)\n", - " axes[1, i].tick_params(labelsize=7)\n", - "\n", - "fig.suptitle(\"Road features: all visible, full rollout (top=lanes, bottom=boundaries)\", fontsize=13)\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "# Spatial scatter: lane vs boundary positions (pooled)\n", - "fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n", - "axes[0].scatter(vis_lanes[:, 0], vis_lanes[:, 1], s=0.5, alpha=0.05, color=\"gray\")\n", - "axes[0].set_xlabel(\"rel_x\")\n", - "axes[0].set_ylabel(\"rel_y\")\n", - "axes[0].set_title(f\"Lane segment positions (n={len(vis_lanes)})\")\n", - "axes[0].set_aspect(\"equal\")\n", - "axes[0].grid(True, alpha=0.3)\n", - "\n", - "axes[1].scatter(vis_bounds[:, 0], vis_bounds[:, 1], s=0.5, alpha=0.05, color=\"black\")\n", - "axes[1].set_xlabel(\"rel_x\")\n", - "axes[1].set_ylabel(\"rel_y\")\n", - "axes[1].set_title(f\"Boundary segment positions (n={len(vis_bounds)})\")\n", - "axes[1].set_aspect(\"equal\")\n", - "axes[1].grid(True, alpha=0.3)\n", - "\n", - "# Lane + boundary segment length comparison\n", - "axes[2].hist(vis_lanes[:, 2], bins=80, alpha=0.6, color=\"silver\", edgecolor=\"black\", label=\"lanes\")\n", - "axes[2].hist(vis_bounds[:, 2], bins=80, alpha=0.6, color=\"dimgray\", edgecolor=\"black\", label=\"boundaries\")\n", - "axes[2].set_xlabel(\"Segment length (normalized)\")\n", - "axes[2].set_ylabel(\"Count\")\n", - "axes[2].set_title(\"Segment length distribution\")\n", - "axes[2].legend()\n", - "\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "# Target distributions across all agents, full rollout\n", - "_tgt_start = _ego_d + _cond_d\n", - "_tgt_end = _tgt_start + _tgt_d\n", - "all_target = buf_stoch[\"obs\"][:, :, _tgt_start:_tgt_end].reshape(-1, n_tgt_wp, _tgt_f)\n", - "\n", - "if tgt_type == \"static\":\n", - " tgt_flabels = [\"rel_x\", \"rel_y\", \"rel_z\"]\n", - "else:\n", - " tgt_flabels = [\"rel_x\", \"rel_y\", \"rel_z\", \"heading_cos\", \"heading_sin\"]\n", - "\n", - "fig, axes = plt.subplots(1, n_tgt_wp + 1, figsize=(5 * (n_tgt_wp + 1), 4))\n", - "\n", - "for wp in range(n_tgt_wp):\n", - " wp_data = all_target[:, wp, :]\n", - " active = np.any(wp_data != 0, axis=1)\n", - " wp_active = wp_data[active]\n", - " dist = np.sqrt(wp_active[:, 0] ** 2 + wp_active[:, 1] ** 2) if len(wp_active) > 0 else np.array([])\n", - " axes[wp].hist(dist, bins=60, edgecolor=\"black\", alpha=0.7, color=[\"red\", \"salmon\", \"lightsalmon\"][wp % 3])\n", - " axes[wp].set_title(f\"wp[{wp}] distance (n={len(wp_active)}/{len(wp_data)})\", fontsize=10)\n", - " axes[wp].set_xlabel(\"Distance (normalized)\")\n", - "\n", - "# All waypoints x-y scatter\n", - "for wp in range(n_tgt_wp):\n", - " wp_data = all_target[:, wp, :]\n", - " active = np.any(wp_data != 0, axis=1)\n", - " wp_active = wp_data[active]\n", - " if len(wp_active) > 0:\n", - " axes[n_tgt_wp].scatter(wp_active[:, 0], wp_active[:, 1], s=1, alpha=0.1, label=f\"wp[{wp}]\")\n", - "axes[n_tgt_wp].set_xlabel(\"rel_x\")\n", - "axes[n_tgt_wp].set_ylabel(\"rel_y\")\n", - "axes[n_tgt_wp].set_title(\"Target positions (ego frame)\")\n", - "axes[n_tgt_wp].set_aspect(\"equal\")\n", - "axes[n_tgt_wp].legend(fontsize=8)\n", - "axes[n_tgt_wp].grid(True, alpha=0.3)\n", - "\n", - "fig.suptitle(\"Target waypoint distributions (all agents, full rollout)\", fontsize=13)\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "csj7ji1j6ed", - "metadata": {}, - "source": [ - "### Observation sparsity and layer occupancy heatmaps" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3xbi2pbsfyv", - "metadata": {}, - "outputs": [], - "source": [ - "# Sparsity heatmap: fraction of nonzero per layer, per agent, over time\n", - "layer_names = [\"partners\", \"lanes\", \"boundaries\"]\n", - "layer_slices = [\n", - " (_p_start, _p_end, env.obs_slots_partners_n, env.partner_features),\n", - " (_l_start, _l_end, env.obs_slots_lane_kept, env.road_features),\n", - " (_b_start, _b_end, env.obs_slots_boundary_kept, env.road_features),\n", - "]\n", - "\n", - "fig, axes = plt.subplots(1, 3, figsize=(20, 5))\n", - "for ax, name, (s, e, n_obj, n_feat) in zip(axes, layer_names, layer_slices):\n", - " # (H, N) -> fraction of visible objects per (timestep, agent)\n", - " raw = buf_stoch[\"obs\"][:, :, s:e].reshape(HORIZON, env.num_agents, n_obj, n_feat)\n", - " occupancy = np.any(raw != 0, axis=3).sum(axis=2) / n_obj # (H, N)\n", - " im = ax.imshow(occupancy.T, aspect=\"auto\", cmap=\"YlOrRd\", interpolation=\"nearest\", vmin=0, vmax=1)\n", - " ax.set_xlabel(\"Step\")\n", - " ax.set_ylabel(\"Agent\")\n", - " ax.set_title(f\"{name} occupancy (frac visible)\")\n", - " plt.colorbar(im, ax=ax)\n", - "\n", - "plt.suptitle(\"Per-layer occupancy heatmaps (fraction of max slots filled)\", fontsize=13)\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "# Per-layer mean occupancy over time\n", - "fig, axes = plt.subplots(1, 2, figsize=(16, 4))\n", - "\n", - "# Mean across agents\n", - "for name, (s, e, n_obj, n_feat) in zip(layer_names, layer_slices):\n", - " raw = buf_stoch[\"obs\"][:, :, s:e].reshape(HORIZON, env.num_agents, n_obj, n_feat)\n", - " occ_mean = np.any(raw != 0, axis=3).sum(axis=2).mean(axis=1) # (H,)\n", - " axes[0].plot(occ_mean, label=name, alpha=0.8)\n", - "axes[0].set_xlabel(\"Step\")\n", - "axes[0].set_ylabel(\"Mean visible count\")\n", - "axes[0].set_title(\"Mean occupancy over time (across agents)\")\n", - "axes[0].legend()\n", - "axes[0].grid(True, alpha=0.3)\n", - "\n", - "# Mean across timesteps (per agent)\n", - "for name, (s, e, n_obj, n_feat) in zip(layer_names, layer_slices):\n", - " raw = buf_stoch[\"obs\"][:, :, s:e].reshape(HORIZON, env.num_agents, n_obj, n_feat)\n", - " occ_per_agent = np.any(raw != 0, axis=3).sum(axis=2).mean(axis=0) # (N,)\n", - " axes[1].bar(range(N), occ_per_agent, alpha=0.5, label=name)\n", - "axes[1].set_xlabel(\"Agent\")\n", - "axes[1].set_ylabel(\"Mean visible count\")\n", - "axes[1].set_title(\"Mean occupancy per agent (across timesteps)\")\n", - "axes[1].legend()\n", - "axes[1].grid(True, alpha=0.3)\n", - "\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "# Full obs sparsity: fraction of zero features per obs dimension, pooled\n", - "all_flat = buf_stoch[\"obs\"].reshape(-1, obs_dim) # (H*N, obs_dim)\n", - "zero_frac = (all_flat == 0).mean(axis=0) # (obs_dim,)\n", - "fig, ax = plt.subplots(figsize=(18, 3))\n", - "ax.bar(range(obs_dim), zero_frac, width=1.0, color=\"steelblue\", alpha=0.7)\n", - "# Annotate layer boundaries\n", - "prev_e = 0\n", - "for name, (s, e) in slices.items():\n", - " ax.axvline(s, color=\"red\", ls=\"--\", lw=0.5, alpha=0.7)\n", - " mid = (s + e) / 2\n", - " ax.text(mid, 1.02, name, ha=\"center\", va=\"bottom\", fontsize=7, rotation=0, color=\"red\")\n", - " prev_e = e\n", - "ax.set_xlim(0, obs_dim)\n", - "ax.set_ylim(0, 1.1)\n", - "ax.set_xlabel(\"Obs dimension index\")\n", - "ax.set_ylabel(\"Fraction zero\")\n", - "ax.set_title(\"Per-dimension sparsity (fraction zero across full rollout)\")\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "cell-8", - "metadata": {}, - "source": [ - "## Policy outputs over time" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cell-9", - "metadata": {}, - "outputs": [], - "source": [ - "# Compute action probs over time for tracked agent (stochastic rollout)\n", - "n_actions = env.single_action_space.nvec[0] if not is_continuous else 1\n", - "action_probs_time = np.zeros((HORIZON, n_actions))\n", - "for t in range(HORIZON):\n", - " obs_t = torch.FloatTensor(buf_stoch[\"obs\"][t : t + 1, TRACKED_AGENT : TRACKED_AGENT + 1][0]).to(device)\n", - " with torch.no_grad():\n", - " logits_list, _ = policy(obs_t)\n", - " logits = logits_list[0] if isinstance(logits_list, (list, tuple)) else logits_list\n", - " action_probs_time[t] = F.softmax(logits, dim=-1).cpu().numpy().flatten()\n", - "\n", - "fig, axes = plt.subplots(2, 2, figsize=(16, 10))\n", - "\n", - "# Action distribution heatmap\n", - "im = axes[0, 0].imshow(action_probs_time.T, aspect=\"auto\", cmap=\"viridis\", interpolation=\"nearest\")\n", - "axes[0, 0].set_xlabel(\"Step\")\n", - "axes[0, 0].set_ylabel(\"Action ID\")\n", - "axes[0, 0].set_title(f\"Action prob heatmap | agent={TRACKED_AGENT}\")\n", - "plt.colorbar(im, ax=axes[0, 0])\n", - "\n", - "# Entropy over time\n", - "axes[0, 1].plot(buf_stoch[\"entropy\"][:, TRACKED_AGENT], label=\"stochastic\", alpha=0.8)\n", - "axes[0, 1].set_xlabel(\"Step\")\n", - "axes[0, 1].set_ylabel(\"Entropy\")\n", - "axes[0, 1].set_title(\"Entropy over time\")\n", - "axes[0, 1].grid(True, alpha=0.3)\n", - "\n", - "# Value over time\n", - "axes[1, 0].plot(buf_stoch[\"values\"][:, TRACKED_AGENT], label=\"stochastic\", alpha=0.8)\n", - "axes[1, 0].plot(buf_det[\"values\"][:, TRACKED_AGENT], label=\"deterministic\", alpha=0.8)\n", - "axes[1, 0].set_xlabel(\"Step\")\n", - "axes[1, 0].set_ylabel(\"Value\")\n", - "axes[1, 0].set_title(\"Value predictions over time\")\n", - "axes[1, 0].legend()\n", - "axes[1, 0].grid(True, alpha=0.3)\n", - "\n", - "# Actions over time: deterministic vs stochastic\n", - "axes[1, 1].step(range(HORIZON), buf_stoch[\"actions\"][:, TRACKED_AGENT], label=\"stochastic\", alpha=0.7)\n", - "axes[1, 1].step(range(HORIZON), buf_det[\"actions\"][:, TRACKED_AGENT], label=\"deterministic\", alpha=0.7)\n", - "axes[1, 1].set_xlabel(\"Step\")\n", - "axes[1, 1].set_ylabel(\"Action\")\n", - "axes[1, 1].set_title(\"Selected action over time\")\n", - "axes[1, 1].legend()\n", - "axes[1, 1].grid(True, alpha=0.3)\n", - "\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "cell-10", - "metadata": {}, - "source": [ - "## Rewards and returns" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cell-11", - "metadata": {}, - "outputs": [], - "source": [ - "fig, axes = plt.subplots(2, 3, figsize=(18, 10))\n", - "\n", - "# Per-step mean reward\n", - "axes[0, 0].plot(buf_stoch[\"rewards\"].mean(axis=1), label=\"stochastic\", alpha=0.8)\n", - "axes[0, 0].plot(buf_det[\"rewards\"].mean(axis=1), label=\"deterministic\", alpha=0.8)\n", - "axes[0, 0].set_xlabel(\"Step\")\n", - "axes[0, 0].set_ylabel(\"Mean reward\")\n", - "axes[0, 0].set_title(\"Mean reward per step\")\n", - "axes[0, 0].legend()\n", - "axes[0, 0].grid(True, alpha=0.3)\n", - "\n", - "# Reward heatmap (stochastic)\n", - "im = axes[0, 1].imshow(buf_stoch[\"rewards\"].T, aspect=\"auto\", cmap=\"RdYlGn\", interpolation=\"nearest\")\n", - "axes[0, 1].set_xlabel(\"Step\")\n", - "axes[0, 1].set_ylabel(\"Agent\")\n", - "axes[0, 1].set_title(\"Reward heatmap (stochastic)\")\n", - "plt.colorbar(im, ax=axes[0, 1])\n", - "\n", - "# Cumulative return per agent\n", - "cum_ret_stoch = buf_stoch[\"rewards\"].sum(axis=0)\n", - "cum_ret_det = buf_det[\"rewards\"].sum(axis=0)\n", - "axes[0, 2].hist(cum_ret_stoch, bins=30, alpha=0.6, label=\"stochastic\", edgecolor=\"black\")\n", - "axes[0, 2].hist(cum_ret_det, bins=30, alpha=0.6, label=\"deterministic\", edgecolor=\"black\")\n", - "axes[0, 2].set_xlabel(\"Cumulative return\")\n", - "axes[0, 2].set_ylabel(\"Count\")\n", - "axes[0, 2].set_title(\"Return distribution across agents\")\n", - "axes[0, 2].legend()\n", - "\n", - "# Reward distribution histogram\n", - "axes[1, 0].hist(buf_stoch[\"rewards\"].flatten(), bins=50, alpha=0.7, edgecolor=\"black\")\n", - "axes[1, 0].set_xlabel(\"Reward\")\n", - "axes[1, 0].set_ylabel(\"Count\")\n", - "axes[1, 0].set_title(\"Per-step reward distribution (stochastic)\")\n", - "axes[1, 0].set_yscale(\"log\")\n", - "\n", - "# Terminal/truncation timeline\n", - "axes[1, 1].plot(buf_stoch[\"terminals\"].sum(axis=1), label=\"terminals\", alpha=0.8)\n", - "axes[1, 1].plot(buf_stoch[\"truncations\"].sum(axis=1), label=\"truncations\", alpha=0.8)\n", - "axes[1, 1].set_xlabel(\"Step\")\n", - "axes[1, 1].set_ylabel(\"Count\")\n", - "axes[1, 1].set_title(\"Terminals/Truncations per step (stochastic)\")\n", - "axes[1, 1].legend()\n", - "axes[1, 1].grid(True, alpha=0.3)\n", - "\n", - "# Tracked agent reward\n", - "axes[1, 2].plot(buf_stoch[\"rewards\"][:, TRACKED_AGENT], label=\"stochastic\", alpha=0.8)\n", - "axes[1, 2].plot(buf_det[\"rewards\"][:, TRACKED_AGENT], label=\"deterministic\", alpha=0.8)\n", - "axes[1, 2].set_xlabel(\"Step\")\n", - "axes[1, 2].set_ylabel(\"Reward\")\n", - "axes[1, 2].set_title(f\"Reward over time | agent={TRACKED_AGENT}\")\n", - "axes[1, 2].legend()\n", - "axes[1, 2].grid(True, alpha=0.3)\n", - "\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "cell-12", - "metadata": {}, - "source": [ - "## Episode metrics" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cell-13", - "metadata": {}, - "outputs": [], - "source": [ - "# Collect episode-level metrics from the C logging\n", - "log_stoch = binding.vec_log(env.c_envs, N)\n", - "\n", - "# eval_mode=1 returns a list of per-env dicts; aggregate by averaging\n", - "if isinstance(log_stoch, list) and log_stoch:\n", - " all_keys = set(k for d in log_stoch for k in d if isinstance(d[k], (int, float)))\n", - " log_stoch = {k: np.mean([d[k] for d in log_stoch if k in d]) for k in all_keys}\n", - "\n", - "if log_stoch:\n", - " print(\"Episode metrics (after stochastic rollout):\")\n", - " for k, v in sorted(log_stoch.items()):\n", - " if isinstance(v, (int, float)):\n", - " print(f\" {k}: {v:.4f}\")\n", - "\n", - " # Bar chart of key metrics\n", - " keys = [\"score\", \"collision_rate\", \"offroad_rate\", \"completion_rate\", \"dnf_rate\"]\n", - " vals = [log_stoch.get(k, 0) for k in keys]\n", - " fig, ax = plt.subplots(figsize=(10, 4))\n", - " bars = ax.bar(keys, vals, edgecolor=\"black\", alpha=0.7, color=[\"green\", \"red\", \"orange\", \"blue\", \"gray\"])\n", - " ax.set_ylabel(\"Rate\")\n", - " ax.set_title(\"Episode Metrics\")\n", - " for bar, v in zip(bars, vals):\n", - " ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01, f\"{v:.3f}\", ha=\"center\", fontsize=10)\n", - " plt.tight_layout()\n", - " plt.show()\n", - "else:\n", - " print(\"No episode metrics available yet (not enough episodes completed)\")" - ] - }, - { - "cell_type": "markdown", - "id": "cell-14", - "metadata": {}, - "source": [ - "## Value predictions vs actual returns" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cell-15", - "metadata": {}, - "outputs": [], - "source": [ - "gamma = config[\"train\"].get(\"gamma\", 0.98)\n", - "lam = config[\"train\"].get(\"gae_lambda\", 0.95)\n", - "\n", - "\n", - "def compute_gae(rewards, values, terminals, truncations, gamma, lam):\n", - " H, N = rewards.shape\n", - " advantages = np.zeros_like(rewards)\n", - " last_gae = np.zeros(N)\n", - " for t in reversed(range(H - 1)):\n", - " done = np.maximum(terminals[t + 1], truncations[t + 1])\n", - " next_non_terminal = 1.0 - done\n", - " delta = rewards[t + 1] + gamma * values[t + 1] * next_non_terminal - values[t]\n", - " last_gae = delta + gamma * lam * last_gae * next_non_terminal\n", - " advantages[t] = last_gae\n", - " return advantages\n", - "\n", - "\n", - "adv_stoch = compute_gae(\n", - " buf_stoch[\"rewards\"], buf_stoch[\"values\"], buf_stoch[\"terminals\"], buf_stoch[\"truncations\"], gamma, lam\n", - ")\n", - "returns_stoch = adv_stoch + buf_stoch[\"values\"]\n", - "\n", - "pred_v = buf_stoch[\"values\"].flatten()\n", - "actual_r = returns_stoch.flatten()\n", - "\n", - "var_actual = np.var(actual_r)\n", - "explained_var = 1 - np.var(actual_r - pred_v) / (var_actual + 1e-8) if var_actual > 1e-8 else 0.0\n", - "\n", - "fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n", - "\n", - "# Scatter: predicted vs actual\n", - "axes[0].scatter(actual_r, pred_v, alpha=0.2, s=5)\n", - "lims = [min(actual_r.min(), pred_v.min()), max(actual_r.max(), pred_v.max())]\n", - "axes[0].plot(lims, lims, \"r--\", label=\"perfect\")\n", - "axes[0].set_xlabel(\"Actual return\")\n", - "axes[0].set_ylabel(\"Predicted value\")\n", - "axes[0].set_title(f\"Value accuracy (EV: {explained_var:.4f})\")\n", - "axes[0].legend()\n", - "axes[0].grid(True, alpha=0.3)\n", - "\n", - "# Value error over time\n", - "value_error = np.abs(returns_stoch - buf_stoch[\"values\"]).mean(axis=1)\n", - "axes[1].plot(value_error)\n", - "axes[1].set_xlabel(\"Step\")\n", - "axes[1].set_ylabel(\"Mean |error|\")\n", - "axes[1].set_title(\"Value prediction error over time\")\n", - "axes[1].grid(True, alpha=0.3)\n", - "\n", - "# Advantage distribution\n", - "axes[2].hist(adv_stoch.flatten(), bins=50, edgecolor=\"black\", alpha=0.7)\n", - "axes[2].set_xlabel(\"Advantage\")\n", - "axes[2].set_ylabel(\"Count\")\n", - "axes[2].set_title(f\"Advantage distribution (std={adv_stoch.std():.4f})\")\n", - "\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "print(f\"Explained variance: {explained_var:.4f}\")\n", - "print(f\"Value MSE: {np.mean((actual_r - pred_v) ** 2):.6f}\")" - ] - }, - { - "cell_type": "markdown", - "id": "cell-16", - "metadata": {}, - "source": [ - "## Agent trajectories" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cell-17", - "metadata": {}, - "outputs": [], - "source": [ - "N_TRAJ = min(16, N) # number of agents to plot\n", - "\n", - "fig, axes = plt.subplots(1, 2, figsize=(16, 7))\n", - "\n", - "for i in range(N_TRAJ):\n", - " color = plt.cm.tab20(i % 20)\n", - " # Stochastic\n", - " axes[0].plot(buf_stoch[\"positions_x\"][:, i], buf_stoch[\"positions_y\"][:, i], alpha=0.6, color=color, linewidth=1)\n", - " axes[0].scatter(\n", - " buf_stoch[\"positions_x\"][0, i], buf_stoch[\"positions_y\"][0, i], color=color, s=30, marker=\"o\", zorder=5\n", - " ) # start\n", - " # Deterministic\n", - " axes[1].plot(buf_det[\"positions_x\"][:, i], buf_det[\"positions_y\"][:, i], alpha=0.6, color=color, linewidth=1)\n", - " axes[1].scatter(buf_det[\"positions_x\"][0, i], buf_det[\"positions_y\"][0, i], color=color, s=30, marker=\"o\", zorder=5)\n", - "\n", - "axes[0].set_title(f\"Stochastic trajectories (N={N_TRAJ})\")\n", - "axes[1].set_title(f\"Deterministic trajectories (N={N_TRAJ})\")\n", - "for ax in axes:\n", - " ax.set_xlabel(\"X\")\n", - " ax.set_ylabel(\"Y\")\n", - " ax.set_aspect(\"equal\")\n", - " ax.grid(True, alpha=0.3)\n", - "\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "# ADE vs ground truth if scenario_length is set\n", - "if config[\"env\"].get(\"scenario_length\"):\n", - " try:\n", - " gt = env.get_ground_truth_trajectories()\n", - " # gt['x'] shape: (N, 1, T), positions shape: (T, N)\n", - " gt_x = gt[\"x\"][:, 0, :].T # (T, N)\n", - " gt_y = gt[\"y\"][:, 0, :].T\n", - " gt_valid = gt[\"valid\"][:, 0, :].T\n", - " T_gt = gt_x.shape[0]\n", - " T_use = min(T_gt, HORIZON)\n", - "\n", - " disp = np.sqrt(\n", - " (buf_stoch[\"positions_x\"][:T_use] - gt_x[:T_use]) ** 2\n", - " + (buf_stoch[\"positions_y\"][:T_use] - gt_y[:T_use]) ** 2\n", - " )\n", - " valid_mask = gt_valid[:T_use] > 0\n", - " if valid_mask.sum() > 0:\n", - " ade = disp[valid_mask].mean()\n", - " print(f\"ADE (stochastic vs ground truth): {ade:.3f}m\")\n", - " ade_per_agent = np.array(\n", - " [disp[:, i][valid_mask[:, i]].mean() for i in range(N) if valid_mask[:, i].sum() > 0]\n", - " )\n", - " plt.figure(figsize=(8, 3))\n", - " plt.hist(ade_per_agent, bins=30, edgecolor=\"black\", alpha=0.7)\n", - " plt.xlabel(\"ADE (m)\")\n", - " plt.ylabel(\"Count\")\n", - " plt.title(f\"Per-agent ADE distribution (mean={ade:.3f}m)\")\n", - " plt.tight_layout()\n", - " plt.show()\n", - " else:\n", - " print(\"No valid ground truth timesteps to compute ADE\")\n", - " except Exception as e:\n", - " print(f\"Could not compute ADE: {e}\")" - ] - }, - { - "cell_type": "markdown", - "id": "ea90af09", - "metadata": {}, - "source": [ - "## Encoder analysis — what the policy encodes\n", - "\n", - "Each obs layer has its own encoder projecting raw features → embedding width:\n", - "- **ego** and **conditioning** (reward coefs + target): single vector, no pooling.\n", - "- **partners / lanes / boundaries / traffic**: per-slot encoder, padded slots masked to `-inf`, then **max-pooled** across slots → one embedding. Fully-padded layers are zeroed.\n", - "\n", - "The max-pool means each embedding dim is \"won\" by exactly one slot (object). Below we inspect:\n", - "1. Encoder inventory (in/out dims, params).\n", - "2. **What survives the max-pool**: which slot wins per dim, per-dim winner entropy (slot-specialized vs. spread), and where the dominant objects sit in ego frame.\n", - "3. **Embedding space**: per-encoder contribution (L2 norm), active/dead dims, silence rate." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7eb06a8d", - "metadata": {}, - "outputs": [], - "source": [ - "# ── Setup: capture per-encoder embeddings + reconstruct the max-pool ──\n", - "bb = policy.actor_backbone\n", - "ego_dim = policy.ego_dim\n", - "PAD = -1.0 # PADDED_OBSERVATION_VALUE\n", - "\n", - "# Flat batch of observations from the stochastic rollout\n", - "obs_flat = buf_stoch[\"obs\"].reshape(-1, obs_dim)\n", - "rng = np.random.default_rng(0)\n", - "sel = rng.choice(obs_flat.shape[0], size=min(4096, obs_flat.shape[0]), replace=False)\n", - "obs_batch = torch.FloatTensor(obs_flat[sel]).to(device)\n", - "B = obs_batch.shape[0]\n", - "\n", - "# Encoder inventory: (name, module, raw_in_features, n_slots, is_set)\n", - "enc_inventory = [(\"ego\", bb.ego_encoder, ego_dim, 1, False)]\n", - "if bb.obs_slots_lane_kept > 0:\n", - " enc_inventory.append((\"lane\", bb.lane_encoder, bb.road_features_count, bb.obs_slots_lane_kept, True))\n", - "if bb.obs_slots_boundary_kept > 0:\n", - " enc_inventory.append((\"boundary\", bb.boundary_encoder, bb.road_features_count, bb.obs_slots_boundary_kept, True))\n", - "if bb.obs_slots_partners_n > 0:\n", - " enc_inventory.append((\"partner\", bb.partner_encoder, bb.partner_features_count, bb.obs_slots_partners_n, True))\n", - "if bb.obs_slots_traffic_controls_n > 0:\n", - " enc_inventory.append(\n", - " (\n", - " \"traffic\",\n", - " bb.traffic_control_encoder,\n", - " bb.traffic_control_features_after_onehot,\n", - " bb.obs_slots_traffic_controls_n,\n", - " True,\n", - " )\n", - " )\n", - "if bb.target_dim > 0:\n", - " enc_inventory.append((\"conditioning\", bb.target_encoder, bb.target_dim, 1, False))\n", - "\n", - "enc_names = [n for n, *_ in enc_inventory]\n", - "set_encs = [n for n, _, _, _, is_set in enc_inventory if is_set]\n", - "\n", - "print(f\"{'encoder':>13s} | {'raw_in':>6s} | {'emb_out':>7s} | {'slots':>5s} | {'pooled':>6s} | {'params':>9s}\")\n", - "print(\"-\" * 66)\n", - "for name, mod, rin, nslots, is_set in enc_inventory:\n", - " nparam = sum(p.numel() for p in mod.parameters())\n", - " print(\n", - " f\"{name:>13s} | {rin:>6d} | {mod[-1].out_features:>7d} | {nslots:>5d} | {('max' if is_set else '-'):>6s} | {nparam:>9,d}\"\n", - " )\n", - "print(\n", - " f\"\\nBackbone input = {sum(mod[-1].out_features for _, mod, _, _, _ in enc_inventory)} -> backbone -> {bb.out_dim}\"\n", - ")\n", - "\n", - "# Capture pre-pool encoder outputs via forward hooks\n", - "captured = {}\n", - "\n", - "\n", - "def _hook(name):\n", - " def fn(m, i, o):\n", - " captured[name] = o.detach()\n", - "\n", - " return fn\n", - "\n", - "\n", - "handles = [mod.register_forward_hook(_hook(name)) for name, mod, *_ in enc_inventory]\n", - "policy.eval()\n", - "with torch.no_grad():\n", - " policy(obs_batch)\n", - "for h in handles:\n", - " h.remove()\n", - "\n", - "# Reconstruct slot slices (same order as DriveBackbone.forward) + pad masks\n", - "partner_dim = bb.obs_slots_partners_n * bb.partner_features_count\n", - "lane_dim = bb.obs_slots_lane_kept * bb.road_features_count\n", - "boundary_dim = bb.obs_slots_boundary_kept * bb.road_features_count\n", - "traffic_dim = bb.obs_slots_traffic_controls_n * bb.traffic_control_features_count\n", - "_s = ego_dim + bb.target_dim\n", - "sl = {}\n", - "sl[\"partner\"] = (_s, _s + partner_dim, bb.obs_slots_partners_n, bb.partner_features_count)\n", - "_s += partner_dim\n", - "sl[\"lane\"] = (_s, _s + lane_dim, bb.obs_slots_lane_kept, bb.road_features_count)\n", - "_s += lane_dim\n", - "sl[\"boundary\"] = (_s, _s + boundary_dim, bb.obs_slots_boundary_kept, bb.road_features_count)\n", - "_s += boundary_dim\n", - "sl[\"traffic\"] = (_s, _s + traffic_dim, bb.obs_slots_traffic_controls_n, bb.traffic_control_features_count)\n", - "_s += traffic_dim\n", - "\n", - "raw, pad, pooled, winners, valid_sample = {}, {}, {}, {}, {}\n", - "for name in set_encs:\n", - " s, e, ns, nf = sl[name]\n", - " obj = obs_batch[:, s:e].view(B, ns, nf)\n", - " raw[name] = obj\n", - " if name == \"traffic\":\n", - " cont = obj[:, :, : bb.traffic_control_continuous_features]\n", - " typ = obj[:, :, bb.traffic_control_continuous_features]\n", - " st = obj[:, :, bb.traffic_control_continuous_features + 1]\n", - " pad[name] = (\n", - " (cont == PAD).all(dim=2)\n", - " & (typ == binding.TRAFFIC_CONTROL_TYPE_NONE)\n", - " & (st == binding.TRAFFIC_CONTROL_STATE_UNKNOWN)\n", - " )\n", - " else:\n", - " pad[name] = (obj == PAD).all(dim=2)\n", - " masked = captured[name].masked_fill(pad[name].unsqueeze(2), -torch.inf)\n", - " vm = (~pad[name]).any(dim=1)\n", - " valid_sample[name] = vm\n", - " winners[name] = masked.max(dim=1).indices # (B, embedding dim): winning slot per dim\n", - " pooled[name] = torch.where(vm.unsqueeze(1), masked.max(dim=1).values, torch.zeros_like(masked.max(dim=1).values))\n", - "\n", - "for name in (\"ego\", \"conditioning\"):\n", - " if name in enc_names:\n", - " pooled[name] = captured[name]\n", - "\n", - "print(\"\\nCaptured embeddings for:\", enc_names)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "01446b9c", - "metadata": {}, - "outputs": [], - "source": [ - "# ── What survives the max-pool: winning slots, specialization, spatial ──\n", - "n = len(set_encs)\n", - "fig, axes = plt.subplots(n, 3, figsize=(18, 4.2 * n))\n", - "if n == 1:\n", - " axes = axes[None, :]\n", - "\n", - "print(f\"{'encoder':>9s} | {'valid%':>6s} | {'mean active slots/dim':>21s} | {'%slot-specialized dims':>22s}\")\n", - "print(\"-\" * 70)\n", - "for r, name in enumerate(set_encs):\n", - " s, e, ns, nf = sl[name]\n", - " vm = valid_sample[name]\n", - " w = winners[name][vm] # (Bv, D)\n", - " D = w.shape[1]\n", - "\n", - " # (1) which slot wins, pooled over all dims+samples\n", - " slot_counts = torch.bincount(w.reshape(-1), minlength=ns).float().cpu().numpy()\n", - " slot_counts = slot_counts / max(slot_counts.sum(), 1)\n", - " axes[r, 0].bar(range(ns), slot_counts, color=\"teal\", alpha=0.85, edgecolor=\"black\")\n", - " axes[r, 0].set_title(f\"{name}: max-pool winner by slot\")\n", - " axes[r, 0].set_xlabel(\"slot index (0 = first/closest)\")\n", - " axes[r, 0].set_ylabel(\"frac of dims won\")\n", - "\n", - " # (2) per-dim winner entropy: slot-specialized (0) vs spread across slots (1)\n", - " onehot = F.one_hot(w, num_classes=ns).float() # (Bv, D, ns)\n", - " p = onehot.mean(dim=0) # (D, ns) winner distribution per dim\n", - " ent = (-(p * (p + 1e-9).log()).sum(dim=1) / np.log(ns)).cpu().numpy()\n", - " axes[r, 1].hist(ent, bins=30, color=\"indianred\", alpha=0.85, edgecolor=\"black\")\n", - " axes[r, 1].set_title(f\"{name}: per-dim winner entropy\")\n", - " axes[r, 1].set_xlabel(\"0 = slot-specialized → 1 = spread\")\n", - " axes[r, 1].set_xlim(0, 1)\n", - "\n", - " # (3) ego-frame position of the dominant object (mode winning slot per sample)\n", - " dom = torch.mode(w, dim=1).values # (Bv,)\n", - " rel = raw[name][vm]\n", - " dom_xy = rel[torch.arange(rel.shape[0]), dom][:, :2].cpu().numpy()\n", - " axes[r, 2].scatter(dom_xy[:, 0], dom_xy[:, 1], s=3, alpha=0.15, color=\"navy\")\n", - " axes[r, 2].scatter(0, 0, marker=\"*\", s=200, color=\"red\", zorder=5, label=\"ego\")\n", - " axes[r, 2].set_title(f\"{name}: dominant object position (ego frame)\")\n", - " axes[r, 2].set_xlabel(\"rel_x\")\n", - " axes[r, 2].set_ylabel(\"rel_y\")\n", - " axes[r, 2].set_aspect(\"equal\")\n", - " axes[r, 2].legend(fontsize=8)\n", - "\n", - " active_per_dim = np.exp(ent * np.log(ns)).mean()\n", - " print(\n", - " f\"{name:>9s} | {100 * vm.float().mean().item():>5.1f}% | {active_per_dim:>21.2f} | {100 * (ent < 0.2).mean():>21.1f}%\"\n", - " )\n", - "\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "\n", - "# ── H1/H2/H3 check: boundary max-pool winner distance vs slot index ──\n", - "if \"boundary\" in set_encs:\n", - " vm = valid_sample[\"boundary\"]\n", - " w = winners[\"boundary\"][vm] # (Bv, D) winning slot per dim\n", - " rb = raw[\"boundary\"][vm] # (Bv, ns, nf) raw segments\n", - " nsb = rb.shape[1]\n", - " reldist = torch.hypot(rb[:, :, 0], rb[:, :, 1]) # (Bv, ns) normalized ego-frame dist\n", - " valid_seg = ~pad[\"boundary\"][vm] # (Bv, ns) slots holding a real segment\n", - "\n", - " win_reldist = torch.gather(reldist, 1, w) # (Bv, D) dist of each winning segment\n", - " slot_flat = w.reshape(-1)\n", - " wdist_flat = win_reldist.reshape(-1)\n", - " total_wins = slot_flat.numel()\n", - "\n", - " print(\"\\n=== Boundary winner distance vs slot index (H1/H2/H3) ===\")\n", - " print(f\"{'slot':>4s} | {'#wins':>8s} | {'win%':>6s} | {'rel_dist winners':>16s} | {'rel_dist occupied':>17s}\")\n", - " print(\"-\" * 64)\n", - " for s in range(nsb):\n", - " wm = slot_flat == s\n", - " nwin = int(wm.sum())\n", - " wmean = wdist_flat[wm].mean().item() if nwin > 0 else float(\"nan\")\n", - " occ = valid_seg[:, s]\n", - " omean = reldist[occ, s].mean().item() if int(occ.sum()) > 0 else float(\"nan\")\n", - " print(f\"{s:>4d} | {nwin:>8d} | {100 * nwin / total_wins:>5.1f}% | {wmean:>16.4f} | {omean:>17.4f}\")\n", - "\n", - " win_mean = wdist_flat.mean().item()\n", - " seg_mean = reldist[valid_seg].mean().item()\n", - " print(f\"\\nMean rel_dist of WINNING segments : {win_mean:.4f}\")\n", - " print(f\"Mean rel_dist of ALL valid segments: {seg_mean:.4f}\")\n", - " verdict = \"FARTHER than avg (H3 supported)\" if win_mean > seg_mean else \"nearer than avg\"\n", - " print(f\"-> winners are {verdict} by {win_mean - seg_mean:+.4f} (normalized units)\")\n", - "\n", - " fig, ax = plt.subplots(1, 2, figsize=(14, 4))\n", - " occ_means = [\n", - " reldist[valid_seg[:, s], s].mean().item() if int(valid_seg[:, s].sum()) > 0 else np.nan for s in range(nsb)\n", - " ]\n", - " win_means = [\n", - " wdist_flat[slot_flat == s].mean().item() if int((slot_flat == s).sum()) > 0 else np.nan for s in range(nsb)\n", - " ]\n", - " ax[0].plot(range(nsb), occ_means, \"o-\", label=\"occupied (any segment in slot)\")\n", - " ax[0].plot(range(nsb), win_means, \"s-\", label=\"winners only\")\n", - " ax[0].axhline(seg_mean, color=\"gray\", ls=\"--\", label=\"global valid mean\")\n", - " ax[0].set_xlabel(\"slot index\")\n", - " ax[0].set_ylabel(\"mean rel_dist (normalized)\")\n", - " ax[0].set_title(\"Boundary rel_dist vs slot index\\n(flat = not distance-sorted -> H1/H2)\")\n", - " ax[0].legend(fontsize=8)\n", - " ax[1].hist(reldist[valid_seg].cpu().numpy(), bins=50, alpha=0.6, density=True, label=\"all valid segs\", color=\"gray\")\n", - " ax[1].hist(wdist_flat.cpu().numpy(), bins=50, alpha=0.6, density=True, label=\"winners\", color=\"crimson\")\n", - " ax[1].axvline(seg_mean, color=\"gray\", ls=\"--\")\n", - " ax[1].axvline(win_mean, color=\"crimson\", ls=\"--\")\n", - " ax[1].set_xlabel(\"rel_dist (normalized)\")\n", - " ax[1].set_title(\"Winner vs all-segment distance (H3)\")\n", - " ax[1].legend(fontsize=8)\n", - " plt.tight_layout()\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "44d421ad", - "metadata": {}, - "outputs": [], - "source": [ - "# ── Embedding space: per-encoder contribution, active/dead dims, silence ──\n", - "fig, axes = plt.subplots(1, 3, figsize=(20, 5))\n", - "\n", - "# (1) Mean L2 norm of each pooled embedding = relative weight in the concat fed to backbone\n", - "norms = [pooled[n].norm(dim=1).mean().item() for n in enc_names]\n", - "axes[0].bar(enc_names, norms, color=\"slateblue\", edgecolor=\"black\")\n", - "axes[0].set_title(\"Mean L2 norm of pooled embedding\\n(relative contribution to backbone input)\")\n", - "axes[0].tick_params(axis=\"x\", rotation=45)\n", - "axes[0].grid(True, axis=\"y\", alpha=0.3)\n", - "\n", - "# (2) Mean |activation| per embedding dim, per encoder\n", - "M = np.stack([pooled[n].abs().mean(0).cpu().numpy() for n in enc_names])\n", - "im = axes[1].imshow(M, aspect=\"auto\", cmap=\"magma\")\n", - "axes[1].set_yticks(range(len(enc_names)))\n", - "axes[1].set_yticklabels(enc_names)\n", - "axes[1].set_xlabel(\"embedding dim\")\n", - "axes[1].set_title(\"Mean |activation| per embedding dim\")\n", - "plt.colorbar(im, ax=axes[1])\n", - "\n", - "# (3) Dead dims (std<1e-4) — capacity the encoder never uses\n", - "dead = [(pooled[n].std(0) < 1e-4).float().mean().item() for n in enc_names]\n", - "axes[2].bar(enc_names, dead, color=\"gray\", edgecolor=\"black\")\n", - "axes[2].set_title(\"Fraction of dead embedding dims (std < 1e-4)\")\n", - "axes[2].tick_params(axis=\"x\", rotation=45)\n", - "axes[2].set_ylim(0, 1)\n", - "axes[2].grid(True, axis=\"y\", alpha=0.3)\n", - "\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "print(f\"{'encoder':>13s} | {'mean|act|':>9s} | {'emb L2':>7s} | {'dead dims':>9s} | {'silence (fully padded)':>22s}\")\n", - "print(\"-\" * 80)\n", - "for name in enc_names:\n", - " silence = (1 - valid_sample[name].float().mean().item()) if name in valid_sample else 0.0\n", - " deadf = (pooled[name].std(0) < 1e-4).float().mean().item()\n", - " print(\n", - " f\"{name:>13s} | {pooled[name].abs().mean().item():>9.4f} | {pooled[name].norm(dim=1).mean().item():>7.3f} | \"\n", - " f\"{100 * deadf:>7.1f}% | {100 * silence:>21.1f}%\"\n", - " )" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/05_inference.py b/notebooks/05_inference.py new file mode 100644 index 0000000000..8a2ffc8013 --- /dev/null +++ b/notebooks/05_inference.py @@ -0,0 +1,1576 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.19.3 +# kernelspec: +# display_name: .venv +# language: python +# name: python3 +# --- + +# %% [markdown] +# # 05 - Model Inference Debug +# End-to-end inference pipeline: config loading, policy forward pass, rollouts (deterministic vs stochastic), observation/reward analysis, value accuracy, trajectories, LSTM state. + +# %% +import numpy as np +import matplotlib.pyplot as plt +import torch +import torch.nn.functional as F +from pufferlib.ocean.drive.drive import Drive +from pufferlib.ocean.drive import binding +from pufferlib.ocean.torch import Drive as DrivePolicy +from pufferlib.pytorch import sample_logits +from notebooks.notebook_utils import COEF_NAMES, EGO_LABELS, MAP_DIR, load_notebook_config, zero_actions + +CHECKPOINT_PATH = "" +ENV_NAME = "puffer_drive" + +config = load_notebook_config(CHECKPOINT_PATH, ENV_NAME) +config["env"]["num_agents"] = 64 +config["env"]["num_maps"] = 8 +config["env"]["eval_mode"] = 1 +config["env"]["map_dir"] = MAP_DIR + +config["env"]["obs_slots_boundary_n"] = 80 +config["env"]["obs_slots_lane_n"] = 80 +config["env"]["obs_dropout_lane"] = 0.0 +config["env"]["obs_dropout_boundary"] = 0.0 + +env = Drive(**config["env"]) +obs, info = env.reset(seed=42) +N = env.num_agents + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +policy = DrivePolicy(env, **config["policy"]).to(device) + +if CHECKPOINT_PATH: + sd = torch.load(CHECKPOINT_PATH, map_location=device) + sd = {k.replace("module.", ""): v for k, v in sd.items()} + policy.load_state_dict(sd) + print(f"Loaded checkpoint: {CHECKPOINT_PATH}") + +is_continuous = policy.is_continuous +ACT_SHAPE = (N, len(env.single_action_space.nvec)) if not is_continuous else (N, env.single_action_space.shape[0]) + +print(f"Policy on {device}, params: {sum(p.numel() for p in policy.parameters()):,}") +print(f"Obs shape: {obs.shape}, Action space: {env.single_action_space}") +print(f"Config: dynamics={config['env']['dynamics_model']}, action={config['env']['action_type']}") + +# %% [markdown] +# ## Single-step policy output + +# %% +# Take one step to get fresh obs +actions = zero_actions(env) +obs, rew, term, trunc, info = env.step(actions) + +obs_tensor = torch.FloatTensor(obs).to(device) +policy.eval() + +with torch.no_grad(): + logits_list, value = policy(obs_tensor) + +# Sample actions +action, logprob, ent = sample_logits(logits_list) +action_det, _, _ = sample_logits(logits_list, deterministic=True) + +print(f"Value: mean={value.mean():.4f}, std={value.std():.4f}, range=[{value.min():.4f}, {value.max():.4f}]") +print(f"Entropy: mean={ent.mean():.4f}, std={ent.std():.4f}") +print(f"LogProb: mean={logprob.mean():.4f}, std={logprob.std():.4f}") +print(f"Stochastic action sample: {action[0].cpu().numpy()}") +print(f"Deterministic action: {action_det[0].cpu().numpy()}") + +# Plot +fig, axes = plt.subplots(1, 2, figsize=(14, 4)) + +# Action probs (first head for multi-discrete, or full logits) +if isinstance(logits_list, list) or isinstance(logits_list, tuple): + probs = F.softmax(logits_list[0], dim=-1) +else: + probs = F.softmax(logits_list, dim=-1) +mean_probs = probs.mean(dim=0).cpu().numpy() +axes[0].bar(range(len(mean_probs)), mean_probs, edgecolor="black", alpha=0.7) +axes[0].axhline(1.0 / len(mean_probs), color="red", ls="--", label="uniform") +axes[0].set_xlabel("Action") +axes[0].set_ylabel("Probability") +axes[0].set_title("Mean action probabilities (across agents)") +axes[0].legend() + +axes[1].hist(value.cpu().numpy().flatten(), bins=30, edgecolor="black", alpha=0.7, color="purple") +axes[1].set_title("Value predictions across agents") +axes[1].set_xlabel("Value") +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## Full rollout: deterministic vs stochastic + +# %% +HORIZON = 256 +TRACKED_AGENT = 0 # agent index to track in detail +obs_dim = obs.shape[1] + +dyn_model = config["env"]["dynamics_model"] +tgt_type = config["env"]["target_type"] +rew_cond = config["env"].get("reward_conditioning", False) +n_tgt_wp = config["env"].get("num_target_waypoints", 3) + + +def run_rollout(env, policy, deterministic=False, horizon=HORIZON): + obs, _ = env.reset(seed=42) + N = env.num_agents + + buffers = { + "obs": np.zeros((horizon, N, obs_dim), dtype=np.float32), + "actions": np.zeros((horizon, N), dtype=np.int64), + "rewards": np.zeros((horizon, N), dtype=np.float32), + "values": np.zeros((horizon, N), dtype=np.float32), + "logprobs": np.zeros((horizon, N), dtype=np.float32), + "entropy": np.zeros((horizon, N), dtype=np.float32), + "terminals": np.zeros((horizon, N), dtype=np.float32), + "truncations": np.zeros((horizon, N), dtype=np.float32), + "positions_x": np.zeros((horizon, N), dtype=np.float32), + "positions_y": np.zeros((horizon, N), dtype=np.float32), + } + + policy.eval() + for t in range(horizon): + obs_t = torch.FloatTensor(obs).to(device) + with torch.no_grad(): + logits_list, val = policy(obs_t) + act, logp, entr = sample_logits(logits_list, deterministic=deterministic) + + buffers["obs"][t] = obs + buffers["actions"][t] = act.cpu().numpy().reshape(N) if act.dim() > 1 else act.cpu().numpy() + buffers["values"][t] = val.squeeze().cpu().numpy() + buffers["logprobs"][t] = logp.cpu().numpy() + buffers["entropy"][t] = entr.cpu().numpy() + + # Get positions + gstate = env.get_global_agent_state() + buffers["positions_x"][t] = gstate["x"] + buffers["positions_y"][t] = gstate["y"] + + # Step env + env_actions = act.cpu().numpy().reshape(ACT_SHAPE) + obs, rew, term, trunc, info = env.step(env_actions) + buffers["rewards"][t] = rew + buffers["terminals"][t] = term + buffers["truncations"][t] = trunc + + return buffers + + +print("Running stochastic rollout...") +buf_stoch = run_rollout(env, policy, deterministic=False) +print("Running deterministic rollout...") +buf_det = run_rollout(env, policy, deterministic=True) + +for name, buf in [("Stochastic", buf_stoch), ("Deterministic", buf_det)]: + print(f"\n--- {name} ---") + print(f" Reward: mean={buf['rewards'].mean():.5f}, std={buf['rewards'].std():.5f}") + print(f" Value: mean={buf['values'].mean():.5f}, std={buf['values'].std():.5f}") + print(f" Entropy: mean={buf['entropy'].mean():.4f}") + print(f" Terminals: {buf['terminals'].sum():.0f}, Truncations: {buf['truncations'].sum():.0f}") + +# %% [markdown] +# ## Observation analysis + +# %% +from pufferlib.viz import unpack_obs, plot_observation, plot_simulator_state + +# Ego-centric observation at t=50 for tracked agent +sample_t = min(50, HORIZON - 1) +sample_obs = buf_stoch["obs"][sample_t : sample_t + 1, TRACKED_AGENT : TRACKED_AGENT + 1][0] +print(dyn_model, tgt_type, rew_cond, n_tgt_wp) +img = plot_observation( + sample_obs, + target_type=tgt_type, + reward_conditioning=rew_cond, + num_target_waypoints=n_tgt_wp, + obs_slots_partners_n=env.obs_slots_partners_n, + obs_slots_lane_n=env.obs_slots_lane_kept, + obs_slots_boundary_n=env.obs_slots_boundary_kept, + obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n, + obs_norm_goal_offset_m=env.obs_norm_goal_offset_m, + obs_norm_xy_offset_m=env.obs_norm_xy_offset_m, + obs_norm_veh_width_m=env.obs_norm_veh_width_m, + obs_norm_veh_length_m=env.obs_norm_veh_length_m, +) +plt.figure(figsize=(10, 10)) +plt.imshow(img) +plt.axis("off") +plt.title(f"Ego-centric obs | agent={TRACKED_AGENT}, t={sample_t}") +plt.show() + +# BEV simulator state +scenarios = env.get_state() +if scenarios and len(scenarios) > 0: + img_bev = plot_simulator_state(scenarios[0], timestep=0) + plt.figure(figsize=(10, 10)) + plt.imshow(img_bev) + plt.axis("off") + plt.title("BEV Simulator State") + plt.show() + +# Ego feature time series for tracked agent +ego_features_over_time = [] +for t in range(HORIZON): + ego, *_ = unpack_obs( + buf_stoch["obs"][t : t + 1, TRACKED_AGENT : TRACKED_AGENT + 1][0], + target_type=tgt_type, + reward_conditioning=rew_cond, + num_target_waypoints=n_tgt_wp, + obs_slots_partners_n=env.obs_slots_partners_n, + obs_slots_lane_n=env.obs_slots_lane_kept, + obs_slots_boundary_n=env.obs_slots_boundary_kept, + obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n, + ) + ego_features_over_time.append(ego) +ego_ts = np.array(ego_features_over_time) + +if dyn_model == "jerk": + labels = ["speed", "width", "length", "steering", "a_long", "a_lat", "lcenter", "lalign", "speed_limit"] + plot_idxs = [0, 3, 4, 5] # speed, steering, a_long, a_lat +else: + labels = ["speed", "width", "length", "lcenter", "lalign", "speed_limit"] + plot_idxs = [0, 3, 4, 5] # speed, lcenter, lalign, speed_limit + +fig, axes = plt.subplots(len(plot_idxs), 1, figsize=(14, 3 * len(plot_idxs)), sharex=True) +for i, idx in enumerate(plot_idxs): + axes[i].plot(ego_ts[:, idx]) + print(ego_ts[10:, idx].argmin()) + axes[i].set_ylabel(labels[idx]) + axes[i].grid(True, alpha=0.3) +axes[-1].set_xlabel("Step") +fig.suptitle(f"Ego features over time | agent={TRACKED_AGENT}", fontsize=14) +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## Observation layer breakdown +# +# Obs layout (all ego-centric, normalized): +# - **Ego**: speed, width, length, [jerk: steering, a_long, a_lat], lane_center_dist, lane_angle, speed_limit +# - **Conditioning** (if enabled): 17 reward coefs (goal_radius, goal_speed, collision, offroad, comfort, lane_align, vel_align, lane_center, center_bias, velocity, reverse, stop_line, timestep, overspeed, throttle, steer, acc) + target waypoints +# - **Target**: static=rel_x,rel_y,rel_z per waypoint; dynamic=rel_x,rel_y,rel_z,heading_cos,heading_sin per waypoint +# - **Partners** (MAX_PARTNERS x 9): rel_x, rel_y, rel_z, length, width, heading_cos, heading_sin, rel_vx, rel_vy, seconds_stopped +# - **Lanes** (MAX_LANES x 7): rel_x, rel_y, rel_z, seg_length, seg_width, dir_cos, dir_sin +# - **Boundaries** (MAX_BOUNDS x 7): same as lanes +# - **Traffic controls** (MAX_TRAFFIC x 7): rel_x1, rel_y1, rel_x2, rel_y2, rel_z, type, state + +# %% +from pufferlib.viz import unpack_obs + +sample_t = min(50, HORIZON - 1) +sample_obs = buf_stoch["obs"][sample_t : sample_t + 1, TRACKED_AGENT : TRACKED_AGENT + 1][0] +ego, target, partners, lanes, boundaries, traffic_controls = unpack_obs( + sample_obs, + target_type=tgt_type, + reward_conditioning=rew_cond, + num_target_waypoints=n_tgt_wp, + obs_slots_partners_n=env.obs_slots_partners_n, + obs_slots_lane_n=env.obs_slots_lane_kept, + obs_slots_boundary_n=env.obs_slots_boundary_kept, + obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n, +) + +# Also unpack conditioning manually (unpack_obs doesn't return it separately) +ego_dim = binding.EGO_FEATURES +cond_dim = binding.NUM_REWARD_COEFS if rew_cond else 0 +cond_obs = sample_obs[0, ego_dim : ego_dim + cond_dim] if cond_dim > 0 else None + + +# --- Print all layer shapes + stats --- +def layer_stats(name, arr): + flat = arr.flatten() if hasattr(arr, "flatten") else np.array(arr).flatten() + if flat.size == 0: + print(f"{name:>14s}: shape={str(list(arr.shape)):>16s} (empty)") + return + nonzero = np.count_nonzero(flat) + print( + f"{name:>14s}: shape={str(list(arr.shape)):>16s} " + f"nonzero={nonzero:>5d}/{flat.size:<5d} " + f"range=[{flat.min():.4f}, {flat.max():.4f}] " + f"mean={flat.mean():.4f} std={flat.std():.4f}" + ) + + +print(f"--- Observation breakdown at t={sample_t}, agent={TRACKED_AGENT} ---") +print(f"Total obs dim: {sample_obs.shape[-1]}") +print() +layer_stats("Ego", ego) +if cond_obs is not None: + layer_stats("Conditioning", cond_obs) +layer_stats("Target", target) +layer_stats("Partners", partners) +layer_stats("Lanes", lanes) +layer_stats("Boundaries", boundaries) +layer_stats("TrafficControls", traffic_controls) + +# --- Ego features detail --- +ego_labels = EGO_LABELS + +print(f"\n--- Ego features ---") +for i, (label, val) in enumerate(zip(ego_labels, ego)): + print(f" [{i}] {label:>14s} = {val:.4f}") + +# --- Conditioning detail --- +if cond_obs is not None: + cond_labels = COEF_NAMES + print(f"\n--- Conditioning (reward coefs, normalized) ---") + for i, (label, val) in enumerate(zip(cond_labels, cond_obs)): + print(f" [{i:>2d}] {label:>16s} = {val:.4f}") + +# --- Target waypoints --- +tgt_feat = binding.STATIC_TARGET_FEATURES if tgt_type == "static" else binding.DYNAMIC_TARGET_FEATURES +if tgt_type == "static": + tgt_labels = ["rel_x", "rel_y", "rel_z"] +else: + tgt_labels = ["rel_x", "rel_y", "rel_z", "heading_cos", "heading_sin"] + +print(f"\n--- Target waypoints (n={n_tgt_wp}, type={tgt_type}) ---") +for wp in range(target.shape[0]): + vals = ", ".join(f"{tgt_labels[j]}={target[wp, j]:.4f}" for j in range(tgt_feat)) + active = "ACTIVE" if not np.allclose(target[wp], 0) else "zeroed" + print(f" wp[{wp}]: {vals} ({active})") + +# --- Partner summary --- +n_visible = np.sum(np.any(partners != 0, axis=1)) +print(f"\n--- Partners: {n_visible}/{partners.shape[0]} visible ---") +partner_labels = [ + "rel_x", + "rel_y", + "rel_z", + "length", + "width", + "heading_cos", + "heading_sin", + "rel_vx", + "rel_vy", + "seconds_stopped", +] +for p in range(min(int(n_visible), 5)): + vals = ", ".join(f"{partner_labels[j]}={partners[p, j]:.3f}" for j in range(env.partner_features)) + print(f" [{p}] {vals}") +if n_visible > 5: + print(f" ... ({n_visible - 5} more)") + +# --- Lane/boundary occupancy --- +n_lanes = np.sum(np.any(lanes != 0, axis=1)) +n_bounds = np.sum(np.any(boundaries != 0, axis=1)) +print(f"\n--- Road: {n_lanes}/{lanes.shape[0]} lane segs, {n_bounds}/{boundaries.shape[0]} boundary segs ---") + +# --- Traffic --- +n_traffic = np.sum(np.any(traffic_controls != 0, axis=1)) +print(f"\n--- Traffic controls: {n_traffic}/{traffic_controls.shape[0]} visible ---") +traffic_labels = ["rel_x1", "rel_y1", "rel_x2", "rel_y2", "rel_z", "type", "state"] +for t in range(min(int(n_traffic), 5)): + vals = ", ".join( + f"{traffic_labels[j]}={traffic_controls[t, j]:.3f}" + for j in range(min(len(traffic_labels), traffic_controls.shape[1])) + ) + print(f" [{t}] {vals}") + +# %% +# --- Layer-level stats across ALL agents at sample_t --- +all_obs = buf_stoch["obs"][sample_t] # (N, obs_dim) + +ego_dim = binding.EGO_FEATURES +cond_dim = binding.NUM_REWARD_COEFS if rew_cond else 0 +tgt_feat = binding.STATIC_TARGET_FEATURES if tgt_type == "static" else binding.DYNAMIC_TARGET_FEATURES +tgt_dim = n_tgt_wp * tgt_feat +partner_dim = env.obs_slots_partners_n * env.partner_features +lane_dim = env.obs_slots_lane_kept * env.road_features +boundary_dim = env.obs_slots_boundary_kept * env.road_features +traffic_dim = env.obs_slots_traffic_controls_n * env.traffic_control_features + +# Slice indices +idx = 0 +slices = {} +slices["ego"] = (idx, idx + ego_dim) +idx += ego_dim +if cond_dim > 0: + slices["conditioning"] = (idx, idx + cond_dim) + idx += cond_dim +slices["target"] = (idx, idx + tgt_dim) +idx += tgt_dim +slices["partners"] = (idx, idx + partner_dim) +idx += partner_dim +slices["lanes"] = (idx, idx + lane_dim) +idx += lane_dim +slices["boundaries"] = (idx, idx + boundary_dim) +idx += boundary_dim +slices["traffic"] = (idx, idx + traffic_dim) +idx += traffic_dim + +print(f"Obs dim used: {idx} / {all_obs.shape[1]}") +print( + f"\n{'Layer':>14s} | {'start':>5s}-{'end':>5s} | {'dim':>5s} | {'mean':>8s} | {'std':>8s} | {'min':>8s} | {'max':>8s} | {'%nonzero':>8s}" +) +print("-" * 95) +for name, (s, e) in slices.items(): + chunk = all_obs[:, s:e] + nz_pct = 100 * np.count_nonzero(chunk) / chunk.size + print( + f"{name:>14s} | {s:>5d}-{e:>5d} | {e - s:>5d} | {chunk.mean():>8.4f} | {chunk.std():>8.4f} | " + f"{chunk.min():>8.4f} | {chunk.max():>8.4f} | {nz_pct:>7.1f}%" + ) + +# --- Plots --- +n_layers = len(slices) +fig, axes = plt.subplots(2, (n_layers + 1) // 2, figsize=(5 * ((n_layers + 1) // 2), 8)) +axes = axes.flatten() + +for i, (name, (s, e)) in enumerate(slices.items()): + chunk = all_obs[:, s:e].flatten() + # Filter out exact zeros for histogram readability on sparse layers + nonzero_vals = chunk[chunk != 0] + if len(nonzero_vals) > 0: + axes[i].hist(nonzero_vals, bins=50, edgecolor="black", alpha=0.7) + axes[i].set_title(f"{name} (nonzero only, {len(nonzero_vals)}/{len(chunk)})") + else: + axes[i].hist(chunk, bins=50, edgecolor="black", alpha=0.7) + axes[i].set_title(f"{name} (all zeros)") + axes[i].set_xlabel("Value") + +# Hide unused axes +for j in range(i + 1, len(axes)): + axes[j].set_visible(False) + +fig.suptitle(f"Observation distributions across {N} agents at t={sample_t}", fontsize=14) +plt.tight_layout() +plt.show() + +# %% +# --- Per-feature detail for partners, lanes, boundaries over time (tracked agent) --- + + +def unpack_all_timesteps(bufs, agent_idx): + """Unpack all obs layers across time for one agent.""" + H = bufs["obs"].shape[0] + egos, targets, conds = [], [], [] + n_partners, n_lanes, n_bounds, n_traffic = [], [], [], [] + + for t in range(H): + ob = bufs["obs"][t : t + 1, agent_idx : agent_idx + 1][0] + ego, tgt, part, lane, bnd, tfc = unpack_obs( + ob, + target_type=tgt_type, + reward_conditioning=rew_cond, + num_target_waypoints=n_tgt_wp, + obs_slots_partners_n=env.obs_slots_partners_n, + obs_slots_lane_n=env.obs_slots_lane_kept, + obs_slots_boundary_n=env.obs_slots_boundary_kept, + obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n, + ) + egos.append(ego) + targets.append(tgt) + n_partners.append(np.sum(np.any(part != 0, axis=1))) + n_lanes.append(np.sum(np.any(lane != 0, axis=1))) + n_bounds.append(np.sum(np.any(bnd != 0, axis=1))) + n_traffic.append(np.sum(np.any(tfc != 0, axis=1))) + + if rew_cond: + ed = binding.EGO_FEATURES + conds.append(ob[0, ed : ed + binding.NUM_REWARD_COEFS]) + + return { + "ego": np.array(egos), + "target": np.array(targets), + "cond": np.array(conds) if conds else None, + "n_partners": np.array(n_partners), + "n_lanes": np.array(n_lanes), + "n_bounds": np.array(n_bounds), + "n_traffic": np.array(n_traffic), + } + + +ts = unpack_all_timesteps(buf_stoch, TRACKED_AGENT) + +fig, axes = plt.subplots(2, 2, figsize=(16, 10)) + +# Occupancy over time +axes[0, 0].plot(ts["n_partners"], label="partners", alpha=0.8) +axes[0, 0].plot(ts["n_lanes"], label="lanes", alpha=0.8) +axes[0, 0].plot(ts["n_bounds"], label="boundaries", alpha=0.8) +axes[0, 0].plot(ts["n_traffic"], label="traffic", alpha=0.8) +axes[0, 0].set_xlabel("Step") +axes[0, 0].set_ylabel("Visible count") +axes[0, 0].set_title(f"Obs occupancy over time | agent={TRACKED_AGENT}") +axes[0, 0].legend() +axes[0, 0].grid(True, alpha=0.3) + +# Target waypoint distances over time +tgt_x = ts["target"][:, :, 0] +tgt_y = ts["target"][:, :, 1] +tgt_dist = np.sqrt(tgt_x**2 + tgt_y**2) +for wp in range(n_tgt_wp): + axes[0, 1].plot(tgt_dist[:, wp], label=f"wp[{wp}]", alpha=0.8) +axes[0, 1].set_xlabel("Step") +axes[0, 1].set_ylabel("Distance (normalized)") +axes[0, 1].set_title("Target waypoint distance over time") +axes[0, 1].legend() +axes[0, 1].grid(True, alpha=0.3) + +# Conditioning heatmap over time +if ts["cond"] is not None: + cond_labels = COEF_NAMES + im = axes[1, 0].imshow(ts["cond"].T, aspect="auto", cmap="coolwarm", interpolation="nearest") + axes[1, 0].set_yticks(range(len(cond_labels))) + axes[1, 0].set_yticklabels(cond_labels, fontsize=8) + axes[1, 0].set_xlabel("Step") + axes[1, 0].set_title("Conditioning coefs over time") + plt.colorbar(im, ax=axes[1, 0]) +else: + axes[1, 0].text(0.5, 0.5, "No conditioning", ha="center", va="center", transform=axes[1, 0].transAxes) + axes[1, 0].set_title("Conditioning (disabled)") + +# Partner closest distance over time +partner_dists = [] +for t in range(HORIZON): + ob = buf_stoch["obs"][t : t + 1, TRACKED_AGENT : TRACKED_AGENT + 1][0] + _, _, part, _, _, _ = unpack_obs( + ob, + target_type=tgt_type, + reward_conditioning=rew_cond, + num_target_waypoints=n_tgt_wp, + obs_slots_partners_n=env.obs_slots_partners_n, + obs_slots_lane_n=env.obs_slots_lane_kept, + obs_slots_boundary_n=env.obs_slots_boundary_kept, + obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n, + ) + dists = np.sqrt(part[:, 0] ** 2 + part[:, 1] ** 2) + visible = np.any(part != 0, axis=1) + partner_dists.append(dists[visible].min() if visible.any() else np.nan) + +axes[1, 1].plot(partner_dists, alpha=0.8, color="red") +axes[1, 1].set_xlabel("Step") +axes[1, 1].set_ylabel("Min partner dist (normalized)") +axes[1, 1].set_title("Closest partner distance over time") +axes[1, 1].grid(True, alpha=0.3) + +plt.tight_layout() +plt.show() + +# %% +# --- Spatial scatter: all observed entities in ego frame at sample_t --- +sample_obs = buf_stoch["obs"][sample_t : sample_t + 1, TRACKED_AGENT : TRACKED_AGENT + 1][0] +ego, target, partners, lanes, boundaries, traffic_controls = unpack_obs( + sample_obs, + target_type=tgt_type, + reward_conditioning=rew_cond, + num_target_waypoints=n_tgt_wp, + obs_slots_partners_n=env.obs_slots_partners_n, + obs_slots_lane_n=env.obs_slots_lane_kept, + obs_slots_boundary_n=env.obs_slots_boundary_kept, + obs_slots_traffic_controls_n=env.obs_slots_traffic_controls_n, +) + +fig, ax = plt.subplots(figsize=(10, 10)) + +# Ego vehicle at origin +from matplotlib.patches import Rectangle + +ax.add_patch( + Rectangle((-ego[2] / 2, -ego[1] / 2), ego[2], ego[1], facecolor="blue", edgecolor="black", alpha=0.7, zorder=10) +) +ax.annotate("EGO", (0, 0), fontsize=9, ha="center", va="center", color="white", fontweight="bold", zorder=11) + +# Lane segments +for i in range(lanes.shape[0]): + if np.allclose(lanes[i], 0): + continue + rx, ry, rz, length, _, dc, ds = lanes[i] + ax.plot( + [rx - dc * length / 2, rx + dc * length / 2], + [ry - ds * length / 2, ry + ds * length / 2], + color="lightgray", + linewidth=1, + zorder=1, + ) +ax.scatter( + lanes[np.any(lanes != 0, axis=1), 0], + lanes[np.any(lanes != 0, axis=1), 1], + s=5, + color="gray", + alpha=0.5, + label=f"lanes ({n_lanes})", + zorder=2, +) + +# Boundary segments +for i in range(boundaries.shape[0]): + if np.allclose(boundaries[i], 0): + continue + rx, ry, rz, length, _, dc, ds = boundaries[i] + ax.plot( + [rx - dc * length / 2, rx + dc * length / 2], + [ry - ds * length / 2, ry + ds * length / 2], + color="black", + linewidth=1, + zorder=1, + ) +bnd_mask = np.any(boundaries != 0, axis=1) +if bnd_mask.any(): + ax.scatter( + boundaries[bnd_mask, 0], + boundaries[bnd_mask, 1], + s=8, + color="black", + alpha=0.6, + label=f"boundaries ({n_bounds})", + zorder=2, + ) + +# Partners +for i in range(partners.shape[0]): + if np.allclose(partners[i], 0): + continue + rx, ry, rz, w, l, hc, hs, vx, vy, _ = partners[i] + heading = np.arctan2(hs, hc) + rect = Rectangle((-l / 2, -w / 2), l, w, facecolor="orange", edgecolor="black", alpha=0.6, zorder=9) + rect.set_transform(plt.matplotlib.transforms.Affine2D().rotate(heading).translate(rx, ry) + ax.transData) + ax.add_patch(rect) + ax.annotate(f"{vx:.2f}, {vy:.2f}", (rx, ry), fontsize=7, ha="center", color="darkred", zorder=12) +part_mask = np.any(partners != 0, axis=1) +if part_mask.any(): + ax.scatter( + partners[part_mask, 0], + partners[part_mask, 1], + s=40, + color="orange", + edgecolors="black", + label=f"partners ({n_visible})", + zorder=8, + ) + +# Target waypoints +for wp in range(target.shape[0]): + if np.allclose(target[wp], 0): + continue + marker = "*" if wp == 0 else "o" + s = 200 if wp == 0 else 80 + color = "red" if wp == 0 else "salmon" + ax.scatter( + target[wp, 0], + target[wp, 1], + color=color, + marker=marker, + s=s, + zorder=15, + label=f"target wp[{wp}]" if wp < 3 else None, + ) + +# Traffic controls +for i in range(traffic_controls.shape[0]): + if np.allclose(traffic_controls[i], 0): + continue + x1, y1, x2, y2, _, control_type, state = traffic_controls[i] + if int(control_type) == binding.TRAFFIC_CONTROL_TYPE_TRAFFIC_LIGHT: + state_colors = { + binding.TRAFFIC_CONTROL_STATE_UNKNOWN: "gray", + binding.TRAFFIC_CONTROL_STATE_RED: "red", + binding.TRAFFIC_CONTROL_STATE_YELLOW: "yellow", + binding.TRAFFIC_CONTROL_STATE_GREEN: "green", + binding.TRAFFIC_CONTROL_STATE_OFF: "gray", + } + ax.plot([x1, x2], [y1, y2], color=state_colors.get(int(state), "gray"), linewidth=3, zorder=15) + else: + accent = "red" if int(control_type) == binding.TRAFFIC_CONTROL_TYPE_STOP_SIGN else "gold" + ax.plot([x1, x2], [y1, y2], color="black", linewidth=4, zorder=14) + ax.plot([x1, x2], [y1, y2], color=accent, linewidth=2.5, linestyle="--", zorder=15) + +ax.set_xlim(-1, 1) +ax.set_ylim(-1, 1) +ax.set_aspect("equal") +ax.set_xlabel("X (ego frame, normalized)") +ax.set_ylabel("Y (ego frame, normalized)") +ax.set_title(f"All observed entities | agent={TRACKED_AGENT}, t={sample_t}") +ax.legend(loc="upper right", fontsize=8) +ax.grid(True, alpha=0.3) +plt.tight_layout() +plt.show() + +# %% [markdown] +# ### Ego + conditioning distributions across all agents + +# %% +# Ego feature distributions across all agents, pooled over full rollout +ego_dim = binding.EGO_FEATURES +all_ego = buf_stoch["obs"][:, :, :ego_dim].reshape(-1, ego_dim) # (H*N, ego_dim) + +ego_labels = EGO_LABELS + +fig, axes = plt.subplots(2, len(ego_labels), figsize=(3.5 * len(ego_labels), 7)) + +# Row 0: histograms +for i, label in enumerate(ego_labels): + vals = all_ego[:, i] + print(f"{label}: mean={vals}") + axes[0, i].hist(vals, bins=60, edgecolor="black", alpha=0.7, color="steelblue") + axes[0, i].set_title(label, fontsize=10) + axes[0, i].set_xlabel("") + axes[0, i].tick_params(labelsize=7) + axes[0, i].axvline(vals.mean(), color="red", ls="--", lw=1) + +# Row 1: boxplots per-agent (distribution across timesteps for each agent) +ego_per_agent = buf_stoch["obs"][:, :, :ego_dim] # (H, N, ego_dim) +for i, label in enumerate(ego_labels): + data = [ego_per_agent[:, a, i] for a in range(N)] + bp = axes[1, i].boxplot( + data, + showfliers=False, + patch_artist=True, + boxprops=dict(facecolor="steelblue", alpha=0.5), + medianprops=dict(color="red"), + ) + axes[1, i].set_xlabel("Agent") + axes[1, i].tick_params(labelsize=7) + axes[1, i].set_title(f"{label} per agent", fontsize=9) + +fig.suptitle("Ego features: full rollout distributions", fontsize=13) +plt.tight_layout() +plt.show() + +# Conditioning distributions across all agents (if enabled) +if rew_cond: + cond_start = ego_dim + cond_end = cond_start + binding.NUM_REWARD_COEFS + all_cond = buf_stoch["obs"][:, :, cond_start:cond_end].reshape(-1, binding.NUM_REWARD_COEFS) + + cond_labels = COEF_NAMES + + fig, ax = plt.subplots(figsize=(14, 5)) + parts = ax.violinplot( + [all_cond[:, i] for i in range(binding.NUM_REWARD_COEFS)], + positions=range(binding.NUM_REWARD_COEFS), + showmeans=True, + showmedians=True, + ) + ax.set_xticks(range(binding.NUM_REWARD_COEFS)) + ax.set_xticklabels(cond_labels, rotation=45, ha="right", fontsize=9) + ax.set_ylabel("Normalized value") + ax.set_title("Conditioning coef distributions (all agents, full rollout)") + ax.grid(True, alpha=0.3, axis="y") + plt.tight_layout() + plt.show() + +# %% [markdown] +# ### Partner per-feature distributions + +# %% +# Partner per-feature distributions (pooled over all agents + timesteps, visible only) +partner_labels = [ + "rel_x", + "rel_y", + "rel_z", + "length", + "width", + "heading_cos", + "heading_sin", + "rel_vx", + "rel_vy", + "seconds_stopped", +] +obs_slots_partners_n = env.obs_slots_partners_n +pf = env.partner_features + +# Compute slices +_ego_d = binding.EGO_FEATURES +_cond_d = binding.NUM_REWARD_COEFS if rew_cond else 0 +_tgt_f = binding.STATIC_TARGET_FEATURES if tgt_type == "static" else binding.DYNAMIC_TARGET_FEATURES +_tgt_d = n_tgt_wp * _tgt_f +_p_start = _ego_d + _cond_d + _tgt_d +_p_end = _p_start + obs_slots_partners_n * pf + +all_partners = buf_stoch["obs"][:, :, _p_start:_p_end].reshape( + -1, obs_slots_partners_n, pf +) # (H*N, obs_slots_partners_n, 10) +# Mask: partner is visible if any feature != 0 +visible_mask = np.any(all_partners != 0, axis=2) # (H*N, 16) +visible_partners = all_partners[visible_mask] # (K, 10) — all visible partner observations + +print( + f"Total partner obs: {all_partners.shape[0] * obs_slots_partners_n}, visible: {len(visible_partners)} " + f"({100 * len(visible_partners) / (all_partners.shape[0] * obs_slots_partners_n):.1f}%)" +) + +fig, axes = plt.subplots(2, 5, figsize=(21, 8)) +axes = axes.flatten() + +for i, label in enumerate(partner_labels): + vals = visible_partners[:, i] + axes[i].hist(vals, bins=80, edgecolor="black", alpha=0.7, color="darkorange") + axes[i].set_title(f"{label} (n={len(vals)})", fontsize=10) + axes[i].axvline(vals.mean(), color="red", ls="--", lw=1, label=f"mean={vals.mean():.3f}") + axes[i].legend(fontsize=7) + axes[i].tick_params(labelsize=7) + +# rel_x vs rel_y scatter in last panel +pos_ax = axes[len(partner_labels)] +pos_ax.scatter(visible_partners[:, 0], visible_partners[:, 1], s=1, alpha=0.15, color="darkorange") +pos_ax.set_xlabel("rel_x") +pos_ax.set_ylabel("rel_y") +pos_ax.set_title("Partner positions (ego frame)") +pos_ax.set_aspect("equal") +pos_ax.grid(True, alpha=0.3) + +fig.suptitle("Partner features: all visible, full rollout", fontsize=13) +plt.tight_layout() +plt.show() + +# Partner count distribution across (timestep, agent) +partner_counts = visible_mask.sum(axis=1) # (H*N,) +fig, axes = plt.subplots(1, 2, figsize=(12, 4)) +axes[0].hist(partner_counts, bins=range(obs_slots_partners_n + 2), edgecolor="black", alpha=0.7, color="darkorange") +axes[0].set_xlabel("Visible partners") +axes[0].set_ylabel("Count") +axes[0].set_title("Partner count distribution (per agent per step)") + +# Partner distance distribution +dists = np.sqrt(visible_partners[:, 0] ** 2 + visible_partners[:, 1] ** 2) +axes[1].hist(dists, bins=80, edgecolor="black", alpha=0.7, color="coral") +axes[1].set_xlabel("Distance (normalized)") +axes[1].set_ylabel("Count") +axes[1].set_title(f"Partner distance distribution (mean={dists.mean():.3f})") +axes[1].axvline(dists.mean(), color="red", ls="--", lw=1) + +plt.tight_layout() +plt.show() + +# %% [markdown] +# ### Road (lanes + boundaries) and target distributions + +# %% +# Road per-feature distributions (lanes + boundaries) +road_labels = ["rel_x", "rel_y", "rel_z", "seg_length", "seg_width", "dir_cos", "dir_sin"] +rf = env.road_features +max_lanes = env.obs_slots_lane_kept +max_bounds = env.obs_slots_boundary_kept + +_l_start = _p_end +_l_end = _l_start + max_lanes * rf +_b_start = _l_end +_b_end = _b_start + max_bounds * rf + +all_lanes = buf_stoch["obs"][:, :, _l_start:_l_end].reshape(-1, max_lanes, rf) +all_bounds = buf_stoch["obs"][:, :, _b_start:_b_end].reshape(-1, max_bounds, rf) + +vis_lanes = all_lanes[np.any(all_lanes != 0, axis=2)] +vis_bounds = all_bounds[np.any(all_bounds != 0, axis=2)] + +print( + f"Lanes: {len(vis_lanes)} visible / {all_lanes.shape[0] * max_lanes} total " + f"({100 * len(vis_lanes) / (all_lanes.shape[0] * max_lanes):.1f}%)" +) +print( + f"Boundaries: {len(vis_bounds)} visible / {all_bounds.shape[0] * max_bounds} total " + f"({100 * len(vis_bounds) / (all_bounds.shape[0] * max_bounds):.1f}%)" +) + +fig, axes = plt.subplots(2, 7, figsize=(28, 8)) +for i, label in enumerate(road_labels): + # Lanes + axes[0, i].hist(vis_lanes[:, i], bins=80, edgecolor="black", alpha=0.7, color="silver") + axes[0, i].set_title(f"lane {label}", fontsize=9) + axes[0, i].axvline(vis_lanes[:, i].mean(), color="red", ls="--", lw=1) + axes[0, i].tick_params(labelsize=7) + # Boundaries + axes[1, i].hist(vis_bounds[:, i], bins=80, edgecolor="black", alpha=0.7, color="dimgray") + axes[1, i].set_title(f"boundary {label}", fontsize=9) + axes[1, i].axvline(vis_bounds[:, i].mean(), color="red", ls="--", lw=1) + axes[1, i].tick_params(labelsize=7) + +fig.suptitle("Road features: all visible, full rollout (top=lanes, bottom=boundaries)", fontsize=13) +plt.tight_layout() +plt.show() + +# Spatial scatter: lane vs boundary positions (pooled) +fig, axes = plt.subplots(1, 3, figsize=(18, 5)) +axes[0].scatter(vis_lanes[:, 0], vis_lanes[:, 1], s=0.5, alpha=0.05, color="gray") +axes[0].set_xlabel("rel_x") +axes[0].set_ylabel("rel_y") +axes[0].set_title(f"Lane segment positions (n={len(vis_lanes)})") +axes[0].set_aspect("equal") +axes[0].grid(True, alpha=0.3) + +axes[1].scatter(vis_bounds[:, 0], vis_bounds[:, 1], s=0.5, alpha=0.05, color="black") +axes[1].set_xlabel("rel_x") +axes[1].set_ylabel("rel_y") +axes[1].set_title(f"Boundary segment positions (n={len(vis_bounds)})") +axes[1].set_aspect("equal") +axes[1].grid(True, alpha=0.3) + +# Lane + boundary segment length comparison +axes[2].hist(vis_lanes[:, 2], bins=80, alpha=0.6, color="silver", edgecolor="black", label="lanes") +axes[2].hist(vis_bounds[:, 2], bins=80, alpha=0.6, color="dimgray", edgecolor="black", label="boundaries") +axes[2].set_xlabel("Segment length (normalized)") +axes[2].set_ylabel("Count") +axes[2].set_title("Segment length distribution") +axes[2].legend() + +plt.tight_layout() +plt.show() + +# Target distributions across all agents, full rollout +_tgt_start = _ego_d + _cond_d +_tgt_end = _tgt_start + _tgt_d +all_target = buf_stoch["obs"][:, :, _tgt_start:_tgt_end].reshape(-1, n_tgt_wp, _tgt_f) + +if tgt_type == "static": + tgt_flabels = ["rel_x", "rel_y", "rel_z"] +else: + tgt_flabels = ["rel_x", "rel_y", "rel_z", "heading_cos", "heading_sin"] + +fig, axes = plt.subplots(1, n_tgt_wp + 1, figsize=(5 * (n_tgt_wp + 1), 4)) + +for wp in range(n_tgt_wp): + wp_data = all_target[:, wp, :] + active = np.any(wp_data != 0, axis=1) + wp_active = wp_data[active] + dist = np.sqrt(wp_active[:, 0] ** 2 + wp_active[:, 1] ** 2) if len(wp_active) > 0 else np.array([]) + axes[wp].hist(dist, bins=60, edgecolor="black", alpha=0.7, color=["red", "salmon", "lightsalmon"][wp % 3]) + axes[wp].set_title(f"wp[{wp}] distance (n={len(wp_active)}/{len(wp_data)})", fontsize=10) + axes[wp].set_xlabel("Distance (normalized)") + +# All waypoints x-y scatter +for wp in range(n_tgt_wp): + wp_data = all_target[:, wp, :] + active = np.any(wp_data != 0, axis=1) + wp_active = wp_data[active] + if len(wp_active) > 0: + axes[n_tgt_wp].scatter(wp_active[:, 0], wp_active[:, 1], s=1, alpha=0.1, label=f"wp[{wp}]") +axes[n_tgt_wp].set_xlabel("rel_x") +axes[n_tgt_wp].set_ylabel("rel_y") +axes[n_tgt_wp].set_title("Target positions (ego frame)") +axes[n_tgt_wp].set_aspect("equal") +axes[n_tgt_wp].legend(fontsize=8) +axes[n_tgt_wp].grid(True, alpha=0.3) + +fig.suptitle("Target waypoint distributions (all agents, full rollout)", fontsize=13) +plt.tight_layout() +plt.show() + +# %% [markdown] +# ### Observation sparsity and layer occupancy heatmaps + +# %% +# Sparsity heatmap: fraction of nonzero per layer, per agent, over time +layer_names = ["partners", "lanes", "boundaries"] +layer_slices = [ + (_p_start, _p_end, env.obs_slots_partners_n, env.partner_features), + (_l_start, _l_end, env.obs_slots_lane_kept, env.road_features), + (_b_start, _b_end, env.obs_slots_boundary_kept, env.road_features), +] + +fig, axes = plt.subplots(1, 3, figsize=(20, 5)) +for ax, name, (s, e, n_obj, n_feat) in zip(axes, layer_names, layer_slices): + # (H, N) -> fraction of visible objects per (timestep, agent) + raw = buf_stoch["obs"][:, :, s:e].reshape(HORIZON, env.num_agents, n_obj, n_feat) + occupancy = np.any(raw != 0, axis=3).sum(axis=2) / n_obj # (H, N) + im = ax.imshow(occupancy.T, aspect="auto", cmap="YlOrRd", interpolation="nearest", vmin=0, vmax=1) + ax.set_xlabel("Step") + ax.set_ylabel("Agent") + ax.set_title(f"{name} occupancy (frac visible)") + plt.colorbar(im, ax=ax) + +plt.suptitle("Per-layer occupancy heatmaps (fraction of max slots filled)", fontsize=13) +plt.tight_layout() +plt.show() + +# Per-layer mean occupancy over time +fig, axes = plt.subplots(1, 2, figsize=(16, 4)) + +# Mean across agents +for name, (s, e, n_obj, n_feat) in zip(layer_names, layer_slices): + raw = buf_stoch["obs"][:, :, s:e].reshape(HORIZON, env.num_agents, n_obj, n_feat) + occ_mean = np.any(raw != 0, axis=3).sum(axis=2).mean(axis=1) # (H,) + axes[0].plot(occ_mean, label=name, alpha=0.8) +axes[0].set_xlabel("Step") +axes[0].set_ylabel("Mean visible count") +axes[0].set_title("Mean occupancy over time (across agents)") +axes[0].legend() +axes[0].grid(True, alpha=0.3) + +# Mean across timesteps (per agent) +for name, (s, e, n_obj, n_feat) in zip(layer_names, layer_slices): + raw = buf_stoch["obs"][:, :, s:e].reshape(HORIZON, env.num_agents, n_obj, n_feat) + occ_per_agent = np.any(raw != 0, axis=3).sum(axis=2).mean(axis=0) # (N,) + axes[1].bar(range(N), occ_per_agent, alpha=0.5, label=name) +axes[1].set_xlabel("Agent") +axes[1].set_ylabel("Mean visible count") +axes[1].set_title("Mean occupancy per agent (across timesteps)") +axes[1].legend() +axes[1].grid(True, alpha=0.3) + +plt.tight_layout() +plt.show() + +# Full obs sparsity: fraction of zero features per obs dimension, pooled +all_flat = buf_stoch["obs"].reshape(-1, obs_dim) # (H*N, obs_dim) +zero_frac = (all_flat == 0).mean(axis=0) # (obs_dim,) +fig, ax = plt.subplots(figsize=(18, 3)) +ax.bar(range(obs_dim), zero_frac, width=1.0, color="steelblue", alpha=0.7) +# Annotate layer boundaries +prev_e = 0 +for name, (s, e) in slices.items(): + ax.axvline(s, color="red", ls="--", lw=0.5, alpha=0.7) + mid = (s + e) / 2 + ax.text(mid, 1.02, name, ha="center", va="bottom", fontsize=7, rotation=0, color="red") + prev_e = e +ax.set_xlim(0, obs_dim) +ax.set_ylim(0, 1.1) +ax.set_xlabel("Obs dimension index") +ax.set_ylabel("Fraction zero") +ax.set_title("Per-dimension sparsity (fraction zero across full rollout)") +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## Policy outputs over time + +# %% +# Compute action probs over time for tracked agent (stochastic rollout) +n_actions = env.single_action_space.nvec[0] if not is_continuous else 1 +action_probs_time = np.zeros((HORIZON, n_actions)) +for t in range(HORIZON): + obs_t = torch.FloatTensor(buf_stoch["obs"][t : t + 1, TRACKED_AGENT : TRACKED_AGENT + 1][0]).to(device) + with torch.no_grad(): + logits_list, _ = policy(obs_t) + logits = logits_list[0] if isinstance(logits_list, (list, tuple)) else logits_list + action_probs_time[t] = F.softmax(logits, dim=-1).cpu().numpy().flatten() + +fig, axes = plt.subplots(2, 2, figsize=(16, 10)) + +# Action distribution heatmap +im = axes[0, 0].imshow(action_probs_time.T, aspect="auto", cmap="viridis", interpolation="nearest") +axes[0, 0].set_xlabel("Step") +axes[0, 0].set_ylabel("Action ID") +axes[0, 0].set_title(f"Action prob heatmap | agent={TRACKED_AGENT}") +plt.colorbar(im, ax=axes[0, 0]) + +# Entropy over time +axes[0, 1].plot(buf_stoch["entropy"][:, TRACKED_AGENT], label="stochastic", alpha=0.8) +axes[0, 1].set_xlabel("Step") +axes[0, 1].set_ylabel("Entropy") +axes[0, 1].set_title("Entropy over time") +axes[0, 1].grid(True, alpha=0.3) + +# Value over time +axes[1, 0].plot(buf_stoch["values"][:, TRACKED_AGENT], label="stochastic", alpha=0.8) +axes[1, 0].plot(buf_det["values"][:, TRACKED_AGENT], label="deterministic", alpha=0.8) +axes[1, 0].set_xlabel("Step") +axes[1, 0].set_ylabel("Value") +axes[1, 0].set_title("Value predictions over time") +axes[1, 0].legend() +axes[1, 0].grid(True, alpha=0.3) + +# Actions over time: deterministic vs stochastic +axes[1, 1].step(range(HORIZON), buf_stoch["actions"][:, TRACKED_AGENT], label="stochastic", alpha=0.7) +axes[1, 1].step(range(HORIZON), buf_det["actions"][:, TRACKED_AGENT], label="deterministic", alpha=0.7) +axes[1, 1].set_xlabel("Step") +axes[1, 1].set_ylabel("Action") +axes[1, 1].set_title("Selected action over time") +axes[1, 1].legend() +axes[1, 1].grid(True, alpha=0.3) + +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## Rewards and returns + +# %% +fig, axes = plt.subplots(2, 3, figsize=(18, 10)) + +# Per-step mean reward +axes[0, 0].plot(buf_stoch["rewards"].mean(axis=1), label="stochastic", alpha=0.8) +axes[0, 0].plot(buf_det["rewards"].mean(axis=1), label="deterministic", alpha=0.8) +axes[0, 0].set_xlabel("Step") +axes[0, 0].set_ylabel("Mean reward") +axes[0, 0].set_title("Mean reward per step") +axes[0, 0].legend() +axes[0, 0].grid(True, alpha=0.3) + +# Reward heatmap (stochastic) +im = axes[0, 1].imshow(buf_stoch["rewards"].T, aspect="auto", cmap="RdYlGn", interpolation="nearest") +axes[0, 1].set_xlabel("Step") +axes[0, 1].set_ylabel("Agent") +axes[0, 1].set_title("Reward heatmap (stochastic)") +plt.colorbar(im, ax=axes[0, 1]) + +# Cumulative return per agent +cum_ret_stoch = buf_stoch["rewards"].sum(axis=0) +cum_ret_det = buf_det["rewards"].sum(axis=0) +axes[0, 2].hist(cum_ret_stoch, bins=30, alpha=0.6, label="stochastic", edgecolor="black") +axes[0, 2].hist(cum_ret_det, bins=30, alpha=0.6, label="deterministic", edgecolor="black") +axes[0, 2].set_xlabel("Cumulative return") +axes[0, 2].set_ylabel("Count") +axes[0, 2].set_title("Return distribution across agents") +axes[0, 2].legend() + +# Reward distribution histogram +axes[1, 0].hist(buf_stoch["rewards"].flatten(), bins=50, alpha=0.7, edgecolor="black") +axes[1, 0].set_xlabel("Reward") +axes[1, 0].set_ylabel("Count") +axes[1, 0].set_title("Per-step reward distribution (stochastic)") +axes[1, 0].set_yscale("log") + +# Terminal/truncation timeline +axes[1, 1].plot(buf_stoch["terminals"].sum(axis=1), label="terminals", alpha=0.8) +axes[1, 1].plot(buf_stoch["truncations"].sum(axis=1), label="truncations", alpha=0.8) +axes[1, 1].set_xlabel("Step") +axes[1, 1].set_ylabel("Count") +axes[1, 1].set_title("Terminals/Truncations per step (stochastic)") +axes[1, 1].legend() +axes[1, 1].grid(True, alpha=0.3) + +# Tracked agent reward +axes[1, 2].plot(buf_stoch["rewards"][:, TRACKED_AGENT], label="stochastic", alpha=0.8) +axes[1, 2].plot(buf_det["rewards"][:, TRACKED_AGENT], label="deterministic", alpha=0.8) +axes[1, 2].set_xlabel("Step") +axes[1, 2].set_ylabel("Reward") +axes[1, 2].set_title(f"Reward over time | agent={TRACKED_AGENT}") +axes[1, 2].legend() +axes[1, 2].grid(True, alpha=0.3) + +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## Episode metrics + +# %% +# Collect episode-level metrics from the C logging +log_stoch = binding.vec_log(env.c_envs, N) + +# eval_mode=1 returns a list of per-env dicts; aggregate by averaging +if isinstance(log_stoch, list) and log_stoch: + all_keys = set(k for d in log_stoch for k in d if isinstance(d[k], (int, float))) + log_stoch = {k: np.mean([d[k] for d in log_stoch if k in d]) for k in all_keys} + +if log_stoch: + print("Episode metrics (after stochastic rollout):") + for k, v in sorted(log_stoch.items()): + if isinstance(v, (int, float)): + print(f" {k}: {v:.4f}") + + # Bar chart of key metrics + keys = ["score", "collision_rate", "offroad_rate", "completion_rate", "dnf_rate"] + vals = [log_stoch.get(k, 0) for k in keys] + fig, ax = plt.subplots(figsize=(10, 4)) + bars = ax.bar(keys, vals, edgecolor="black", alpha=0.7, color=["green", "red", "orange", "blue", "gray"]) + ax.set_ylabel("Rate") + ax.set_title("Episode Metrics") + for bar, v in zip(bars, vals): + ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01, f"{v:.3f}", ha="center", fontsize=10) + plt.tight_layout() + plt.show() +else: + print("No episode metrics available yet (not enough episodes completed)") + +# %% [markdown] +# ## Value predictions vs actual returns + +# %% +gamma = config["train"].get("gamma", 0.98) +lam = config["train"].get("gae_lambda", 0.95) + + +def compute_gae(rewards, values, terminals, truncations, gamma, lam): + H, N = rewards.shape + advantages = np.zeros_like(rewards) + last_gae = np.zeros(N) + for t in reversed(range(H - 1)): + done = np.maximum(terminals[t + 1], truncations[t + 1]) + next_non_terminal = 1.0 - done + delta = rewards[t + 1] + gamma * values[t + 1] * next_non_terminal - values[t] + last_gae = delta + gamma * lam * last_gae * next_non_terminal + advantages[t] = last_gae + return advantages + + +adv_stoch = compute_gae( + buf_stoch["rewards"], buf_stoch["values"], buf_stoch["terminals"], buf_stoch["truncations"], gamma, lam +) +returns_stoch = adv_stoch + buf_stoch["values"] + +pred_v = buf_stoch["values"].flatten() +actual_r = returns_stoch.flatten() + +var_actual = np.var(actual_r) +explained_var = 1 - np.var(actual_r - pred_v) / (var_actual + 1e-8) if var_actual > 1e-8 else 0.0 + +fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + +# Scatter: predicted vs actual +axes[0].scatter(actual_r, pred_v, alpha=0.2, s=5) +lims = [min(actual_r.min(), pred_v.min()), max(actual_r.max(), pred_v.max())] +axes[0].plot(lims, lims, "r--", label="perfect") +axes[0].set_xlabel("Actual return") +axes[0].set_ylabel("Predicted value") +axes[0].set_title(f"Value accuracy (EV: {explained_var:.4f})") +axes[0].legend() +axes[0].grid(True, alpha=0.3) + +# Value error over time +value_error = np.abs(returns_stoch - buf_stoch["values"]).mean(axis=1) +axes[1].plot(value_error) +axes[1].set_xlabel("Step") +axes[1].set_ylabel("Mean |error|") +axes[1].set_title("Value prediction error over time") +axes[1].grid(True, alpha=0.3) + +# Advantage distribution +axes[2].hist(adv_stoch.flatten(), bins=50, edgecolor="black", alpha=0.7) +axes[2].set_xlabel("Advantage") +axes[2].set_ylabel("Count") +axes[2].set_title(f"Advantage distribution (std={adv_stoch.std():.4f})") + +plt.tight_layout() +plt.show() + +print(f"Explained variance: {explained_var:.4f}") +print(f"Value MSE: {np.mean((actual_r - pred_v) ** 2):.6f}") + +# %% [markdown] +# ## Agent trajectories + +# %% +N_TRAJ = min(16, N) # number of agents to plot + +fig, axes = plt.subplots(1, 2, figsize=(16, 7)) + +for i in range(N_TRAJ): + color = plt.cm.tab20(i % 20) + # Stochastic + axes[0].plot(buf_stoch["positions_x"][:, i], buf_stoch["positions_y"][:, i], alpha=0.6, color=color, linewidth=1) + axes[0].scatter( + buf_stoch["positions_x"][0, i], buf_stoch["positions_y"][0, i], color=color, s=30, marker="o", zorder=5 + ) # start + # Deterministic + axes[1].plot(buf_det["positions_x"][:, i], buf_det["positions_y"][:, i], alpha=0.6, color=color, linewidth=1) + axes[1].scatter(buf_det["positions_x"][0, i], buf_det["positions_y"][0, i], color=color, s=30, marker="o", zorder=5) + +axes[0].set_title(f"Stochastic trajectories (N={N_TRAJ})") +axes[1].set_title(f"Deterministic trajectories (N={N_TRAJ})") +for ax in axes: + ax.set_xlabel("X") + ax.set_ylabel("Y") + ax.set_aspect("equal") + ax.grid(True, alpha=0.3) + +plt.tight_layout() +plt.show() + +# ADE vs ground truth if scenario_length is set +if config["env"].get("scenario_length"): + try: + gt = env.get_ground_truth_trajectories() + # gt['x'] shape: (N, 1, T), positions shape: (T, N) + gt_x = gt["x"][:, 0, :].T # (T, N) + gt_y = gt["y"][:, 0, :].T + gt_valid = gt["valid"][:, 0, :].T + T_gt = gt_x.shape[0] + T_use = min(T_gt, HORIZON) + + disp = np.sqrt( + (buf_stoch["positions_x"][:T_use] - gt_x[:T_use]) ** 2 + + (buf_stoch["positions_y"][:T_use] - gt_y[:T_use]) ** 2 + ) + valid_mask = gt_valid[:T_use] > 0 + if valid_mask.sum() > 0: + ade = disp[valid_mask].mean() + print(f"ADE (stochastic vs ground truth): {ade:.3f}m") + ade_per_agent = np.array( + [disp[:, i][valid_mask[:, i]].mean() for i in range(N) if valid_mask[:, i].sum() > 0] + ) + plt.figure(figsize=(8, 3)) + plt.hist(ade_per_agent, bins=30, edgecolor="black", alpha=0.7) + plt.xlabel("ADE (m)") + plt.ylabel("Count") + plt.title(f"Per-agent ADE distribution (mean={ade:.3f}m)") + plt.tight_layout() + plt.show() + else: + print("No valid ground truth timesteps to compute ADE") + except Exception as e: + print(f"Could not compute ADE: {e}") + +# %% [markdown] +# ## Encoder analysis — what the policy encodes +# +# Each obs layer has its own encoder projecting raw features → embedding width: +# - **ego** and **conditioning** (reward coefs + target): single vector, no pooling. +# - **partners / lanes / boundaries / traffic**: per-slot encoder, padded slots masked to `-inf`, then **max-pooled** across slots → one embedding. Fully-padded layers are zeroed. +# +# The max-pool means each embedding dim is "won" by exactly one slot (object). Below we inspect: +# 1. Encoder inventory (in/out dims, params). +# 2. **What survives the max-pool**: which slot wins per dim, per-dim winner entropy (slot-specialized vs. spread), and where the dominant objects sit in ego frame. +# 3. **Embedding space**: per-encoder contribution (L2 norm), active/dead dims, silence rate. + +# %% +# ── Setup: capture per-encoder embeddings + reconstruct the max-pool ── +bb = policy.actor_backbone +ego_dim = policy.ego_dim +PAD = -1.0 # PADDED_OBSERVATION_VALUE + +# Flat batch of observations from the stochastic rollout +obs_flat = buf_stoch["obs"].reshape(-1, obs_dim) +rng = np.random.default_rng(0) +sel = rng.choice(obs_flat.shape[0], size=min(4096, obs_flat.shape[0]), replace=False) +obs_batch = torch.FloatTensor(obs_flat[sel]).to(device) +B = obs_batch.shape[0] + +# Encoder inventory: (name, module, raw_in_features, n_slots, is_set) +enc_inventory = [("ego", bb.ego_encoder, ego_dim, 1, False)] +if bb.obs_slots_lane_kept > 0: + enc_inventory.append(("lane", bb.lane_encoder, bb.road_features_count, bb.obs_slots_lane_kept, True)) +if bb.obs_slots_boundary_kept > 0: + enc_inventory.append(("boundary", bb.boundary_encoder, bb.road_features_count, bb.obs_slots_boundary_kept, True)) +if bb.obs_slots_partners_n > 0: + enc_inventory.append(("partner", bb.partner_encoder, bb.partner_features_count, bb.obs_slots_partners_n, True)) +if bb.obs_slots_traffic_controls_n > 0: + enc_inventory.append( + ( + "traffic", + bb.traffic_control_encoder, + bb.traffic_control_features_after_onehot, + bb.obs_slots_traffic_controls_n, + True, + ) + ) +if bb.target_dim > 0: + enc_inventory.append(("conditioning", bb.target_encoder, bb.target_dim, 1, False)) + +enc_names = [n for n, *_ in enc_inventory] +set_encs = [n for n, _, _, _, is_set in enc_inventory if is_set] + +print(f"{'encoder':>13s} | {'raw_in':>6s} | {'emb_out':>7s} | {'slots':>5s} | {'pooled':>6s} | {'params':>9s}") +print("-" * 66) +for name, mod, rin, nslots, is_set in enc_inventory: + nparam = sum(p.numel() for p in mod.parameters()) + print( + f"{name:>13s} | {rin:>6d} | {mod[-1].out_features:>7d} | {nslots:>5d} | {('max' if is_set else '-'):>6s} | {nparam:>9,d}" + ) +print( + f"\nBackbone input = {sum(mod[-1].out_features for _, mod, _, _, _ in enc_inventory)} -> backbone -> {bb.out_dim}" +) + +# Capture pre-pool encoder outputs via forward hooks +captured = {} + + +def _hook(name): + def fn(m, i, o): + captured[name] = o.detach() + + return fn + + +handles = [mod.register_forward_hook(_hook(name)) for name, mod, *_ in enc_inventory] +policy.eval() +with torch.no_grad(): + policy(obs_batch) +for h in handles: + h.remove() + +# Reconstruct slot slices (same order as DriveBackbone.forward) + pad masks +partner_dim = bb.obs_slots_partners_n * bb.partner_features_count +lane_dim = bb.obs_slots_lane_kept * bb.road_features_count +boundary_dim = bb.obs_slots_boundary_kept * bb.road_features_count +traffic_dim = bb.obs_slots_traffic_controls_n * bb.traffic_control_features_count +_s = ego_dim + bb.target_dim +sl = {} +sl["partner"] = (_s, _s + partner_dim, bb.obs_slots_partners_n, bb.partner_features_count) +_s += partner_dim +sl["lane"] = (_s, _s + lane_dim, bb.obs_slots_lane_kept, bb.road_features_count) +_s += lane_dim +sl["boundary"] = (_s, _s + boundary_dim, bb.obs_slots_boundary_kept, bb.road_features_count) +_s += boundary_dim +sl["traffic"] = (_s, _s + traffic_dim, bb.obs_slots_traffic_controls_n, bb.traffic_control_features_count) +_s += traffic_dim + +raw, pad, pooled, winners, valid_sample = {}, {}, {}, {}, {} +for name in set_encs: + s, e, ns, nf = sl[name] + obj = obs_batch[:, s:e].view(B, ns, nf) + raw[name] = obj + if name == "traffic": + cont = obj[:, :, : bb.traffic_control_continuous_features] + typ = obj[:, :, bb.traffic_control_continuous_features] + st = obj[:, :, bb.traffic_control_continuous_features + 1] + pad[name] = ( + (cont == PAD).all(dim=2) + & (typ == binding.TRAFFIC_CONTROL_TYPE_NONE) + & (st == binding.TRAFFIC_CONTROL_STATE_UNKNOWN) + ) + else: + pad[name] = (obj == PAD).all(dim=2) + masked = captured[name].masked_fill(pad[name].unsqueeze(2), -torch.inf) + vm = (~pad[name]).any(dim=1) + valid_sample[name] = vm + winners[name] = masked.max(dim=1).indices # (B, embedding dim): winning slot per dim + pooled[name] = torch.where(vm.unsqueeze(1), masked.max(dim=1).values, torch.zeros_like(masked.max(dim=1).values)) + +for name in ("ego", "conditioning"): + if name in enc_names: + pooled[name] = captured[name] + +print("\nCaptured embeddings for:", enc_names) + +# %% +# ── What survives the max-pool: winning slots, specialization, spatial ── +n = len(set_encs) +fig, axes = plt.subplots(n, 3, figsize=(18, 4.2 * n)) +if n == 1: + axes = axes[None, :] + +print(f"{'encoder':>9s} | {'valid%':>6s} | {'mean active slots/dim':>21s} | {'%slot-specialized dims':>22s}") +print("-" * 70) +for r, name in enumerate(set_encs): + s, e, ns, nf = sl[name] + vm = valid_sample[name] + w = winners[name][vm] # (Bv, D) + D = w.shape[1] + + # (1) which slot wins, pooled over all dims+samples + slot_counts = torch.bincount(w.reshape(-1), minlength=ns).float().cpu().numpy() + slot_counts = slot_counts / max(slot_counts.sum(), 1) + axes[r, 0].bar(range(ns), slot_counts, color="teal", alpha=0.85, edgecolor="black") + axes[r, 0].set_title(f"{name}: max-pool winner by slot") + axes[r, 0].set_xlabel("slot index (0 = first/closest)") + axes[r, 0].set_ylabel("frac of dims won") + + # (2) per-dim winner entropy: slot-specialized (0) vs spread across slots (1) + onehot = F.one_hot(w, num_classes=ns).float() # (Bv, D, ns) + p = onehot.mean(dim=0) # (D, ns) winner distribution per dim + ent = (-(p * (p + 1e-9).log()).sum(dim=1) / np.log(ns)).cpu().numpy() + axes[r, 1].hist(ent, bins=30, color="indianred", alpha=0.85, edgecolor="black") + axes[r, 1].set_title(f"{name}: per-dim winner entropy") + axes[r, 1].set_xlabel("0 = slot-specialized → 1 = spread") + axes[r, 1].set_xlim(0, 1) + + # (3) ego-frame position of the dominant object (mode winning slot per sample) + dom = torch.mode(w, dim=1).values # (Bv,) + rel = raw[name][vm] + dom_xy = rel[torch.arange(rel.shape[0]), dom][:, :2].cpu().numpy() + axes[r, 2].scatter(dom_xy[:, 0], dom_xy[:, 1], s=3, alpha=0.15, color="navy") + axes[r, 2].scatter(0, 0, marker="*", s=200, color="red", zorder=5, label="ego") + axes[r, 2].set_title(f"{name}: dominant object position (ego frame)") + axes[r, 2].set_xlabel("rel_x") + axes[r, 2].set_ylabel("rel_y") + axes[r, 2].set_aspect("equal") + axes[r, 2].legend(fontsize=8) + + active_per_dim = np.exp(ent * np.log(ns)).mean() + print( + f"{name:>9s} | {100 * vm.float().mean().item():>5.1f}% | {active_per_dim:>21.2f} | {100 * (ent < 0.2).mean():>21.1f}%" + ) + +plt.tight_layout() +plt.show() + + +# ── H1/H2/H3 check: boundary max-pool winner distance vs slot index ── +if "boundary" in set_encs: + vm = valid_sample["boundary"] + w = winners["boundary"][vm] # (Bv, D) winning slot per dim + rb = raw["boundary"][vm] # (Bv, ns, nf) raw segments + nsb = rb.shape[1] + reldist = torch.hypot(rb[:, :, 0], rb[:, :, 1]) # (Bv, ns) normalized ego-frame dist + valid_seg = ~pad["boundary"][vm] # (Bv, ns) slots holding a real segment + + win_reldist = torch.gather(reldist, 1, w) # (Bv, D) dist of each winning segment + slot_flat = w.reshape(-1) + wdist_flat = win_reldist.reshape(-1) + total_wins = slot_flat.numel() + + print("\n=== Boundary winner distance vs slot index (H1/H2/H3) ===") + print(f"{'slot':>4s} | {'#wins':>8s} | {'win%':>6s} | {'rel_dist winners':>16s} | {'rel_dist occupied':>17s}") + print("-" * 64) + for s in range(nsb): + wm = slot_flat == s + nwin = int(wm.sum()) + wmean = wdist_flat[wm].mean().item() if nwin > 0 else float("nan") + occ = valid_seg[:, s] + omean = reldist[occ, s].mean().item() if int(occ.sum()) > 0 else float("nan") + print(f"{s:>4d} | {nwin:>8d} | {100 * nwin / total_wins:>5.1f}% | {wmean:>16.4f} | {omean:>17.4f}") + + win_mean = wdist_flat.mean().item() + seg_mean = reldist[valid_seg].mean().item() + print(f"\nMean rel_dist of WINNING segments : {win_mean:.4f}") + print(f"Mean rel_dist of ALL valid segments: {seg_mean:.4f}") + verdict = "FARTHER than avg (H3 supported)" if win_mean > seg_mean else "nearer than avg" + print(f"-> winners are {verdict} by {win_mean - seg_mean:+.4f} (normalized units)") + + fig, ax = plt.subplots(1, 2, figsize=(14, 4)) + occ_means = [ + reldist[valid_seg[:, s], s].mean().item() if int(valid_seg[:, s].sum()) > 0 else np.nan for s in range(nsb) + ] + win_means = [ + wdist_flat[slot_flat == s].mean().item() if int((slot_flat == s).sum()) > 0 else np.nan for s in range(nsb) + ] + ax[0].plot(range(nsb), occ_means, "o-", label="occupied (any segment in slot)") + ax[0].plot(range(nsb), win_means, "s-", label="winners only") + ax[0].axhline(seg_mean, color="gray", ls="--", label="global valid mean") + ax[0].set_xlabel("slot index") + ax[0].set_ylabel("mean rel_dist (normalized)") + ax[0].set_title("Boundary rel_dist vs slot index\n(flat = not distance-sorted -> H1/H2)") + ax[0].legend(fontsize=8) + ax[1].hist(reldist[valid_seg].cpu().numpy(), bins=50, alpha=0.6, density=True, label="all valid segs", color="gray") + ax[1].hist(wdist_flat.cpu().numpy(), bins=50, alpha=0.6, density=True, label="winners", color="crimson") + ax[1].axvline(seg_mean, color="gray", ls="--") + ax[1].axvline(win_mean, color="crimson", ls="--") + ax[1].set_xlabel("rel_dist (normalized)") + ax[1].set_title("Winner vs all-segment distance (H3)") + ax[1].legend(fontsize=8) + plt.tight_layout() + plt.show() + +# %% +# ── Embedding space: per-encoder contribution, active/dead dims, silence ── +fig, axes = plt.subplots(1, 3, figsize=(20, 5)) + +# (1) Mean L2 norm of each pooled embedding = relative weight in the concat fed to backbone +norms = [pooled[n].norm(dim=1).mean().item() for n in enc_names] +axes[0].bar(enc_names, norms, color="slateblue", edgecolor="black") +axes[0].set_title("Mean L2 norm of pooled embedding\n(relative contribution to backbone input)") +axes[0].tick_params(axis="x", rotation=45) +axes[0].grid(True, axis="y", alpha=0.3) + +# (2) Mean |activation| per embedding dim, per encoder +M = np.stack([pooled[n].abs().mean(0).cpu().numpy() for n in enc_names]) +im = axes[1].imshow(M, aspect="auto", cmap="magma") +axes[1].set_yticks(range(len(enc_names))) +axes[1].set_yticklabels(enc_names) +axes[1].set_xlabel("embedding dim") +axes[1].set_title("Mean |activation| per embedding dim") +plt.colorbar(im, ax=axes[1]) + +# (3) Dead dims (std<1e-4) — capacity the encoder never uses +dead = [(pooled[n].std(0) < 1e-4).float().mean().item() for n in enc_names] +axes[2].bar(enc_names, dead, color="gray", edgecolor="black") +axes[2].set_title("Fraction of dead embedding dims (std < 1e-4)") +axes[2].tick_params(axis="x", rotation=45) +axes[2].set_ylim(0, 1) +axes[2].grid(True, axis="y", alpha=0.3) + +plt.tight_layout() +plt.show() + +print(f"{'encoder':>13s} | {'mean|act|':>9s} | {'emb L2':>7s} | {'dead dims':>9s} | {'silence (fully padded)':>22s}") +print("-" * 80) +for name in enc_names: + silence = (1 - valid_sample[name].float().mean().item()) if name in valid_sample else 0.0 + deadf = (pooled[name].std(0) < 1e-4).float().mean().item() + print( + f"{name:>13s} | {pooled[name].abs().mean().item():>9.4f} | {pooled[name].norm(dim=1).mean().item():>7.3f} | " + f"{100 * deadf:>7.1f}% | {100 * silence:>21.1f}%" + ) diff --git a/notebooks/06_architecture.ipynb b/notebooks/06_architecture.ipynb deleted file mode 100644 index 7a799a2abf..0000000000 --- a/notebooks/06_architecture.ipynb +++ /dev/null @@ -1,813 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 06 - Neural Network Architecture\n", - "Visualize, analyze, and iterate on the DrivePolicy architecture. Covers model summary, per-encoder breakdown, forward pass shape tracing, weight distributions, and architecture comparison." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "import torch\n", - "import torch.nn.functional as F\n", - "from torchinfo import summary\n", - "from pufferlib.ocean.drive import binding\n", - "from pufferlib.ocean.torch import Drive as DrivePolicy\n", - "from notebooks.notebook_utils import make_drive_env, zero_actions\n", - "\n", - "# --- Policy architecture ---\n", - "INPUT_SIZE = 64\n", - "BACKBONE_HIDDEN_SIZE = 1024\n", - "BACKBONE_NUM_LAYERS = 3\n", - "ACTOR_HIDDEN_SIZE = 128\n", - "ACTOR_NUM_LAYERS = 3\n", - "CRITIC_HIDDEN_SIZE = 64\n", - "CRITIC_NUM_LAYERS = 2\n", - "SHARED_NETWORK = True\n", - "ENCODER_ACTIVATION = \"tanh\"\n", - "ENCODER_LAYER_NORM = True\n", - "BACKBONE_ACTIVATION = \"gelu\"\n", - "BACKBONE_LAYER_NORM = False\n", - "\n", - "env, obs, info = make_drive_env()\n", - "\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "policy = DrivePolicy(\n", - " env,\n", - " ego_input_size=INPUT_SIZE,\n", - " partner_input_size=INPUT_SIZE,\n", - " lane_input_size=INPUT_SIZE,\n", - " boundary_input_size=INPUT_SIZE,\n", - " traffic_control_input_size=INPUT_SIZE,\n", - " target_input_size=INPUT_SIZE,\n", - " backbone_hidden_size=BACKBONE_HIDDEN_SIZE,\n", - " backbone_num_layers=BACKBONE_NUM_LAYERS,\n", - " actor_hidden_size=ACTOR_HIDDEN_SIZE,\n", - " actor_num_layers=ACTOR_NUM_LAYERS,\n", - " critic_hidden_size=CRITIC_HIDDEN_SIZE,\n", - " critic_num_layers=CRITIC_NUM_LAYERS,\n", - " encoder_activation=ENCODER_ACTIVATION,\n", - " encoder_layer_norm=ENCODER_LAYER_NORM,\n", - " backbone_activation=BACKBONE_ACTIVATION,\n", - " backbone_layer_norm=BACKBONE_LAYER_NORM,\n", - " shared_network=SHARED_NETWORK,\n", - ").to(device)\n", - "\n", - "print(f\"Device: {device}\")\n", - "print(f\"Obs dim: {obs.shape[1]}\")\n", - "print(f\"Action dim: {policy.atn_dim}\")\n", - "print(f\"Shared network: {SHARED_NETWORK}\")\n", - "print(f\"Backbone: {BACKBONE_HIDDEN_SIZE} x {BACKBONE_NUM_LAYERS}L\")\n", - "print(f\"Actor: {ACTOR_HIDDEN_SIZE} x {ACTOR_NUM_LAYERS}L\")\n", - "print(f\"Critic: {CRITIC_HIDDEN_SIZE} x {CRITIC_NUM_LAYERS}L\")\n", - "print(f\"Encoder: {ENCODER_ACTIVATION}, LayerNorm: {ENCODER_LAYER_NORM}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Model Summary (torchinfo)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "obs_tensor = torch.FloatTensor(obs).to(device)\n", - "summary(policy, input_data=obs_tensor, depth=4, col_names=[\"input_size\", \"output_size\", \"num_params\", \"mult_adds\"])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Architecture Diagram" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "backbone = policy.actor_backbone\n", - "cond_dim = backbone.target_dim\n", - "\n", - "# Collect encoder info\n", - "encoders = [\n", - " (\"ego\", env.ego_features, 1, \"direct\", INPUT_SIZE),\n", - " (\"conditioning\", cond_dim, 1, \"direct\", INPUT_SIZE) if cond_dim > 0 else None,\n", - " (\"partner\", env.partner_features, env.obs_slots_partners_n, \"max-pool\", INPUT_SIZE),\n", - " (\"lane\", env.road_features, env.obs_slots_lane_kept, \"max-pool\", INPUT_SIZE),\n", - " (\"boundary\", env.road_features, env.obs_slots_boundary_kept, \"max-pool\", INPUT_SIZE),\n", - " (\n", - " \"traffic_ctrl\",\n", - " env.traffic_control_features - 2 + binding.NUM_TRAFFIC_CONTROL_TYPES + binding.NUM_TRAFFIC_CONTROL_STATES,\n", - " env.obs_slots_traffic_controls_n,\n", - " \"max-pool (onehot)\",\n", - " INPUT_SIZE,\n", - " ),\n", - "]\n", - "encoders = [e for e in encoders if e is not None]\n", - "\n", - "fig, ax = plt.subplots(figsize=(14, 8))\n", - "ax.set_xlim(0, 10)\n", - "ax.set_ylim(0, 10)\n", - "ax.axis(\"off\")\n", - "\n", - "n_enc = len(encoders)\n", - "y_positions = np.linspace(9, 1, n_enc)\n", - "colors = plt.cm.Set2(np.linspace(0, 1, n_enc))\n", - "\n", - "# Draw encoders\n", - "for i, ((name, in_f, n_obj, agg, out_size), y, c) in enumerate(zip(encoders, y_positions, colors)):\n", - " # Input box\n", - " label = f\"{name}\\n{n_obj}x{in_f}\" if n_obj > 1 else f\"{name}\\n{in_f}\"\n", - " ax.add_patch(plt.Rectangle((0.2, y - 0.3), 1.6, 0.6, facecolor=c, edgecolor=\"black\", lw=1.2, alpha=0.8))\n", - " ax.text(1.0, y, label, ha=\"center\", va=\"center\", fontsize=8, fontweight=\"bold\")\n", - "\n", - " # Encoder box\n", - " ax.add_patch(plt.Rectangle((2.5, y - 0.25), 2.0, 0.5, facecolor=\"lightyellow\", edgecolor=\"black\", lw=1))\n", - " ax.text(3.5, y + 0.05, f\"Linear({in_f},{out_size})\", ha=\"center\", va=\"center\", fontsize=7)\n", - " ln_label = \"LN+\" if ENCODER_LAYER_NORM else \"\"\n", - " ax.text(\n", - " 3.5,\n", - " y - 0.12,\n", - " f\"{ln_label}{ENCODER_ACTIVATION}+Linear({out_size},{out_size})\",\n", - " ha=\"center\",\n", - " va=\"center\",\n", - " fontsize=6,\n", - " color=\"gray\",\n", - " )\n", - "\n", - " # Aggregation\n", - " if n_obj > 1:\n", - " ax.text(5.0, y, agg, ha=\"center\", va=\"center\", fontsize=7, style=\"italic\", color=\"darkblue\")\n", - " arrow_start = 5.5\n", - " else:\n", - " arrow_start = 4.6\n", - "\n", - " # Arrows\n", - " ax.annotate(\"\", xy=(2.5, y), xytext=(1.8, y), arrowprops=dict(arrowstyle=\"->\", lw=1))\n", - " ax.annotate(\"\", xy=(6.0, 5.0), xytext=(arrow_start, y), arrowprops=dict(arrowstyle=\"->\", lw=0.8, color=\"gray\"))\n", - "\n", - "# Concat box\n", - "ax.add_patch(plt.Rectangle((5.8, 4.5), 1.4, 1.0, facecolor=\"lightsalmon\", edgecolor=\"black\", lw=1.5))\n", - "ax.text(6.5, 5.2, \"Concat\", ha=\"center\", va=\"center\", fontsize=9, fontweight=\"bold\")\n", - "ax.text(6.5, 4.85, f\"{n_enc}x{INPUT_SIZE}={n_enc * INPUT_SIZE}\", ha=\"center\", va=\"center\", fontsize=7)\n", - "\n", - "# Backbone\n", - "ax.add_patch(plt.Rectangle((7.5, 4.5), 1.3, 1.0, facecolor=\"lightblue\", edgecolor=\"black\", lw=1.5))\n", - "ax.text(8.15, 5.15, f\"Backbone ({BACKBONE_NUM_LAYERS}L)\", ha=\"center\", va=\"center\", fontsize=8, fontweight=\"bold\")\n", - "ax.text(8.15, 4.85, f\"GELU+Linear\\n({n_enc * INPUT_SIZE},{BACKBONE_HIDDEN_SIZE})\", ha=\"center\", va=\"center\", fontsize=6)\n", - "ax.annotate(\"\", xy=(7.5, 5.0), xytext=(7.2, 5.0), arrowprops=dict(arrowstyle=\"->\", lw=1.5))\n", - "\n", - "# Actor / Critic heads\n", - "ax.add_patch(plt.Rectangle((9.0, 5.7), 0.9, 0.6, facecolor=\"lightgreen\", edgecolor=\"black\", lw=1.2))\n", - "actor_label = f\"Actor ({ACTOR_NUM_LAYERS}L)\\n{BACKBONE_HIDDEN_SIZE}->{sum(policy.atn_dim)}\"\n", - "if ACTOR_NUM_LAYERS > 1:\n", - " actor_label = (\n", - " f\"Actor ({ACTOR_NUM_LAYERS}L)\\n{BACKBONE_HIDDEN_SIZE}->{ACTOR_HIDDEN_SIZE}->...->{sum(policy.atn_dim)}\"\n", - " )\n", - "ax.text(9.45, 6.0, actor_label, ha=\"center\", va=\"center\", fontsize=6, fontweight=\"bold\")\n", - "\n", - "ax.add_patch(plt.Rectangle((9.0, 3.7), 0.9, 0.6, facecolor=\"plum\", edgecolor=\"black\", lw=1.2))\n", - "critic_label = f\"Critic ({CRITIC_NUM_LAYERS}L)\\n{BACKBONE_HIDDEN_SIZE}->1\"\n", - "if CRITIC_NUM_LAYERS > 1:\n", - " critic_label = f\"Critic ({CRITIC_NUM_LAYERS}L)\\n{BACKBONE_HIDDEN_SIZE}->{CRITIC_HIDDEN_SIZE}->...->1\"\n", - "ax.text(9.45, 4.0, critic_label, ha=\"center\", va=\"center\", fontsize=6, fontweight=\"bold\")\n", - "\n", - "ax.annotate(\"\", xy=(9.0, 6.0), xytext=(8.8, 5.3), arrowprops=dict(arrowstyle=\"->\", lw=1.2))\n", - "ax.annotate(\"\", xy=(9.0, 4.0), xytext=(8.8, 4.7), arrowprops=dict(arrowstyle=\"->\", lw=1.2))\n", - "\n", - "split_label = \"SHARED\" if SHARED_NETWORK else \"SPLIT\"\n", - "ax.text(8.9, 4.55, split_label, ha=\"center\", va=\"center\", fontsize=7, color=\"red\", fontweight=\"bold\")\n", - "\n", - "ax.text(\n", - " 5.0,\n", - " 0.3,\n", - " f\"Encoder: {ENCODER_ACTIVATION} | LayerNorm: {ENCODER_LAYER_NORM}\",\n", - " ha=\"center\",\n", - " va=\"center\",\n", - " fontsize=8,\n", - " color=\"darkgreen\",\n", - " fontweight=\"bold\",\n", - ")\n", - "\n", - "ax.set_title(\n", - " f\"DrivePolicy Architecture (encoder_size={INPUT_SIZE}, backbone={BACKBONE_HIDDEN_SIZE})\",\n", - " fontsize=12,\n", - " fontweight=\"bold\",\n", - ")\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Per-Encoder Parameter Breakdown" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def count_params(module):\n", - " return sum(p.numel() for p in module.parameters())\n", - "\n", - "\n", - "backbone = policy.actor_backbone\n", - "components = {\n", - " \"ego_encoder\": backbone.ego_encoder,\n", - " \"lane_encoder\": backbone.lane_encoder,\n", - " \"boundary_encoder\": backbone.boundary_encoder,\n", - " \"partner_encoder\": backbone.partner_encoder,\n", - " \"traffic_ctrl_encoder\": backbone.traffic_control_encoder,\n", - "}\n", - "if backbone.target_dim > 0:\n", - " components[\"target_encoder\"] = backbone.target_encoder\n", - "components[\"backbone_mlp\"] = backbone.backbone\n", - "components[\"actor_head\"] = policy.actor_head\n", - "components[\"critic_head\"] = policy.critic_head\n", - "\n", - "names, counts = zip(*[(k, count_params(v)) for k, v in components.items()])\n", - "total = sum(counts)\n", - "\n", - "print(f\"{'Component':>25s} | {'Params':>10s} | {'%':>6s}\")\n", - "print(\"-\" * 48)\n", - "for n, c in zip(names, counts):\n", - " print(f\"{n:>25s} | {c:>10,d} | {c / total:>5.1%}\")\n", - "print(\"-\" * 48)\n", - "print(f\"{'TOTAL':>25s} | {total:>10,d}\")\n", - "if not SHARED_NETWORK:\n", - " critic_bb = count_params(policy.critic_backbone)\n", - " print(f\"{'+ critic_backbone':>25s} | {critic_bb:>10,d}\")\n", - " print(f\"{'GRAND TOTAL':>25s} | {total + critic_bb:>10,d}\")\n", - "\n", - "fig, ax = plt.subplots(figsize=(8, 5))\n", - "colors = plt.cm.Set3(np.linspace(0, 1, len(names)))\n", - "bars = ax.barh(names, counts, color=colors, edgecolor=\"black\")\n", - "for bar, c in zip(bars, counts):\n", - " ax.text(bar.get_width() + total * 0.01, bar.get_y() + bar.get_height() / 2, f\"{c:,}\", va=\"center\", fontsize=8)\n", - "ax.set_xlabel(\"Parameters\")\n", - "ax.set_title(f\"Parameter Distribution ({total:,} total)\")\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Forward Pass Shape Trace" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "x = obs_tensor\n", - "backbone = policy.actor_backbone\n", - "\n", - "slide_idx = env.ego_features\n", - "cond_dim = backbone.target_dim\n", - "partner_dim = env.obs_slots_partners_n * env.partner_features\n", - "lane_dim = env.obs_slots_lane_kept * env.road_features\n", - "boundary_dim = env.obs_slots_boundary_kept * env.road_features\n", - "traffic_dim = env.obs_slots_traffic_controls_n * env.traffic_control_features\n", - "\n", - "# Slicing\n", - "ego_obs = x[:, :slide_idx]\n", - "slices = [(\"ego\", 0, slide_idx, ego_obs.shape)]\n", - "\n", - "if cond_dim > 0:\n", - " cond_obs = x[:, slide_idx : slide_idx + cond_dim]\n", - " slices.append((\"conditioning\", slide_idx, slide_idx + cond_dim, cond_obs.shape))\n", - " slide_idx += cond_dim\n", - "\n", - "partner_obs = x[:, slide_idx : slide_idx + partner_dim]\n", - "slices.append((\"partners\", slide_idx, slide_idx + partner_dim, partner_obs.shape))\n", - "slide_idx += partner_dim\n", - "\n", - "lane_obs = x[:, slide_idx : slide_idx + lane_dim]\n", - "slices.append((\"lanes\", slide_idx, slide_idx + lane_dim, lane_obs.shape))\n", - "slide_idx += lane_dim\n", - "\n", - "boundary_obs = x[:, slide_idx : slide_idx + boundary_dim]\n", - "slices.append((\"boundaries\", slide_idx, slide_idx + boundary_dim, boundary_obs.shape))\n", - "slide_idx += boundary_dim\n", - "\n", - "traffic_obs = x[:, slide_idx : slide_idx + traffic_dim]\n", - "slices.append((\"traffic_ctrl\", slide_idx, slide_idx + traffic_dim, traffic_obs.shape))\n", - "\n", - "print(f\"Obs buffer layout (total={x.shape[1]}):\")\n", - "print(f\"{'Name':>15s} | {'Start':>5s} | {'End':>5s} | {'Width':>5s} | Shape\")\n", - "print(\"-\" * 65)\n", - "for name, start, end, shape in slices:\n", - " print(f\"{name:>15s} | {start:>5d} | {end:>5d} | {end - start:>5d} | {shape}\")\n", - "\n", - "# Forward through encoders\n", - "print(\"\\nEncoder outputs:\")\n", - "with torch.no_grad():\n", - " ego_enc = backbone.ego_encoder(ego_obs)\n", - " print(f\" ego_encoder: {ego_obs.shape} -> {ego_enc.shape}\")\n", - "\n", - " if cond_dim > 0:\n", - " cond_enc = backbone.target_encoder(cond_obs)\n", - " print(f\" cond_encoder: {cond_obs.shape} -> {cond_enc.shape}\")\n", - "\n", - " p_reshaped = partner_obs.view(-1, env.obs_slots_partners_n, env.partner_features)\n", - " p_enc, _ = backbone.partner_encoder(p_reshaped).max(dim=1)\n", - " print(f\" partner_encoder: {partner_obs.shape} -> view {p_reshaped.shape} -> encode -> max-pool -> {p_enc.shape}\")\n", - "\n", - " l_reshaped = lane_obs.view(-1, env.obs_slots_lane_kept, env.road_features)\n", - " l_enc, _ = backbone.lane_encoder(l_reshaped).max(dim=1)\n", - " print(f\" lane_encoder: {lane_obs.shape} -> view {l_reshaped.shape} -> encode -> max-pool -> {l_enc.shape}\")\n", - "\n", - " b_reshaped = boundary_obs.view(-1, env.obs_slots_boundary_kept, env.road_features)\n", - " b_enc, _ = backbone.boundary_encoder(b_reshaped).max(dim=1)\n", - " print(f\" bound_encoder: {boundary_obs.shape} -> view {b_reshaped.shape} -> encode -> max-pool -> {b_enc.shape}\")\n", - "\n", - " t_reshaped = traffic_obs.view(-1, env.obs_slots_traffic_controls_n, env.traffic_control_features)\n", - " t_cont = t_reshaped[:, :, : env.traffic_control_features - 2]\n", - " t_type = t_reshaped[:, :, env.traffic_control_features - 2]\n", - " t_state = t_reshaped[:, :, env.traffic_control_features - 1]\n", - " t_type_onehot = F.one_hot(t_type.long(), num_classes=binding.NUM_TRAFFIC_CONTROL_TYPES).float()\n", - " t_state_onehot = F.one_hot(t_state.long(), num_classes=binding.NUM_TRAFFIC_CONTROL_STATES).float()\n", - " t_input = torch.cat([t_cont, t_type_onehot, t_state_onehot], dim=2)\n", - " t_enc, _ = backbone.traffic_control_encoder(t_input).max(dim=1)\n", - " print(\n", - " f\" traffic_encoder: {traffic_obs.shape} -> view {t_reshaped.shape} -> onehot {t_input.shape} -> encode -> max-pool -> {t_enc.shape}\"\n", - " )\n", - "\n", - " # Concat + backbone\n", - " features = [ego_enc, l_enc, b_enc, p_enc, t_enc]\n", - " if cond_dim > 0:\n", - " features.append(cond_enc)\n", - " concat = torch.cat(features, dim=1)\n", - " hidden = backbone.backbone(concat)\n", - " print(f\"\\n concat: {concat.shape}\")\n", - " print(f\" backbone_mlp: {concat.shape} -> {hidden.shape}\")\n", - "\n", - " # Heads\n", - " actor_out = policy.actor_head(hidden)\n", - " critic_out = policy.critic_head(hidden)\n", - " print(f\" actor_head: {hidden.shape} -> {actor_out.shape} (split into {policy.atn_dim})\")\n", - " print(f\" critic_head: {hidden.shape} -> {critic_out.shape}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Weight Distributions by Layer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "weight_data = [\n", - " (n, p.data.cpu().numpy().flatten()) for n, p in policy.named_parameters() if \"weight\" in n and p.dim() >= 2\n", - "]\n", - "\n", - "n_weights = len(weight_data)\n", - "cols = 4\n", - "rows = (n_weights + cols - 1) // cols\n", - "fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 3 * rows))\n", - "axes = axes.flatten()\n", - "\n", - "for i, (name, w) in enumerate(weight_data):\n", - " ax = axes[i]\n", - " ax.hist(w, bins=50, edgecolor=\"black\", alpha=0.7, density=True)\n", - " ax.set_title(name.replace(\"actor_backbone.\", \"\"), fontsize=7)\n", - " ax.axvline(0, color=\"red\", ls=\"--\", lw=0.5)\n", - " ax.text(0.95, 0.95, f\"std={w.std():.3f}\", transform=ax.transAxes, fontsize=6, ha=\"right\", va=\"top\")\n", - "\n", - "for j in range(i + 1, len(axes)):\n", - " axes[j].axis(\"off\")\n", - "\n", - "fig.suptitle(\"Weight Distributions (init)\", fontsize=12, fontweight=\"bold\")\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Activation Analysis (per encoder)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "policy.eval()\n", - "with torch.no_grad():\n", - " hidden = policy.actor_backbone(obs_tensor, env.ego_features)\n", - " action_logits, value = policy.decode_actions(hidden)\n", - "\n", - "# Collect per-encoder activations\n", - "activations = {}\n", - "with torch.no_grad():\n", - " slide = env.ego_features\n", - " activations[\"ego\"] = backbone.ego_encoder(obs_tensor[:, : env.ego_features])\n", - "\n", - " if cond_dim > 0:\n", - " activations[\"conditioning\"] = backbone.target_encoder(obs_tensor[:, slide : slide + cond_dim])\n", - " slide += cond_dim\n", - "\n", - " p_obs = obs_tensor[:, slide : slide + partner_dim].view(-1, env.obs_slots_partners_n, env.partner_features)\n", - " activations[\"partner\"], _ = backbone.partner_encoder(p_obs).max(dim=1)\n", - " slide += partner_dim\n", - "\n", - " l_obs = obs_tensor[:, slide : slide + lane_dim].view(-1, env.obs_slots_lane_kept, env.road_features)\n", - " activations[\"lane\"], _ = backbone.lane_encoder(l_obs).max(dim=1)\n", - " slide += lane_dim\n", - "\n", - " b_obs = obs_tensor[:, slide : slide + boundary_dim].view(-1, env.obs_slots_boundary_kept, env.road_features)\n", - " activations[\"boundary\"], _ = backbone.boundary_encoder(b_obs).max(dim=1)\n", - " slide += boundary_dim\n", - "\n", - " t_obs = obs_tensor[:, slide : slide + traffic_dim].view(\n", - " -1, env.obs_slots_traffic_controls_n, env.traffic_control_features\n", - " )\n", - " t_cont = t_obs[:, :, : env.traffic_control_features - 2]\n", - " t_type = t_obs[:, :, env.traffic_control_features - 2]\n", - " t_state = t_obs[:, :, env.traffic_control_features - 1]\n", - " t_type_onehot = F.one_hot(t_type.long(), num_classes=binding.NUM_TRAFFIC_CONTROL_TYPES).float()\n", - " t_state_onehot = F.one_hot(t_state.long(), num_classes=binding.NUM_TRAFFIC_CONTROL_STATES).float()\n", - " t_input = torch.cat([t_cont, t_type_onehot, t_state_onehot], dim=2)\n", - " activations[\"traffic_ctrl\"], _ = backbone.traffic_control_encoder(t_input).max(dim=1)\n", - "\n", - " activations[\"hidden\"] = hidden\n", - "\n", - "fig, axes = plt.subplots(2, 4, figsize=(16, 6))\n", - "axes = axes.flatten()\n", - "for i, (name, act) in enumerate(activations.items()):\n", - " if i >= len(axes):\n", - " break\n", - " vals = act.cpu().numpy().flatten()\n", - " ax = axes[i]\n", - " ax.hist(vals, bins=50, edgecolor=\"black\", alpha=0.7)\n", - " dead = (act.abs().sum(dim=0) == 0).sum().item()\n", - " ax.set_title(f\"{name} (dead={dead}/{act.shape[1]})\", fontsize=9)\n", - " ax.text(\n", - " 0.95,\n", - " 0.95,\n", - " f\"mean={vals.mean():.3f}\\nstd={vals.std():.3f}\",\n", - " transform=ax.transAxes,\n", - " fontsize=7,\n", - " ha=\"right\",\n", - " va=\"top\",\n", - " )\n", - "\n", - "for j in range(i + 1, len(axes)):\n", - " axes[j].axis(\"off\")\n", - "\n", - "fig.suptitle(\"Per-Encoder Activation Distributions\", fontsize=12, fontweight=\"bold\")\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Encoder Embedding Similarity (cosine)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Mean embedding per encoder (exclude hidden — different dim)\n", - "emb_names = [k for k in activations.keys() if k != \"hidden\"]\n", - "emb_means = torch.stack([activations[k].mean(dim=0) for k in emb_names])\n", - "emb_norm = F.normalize(emb_means, dim=1)\n", - "sim_matrix = (emb_norm @ emb_norm.T).cpu().numpy()\n", - "\n", - "fig, ax = plt.subplots(figsize=(7, 6))\n", - "im = ax.imshow(sim_matrix, cmap=\"RdBu_r\", vmin=-1, vmax=1)\n", - "ax.set_xticks(range(len(emb_names)))\n", - "ax.set_yticks(range(len(emb_names)))\n", - "ax.set_xticklabels(emb_names, rotation=45, ha=\"right\", fontsize=8)\n", - "ax.set_yticklabels(emb_names, fontsize=8)\n", - "for i in range(len(emb_names)):\n", - " for j in range(len(emb_names)):\n", - " ax.text(j, i, f\"{sim_matrix[i, j]:.2f}\", ha=\"center\", va=\"center\", fontsize=7)\n", - "fig.colorbar(im, ax=ax)\n", - "ax.set_title(\"Cosine Similarity Between Encoder Mean Embeddings\")\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Architecture Comparison\n", - "Compare different architecture configs side-by-side without training." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "configs = [\n", - " {\"name\": \"tiny\", \"encoder_size\": 32, \"backbone_hidden_size\": 64},\n", - " {\"name\": \"small\", \"encoder_size\": 64, \"backbone_hidden_size\": 128},\n", - " {\"name\": \"medium\", \"encoder_size\": 128, \"backbone_hidden_size\": 256, \"backbone_num_layers\": 2},\n", - " {\n", - " \"name\": \"large\",\n", - " \"encoder_size\": 128,\n", - " \"backbone_hidden_size\": 512,\n", - " \"backbone_num_layers\": 2,\n", - " \"actor_num_layers\": 2,\n", - " \"actor_hidden_size\": 256,\n", - " \"critic_num_layers\": 2,\n", - " \"critic_hidden_size\": 256,\n", - " },\n", - " {\n", - " \"name\": \"xlarge\",\n", - " \"encoder_size\": 256,\n", - " \"backbone_hidden_size\": 1024,\n", - " \"backbone_num_layers\": 3,\n", - " \"actor_num_layers\": 2,\n", - " \"actor_hidden_size\": 512,\n", - " \"critic_num_layers\": 2,\n", - " \"critic_hidden_size\": 512,\n", - " },\n", - " {\"name\": \"small+tanh\", \"encoder_size\": 64, \"backbone_hidden_size\": 128, \"encoder_activation\": \"tanh\"},\n", - " {\n", - " \"name\": \"medium+tanh\",\n", - " \"encoder_size\": 128,\n", - " \"backbone_hidden_size\": 256,\n", - " \"backbone_num_layers\": 2,\n", - " \"encoder_activation\": \"tanh\",\n", - " },\n", - "]\n", - "\n", - "POLICY_DEFAULTS = {\n", - " \"ego_input_size\": 64,\n", - " \"partner_input_size\": 64,\n", - " \"lane_input_size\": 64,\n", - " \"boundary_input_size\": 64,\n", - " \"traffic_control_input_size\": 64,\n", - " \"target_input_size\": 64,\n", - " \"backbone_num_layers\": 1,\n", - " \"actor_hidden_size\": 128,\n", - " \"actor_num_layers\": 0,\n", - " \"critic_hidden_size\": 128,\n", - " \"critic_num_layers\": 0,\n", - " \"encoder_activation\": \"relu\",\n", - " \"encoder_layer_norm\": True,\n", - " \"backbone_activation\": \"gelu\",\n", - " \"backbone_layer_norm\": False,\n", - " \"shared_network\": True,\n", - "}\n", - "\n", - "results = []\n", - "for cfg in configs:\n", - " name = cfg[\"name\"]\n", - " encoder_size = cfg.get(\"encoder_size\", POLICY_DEFAULTS[\"ego_input_size\"])\n", - " full_cfg = {**POLICY_DEFAULTS, **{k: v for k, v in cfg.items() if k not in (\"name\", \"encoder_size\")}}\n", - " full_cfg.update(\n", - " {\n", - " \"ego_input_size\": encoder_size,\n", - " \"partner_input_size\": encoder_size,\n", - " \"lane_input_size\": encoder_size,\n", - " \"boundary_input_size\": encoder_size,\n", - " \"traffic_control_input_size\": encoder_size,\n", - " \"target_input_size\": encoder_size,\n", - " }\n", - " )\n", - " p = DrivePolicy(env, **full_cfg).to(device)\n", - " n_params = sum(pp.numel() for pp in p.parameters())\n", - "\n", - " with torch.no_grad():\n", - " import time\n", - "\n", - " t0 = time.time()\n", - " for _ in range(100):\n", - " p(obs_tensor)\n", - " if device.type == \"cuda\":\n", - " torch.cuda.synchronize()\n", - " ms_per_fwd = (time.time() - t0) / 100 * 1000\n", - "\n", - " results.append({\"name\": name, \"encoder_size\": encoder_size, \"params\": n_params, \"ms/fwd\": ms_per_fwd, **full_cfg})\n", - " del p\n", - "\n", - "print(\n", - " f\"{'Config':>12s} | {'enc':>5s} | {'bb_h':>5s} | {'bb_L':>4s} | {'act_h':>5s} | {'act_L':>5s} | {'crt_h':>5s} | {'crt_L':>5s} | {'enc_act':>7s} | {'Params':>10s} | {'ms/fwd':>8s}\"\n", - ")\n", - "print(\"-\" * 105)\n", - "for r in results:\n", - " print(\n", - " f\"{r['name']:>12s} | {r['encoder_size']:>5d} | {r['backbone_hidden_size']:>5d} | {r['backbone_num_layers']:>4d} | {r['actor_hidden_size']:>5d} | {r['actor_num_layers']:>5d} | {r['critic_hidden_size']:>5d} | {r['critic_num_layers']:>5d} | {r['encoder_activation']:>7s} | {r['params']:>10,d} | {r['ms/fwd']:>7.2f}ms\"\n", - " )\n", - "\n", - "fig, axes = plt.subplots(1, 2, figsize=(14, 4))\n", - "names = [r[\"name\"] for r in results]\n", - "params = [r[\"params\"] for r in results]\n", - "times = [r[\"ms/fwd\"] for r in results]\n", - "\n", - "bar_colors = [\"coral\" if r[\"encoder_activation\"] == \"tanh\" else \"steelblue\" for r in results]\n", - "\n", - "axes[0].bar(names, params, color=bar_colors, edgecolor=\"black\")\n", - "axes[0].set_ylabel(\"Parameters\")\n", - "axes[0].set_title(\"Parameter Count (orange=tanh encoder)\")\n", - "axes[0].tick_params(axis=\"x\", rotation=30)\n", - "for i, v in enumerate(params):\n", - " axes[0].text(i, v, f\"{v:,}\", ha=\"center\", va=\"bottom\", fontsize=7)\n", - "\n", - "axes[1].bar(names, times, color=bar_colors, edgecolor=\"black\")\n", - "axes[1].set_ylabel(\"ms / forward\")\n", - "axes[1].set_title(f\"Forward Pass Latency ({env.num_agents} agents)\")\n", - "axes[1].tick_params(axis=\"x\", rotation=30)\n", - "for i, v in enumerate(times):\n", - " axes[1].text(i, v, f\"{v:.2f}\", ha=\"center\", va=\"bottom\", fontsize=7)\n", - "\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Observation Buffer Utilization\n", - "How much of each observation slot is actually filled (non-zero)?" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Run a few steps to get diverse observations\n", - "actions = zero_actions(env)\n", - "all_obs = [obs]\n", - "for _ in range(20):\n", - " o, _, _, _, _ = env.step(actions)\n", - " all_obs.append(o)\n", - "stacked = np.concatenate(all_obs, axis=0)\n", - "\n", - "slide = env.ego_features\n", - "segments = [(\"ego\", 0, env.ego_features, 1, env.ego_features)]\n", - "if cond_dim > 0:\n", - " segments.append((\"conditioning\", slide, slide + cond_dim, 1, cond_dim))\n", - " slide += cond_dim\n", - "segments.append((\"partners\", slide, slide + partner_dim, env.obs_slots_partners_n, env.partner_features))\n", - "slide += partner_dim\n", - "segments.append((\"lanes\", slide, slide + lane_dim, env.obs_slots_lane_kept, env.road_features))\n", - "slide += lane_dim\n", - "segments.append((\"boundaries\", slide, slide + boundary_dim, env.obs_slots_boundary_kept, env.road_features))\n", - "slide += boundary_dim\n", - "segments.append((\"traffic\", slide, slide + traffic_dim, env.obs_slots_traffic_controls_n, env.traffic_control_features))\n", - "\n", - "print(f\"{'Segment':>15s} | {'Slots':>5s} | {'Features':>8s} | {'Fill %':>7s} | {'Mean':>8s} | {'Std':>8s}\")\n", - "print(\"-\" * 65)\n", - "fill_rates = []\n", - "seg_names = []\n", - "for name, start, end, n_slots, n_feat in segments:\n", - " chunk = stacked[:, start:end]\n", - " if n_slots > 1:\n", - " reshaped = chunk.reshape(-1, n_slots, n_feat)\n", - " # A slot is \"filled\" if any feature is non-zero\n", - " filled = (np.abs(reshaped).sum(axis=2) > 1e-8).mean()\n", - " else:\n", - " filled = (np.abs(chunk) > 1e-8).mean()\n", - " fill_rates.append(filled * 100)\n", - " seg_names.append(name)\n", - " print(f\"{name:>15s} | {n_slots:>5d} | {n_feat:>8d} | {filled:>6.1%} | {chunk.mean():>8.4f} | {chunk.std():>8.4f}\")\n", - "\n", - "fig, ax = plt.subplots(figsize=(8, 4))\n", - "colors = [\"#2ecc71\" if f > 50 else \"#e74c3c\" if f < 10 else \"#f39c12\" for f in fill_rates]\n", - "ax.barh(seg_names, fill_rates, color=colors, edgecolor=\"black\")\n", - "ax.set_xlabel(\"Fill Rate (%)\")\n", - "ax.set_title(\"Observation Slot Utilization\")\n", - "ax.axvline(50, color=\"gray\", ls=\"--\", alpha=0.5)\n", - "for i, v in enumerate(fill_rates):\n", - " ax.text(v + 1, i, f\"{v:.1f}%\", va=\"center\", fontsize=8)\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Effective Receptive Field\n", - "Which input features have the most influence on the hidden representation?" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Jacobian-based sensitivity: d(hidden) / d(obs) magnitude\n", - "sample = obs_tensor[:1].clone().requires_grad_(True)\n", - "hidden = policy.actor_backbone(sample, env.ego_features)\n", - "# Sum hidden to scalar for backward\n", - "hidden.sum().backward()\n", - "sensitivity = sample.grad.abs().squeeze().cpu().numpy()\n", - "\n", - "fig, axes = plt.subplots(2, 1, figsize=(14, 6), gridspec_kw={\"height_ratios\": [2, 1]})\n", - "\n", - "# Full sensitivity\n", - "axes[0].plot(sensitivity, lw=0.5, color=\"steelblue\")\n", - "axes[0].set_ylabel(\"|grad|\")\n", - "axes[0].set_title(\"Input Feature Sensitivity (|d hidden / d obs|)\")\n", - "\n", - "# Mark segments\n", - "seg_boundaries = [0, env.ego_features]\n", - "seg_labels = [\"ego\"]\n", - "s = env.ego_features\n", - "if cond_dim > 0:\n", - " s += cond_dim\n", - " seg_boundaries.append(s)\n", - " seg_labels.append(\"cond\")\n", - "for name, dim in [\n", - " (\"partners\", partner_dim),\n", - " (\"lanes\", lane_dim),\n", - " (\"boundaries\", boundary_dim),\n", - " (\"traffic\", traffic_dim),\n", - "]:\n", - " s += dim\n", - " seg_boundaries.append(s)\n", - " seg_labels.append(name)\n", - "\n", - "seg_colors = plt.cm.Set2(np.linspace(0, 1, len(seg_labels)))\n", - "for i, (label, c) in enumerate(zip(seg_labels, seg_colors)):\n", - " start, end = seg_boundaries[i], seg_boundaries[i + 1]\n", - " axes[0].axvspan(start, end, alpha=0.15, color=c)\n", - " axes[0].text((start + end) / 2, axes[0].get_ylim()[1] * 0.9, label, ha=\"center\", fontsize=7, color=\"black\")\n", - "\n", - "# Per-segment mean sensitivity\n", - "seg_means = []\n", - "for i in range(len(seg_labels)):\n", - " start, end = seg_boundaries[i], seg_boundaries[i + 1]\n", - " seg_means.append(sensitivity[start:end].mean())\n", - "\n", - "axes[1].bar(seg_labels, seg_means, color=seg_colors, edgecolor=\"black\")\n", - "axes[1].set_ylabel(\"Mean |grad|\")\n", - "axes[1].set_title(\"Mean Sensitivity per Observation Segment\")\n", - "\n", - "plt.tight_layout()\n", - "plt.show()\n", - "\n", - "policy.zero_grad()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/notebooks/06_architecture.py b/notebooks/06_architecture.py new file mode 100644 index 0000000000..f601fc7af2 --- /dev/null +++ b/notebooks/06_architecture.py @@ -0,0 +1,694 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.19.3 +# kernelspec: +# display_name: .venv +# language: python +# name: python3 +# --- + +# %% [markdown] +# # 06 - Neural Network Architecture +# Visualize, analyze, and iterate on the DrivePolicy architecture. Covers model summary, per-encoder breakdown, forward pass shape tracing, weight distributions, and architecture comparison. + +# %% +import numpy as np +import matplotlib.pyplot as plt +import torch +import torch.nn.functional as F +from torchinfo import summary +from pufferlib.ocean.drive import binding +from pufferlib.ocean.torch import Drive as DrivePolicy +from notebooks.notebook_utils import make_drive_env, zero_actions + +# --- Policy architecture --- +INPUT_SIZE = 64 +BACKBONE_HIDDEN_SIZE = 1024 +BACKBONE_NUM_LAYERS = 3 +ACTOR_HIDDEN_SIZE = 128 +ACTOR_NUM_LAYERS = 3 +CRITIC_HIDDEN_SIZE = 64 +CRITIC_NUM_LAYERS = 2 +SHARED_NETWORK = True +ENCODER_ACTIVATION = "tanh" +ENCODER_LAYER_NORM = True +BACKBONE_ACTIVATION = "gelu" +BACKBONE_LAYER_NORM = False + +env, obs, info = make_drive_env() + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +policy = DrivePolicy( + env, + ego_input_size=INPUT_SIZE, + partner_input_size=INPUT_SIZE, + lane_input_size=INPUT_SIZE, + boundary_input_size=INPUT_SIZE, + traffic_control_input_size=INPUT_SIZE, + target_input_size=INPUT_SIZE, + backbone_hidden_size=BACKBONE_HIDDEN_SIZE, + backbone_num_layers=BACKBONE_NUM_LAYERS, + actor_hidden_size=ACTOR_HIDDEN_SIZE, + actor_num_layers=ACTOR_NUM_LAYERS, + critic_hidden_size=CRITIC_HIDDEN_SIZE, + critic_num_layers=CRITIC_NUM_LAYERS, + encoder_activation=ENCODER_ACTIVATION, + encoder_layer_norm=ENCODER_LAYER_NORM, + backbone_activation=BACKBONE_ACTIVATION, + backbone_layer_norm=BACKBONE_LAYER_NORM, + shared_network=SHARED_NETWORK, +).to(device) + +print(f"Device: {device}") +print(f"Obs dim: {obs.shape[1]}") +print(f"Action dim: {policy.atn_dim}") +print(f"Shared network: {SHARED_NETWORK}") +print(f"Backbone: {BACKBONE_HIDDEN_SIZE} x {BACKBONE_NUM_LAYERS}L") +print(f"Actor: {ACTOR_HIDDEN_SIZE} x {ACTOR_NUM_LAYERS}L") +print(f"Critic: {CRITIC_HIDDEN_SIZE} x {CRITIC_NUM_LAYERS}L") +print(f"Encoder: {ENCODER_ACTIVATION}, LayerNorm: {ENCODER_LAYER_NORM}") + +# %% [markdown] +# ## Model Summary (torchinfo) + +# %% +obs_tensor = torch.FloatTensor(obs).to(device) +summary(policy, input_data=obs_tensor, depth=4, col_names=["input_size", "output_size", "num_params", "mult_adds"]) + +# %% [markdown] +# ## Architecture Diagram + +# %% +backbone = policy.actor_backbone +cond_dim = backbone.target_dim + +# Collect encoder info +encoders = [ + ("ego", env.ego_features, 1, "direct", INPUT_SIZE), + ("conditioning", cond_dim, 1, "direct", INPUT_SIZE) if cond_dim > 0 else None, + ("partner", env.partner_features, env.obs_slots_partners_n, "max-pool", INPUT_SIZE), + ("lane", env.road_features, env.obs_slots_lane_kept, "max-pool", INPUT_SIZE), + ("boundary", env.road_features, env.obs_slots_boundary_kept, "max-pool", INPUT_SIZE), + ( + "traffic_ctrl", + env.traffic_control_features - 2 + binding.NUM_TRAFFIC_CONTROL_TYPES + binding.NUM_TRAFFIC_CONTROL_STATES, + env.obs_slots_traffic_controls_n, + "max-pool (onehot)", + INPUT_SIZE, + ), +] +encoders = [e for e in encoders if e is not None] + +fig, ax = plt.subplots(figsize=(14, 8)) +ax.set_xlim(0, 10) +ax.set_ylim(0, 10) +ax.axis("off") + +n_enc = len(encoders) +y_positions = np.linspace(9, 1, n_enc) +colors = plt.cm.Set2(np.linspace(0, 1, n_enc)) + +# Draw encoders +for i, ((name, in_f, n_obj, agg, out_size), y, c) in enumerate(zip(encoders, y_positions, colors)): + # Input box + label = f"{name}\n{n_obj}x{in_f}" if n_obj > 1 else f"{name}\n{in_f}" + ax.add_patch(plt.Rectangle((0.2, y - 0.3), 1.6, 0.6, facecolor=c, edgecolor="black", lw=1.2, alpha=0.8)) + ax.text(1.0, y, label, ha="center", va="center", fontsize=8, fontweight="bold") + + # Encoder box + ax.add_patch(plt.Rectangle((2.5, y - 0.25), 2.0, 0.5, facecolor="lightyellow", edgecolor="black", lw=1)) + ax.text(3.5, y + 0.05, f"Linear({in_f},{out_size})", ha="center", va="center", fontsize=7) + ln_label = "LN+" if ENCODER_LAYER_NORM else "" + ax.text( + 3.5, + y - 0.12, + f"{ln_label}{ENCODER_ACTIVATION}+Linear({out_size},{out_size})", + ha="center", + va="center", + fontsize=6, + color="gray", + ) + + # Aggregation + if n_obj > 1: + ax.text(5.0, y, agg, ha="center", va="center", fontsize=7, style="italic", color="darkblue") + arrow_start = 5.5 + else: + arrow_start = 4.6 + + # Arrows + ax.annotate("", xy=(2.5, y), xytext=(1.8, y), arrowprops=dict(arrowstyle="->", lw=1)) + ax.annotate("", xy=(6.0, 5.0), xytext=(arrow_start, y), arrowprops=dict(arrowstyle="->", lw=0.8, color="gray")) + +# Concat box +ax.add_patch(plt.Rectangle((5.8, 4.5), 1.4, 1.0, facecolor="lightsalmon", edgecolor="black", lw=1.5)) +ax.text(6.5, 5.2, "Concat", ha="center", va="center", fontsize=9, fontweight="bold") +ax.text(6.5, 4.85, f"{n_enc}x{INPUT_SIZE}={n_enc * INPUT_SIZE}", ha="center", va="center", fontsize=7) + +# Backbone +ax.add_patch(plt.Rectangle((7.5, 4.5), 1.3, 1.0, facecolor="lightblue", edgecolor="black", lw=1.5)) +ax.text(8.15, 5.15, f"Backbone ({BACKBONE_NUM_LAYERS}L)", ha="center", va="center", fontsize=8, fontweight="bold") +ax.text(8.15, 4.85, f"GELU+Linear\n({n_enc * INPUT_SIZE},{BACKBONE_HIDDEN_SIZE})", ha="center", va="center", fontsize=6) +ax.annotate("", xy=(7.5, 5.0), xytext=(7.2, 5.0), arrowprops=dict(arrowstyle="->", lw=1.5)) + +# Actor / Critic heads +ax.add_patch(plt.Rectangle((9.0, 5.7), 0.9, 0.6, facecolor="lightgreen", edgecolor="black", lw=1.2)) +actor_label = f"Actor ({ACTOR_NUM_LAYERS}L)\n{BACKBONE_HIDDEN_SIZE}->{sum(policy.atn_dim)}" +if ACTOR_NUM_LAYERS > 1: + actor_label = ( + f"Actor ({ACTOR_NUM_LAYERS}L)\n{BACKBONE_HIDDEN_SIZE}->{ACTOR_HIDDEN_SIZE}->...->{sum(policy.atn_dim)}" + ) +ax.text(9.45, 6.0, actor_label, ha="center", va="center", fontsize=6, fontweight="bold") + +ax.add_patch(plt.Rectangle((9.0, 3.7), 0.9, 0.6, facecolor="plum", edgecolor="black", lw=1.2)) +critic_label = f"Critic ({CRITIC_NUM_LAYERS}L)\n{BACKBONE_HIDDEN_SIZE}->1" +if CRITIC_NUM_LAYERS > 1: + critic_label = f"Critic ({CRITIC_NUM_LAYERS}L)\n{BACKBONE_HIDDEN_SIZE}->{CRITIC_HIDDEN_SIZE}->...->1" +ax.text(9.45, 4.0, critic_label, ha="center", va="center", fontsize=6, fontweight="bold") + +ax.annotate("", xy=(9.0, 6.0), xytext=(8.8, 5.3), arrowprops=dict(arrowstyle="->", lw=1.2)) +ax.annotate("", xy=(9.0, 4.0), xytext=(8.8, 4.7), arrowprops=dict(arrowstyle="->", lw=1.2)) + +split_label = "SHARED" if SHARED_NETWORK else "SPLIT" +ax.text(8.9, 4.55, split_label, ha="center", va="center", fontsize=7, color="red", fontweight="bold") + +ax.text( + 5.0, + 0.3, + f"Encoder: {ENCODER_ACTIVATION} | LayerNorm: {ENCODER_LAYER_NORM}", + ha="center", + va="center", + fontsize=8, + color="darkgreen", + fontweight="bold", +) + +ax.set_title( + f"DrivePolicy Architecture (encoder_size={INPUT_SIZE}, backbone={BACKBONE_HIDDEN_SIZE})", + fontsize=12, + fontweight="bold", +) +plt.tight_layout() +plt.show() + + +# %% [markdown] +# ## Per-Encoder Parameter Breakdown + + +# %% +def count_params(module): + return sum(p.numel() for p in module.parameters()) + + +backbone = policy.actor_backbone +components = { + "ego_encoder": backbone.ego_encoder, + "lane_encoder": backbone.lane_encoder, + "boundary_encoder": backbone.boundary_encoder, + "partner_encoder": backbone.partner_encoder, + "traffic_ctrl_encoder": backbone.traffic_control_encoder, +} +if backbone.target_dim > 0: + components["target_encoder"] = backbone.target_encoder +components["backbone_mlp"] = backbone.backbone +components["actor_head"] = policy.actor_head +components["critic_head"] = policy.critic_head + +names, counts = zip(*[(k, count_params(v)) for k, v in components.items()]) +total = sum(counts) + +print(f"{'Component':>25s} | {'Params':>10s} | {'%':>6s}") +print("-" * 48) +for n, c in zip(names, counts): + print(f"{n:>25s} | {c:>10,d} | {c / total:>5.1%}") +print("-" * 48) +print(f"{'TOTAL':>25s} | {total:>10,d}") +if not SHARED_NETWORK: + critic_bb = count_params(policy.critic_backbone) + print(f"{'+ critic_backbone':>25s} | {critic_bb:>10,d}") + print(f"{'GRAND TOTAL':>25s} | {total + critic_bb:>10,d}") + +fig, ax = plt.subplots(figsize=(8, 5)) +colors = plt.cm.Set3(np.linspace(0, 1, len(names))) +bars = ax.barh(names, counts, color=colors, edgecolor="black") +for bar, c in zip(bars, counts): + ax.text(bar.get_width() + total * 0.01, bar.get_y() + bar.get_height() / 2, f"{c:,}", va="center", fontsize=8) +ax.set_xlabel("Parameters") +ax.set_title(f"Parameter Distribution ({total:,} total)") +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## Forward Pass Shape Trace + +# %% +x = obs_tensor +backbone = policy.actor_backbone + +slide_idx = env.ego_features +cond_dim = backbone.target_dim +partner_dim = env.obs_slots_partners_n * env.partner_features +lane_dim = env.obs_slots_lane_kept * env.road_features +boundary_dim = env.obs_slots_boundary_kept * env.road_features +traffic_dim = env.obs_slots_traffic_controls_n * env.traffic_control_features + +# Slicing +ego_obs = x[:, :slide_idx] +slices = [("ego", 0, slide_idx, ego_obs.shape)] + +if cond_dim > 0: + cond_obs = x[:, slide_idx : slide_idx + cond_dim] + slices.append(("conditioning", slide_idx, slide_idx + cond_dim, cond_obs.shape)) + slide_idx += cond_dim + +partner_obs = x[:, slide_idx : slide_idx + partner_dim] +slices.append(("partners", slide_idx, slide_idx + partner_dim, partner_obs.shape)) +slide_idx += partner_dim + +lane_obs = x[:, slide_idx : slide_idx + lane_dim] +slices.append(("lanes", slide_idx, slide_idx + lane_dim, lane_obs.shape)) +slide_idx += lane_dim + +boundary_obs = x[:, slide_idx : slide_idx + boundary_dim] +slices.append(("boundaries", slide_idx, slide_idx + boundary_dim, boundary_obs.shape)) +slide_idx += boundary_dim + +traffic_obs = x[:, slide_idx : slide_idx + traffic_dim] +slices.append(("traffic_ctrl", slide_idx, slide_idx + traffic_dim, traffic_obs.shape)) + +print(f"Obs buffer layout (total={x.shape[1]}):") +print(f"{'Name':>15s} | {'Start':>5s} | {'End':>5s} | {'Width':>5s} | Shape") +print("-" * 65) +for name, start, end, shape in slices: + print(f"{name:>15s} | {start:>5d} | {end:>5d} | {end - start:>5d} | {shape}") + +# Forward through encoders +print("\nEncoder outputs:") +with torch.no_grad(): + ego_enc = backbone.ego_encoder(ego_obs) + print(f" ego_encoder: {ego_obs.shape} -> {ego_enc.shape}") + + if cond_dim > 0: + cond_enc = backbone.target_encoder(cond_obs) + print(f" cond_encoder: {cond_obs.shape} -> {cond_enc.shape}") + + p_reshaped = partner_obs.view(-1, env.obs_slots_partners_n, env.partner_features) + p_enc, _ = backbone.partner_encoder(p_reshaped).max(dim=1) + print(f" partner_encoder: {partner_obs.shape} -> view {p_reshaped.shape} -> encode -> max-pool -> {p_enc.shape}") + + l_reshaped = lane_obs.view(-1, env.obs_slots_lane_kept, env.road_features) + l_enc, _ = backbone.lane_encoder(l_reshaped).max(dim=1) + print(f" lane_encoder: {lane_obs.shape} -> view {l_reshaped.shape} -> encode -> max-pool -> {l_enc.shape}") + + b_reshaped = boundary_obs.view(-1, env.obs_slots_boundary_kept, env.road_features) + b_enc, _ = backbone.boundary_encoder(b_reshaped).max(dim=1) + print(f" bound_encoder: {boundary_obs.shape} -> view {b_reshaped.shape} -> encode -> max-pool -> {b_enc.shape}") + + t_reshaped = traffic_obs.view(-1, env.obs_slots_traffic_controls_n, env.traffic_control_features) + t_cont = t_reshaped[:, :, : env.traffic_control_features - 2] + t_type = t_reshaped[:, :, env.traffic_control_features - 2] + t_state = t_reshaped[:, :, env.traffic_control_features - 1] + t_type_onehot = F.one_hot(t_type.long(), num_classes=binding.NUM_TRAFFIC_CONTROL_TYPES).float() + t_state_onehot = F.one_hot(t_state.long(), num_classes=binding.NUM_TRAFFIC_CONTROL_STATES).float() + t_input = torch.cat([t_cont, t_type_onehot, t_state_onehot], dim=2) + t_enc, _ = backbone.traffic_control_encoder(t_input).max(dim=1) + print( + f" traffic_encoder: {traffic_obs.shape} -> view {t_reshaped.shape} -> onehot {t_input.shape} -> encode -> max-pool -> {t_enc.shape}" + ) + + # Concat + backbone + features = [ego_enc, l_enc, b_enc, p_enc, t_enc] + if cond_dim > 0: + features.append(cond_enc) + concat = torch.cat(features, dim=1) + hidden = backbone.backbone(concat) + print(f"\n concat: {concat.shape}") + print(f" backbone_mlp: {concat.shape} -> {hidden.shape}") + + # Heads + actor_out = policy.actor_head(hidden) + critic_out = policy.critic_head(hidden) + print(f" actor_head: {hidden.shape} -> {actor_out.shape} (split into {policy.atn_dim})") + print(f" critic_head: {hidden.shape} -> {critic_out.shape}") + +# %% [markdown] +# ## Weight Distributions by Layer + +# %% +weight_data = [ + (n, p.data.cpu().numpy().flatten()) for n, p in policy.named_parameters() if "weight" in n and p.dim() >= 2 +] + +n_weights = len(weight_data) +cols = 4 +rows = (n_weights + cols - 1) // cols +fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 3 * rows)) +axes = axes.flatten() + +for i, (name, w) in enumerate(weight_data): + ax = axes[i] + ax.hist(w, bins=50, edgecolor="black", alpha=0.7, density=True) + ax.set_title(name.replace("actor_backbone.", ""), fontsize=7) + ax.axvline(0, color="red", ls="--", lw=0.5) + ax.text(0.95, 0.95, f"std={w.std():.3f}", transform=ax.transAxes, fontsize=6, ha="right", va="top") + +for j in range(i + 1, len(axes)): + axes[j].axis("off") + +fig.suptitle("Weight Distributions (init)", fontsize=12, fontweight="bold") +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## Activation Analysis (per encoder) + +# %% +policy.eval() +with torch.no_grad(): + hidden = policy.actor_backbone(obs_tensor, env.ego_features) + action_logits, value = policy.decode_actions(hidden) + +# Collect per-encoder activations +activations = {} +with torch.no_grad(): + slide = env.ego_features + activations["ego"] = backbone.ego_encoder(obs_tensor[:, : env.ego_features]) + + if cond_dim > 0: + activations["conditioning"] = backbone.target_encoder(obs_tensor[:, slide : slide + cond_dim]) + slide += cond_dim + + p_obs = obs_tensor[:, slide : slide + partner_dim].view(-1, env.obs_slots_partners_n, env.partner_features) + activations["partner"], _ = backbone.partner_encoder(p_obs).max(dim=1) + slide += partner_dim + + l_obs = obs_tensor[:, slide : slide + lane_dim].view(-1, env.obs_slots_lane_kept, env.road_features) + activations["lane"], _ = backbone.lane_encoder(l_obs).max(dim=1) + slide += lane_dim + + b_obs = obs_tensor[:, slide : slide + boundary_dim].view(-1, env.obs_slots_boundary_kept, env.road_features) + activations["boundary"], _ = backbone.boundary_encoder(b_obs).max(dim=1) + slide += boundary_dim + + t_obs = obs_tensor[:, slide : slide + traffic_dim].view( + -1, env.obs_slots_traffic_controls_n, env.traffic_control_features + ) + t_cont = t_obs[:, :, : env.traffic_control_features - 2] + t_type = t_obs[:, :, env.traffic_control_features - 2] + t_state = t_obs[:, :, env.traffic_control_features - 1] + t_type_onehot = F.one_hot(t_type.long(), num_classes=binding.NUM_TRAFFIC_CONTROL_TYPES).float() + t_state_onehot = F.one_hot(t_state.long(), num_classes=binding.NUM_TRAFFIC_CONTROL_STATES).float() + t_input = torch.cat([t_cont, t_type_onehot, t_state_onehot], dim=2) + activations["traffic_ctrl"], _ = backbone.traffic_control_encoder(t_input).max(dim=1) + + activations["hidden"] = hidden + +fig, axes = plt.subplots(2, 4, figsize=(16, 6)) +axes = axes.flatten() +for i, (name, act) in enumerate(activations.items()): + if i >= len(axes): + break + vals = act.cpu().numpy().flatten() + ax = axes[i] + ax.hist(vals, bins=50, edgecolor="black", alpha=0.7) + dead = (act.abs().sum(dim=0) == 0).sum().item() + ax.set_title(f"{name} (dead={dead}/{act.shape[1]})", fontsize=9) + ax.text( + 0.95, + 0.95, + f"mean={vals.mean():.3f}\nstd={vals.std():.3f}", + transform=ax.transAxes, + fontsize=7, + ha="right", + va="top", + ) + +for j in range(i + 1, len(axes)): + axes[j].axis("off") + +fig.suptitle("Per-Encoder Activation Distributions", fontsize=12, fontweight="bold") +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## Encoder Embedding Similarity (cosine) + +# %% +# Mean embedding per encoder (exclude hidden — different dim) +emb_names = [k for k in activations.keys() if k != "hidden"] +emb_means = torch.stack([activations[k].mean(dim=0) for k in emb_names]) +emb_norm = F.normalize(emb_means, dim=1) +sim_matrix = (emb_norm @ emb_norm.T).cpu().numpy() + +fig, ax = plt.subplots(figsize=(7, 6)) +im = ax.imshow(sim_matrix, cmap="RdBu_r", vmin=-1, vmax=1) +ax.set_xticks(range(len(emb_names))) +ax.set_yticks(range(len(emb_names))) +ax.set_xticklabels(emb_names, rotation=45, ha="right", fontsize=8) +ax.set_yticklabels(emb_names, fontsize=8) +for i in range(len(emb_names)): + for j in range(len(emb_names)): + ax.text(j, i, f"{sim_matrix[i, j]:.2f}", ha="center", va="center", fontsize=7) +fig.colorbar(im, ax=ax) +ax.set_title("Cosine Similarity Between Encoder Mean Embeddings") +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## Architecture Comparison +# Compare different architecture configs side-by-side without training. + +# %% +configs = [ + {"name": "tiny", "encoder_size": 32, "backbone_hidden_size": 64}, + {"name": "small", "encoder_size": 64, "backbone_hidden_size": 128}, + {"name": "medium", "encoder_size": 128, "backbone_hidden_size": 256, "backbone_num_layers": 2}, + { + "name": "large", + "encoder_size": 128, + "backbone_hidden_size": 512, + "backbone_num_layers": 2, + "actor_num_layers": 2, + "actor_hidden_size": 256, + "critic_num_layers": 2, + "critic_hidden_size": 256, + }, + { + "name": "xlarge", + "encoder_size": 256, + "backbone_hidden_size": 1024, + "backbone_num_layers": 3, + "actor_num_layers": 2, + "actor_hidden_size": 512, + "critic_num_layers": 2, + "critic_hidden_size": 512, + }, + {"name": "small+tanh", "encoder_size": 64, "backbone_hidden_size": 128, "encoder_activation": "tanh"}, + { + "name": "medium+tanh", + "encoder_size": 128, + "backbone_hidden_size": 256, + "backbone_num_layers": 2, + "encoder_activation": "tanh", + }, +] + +POLICY_DEFAULTS = { + "ego_input_size": 64, + "partner_input_size": 64, + "lane_input_size": 64, + "boundary_input_size": 64, + "traffic_control_input_size": 64, + "target_input_size": 64, + "backbone_num_layers": 1, + "actor_hidden_size": 128, + "actor_num_layers": 0, + "critic_hidden_size": 128, + "critic_num_layers": 0, + "encoder_activation": "relu", + "encoder_layer_norm": True, + "backbone_activation": "gelu", + "backbone_layer_norm": False, + "shared_network": True, +} + +results = [] +for cfg in configs: + name = cfg["name"] + encoder_size = cfg.get("encoder_size", POLICY_DEFAULTS["ego_input_size"]) + full_cfg = {**POLICY_DEFAULTS, **{k: v for k, v in cfg.items() if k not in ("name", "encoder_size")}} + full_cfg.update( + { + "ego_input_size": encoder_size, + "partner_input_size": encoder_size, + "lane_input_size": encoder_size, + "boundary_input_size": encoder_size, + "traffic_control_input_size": encoder_size, + "target_input_size": encoder_size, + } + ) + p = DrivePolicy(env, **full_cfg).to(device) + n_params = sum(pp.numel() for pp in p.parameters()) + + with torch.no_grad(): + import time + + t0 = time.time() + for _ in range(100): + p(obs_tensor) + if device.type == "cuda": + torch.cuda.synchronize() + ms_per_fwd = (time.time() - t0) / 100 * 1000 + + results.append({"name": name, "encoder_size": encoder_size, "params": n_params, "ms/fwd": ms_per_fwd, **full_cfg}) + del p + +print( + f"{'Config':>12s} | {'enc':>5s} | {'bb_h':>5s} | {'bb_L':>4s} | {'act_h':>5s} | {'act_L':>5s} | {'crt_h':>5s} | {'crt_L':>5s} | {'enc_act':>7s} | {'Params':>10s} | {'ms/fwd':>8s}" +) +print("-" * 105) +for r in results: + print( + f"{r['name']:>12s} | {r['encoder_size']:>5d} | {r['backbone_hidden_size']:>5d} | {r['backbone_num_layers']:>4d} | {r['actor_hidden_size']:>5d} | {r['actor_num_layers']:>5d} | {r['critic_hidden_size']:>5d} | {r['critic_num_layers']:>5d} | {r['encoder_activation']:>7s} | {r['params']:>10,d} | {r['ms/fwd']:>7.2f}ms" + ) + +fig, axes = plt.subplots(1, 2, figsize=(14, 4)) +names = [r["name"] for r in results] +params = [r["params"] for r in results] +times = [r["ms/fwd"] for r in results] + +bar_colors = ["coral" if r["encoder_activation"] == "tanh" else "steelblue" for r in results] + +axes[0].bar(names, params, color=bar_colors, edgecolor="black") +axes[0].set_ylabel("Parameters") +axes[0].set_title("Parameter Count (orange=tanh encoder)") +axes[0].tick_params(axis="x", rotation=30) +for i, v in enumerate(params): + axes[0].text(i, v, f"{v:,}", ha="center", va="bottom", fontsize=7) + +axes[1].bar(names, times, color=bar_colors, edgecolor="black") +axes[1].set_ylabel("ms / forward") +axes[1].set_title(f"Forward Pass Latency ({env.num_agents} agents)") +axes[1].tick_params(axis="x", rotation=30) +for i, v in enumerate(times): + axes[1].text(i, v, f"{v:.2f}", ha="center", va="bottom", fontsize=7) + +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## Observation Buffer Utilization +# How much of each observation slot is actually filled (non-zero)? + +# %% +# Run a few steps to get diverse observations +actions = zero_actions(env) +all_obs = [obs] +for _ in range(20): + o, _, _, _, _ = env.step(actions) + all_obs.append(o) +stacked = np.concatenate(all_obs, axis=0) + +slide = env.ego_features +segments = [("ego", 0, env.ego_features, 1, env.ego_features)] +if cond_dim > 0: + segments.append(("conditioning", slide, slide + cond_dim, 1, cond_dim)) + slide += cond_dim +segments.append(("partners", slide, slide + partner_dim, env.obs_slots_partners_n, env.partner_features)) +slide += partner_dim +segments.append(("lanes", slide, slide + lane_dim, env.obs_slots_lane_kept, env.road_features)) +slide += lane_dim +segments.append(("boundaries", slide, slide + boundary_dim, env.obs_slots_boundary_kept, env.road_features)) +slide += boundary_dim +segments.append(("traffic", slide, slide + traffic_dim, env.obs_slots_traffic_controls_n, env.traffic_control_features)) + +print(f"{'Segment':>15s} | {'Slots':>5s} | {'Features':>8s} | {'Fill %':>7s} | {'Mean':>8s} | {'Std':>8s}") +print("-" * 65) +fill_rates = [] +seg_names = [] +for name, start, end, n_slots, n_feat in segments: + chunk = stacked[:, start:end] + if n_slots > 1: + reshaped = chunk.reshape(-1, n_slots, n_feat) + # A slot is "filled" if any feature is non-zero + filled = (np.abs(reshaped).sum(axis=2) > 1e-8).mean() + else: + filled = (np.abs(chunk) > 1e-8).mean() + fill_rates.append(filled * 100) + seg_names.append(name) + print(f"{name:>15s} | {n_slots:>5d} | {n_feat:>8d} | {filled:>6.1%} | {chunk.mean():>8.4f} | {chunk.std():>8.4f}") + +fig, ax = plt.subplots(figsize=(8, 4)) +colors = ["#2ecc71" if f > 50 else "#e74c3c" if f < 10 else "#f39c12" for f in fill_rates] +ax.barh(seg_names, fill_rates, color=colors, edgecolor="black") +ax.set_xlabel("Fill Rate (%)") +ax.set_title("Observation Slot Utilization") +ax.axvline(50, color="gray", ls="--", alpha=0.5) +for i, v in enumerate(fill_rates): + ax.text(v + 1, i, f"{v:.1f}%", va="center", fontsize=8) +plt.tight_layout() +plt.show() + +# %% [markdown] +# ## Effective Receptive Field +# Which input features have the most influence on the hidden representation? + +# %% +# Jacobian-based sensitivity: d(hidden) / d(obs) magnitude +sample = obs_tensor[:1].clone().requires_grad_(True) +hidden = policy.actor_backbone(sample, env.ego_features) +# Sum hidden to scalar for backward +hidden.sum().backward() +sensitivity = sample.grad.abs().squeeze().cpu().numpy() + +fig, axes = plt.subplots(2, 1, figsize=(14, 6), gridspec_kw={"height_ratios": [2, 1]}) + +# Full sensitivity +axes[0].plot(sensitivity, lw=0.5, color="steelblue") +axes[0].set_ylabel("|grad|") +axes[0].set_title("Input Feature Sensitivity (|d hidden / d obs|)") + +# Mark segments +seg_boundaries = [0, env.ego_features] +seg_labels = ["ego"] +s = env.ego_features +if cond_dim > 0: + s += cond_dim + seg_boundaries.append(s) + seg_labels.append("cond") +for name, dim in [ + ("partners", partner_dim), + ("lanes", lane_dim), + ("boundaries", boundary_dim), + ("traffic", traffic_dim), +]: + s += dim + seg_boundaries.append(s) + seg_labels.append(name) + +seg_colors = plt.cm.Set2(np.linspace(0, 1, len(seg_labels))) +for i, (label, c) in enumerate(zip(seg_labels, seg_colors)): + start, end = seg_boundaries[i], seg_boundaries[i + 1] + axes[0].axvspan(start, end, alpha=0.15, color=c) + axes[0].text((start + end) / 2, axes[0].get_ylim()[1] * 0.9, label, ha="center", fontsize=7, color="black") + +# Per-segment mean sensitivity +seg_means = [] +for i in range(len(seg_labels)): + start, end = seg_boundaries[i], seg_boundaries[i + 1] + seg_means.append(sensitivity[start:end].mean()) + +axes[1].bar(seg_labels, seg_means, color=seg_colors, edgecolor="black") +axes[1].set_ylabel("Mean |grad|") +axes[1].set_title("Mean Sensitivity per Observation Segment") + +plt.tight_layout() +plt.show() + +policy.zero_grad() From 0961e3af6c6d3d24cf2c3726a5149501e138930d Mon Sep 17 00:00:00 2001 From: Valentin Charraut Date: Wed, 3 Jun 2026 17:39:04 +0200 Subject: [PATCH 08/10] Rename target_dim to goal_dim for clarity in Drive class --- pufferlib/ocean/drive/drive.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pufferlib/ocean/drive/drive.py b/pufferlib/ocean/drive/drive.py index 370425ed29..214935550d 100644 --- a/pufferlib/ocean/drive/drive.py +++ b/pufferlib/ocean/drive/drive.py @@ -220,12 +220,12 @@ def __init__( self.target_features = binding.STATIC_TARGET_FEATURES else: self.target_features = binding.DYNAMIC_TARGET_FEATURES - self.target_dim = self.num_target_waypoints * self.target_features + self.goal_dim = self.num_target_waypoints * self.target_features self.num_obs = ( self.ego_features + self.num_reward_coefs - + self.target_dim + + self.goal_dim + self.obs_slots_partners_n * self.partner_features + self.obs_slots_lane_kept * self.road_features + self.obs_slots_boundary_kept * self.road_features From 03cc5384f0ec7c8ac2d80e7d971ded81a0aa2bd6 Mon Sep 17 00:00:00 2001 From: Valentin Charraut Date: Wed, 3 Jun 2026 18:09:42 +0200 Subject: [PATCH 09/10] Update inference and architecture notebooks to include mask_padded_features parameter --- notebooks/05_inference.py | 5 ++++- notebooks/06_architecture.py | 3 +++ notebooks/notebook_utils.py | 1 + 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/notebooks/05_inference.py b/notebooks/05_inference.py index 8a2ffc8013..eddcc1ad8e 100644 --- a/notebooks/05_inference.py +++ b/notebooks/05_inference.py @@ -801,7 +801,7 @@ def unpack_all_timesteps(bufs, agent_idx): f"({100 * len(visible_partners) / (all_partners.shape[0] * obs_slots_partners_n):.1f}%)" ) -fig, axes = plt.subplots(2, 5, figsize=(21, 8)) +fig, axes = plt.subplots(3, 4, figsize=(21, 11)) axes = axes.flatten() for i, label in enumerate(partner_labels): @@ -821,6 +821,9 @@ def unpack_all_timesteps(bufs, agent_idx): pos_ax.set_aspect("equal") pos_ax.grid(True, alpha=0.3) +for ax in axes[len(partner_labels) + 1 :]: + ax.axis("off") + fig.suptitle("Partner features: all visible, full rollout", fontsize=13) plt.tight_layout() plt.show() diff --git a/notebooks/06_architecture.py b/notebooks/06_architecture.py index f601fc7af2..ddcd8eac41 100644 --- a/notebooks/06_architecture.py +++ b/notebooks/06_architecture.py @@ -39,6 +39,7 @@ ENCODER_LAYER_NORM = True BACKBONE_ACTIVATION = "gelu" BACKBONE_LAYER_NORM = False +MASK_PADDED_FEATURES = False env, obs, info = make_drive_env() @@ -62,6 +63,7 @@ backbone_activation=BACKBONE_ACTIVATION, backbone_layer_norm=BACKBONE_LAYER_NORM, shared_network=SHARED_NETWORK, + mask_padded_features=MASK_PADDED_FEATURES, ).to(device) print(f"Device: {device}") @@ -516,6 +518,7 @@ def count_params(module): "backbone_activation": "gelu", "backbone_layer_norm": False, "shared_network": True, + "mask_padded_features": False, } results = [] diff --git a/notebooks/notebook_utils.py b/notebooks/notebook_utils.py index cbd13bfcc1..815cc58ee2 100644 --- a/notebooks/notebook_utils.py +++ b/notebooks/notebook_utils.py @@ -109,6 +109,7 @@ "backbone_activation": "gelu", "backbone_layer_norm": False, "shared_network": True, + "mask_padded_features": False, } From 4e58499c3da510d92b4a83585156ccd500324946 Mon Sep 17 00:00:00 2001 From: Valentin Charraut Date: Wed, 3 Jun 2026 18:17:25 +0200 Subject: [PATCH 10/10] Update CI configuration and test data for improved performance and consistency --- .github/workflows/ci.yml | 2 +- setup.py | 1 + .../smoke_tests/data/drive_smoke_golden.json | 60 +++++++++---------- 3 files changed, 32 insertions(+), 31 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4d239b9a4a..5927bf2f08 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -66,7 +66,7 @@ jobs: PIP_NO_CACHE_DIR: 1 run: | sudo apt-get update && sudo apt-get install -y build-essential cmake - python -m pip install -U pip pytest jupytext nbclient ipykernel ipywidgets + python -m pip install -U pip pytest jupytext nbclient ipykernel pip install -e . --no-cache-dir python setup.py build_ext --inplace --force diff --git a/setup.py b/setup.py index 7504107dd2..52b5283cdc 100644 --- a/setup.py +++ b/setup.py @@ -277,6 +277,7 @@ def run(self): "tensorboard", "jupytext", "torchinfo", + "ipywidgets", ] setup( diff --git a/tests/smoke_tests/data/drive_smoke_golden.json b/tests/smoke_tests/data/drive_smoke_golden.json index b3e2c6865e..01734a546f 100644 --- a/tests/smoke_tests/data/drive_smoke_golden.json +++ b/tests/smoke_tests/data/drive_smoke_golden.json @@ -1,48 +1,48 @@ { "env": { - "avg_distance_per_infraction": 13.18341121673584, - "avg_speed_per_agent": 1.3711345672607422, - "collision_rate": 0.0375, - "comfort_violation_count": 0.7263497829437255, + "avg_distance_per_infraction": 14.85188889503479, + "avg_speed_per_agent": 1.4011926352977753, + "collision_rate": 0.0, + "comfort_violation_count": 0.7380953431129456, "dnf_rate": 0.5625, - "episode_length": 41.8, - "episode_return": -2.334346318244934, - "lane_center_rate": 0.7175199031829834, + "episode_length": 46.0, + "episode_return": -2.523432493209839, + "lane_center_rate": 0.705732524394989, "n": 16.0, "num_goals_reached": 0.0, "obs/max": 80.0, - "obs/mean": 0.2281341243069619, - "obs/min": -1.236907229758799, - "offroad_rate": 0.375, - "red_light_violation_rate": 0.025, + "obs/mean": 0.22964997438248247, + "obs/min": -1.1160968635231256, + "offroad_rate": 0.40625, + "red_light_violation_rate": 0.03125, "reward_components/ade": 0.0, - "reward_components/collision": -0.06195625364780426, - "reward_components/comfort": -1.5356246471405028, + "reward_components/collision": 0.0, + "reward_components/comfort": -1.712499350309372, "reward_components/goal": 0.0, - "reward_components/lane_align": -0.13004764765501023, - "reward_components/lane_center": -0.002508045267313719, - "reward_components/offroad": -0.5625, + "reward_components/lane_align": -0.14824828505516052, + "reward_components/lane_center": -0.003005934282555245, + "reward_components/offroad": -0.609375, "reward_components/overspeed": 0.0, - "reward_components/red_light": -0.025, - "reward_components/reverse": -0.01661874633282423, - "reward_components/timestep": -9.165622614091263e-05, + "reward_components/red_light": -0.03125, + "reward_components/reverse": -0.018953118938952684, + "reward_components/timestep": -0.00010140621816390194, "reward_components/velocity": 0.0, "score": 0.0, "velocity_progress_sum": 0.0 }, "losses": { - "approx_kl": 0.0002155543354872082, + "approx_kl": 0.00022559609663273607, "clipfrac": 0.0, - "ema_max": 1.2262159883975983, - "entropy": 2.484058209827968, - "explained_variance": 0.16326427459716797, - "filter_threshold": 0.012262159883975983, - "filtered_fraction": 0.024684585847504104, - "kept_fraction": 0.9753154141524959, - "masked_fraction": 0.10986328125, - "old_approx_kl": -0.0011627360114029475, - "policy_loss": -0.0004513516489948545, - "value_loss": 0.08172488159367017 + "ema_max": 1.2698969841003418, + "entropy": 2.4838708468845914, + "explained_variance": 0.07069230079650879, + "filter_threshold": 0.012698969841003419, + "filtered_fraction": 0.014412416851441234, + "kept_fraction": 0.9855875831485588, + "masked_fraction": 0.119140625, + "old_approx_kl": 0.0003389947648559298, + "policy_loss": -9.120744653046131e-05, + "value_loss": 0.0982145725616387 }, "meta": { "bptt_horizon": 64,