diff --git a/packages/dali_pipeline_framework/tests/processing_steps/bev_bboxes_transformer_3d_test.py b/packages/dali_pipeline_framework/tests/processing_steps/bev_bboxes_transformer_3d_test.py index 0822a90..e497fbe 100644 --- a/packages/dali_pipeline_framework/tests/processing_steps/bev_bboxes_transformer_3d_test.py +++ b/packages/dali_pipeline_framework/tests/processing_steps/bev_bboxes_transformer_3d_test.py @@ -34,6 +34,10 @@ ) from accvlab.dali_pipeline_framework.processing_steps import BEVBBoxesTransformer3D +# Float32 transform matrices accumulate error across trigonometry, inversions, and matrix products. +# 5e-6 is a small error budget for values with translation components in the tens. +FLOAT32_TRANSFORM_MATRIX_TOLERANCE = 5e-6 + def set_dali_uniform_generator_and_get_orig_and_replacement(sequences): generator = DaliFakeRandomGenerator(sequences) @@ -292,7 +296,9 @@ def apply_transformation_to_points(points, rotation_matrix, scaling_matrix, tran return res -def verify_matrix_inverse_relationship(ego_to_world, world_to_ego, tolerance=1e-6): +def verify_matrix_inverse_relationship( + ego_to_world, world_to_ego, tolerance=FLOAT32_TRANSFORM_MATRIX_TOLERANCE +): """Verify that world_to_ego is the inverse of ego_to_world.""" expected_world_to_ego = torch.linalg.inv(ego_to_world) @@ -301,14 +307,6 @@ def verify_matrix_inverse_relationship(ego_to_world, world_to_ego, tolerance=1e- f"Max difference: {torch.max(torch.abs(world_to_ego - expected_world_to_ego)).item()}" ) - # Check that world_to_ego @ ego_to_world = identity - identity_check = world_to_ego @ ego_to_world - expected_identity = torch.eye(4, dtype=torch.float32) - - assert torch.allclose( - identity_check, expected_identity, atol=tolerance - ), f"world_to_ego is not the inverse of ego_to_world. Max difference: {torch.max(torch.abs(identity_check - expected_identity)).item()}" - def verify_transformation_consistency( original_points, original_transformation, transformed_points, transformed_transformation, tolerance=1e-6 @@ -367,6 +365,7 @@ def verify( scaling, translation, tolerance=1e-6, + transform_matrix_tolerance=FLOAT32_TRANSFORM_MATRIX_TOLERANCE, ): """Verify that the transformation is consistent.""" @@ -422,18 +421,20 @@ def verify( ), dtype=torch.float32, ) - assert torch.allclose(ego_to_world_ref, ego_to_world_out, atol=tolerance), ( + assert torch.allclose(ego_to_world_ref, ego_to_world_out, atol=transform_matrix_tolerance), ( "ego_to_world transformation does not match reference implementation. " f"Max difference: {torch.max(torch.abs(ego_to_world_ref - ego_to_world_out)).item()}" ) - assert torch.allclose(world_to_ego_ref, world_to_ego_out, atol=tolerance), ( + assert torch.allclose(world_to_ego_ref, world_to_ego_out, atol=transform_matrix_tolerance), ( "world_to_ego transformation does not match reference implementation. " f"Max difference: {torch.max(torch.abs(world_to_ego_ref - world_to_ego_out)).item()}" ) verify_transformation_consistency(points_in, proj_matrix_in, points_out, proj_matrix_out, tolerance) verify_transformation_consistency(points_in, ego_to_world_in, points_out, ego_to_world_out, tolerance) - verify_matrix_inverse_relationship(ego_to_world_out, world_to_ego_out, tolerance) + verify_matrix_inverse_relationship( + ego_to_world_out, world_to_ego_out, tolerance=transform_matrix_tolerance + ) def run_transformation_test_with_reference_comparison(