diff --git a/src/nrdk/metrics/pointcloud.py b/src/nrdk/metrics/pointcloud.py index 2145a16..d081f36 100644 --- a/src/nrdk/metrics/pointcloud.py +++ b/src/nrdk/metrics/pointcloud.py @@ -5,7 +5,7 @@ import numpy as np import torch from beartype.typing import Literal -from jaxtyping import Float, Shaped +from jaxtyping import Bool, Float, Shaped from torch import Tensor @@ -46,7 +46,7 @@ def __init__( self.max_points = max_points @abstractmethod - def as_points(self, data: Shaped[Tensor, "..."]) -> Float[Tensor, "n d"]: + def as_points(self, data: Bool[Tensor, "..."]) -> Float[Tensor, "n d"]: """Convert a model's native representation to cartesian points. Note that since the output number of points may vary across a batch, @@ -54,16 +54,18 @@ def as_points(self, data: Shaped[Tensor, "..."]) -> Float[Tensor, "n d"]: """ ... - def _limit(self, data: Tensor) -> Tensor: + def _limit(self, data: Tensor) -> Bool[Tensor, "..."]: """Subsample occupied cells to at most max_points.""" + occupied = data if data.dtype == torch.bool else data > 0 + # not set if self.max_points is None: - return data + return occupied # below limit - occ_idx = (data > 0).nonzero(as_tuple=False) # [n, ndim] + occ_idx = occupied.nonzero(as_tuple=False) # [n, ndim] if len(occ_idx) <= self.max_points: - return data + return occupied # sampling required if data.dtype == torch.bool: @@ -72,11 +74,8 @@ def _limit(self, data: Tensor) -> Tensor: weights = torch.sigmoid(data[tuple(occ_idx.T)]) sampled = torch.multinomial(weights, self.max_points, replacement=False) keep = occ_idx[sampled] - result = torch.zeros_like(data) - if data.dtype == torch.bool: - result[tuple(keep.T)] = True - else: - result[tuple(keep.T)] = data[tuple(keep.T)] + result = torch.zeros_like(occupied) + result[tuple(keep.T)] = True return result @@ -148,7 +147,7 @@ def __init__( self.az_max = az_max def as_points( - self, data: Shaped[Tensor, "azimuth range"] + self, data: Bool[Tensor, "azimuth range"] ) -> Float[Tensor, "n 2"]: """Convert (azimuth, range) occupancy grid to points.""" Na, Nr = data.shape @@ -200,7 +199,7 @@ def __init__( self.el_max = el_max def as_points( - self, data: Shaped[Tensor, "elevation azimuth range"] + self, data: Bool[Tensor, "elevation azimuth range"] ) -> Float[Tensor, "n 3"]: """Convert (elevation, azimuth, range) occupancy grid to points.""" Ne, Na, _Nr = data.shape diff --git a/tests/test_metrics.py b/tests/test_metrics.py index f205eb3..0ae4f5c 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -292,6 +292,20 @@ def test_polar_chamfer_3d_basic(): assert torch.all(loss >= 0) +def test_polar_chamfer_3d_float_logits_below_limit(): + """Test PolarChamfer3D ignores negative logits below max_points.""" + chamfer = PolarChamfer3D(max_points=10000) + data = -torch.ones(4, 8, 16) + data[1, 2, 5] = 0.5 + data[2, 3, 10] = 1.0 + + limited = chamfer._limit(data) + points = chamfer.as_points(limited) + + assert limited.dtype == torch.bool + assert points.shape == (2, 3) + + def test_polar_chamfer_3d_modes(): """Test PolarChamfer3D with different modes.""" # Shape: [batch, elevation, azimuth, range]