diff --git a/configs/dataset/droidLerobot.yaml b/configs/dataset/droidLerobot.yaml new file mode 100644 index 0000000..1cbd27d --- /dev/null +++ b/configs/dataset/droidLerobot.yaml @@ -0,0 +1,24 @@ +defaults: + - base_dataset # Inherit from base +name: rpadLerobot + +# The LeRobot dataset repo containing post-processed goals. +# # Overriding this means you have to do LEROBOT_HOME / repo_id directly in data_dir. +data_dir: null # don't want to override automatic computation, unless you do... +repo_id: sriramsk/droid_lerobot +use_subgoals: False + +val_episode_ratio: 0.05 # percent of episodes in val set + +# Multi-camera configuration (first camera is primary) +cameras: + - name: cam_1 + color_key: "observation.images.cam_1.color" + depth_key: "observation.images.cam_1.depth" + - name: cam_2 + color_key: "observation.images.cam_2.color" + depth_key: "observation.images.cam_2.depth" + +gripper_pcd_key: "observation.points.gripper_pcds" + +rgb_feat: False # If true, compute DINOv2 features, else just return RGB diff --git a/configs/model/dino_3dgp.yaml b/configs/model/dino_3dgp.yaml index f9e2e5e..4f2a4b1 100644 --- a/configs/model/dino_3dgp.yaml +++ b/configs/model/dino_3dgp.yaml @@ -19,6 +19,13 @@ is_gmm: True # Train a GMM and minimize negative log likelihood instead fixed_variance: [0.01, 0.05, 0.1, 0.25, 0.5] # for gmm uniform_weights_coeff: 0.1 # coefficient for nll loss term when we use uniform mixing weights instead of pred +# Optimal Transport loss settings (for domain adaptation) +use_ot_loss: True # Enable optimal transport loss for aligning human/robot latent distributions +ot_alpha: 0.05 # Weight for combining OT loss with main loss +ot_lambda: 0.1 # Discount factor for matching latents (lower = stronger discount) +ot_epsilon: 0.1 # Regularization parameter for Sinkhorn algorithm +ot_percentile: 0.1 # Percentile threshold for determining best matches based on goal similarity + # Model-specific training augmentations image_token_dropout: True # Enable image token dropout during training gripper_noise_prob: 0.4 # Probability of applying gripper noise augmentation diff --git a/configs/training/droidLerobot_dino_3dgp.yaml b/configs/training/droidLerobot_dino_3dgp.yaml new file mode 100644 index 0000000..6a98f8c --- /dev/null +++ b/configs/training/droidLerobot_dino_3dgp.yaml @@ -0,0 +1,18 @@ +defaults: + - base_train + +# Override augment_train to image_color_only for dino 3dgp (safe for multiview) +augment_train: "image_color_only" + +epochs: 100 +batch_size: 128 +val_batch_size: 128 + +# ModelCheckpoint configurations +checkpoints: + rmse: + monitor: val/rmse + mode: min + rmse_and_std_combi: + monitor: val/rmse_and_std_combi + mode: min diff --git a/pixi.lock b/pixi.lock index 6a99162..239e5f3 100644 --- a/pixi.lock +++ b/pixi.lock @@ -393,6 +393,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/6d/45/59578566b3275b8fd9157885918fcd0c4d74162928a5310926887b856a51/platformdirs-4.3.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f7/60/1974cfdd5bb770568ddc6f89f3e0df4cfdd1acffd5a609dff5e95f48c6e2/portalocker-3.1.1-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/5b/7d/1529014aebb9d5fd54538115886d005d371a624b1ecaf5c2525b45ad0f77/pot-0.9.6.post1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/e3/b7/1d145c985d8be9729672a45b8b8113030ad60dff45dec592efc4e5f5897a/pre_commit-3.3.3-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/6d/7a/dcb10ad171dbffb6dd2122672f69e5b34e9859d9bcc6e7119c3cb2986ca2/proglog-0.1.11-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ff/c2/ab7d37426c179ceb9aeb109a85cda8948bb269b7561a0be870cc656eefe4/prometheus_client-0.21.1-py3-none-any.whl @@ -3521,7 +3522,7 @@ packages: - pypi: . name: lfd3d version: 0.1.0 - sha256: dfca697c602714453dd8cc25600627077dd3034ee8abf63d0dd1f316f2634399 + sha256: d9faf9747d41f86679e4f91b3643d0cc518b6c6c24fd3ebae2a9702b139230c2 requires_dist: - diffusers>=0.26.3 - gif @@ -3549,6 +3550,7 @@ packages: - peft>=0.17,<0.18 - decord>=0.6.0,<0.7 - mink>=0.0.11,<0.0.12 + - pot>=0.9.6.post1,<0.10 - ruff==0.3.6 ; extra == 'develop' - mypy==1.3.0 ; extra == 'develop' - pytest==7.3.2 ; extra == 'develop' @@ -5375,6 +5377,35 @@ packages: - pytest-rerunfailures>=15.0 ; extra == 'tests' - redis ; extra == 'redis' requires_python: '>=3.9' +- pypi: https://files.pythonhosted.org/packages/5b/7d/1529014aebb9d5fd54538115886d005d371a624b1ecaf5c2525b45ad0f77/pot-0.9.6.post1-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl + name: pot + version: 0.9.6.post1 + sha256: 48924f34d61b909e68651f3fe9fc1a892c69ae38d3c52bc832f95a28569c0e0e + requires_dist: + - numpy>=1.16 + - scipy>=1.6 + - jax ; extra == 'backend-jax' + - jaxlib ; extra == 'backend-jax' + - tensorflow ; extra == 'backend-tf' + - torch ; extra == 'backend-torch' + - cvxopt ; extra == 'cvxopt' + - scikit-learn ; extra == 'dr' + - pymanopt ; extra == 'dr' + - autograd ; extra == 'dr' + - torch ; extra == 'gnn' + - torch-geometric ; extra == 'gnn' + - matplotlib ; extra == 'plot' + - jax ; extra == 'all' + - jaxlib ; extra == 'all' + - tensorflow ; extra == 'all' + - torch ; extra == 'all' + - cvxopt ; extra == 'all' + - scikit-learn ; extra == 'all' + - pymanopt ; extra == 'all' + - autograd ; extra == 'all' + - torch-geometric ; extra == 'all' + - matplotlib ; extra == 'all' + requires_python: '>=3.7' - pypi: https://files.pythonhosted.org/packages/e3/b7/1d145c985d8be9729672a45b8b8113030ad60dff45dec592efc4e5f5897a/pre_commit-3.3.3-py2.py3-none-any.whl name: pre-commit version: 3.3.3 diff --git a/pyproject.toml b/pyproject.toml index 73ac84f..8dcea03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "peft>=0.17,<0.18", "decord>=0.6.0,<0.7", "mink>=0.0.11,<0.0.12", + "pot>=0.9.6.post1,<0.10", ] [project.optional-dependencies] diff --git a/scripts/analyze_latents.py b/scripts/analyze_latents.py new file mode 100644 index 0000000..76ab2fa --- /dev/null +++ b/scripts/analyze_latents.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +""" +Analyze and visualize latent representations from two datasets using t-SNE. +Also computes Wasserstein-2 distance between distributions. +""" + +import argparse +from pathlib import Path +from typing import List, Tuple + +import matplotlib.pyplot as plt +import numpy as np +import ot +import torch +from sklearn.manifold import TSNE + + +def load_latents_from_path(dataset_path: Path) -> Tuple[np.ndarray, List[str]]: + """ + Load latent tensors from episode*.pt files in the given path. + + Args: + dataset_path: Path to the dataset directory + + Returns: + Tuple of (concatenated latents array of shape [total_frames, latent_dim], list of episode names) + """ + episode_files = sorted(dataset_path.glob("episode*.pt")) + all_latents, episode_names = [], [] + + for episode_file in episode_files: + latent_tensor = torch.load(episode_file)["latents"] + latent_np = latent_tensor.cpu().numpy() + all_latents.append(latent_np) + episode_names.append(episode_file.name) + + concatenated_latents = np.concatenate(all_latents, axis=0) + return concatenated_latents, episode_names + + +def compute_wasserstein2_distance(X: np.ndarray, Y: np.ndarray) -> float: + """ + Compute the Wasserstein-2 distance between two point clouds. + + Args: + X: First point cloud of shape [n_samples, n_features] + Y: Second point cloud of shape [m_samples, n_features] + + Returns: + Wasserstein-2 distance + """ + # Uniform weights for both distributions + a = np.ones(len(X)) / len(X) + b = np.ones(len(Y)) / len(Y) + + # Compute cost matrix (squared Euclidean distance) + M = ot.dist(X, Y, metric="sqeuclidean") + + # Compute Wasserstein distance squared using EMD + w2_squared = ot.emd2(a, b, M) + + # Return Wasserstein-2 distance (square root) + return np.sqrt(w2_squared) + + +def main(): + parser = argparse.ArgumentParser( + description="Analyze latents using t-SNE visualization" + ) + parser.add_argument( + "--dset1_path", type=str, required=True, help="Path to dataset 1" + ) + parser.add_argument( + "--dset2_path", type=str, required=True, help="Path to dataset 2" + ) + parser.add_argument( + "--perplexity", + type=float, + default=30.0, + help="t-SNE perplexity parameter (default: 30)", + ) + parser.add_argument( + "--n_iter", + type=int, + default=1000, + help="Number of t-SNE iterations (default: 1000)", + ) + parser.add_argument( + "--output", + type=str, + default="latent_tsne.png", + help="Output figure path (default: latent_tsne.png)", + ) + + args = parser.parse_args() + + dset1_path = Path(args.dset1_path) + dset2_path = Path(args.dset2_path) + + dset1_latents, _ = load_latents_from_path(dset1_path) + dset2_latents, _ = load_latents_from_path(dset2_path) + + # Compute Wasserstein-2 distance + w2_dist = compute_wasserstein2_distance(dset1_latents, dset2_latents) + print(f"\nWasserstein-2 distance: {w2_dist:.6f}\n") + + # Combine all latents + all_latents = np.concatenate([dset1_latents, dset2_latents], axis=0) + dset1_labels = np.zeros(len(dset1_latents)) + dset2_labels = np.ones(len(dset2_latents)) + all_labels = np.concatenate([dset1_labels, dset2_labels]) + + tsne = TSNE( + n_components=2, + perplexity=args.perplexity, + n_iter=args.n_iter, + random_state=42, + verbose=1, + ) + embeddings = tsne.fit_transform(all_latents) + + dset1_embeddings = embeddings[all_labels == 0] + dset2_embeddings = embeddings[all_labels == 1] + + plt.figure(figsize=(12, 8)) + plt.scatter( + dset1_embeddings[:, 0], + dset1_embeddings[:, 1], + c="blue", + alpha=0.6, + s=10, + label=f"{args.dset1_path} (n={len(dset1_latents)})", + ) + plt.scatter( + dset2_embeddings[:, 0], + dset2_embeddings[:, 1], + c="red", + alpha=0.6, + s=10, + label=f"{args.dset2_path} (n={len(dset2_latents)})", + ) + plt.xlabel("t-SNE Dimension 1", fontsize=12) + plt.ylabel("t-SNE Dimension 2", fontsize=12) + plt.title( + f"t-SNE Visualization of Latents (W2: {w2_dist:.4f})", + fontsize=14, + fontweight="bold", + ) + plt.legend(fontsize=11) + plt.grid(True, alpha=0.3) + plt.tight_layout() + + # Save figure + plt.savefig(args.output, dpi=300, bbox_inches="tight") + print(f"Figure saved to: {args.output}") + print("\nDone!") + + +if __name__ == "__main__": + main() diff --git a/scripts/eval_lerobot_episode.py b/scripts/eval_lerobot_episode.py index b7decdb..4445507 100644 --- a/scripts/eval_lerobot_episode.py +++ b/scripts/eval_lerobot_episode.py @@ -1,5 +1,7 @@ +import csv import json import random +from pathlib import Path import cv2 import hydra @@ -166,16 +168,21 @@ def main(cfg): preds = trainer.predict(model, datamodule=eval_datamodule) preds_dict = {tag: {} for tag in eval_datamodule.eval_tags} + exp_dir = Path(cfg.log_dir) / f"{cfg.checkpoint.run_id}_{cfg.dataset.repo_id}" + exp_dir.mkdir(parents=True) + all_episode_metrics = [] + loader = eval_datamodule.predict_dataloader() for i, episode_id in enumerate(episode_idx): - heatmaps = [] - raw_heatmaps = [] + heatmaps, raw_heatmaps, episode_latents = [], [], [] metrics = {"pix_dist": [], "rmse": []} if len(episode_idx) == 1: preds = [preds] - for pred, batch in tqdm(zip(preds[i], loader[i]), total=len(loader[i])): + for pred, batch in tqdm( + zip(preds[i], loader[i]), total=len(loader[i]), desc=f"Episode {episode_id}" + ): rgb = batch["rgbs"][:, 0].cpu().numpy() # B, H, W, 3 batch_size = rgb.shape[0] @@ -197,6 +204,7 @@ def main(cfg): pred_coord = pred["pred_coord"].cpu().numpy().astype(int) # B, 2 metrics["pix_dist"].append(pred["pix_dist"]) metrics["rmse"].append(pred["rmse"]) + episode_latents.append(pred["z"].cpu()) for j in range(batch_size): heatmap_ = generate_heatmap_from_points( @@ -215,19 +223,52 @@ def main(cfg): if len(raw_heatmaps) > 0: save_video( - f"{cfg.log_dir}/episode_{episode_id}_raw_heatmaps_{cfg.model.name}.mp4", + str(exp_dir / f"episode_{episode_id}_raw_heatmaps.mp4"), frames=raw_heatmaps, ) save_video( - f"{cfg.log_dir}/episode_{episode_id}_heatmap_{cfg.model.name}.mp4", + str(exp_dir / f"episode_{episode_id}_heatmap.mp4"), frames=heatmaps, ) + if len(episode_latents) > 0: + episode_latents_tensor = torch.cat(episode_latents, dim=0) # (N_frames, D) + latent_file = exp_dir / f"episode_{episode_id}.pt" + torch.save( + { + "latents": episode_latents_tensor, + "n_frames": episode_latents_tensor.shape[0], + "latent_dim": episode_latents_tensor.shape[1], + }, + latent_file, + ) + + # Compute and store metrics for this episode + episode_metric_dict = {"episode_id": episode_id} for key in metrics: if len(metrics[key]) != 0: metric_val = torch.cat(metrics[key]) - print(f"Mean {key}:", metric_val.mean().item()) - print(f"Std. {key}:", metric_val.std().item()) + mean_val = metric_val.mean().item() + std_val = metric_val.std().item() + episode_metric_dict[f"{key}_mean"] = mean_val + episode_metric_dict[f"{key}_std"] = std_val + print( + f"Episode {episode_id} - {key}: mean={mean_val:.4f}, std={std_val:.4f}" + ) + + all_episode_metrics.append(episode_metric_dict) + + # Save metrics to CSV sorted by episode number + if all_episode_metrics: + all_episode_metrics.sort(key=lambda x: x["episode_id"]) + csv_file = exp_dir / "metrics.csv" + fieldnames = list(all_episode_metrics[0].keys()) + with open(csv_file, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(all_episode_metrics) + print(f"\nSaved metrics to {csv_file}") + print(f"Experiment results saved to: {exp_dir}") if __name__ == "__main__": diff --git a/src/lfd3d/datasets/lerobot/lerobot_dataset.py b/src/lfd3d/datasets/lerobot/lerobot_dataset.py index a66babd..562ba1d 100644 --- a/src/lfd3d/datasets/lerobot/lerobot_dataset.py +++ b/src/lfd3d/datasets/lerobot/lerobot_dataset.py @@ -104,6 +104,7 @@ def __init__( self.GRIPPER_IDX = { "aloha": np.array([6, 197, 174]), "human": np.array([343, 763, 60]), + "droid": np.array([356, 232, 16]), "libero_franka": np.array( [0, 1, 2] ), # gripper pcd in dataset: [left right top grasp-center] in agentview; (right gripper, left gripper, top, grasp-center) diff --git a/src/lfd3d/datasets/rgb_text_feature_gen.py b/src/lfd3d/datasets/rgb_text_feature_gen.py index 054d4a5..2cd56e8 100644 --- a/src/lfd3d/datasets/rgb_text_feature_gen.py +++ b/src/lfd3d/datasets/rgb_text_feature_gen.py @@ -7,6 +7,7 @@ Example: python rgb_text_feature_gen.py --dataset hoi4d --input_dir /path/to/hoi4d """ + import argparse import json import os @@ -126,7 +127,13 @@ def get_siglip_text_embedding( ) # Process text input - inputs = siglip_processor(text=[caption], return_tensors="pt", padding=True) + inputs = siglip_processor( + text=[caption], + return_tensors="pt", + padding=True, + truncation=True, + max_length=64, + ) inputs = {k: v.to(device) for k, v in inputs.items()} # Generate embeddings diff --git a/src/lfd3d/models/dino_3dgp.py b/src/lfd3d/models/dino_3dgp.py index f67776a..ba315f6 100644 --- a/src/lfd3d/models/dino_3dgp.py +++ b/src/lfd3d/models/dino_3dgp.py @@ -4,6 +4,7 @@ import cv2 import numpy as np +import ot import pytorch_lightning as pl import torch import torch.nn.functional as F @@ -156,8 +157,15 @@ def __init__(self, model_cfg): self.use_source_token = model_cfg.use_source_token if self.use_source_token: # Learnable embeddings: 0 = human, 1 = robot - self.source_to_idx = {"human": 0, "aloha": 1, "libero_franka": 2} - self.source_embeddings = nn.Embedding(3, self.hidden_dim) + self.source_to_idx = { + "human": 0, + "aloha": 1, + "libero_franka": 2, + "droid": 3, + } + self.source_embeddings = nn.Embedding( + len(self.source_to_idx), self.hidden_dim + ) # Transformer blocks (self-attention only) self.num_layers = model_cfg.num_transformer_layers @@ -468,7 +476,7 @@ def forward( # Predict GMM parameters outputs = self.output_head(tokens) # (B, T, 13) - return outputs, patch_coords + return outputs, patch_coords, tokens class Dino3DGPGoalRegressionModule(pl.LightningModule): @@ -494,6 +502,13 @@ def __init__(self, network, cfg) -> None: self.uniform_weights_coeff = cfg.model.uniform_weights_coeff self.is_gmm = cfg.model.is_gmm + # Optimal Transport loss parameters + self.use_ot_loss = cfg.model.use_ot_loss + self.ot_alpha = cfg.model.ot_alpha + self.ot_lambda = cfg.model.ot_lambda + self.ot_epsilon = cfg.model.ot_epsilon + self.ot_percentile = cfg.model.ot_percentile + # Gripper noise augmentation parameters self.gripper_noise_prob = cfg.model.gripper_noise_prob self.gripper_noise_translation = cfg.model.gripper_noise_translation @@ -810,7 +825,7 @@ def forward(self, batch): ) # (B, N, 4, 4) # Forward through network - outputs, patch_coords = self.network( + outputs, patch_coords, tokens = self.network( rgb, depth, all_intrinsics, @@ -842,6 +857,7 @@ def forward(self, batch): B, num_components, device=outputs.device, dtype=torch.bool ) + loss_dict = {} # Compute GMM loss if self.is_gmm: loss = 0 @@ -862,6 +878,7 @@ def forward(self, batch): var, use_weights=False, ) + loss_dict["gmm_loss"] = loss else: # Simple MSE loss (if not using GMM) # Get weighted prediction @@ -870,8 +887,70 @@ def forward(self, batch): dim=1 ) loss = F.mse_loss(pred_points, gt) + loss_dict["mse_loss"] = loss + + if self.use_ot_loss: + ot_loss = self.ot_loss( + tokens, + embodiment=batch["data_source"], + caption=batch["caption"], + goal_vec=(gt - init)[:, 0, :], + ) + loss_dict["ot_loss"] = ot_loss + loss = loss + (self.ot_alpha * ot_loss) + + return None, loss, loss_dict + + def ot_loss(self, tokens, embodiment, caption, goal_vec): + """ + Optimal Transport-based loss for domain adaptation based on EgoBridge. + Aligns distributions of the latent representations of human and robot data. + + Similar latents are expected when we have similar tasks (pick-place, fold) + and the goal vectors (goal_pos - current_pos) are similar. + + For this to work, batch size needs to be reasonably large and contain similar amounts + and types of human and robot data, careful! + """ + # Only compute OT loss if minibatch contains aloha and human data. + if set(embodiment) != {"aloha", "human"}: + return 0.0 + + human_mask = [i == "human" for i in embodiment] + robot_mask = [i == "aloha" for i in embodiment] + n_h, n_r = sum(human_mask), sum(robot_mask) + + # Group the captions by the first word + # [Fold the onesie, Fold the shirt] -> Fold + # Somewhat hacky, should probably do semantic similarity? + task = np.array([c.split(" ")[0] for c in caption]) + task_h, task_r = task[human_mask], task[robot_mask] + task_match = torch.tensor( + task_h[:, None] == task_r[None, :], device=tokens.device + ) + + # Similarity matrix of residual vectors (goal - current) + # Considered a match if the distance is less than the percentile threshold + res_h, res_r = goal_vec[human_mask], goal_vec[robot_mask] + R = torch.cdist(res_h, res_r) ** 2 + best_match = (R < R.quantile(self.ot_percentile)) & task_match + + z = tokens.mean(dim=1) # (B, T, D) -> (B, D) + z_h, z_r = z[human_mask], z[robot_mask] + + # Compute cost matrix of latents + C = torch.cdist(z_h, z_r) ** 2 + C = C / (C.max() + 1e-8) + C[best_match] *= self.ot_lambda # Discount latents which should align + C[~task_match] /= self.ot_lambda # Penalize cross-task - return None, loss + # Optimal Transport loss + a = torch.ones(n_h, device=tokens.device) / n_h + b = torch.ones(n_r, device=tokens.device) / n_r + T = ot.sinkhorn(a, b, C, reg=self.ot_epsilon) + loss = (T * C).sum() + + return loss def training_step(self, batch, batch_idx): """Training step with 3D GMM prediction""" @@ -882,8 +961,9 @@ def training_step(self, batch, batch_idx): self.train() batch_size = batch[self.label_key].points_padded().shape[0] - _, loss = self(batch) + _, loss, loss_dict = self(batch) train_metrics = {"loss": loss} + train_metrics.update(loss_dict) # Additional logging do_additional_logging = ( @@ -902,7 +982,7 @@ def training_step(self, batch, batch_idx): else: all_pred_dict = [self.predict(batch)] - pred_dict, weighted_displacement = all_pred_dict[0] + pred_dict, weighted_displacement, _ = all_pred_dict[0] pred_dict[self.prediction_type]["all_pred"] = [ i[0][self.prediction_type]["pred"] for i in all_pred_dict ] @@ -992,7 +1072,7 @@ def predict(self, batch, progress=False): ) # Forward - outputs, patch_coords = self.network( + outputs, patch_coords, tokens = self.network( rgb, depth, all_intrinsics, @@ -1002,13 +1082,15 @@ def predict(self, batch, progress=False): source=batch["data_source"], ) + z = tokens.mean(dim=1) # (B, T, D) -> (B, D) + if self.is_gmm: pred = self.sample_from_gmm(outputs, patch_coords) else: pred = self.get_weighted_prediction(outputs, patch_coords) pred_displacement = pred - init - return {self.prediction_type: {"pred": pred_displacement}}, outputs + return {self.prediction_type: {"pred": pred_displacement}}, outputs, z def sample_from_gmm(self, outputs, patch_coords): """ @@ -1203,6 +1285,14 @@ def mean_metric(metric_name): [x[metric_name].mean() for x in self.train_outputs if metric_name in x] ).mean() + # Log loss_dict components + if any("gmm_loss" in x for x in self.train_outputs): + log_dictionary["train/gmm_loss"] = mean_metric("gmm_loss") + if any("mse_loss" in x for x in self.train_outputs): + log_dictionary["train/mse_loss"] = mean_metric("mse_loss") + if any("ot_loss" in x for x in self.train_outputs): + log_dictionary["train/ot_loss"] = mean_metric("ot_loss") + if any("rmse" in x for x in self.train_outputs): log_dictionary["train/rmse"] = mean_metric("rmse") log_dictionary["train/wta_rmse"] = mean_metric("wta_rmse") @@ -1244,7 +1334,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): all_pred_dict.append(self.predict(batch)) else: all_pred_dict = [self.predict(batch)] - pred_dict, weighted_displacement = all_pred_dict[0] + pred_dict, weighted_displacement, _ = all_pred_dict[0] pred_dict[self.prediction_type]["all_pred"] = [ i[0][self.prediction_type]["pred"] for i in all_pred_dict @@ -1360,7 +1450,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): else: all_pred_dict = [self.predict(batch)] - pred_dict, weighted_displacement = all_pred_dict[0] + pred_dict, weighted_displacement, z = all_pred_dict[0] pred_dict[self.prediction_type]["all_pred"] = [ i[0][self.prediction_type]["pred"] for i in all_pred_dict ] @@ -1431,6 +1521,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): "wta_pix_dist": pred_dict["wta_pix_dist"], "vid_name": batch["vid_name"], "caption": batch["caption"], + "z": z, # (B, D) mean-pooled token representation } def on_predict_epoch_end(self):