From 85b07d6b64b56e62b16f7f8f1c99df5b97a0990f Mon Sep 17 00:00:00 2001 From: Sriram Krishna Date: Sat, 3 Jan 2026 12:38:10 -0500 Subject: [PATCH 1/5] return and save token representations --- pixi.lock | 33 +++++++++++++++++++- pyproject.toml | 1 + scripts/eval_lerobot_episode.py | 55 ++++++++++++++++++++++++++++----- src/lfd3d/models/dino_3dgp.py | 17 +++++----- 4 files changed, 91 insertions(+), 15 deletions(-) diff --git a/pixi.lock b/pixi.lock index 6a99162..239e5f3 100644 --- a/pixi.lock +++ b/pixi.lock @@ -393,6 +393,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/6d/45/59578566b3275b8fd9157885918fcd0c4d74162928a5310926887b856a51/platformdirs-4.3.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f7/60/1974cfdd5bb770568ddc6f89f3e0df4cfdd1acffd5a609dff5e95f48c6e2/portalocker-3.1.1-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/5b/7d/1529014aebb9d5fd54538115886d005d371a624b1ecaf5c2525b45ad0f77/pot-0.9.6.post1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/e3/b7/1d145c985d8be9729672a45b8b8113030ad60dff45dec592efc4e5f5897a/pre_commit-3.3.3-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/6d/7a/dcb10ad171dbffb6dd2122672f69e5b34e9859d9bcc6e7119c3cb2986ca2/proglog-0.1.11-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ff/c2/ab7d37426c179ceb9aeb109a85cda8948bb269b7561a0be870cc656eefe4/prometheus_client-0.21.1-py3-none-any.whl @@ -3521,7 +3522,7 @@ packages: - pypi: . name: lfd3d version: 0.1.0 - sha256: dfca697c602714453dd8cc25600627077dd3034ee8abf63d0dd1f316f2634399 + sha256: d9faf9747d41f86679e4f91b3643d0cc518b6c6c24fd3ebae2a9702b139230c2 requires_dist: - diffusers>=0.26.3 - gif @@ -3549,6 +3550,7 @@ packages: - peft>=0.17,<0.18 - decord>=0.6.0,<0.7 - mink>=0.0.11,<0.0.12 + - pot>=0.9.6.post1,<0.10 - ruff==0.3.6 ; extra == 'develop' - mypy==1.3.0 ; extra == 'develop' - pytest==7.3.2 ; extra == 'develop' @@ -5375,6 +5377,35 @@ packages: - pytest-rerunfailures>=15.0 ; extra == 'tests' - redis ; extra == 'redis' requires_python: '>=3.9' +- pypi: https://files.pythonhosted.org/packages/5b/7d/1529014aebb9d5fd54538115886d005d371a624b1ecaf5c2525b45ad0f77/pot-0.9.6.post1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl + name: pot + version: 0.9.6.post1 + sha256: 48924f34d61b909e68651f3fe9fc1a892c69ae38d3c52bc832f95a28569c0e0e + requires_dist: + - numpy>=1.16 + - scipy>=1.6 + - jax ; extra == 'backend-jax' + - jaxlib ; extra == 'backend-jax' + - tensorflow ; extra == 'backend-tf' + - torch ; extra == 'backend-torch' + - cvxopt ; extra == 'cvxopt' + - scikit-learn ; extra == 'dr' + - pymanopt ; extra == 'dr' + - autograd ; extra == 'dr' + - torch ; extra == 'gnn' + - torch-geometric ; extra == 'gnn' + - matplotlib ; extra == 'plot' + - jax ; extra == 'all' + - jaxlib ; extra == 'all' + - tensorflow ; extra == 'all' + - torch ; extra == 'all' + - cvxopt ; extra == 'all' + - scikit-learn ; extra == 'all' + - pymanopt ; extra == 'all' + - autograd ; extra == 'all' + - torch-geometric ; extra == 'all' + - matplotlib ; extra == 'all' + requires_python: '>=3.7' - pypi: https://files.pythonhosted.org/packages/e3/b7/1d145c985d8be9729672a45b8b8113030ad60dff45dec592efc4e5f5897a/pre_commit-3.3.3-py2.py3-none-any.whl name: pre-commit version: 3.3.3 diff --git a/pyproject.toml b/pyproject.toml index 73ac84f..8dcea03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "peft>=0.17,<0.18", "decord>=0.6.0,<0.7", "mink>=0.0.11,<0.0.12", + "pot>=0.9.6.post1,<0.10", ] [project.optional-dependencies] diff --git a/scripts/eval_lerobot_episode.py b/scripts/eval_lerobot_episode.py index b7decdb..f6bd9b5 100644 --- a/scripts/eval_lerobot_episode.py +++ b/scripts/eval_lerobot_episode.py @@ -1,5 +1,7 @@ +import csv import json import random +from pathlib import Path import cv2 import hydra @@ -166,16 +168,21 @@ def main(cfg): preds = trainer.predict(model, datamodule=eval_datamodule) preds_dict = {tag: {} for tag in eval_datamodule.eval_tags} + exp_dir = Path(cfg.log_dir) / f"{cfg.checkpoint.run_id}_{cfg.dataset.repo_id}" + exp_dir.mkdir(parents=True) + all_episode_metrics = [] + loader = eval_datamodule.predict_dataloader() for i, episode_id in enumerate(episode_idx): - heatmaps = [] - raw_heatmaps = [] + heatmaps, raw_heatmaps, episode_latents = [], [], [] metrics = {"pix_dist": [], "rmse": []} if len(episode_idx) == 1: preds = [preds] - for pred, batch in tqdm(zip(preds[i], loader[i]), total=len(loader[i])): + for pred, batch in tqdm( + zip(preds[i], loader[i]), total=len(loader[i]), desc=f"Episode {episode_id}" + ): rgb = batch["rgbs"][:, 0].cpu().numpy() # B, H, W, 3 batch_size = rgb.shape[0] @@ -197,6 +204,7 @@ def main(cfg): pred_coord = pred["pred_coord"].cpu().numpy().astype(int) # B, 2 metrics["pix_dist"].append(pred["pix_dist"]) metrics["rmse"].append(pred["rmse"]) + episode_latents.append(pred["latent_repr"].cpu()) for j in range(batch_size): heatmap_ = generate_heatmap_from_points( @@ -215,19 +223,52 @@ def main(cfg): if len(raw_heatmaps) > 0: save_video( - f"{cfg.log_dir}/episode_{episode_id}_raw_heatmaps_{cfg.model.name}.mp4", + str(exp_dir / f"episode_{episode_id}_raw_heatmaps.mp4"), frames=raw_heatmaps, ) save_video( - f"{cfg.log_dir}/episode_{episode_id}_heatmap_{cfg.model.name}.mp4", + str(exp_dir / f"episode_{episode_id}_heatmap.mp4"), frames=heatmaps, ) + if len(episode_latents) > 0: + episode_latents_tensor = torch.cat(episode_latents, dim=0) # (N_frames, D) + latent_file = exp_dir / f"episode_{episode_id}.pt" + torch.save( + { + "latents": episode_latents_tensor, + "n_frames": episode_latents_tensor.shape[0], + "latent_dim": episode_latents_tensor.shape[1], + }, + latent_file, + ) + + # Compute and store metrics for this episode + episode_metric_dict = {"episode_id": episode_id} for key in metrics: if len(metrics[key]) != 0: metric_val = torch.cat(metrics[key]) - print(f"Mean {key}:", metric_val.mean().item()) - print(f"Std. {key}:", metric_val.std().item()) + mean_val = metric_val.mean().item() + std_val = metric_val.std().item() + episode_metric_dict[f"{key}_mean"] = mean_val + episode_metric_dict[f"{key}_std"] = std_val + print( + f"Episode {episode_id} - {key}: mean={mean_val:.4f}, std={std_val:.4f}" + ) + + all_episode_metrics.append(episode_metric_dict) + + # Save metrics to CSV sorted by episode number + if all_episode_metrics: + all_episode_metrics.sort(key=lambda x: x["episode_id"]) + csv_file = exp_dir / "metrics.csv" + fieldnames = list(all_episode_metrics[0].keys()) + with open(csv_file, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(all_episode_metrics) + print(f"\nSaved metrics to {csv_file}") + print(f"Experiment results saved to: {exp_dir}") if __name__ == "__main__": diff --git a/src/lfd3d/models/dino_3dgp.py b/src/lfd3d/models/dino_3dgp.py index f67776a..ae50ec0 100644 --- a/src/lfd3d/models/dino_3dgp.py +++ b/src/lfd3d/models/dino_3dgp.py @@ -468,7 +468,7 @@ def forward( # Predict GMM parameters outputs = self.output_head(tokens) # (B, T, 13) - return outputs, patch_coords + return outputs, patch_coords, tokens class Dino3DGPGoalRegressionModule(pl.LightningModule): @@ -810,7 +810,7 @@ def forward(self, batch): ) # (B, N, 4, 4) # Forward through network - outputs, patch_coords = self.network( + outputs, patch_coords, tokens = self.network( rgb, depth, all_intrinsics, @@ -902,7 +902,7 @@ def training_step(self, batch, batch_idx): else: all_pred_dict = [self.predict(batch)] - pred_dict, weighted_displacement = all_pred_dict[0] + pred_dict, weighted_displacement, _ = all_pred_dict[0] pred_dict[self.prediction_type]["all_pred"] = [ i[0][self.prediction_type]["pred"] for i in all_pred_dict ] @@ -992,7 +992,7 @@ def predict(self, batch, progress=False): ) # Forward - outputs, patch_coords = self.network( + outputs, patch_coords, tokens = self.network( rgb, depth, all_intrinsics, @@ -1002,13 +1002,15 @@ def predict(self, batch, progress=False): source=batch["data_source"], ) + latent_repr = tokens.mean(dim=1) # (B, T, D) -> (B, D) + if self.is_gmm: pred = self.sample_from_gmm(outputs, patch_coords) else: pred = self.get_weighted_prediction(outputs, patch_coords) pred_displacement = pred - init - return {self.prediction_type: {"pred": pred_displacement}}, outputs + return {self.prediction_type: {"pred": pred_displacement}}, outputs, latent_repr def sample_from_gmm(self, outputs, patch_coords): """ @@ -1244,7 +1246,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): all_pred_dict.append(self.predict(batch)) else: all_pred_dict = [self.predict(batch)] - pred_dict, weighted_displacement = all_pred_dict[0] + pred_dict, weighted_displacement, _ = all_pred_dict[0] pred_dict[self.prediction_type]["all_pred"] = [ i[0][self.prediction_type]["pred"] for i in all_pred_dict @@ -1360,7 +1362,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): else: all_pred_dict = [self.predict(batch)] - pred_dict, weighted_displacement = all_pred_dict[0] + pred_dict, weighted_displacement, latent_repr = all_pred_dict[0] pred_dict[self.prediction_type]["all_pred"] = [ i[0][self.prediction_type]["pred"] for i in all_pred_dict ] @@ -1431,6 +1433,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): "wta_pix_dist": pred_dict["wta_pix_dist"], "vid_name": batch["vid_name"], "caption": batch["caption"], + "latent_repr": latent_repr, # (B, D) mean-pooled token representation } def on_predict_epoch_end(self): From 24b901d6442b41b867c53123509f4c23bbef34bd Mon Sep 17 00:00:00 2001 From: Sriram Krishna Date: Sat, 3 Jan 2026 13:26:34 -0500 Subject: [PATCH 2/5] add script to analyze latents --- scripts/analyze_latents.py | 160 +++++++++++++++++++++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 scripts/analyze_latents.py diff --git a/scripts/analyze_latents.py b/scripts/analyze_latents.py new file mode 100644 index 0000000..76ab2fa --- /dev/null +++ b/scripts/analyze_latents.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +""" +Analyze and visualize latent representations from two datasets using t-SNE. +Also computes Wasserstein-2 distance between distributions. +""" + +import argparse +from pathlib import Path +from typing import List, Tuple + +import matplotlib.pyplot as plt +import numpy as np +import ot +import torch +from sklearn.manifold import TSNE + + +def load_latents_from_path(dataset_path: Path) -> Tuple[np.ndarray, List[str]]: + """ + Load latent tensors from episode*.pt files in the given path. + + Args: + dataset_path: Path to the dataset directory + + Returns: + Tuple of (concatenated latents array of shape [total_frames, latent_dim], list of episode names) + """ + episode_files = sorted(dataset_path.glob("episode*.pt")) + all_latents, episode_names = [], [] + + for episode_file in episode_files: + latent_tensor = torch.load(episode_file)["latents"] + latent_np = latent_tensor.cpu().numpy() + all_latents.append(latent_np) + episode_names.append(episode_file.name) + + concatenated_latents = np.concatenate(all_latents, axis=0) + return concatenated_latents, episode_names + + +def compute_wasserstein2_distance(X: np.ndarray, Y: np.ndarray) -> float: + """ + Compute the Wasserstein-2 distance between two point clouds. + + Args: + X: First point cloud of shape [n_samples, n_features] + Y: Second point cloud of shape [m_samples, n_features] + + Returns: + Wasserstein-2 distance + """ + # Uniform weights for both distributions + a = np.ones(len(X)) / len(X) + b = np.ones(len(Y)) / len(Y) + + # Compute cost matrix (squared Euclidean distance) + M = ot.dist(X, Y, metric="sqeuclidean") + + # Compute Wasserstein distance squared using EMD + w2_squared = ot.emd2(a, b, M) + + # Return Wasserstein-2 distance (square root) + return np.sqrt(w2_squared) + + +def main(): + parser = argparse.ArgumentParser( + description="Analyze latents using t-SNE visualization" + ) + parser.add_argument( + "--dset1_path", type=str, required=True, help="Path to dataset 1" + ) + parser.add_argument( + "--dset2_path", type=str, required=True, help="Path to dataset 2" + ) + parser.add_argument( + "--perplexity", + type=float, + default=30.0, + help="t-SNE perplexity parameter (default: 30)", + ) + parser.add_argument( + "--n_iter", + type=int, + default=1000, + help="Number of t-SNE iterations (default: 1000)", + ) + parser.add_argument( + "--output", + type=str, + default="latent_tsne.png", + help="Output figure path (default: latent_tsne.png)", + ) + + args = parser.parse_args() + + dset1_path = Path(args.dset1_path) + dset2_path = Path(args.dset2_path) + + dset1_latents, _ = load_latents_from_path(dset1_path) + dset2_latents, _ = load_latents_from_path(dset2_path) + + # Compute Wasserstein-2 distance + w2_dist = compute_wasserstein2_distance(dset1_latents, dset2_latents) + print(f"\nWasserstein-2 distance: {w2_dist:.6f}\n") + + # Combine all latents + all_latents = np.concatenate([dset1_latents, dset2_latents], axis=0) + dset1_labels = np.zeros(len(dset1_latents)) + dset2_labels = np.ones(len(dset2_latents)) + all_labels = np.concatenate([dset1_labels, dset2_labels]) + + tsne = TSNE( + n_components=2, + perplexity=args.perplexity, + n_iter=args.n_iter, + random_state=42, + verbose=1, + ) + embeddings = tsne.fit_transform(all_latents) + + dset1_embeddings = embeddings[all_labels == 0] + dset2_embeddings = embeddings[all_labels == 1] + + plt.figure(figsize=(12, 8)) + plt.scatter( + dset1_embeddings[:, 0], + dset1_embeddings[:, 1], + c="blue", + alpha=0.6, + s=10, + label=f"{args.dset1_path} (n={len(dset1_latents)})", + ) + plt.scatter( + dset2_embeddings[:, 0], + dset2_embeddings[:, 1], + c="red", + alpha=0.6, + s=10, + label=f"{args.dset2_path} (n={len(dset2_latents)})", + ) + plt.xlabel("t-SNE Dimension 1", fontsize=12) + plt.ylabel("t-SNE Dimension 2", fontsize=12) + plt.title( + f"t-SNE Visualization of Latents (W2: {w2_dist:.4f})", + fontsize=14, + fontweight="bold", + ) + plt.legend(fontsize=11) + plt.grid(True, alpha=0.3) + plt.tight_layout() + + # Save figure + plt.savefig(args.output, dpi=300, bbox_inches="tight") + print(f"Figure saved to: {args.output}") + print("\nDone!") + + +if __name__ == "__main__": + main() From 3ffc821123f9eb938e055c0693231f386e33f84e Mon Sep 17 00:00:00 2001 From: Sriram Krishna Date: Sun, 4 Jan 2026 02:18:39 -0500 Subject: [PATCH 3/5] changes for working with droid --- configs/dataset/droidLerobot.yaml | 24 +++++++++++++++++++ configs/training/droidLerobot_dino_3dgp.yaml | 18 ++++++++++++++ src/lfd3d/datasets/lerobot/lerobot_dataset.py | 1 + src/lfd3d/datasets/rgb_text_feature_gen.py | 9 ++++++- src/lfd3d/models/dino_3dgp.py | 11 +++++++-- 5 files changed, 60 insertions(+), 3 deletions(-) create mode 100644 configs/dataset/droidLerobot.yaml create mode 100644 configs/training/droidLerobot_dino_3dgp.yaml diff --git a/configs/dataset/droidLerobot.yaml b/configs/dataset/droidLerobot.yaml new file mode 100644 index 0000000..1cbd27d --- /dev/null +++ b/configs/dataset/droidLerobot.yaml @@ -0,0 +1,24 @@ +defaults: + - base_dataset # Inherit from base +name: rpadLerobot + +# The LeRobot dataset repo containing post-processed goals. +# # Overriding this means you have to do LEROBOT_HOME / repo_id directly in data_dir. +data_dir: null # don't want to override automatic computation, unless you do... +repo_id: sriramsk/droid_lerobot +use_subgoals: False + +val_episode_ratio: 0.05 # percent of episodes in val set + +# Multi-camera configuration (first camera is primary) +cameras: + - name: cam_1 + color_key: "observation.images.cam_1.color" + depth_key: "observation.images.cam_1.depth" + - name: cam_2 + color_key: "observation.images.cam_2.color" + depth_key: "observation.images.cam_2.depth" + +gripper_pcd_key: "observation.points.gripper_pcds" + +rgb_feat: False # If true, compute DINOv2 features, else just return RGB diff --git a/configs/training/droidLerobot_dino_3dgp.yaml b/configs/training/droidLerobot_dino_3dgp.yaml new file mode 100644 index 0000000..6a98f8c --- /dev/null +++ b/configs/training/droidLerobot_dino_3dgp.yaml @@ -0,0 +1,18 @@ +defaults: + - base_train + +# Override augment_train to image_color_only for dino 3dgp (safe for multiview) +augment_train: "image_color_only" + +epochs: 100 +batch_size: 128 +val_batch_size: 128 + +# ModelCheckpoint configurations +checkpoints: + rmse: + monitor: val/rmse + mode: min + rmse_and_std_combi: + monitor: val/rmse_and_std_combi + mode: min diff --git a/src/lfd3d/datasets/lerobot/lerobot_dataset.py b/src/lfd3d/datasets/lerobot/lerobot_dataset.py index a66babd..562ba1d 100644 --- a/src/lfd3d/datasets/lerobot/lerobot_dataset.py +++ b/src/lfd3d/datasets/lerobot/lerobot_dataset.py @@ -104,6 +104,7 @@ def __init__( self.GRIPPER_IDX = { "aloha": np.array([6, 197, 174]), "human": np.array([343, 763, 60]), + "droid": np.array([356, 232, 16]), "libero_franka": np.array( [0, 1, 2] ), # gripper pcd in dataset: [left right top grasp-center] in agentview; (right gripper, left gripper, top, grasp-center) diff --git a/src/lfd3d/datasets/rgb_text_feature_gen.py b/src/lfd3d/datasets/rgb_text_feature_gen.py index 054d4a5..2cd56e8 100644 --- a/src/lfd3d/datasets/rgb_text_feature_gen.py +++ b/src/lfd3d/datasets/rgb_text_feature_gen.py @@ -7,6 +7,7 @@ Example: python rgb_text_feature_gen.py --dataset hoi4d --input_dir /path/to/hoi4d """ + import argparse import json import os @@ -126,7 +127,13 @@ def get_siglip_text_embedding( ) # Process text input - inputs = siglip_processor(text=[caption], return_tensors="pt", padding=True) + inputs = siglip_processor( + text=[caption], + return_tensors="pt", + padding=True, + truncation=True, + max_length=64, + ) inputs = {k: v.to(device) for k, v in inputs.items()} # Generate embeddings diff --git a/src/lfd3d/models/dino_3dgp.py b/src/lfd3d/models/dino_3dgp.py index ae50ec0..0bc3e47 100644 --- a/src/lfd3d/models/dino_3dgp.py +++ b/src/lfd3d/models/dino_3dgp.py @@ -156,8 +156,15 @@ def __init__(self, model_cfg): self.use_source_token = model_cfg.use_source_token if self.use_source_token: # Learnable embeddings: 0 = human, 1 = robot - self.source_to_idx = {"human": 0, "aloha": 1, "libero_franka": 2} - self.source_embeddings = nn.Embedding(3, self.hidden_dim) + self.source_to_idx = { + "human": 0, + "aloha": 1, + "libero_franka": 2, + "droid": 3, + } + self.source_embeddings = nn.Embedding( + len(self.source_to_idx), self.hidden_dim + ) # Transformer blocks (self-attention only) self.num_layers = model_cfg.num_transformer_layers From 996dc3f17df80ca9752b14da29f989a2d3d4f3e0 Mon Sep 17 00:00:00 2001 From: Sriram Krishna Date: Sun, 4 Jan 2026 04:24:49 -0500 Subject: [PATCH 4/5] add optimal transport loss --- scripts/eval_lerobot_episode.py | 2 +- src/lfd3d/models/dino_3dgp.py | 81 ++++++++++++++++++++++++++++++--- 2 files changed, 76 insertions(+), 7 deletions(-) diff --git a/scripts/eval_lerobot_episode.py b/scripts/eval_lerobot_episode.py index f6bd9b5..4445507 100644 --- a/scripts/eval_lerobot_episode.py +++ b/scripts/eval_lerobot_episode.py @@ -204,7 +204,7 @@ def main(cfg): pred_coord = pred["pred_coord"].cpu().numpy().astype(int) # B, 2 metrics["pix_dist"].append(pred["pix_dist"]) metrics["rmse"].append(pred["rmse"]) - episode_latents.append(pred["latent_repr"].cpu()) + episode_latents.append(pred["z"].cpu()) for j in range(batch_size): heatmap_ = generate_heatmap_from_points( diff --git a/src/lfd3d/models/dino_3dgp.py b/src/lfd3d/models/dino_3dgp.py index 0bc3e47..f49d9c8 100644 --- a/src/lfd3d/models/dino_3dgp.py +++ b/src/lfd3d/models/dino_3dgp.py @@ -4,6 +4,7 @@ import cv2 import numpy as np +import ot import pytorch_lightning as pl import torch import torch.nn.functional as F @@ -849,6 +850,7 @@ def forward(self, batch): B, num_components, device=outputs.device, dtype=torch.bool ) + loss_dict = {} # Compute GMM loss if self.is_gmm: loss = 0 @@ -869,6 +871,7 @@ def forward(self, batch): var, use_weights=False, ) + loss_dict["gmm_loss"] = loss else: # Simple MSE loss (if not using GMM) # Get weighted prediction @@ -877,8 +880,73 @@ def forward(self, batch): dim=1 ) loss = F.mse_loss(pred_points, gt) + loss_dict["mse_loss"] = loss - return None, loss + ot_loss = self.ot_loss( + tokens, + embodiment=batch["data_source"], + caption=batch["caption"], + goal_vec=(gt - init)[:, 0, :], + ) + loss_dict["ot_loss"] = ot_loss + alpha = 0.5 + + loss = loss + (alpha * ot_loss) + return None, loss, loss_dict + + def ot_loss(self, tokens, embodiment, caption, goal_vec, lambda_=0.1, epsilon=0.1): + """ + Optimal Transport-based loss for domain adaptation based on EgoBridge. + Aligns distributions of the latent representations of human and robot data. + + Similar latents are expected when we have similar tasks (pick-place, fold) + and the goal vectors (goal_pos - current_pos) are similar. + + For this to work, batch size needs to be reasonably large and contain similar amounts + and types of human and robot data, careful! + """ + # Only compute OT loss if minibatch contains aloha and human data. + if set(embodiment) != {"aloha", "human"}: + return 0.0 + + human_mask = [i == "human" for i in embodiment] + robot_mask = [i == "aloha" for i in embodiment] + n_h, n_r = sum(human_mask), sum(robot_mask) + + # Group the captions by the first word + # [Fold the onesie, Fold the shirt] -> Fold + # Somewhat hacky, should probably do semantic similarity? + task = np.array([c.split(" ")[0] for c in caption]) + task_h, task_r = task[human_mask], task[robot_mask] + task_match = torch.tensor( + task_h[:, None] == task_r[None, :], device=tokens.device + ) + + # Similarity matrix of residual vectors (goal - current) + # Considered a match if the distance is less than the median + res_h, res_r = goal_vec[human_mask], goal_vec[robot_mask] + R = torch.cdist(res_h, res_r) ** 2 + percentile = 0.2 + best_match = (R < R.quantile(percentile)) & task_match + + z = tokens.mean(dim=1) # (B, T, D) -> (B, D) + z_h, z_r = z[human_mask], z[robot_mask] + + # Compute cost matrix of latents + # Normalize cost, discount latents which should align + # Penalize cross-task latents + C = torch.cdist(z_h, z_r) ** 2 + C = C / (C.max() + 1e-8) + C[best_match] *= lambda_ + C[~task_match] = 1.0 + + # Optimal Transport loss + a = torch.ones(n_h, device=tokens.device) / n_h + b = torch.ones(n_r, device=tokens.device) / n_r + T = ot.sinkhorn(a, b, C, reg=epsilon) + loss = (T * C).sum() + + return loss def training_step(self, batch, batch_idx): """Training step with 3D GMM prediction""" @@ -889,8 +957,9 @@ def training_step(self, batch, batch_idx): self.train() batch_size = batch[self.label_key].points_padded().shape[0] - _, loss = self(batch) + _, loss, loss_dict = self(batch) train_metrics = {"loss": loss} + train_metrics.update(loss_dict) # Additional logging do_additional_logging = ( @@ -1009,7 +1078,7 @@ def predict(self, batch, progress=False): source=batch["data_source"], ) - latent_repr = tokens.mean(dim=1) # (B, T, D) -> (B, D) + z = tokens.mean(dim=1) # (B, T, D) -> (B, D) if self.is_gmm: pred = self.sample_from_gmm(outputs, patch_coords) @@ -1017,7 +1086,7 @@ def predict(self, batch, progress=False): pred = self.get_weighted_prediction(outputs, patch_coords) pred_displacement = pred - init - return {self.prediction_type: {"pred": pred_displacement}}, outputs, latent_repr + return {self.prediction_type: {"pred": pred_displacement}}, outputs, z def sample_from_gmm(self, outputs, patch_coords): """ @@ -1369,7 +1438,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): else: all_pred_dict = [self.predict(batch)] - pred_dict, weighted_displacement, latent_repr = all_pred_dict[0] + pred_dict, weighted_displacement, z = all_pred_dict[0] pred_dict[self.prediction_type]["all_pred"] = [ i[0][self.prediction_type]["pred"] for i in all_pred_dict ] @@ -1440,7 +1509,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): "wta_pix_dist": pred_dict["wta_pix_dist"], "vid_name": batch["vid_name"], "caption": batch["caption"], - "latent_repr": latent_repr, # (B, D) mean-pooled token representation + "z": z, # (B, D) mean-pooled token representation } def on_predict_epoch_end(self): From 91b7af05debcb8a162d49b6461f4c7833820270e Mon Sep 17 00:00:00 2001 From: Sriram Krishna Date: Tue, 6 Jan 2026 12:38:21 -0500 Subject: [PATCH 5/5] fix logging, update hparams and make configurable --- configs/model/dino_3dgp.yaml | 7 +++++ src/lfd3d/models/dino_3dgp.py | 48 ++++++++++++++++++++++------------- 2 files changed, 37 insertions(+), 18 deletions(-) diff --git a/configs/model/dino_3dgp.yaml b/configs/model/dino_3dgp.yaml index f9e2e5e..4f2a4b1 100644 --- a/configs/model/dino_3dgp.yaml +++ b/configs/model/dino_3dgp.yaml @@ -19,6 +19,13 @@ is_gmm: True # Train a GMM and minimize negative log likelihood instead fixed_variance: [0.01, 0.05, 0.1, 0.25, 0.5] # for gmm uniform_weights_coeff: 0.1 # coefficient for nll loss term when we use uniform mixing weights instead of pred +# Optimal Transport loss settings (for domain adaptation) +use_ot_loss: True # Enable optimal transport loss for aligning human/robot latent distributions +ot_alpha: 0.05 # Weight for combining OT loss with main loss +ot_lambda: 0.1 # Discount factor for matching latents (lower = stronger discount) +ot_epsilon: 0.1 # Regularization parameter for Sinkhorn algorithm +ot_percentile: 0.1 # Percentile threshold for determining best matches based on goal similarity + # Model-specific training augmentations image_token_dropout: True # Enable image token dropout during training gripper_noise_prob: 0.4 # Probability of applying gripper noise augmentation diff --git a/src/lfd3d/models/dino_3dgp.py b/src/lfd3d/models/dino_3dgp.py index f49d9c8..ba315f6 100644 --- a/src/lfd3d/models/dino_3dgp.py +++ b/src/lfd3d/models/dino_3dgp.py @@ -502,6 +502,13 @@ def __init__(self, network, cfg) -> None: self.uniform_weights_coeff = cfg.model.uniform_weights_coeff self.is_gmm = cfg.model.is_gmm + # Optimal Transport loss parameters + self.use_ot_loss = cfg.model.use_ot_loss + self.ot_alpha = cfg.model.ot_alpha + self.ot_lambda = cfg.model.ot_lambda + self.ot_epsilon = cfg.model.ot_epsilon + self.ot_percentile = cfg.model.ot_percentile + # Gripper noise augmentation parameters self.gripper_noise_prob = cfg.model.gripper_noise_prob self.gripper_noise_translation = cfg.model.gripper_noise_translation @@ -882,19 +889,19 @@ def forward(self, batch): loss = F.mse_loss(pred_points, gt) loss_dict["mse_loss"] = loss - ot_loss = self.ot_loss( - tokens, - embodiment=batch["data_source"], - caption=batch["caption"], - goal_vec=(gt - init)[:, 0, :], - ) - loss_dict["ot_loss"] = ot_loss - alpha = 0.5 + if self.use_ot_loss: + ot_loss = self.ot_loss( + tokens, + embodiment=batch["data_source"], + caption=batch["caption"], + goal_vec=(gt - init)[:, 0, :], + ) + loss_dict["ot_loss"] = ot_loss + loss = loss + (self.ot_alpha * ot_loss) - loss = loss + (alpha * ot_loss) return None, loss, loss_dict - def ot_loss(self, tokens, embodiment, caption, goal_vec, lambda_=0.1, epsilon=0.1): + def ot_loss(self, tokens, embodiment, caption, goal_vec): """ Optimal Transport-based loss for domain adaptation based on EgoBridge. Aligns distributions of the latent representations of human and robot data. @@ -923,27 +930,24 @@ def ot_loss(self, tokens, embodiment, caption, goal_vec, lambda_=0.1, epsilon=0. ) # Similarity matrix of residual vectors (goal - current) - # Considered a match if the distance is less than the median + # Considered a match if the distance is less than the percentile threshold res_h, res_r = goal_vec[human_mask], goal_vec[robot_mask] R = torch.cdist(res_h, res_r) ** 2 - percentile = 0.2 - best_match = (R < R.quantile(percentile)) & task_match + best_match = (R < R.quantile(self.ot_percentile)) & task_match z = tokens.mean(dim=1) # (B, T, D) -> (B, D) z_h, z_r = z[human_mask], z[robot_mask] # Compute cost matrix of latents - # Normalize cost, discount latents which should align - # Penalize cross-task latents C = torch.cdist(z_h, z_r) ** 2 C = C / (C.max() + 1e-8) - C[best_match] *= lambda_ - C[~task_match] = 1.0 + C[best_match] *= self.ot_lambda # Discount latents which should align + C[~task_match] /= self.ot_lambda # Penalize cross-task # Optimal Transport loss a = torch.ones(n_h, device=tokens.device) / n_h b = torch.ones(n_r, device=tokens.device) / n_r - T = ot.sinkhorn(a, b, C, reg=epsilon) + T = ot.sinkhorn(a, b, C, reg=self.ot_epsilon) loss = (T * C).sum() return loss @@ -1281,6 +1285,14 @@ def mean_metric(metric_name): [x[metric_name].mean() for x in self.train_outputs if metric_name in x] ).mean() + # Log loss_dict components + if any("gmm_loss" in x for x in self.train_outputs): + log_dictionary["train/gmm_loss"] = mean_metric("gmm_loss") + if any("mse_loss" in x for x in self.train_outputs): + log_dictionary["train/mse_loss"] = mean_metric("mse_loss") + if any("ot_loss" in x for x in self.train_outputs): + log_dictionary["train/ot_loss"] = mean_metric("ot_loss") + if any("rmse" in x for x in self.train_outputs): log_dictionary["train/rmse"] = mean_metric("rmse") log_dictionary["train/wta_rmse"] = mean_metric("wta_rmse")