diff --git a/opentad/models/detectors/single_stage.py b/opentad/models/detectors/single_stage.py index 3ba72941..4e871193 100644 --- a/opentad/models/detectors/single_stage.py +++ b/opentad/models/detectors/single_stage.py @@ -109,7 +109,7 @@ def post_processing(self, predictions, metas, post_cfg, ext_cls, **kwargs): if num_classes == 1: scores = scores.squeeze(-1) - labels = torch.zeros(scores.shape[0]).contiguous() + labels = torch.zeros(scores.shape[0], dtype=torch.long).contiguous() else: pred_prob = scores.flatten() # [N*class]