Skip to content

[Question] 请教下Pi3X的PointLoss相比Pi3增加了depth_loss和rays_loss,,这是出于什么考虑呢? #142

Description

@buzzcut2190

PointLoss中的代码:

    local_pts_loss = self.criteria_local(aligned_pred_local_pts[valid_masks].float(), gt_local_pts[valid_masks].float()) * weights[valid_masks].float()[..., None]
    details['local_pts_loss'] = local_pts_loss.mean()
    pix = torch.from_numpy(get_pixel(H, W).T.reshape(H, W, 3)).to(gt_local_pts.device).float()[:, :].repeat(B, N, 1, 1, 1)
    gt_rays = torch.einsum('bnij, bnhwj -> bnhwi', torch.inverse(gt['camera_intrinsics']), pix)[..., :2]
    rays_loss = F.l1_loss(pred['xy'], gt_rays)
    final_loss += rays_loss
    details['rays_loss'] = rays_loss

    depth_loss = local_pts_loss[..., 2].mean()
    details['depth_loss'] = depth_loss
    final_loss += depth_loss

另外,从代码中看teacher模型推出的result['local_points']是深度吗?应该还不是points?

    with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
        result = self.teacher(imgs, gt['camera_poses'][teach_batch_id], gt['camera_intrinsics'][teach_batch_id])
    new_depth = result['local_points']
    pix = torch.from_numpy(get_pixel(H, W).T.reshape(H, W, 3)).to(imgs.device).float()[:, :].repeat(B, N, 1, 1, 1)
    gt_rays = torch.einsum('bnij, bnhwj -> bnhwi', torch.inverse(gt['camera_intrinsics'][teach_batch_id]), pix)[..., :2]
    pred_pts_cam = torch.cat([gt_rays * new_depth, new_depth], dim=-1)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions