PointLoss中的代码:
local_pts_loss = self.criteria_local(aligned_pred_local_pts[valid_masks].float(), gt_local_pts[valid_masks].float()) * weights[valid_masks].float()[..., None]
details['local_pts_loss'] = local_pts_loss.mean()
pix = torch.from_numpy(get_pixel(H, W).T.reshape(H, W, 3)).to(gt_local_pts.device).float()[:, :].repeat(B, N, 1, 1, 1)
gt_rays = torch.einsum('bnij, bnhwj -> bnhwi', torch.inverse(gt['camera_intrinsics']), pix)[..., :2]
rays_loss = F.l1_loss(pred['xy'], gt_rays)
final_loss += rays_loss
details['rays_loss'] = rays_loss
depth_loss = local_pts_loss[..., 2].mean()
details['depth_loss'] = depth_loss
final_loss += depth_loss
另外,从代码中看teacher模型推出的result['local_points']是深度吗?应该还不是points?
with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.bfloat16):
result = self.teacher(imgs, gt['camera_poses'][teach_batch_id], gt['camera_intrinsics'][teach_batch_id])
new_depth = result['local_points']
pix = torch.from_numpy(get_pixel(H, W).T.reshape(H, W, 3)).to(imgs.device).float()[:, :].repeat(B, N, 1, 1, 1)
gt_rays = torch.einsum('bnij, bnhwj -> bnhwi', torch.inverse(gt['camera_intrinsics'][teach_batch_id]), pix)[..., :2]
pred_pts_cam = torch.cat([gt_rays * new_depth, new_depth], dim=-1)
PointLoss中的代码:
另外,从代码中看teacher模型推出的result['local_points']是深度吗?应该还不是points?