diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index 8b81b032037..e85c362b686 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -4318,17 +4318,41 @@ mod tests { // don't verify the number of results and row ids for hamming distance, // because there are many vectors with the same distance if dist_type != DistanceType::Hamming { - assert_eq!(left_res.num_rows(), part_idx); - assert_eq!(right_res.num_rows(), k - part_idx); + // Tolerate a single tied pair at the partition boundary. When + // dists[part_idx - 1] == part_dist, the strict-less left filter + // excludes both tied vectors and the inclusive right filter + // includes both, shifting one row from left to right and dropping + // row_ids[k - 1] off right_res's limit. Observed for Dot distance + // on ARM where SIMD FMA yields tied float32 dot products that x86 + // does not. The distance-value assertions below still cover + // partition correctness in both cases. + let boundary_tie = part_idx > 0 && dists[part_idx - 1] == part_dist; let left_row_ids = left_res[ROW_ID].as_primitive::().values(); let right_row_ids = right_res[ROW_ID].as_primitive::().values(); - row_ids.iter().enumerate().for_each(|(i, id)| { - if i < part_idx { - assert_eq!(left_row_ids[i], *id,); - } else { - assert_eq!(right_row_ids[i - part_idx], *id,); + if boundary_tie { + assert_eq!(left_res.num_rows(), part_idx - 1); + for i in 0..(part_idx - 1) { + assert_eq!(left_row_ids[i], row_ids[i]); } - }); + assert_eq!(right_res.num_rows(), k - part_idx); + // right_row_ids[0..2] are the two tied vectors in tiebreaker + // order; their identity is not pinned. right_row_ids[i] for + // i >= 2 aligns with row_ids[part_idx + i - 1] because the + // tie shifts one vector from left to right. + for i in 2..(k - part_idx) { + assert_eq!(right_row_ids[i], row_ids[part_idx + i - 1]); + } + } else { + assert_eq!(left_res.num_rows(), part_idx); + assert_eq!(right_res.num_rows(), k - part_idx); + row_ids.iter().enumerate().for_each(|(i, id)| { + if i < part_idx { + assert_eq!(left_row_ids[i], *id,); + } else { + assert_eq!(right_row_ids[i - part_idx], *id,); + } + }); + } } let left_dists = left_res[DIST_COL].as_primitive::().values(); let right_dists = right_res[DIST_COL].as_primitive::().values();