Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions src/nrdk/metrics/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import torch
from beartype.typing import Literal
from jaxtyping import Float, Shaped
from jaxtyping import Bool, Float, Shaped
from torch import Tensor


Expand Down Expand Up @@ -46,24 +46,26 @@ def __init__(
self.max_points = max_points

@abstractmethod
def as_points(self, data: Shaped[Tensor, "..."]) -> Float[Tensor, "n d"]:
def as_points(self, data: Bool[Tensor, "..."]) -> Float[Tensor, "n d"]:
"""Convert a model's native representation to cartesian points.

Note that since the output number of points may vary across a batch,
this method is not batched.
"""
...

def _limit(self, data: Tensor) -> Tensor:
def _limit(self, data: Tensor) -> Bool[Tensor, "..."]:
"""Subsample occupied cells to at most max_points."""
occupied = data if data.dtype == torch.bool else data > 0

# not set
if self.max_points is None:
return data
return occupied

# below limit
occ_idx = (data > 0).nonzero(as_tuple=False) # [n, ndim]
occ_idx = occupied.nonzero(as_tuple=False) # [n, ndim]
if len(occ_idx) <= self.max_points:
return data
return occupied

# sampling required
if data.dtype == torch.bool:
Expand All @@ -72,11 +74,8 @@ def _limit(self, data: Tensor) -> Tensor:
weights = torch.sigmoid(data[tuple(occ_idx.T)])
sampled = torch.multinomial(weights, self.max_points, replacement=False)
keep = occ_idx[sampled]
result = torch.zeros_like(data)
if data.dtype == torch.bool:
result[tuple(keep.T)] = True
else:
result[tuple(keep.T)] = data[tuple(keep.T)]
result = torch.zeros_like(occupied)
result[tuple(keep.T)] = True

return result

Expand Down Expand Up @@ -148,7 +147,7 @@ def __init__(
self.az_max = az_max

def as_points(
self, data: Shaped[Tensor, "azimuth range"]
self, data: Bool[Tensor, "azimuth range"]
) -> Float[Tensor, "n 2"]:
"""Convert (azimuth, range) occupancy grid to points."""
Na, Nr = data.shape
Expand Down Expand Up @@ -200,7 +199,7 @@ def __init__(
self.el_max = el_max

def as_points(
self, data: Shaped[Tensor, "elevation azimuth range"]
self, data: Bool[Tensor, "elevation azimuth range"]
) -> Float[Tensor, "n 3"]:
"""Convert (elevation, azimuth, range) occupancy grid to points."""
Ne, Na, _Nr = data.shape
Expand Down
14 changes: 14 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,20 @@ def test_polar_chamfer_3d_basic():
assert torch.all(loss >= 0)


def test_polar_chamfer_3d_float_logits_below_limit():
"""Test PolarChamfer3D ignores negative logits below max_points."""
chamfer = PolarChamfer3D(max_points=10000)
data = -torch.ones(4, 8, 16)
data[1, 2, 5] = 0.5
data[2, 3, 10] = 1.0

limited = chamfer._limit(data)
points = chamfer.as_points(limited)

assert limited.dtype == torch.bool
assert points.shape == (2, 3)


def test_polar_chamfer_3d_modes():
"""Test PolarChamfer3D with different modes."""
# Shape: [batch, elevation, azimuth, range]
Expand Down
Loading