Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions configs/dataset/liberoLerobot.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@ repo_id: sriramsk/libero_lerobot_singleTask_heatmapGoal
use_subgoals: False

max_depth: 2.0 # Filter out PCD beyond this depth
val_episode_ratio: 0.1 # percent of episodes in val set

# Camera keys in the dataset
cameras:
- name: cam_libero
color_key: "observation.images.cam_libero.color"
depth_key: "observation.images.cam_libero.transformed_depth"
intrinsics: "libero_franka_calibration/intrinsics.txt"
extrinsics: "libero_franka_calibration/agentview_cam_to_world.txt"


gripper_pcd_key: "observation.points.gripper_pcds"

rgb_feat: False # If true, compute DINOv2 features, else just return RGB
10 changes: 4 additions & 6 deletions configs/dataset/rpadLerobot.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,16 @@ data_dir: null # don't want to override automatic computation, unless you do...
repo_id: beisner/aloha_plate_placement_goal
use_subgoals: False

val_episode_ratio: 0.1 # percent of episodes in val set

# Multi-camera configuration (first camera is primary)
cameras:
- name: front
- name: cam_azure_kinect_front
color_key: "observation.images.cam_azure_kinect_front.color"
depth_key: "observation.images.cam_azure_kinect_front.transformed_depth"
intrinsics: "aloha_calibration/intrinsics_000259921812.txt"
extrinsics: "aloha_calibration/T_world_from_camera_front_v1_1020.txt"
- name: back
- name: cam_azure_kinect_back
color_key: "observation.images.cam_azure_kinect_back.color"
depth_key: "observation.images.cam_azure_kinect_back.transformed_depth"
intrinsics: "aloha_calibration/intrinsics_000003493812.txt"
extrinsics: "aloha_calibration/T_world_from_camera_back_v1_1020.txt"

gripper_pcd_key: "observation.points.gripper_pcds"

Expand Down
1 change: 1 addition & 0 deletions configs/inference/base_inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ batch_size: 16
val_batch_size: 16
n_samples_wta: 5 # number of samples for evaluating coverage
save_wta_to_disk: False # save to disk for further examination
n_eval_episode: 1

# Augmentation settings
augment_train: null # Options: "pcd", "image", or null
Expand Down
1 change: 0 additions & 1 deletion configs/inference/liberoLerobot_dino_heatmap.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
defaults:
- base_inference
loss_type: "cross_entropy" # rkl_div, mkl_div, fkl_div, mse, cross_entropy
n_eval_episode: 5
12 changes: 2 additions & 10 deletions scripts/eval_lerobot_episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,10 @@ def predict_dataloader(self):
return self.dataloaders


def random_episode(episode_idx, n):
if n > len(episode_idx):
return episode_idx
return random.sample(episode_idx, n)


def get_eval_datamodule_episode(datamodule, inference_cfg):
tags = datamodule.val_tags
eval_dataloaders, eval_tags, episode_idx = [], [], []
episode_num = (
inference_cfg.n_eval_episode if "n_eval_episode" in inference_cfg.keys() else 1
)
episode_num = inference_cfg.n_eval_episode

for i, (tag, loader) in enumerate(datamodule.test_dataloader().items()):
eval_dataloaders.extend(loader.values())
Expand All @@ -60,7 +52,7 @@ def get_eval_datamodule_episode(datamodule, inference_cfg):
# eval_dataloaders.append(loader)
# eval_tags.append(f"val_{tag}")

random_id = random_episode(list(range(0, len(episode_idx))), episode_num)
random_id = random.sample(list(range(0, len(episode_idx))), episode_num)
eval_dataloaders = [eval_dataloaders[i] for i in random_id]
eval_tags = [eval_tags[i] for i in random_id]
episode_idx = [episode_idx[i] for i in random_id]
Expand Down
50 changes: 14 additions & 36 deletions src/lfd3d/datasets/lerobot/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,33 +183,6 @@ def load_transition(self, idx):
cam_names,
)

def _load_camera_intrinsics(self, intrinsics_path, data_source):
"""Load camera intrinsics from file.

Args:
intrinsics_path: Relative path to intrinsics file (e.g., "aloha_calibration/intrinsics_xxx.txt")
data_source: Data source name (e.g., 'aloha', 'human', 'libero_franka')

Returns:
np.ndarray: 3x3 intrinsics matrix
"""
file_path = Path(__file__).parent.parent / intrinsics_path
return np.loadtxt(file_path)

