Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions configs/model/dino_3dgp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: dino_3dgp

# Model settings
type: cross_displacement # Type of model
use_text_embedding: True # If true, expects (siglip) text embedding to be provided as input and uses as another token
use_text_embedding: True # If true, expects text to be provided as input
use_gripper_token: True # Adds an additional gripper token
use_source_token: True # If true, adds a learnable token for human/robot data source
num_transformer_layers: 4 # num blocks for self atttention
Expand All @@ -20,7 +20,7 @@ fixed_variance: [0.01, 0.05, 0.1, 0.25, 0.5] # for gmm
uniform_weights_coeff: 0.1 # coefficient for nll loss term when we use uniform mixing weights instead of pred

# Optimal Transport loss settings (for domain adaptation)
use_ot_loss: True # Enable optimal transport loss for aligning human/robot latent distributions
use_ot_loss: False # Enable optimal transport loss for aligning human/robot latent distributions
ot_alpha: 0.05 # Weight for combining OT loss with main loss
ot_lambda: 0.1 # Discount factor for matching latents (lower = stronger discount)
ot_epsilon: 0.1 # Regularization parameter for Sinkhorn algorithm
Expand Down
7 changes: 7 additions & 0 deletions configs/model/mimicplay.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- dino_3dgp # Mimicplay baseline is a modified version of Dino3DGP

name: mimicplay

kl_lambda: 1.0 # Weight for KL-div loss
gmm_min_std: 0.0001
15 changes: 15 additions & 0 deletions configs/training/rpadLerobot_mimicplay.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
defaults:
- base_train

# Override augment_train to image_color_only for dino 3dgp (safe for multiview)
augment_train: "image_color_only"

epochs: 100
batch_size: 128
val_batch_size: 128

# ModelCheckpoint configurations
checkpoints:
rmse:
monitor: val/rmse
mode: min
20 changes: 20 additions & 0 deletions scripts/latent_analyze_all.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
python eval_lerobot_episode.py checkpoint.run_id=xa8xefu6 dataset=rpadLerobot dataset.repo_id="sriramsk/fold_onesie_MV_20251025_ss_hg_mini" model=dino_3dgp checkpoint.type=rmse resources.num_workers=32 dataset.val_episode_ratio=1 inference.n_eval_episode=10

python eval_lerobot_episode.py checkpoint.run_id=xa8xefu6 dataset=rpadLerobot dataset.repo_id="sriramsk/fold_onesie_MV_20251210_ss_hg" model=dino_3dgp checkpoint.type=rmse resources.num_workers=32 dataset.val_episode_ratio=1 inference.n_eval_episode=10

python eval_lerobot_episode.py checkpoint.run_id=xa8xefu6 dataset=rpadLerobot dataset.repo_id="sriramsk/fold_onesie_MVHuman_20251210_ss_hg" model=dino_3dgp checkpoint.type=rmse resources.num_workers=32 dataset.val_episode_ratio=1 inference.n_eval_episode=10

python eval_lerobot_episode.py checkpoint.run_id=xa8xefu6 dataset=rpadLerobot dataset.repo_id="sriramsk/fold_shirt_MV_20251210_ss_hg" model=dino_3dgp checkpoint.type=rmse resources.num_workers=32 dataset.val_episode_ratio=1 inference.n_eval_episode=10

python eval_lerobot_episode.py checkpoint.run_id=xa8xefu6 dataset=rpadLerobot dataset.repo_id="sriramsk/fold_towel_forEval_MV_20251210_ss_hg" model=dino_3dgp checkpoint.type=rmse resources.num_workers=32 dataset.val_episode_ratio=1 inference.n_eval_episode=10

python eval_lerobot_episode.py checkpoint.run_id=xa8xefu6 dataset=rpadLerobot dataset.repo_id="sriramsk/fold_towel_MVHuman_20251210_ss_hg" model=dino_3dgp checkpoint.type=rmse resources.num_workers=32 dataset.val_episode_ratio=1 inference.n_eval_episode=10


python analyze_latents.py --dset1_path logs/xa8xefu6_sriramsk/fold_onesie_MV_20251210_ss_hg --dset2_path logs/xa8xefu6_sriramsk/fold_onesie_MVHuman_20251210_ss_hg --output robotOnesie_humanOnesie_withOTv2.png

python analyze_latents.py --dset1_path logs/xa8xefu6_sriramsk/fold_onesie_MV_20251210_ss_hg --dset2_path logs/xa8xefu6_sriramsk/fold_onesie_MV_20251025_ss_hg_mini --output robotOnesie_robotOnesie_withOTv2.png

python analyze_latents.py --dset1_path logs/xa8xefu6_sriramsk/fold_onesie_MV_20251210_ss_hg --dset2_path logs/xa8xefu6_sriramsk/fold_shirt_MV_20251210_ss_hg --output robotOnesie_robotShirt_withOTv2.png

python analyze_latents.py --dset1_path logs/xa8xefu6_sriramsk/fold_towel_forEval_MV_20251210_ss_hg --dset2_path logs/xa8xefu6_sriramsk/fold_towel_MVHuman_20251210_ss_hg --output humanTowel_robotTowel_withOTv2.png
36 changes: 36 additions & 0 deletions src/lfd3d/datasets/lerobot/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,39 @@ def _transform_to_world_frame(self, points_cam, T_world_from_cam):

return points_world

def load_gripper_trajectory(self, idx, data_source):
episode_index = self.lerobot_dataset[idx]["episode_index"]
gripper_idx = self.GRIPPER_IDX[data_source]

GRIPPER_PCD_KEY = self.gripper_pcd_key

NUM_FRAMES = 70 # Close to 5 secs into the future at 15fps
NUM_TIMESTEPS = 10

gripper_pcds = []
end_idx = idx + (NUM_FRAMES // NUM_TIMESTEPS)
while (
(end_idx < len(self.lerobot_dataset))
and (self.lerobot_dataset[end_idx]["episode_index"] == episode_index)
and (len(gripper_pcds) < NUM_TIMESTEPS)
):
gripper_pcds.append(
self.lerobot_dataset[end_idx][GRIPPER_PCD_KEY][gripper_idx]
)
end_idx += NUM_FRAMES // NUM_TIMESTEPS

if len(gripper_pcds) < NUM_TIMESTEPS:
last_pcd = (
gripper_pcds[-1]
if len(gripper_pcds) > 0
else self.lerobot_dataset[idx][GRIPPER_PCD_KEY][gripper_idx]
)
while len(gripper_pcds) < NUM_TIMESTEPS:
gripper_pcds.append(last_pcd)

gripper_trajectory = np.stack(gripper_pcds, axis=0)
return gripper_trajectory

def __getitem__(self, index):
# Map the dataset index to the actual LeRobot dataset index using split_indices
actual_index = self.split_indices[index]
Expand All @@ -229,6 +262,8 @@ def __getitem__(self, index):
cam_names,
) = self.load_transition(actual_index)

gripper_trajectory = self.load_gripper_trajectory(actual_index, data_source)

# Load intrinsics and extrinsics for all cameras
all_intrinsics = []
all_extrinsics = []
Expand Down Expand Up @@ -358,6 +393,7 @@ def __getitem__(self, index):
"action_pcd": start_tracks,
"anchor_pcd": start_scene_pcd_world,
"anchor_feat_pcd": start_scene_feat_pcd,
"gripper_trajectory": gripper_trajectory,
# Labels
"cross_displacement": end_tracks - start_tracks,
# Text/metadata
Expand Down
Loading
Loading