From 353846166744e9c0114c043f922da6e1b650f220 Mon Sep 17 00:00:00 2001 From: "etienne.perot@gehealthcare.com" Date: Wed, 1 Apr 2026 10:12:03 +0200 Subject: [PATCH] fix warnings --- keymorph/keypoint_aligners.py | 2 +- keymorph/model.py | 2 +- keymorph/utils.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/keymorph/keypoint_aligners.py b/keymorph/keypoint_aligners.py index 6c45fa1..57ec372 100644 --- a/keymorph/keypoint_aligners.py +++ b/keymorph/keypoint_aligners.py @@ -268,7 +268,7 @@ def __init__( # Note we flip the order of the points here if use_checkpoint: self.inverse_theta = checkpoint.checkpoint( - self.fit, points_f, points_m, lmbda, w + self.fit, points_f, points_m, lmbda, w, use_reentrant=False ) else: self.inverse_theta = self.fit(points_f, points_m, lmbda, weights=w) diff --git a/keymorph/model.py b/keymorph/model.py index 5c6a03a..fac7df8 100644 --- a/keymorph/model.py +++ b/keymorph/model.py @@ -183,7 +183,7 @@ def forward(self, img_f, img_m, transform_type="affine", **kwargs): if self.weight_keypoints == "power": if self.use_checkpoint: weights = checkpoint.checkpoint( - self.weight_by_power, feat_f, feat_m + self.weight_by_power, feat_f, feat_m, use_reentrant=False ) else: weights = self.weight_by_power(feat_f, feat_m) diff --git a/keymorph/utils.py b/keymorph/utils.py index e82bfac..5c2b521 100644 --- a/keymorph/utils.py +++ b/keymorph/utils.py @@ -251,7 +251,7 @@ def convert_points_norm2voxel(points, grid_sizes): Returns: Array of points in voxel space. """ - grid_sizes = torch.tensor(grid_sizes).to(points.device) + grid_sizes = torch.as_tensor(grid_sizes).to(points.device) assert grid_sizes.shape[-1] == points.shape[-1], "Dimensions don't match" translated_points = points + 1 scaled_points = (translated_points * grid_sizes) / 2 @@ -270,7 +270,7 @@ def convert_points_voxel2norm(points, grid_sizes): Returns: Array of points in the normalized space [-1, 1]. """ - grid_sizes = torch.tensor(grid_sizes).to(points.device) + grid_sizes = torch.as_tensor(grid_sizes).to(points.device) assert grid_sizes.shape[-1] == points.shape[-1], "Dimensions don't match" rescaled_points_shifted = points + 0.5 normalized_points = (2 * rescaled_points_shifted / grid_sizes) - 1