def _load_camera_extrinsics(self, extrinsics_path, data_source):
"""Load camera extrinsics (T_world_from_camera) from file.

Args:
extrinsics_path: Relative path to extrinsics file (e.g., "aloha_calibration/T_world_from_camera_xxx.txt")
data_source: Data source name (e.g., 'aloha', 'human', 'libero_franka')

Returns:
np.ndarray: 4x4 transformation matrix (T_world_from_camera)
"""
file_path = Path(__file__).parent.parent / extrinsics_path
T = np.loadtxt(file_path).astype(np.float32)
return T.reshape(4, 4)

def _transform_to_world_frame(self, points_cam, T_world_from_cam):
"""Transform points from camera frame to world frame.

Expand Down Expand Up @@ -253,16 +226,18 @@ def __getitem__(self, index):
# Load intrinsics and extrinsics for all cameras
all_intrinsics = []
all_extrinsics = []
for cam_cfg in self.cameras:
# Load intrinsics
K = self._load_camera_intrinsics(cam_cfg.intrinsics, data_source)
for cam_name in cam_names:
K = self.lerobot_dataset[actual_index][
f"observation.{cam_name}.intrinsics"
].numpy()
K_scaled = BaseDataset.get_scaled_intrinsics(
K, orig_shape, self.target_shape
)
all_intrinsics.append(K_scaled)

# Load extrinsics
T = self._load_camera_extrinsics(cam_cfg.extrinsics, data_source)
T = self.lerobot_dataset[actual_index][
f"observation.{cam_name}.extrinsics"
].numpy()
all_extrinsics.append(T)

# Gripper tracks
Expand Down Expand Up @@ -331,8 +306,12 @@ def __getitem__(self, index):
aux_extrinsics = np.stack(all_extrinsics[1:], axis=0) # (num_aux, 4, 4)
else:
# No auxiliary cameras
aux_rgbs = np.zeros((0, 2, self.target_shape, self.target_shape, 3), dtype=np.uint8)
aux_depths = np.zeros((0, 2, self.target_shape, self.target_shape), dtype=np.float32)
aux_rgbs = np.zeros(
(0, 2, self.target_shape, self.target_shape, 3), dtype=np.uint8
)
aux_depths = np.zeros(
(0, 2, self.target_shape, self.target_shape), dtype=np.float32
)
aux_intrinsics = np.zeros((0, 3, 3), dtype=np.float32)
aux_extrinsics = np.zeros((0, 4, 4), dtype=np.float32)

Expand Down Expand Up @@ -384,7 +363,6 @@ def __init__(
num_workers,
dataset_cfg,
seed,
val_episode_ratio=0.1,
augment_train="image",
augment_cfg=None,
):
Expand All @@ -400,7 +378,7 @@ def __init__(
self.val_tags = [] # populated in _generate_episode_splits
# Subset of train to use for eval
self.TRAIN_SUBSET_SIZE = 20
self.val_episode_ratio = val_episode_ratio
self.val_episode_ratio = dataset_cfg.val_episode_ratio
self.train_indices = None
self.val_indices = None

Expand Down
55 changes: 35 additions & 20 deletions src/lfd3d/models/dino_3dgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
rotation_6d_to_matrix,
)
from torch import nn, optim
from transformers import AutoImageProcessor, AutoModel
from transformers import AutoImageProcessor, AutoModel, T5EncoderModel, T5Tokenizer

from lfd3d.models.dino_heatmap import calc_pix_metrics
from lfd3d.models.tax3d import calc_pcd_metrics
Expand Down Expand Up @@ -81,7 +81,7 @@ class Dino3DGPNetwork(nn.Module):
DINOv2 + 3D positional encoding + Transformer for 3D goal prediction
Architecture:
- Image tokens: DINOv2 patches with 3D PE (x,y,z from depth)
- Language token: SigLIP embedding (optional)
- Language tokens: Flan-T5 (optional)
- Gripper token: 6DoF pose + gripper width (optional)
- Source token: learnable embedding for human/robot (optional)
- Transformer: self-attention blocks
Expand Down Expand Up @@ -131,11 +131,14 @@ def __init__(self, model_cfg):
nn.Linear(128, self.pos_encoding_dim),
)

# Language token encoder
# Language encoder
self.use_text_embedding = model_cfg.use_text_embedding
if self.use_text_embedding:
self.text_encoder = nn.Sequential(
nn.Linear(1152, 256), # SIGLIP input dim
self.text_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
self.text_encoder = T5EncoderModel.from_pretrained("google/flan-t5-base")
self.text_encoder.requires_grad_(False) # Freeze
self.text_proj = nn.Sequential(
nn.Linear(768, 256), # Flan-T5 output dim
nn.ReLU(),
nn.Linear(256, self.hidden_dim),
)
Expand Down Expand Up @@ -344,7 +347,7 @@ def forward(
intrinsics,
extrinsics,
gripper_token=None,
text_embedding=None,
text=None,
source=None,
):
"""
Expand All @@ -356,8 +359,8 @@ def forward(
intrinsics: (B, N, 3, 3) camera intrinsics
extrinsics: (B, N, 4, 4) camera-to-world transforms
gripper_token: (B, 10) [6DoF pose (3 pos + 6 rot6d) + gripper width]
text_embedding: (B, 1152) SigLIP embedding
source: list of strings ["human" or "aloha"]
text: (B, ) Text captions
source: (B, ) ["human" or "aloha"]

Returns:
outputs: (B, T, 13) GMM parameters for all cameras
Expand Down Expand Up @@ -399,20 +402,31 @@ def forward(

# Number of tokens T <= N*256
num_patch_tokens = tokens.shape[1]
mask = torch.zeros(B, num_patch_tokens, dtype=torch.bool, device=tokens.device)

# Add language token
# Add language tokens
if self.use_text_embedding:
lang_token = self.text_encoder(text_embedding).unsqueeze(
1
) # (B, 1, hidden_dim)
tokens = torch.cat([tokens, lang_token], dim=1) # (B, T+1, hidden_dim)
text_tokens = self.text_tokenizer(
text, return_tensors="pt", padding=True, truncation=True
)
text_tokens = {
k: v.to(self.text_encoder.device) for k, v in text_tokens.items()
}
text_embedding = self.text_encoder(**text_tokens).last_hidden_state

lang_tokens = self.text_proj(text_embedding) # (B, J, hidden_dim)
tokens = torch.cat([tokens, lang_tokens], dim=1) # (B, T+J, hidden_dim)
mask = torch.cat([mask, text_tokens["attention_mask"] == 0], dim=1)

# Add gripper token
if self.use_gripper_token:
grip_token = self.gripper_encoder(gripper_token).unsqueeze(
1
) # (B, 1, hidden_dim)
tokens = torch.cat([tokens, grip_token], dim=1) # (B, T+2, hidden_dim)
tokens = torch.cat([tokens, grip_token], dim=1) # (B, T+J+1, hidden_dim)
mask = torch.cat(
[mask, torch.zeros(B, 1, dtype=torch.bool, device=tokens.device)], dim=1
)

# Add source token
if self.use_source_token:
Expand All @@ -422,11 +436,14 @@ def forward(
source_token = self.source_embeddings(source_indices).unsqueeze(
1
) # (B, 1, hidden_dim)
tokens = torch.cat([tokens, source_token], dim=1) # (B, T+3, hidden_dim)
tokens = torch.cat([tokens, source_token], dim=1) # (B, T+J+2, hidden_dim)
mask = torch.cat(
[mask, torch.zeros(B, 1, dtype=torch.bool, device=tokens.device)], dim=1
)

# Apply transformer blocks
for block in self.transformer_blocks:
tokens = block(tokens)
tokens = block(tokens, src_key_padding_mask=mask)

# Take only the patch tokens (throw away language, gripper, source tokens)
tokens = tokens[:, :num_patch_tokens] # (B, T, hidden_dim)
Expand Down Expand Up @@ -726,7 +743,6 @@ def nll_loss(

def forward(self, batch):
"""Forward pass with GMM loss"""
text_embedding = batch["text_embed"]
init, gt = self.extract_gt_4_points(batch)

# Get gripper token (6DoF pose + gripper width)
Expand Down Expand Up @@ -783,7 +799,7 @@ def forward(self, batch):
all_intrinsics,
all_extrinsics,
gripper_token=gripper_token,
text_embedding=text_embedding,
text=batch["caption"],
source=batch["data_source"],
)

Expand Down Expand Up @@ -931,7 +947,6 @@ def predict(self, batch, progress=False):
Predict 3D goal points using GMM sampling.
Returns displacement from initial gripper position.
"""
text_embedding = batch["text_embed"]
init, gt = self.extract_gt_4_points(batch)
gripper_token = self.get_gripper_token(init)

Expand Down Expand Up @@ -966,7 +981,7 @@ def predict(self, batch, progress=False):
all_intrinsics,
all_extrinsics,
gripper_token=gripper_token,
text_embedding=text_embedding,
text=batch["caption"],
source=batch["data_source"],
)

Expand Down
Loading