Skip to content

Issue with Updating viewpoint_cam.cam_trans_delta and viewpoint_cam.cam_rot_delta in Optimizer #23

Description

@WillLan

I am experiencing an issue where the viewpoint_cam.cam_trans_delta and viewpoint_cam.cam_rot_delta parameters are not being updated as expected during optimization. After reviewing the code, I believe the issue is caused by the update_pose(viewpoint_cam) being executed inside the with torch.no_grad() block. This prevents the gradients from being computed and updated for these parameters.

Here is the relevant code snippet:

pose_optimizer = torch.optim.Adam([{"params": [viewpoint_cam.cam_trans_delta], "lr": opt.translation_lr_init}, 
                                   {"params": [viewpoint_cam.cam_rot_delta], "lr": opt.rotation_lr_init}])

gt_image = viewpoint_cam.original_image.cuda()

progress_bar = tqdm(range(0, pose_iteration), desc="Pose estimation progress")
for iteration in range(pose_iteration):
    voxel_visible_mask = prefilter_voxel(viewpoint_cam, gaussians, pipe, background)
    render_pkg = render(viewpoint_cam, gaussians, pipe, background, visible_mask=voxel_visible_mask, retain_grad=True)
    image = render_pkg["render"]
    rendered_depth = render_pkg["depth"][0]
    occ_mask = get_occlusion_mask(viewpoint_cam=pre_viewpoint_cam1, viewpoint_cam2=viewpoint_cam, 
                                  depth=pre_rendered_depth, device=pre_rendered_depth.device, thresh=0.001).detach()

    Ll1 = l1_loss(image[:, occ_mask], gt_image[:, occ_mask])
    loss = Ll1

    # 2D correspondence loss
    if opt.loss_2d_correspondence_weight > 0 and viewpoint_cam.uid > 0:
        view1 = scene.getTrainCameras()[viewpoint_cam.uid - 1]
        view2 = viewpoint_cam
        kp0, kp1, conf = view2.kp0.cuda(), view2.kp1.cuda(), view2.conf.cuda()
        loss_2d = correspondence_2d_loss(kp0, kp1, conf, rendered_depth, 
                                         view2.view_world_transform, view1.world_view_transform, view2.intrinsic)
        loss += loss_2d * opt.loss_2d_correspondence_weight

    loss.backward()

    with torch.no_grad():
        pose_optimizer.step()
        pose_optimizer.zero_grad(set_to_none=True)
        gaussians.optimizer.zero_grad(set_to_none=True)
        gaussians.pose_optimizer.zero_grad(set_to_none=True)
        update_pose(viewpoint_cam)

        if iteration % 10 == 0:
            progress_bar.set_postfix({"Loss": f"{loss:.{7}f}"})
            progress_bar.update(10)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions