diff --git a/configs/model/dino_3dgp.yaml b/configs/model/dino_3dgp.yaml index 4f2a4b1..0417aa5 100644 --- a/configs/model/dino_3dgp.yaml +++ b/configs/model/dino_3dgp.yaml @@ -2,7 +2,7 @@ name: dino_3dgp # Model settings type: cross_displacement # Type of model -use_text_embedding: True # If true, expects (siglip) text embedding to be provided as input and uses as another token +use_text_embedding: True # If true, expects text to be provided as input use_gripper_token: True # Adds an additional gripper token use_source_token: True # If true, adds a learnable token for human/robot data source num_transformer_layers: 4 # num blocks for self atttention @@ -20,7 +20,7 @@ fixed_variance: [0.01, 0.05, 0.1, 0.25, 0.5] # for gmm uniform_weights_coeff: 0.1 # coefficient for nll loss term when we use uniform mixing weights instead of pred # Optimal Transport loss settings (for domain adaptation) -use_ot_loss: True # Enable optimal transport loss for aligning human/robot latent distributions +use_ot_loss: False # Enable optimal transport loss for aligning human/robot latent distributions ot_alpha: 0.05 # Weight for combining OT loss with main loss ot_lambda: 0.1 # Discount factor for matching latents (lower = stronger discount) ot_epsilon: 0.1 # Regularization parameter for Sinkhorn algorithm diff --git a/configs/model/mimicplay.yaml b/configs/model/mimicplay.yaml new file mode 100644 index 0000000..6c43c8c --- /dev/null +++ b/configs/model/mimicplay.yaml @@ -0,0 +1,7 @@ +defaults: + - dino_3dgp # Mimicplay baseline is a modified version of Dino3DGP + +name: mimicplay + +kl_lambda: 1.0 # Weight for KL-div loss +gmm_min_std: 0.0001 diff --git a/configs/training/rpadLerobot_mimicplay.yaml b/configs/training/rpadLerobot_mimicplay.yaml new file mode 100644 index 0000000..806785c --- /dev/null +++ b/configs/training/rpadLerobot_mimicplay.yaml @@ -0,0 +1,15 @@ +defaults: + - base_train + +# Override augment_train to image_color_only for dino 3dgp (safe for multiview) +augment_train: "image_color_only" + +epochs: 100 +batch_size: 128 +val_batch_size: 128 + +# ModelCheckpoint configurations +checkpoints: + rmse: + monitor: val/rmse + mode: min diff --git a/scripts/latent_analyze_all.sh b/scripts/latent_analyze_all.sh new file mode 100755 index 0000000..71b8cbd --- /dev/null +++ b/scripts/latent_analyze_all.sh @@ -0,0 +1,20 @@ +python eval_lerobot_episode.py checkpoint.run_id=xa8xefu6 dataset=rpadLerobot dataset.repo_id="sriramsk/fold_onesie_MV_20251025_ss_hg_mini" model=dino_3dgp checkpoint.type=rmse resources.num_workers=32 dataset.val_episode_ratio=1 inference.n_eval_episode=10 + +python eval_lerobot_episode.py checkpoint.run_id=xa8xefu6 dataset=rpadLerobot dataset.repo_id="sriramsk/fold_onesie_MV_20251210_ss_hg" model=dino_3dgp checkpoint.type=rmse resources.num_workers=32 dataset.val_episode_ratio=1 inference.n_eval_episode=10 + +python eval_lerobot_episode.py checkpoint.run_id=xa8xefu6 dataset=rpadLerobot dataset.repo_id="sriramsk/fold_onesie_MVHuman_20251210_ss_hg" model=dino_3dgp checkpoint.type=rmse resources.num_workers=32 dataset.val_episode_ratio=1 inference.n_eval_episode=10 + +python eval_lerobot_episode.py checkpoint.run_id=xa8xefu6 dataset=rpadLerobot dataset.repo_id="sriramsk/fold_shirt_MV_20251210_ss_hg" model=dino_3dgp checkpoint.type=rmse resources.num_workers=32 dataset.val_episode_ratio=1 inference.n_eval_episode=10 + +python eval_lerobot_episode.py checkpoint.run_id=xa8xefu6 dataset=rpadLerobot dataset.repo_id="sriramsk/fold_towel_forEval_MV_20251210_ss_hg" model=dino_3dgp checkpoint.type=rmse resources.num_workers=32 dataset.val_episode_ratio=1 inference.n_eval_episode=10 + +python eval_lerobot_episode.py checkpoint.run_id=xa8xefu6 dataset=rpadLerobot dataset.repo_id="sriramsk/fold_towel_MVHuman_20251210_ss_hg" model=dino_3dgp checkpoint.type=rmse resources.num_workers=32 dataset.val_episode_ratio=1 inference.n_eval_episode=10 + + +python analyze_latents.py --dset1_path logs/xa8xefu6_sriramsk/fold_onesie_MV_20251210_ss_hg --dset2_path logs/xa8xefu6_sriramsk/fold_onesie_MVHuman_20251210_ss_hg --output robotOnesie_humanOnesie_withOTv2.png + +python analyze_latents.py --dset1_path logs/xa8xefu6_sriramsk/fold_onesie_MV_20251210_ss_hg --dset2_path logs/xa8xefu6_sriramsk/fold_onesie_MV_20251025_ss_hg_mini --output robotOnesie_robotOnesie_withOTv2.png + +python analyze_latents.py --dset1_path logs/xa8xefu6_sriramsk/fold_onesie_MV_20251210_ss_hg --dset2_path logs/xa8xefu6_sriramsk/fold_shirt_MV_20251210_ss_hg --output robotOnesie_robotShirt_withOTv2.png + +python analyze_latents.py --dset1_path logs/xa8xefu6_sriramsk/fold_towel_forEval_MV_20251210_ss_hg --dset2_path logs/xa8xefu6_sriramsk/fold_towel_MVHuman_20251210_ss_hg --output humanTowel_robotTowel_withOTv2.png diff --git a/src/lfd3d/datasets/lerobot/lerobot_dataset.py b/src/lfd3d/datasets/lerobot/lerobot_dataset.py index 562ba1d..8ccbc6c 100644 --- a/src/lfd3d/datasets/lerobot/lerobot_dataset.py +++ b/src/lfd3d/datasets/lerobot/lerobot_dataset.py @@ -211,6 +211,39 @@ def _transform_to_world_frame(self, points_cam, T_world_from_cam): return points_world + def load_gripper_trajectory(self, idx, data_source): + episode_index = self.lerobot_dataset[idx]["episode_index"] + gripper_idx = self.GRIPPER_IDX[data_source] + + GRIPPER_PCD_KEY = self.gripper_pcd_key + + NUM_FRAMES = 70 # Close to 5 secs into the future at 15fps + NUM_TIMESTEPS = 10 + + gripper_pcds = [] + end_idx = idx + (NUM_FRAMES // NUM_TIMESTEPS) + while ( + (end_idx < len(self.lerobot_dataset)) + and (self.lerobot_dataset[end_idx]["episode_index"] == episode_index) + and (len(gripper_pcds) < NUM_TIMESTEPS) + ): + gripper_pcds.append( + self.lerobot_dataset[end_idx][GRIPPER_PCD_KEY][gripper_idx] + ) + end_idx += NUM_FRAMES // NUM_TIMESTEPS + + if len(gripper_pcds) < NUM_TIMESTEPS: + last_pcd = ( + gripper_pcds[-1] + if len(gripper_pcds) > 0 + else self.lerobot_dataset[idx][GRIPPER_PCD_KEY][gripper_idx] + ) + while len(gripper_pcds) < NUM_TIMESTEPS: + gripper_pcds.append(last_pcd) + + gripper_trajectory = np.stack(gripper_pcds, axis=0) + return gripper_trajectory + def __getitem__(self, index): # Map the dataset index to the actual LeRobot dataset index using split_indices actual_index = self.split_indices[index] @@ -229,6 +262,8 @@ def __getitem__(self, index): cam_names, ) = self.load_transition(actual_index) + gripper_trajectory = self.load_gripper_trajectory(actual_index, data_source) + # Load intrinsics and extrinsics for all cameras all_intrinsics = [] all_extrinsics = [] @@ -358,6 +393,7 @@ def __getitem__(self, index): "action_pcd": start_tracks, "anchor_pcd": start_scene_pcd_world, "anchor_feat_pcd": start_scene_feat_pcd, + "gripper_trajectory": gripper_trajectory, # Labels "cross_displacement": end_tracks - start_tracks, # Text/metadata diff --git a/src/lfd3d/models/dino_3dgp.py b/src/lfd3d/models/dino_3dgp.py index 729b675..231a3e4 100644 --- a/src/lfd3d/models/dino_3dgp.py +++ b/src/lfd3d/models/dino_3dgp.py @@ -79,9 +79,9 @@ def forward(self, coords): class Dino3DGPNetwork(nn.Module): """ - DINOv2 + 3D positional encoding + Transformer for 3D goal prediction + DINOv3 + 3D positional encoding + Transformer for 3D goal prediction Architecture: - - Image tokens: DINOv2 patches with 3D PE (x,y,z from depth) + - Image tokens: DINOv3 patches with 3D PE (x,y,z from depth) - Language tokens: Flan-T5 (optional) - Gripper token: 6DoF pose + gripper width (optional) - Source token: learnable embedding for human/robot (optional) @@ -92,7 +92,7 @@ class Dino3DGPNetwork(nn.Module): def __init__(self, model_cfg): super(Dino3DGPNetwork, self).__init__() - # DINOv2 backbone + # DINOv3 backbone self.backbone_processor = AutoImageProcessor.from_pretrained( model_cfg.dino_model ) @@ -382,7 +382,7 @@ def forward( """ B, N, C, H, W = image.shape - # Extract DINOv2 features for each camera + # Extract DINOv3 features for each camera all_patch_features = [] for cam_idx in range(N): with torch.no_grad(): @@ -482,7 +482,7 @@ def forward( class Dino3DGPGoalRegressionModule(pl.LightningModule): """ A goal generation module for 3D goal prediction with RGB+Depth. - Similar to articubot.py but uses DINOv2 with RGB+depth instead of PointNet++. + Similar to articubot.py but uses DINOv3 with RGB+depth instead of PointNet++. """ def __init__(self, network, cfg) -> None: @@ -730,6 +730,148 @@ def get_gripper_token(self, gripper_points): gripper_token = torch.cat([gripper_pos, gripper_rot6d, gripper_width], dim=-1) return gripper_token + def combine_camera_data(self, batch): + """ + Combine primary and auxiliary camera data. + + Returns: + rgb: (B, N, 3, H, W) + depth: (B, N, H, W) + all_intrinsics: (B, N, 3, 3) + all_extrinsics: (B, N, 4, 4) + """ + primary_rgb = batch["rgbs"][:, 0] # (B, H, W, 3) + primary_depth = batch["depths"][:, 0] # (B, H, W) + aux_rgbs = batch["aux_rgbs"][:, :, 0, :, :, :] # (B, N_aux, H, W, 3) + aux_depths = batch["aux_depths"][:, :, 0, :, :] # (B, N_aux, H, W) + + # Stack along camera dimension + all_rgbs = torch.cat( + [primary_rgb.unsqueeze(1), aux_rgbs], dim=1 + ) # (B, N, H, W, 3) + all_depths = torch.cat( + [primary_depth.unsqueeze(1), aux_depths], dim=1 + ) # (B, N, H, W) + + # Clip depths + all_depths[all_depths > self.max_depth] = 0 + + # Permute RGB to (B, N, 3, H, W) + rgb = all_rgbs.permute(0, 1, 4, 2, 3) + depth = all_depths + + # Combine intrinsics and extrinsics + all_intrinsics = torch.cat( + [ + batch["intrinsics"].unsqueeze(1), # (B, 1, 3, 3) + batch["aux_intrinsics"], # (B, N_aux, 3, 3) + ], + dim=1, + ) # (B, N, 3, 3) + + all_extrinsics = torch.cat( + [ + batch["extrinsics"].unsqueeze(1), # (B, 1, 4, 4) + batch["aux_extrinsics"], # (B, N_aux, 4, 4) + ], + dim=1, + ) # (B, N, 4, 4) + + return rgb, depth, all_intrinsics, all_extrinsics + + def collect_and_stack_predictions(self, batch, n_samples): + """ + Collect multiple predictions and stack them. + + Returns: + pred_dict: Dictionary with stacked predictions in "all_pred" key + weighted_displacement: Weighted displacement from first prediction + z: Mean-pooled token representation + """ + all_pred_dict = [] + if self.is_gmm: + for i in range(n_samples): + all_pred_dict.append(self.predict(batch)) + else: + all_pred_dict = [self.predict(batch)] + + pred_dict, weighted_displacement, z = all_pred_dict[0] + pred_dict[self.prediction_type]["all_pred"] = [ + i[0][self.prediction_type]["pred"] for i in all_pred_dict + ] + pred_dict[self.prediction_type]["all_pred"] = torch.stack( + pred_dict[self.prediction_type]["all_pred"] + ).permute(1, 0, 2, 3) + + return pred_dict, weighted_displacement, z + + def calculate_pixel_metrics(self, pred_dict, batch, init, gt): + """ + Calculate pixel-based metrics by projecting 3D predictions to 2D. + + Args: + pred_dict: Prediction dictionary to update + batch: Batch data + init: Initial gripper positions (B, 4, 3) + gt: Ground truth gripper positions (B, 4, 3) + + Returns: + Updated pred_dict with pixel metrics + """ + intrinsics = batch["intrinsics"] + extrinsics = batch["extrinsics"] + H, W = batch["rgbs"].shape[2:4] + + # Project GT to 2D (take first point only) + gt_2d = ( + self.project_3d_to_2d(gt[:, :1, :], intrinsics, extrinsics, (H, W)) + .squeeze(1) + .long() + ) # (B, 2) + + # Project all predictions to 2D (take first point only) + all_pred_3d = ( + init[:, None, :, :] + pred_dict[self.prediction_type]["all_pred"] + ) # (B, N, 4, 3) + all_pred_2d = ( + self.project_3d_to_2d( + all_pred_3d[:, :, :1, :], intrinsics, extrinsics, (H, W) + ) + .squeeze(2) + .long() + ) # (B, N, 2) + + pred_dict = calc_pix_metrics(pred_dict, gt_2d, all_pred_2d, (H, W)) + return pred_dict + + def transform_points_homogeneous(self, points, transform_matrix): + """ + Transform 3D points using a 4x4 homogeneous transformation matrix. + + Args: + points: (N, 3) or (M, N, 3) array of 3D points + transform_matrix: (4, 4) transformation matrix + + Returns: + Transformed points with same shape as input + """ + original_shape = points.shape + if points.ndim == 2: + points = points[np.newaxis, ...] + + # Add homogeneous coordinate + points_hom = np.hstack( + (points.reshape(-1, 3), np.ones((points.reshape(-1, 3).shape[0], 1))) + ) + # Transform + points_transformed = (transform_matrix @ points_hom.T).T[:, :3] + + # Reshape back + if len(original_shape) == 2: + return points_transformed + else: + return points_transformed.reshape(original_shape) + def nll_loss( self, pred_displacement, @@ -784,45 +926,7 @@ def forward(self, batch): gripper_token = self.apply_gripper_noise_to_token(gripper_token) # Combine primary + auxiliary cameras - # Primary: batch["rgbs"][:, 0] is (B, H, W, 3), batch["depths"][:, 0] is (B, H, W) - # Auxiliary: batch["aux_rgbs"] is (B, N_aux, 2, H, W, 3), batch["aux_depths"] is (B, N_aux, 2, H, W) - - primary_rgb = batch["rgbs"][:, 0] # (B, H, W, 3) - primary_depth = batch["depths"][:, 0] # (B, H, W) - aux_rgbs = batch["aux_rgbs"][:, :, 0, :, :, :] # (B, N_aux, H, W, 3) - aux_depths = batch["aux_depths"][:, :, 0, :, :] # (B, N_aux, H, W) - - # Stack along camera dimension - all_rgbs = torch.cat( - [primary_rgb.unsqueeze(1), aux_rgbs], dim=1 - ) # (B, N, H, W, 3) - all_depths = torch.cat( - [primary_depth.unsqueeze(1), aux_depths], dim=1 - ) # (B, N, H, W) - - # Clip depths - all_depths[all_depths > self.max_depth] = 0 - - # Permute RGB to (B, N, 3, H, W) - rgb = all_rgbs.permute(0, 1, 4, 2, 3) - depth = all_depths - - # Combine intrinsics and extrinsics - all_intrinsics = torch.cat( - [ - batch["intrinsics"].unsqueeze(1), # (B, 1, 3, 3) - batch["aux_intrinsics"], # (B, N_aux, 3, 3) - ], - dim=1, - ) # (B, N, 3, 3) - - all_extrinsics = torch.cat( - [ - batch["extrinsics"].unsqueeze(1), # (B, 1, 4, 4) - batch["aux_extrinsics"], # (B, N_aux, 4, 4) - ], - dim=1, - ) # (B, N, 4, 4) + rgb, depth, all_intrinsics, all_extrinsics = self.combine_camera_data(batch) # Forward through network outputs, patch_coords, tokens = self.network( @@ -914,7 +1018,7 @@ def ot_loss(self, tokens, embodiment, caption, goal_vec): """ # Only compute OT loss if minibatch contains aloha and human data. if set(embodiment) != {"aloha", "human"}: - return 0.0 + return torch.tensor(0.0, device=tokens.device) human_mask = [i == "human" for i in embodiment] robot_mask = [i == "aloha" for i in embodiment] @@ -975,20 +1079,9 @@ def training_step(self, batch, batch_idx): n_samples_wta = self.run_cfg.n_samples_wta self.eval() with torch.no_grad(): - all_pred_dict = [] - if self.is_gmm: - for i in range(n_samples_wta): - all_pred_dict.append(self.predict(batch)) - else: - all_pred_dict = [self.predict(batch)] - - pred_dict, weighted_displacement, _ = all_pred_dict[0] - pred_dict[self.prediction_type]["all_pred"] = [ - i[0][self.prediction_type]["pred"] for i in all_pred_dict - ] - pred_dict[self.prediction_type]["all_pred"] = torch.stack( - pred_dict[self.prediction_type]["all_pred"] - ).permute(1, 0, 2, 3) + pred_dict, weighted_displacement, _ = ( + self.collect_and_stack_predictions(batch, n_samples_wta) + ) self.train() init, gt = self.extract_gt_4_points(batch) @@ -1005,31 +1098,7 @@ def training_step(self, batch, batch_idx): padding_mask, ) - # Calculate pixel metrics - intrinsics = batch["intrinsics"] - extrinsics = batch["extrinsics"] - H, W = batch["rgbs"].shape[2:4] - - # Project GT to 2D (take first point only) - gt_2d = ( - self.project_3d_to_2d(gt[:, :1, :], intrinsics, extrinsics, (H, W)) - .squeeze(1) - .long() - ) # (B, 2) - - # Project all predictions to 2D (take first point only) - all_pred_3d = ( - init[:, None, :, :] + pred_dict[self.prediction_type]["all_pred"] - ) # (B, N, 4, 3) - all_pred_2d = ( - self.project_3d_to_2d( - all_pred_3d[:, :, :1, :], intrinsics, extrinsics, (H, W) - ) - .squeeze(2) - .long() - ) # (B, N, 2) - - pred_dict = calc_pix_metrics(pred_dict, gt_2d, all_pred_2d, (H, W)) + pred_dict = self.calculate_pixel_metrics(pred_dict, batch, init, gt) train_metrics.update(pred_dict) if self.trainer.is_global_zero: @@ -1047,29 +1116,8 @@ def predict(self, batch, progress=False): init, gt = self.extract_gt_4_points(batch) gripper_token = self.get_gripper_token(init) - # Combine primary + auxiliary cameras (same as forward) - primary_rgb = batch["rgbs"][:, 0] # (B, H, W, 3) - primary_depth = batch["depths"][:, 0] # (B, H, W) - aux_rgbs = batch["aux_rgbs"][:, :, 0, :, :, :] # (B, N_aux, H, W, 3) - aux_depths = batch["aux_depths"][:, :, 0, :, :] # (B, N_aux, H, W) - - all_rgbs = torch.cat( - [primary_rgb.unsqueeze(1), aux_rgbs], dim=1 - ) # (B, N, H, W, 3) - all_depths = torch.cat( - [primary_depth.unsqueeze(1), aux_depths], dim=1 - ) # (B, N, H, W) - - rgb = all_rgbs.permute(0, 1, 4, 2, 3) # (B, N, 3, H, W) - depth = all_depths - - all_intrinsics = torch.cat( - [batch["intrinsics"].unsqueeze(1), batch["aux_intrinsics"]], dim=1 - ) - - all_extrinsics = torch.cat( - [batch["extrinsics"].unsqueeze(1), batch["aux_extrinsics"]], dim=1 - ) + # Combine primary + auxiliary cameras + rgb, depth, all_intrinsics, all_extrinsics = self.combine_camera_data(batch) # Forward outputs, patch_coords, tokens = self.network( @@ -1188,16 +1236,14 @@ def log_viz_to_wandb(self, batch, pred_dict, weighted_displacement, tag): ) # Transform to end frame - pcd_endframe = np.hstack((pcd, np.ones((pcd.shape[0], 1)))) - pcd_endframe = (end2start @ pcd_endframe.T).T[:, :3] - all_pred_pcd_tmp = [] - for i in range(N): - tmp_pcd = np.hstack((all_pred_pcd[i], np.ones((all_pred_pcd.shape[1], 1)))) - tmp_pcd = (end2start @ tmp_pcd.T).T[:, :3] - all_pred_pcd_tmp.append(tmp_pcd) - all_pred_pcd = np.stack(all_pred_pcd_tmp) - gt_pcd = np.hstack((gt_pcd, np.ones((gt_pcd.shape[0], 1)))) - gt_pcd = (end2start @ gt_pcd.T).T[:, :3] + pcd_endframe = self.transform_points_homogeneous(pcd, end2start) + all_pred_pcd = np.stack( + [ + self.transform_points_homogeneous(all_pred_pcd[i], end2start) + for i in range(N) + ] + ) + gt_pcd = self.transform_points_homogeneous(gt_pcd, end2start) # Transform from world frame to primary camera frame for projection # Primary camera extrinsics: T_world_from_cam, we need T_cam_from_world @@ -1206,21 +1252,15 @@ def log_viz_to_wandb(self, batch, pred_dict, weighted_displacement, tag): # Transform points to primary camera frame # Transform initial pcd (for init_rgb_proj) - pcd_cam = np.hstack((pcd, np.ones((pcd.shape[0], 1)))) - pcd_cam = (T_cam_from_world @ pcd_cam.T).T[:, :3] - - pcd_endframe = np.hstack((pcd_endframe, np.ones((pcd_endframe.shape[0], 1)))) - pcd_endframe = (T_cam_from_world @ pcd_endframe.T).T[:, :3] - - all_pred_pcd_tmp = [] - for i in range(N): - tmp_pcd = np.hstack((all_pred_pcd[i], np.ones((all_pred_pcd.shape[1], 1)))) - tmp_pcd = (T_cam_from_world @ tmp_pcd.T).T[:, :3] - all_pred_pcd_tmp.append(tmp_pcd) - all_pred_pcd = np.stack(all_pred_pcd_tmp) - - gt_pcd = np.hstack((gt_pcd, np.ones((gt_pcd.shape[0], 1)))) - gt_pcd = (T_cam_from_world @ gt_pcd.T).T[:, :3] + pcd_cam = self.transform_points_homogeneous(pcd, T_cam_from_world) + pcd_endframe = self.transform_points_homogeneous(pcd_endframe, T_cam_from_world) + all_pred_pcd = np.stack( + [ + self.transform_points_homogeneous(all_pred_pcd[i], T_cam_from_world) + for i in range(N) + ] + ) + gt_pcd = self.transform_points_homogeneous(gt_pcd, T_cam_from_world) K = batch["intrinsics"][viz_idx].cpu().numpy() @@ -1328,20 +1368,9 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): n_samples_wta = self.run_cfg.n_samples_wta self.eval() with torch.no_grad(): - all_pred_dict = [] - if self.is_gmm: - for i in range(n_samples_wta): - all_pred_dict.append(self.predict(batch)) - else: - all_pred_dict = [self.predict(batch)] - pred_dict, weighted_displacement, _ = all_pred_dict[0] - - pred_dict[self.prediction_type]["all_pred"] = [ - i[0][self.prediction_type]["pred"] for i in all_pred_dict - ] - pred_dict[self.prediction_type]["all_pred"] = torch.stack( - pred_dict[self.prediction_type]["all_pred"] - ).permute(1, 0, 2, 3) + pred_dict, weighted_displacement, _ = self.collect_and_stack_predictions( + batch, n_samples_wta + ) init, gt = self.extract_gt_4_points(batch) gt_displacement = gt - init @@ -1357,31 +1386,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): padding_mask, ) - # Calculate pixel metrics - intrinsics = batch["intrinsics"] - extrinsics = batch["extrinsics"] - H, W = batch["rgbs"].shape[2:4] - - # Project GT to 2D (take first point only) - gt_2d = ( - self.project_3d_to_2d(gt[:, :1, :], intrinsics, extrinsics, (H, W)) - .squeeze(1) - .long() - ) # (B, 2) - - # Project all predictions to 2D (take first point only) - all_pred_3d = ( - init[:, None, :, :] + pred_dict[self.prediction_type]["all_pred"] - ) # (B, N, 4, 3) - all_pred_2d = ( - self.project_3d_to_2d( - all_pred_3d[:, :, :1, :], intrinsics, extrinsics, (H, W) - ) - .squeeze(2) - .long() - ) # (B, N, 2) - - pred_dict = calc_pix_metrics(pred_dict, gt_2d, all_pred_2d, (H, W)) + pred_dict = self.calculate_pixel_metrics(pred_dict, batch, init, gt) self.val_outputs[val_tag].append(pred_dict) if ( @@ -1443,20 +1448,9 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): eval_tag = self.trainer.datamodule.eval_tags[dataloader_idx] n_samples_wta = self.trainer.datamodule.n_samples_wta - all_pred_dict = [] - if self.is_gmm: - for i in range(n_samples_wta): - all_pred_dict.append(self.predict(batch)) - else: - all_pred_dict = [self.predict(batch)] - - pred_dict, weighted_displacement, z = all_pred_dict[0] - pred_dict[self.prediction_type]["all_pred"] = [ - i[0][self.prediction_type]["pred"] for i in all_pred_dict - ] - pred_dict[self.prediction_type]["all_pred"] = torch.stack( - pred_dict[self.prediction_type]["all_pred"] - ).permute(1, 0, 2, 3) + pred_dict, weighted_displacement, z = self.collect_and_stack_predictions( + batch, n_samples_wta + ) init, gt = self.extract_gt_4_points(batch) gt_displacement = gt - init @@ -1472,31 +1466,11 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): padding_mask, ) - # Calculate pixel metrics intrinsics = batch["intrinsics"] extrinsics = batch["extrinsics"] H, W = batch["rgbs"].shape[2:4] - # Project GT to 2D (take first point only) - gt_2d = ( - self.project_3d_to_2d(gt[:, :1, :], intrinsics, extrinsics, (H, W)) - .squeeze(1) - .long() - ) # (B, 2) - - # Project all predictions to 2D (take first point only) - all_pred_3d = ( - init[:, None, :, :] + pred_dict[self.prediction_type]["all_pred"] - ) # (B, N, 4, 3) - all_pred_2d = ( - self.project_3d_to_2d( - all_pred_3d[:, :, :1, :], intrinsics, extrinsics, (H, W) - ) - .squeeze(2) - .long() - ) # (B, N, 2) - - pred_dict = calc_pix_metrics(pred_dict, gt_2d, all_pred_2d, (H, W)) + pred_dict = self.calculate_pixel_metrics(pred_dict, batch, init, gt) self.predict_outputs[eval_tag].append(pred_dict) self.predict_weighted_displacements[eval_tag].append( weighted_displacement.cpu() diff --git a/src/lfd3d/models/mimicplay.py b/src/lfd3d/models/mimicplay.py new file mode 100644 index 0000000..07c8b75 --- /dev/null +++ b/src/lfd3d/models/mimicplay.py @@ -0,0 +1,1074 @@ +import random +import types +from collections import defaultdict +from typing import Dict, List + +import cv2 +import numpy as np +import pytorch_lightning as pl +import torch +import torch.distributions as D +import torch.nn.functional as F +import wandb +from diffusers import get_cosine_schedule_with_warmup +from pytorch3d.transforms import ( + euler_angles_to_matrix, + matrix_to_rotation_6d, + rotation_6d_to_matrix, +) +from torch import nn, optim + +from lfd3d.utils.viz_utils import ( + get_img_and_track_pcd, + interpolate_colors, + invert_augmentation_and_normalization, + project_pcd_on_image, +) + + +def calc_traj_metrics(pred_dict, all_pred, gt_trajectory): + """ + Calculate trajectory metrics and update pred_dict with the keys. + + pred_dict: Dictionary with keys to be updated + all_pred: Predicted trajectories multiple samples + gt_trajectory: GT trajectory + """ + all_rmse = [] + for i in range(all_pred.shape[1]): + pred = all_pred[:, i] + all_rmse.append(((pred - gt_trajectory) ** 2).mean((1, 2)) ** 0.5) + pred_dict["rmse"] = all_rmse[0] + pred_dict["wta_rmse"] = torch.stack(all_rmse).min(0)[0] + return pred_dict + + +def calc_traj_pix_metrics(pred_dict, gt_traj, all_pred_traj, img_shape): + """ + Calculate pixel distance metrics and update pred_dict with the keys. + + pred_dict: Dictionary with keys to be updated + gt_idx: GT pixel trajectory (B, 10, 2) [x, y] + all_pred_idx: Predicted goal pixel trajectories of multiple samples (B, N, 10, 2) + img_shape: (H, W) image shape for normalization + """ + H, W = img_shape + + # Calculate L1 distances for all samples + # gt_idx: (B, 2), all_pred_idx: (B, N, 2) + gt_expanded = gt_traj.unsqueeze(1) # (B, 1, 2) + + # L1 distance in pixel space + pix_distances = torch.sum( + torch.abs(all_pred_traj.float() - gt_expanded.float()), dim=(2, 3) + ) # (B, N) + + # Normalize by max Manhattan distance for 0-1 range + max_dist = H + W + normalized_distances = pix_distances / max_dist + + # Use first sample for single prediction metrics + pred_dict["pix_dist"] = pix_distances[:, 0] # (B,) + pred_dict["normalized_pix_dist"] = normalized_distances[:, 0] # (B,) + + # Winner-takes-all: best sample across all predictions + pred_dict["wta_pix_dist"] = pix_distances.min(dim=1)[0] # (B,) + pred_dict["wta_normalized_pix_dist"] = normalized_distances.min(dim=1)[0] # (B,) + + return pred_dict + + +def monkey_patch_mimicplay(network): + """ + Monkey-patch in alternate functionality to train Mimicplay baseline. + """ + + def mimicplay_forward( + self, + image, + depth, + intrinsics, + extrinsics, + gripper_token=None, + text=None, + source=None, + ): + """ + Modified version of forward() for Dino3DGP + """ + B, N, C, H, W = image.shape + + # Extract DINOv3 features for each camera + all_patch_features = [] + for cam_idx in range(N): + with torch.no_grad(): + cam_image = image[:, cam_idx, :, :, :] # (B, 3, H, W) + inputs = self.backbone_processor(images=cam_image, return_tensors="pt") + inputs = {k: v.to(self.backbone.device) for k, v in inputs.items()} + dino_outputs = self.backbone(**inputs) + + # Get patch features (skip CLS and register tokens) + patch_features = dino_outputs.last_hidden_state[ + :, 5: + ] # (B, 196, dino_hidden_dim) + all_patch_features.append(patch_features) + + # Concatenate features from all cameras: (B, N*196, dino_hidden_dim) + patch_features = torch.cat(all_patch_features, dim=1) + + # Get 3D positional encoding for patches (in world frame) + patch_coords = self.get_patch_centers( + H, W, intrinsics, depth, extrinsics + ) # (B, N*196, 3) + pos_encoding = self.pos_encoder(patch_coords) # (B, N*196, 128) + + # Combine patch features with positional encoding + tokens = torch.cat( + [patch_features, pos_encoding], dim=-1 + ) # (B, N*196, hidden_dim) + + # Apply image token dropout (training only) + tokens, patch_coords = self.apply_image_token_dropout(tokens, patch_coords, N) + + # Number of tokens T <= N*196 + num_patch_tokens = tokens.shape[1] + mask = torch.zeros(B, num_patch_tokens, dtype=torch.bool, device=tokens.device) + + # Add language tokens + if self.use_text_embedding: + text_tokens = self.text_tokenizer( + text, return_tensors="pt", padding=True, truncation=True + ) + text_tokens = { + k: v.to(self.text_encoder.device) for k, v in text_tokens.items() + } + text_embedding = self.text_encoder(**text_tokens).last_hidden_state + + lang_tokens = self.text_proj(text_embedding) # (B, J, hidden_dim) + tokens = torch.cat([tokens, lang_tokens], dim=1) # (B, T+J, hidden_dim) + mask = torch.cat([mask, text_tokens["attention_mask"] == 0], dim=1) + + # Add gripper token + if self.use_gripper_token: + grip_token = self.gripper_encoder(gripper_token).unsqueeze( + 1 + ) # (B, 1, hidden_dim) + tokens = torch.cat([tokens, grip_token], dim=1) # (B, T+J+1, hidden_dim) + mask = torch.cat( + [mask, torch.zeros(B, 1, dtype=torch.bool, device=tokens.device)], dim=1 + ) + + # Add source token + if self.use_source_token: + source_indices = torch.tensor( + [self.source_to_idx[s] for s in source], device=tokens.device + ) + source_token = self.source_embeddings(source_indices).unsqueeze( + 1 + ) # (B, 1, hidden_dim) + tokens = torch.cat([tokens, source_token], dim=1) # (B, T+J+2, hidden_dim) + mask = torch.cat( + [mask, torch.zeros(B, 1, dtype=torch.bool, device=tokens.device)], dim=1 + ) + + tokens = torch.cat([tokens, self.registers.expand(B, -1, -1)], dim=1) + mask = torch.cat( + [ + mask, + torch.zeros( + B, self.num_registers, dtype=torch.bool, device=tokens.device + ), + ], + dim=1, + ) + + # NEW ### + # Add CLS token to hold latent plan + tokens = torch.cat([tokens, self.cls_token.expand(B, -1, -1)], dim=1) + mask = torch.cat( + [ + mask, + torch.zeros(B, 1, dtype=torch.bool, device=tokens.device), + ], + dim=1, + ) + + # Apply transformer blocks + for block in self.transformer_blocks: + tokens = block(tokens, src_key_padding_mask=mask) + + # NEW ### + # Take only the CLS token + latent_plan = tokens[:, -1, :] # (B, hidden_dim) + # Predict GMM parameters + outputs = self.gmm_decoder(latent_plan) + + means = outputs[:, :150].reshape(-1, 5, 30) # (B, 5, 30) + raw_scales = outputs[:, 150:300].reshape(-1, 5, 30) # (B, 5, 30) + logits = outputs[:, 300:305].reshape(-1, 5) # (B, 5) + + scales = F.softplus(raw_scales) + 0.0001 + component_dist = D.Independent(D.Normal(means, scales), 1) + mixture_dist = D.Categorical(logits=logits) + gmm_dist = D.MixtureSameFamily(mixture_dist, component_dist) + + return latent_plan, gmm_dist + + # Add extra params for latent plan and GMM decoder + network.cls_token = nn.Parameter(torch.randn(1, 1, network.hidden_dim) * 0.02) + network.num_modes = 5 + network.pred_timesteps = 10 + # Predict 10-step trajectory of 4th gripper point (xyz only) + # Output: means (5, 30) + scales (5, 30) + logits (5) = 305 total + network.gmm_decoder = nn.Sequential( + nn.Linear(network.hidden_dim, 400), + nn.Softplus(), + nn.Linear(400, 400), + nn.Softplus(), + nn.Linear( + 400, + network.num_modes * (network.pred_timesteps * 3 * 2) + network.num_modes, + ), + # 5 modes × (10 timesteps × 3 coords × 2 [mean+scale]) + 5 logits = 305 + ) + del network.output_head + + # Add in a separate forward pass for MimicPlay + network.mimicplay_forward = types.MethodType(mimicplay_forward, network) + + return network + + +class MimicplayModule(pl.LightningModule): + """ + MimicPlay baseline building on top of Dino3DGP + Monkey-patches a different forward pass into the model + and uses a different loss function. + """ + + def __init__(self, network, cfg) -> None: + super().__init__() + self.network = monkey_patch_mimicplay(network) + self.model_cfg = cfg.model + self.prediction_type = self.model_cfg.type + self.mode = cfg.mode # train or eval + self.val_outputs: defaultdict[str, List[Dict]] = defaultdict(list) + self.train_outputs: List[Dict] = [] + self.predict_outputs: defaultdict[str, List[Dict]] = defaultdict(list) + + # Gripper noise augmentation parameters + self.gripper_noise_prob = cfg.model.gripper_noise_prob + self.gripper_noise_translation = cfg.model.gripper_noise_translation + self.gripper_noise_rotation = cfg.model.gripper_noise_rotation + self.gripper_noise_width = cfg.model.gripper_noise_width + + # KL divergence loss parameters (MimicPlay-style) + self.kl_lambda = cfg.model.kl_lambda + self.min_std = cfg.model.gmm_min_std + + if self.prediction_type != "cross_displacement": + raise ValueError(f"Invalid prediction type: {self.prediction_type}") + self.label_key = "cross_displacement" + + # mode-specific processing + if self.mode == "train": + self.run_cfg = cfg.training + # training-specific params + self.lr = self.run_cfg.lr + self.weight_decay = self.run_cfg.weight_decay + self.num_training_steps = self.run_cfg.num_training_steps + self.lr_warmup_steps = self.run_cfg.lr_warmup_steps + self.additional_train_logging_period = ( + self.run_cfg.additional_train_logging_period + ) + elif self.mode == "eval": + self.run_cfg = cfg.inference + else: + raise ValueError(f"Invalid mode: {self.mode}") + + # data params + self.batch_size = self.run_cfg.batch_size + self.val_batch_size = self.run_cfg.val_batch_size + self.max_depth = cfg.dataset.max_depth + + def configure_optimizers(self): + assert self.mode == "train", "Can only configure optimizers in training mode." + optimizer = optim.AdamW( + self.parameters(), lr=self.lr, weight_decay=self.weight_decay + ) + lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=self.lr_warmup_steps, + num_training_steps=self.num_training_steps, + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": lr_scheduler, + "interval": "step", # Step after every batch + }, + } + + def apply_gripper_noise_to_token(self, gripper_token): + """ + Apply noise to gripper token during training. + + Args: + gripper_token: (B, 10) [pos (3) + rot6d (6) + width (1)] + + Returns: + noisy_token: (B, 10) gripper token with noise applied + """ + if not self.training or random.random() > self.gripper_noise_prob: + return gripper_token + + B = gripper_token.shape[0] + device = gripper_token.device + + # Parse gripper token + pos = gripper_token[:, :3] # (B, 3) + rot6d = gripper_token[:, 3:9] # (B, 6) + width = gripper_token[:, 9:10] # (B, 1) + + # 1. Add translation noise + translation_noise = torch.empty_like(pos).uniform_( + -self.gripper_noise_translation, self.gripper_noise_translation + ) + noisy_pos = pos + translation_noise + + # 2. Add rotation noise + # Convert rot6d to rotation matrix + rot_matrix = rotation_6d_to_matrix(rot6d) # (B, 3, 3) + + # Generate random euler angles (XYZ convention) + euler_noise = torch.empty(B, 3, device=device).uniform_( + -np.deg2rad(self.gripper_noise_rotation), + np.deg2rad(self.gripper_noise_rotation), + ) # (B, 3) in radians + + # Convert euler angles to rotation matrix + R_noise = euler_angles_to_matrix(euler_noise, convention="XYZ") # (B, 3, 3) + + # Apply noise: R_new = R_noise @ R_original + noisy_rot_matrix = torch.bmm(R_noise, rot_matrix) # (B, 3, 3) + + # Convert back to rot6d + noisy_rot6d = matrix_to_rotation_6d(noisy_rot_matrix) # (B, 6) + + # 3. Add gripper width noise + width_noise = torch.empty_like(width).uniform_( + -self.gripper_noise_width, self.gripper_noise_width + ) + noisy_width = width + width_noise + + # Reconstruct token + noisy_token = torch.cat([noisy_pos, noisy_rot6d, noisy_width], dim=1) + return noisy_token + + def extract_gt_4_points(self, batch): + """Extract ground truth goal points (4 gripper points)""" + cross_displacement = batch[self.label_key].points_padded() + initial_gripper = batch["action_pcd"].points_padded() + ground_truth_gripper = initial_gripper + cross_displacement + batch_indices = torch.arange( + ground_truth_gripper.shape[0], device=ground_truth_gripper.device + ).unsqueeze(1) + + # Select specific idxs to compute the loss over + gt_primary_points = ground_truth_gripper[batch_indices, batch["gripper_idx"], :] + # Assumes 0/1 are tips to be averaged + gt_extra_point = (gt_primary_points[:, 0, :] + gt_primary_points[:, 1, :]) / 2 + gt = torch.cat([gt_primary_points, gt_extra_point[:, None, :]], dim=1) + + init_primary_points = initial_gripper[batch_indices, batch["gripper_idx"], :] + init_extra_point = ( + init_primary_points[:, 0, :] + init_primary_points[:, 1, :] + ) / 2 + init = torch.cat([init_primary_points, init_extra_point[:, None, :]], dim=1) + return init, gt + + def extract_gt_trajectory(self, batch): + """Extract ground truth goal points (4 gripper points)""" + gripper_trajectory = batch["gripper_trajectory"] + # Assumes 0/1 are tips to be averaged + gt_trajectory = ( + gripper_trajectory[:, :, 0, :] + gripper_trajectory[:, :, 1, :] + ) / 2 + return gt_trajectory + + def project_3d_to_2d( + self, points_3d_world, intrinsics, extrinsics, img_shape=(224, 224) + ): + """ + Project 3D points in world frame to 2D pixel coordinates. + + Args: + points_3d_world: (B, N, 3) or (B, N, M, 3) 3D points in WORLD frame + intrinsics: (B, 3, 3) camera intrinsics + extrinsics: (B, 4, 4) camera-to-world transformation (T_world_from_cam) + img_shape: (H, W) image shape for clamping + + Returns: + pixel_coords: (B, N, 2) or (B, N, M, 2) pixel coordinates [x, y] + """ + H, W = img_shape + original_shape = points_3d_world.shape + + # Reshape to (B, -1, 3) for batch processing + if len(original_shape) == 4: + B, N, M, _ = original_shape + points_3d_world = points_3d_world.reshape(B, N * M, 3) + else: + B, N, _ = original_shape + + # Transform from world frame to camera frame + # extrinsics is T_world_from_cam, we need T_cam_from_world = inv(T_world_from_cam) + T_cam_from_world = torch.inverse(extrinsics) # (B, 4, 4) + + ones = torch.ones(B, points_3d_world.shape[1], 1, device=points_3d_world.device) + points_world_hom = torch.cat([points_3d_world, ones], dim=-1) # (B, N*M, 4) + + # Apply transformation: (B, 4, 4) @ (B, N*M, 4) -> (B, N*M, 4) -> (B, N*M, 3) + points_3d_cam = torch.einsum( + "bij,bnj->bni", T_cam_from_world, points_world_hom + )[:, :, :3] + + fx = intrinsics[:, 0, 0].unsqueeze(1) # (B, 1) + fy = intrinsics[:, 1, 1].unsqueeze(1) + cx = intrinsics[:, 0, 2].unsqueeze(1) + cy = intrinsics[:, 1, 2].unsqueeze(1) + + # Project: [x, y, z] -> [u, v] + # Add epsilon to avoid division by zero + z = points_3d_cam[:, :, 2].clamp(min=1e-6) + u = (points_3d_cam[:, :, 0] * fx / z + cx).clamp(0, W - 1) # (B, N*M) + v = (points_3d_cam[:, :, 1] * fy / z + cy).clamp(0, H - 1) # (B, N*M) + + pixel_coords = torch.stack([u, v], dim=2) # (B, N*M, 2) + + # Reshape back to original shape + if len(original_shape) == 4: + pixel_coords = pixel_coords.reshape(B, N, M, 2) + + return pixel_coords + + def gripper_points_to_rotation(self, gripper_center, palm_point, finger_point): + # Always use palm->gripper as primary axis (more stable) + forward = gripper_center - palm_point + x_axis = forward / torch.linalg.norm(forward, dim=1, keepdim=True) + + # Use finger relative to the forward direction for secondary axis + finger_vec = gripper_center - finger_point + + # Project finger vector onto plane perpendicular to forward + finger_projected = ( + finger_vec - torch.sum(finger_vec * x_axis, dim=1, keepdim=True) * x_axis + ) + y_axis = finger_projected / torch.linalg.norm( + finger_projected, dim=1, keepdim=True + ) + + # Z completes the frame + z_axis = torch.cross(x_axis, y_axis) + + return torch.stack([x_axis, y_axis, z_axis], dim=-1) + + def get_gripper_token(self, gripper_points): + """ + Extract gripper state as a token (6DoF pose + gripper width). + """ + gripper_pos = (gripper_points[:, 0, :] + gripper_points[:, 1, :]) / 2 + gripper_width = torch.linalg.norm( + gripper_points[:, 0, :] - gripper_points[:, 1, :], dim=1 + )[:, None] + # eef pose, base, right finger + gripper_rot = self.gripper_points_to_rotation( + gripper_pos, gripper_points[:, 2, :], gripper_points[:, 0, :] + ) + gripper_rot6d = matrix_to_rotation_6d(gripper_rot) + + gripper_token = torch.cat([gripper_pos, gripper_rot6d, gripper_width], dim=-1) + return gripper_token + + def combine_camera_data(self, batch): + """ + Combine primary and auxiliary camera data. + + Returns: + rgb: (B, N, 3, H, W) + depth: (B, N, H, W) + all_intrinsics: (B, N, 3, 3) + all_extrinsics: (B, N, 4, 4) + """ + primary_rgb = batch["rgbs"][:, 0] # (B, H, W, 3) + primary_depth = batch["depths"][:, 0] # (B, H, W) + aux_rgbs = batch["aux_rgbs"][:, :, 0, :, :, :] # (B, N_aux, H, W, 3) + aux_depths = batch["aux_depths"][:, :, 0, :, :] # (B, N_aux, H, W) + + # Stack along camera dimension + all_rgbs = torch.cat( + [primary_rgb.unsqueeze(1), aux_rgbs], dim=1 + ) # (B, N, H, W, 3) + all_depths = torch.cat( + [primary_depth.unsqueeze(1), aux_depths], dim=1 + ) # (B, N, H, W) + + # Clip depths + all_depths[all_depths > self.max_depth] = 0 + + # Permute RGB to (B, N, 3, H, W) + rgb = all_rgbs.permute(0, 1, 4, 2, 3) + depth = all_depths + + # Combine intrinsics and extrinsics + all_intrinsics = torch.cat( + [ + batch["intrinsics"].unsqueeze(1), # (B, 1, 3, 3) + batch["aux_intrinsics"], # (B, N_aux, 3, 3) + ], + dim=1, + ) # (B, N, 3, 3) + + all_extrinsics = torch.cat( + [ + batch["extrinsics"].unsqueeze(1), # (B, 1, 4, 4) + batch["aux_extrinsics"], # (B, N_aux, 4, 4) + ], + dim=1, + ) # (B, N, 4, 4) + + return rgb, depth, all_intrinsics, all_extrinsics + + def collect_and_stack_predictions(self, batch, n_samples): + """ + Collect multiple predictions and stack them. + + Returns: + pred_dict: Dictionary with stacked predictions in "all_pred" key + pred_gmm: GMM distribution from first prediction + """ + all_pred_dict = [] + for i in range(n_samples): + all_pred_dict.append(self.predict(batch)) + + pred_dict, pred_gmm = all_pred_dict[0] + pred_dict[self.prediction_type]["all_pred"] = [ + i[0][self.prediction_type]["pred"] for i in all_pred_dict + ] + pred_dict[self.prediction_type]["all_pred"] = torch.stack( + pred_dict[self.prediction_type]["all_pred"] + ).permute(1, 0, 2, 3) + + return pred_dict, pred_gmm + + def calculate_pixel_metrics(self, pred_dict, batch, gt_trajectory): + """ + Calculate pixel-based metrics by projecting 3D predictions to 2D. + + Args: + pred_dict: Prediction dictionary to update + batch: Batch data + gt_trajectory: Ground truth trajectory (B, 10, 3) + + Returns: + Updated pred_dict with pixel metrics + """ + intrinsics = batch["intrinsics"] + extrinsics = batch["extrinsics"] + H, W = batch["rgbs"].shape[2:4] + + # Project GT to 2D + gt_2d = self.project_3d_to_2d( + gt_trajectory, intrinsics, extrinsics, (H, W) + ).long() # (B, 10, 2) + + # Project all predictions to 2D + all_pred_3d = pred_dict[self.prediction_type]["all_pred"] # (B, N, 10, 3) + all_pred_2d = self.project_3d_to_2d( + all_pred_3d, intrinsics, extrinsics, (H, W) + ).long() # (B, N, 10, 2) + + pred_dict = calc_traj_pix_metrics(pred_dict, gt_2d, all_pred_2d, (H, W)) + return pred_dict + + def transform_points_homogeneous(self, points, transform_matrix): + """ + Transform 3D points using a 4x4 homogeneous transformation matrix. + + Args: + points: (N, 3) or (M, N, 3) array of 3D points + transform_matrix: (4, 4) transformation matrix + + Returns: + Transformed points with same shape as input + """ + original_shape = points.shape + if points.ndim == 2: + points = points[np.newaxis, ...] + + # Add homogeneous coordinate + points_hom = np.hstack( + (points.reshape(-1, 3), np.ones((points.reshape(-1, 3).shape[0], 1))) + ) + # Transform + points_transformed = (transform_matrix @ points_hom.T).T[:, :3] + + # Reshape back + if len(original_shape) == 2: + return points_transformed + else: + return points_transformed.reshape(original_shape) + + def kl_loss(self, latent_plan, embodiment): + """ + KL divergence loss between robot and human latent plan distributions. + Domain adaptation loss that encourages human and robot latent plans to align. + Args: + latent_plan: (B, hidden_dim) latent representations from CLS token + embodiment: list of strings ["human", "aloha", ...] indicating data source + """ + # Separate by embodiment + human_mask = torch.tensor( + [e == "human" for e in embodiment], device=latent_plan.device + ) + robot_mask = torch.tensor( + [e == "aloha" for e in embodiment], device=latent_plan.device + ) + + # Need both human and robot samples in the batch + if not (human_mask.any() and robot_mask.any()): + return torch.tensor(0.0, device=latent_plan.device) + + human_latents = latent_plan[human_mask] # (N_h, D) + robot_latents = latent_plan[robot_mask] # (N_r, D) + + # Compute distribution statistics for each embodiment + mu_h = human_latents.mean(dim=0) # (D,) + mu_r = robot_latents.mean(dim=0) # (D,) + + sigma_h = human_latents.std(dim=0) + 1e-6 + sigma_r = robot_latents.std(dim=0) + 1e-6 + + # DKL(Qr || Qh) + kl = ( + 0.5 + * ( + 2 * torch.log(sigma_h / sigma_r) + + (sigma_r**2 + (mu_r - mu_h) ** 2) / (sigma_h**2) + - 1.0 + ).sum() + ) + + return kl + + def forward(self, batch): + """Forward pass with GMM loss""" + init, gt = self.extract_gt_4_points(batch) + + # Get gripper token (6DoF pose + gripper width) + gripper_token = self.get_gripper_token(init) + + # Apply gripper noise augmentation (training only) + gripper_token = self.apply_gripper_noise_to_token(gripper_token) + + # Combine primary + auxiliary cameras + rgb, depth, all_intrinsics, all_extrinsics = self.combine_camera_data(batch) + + # Forward through network + latent_plan, gmm_dist = self.network.mimicplay_forward( + rgb, + depth, + all_intrinsics, + all_extrinsics, + gripper_token=gripper_token, + text=batch["caption"], + source=batch["data_source"], + ) + + gt_trajectory = self.extract_gt_trajectory(batch).reshape(-1, 30) + gmm_loss = -gmm_dist.log_prob(gt_trajectory).mean() + kl_div = self.kl_loss(latent_plan, batch["data_source"]) + + loss = gmm_loss + self.kl_lambda * kl_div + loss_dict = { + "gmm_loss": gmm_loss, + "kl_div": kl_div, + } + + return None, loss, loss_dict + + def training_step(self, batch, batch_idx): + """Training step with 3D GMM prediction""" + assert ( + batch["augment_t"].mean().item() == 0.0 + ), "Disable pcd augmentations when training image model!" + + self.train() + batch_size = batch[self.label_key].points_padded().shape[0] + + _, loss, loss_dict = self(batch) + train_metrics = {"loss": loss} + train_metrics.update(loss_dict) + + # Additional logging + do_additional_logging = ( + self.global_step % self.additional_train_logging_period == 0 + and self.global_step != 0 + ) + + if do_additional_logging: + n_samples_wta = self.run_cfg.n_samples_wta + self.eval() + with torch.no_grad(): + pred_dict, pred_gmm = self.collect_and_stack_predictions( + batch, n_samples_wta + ) + self.train() + + gt_trajectory = self.extract_gt_trajectory(batch) + pred_dict = calc_traj_metrics( + pred_dict, + pred_dict[self.prediction_type]["all_pred"], + gt_trajectory, + ) + + pred_dict = self.calculate_pixel_metrics(pred_dict, batch, gt_trajectory) + train_metrics.update(pred_dict) + + if self.trainer.is_global_zero: + self.log_viz_to_wandb(batch, pred_dict, "train") + + self.train_outputs.append(train_metrics) + return loss + + @torch.no_grad() + def predict(self, batch, progress=False): + """ + Predict 3D goal points using GMM sampling. + Returns displacement from initial gripper position. + """ + init, gt = self.extract_gt_4_points(batch) + gripper_token = self.get_gripper_token(init) + + # Combine primary + auxiliary cameras + rgb, depth, all_intrinsics, all_extrinsics = self.combine_camera_data(batch) + + # Forward + latent_plan, gmm_dist = self.network.mimicplay_forward( + rgb, + depth, + all_intrinsics, + all_extrinsics, + gripper_token=gripper_token, + text=batch["caption"], + source=batch["data_source"], + ) + + pred_traj = gmm_dist.sample().reshape(-1, self.network.pred_timesteps, 3) + return {self.prediction_type: {"pred": pred_traj}}, gmm_dist + + def log_viz_to_wandb(self, batch, pred_dict, tag): + """ + Log 3D visualizations to wandb (similar to articubot.py). + """ + batch_size = batch[self.label_key].points_padded().shape[0] + viz_idx = np.random.randint(0, batch_size) + RED, GREEN, BLUE = (255, 0, 0), (0, 255, 0), (0, 0, 255) + max_depth = self.max_depth + + all_pred = pred_dict[self.prediction_type]["all_pred"][viz_idx].cpu().numpy() + N = all_pred.shape[0] + end2start = np.linalg.inv(batch["start2end"][viz_idx].cpu().numpy()) + + goal_text = batch["caption"][viz_idx] + vid_name = batch["vid_name"][viz_idx] + rmse = pred_dict["rmse"][viz_idx] + + gt_trajectory = self.extract_gt_trajectory(batch) + + pcd, gt = self.extract_gt_4_points(batch) + pcd, gt = pcd.cpu().numpy()[viz_idx], gt.cpu().numpy()[viz_idx] + all_pred_pcd = all_pred + gt_pcd = self.extract_gt_trajectory(batch)[viz_idx].cpu().numpy() + padding_mask = torch.ones(gt_pcd.shape[0]).bool().numpy() + + # Invert augmentation transforms before viz + pcd_mean = batch["pcd_mean"][viz_idx].cpu().numpy() + pcd_std = batch["pcd_std"][viz_idx].cpu().numpy() + R = batch["augment_R"][viz_idx].cpu().numpy() + t = batch["augment_t"][viz_idx].cpu().numpy() + scene_centroid = batch["augment_C"][viz_idx].cpu().numpy() + + pcd = invert_augmentation_and_normalization( + pcd, pcd_mean, pcd_std, R, t, scene_centroid + ) + all_pred_pcd = invert_augmentation_and_normalization( + all_pred_pcd, pcd_mean, pcd_std, R, t, scene_centroid + ) + gt_pcd = invert_augmentation_and_normalization( + gt_pcd, pcd_mean, pcd_std, R, t, scene_centroid + ) + + # Transform to end frame + pcd_endframe = self.transform_points_homogeneous(pcd, end2start) + all_pred_pcd = np.stack( + [ + self.transform_points_homogeneous(all_pred_pcd[i], end2start) + for i in range(N) + ] + ) + gt_pcd = self.transform_points_homogeneous(gt_pcd, end2start) + + # Transform from world frame to primary camera frame for projection + # Primary camera extrinsics: T_world_from_cam, we need T_cam_from_world + primary_extrinsics = batch["extrinsics"][viz_idx].cpu().numpy() # (4, 4) + T_cam_from_world = np.linalg.inv(primary_extrinsics) + + # Transform points to primary camera frame + pcd_endframe = self.transform_points_homogeneous(pcd_endframe, T_cam_from_world) + all_pred_pcd = np.stack( + [ + self.transform_points_homogeneous(all_pred_pcd[i], T_cam_from_world) + for i in range(N) + ] + ) + gt_pcd = self.transform_points_homogeneous(gt_pcd, T_cam_from_world) + + K = batch["intrinsics"][viz_idx].cpu().numpy() + + rgb_init, rgb_end = ( + batch["rgbs"][viz_idx, 0].cpu().numpy(), + batch["rgbs"][viz_idx, 1].cpu().numpy(), + ) + depth_init, depth_end = ( + batch["depths"][viz_idx, 0].cpu().numpy(), + batch["depths"][viz_idx, 1].cpu().numpy(), + ) + + # Project tracks to image with color interpolation + YELLOW = (255, 255, 0) + GREEN = (0, 255, 0) + + # GT trajectory: RED to YELLOW gradient + RED2YELLOW = interpolate_colors(RED, YELLOW, gt_pcd.shape[0]) + init_rgb_proj = project_pcd_on_image( + gt_pcd, padding_mask, rgb_init, K, RED2YELLOW, radius=3 + ) + end_rgb_proj = project_pcd_on_image( + gt_pcd, padding_mask, rgb_end, K, RED2YELLOW, radius=3 + ) + + # Predicted trajectory: BLUE to GREEN gradient + BLUE2GREEN = interpolate_colors(BLUE, GREEN, all_pred_pcd[-1].shape[0]) + pred_rgb_proj = project_pcd_on_image( + all_pred_pcd[-1], padding_mask, rgb_end, K, BLUE2GREEN, radius=3 + ) + rgb_proj_viz = cv2.hconcat([init_rgb_proj, end_rgb_proj, pred_rgb_proj]) + + wandb_proj_img = wandb.Image( + rgb_proj_viz, + caption=f"Left: Initial Frame (GT Track)\n; Middle: Final Frame (GT Track)\n\ + ; Right: Final Frame (Pred Track)\n; Goal Description : {goal_text};\n\ + rmse={rmse};\nvideo path = {vid_name}; ", + ) + + # Create BLUE to GREEN gradients for all predictions + # Each prediction gets a gradient from a shade of blue to corresponding shade of green + BLUES2GREENS = [] + for i in range(N): + start_blue = ( + (int(200 * (1 - i / (N - 1))), int(220 * (1 - i / (N - 1))), 255) + if N > 1 + else (200, 220, 255) + ) + end_green = ( + (int(200 * (1 - i / (N - 1))), 255, int(220 * (1 - i / (N - 1)))) + if N > 1 + else (200, 255, 220) + ) + gradient = interpolate_colors( + start_blue, end_green, all_pred_pcd[i].shape[0] + ) + BLUES2GREENS.append(gradient) + + # Visualize point cloud + viz_pcd, _ = get_img_and_track_pcd( + rgb_end, + depth_end, + K, + padding_mask, + gt_pcd, # repeating twice + gt_pcd, + all_pred_pcd, + GREEN, + RED2YELLOW, + BLUES2GREENS, + max_depth, + 4096, + ) + + viz_dict = { + f"{tag}/track_projected_to_rgb": wandb_proj_img, + f"{tag}/image_and_tracks_pcd": wandb.Object3D(viz_pcd), + "trainer/global_step": self.global_step, + } + + wandb.log(viz_dict) + + def on_train_epoch_end(self): + if len(self.train_outputs) == 0: + return + + log_dictionary = {} + loss = torch.stack([x["loss"] for x in self.train_outputs]).mean() + log_dictionary["train/loss"] = loss + + def mean_metric(metric_name): + return torch.stack( + [x[metric_name].mean() for x in self.train_outputs if metric_name in x] + ).mean() + + # Log loss_dict components + if any("gmm_loss" in x for x in self.train_outputs): + log_dictionary["train/gmm_loss"] = mean_metric("gmm_loss") + if any("kl_div" in x for x in self.train_outputs): + log_dictionary["train/kl_div"] = mean_metric("kl_div") + if any("mse_loss" in x for x in self.train_outputs): + log_dictionary["train/mse_loss"] = mean_metric("mse_loss") + if any("ot_loss" in x for x in self.train_outputs): + log_dictionary["train/ot_loss"] = mean_metric("ot_loss") + + if any("rmse" in x for x in self.train_outputs): + log_dictionary["train/rmse"] = mean_metric("rmse") + log_dictionary["train/wta_rmse"] = mean_metric("wta_rmse") + log_dictionary["train/pix_dist"] = mean_metric("pix_dist") + log_dictionary["train/wta_pix_dist"] = mean_metric("wta_pix_dist") + log_dictionary["train/normalized_pix_dist"] = mean_metric( + "normalized_pix_dist" + ) + log_dictionary["train/wta_normalized_pix_dist"] = mean_metric( + "wta_normalized_pix_dist" + ) + + self.log_dict( + log_dictionary, + add_dataloader_idx=False, + prog_bar=True, + sync_dist=True, + ) + self.train_outputs.clear() + + def on_validation_epoch_start(self): + self.random_val_viz_idx = { + k: random.randint(0, len(v) - 1) + for k, v in self.trainer.val_dataloaders.items() + } + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + """Validation step for 3D goal prediction""" + val_tag = self.trainer.datamodule.val_tags[dataloader_idx] + n_samples_wta = self.run_cfg.n_samples_wta + self.eval() + with torch.no_grad(): + pred_dict, pred_gmm = self.collect_and_stack_predictions( + batch, n_samples_wta + ) + + gt_trajectory = self.extract_gt_trajectory(batch) + pred_dict = calc_traj_metrics( + pred_dict, + pred_dict[self.prediction_type]["all_pred"], + gt_trajectory, + ) + + pred_dict = self.calculate_pixel_metrics(pred_dict, batch, gt_trajectory) + self.val_outputs[val_tag].append(pred_dict) + + if ( + batch_idx == self.random_val_viz_idx[val_tag] + and self.trainer.is_global_zero + ): + self.log_viz_to_wandb(batch, pred_dict, f"val_{val_tag}") + return pred_dict + + def on_validation_epoch_end(self): + log_dict = {} + all_metrics = { + "rmse": [], + "wta_rmse": [], + "pix_dist": [], + "wta_pix_dist": [], + "normalized_pix_dist": [], + "wta_normalized_pix_dist": [], + } + + for val_tag in self.trainer.datamodule.val_tags: + val_outputs = self.val_outputs[val_tag] + tag_metrics = {} + + if len(val_outputs) == 0: + continue + + for metric in all_metrics.keys(): + values = torch.stack([x[metric].mean() for x in val_outputs]).mean() + tag_metrics[metric] = values + all_metrics[metric].append(values) + + for metric, value in tag_metrics.items(): + log_dict[f"val_{val_tag}/{metric}"] = value + + for metric, values in all_metrics.items(): + log_dict[f"val/{metric}"] = torch.stack(values).mean() + + self.log_dict( + log_dict, + add_dataloader_idx=False, + sync_dist=True, + ) + self.val_outputs.clear() + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + """Prediction step for model evaluation""" + eval_tag = self.trainer.datamodule.eval_tags[dataloader_idx] + n_samples_wta = self.trainer.datamodule.n_samples_wta + + pred_dict, pred_gmm = self.collect_and_stack_predictions(batch, n_samples_wta) + + gt_trajectory = self.extract_gt_trajectory(batch) + pred_dict = calc_traj_metrics( + pred_dict, + pred_dict[self.prediction_type]["all_pred"], + gt_trajectory, + ) + + pred_dict = self.calculate_pixel_metrics(pred_dict, batch, gt_trajectory) + self.predict_outputs[eval_tag].append(pred_dict) + + # Get pred_coord for visualization (take first point of trajectory) + intrinsics = batch["intrinsics"] + extrinsics = batch["extrinsics"] + H, W = batch["rgbs"].shape[2:4] + + # First sample prediction, first timestep + pred_first_timestep = pred_dict[self.prediction_type]["pred"][ + :, :1, : + ] # (B, 1, 3) + pred_coord_viz = ( + self.project_3d_to_2d(pred_first_timestep, intrinsics, extrinsics, (H, W)) + .squeeze(1) + .long() + ) # (B, 2) + + return { + "pred_coord": pred_coord_viz, + "rmse": pred_dict["rmse"], + "wta_rmse": pred_dict["wta_rmse"], + "pix_dist": pred_dict["pix_dist"], + "wta_pix_dist": pred_dict["wta_pix_dist"], + "vid_name": batch["vid_name"], + "caption": batch["caption"], + } + + def on_predict_epoch_end(self): + """Stub - implement if needed""" + pass diff --git a/src/lfd3d/utils/script_utils.py b/src/lfd3d/utils/script_utils.py index 7c064e6..1b04549 100644 --- a/src/lfd3d/utils/script_utils.py +++ b/src/lfd3d/utils/script_utils.py @@ -22,6 +22,7 @@ from lfd3d.models.dino_3dgp import Dino3DGPGoalRegressionModule, Dino3DGPNetwork from lfd3d.models.dino_heatmap import DinoHeatmapNetwork, HeatmapSamplerModule from lfd3d.models.diptv3 import DiPTv3, DiPTv3Adapter +from lfd3d.models.mimicplay import MimicplayModule from lfd3d.models.tax3d import ( CrossDisplacementModule, DiffusionTransformerNetwork, @@ -73,6 +74,10 @@ def create_model(cfg): elif cfg.model.name == "dino_3dgp": network_fn = Dino3DGPNetwork module_fn = Dino3DGPGoalRegressionModule + elif cfg.model.name == "mimicplay": + # The MimicPlay baseline is a modified version of Dino3DGP + network_fn = Dino3DGPNetwork + module_fn = MimicplayModule else: raise NotImplementedError(cfg.model.name) diff --git a/src/lfd3d/utils/viz_utils.py b/src/lfd3d/utils/viz_utils.py index 8b7b3dc..31ac57f 100644 --- a/src/lfd3d/utils/viz_utils.py +++ b/src/lfd3d/utils/viz_utils.py @@ -3,18 +3,56 @@ import cv2 import imageio.v3 as iio +import matplotlib.pyplot as plt import numpy as np import torch import trimesh from matplotlib import cm from pytorch3d.ops import sample_farthest_points -import matplotlib.pyplot as plt -def project_pcd_on_image(pcd, mask, image, K, color, return_coords=False): +def interpolate_colors(color_start, color_end, n_points): + """ + Interpolate between two colors. + + Args: + color_start: (R, G, B) starting color + color_end: (R, G, B) ending color + n_points: Number of interpolation points + + Returns: + Array of shape (n_points, 3) with interpolated colors + """ + if n_points == 1: + return np.array([color_start]) + + color_start = np.array(color_start) + color_end = np.array(color_end) + + # Linear interpolation + alphas = np.linspace(0, 1, n_points)[:, None] # (n_points, 1) + colors = (1 - alphas) * color_start + alphas * color_end + + return colors.astype(int) + + +def project_pcd_on_image(pcd, mask, image, K, color, return_coords=False, radius=1): """ Project point cloud onto image, overwrite projected points with the provided colour. + + Args: + pcd: Point cloud array (N, 3) + mask: Boolean mask for points to project + image: RGB image (H, W, 3) + K: Camera intrinsics (3, 3) + color: Either a single color tuple (R, G, B) or a sequence of colors [(R, G, B), ...] + If sequence, length must match number of masked points + return_coords: Whether to return projected coordinates + radius: Circle radius for visualization + + Returns: + viz_image or (coords, viz_image) if return_coords=True """ height, width, ch = image.shape viz_image = image.copy() @@ -26,8 +64,35 @@ def project_pcd_on_image(pcd, mask, image, K, color, return_coords=False): projected_image_coords = projected_points.T.round().astype(int) coords = np.clip(projected_image_coords, 0, [width - 1, height - 1]) - for point in coords: - cv2.circle(viz_image, point, color=color, thickness=-1, radius=1) + + # Check if color is a sequence of colors or a single color + # Single color: tuple/list with 3 elements (R, G, B) + # Sequence of colors: list/array where each element is a color + is_single_color = ( + isinstance(color, (tuple, list)) + and len(color) == 3 + and all(isinstance(c, (int, np.integer)) for c in color) + ) + + if is_single_color: + # Use same color for all points (backward compatible) + for point in coords: + cv2.circle(viz_image, point, color=color, thickness=-1, radius=radius) + else: + # Use different color for each point + colors = np.array(color) + if len(colors) != len(coords): + raise ValueError( + f"Number of colors ({len(colors)}) must match number of points ({len(coords)})" + ) + for point, pt_color in zip(coords, colors): + cv2.circle( + viz_image, + point, + color=tuple(pt_color.tolist()), + thickness=-1, + radius=radius, + ) if return_coords: return coords, viz_image @@ -91,28 +156,76 @@ def get_img_and_track_pcd( max_depth, num_points, ): - init_pcd_color, all_pred_color, gt_color = ( - np.array(init_pcd_color), - np.array(all_pred_color), - np.array(gt_color), + """ + Create a combined point cloud visualization with image points and trajectory points. + + Args: + image: RGB image + depth: Depth map + K: Camera intrinsics + mask: Boolean mask for trajectory points + init_pcd: Initial trajectory points + gt_pcd: Ground truth trajectory points + all_pred_pcd: List of predicted trajectories + init_pcd_color: Single color (R,G,B) or sequence of colors (N_init, 3) + gt_color: Single color (R,G,B) or sequence of colors (N_gt, 3) + all_pred_color: Either: + - Array (N_preds, 3): one color per prediction (backward compatible) + - List of arrays [(N_points, 3), ...]: color sequence per prediction + max_depth: Maximum depth for filtering + num_points: Number of points to sample from image + + Returns: + viz_pcd: Combined point cloud with colors + num_pred_points: Number of prediction points + """ + init_pcd_color = np.array(init_pcd_color) + gt_color = np.array(gt_color) + all_pred_color = ( + np.array(all_pred_color) + if not isinstance(all_pred_color, list) + else all_pred_color ) image_pcd_pts, image_pcd_colors = get_img_pcd( image, depth, K, max_depth, num_points ) + # Process init_pcd colors init_pcd_pts = init_pcd[mask] - init_pcd_color = np.repeat(init_pcd_color[None, :], init_pcd_pts.shape[0], axis=0) + if init_pcd_color.ndim == 1: + # Single color: repeat for all points + init_pcd_color = np.repeat( + init_pcd_color[None, :], init_pcd_pts.shape[0], axis=0 + ) + else: + # Color sequence: apply mask + init_pcd_color = init_pcd_color[mask] + # Process gt_pcd colors gt_pcd_pts = gt_pcd[mask] - gt_color = np.repeat(gt_color[None, :], gt_pcd_pts.shape[0], axis=0) + if gt_color.ndim == 1: + # Single color: repeat for all points + gt_color = np.repeat(gt_color[None, :], gt_pcd_pts.shape[0], axis=0) + else: + # Color sequence: apply mask + gt_color = gt_color[mask] + # Process all_pred_pcd colors all_pred_pcd_pts, all_pred_colors = [], [] for i, pred_pcd in enumerate(all_pred_pcd): pred_pcd_pts = pred_pcd[mask] - pred_color = np.repeat( - all_pred_color[None, i], all_pred_pcd[0][mask].shape[0], axis=0 - ) + + # Check if all_pred_color is list of color sequences or array of single colors + if isinstance(all_pred_color, list): + # List of color sequences + pred_color = all_pred_color[i][mask] + else: + # Array of single colors (backward compatible) + pred_color = np.repeat( + all_pred_color[None, i], pred_pcd_pts.shape[0], axis=0 + ) + all_pred_pcd_pts.append(pred_pcd_pts) all_pred_colors.append(pred_color) @@ -313,12 +426,13 @@ def save_video( print(f"saving video of size {(n, h, w)} to {save_path}") iio.imwrite(save_path, frames, fps=fps, extension=".webm", codec="vp9") + def plot_seq_data(data, title, xlabel, ylabel, path): """ Plot a sequence of numeric data with its mean and variance statistics. - The function creates a line plot of the input data sequence, adds a - horizontal line representing the mean value, and displays the variance + The function creates a line plot of the input data sequence, adds a + horizontal line representing the mean value, and displays the variance in a text box on the plot. The figure is then saved to the given file path. Args: @@ -335,13 +449,18 @@ def plot_seq_data(data, title, xlabel, ylabel, path): avg = float(np.nanmean(data)) std = float(np.nanstd(data)) - ax.axhline(avg, color="red", linestyle="--", linewidth=1.2, label=f"Mean = {avg:.4f}") - ax.text(0.02, 0.95, + ax.axhline( + avg, color="red", linestyle="--", linewidth=1.2, label=f"Mean = {avg:.4f}" + ) + ax.text( + 0.02, + 0.95, f"Std = {std:.4f}", transform=ax.transAxes, fontsize=9, verticalalignment="top", - bbox=dict(boxstyle="round", facecolor="white", alpha=0.6)) + bbox=dict(boxstyle="round", facecolor="white", alpha=0.6), + ) ax.set_title(title) ax.set_xlabel(xlabel) @@ -353,16 +472,26 @@ def plot_seq_data(data, title, xlabel, ylabel, path): fig.savefig(path, dpi=150, bbox_inches="tight") plt.close(fig) + def plot_barchart_with_error(data, error, title, xlabel, ylabel, path): """ Plot data as a bar with an error bar. """ data = np.asarray(data) - error = np.asarray(error) + error = np.asarray(error) x = np.arange(len(data)) fig, ax = plt.subplots() - ax.bar(x, data, yerr=error, capsize=4, width=1.0, align="edge", color="skyblue", edgecolor="black") + ax.bar( + x, + data, + yerr=error, + capsize=4, + width=1.0, + align="edge", + color="skyblue", + edgecolor="black", + ) ax.set_xticks(x + 0.5) ax.set_xticklabels([str(i) for i in x]) @@ -377,7 +506,7 @@ def plot_barchart_with_error(data, error, title, xlabel, ylabel, path): plt.close(fig) -def annotate_video(frames, annotation=None, path = None, fps=30): +def annotate_video(frames, annotation=None, path=None, fps=30): """ Args: frames: np.ndarray of shape (N, H, W, 3) dtype=uint8 @@ -390,20 +519,31 @@ def annotate_video(frames, annotation=None, path = None, fps=30): font = cv2.FONT_HERSHEY_SIMPLEX annotated_frames = [] - pos = [(10, 30), (W - 300, 30), (10, H - 30), (W - 300, H - 30),] + pos = [ + (10, 30), + (W - 300, 30), + (10, H - 30), + (W - 300, H - 30), + ] for i, frame in enumerate(frames): img = frame.copy() if annotation is not None: for j, (key, value) in enumerate(annotation.items()): cv2.putText( - img, f"{key}: {value[i]:.4f}", pos[j], - font, 1, (255, 0, 0), 2, cv2.LINE_AA + img, + f"{key}: {value[i]:.4f}", + pos[j], + font, + 1, + (255, 0, 0), + 2, + cv2.LINE_AA, ) - + annotated_frames.append(img) annotated_frames = np.stack(annotated_frames) if path: iio.imwrite(path, annotated_frames, fps=fps) else: - return annotated_frames \ No newline at end of file + return annotated_frames