From 25398339b0e235d3fa252602b987a13167a83ae6 Mon Sep 17 00:00:00 2001 From: enp1s0 Date: Fri, 29 May 2026 17:50:22 +0900 Subject: [PATCH 01/15] Add pq_len=8 --- .../cagra/compute_distance_vpq_matrix.json | 53 ++++++++++++-- .../jit_lto_kernels/cagra_planner_base.hpp | 23 ++++-- .../compute_distance_matrix.json | 70 ++++++++++++++++++ .../setup_workspace_matrix.json | 70 ++++++++++++++++++ cpp/tests/neighbors/ann_cagra.cuh | 46 +++++++++++- cpp/tests/neighbors/vpq_utils.cuh | 73 +++++++++++++++++++ 6 files changed, 321 insertions(+), 14 deletions(-) create mode 100644 cpp/tests/neighbors/vpq_utils.cuh diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq_matrix.json b/cpp/src/neighbors/detail/cagra/compute_distance_vpq_matrix.json index cf6e060d33..c6e2ae319c 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq_matrix.json +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq_matrix.json @@ -36,15 +36,53 @@ "_mxdim_team": [ { "dim": "128", - "team_size": "8" + "team_size": "8", + "pq_len": "2" }, { "dim": "256", - "team_size": "16" + "team_size": "16", + "pq_len": "2" }, { "dim": "512", - "team_size": "32" + "team_size": "32", + "pq_len": "2" + }, + { + "dim": "128", + "team_size": "8", + "pq_len": "4" + }, + { + "dim": "256", + "team_size": "16", + "pq_len": "4" + }, + { + "dim": "512", + "team_size": "32", + "pq_len": "4" + }, + { + "dim": "128", + "team_size": "4", + "pq_len": "8" + }, + { + "dim": "256", + "team_size": "8", + "pq_len": "8" + }, + { + "dim": "512", + "team_size": "16", + "pq_len": "8" + }, + { + "dim": "1024", + "team_size": "32", + "pq_len": "8" } ], "_codebook": [ @@ -53,7 +91,10 @@ "codebook_abbrev": "h" } ], - "pq_bits": ["8"], - "pq_len": ["2", "4"], - "metric": ["L2Expanded"] + "pq_bits": [ + "8" + ], + "metric": [ + "L2Expanded" + ] } diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp index 317ca1a1b6..b9e7891723 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp @@ -64,8 +64,8 @@ struct CagraPlannerBase : AlgorithmPlanner { uint32_t dataset_block_dim, uint32_t pq_len) { - if (pq_len != 2 && pq_len != 4) { - RAFT_FAIL("CAGRA JIT VPQ setup_workspace expects pq_len in {2,4} (matrix uses pq_bits=8)"); + if (pq_len != 2 && pq_len != 4 && pq_len != 8) { + RAFT_FAIL("CAGRA JIT VPQ setup_workspace expects pq_len in {2,4,8} (matrix uses pq_bits=8)"); } auto add = [&]() { this->add_static_fragment() { if (pq_len == 2) { add.template operator()(); - } else { + } else if (pq_len == 4) { add.template operator()(); + } else { + add.template operator()(); } }); } @@ -120,8 +122,8 @@ struct CagraPlannerBase : AlgorithmPlanner { uint32_t dataset_block_dim, uint32_t pq_len) { - if (pq_len != 2 && pq_len != 4) { - RAFT_FAIL("CAGRA JIT VPQ compute_distance expects pq_len in {2,4} (matrix uses pq_bits=8)"); + if (pq_len != 2 && pq_len != 4 && pq_len != 8) { + RAFT_FAIL("CAGRA JIT VPQ compute_distance expects pq_len in {2,4,8} (matrix uses pq_bits=8)"); } auto add = [&]() { this->add_static_fragment() { if (pq_len == 2) { add.template operator()(); - } else { + } else if (pq_len == 4) { add.template operator()(); + } else { + add.template operator()(); } }); } @@ -219,6 +223,12 @@ struct CagraPlannerBase : AlgorithmPlanner { static void dispatch_cagra_team_dim(uint32_t team_size, uint32_t dataset_block_dim, Lambda&& l) { switch (team_size) { + case 4: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()<4u, 128u>(); return; + default: break; + } + break; case 8: switch (dataset_block_dim) { case 128: std::forward(l).template operator()<8u, 128u>(); return; @@ -240,6 +250,7 @@ struct CagraPlannerBase : AlgorithmPlanner { case 128: std::forward(l).template operator()<32u, 128u>(); return; case 256: std::forward(l).template operator()<32u, 256u>(); return; case 512: std::forward(l).template operator()<32u, 512u>(); return; + case 1024: std::forward(l).template operator()<32u, 1024u>(); return; default: break; } break; diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json index 82b8dbdf4e..2e64ee2ce1 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json @@ -150,5 +150,75 @@ "codebook_abbrev": "half" } ] + }, + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f" + }, + { + "data_type": "__half", + "data_abbrev": "h" + }, + { + "data_type": "uint8_t", + "data_abbrev": "u8" + }, + { + "data_type": "int8_t", + "data_abbrev": "i8" + } + ], + "_query": [ + { + "query_type": "half", + "query_abbrev": "h" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "u32" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "_pq": [ + { + "pq_len": "8", + "pq_bits": "8", + "pq_prefix": "_vpq", + "pq_suffix": "_8pq_8subd" + } + ], + "_codebook": [ + { + "codebook_type": "half", + "codebook_abbrev": "half" + } + ], + "_mxdim_team": [ + { + "dataset_block_dim": "128", + "team_size": "4" + }, + { + "dataset_block_dim": "256", + "team_size": "8" + }, + { + "dataset_block_dim": "512", + "team_size": "16" + }, + { + "dataset_block_dim": "1024", + "team_size": "32" + } + ] } ] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json index 83aa8764bc..64c82ce13a 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json @@ -150,5 +150,75 @@ "codebook_abbrev": "half" } ] + }, + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f" + }, + { + "data_type": "__half", + "data_abbrev": "h" + }, + { + "data_type": "uint8_t", + "data_abbrev": "u8" + }, + { + "data_type": "int8_t", + "data_abbrev": "i8" + } + ], + "_query": [ + { + "query_type": "half", + "query_abbrev": "h" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "u32" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "_pq": [ + { + "pq_len": "8", + "pq_bits": "8", + "pq_prefix": "_vpq", + "pq_suffix": "_8pq_8subd" + } + ], + "_codebook": [ + { + "codebook_type": "half", + "codebook_abbrev": "half" + } + ], + "_mxdim_team": [ + { + "dataset_block_dim": "128", + "team_size": "4" + }, + { + "dataset_block_dim": "256", + "team_size": "8" + }, + { + "dataset_block_dim": "512", + "team_size": "16" + }, + { + "dataset_block_dim": "1024", + "team_size": "32" + } + ] } ] diff --git a/cpp/tests/neighbors/ann_cagra.cuh b/cpp/tests/neighbors/ann_cagra.cuh index a6704f892a..826b8d1a3a 100644 --- a/cpp/tests/neighbors/ann_cagra.cuh +++ b/cpp/tests/neighbors/ann_cagra.cuh @@ -6,6 +6,7 @@ #include "../test_utils.cuh" #include "ann_utils.cuh" +#include "vpq_utils.cuh" #include #include "naive_knn.cuh" @@ -461,6 +462,46 @@ class AnnCagraTest : public ::testing::TestWithParam { raft::update_host(indices_Cagra.data(), indices_dev.data(), queries_size, stream_); raft::resource::sync_stream(handle_); + + reference_recall = 1; + if (ps.compression.has_value()) { + auto decoded_dataset = + raft::make_device_matrix(handle_, ps.n_rows, ps.dim); + cuvs::neighbors::decode_vpq_dataset( + decoded_dataset.view(), + dynamic_cast&>(index.data()), + raft::resource::get_cuda_stream(handle_)); + auto indices_out_view = raft::make_device_matrix_view( + indices_dev.data(), ps.n_queries, ps.k); + auto dists_out_view = raft::make_device_matrix_view( + distances_dev.data(), ps.n_queries, ps.k); + + cuvs::neighbors::naive_knn(handle_, + dists_out_view.data_handle(), + indices_out_view.data_handle(), + search_queries.data(), + decoded_dataset.data_handle(), + ps.n_queries, + ps.n_rows, + ps.dim, + ps.k, + ps.metric); + std::vector indices_vpq_dataset(queries_size); + std::vector distances_vpq_dataset(queries_size); + raft::update_host( + distances_vpq_dataset.data(), dists_out_view.data_handle(), queries_size, stream_); + raft::update_host( + indices_vpq_dataset.data(), indices_out_view.data_handle(), queries_size, stream_); + + reference_recall = std::get<1>(calc_recall(indices_naive, + indices_vpq_dataset, + distances_naive, + distances_vpq_dataset, + ps.n_queries, + ps.k, + 0)); + printf("reference_recall = %e\n", reference_recall); + } } // for (int i = 0; i < min(ps.n_queries, 10); i++) { @@ -470,7 +511,7 @@ class AnnCagraTest : public ::testing::TestWithParam { // print_vector("T", distances_naive.data() + i * ps.k, ps.k, std::cout); // print_vector("C", distances_Cagra.data() + i * ps.k, ps.k, std::cout); // } - double min_recall = ps.min_recall; + double min_recall = ps.min_recall * reference_recall; EXPECT_TRUE(eval_neighbours(indices_naive, indices_Cagra, distances_naive, @@ -519,6 +560,7 @@ class AnnCagraTest : public ::testing::TestWithParam { AnnCagraInputs ps; rmm::device_uvector database; rmm::device_uvector search_queries; + double reference_recall; }; template @@ -1652,7 +1694,7 @@ inline std::vector generate_inputs() {cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_PHYSICAL, cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_LOGICAL}); // don't demand high recall // without refinement - for (uint32_t pq_len : {2}) { // for now, only pq_len = 2 is supported, more options coming soon + for (uint32_t pq_len : {2, 4, 8}) { for (uint32_t vq_n_centers : {100, 1000}) { for (auto input : inputs2) { vpq_params ps{}; diff --git a/cpp/tests/neighbors/vpq_utils.cuh b/cpp/tests/neighbors/vpq_utils.cuh new file mode 100644 index 0000000000..613b5fe1d5 --- /dev/null +++ b/cpp/tests/neighbors/vpq_utils.cuh @@ -0,0 +1,73 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include +#include +#include + +#include +#include + +namespace cuvs::neighbors { +template +__global__ void decode_vpq_dataset_kernel(data_t* const decoded_dataset_ptr, + const uint32_t ldd, + const math_t* const vq_codebook_ptr, + const uint32_t ldv, + const math_t* const pq_codebook_ptr, + const uint32_t pq_subspace_dim, + const uint32_t pq_table_size, + const uint32_t dataset_dim, + const size_t dataset_size, + const uint8_t* const data_ptr, + const uint32_t ldi) +{ + constexpr uint32_t warp_size = 32; + const size_t batch_id = (blockIdx.x * blockDim.x + threadIdx.x) / warp_size; + if (batch_id >= dataset_size) { return; } + + const auto local_data_ptr = data_ptr + ldi * batch_id; + const auto vq_code = *reinterpret_cast(local_data_ptr); + const auto pq_code_ptr = local_data_ptr + sizeof(uint32_t); + const auto vq_vec_ptr = vq_codebook_ptr + vq_code * ldv; + auto local_dst_ptr = decoded_dataset_ptr + batch_id * ldd; + + const auto lane_id = threadIdx.x % warp_size; + for (uint32_t i = lane_id; i < dataset_dim; i += warp_size) { + const auto pq_code = pq_code_ptr[i / pq_subspace_dim]; + const auto pq_v = pq_codebook_ptr[pq_code * pq_subspace_dim + (i % pq_subspace_dim)]; + + local_dst_ptr[i] = static_cast(vq_vec_ptr[i]) + static_cast(pq_v); + } +} + +template +void decode_vpq_dataset(raft::device_matrix_view decoded_dataset, + const cuvs::neighbors::vpq_dataset& vpq_dataset, + cudaStream_t cuda_stream) +{ + const auto dataset_size = decoded_dataset.extent(0); + RAFT_EXPECTS(vpq_dataset.data.extent(0) == dataset_size, "Dataset sizes mismatch"); + + constexpr uint32_t block_size = 256; + constexpr uint32_t warp_size = 32; + constexpr int64_t vecs_per_cta = block_size / warp_size; + const auto grid_size = raft::div_rounding_up_safe(decoded_dataset.extent(0), vecs_per_cta); + + decode_vpq_dataset_kernel + <<>>(decoded_dataset.data_handle(), + decoded_dataset.stride(0), + vpq_dataset.vq_code_book.data_handle(), + vpq_dataset.vq_code_book.stride(0), + vpq_dataset.pq_code_book.data_handle(), + vpq_dataset.pq_len(), + 1u << vpq_dataset.pq_bits(), + vpq_dataset.dim(), + dataset_size, + vpq_dataset.data.data_handle(), + vpq_dataset.data.stride(0)); +} +} // namespace cuvs::neighbors From 19d5a0a788119c40f5e2041db50a90f947dd4c3b Mon Sep 17 00:00:00 2001 From: enp1s0 Date: Thu, 4 Jun 2026 12:01:19 +0900 Subject: [PATCH 02/15] Update cagra-q test --- cpp/tests/neighbors/ann_utils.cuh | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/cpp/tests/neighbors/ann_utils.cuh b/cpp/tests/neighbors/ann_utils.cuh index cbc95d7bb7..64732a8a3a 100644 --- a/cpp/tests/neighbors/ann_utils.cuh +++ b/cpp/tests/neighbors/ann_utils.cuh @@ -227,8 +227,9 @@ auto calc_recall(const std::vector& expected_idx, size_t cols, double eps) { - size_t match_count = 0; - size_t total_count = static_cast(rows) * static_cast(cols); + size_t match_count = 0; + size_t index_match_count = 0; + size_t total_count = static_cast(rows) * static_cast(cols); for (size_t i = 0; i < rows; ++i) { for (size_t k = 0; k < cols; ++k) { size_t idx_k = i * cols + k; // row major assumption! @@ -247,8 +248,28 @@ auto calc_recall(const std::vector& expected_idx, } } } - return std::make_tuple( - static_cast(match_count) / static_cast(total_count), match_count, total_count); + + // Index based recall + for (size_t i = 0; i < rows; ++i) { + for (size_t k = 0; k < cols; ++k) { + size_t idx_k = i * cols + k; // row major assumption! + auto act_idx = actual_idx[idx_k]; + for (size_t j = 0; j < cols; ++j) { + size_t idx = i * cols + j; // row major assumption! + auto exp_idx = expected_idx[idx]; + + if (act_idx == exp_idx) { + index_match_count++; + break; + } + } + } + } + + return std::make_tuple(static_cast(match_count) / static_cast(total_count), + static_cast(index_match_count) / static_cast(total_count), + match_count, + total_count); } /** same as eval_recall, but in case indices do not match, @@ -265,7 +286,7 @@ auto eval_neighbours(const std::vector& expected_idx, bool test_unique = true, size_t max_duplicates = 0) -> testing::AssertionResult { - auto [actual_recall, match_count, total_count] = + auto [actual_recall, index_based_actual_recall, match_count, total_count] = calc_recall(expected_idx, actual_idx, expected_dist, actual_dist, rows, cols, eps); double error_margin = (actual_recall - min_recall) / std::max(1.0 - min_recall, eps); From 09deae59dfa21902bd23a83f671c405c6f9fc759 Mon Sep 17 00:00:00 2001 From: enp1s0 Date: Thu, 4 Jun 2026 12:05:17 +0900 Subject: [PATCH 03/15] Update the compute distance kernel --- .../cagra/compute_distance_vpq-impl.cuh | 17 ++++- .../detail/cagra/compute_distance_vpq.hpp | 8 ++- .../jit_lto_kernels/compute_distance_impl.cuh | 46 +++++++++---- .../jit_lto_kernels/setup_workspace_impl.cuh | 68 +++++++++++-------- 4 files changed, 93 insertions(+), 46 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh b/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh index 6992ae979a..d0f12a20fd 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh @@ -14,6 +14,14 @@ namespace cuvs::neighbors::cagra::detail { +template +struct vpq_smem_value_config { + using smem_val_pack_t = half2; + using smem_val_t = half; + using smem_val_pack_uint_t = uint32_t; + static constexpr uint32_t num_packed_elements = 2; +}; + template ; + static constexpr std::uint32_t kSMemCodeBookSizeInBytes = - (1 << PQ_BITS) * PQ_LEN * utils::size_of(); + (1 << PQ_BITS) * PQ_LEN * utils::size_of() / + smem_val_config::num_packed_elements; _RAFT_HOST_DEVICE cagra_q_dataset_descriptor_t(const std::uint8_t* encoded_dataset_ptr, std::uint32_t encoded_dataset_dim, @@ -108,7 +119,9 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t(dim, DatasetBlockDim) * sizeof(QUERY_T); + raft::round_up_safe(dim, DatasetBlockDim) * + utils::size_of() / + smem_val_config::num_packed_elements; } private: diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp b/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp index 2b69a1cef4..299916c6c7 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -69,6 +69,12 @@ struct vpq_descriptor_spec : public instance_spec { // Match codebook params if (dataset.pq_bits() != PqBits) { return -1.0; } if (dataset.pq_len() != PqLen) { return -1.0; } + // Keep auto-selection on the tuned VPQ diagonal while allowing explicit team_size requests to + // use the expanded team_size / dataset_block_dim grid. + constexpr std::uint32_t auto_dataset_block_dim_per_team = PqLen == 8 ? 32 : 16; + if (params.team_size == 0 && DatasetBlockDim != TeamSize * auto_dataset_block_dim_per_team) { + return -1.0; + } // Otherwise, favor the closest dataset dimensionality. constexpr std::uint32_t preferred_load_elmes_per_thread = 16; /*magic number that is good based on experiments.*/ diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh index 92a014bd2f..08f44c171e 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh @@ -100,10 +100,16 @@ _RAFT_DEVICE RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_vpq_worker_impl( constexpr auto DatasetBlockDim = DescriptorT::kDatasetBlockDim; constexpr auto PQ_BITS = DescriptorT::kPqBits; constexpr auto PQ_LEN = DescriptorT::kPqLen; + using PQ_CODEBOOK_LOAD_T = uint32_t; + + using smem_val_config = vpq_smem_value_config; + using smem_val_pack_t = typename smem_val_config::smem_val_pack_t; + using smem_val_pack_uint_t = typename smem_val_config::smem_val_pack_uint_t; + constexpr uint32_t num_packed_elements = smem_val_config::num_packed_elements; const uint32_t query_ptr = pq_codebook_ptr + DescriptorT::kSMemCodeBookSizeInBytes; static_assert(PQ_BITS == 8, "Only pq_bits == 8 is supported at the moment."); - constexpr uint32_t vlen = 4; // **** DO NOT CHANGE **** + constexpr uint32_t vlen = utils::size_of() / utils::size_of(); constexpr uint32_t nelem = raft::div_rounding_up_unsafe(DatasetBlockDim / PQ_LEN, TeamSize * vlen); @@ -115,12 +121,17 @@ _RAFT_DEVICE RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_vpq_worker_impl( DISTANCE_T norm = 0; for (uint32_t elem_offset = 0; elem_offset * PQ_LEN < dim; elem_offset += DatasetBlockDim / PQ_LEN) { - uint32_t pq_codes[nelem]; + PQ_CODEBOOK_LOAD_T pq_codes[nelem]; #pragma unroll for (std::uint32_t e = 0; e < nelem; e++) { const std::uint32_t k = e * kTeamVLen + elem_offset + laneId * vlen; if (k >= n_subspace) break; - device::ldg_cg(pq_codes[e], reinterpret_cast(dataset_ptr + 4 + k)); + if constexpr (std::is_same_v) { + device::ldg_cg(pq_codes[e], + reinterpret_cast(dataset_ptr + 4 + k)); + } else { + pq_codes[e] = *reinterpret_cast(dataset_ptr + 4 + k); + } } // if constexpr (PQ_LEN % 2 == 0) { @@ -135,23 +146,30 @@ _RAFT_DEVICE RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_vpq_worker_impl( if (d >= dim) break; device::ldg_ca(vq_vals[m], vq_code_book_ptr + d); } - std::uint32_t pq_code = pq_codes[e]; + PQ_CODEBOOK_LOAD_T pq_code = pq_codes[e]; #pragma unroll for (std::uint32_t v = 0; v < vlen; v++) { if (PQ_LEN * (v + k) >= dim) break; #pragma unroll - for (std::uint32_t m = 0; m < PQ_LEN / 2; m++) { - constexpr auto kQueryBlock = DatasetBlockDim / (vlen * PQ_LEN); - const std::uint32_t d1 = m + (PQ_LEN / 2) * v; - const std::uint32_t d = - d1 * kQueryBlock + elem_offset * (PQ_LEN / 2) + e * TeamSize + laneId; - half2 q2, c2; - device::lds(q2, query_ptr + sizeof(half2) * d); + for (std::uint32_t m = 0; m < PQ_LEN / num_packed_elements; m++) { + constexpr uint32_t vq_val_pack_num_elements = 2; + constexpr auto kQueryBlock = DatasetBlockDim / (vlen * PQ_LEN); + const std::uint32_t vq_half2_index = + m * (num_packed_elements / vq_val_pack_num_elements) + (PQ_LEN / 2) * v; + + static_assert(num_packed_elements == 2, + "CAGRA JIT VPQ currently stores pq_len=8 in half2 shared-memory packs"); + const uint32_t query_val_index = + vq_half2_index * kQueryBlock + elem_offset * (PQ_LEN / 2) + e * TeamSize + laneId; + + smem_val_pack_t q2, c2; + device::lds(q2, query_ptr + sizeof(smem_val_pack_t) * query_val_index); device::lds(c2, pq_codebook_ptr + - sizeof(CODE_BOOK_T) * ((1 << PQ_BITS) * 2 * m + (2 * (pq_code & 0xff)))); - auto dist = q2 - c2 - reinterpret_cast(vq_vals)[d1]; - dist = dist * dist; + sizeof(smem_val_pack_uint_t) * ((1 << PQ_BITS) * m + (pq_code & 0xff))); + auto dist = + q2 - c2 - reinterpret_cast(vq_vals)[vq_half2_index]; + dist = dist * dist; norm += static_cast(dist.x + dist.y); } pq_code >>= 8; diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuh index 8cdd7febd5..ed83c181fe 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuh @@ -79,12 +79,16 @@ _RAFT_DEVICE __noinline__ auto setup_workspace_vpq_impl( const typename DescriptorT::DATA_T* queries_ptr, uint32_t query_id) -> const DescriptorT* { - using QUERY_T = typename DescriptorT::QUERY_T; - using CODE_BOOK_T = typename DescriptorT::CODE_BOOK_T; - using word_type = uint32_t; - constexpr auto kDatasetBlockDim = DescriptorT::kDatasetBlockDim; - constexpr auto PQ_BITS = DescriptorT::kPqBits; - constexpr auto PQ_LEN = DescriptorT::kPqLen; + using QUERY_T = typename DescriptorT::QUERY_T; + using word_type = uint32_t; + constexpr auto kDatasetBlockDim = DescriptorT::kDatasetBlockDim; + constexpr auto PQ_BITS = DescriptorT::kPqBits; + constexpr auto PQ_LEN = DescriptorT::kPqLen; + using smem_val_config = vpq_smem_value_config; + using smem_val_t = typename smem_val_config::smem_val_t; + using smem_val_pack_t = typename smem_val_config::smem_val_pack_t; + using smem_val_pack_uint_t = typename smem_val_config::smem_val_pack_uint_t; + constexpr auto num_packed_elements = smem_val_config::num_packed_elements; auto* r = reinterpret_cast(smem_ptr); @@ -105,18 +109,22 @@ _RAFT_DEVICE __noinline__ auto setup_workspace_vpq_impl( } __syncthreads(); - for (unsigned i = threadIdx.x * 2; i < (1 << PQ_BITS) * PQ_LEN; i += blockDim.x * 2) { - half2 buf2; - buf2.x = r->pq_code_book_ptr()[i]; - buf2.y = r->pq_code_book_ptr()[i + 1]; - - constexpr auto num_elements_per_bank = 4 / utils::size_of(); - constexpr auto num_banks_per_subspace = PQ_LEN / num_elements_per_bank; - const auto j = i / num_elements_per_bank; - const auto smem_index = - (j / num_banks_per_subspace) + (j % num_banks_per_subspace) * (1 << PQ_BITS); - - device::sts(codebook_buf + smem_index * sizeof(half2), buf2); + for (unsigned i = threadIdx.x * num_packed_elements; i < (1 << PQ_BITS) * PQ_LEN; + i += blockDim.x * num_packed_elements) { + constexpr auto num_elements_per_bank = + num_packed_elements / (utils::size_of() / utils::size_of()); + + if constexpr (PQ_LEN >= num_elements_per_bank) { + constexpr auto num_banks_per_subspace = PQ_LEN / num_elements_per_bank; + const auto j = i / num_elements_per_bank; + const auto smem_index = + (j / num_banks_per_subspace) + (j % num_banks_per_subspace) * (1 << PQ_BITS); + + smem_val_pack_t buf; + buf.x = r->pq_code_book_ptr()[i]; + buf.y = r->pq_code_book_ptr()[i + 1]; + device::sts(codebook_buf + smem_index * sizeof(smem_val_pack_t), buf); + } } } @@ -125,19 +133,21 @@ _RAFT_DEVICE __noinline__ auto setup_workspace_vpq_impl( constexpr cuvs::spatial::knn::detail::utils::mapping mapping{}; auto smem_query_ptr = - reinterpret_cast(reinterpret_cast(smem_ptr) + sizeof(DescriptorT) + - DescriptorT::kSMemCodeBookSizeInBytes); - for (unsigned i = threadIdx.x * 2; i < dim; i += blockDim.x * 2) { - half2 buf2{0, 0}; - if (i < dim) { buf2.x = mapping(queries_ptr[i]); } - if (i + 1 < dim) { buf2.y = mapping(queries_ptr[i + 1]); } - if constexpr ((PQ_BITS == 8) && (PQ_LEN % 2 == 0)) { + reinterpret_cast(reinterpret_cast(smem_ptr) + sizeof(DescriptorT) + + DescriptorT::kSMemCodeBookSizeInBytes); + for (unsigned i = threadIdx.x * num_packed_elements; i < dim; + i += blockDim.x * num_packed_elements) { + smem_val_pack_t buf{0, 0}; + if (i < dim) { buf.x = mapping(queries_ptr[i]); } + if (i + 1 < dim) { buf.y = mapping(queries_ptr[i + 1]); } + if constexpr ((PQ_BITS == 8) && (PQ_LEN % num_packed_elements == 0)) { constexpr uint32_t vlen = 4; // **** DO NOT CHANGE **** - constexpr auto kStride = vlen * PQ_LEN / 2; - reinterpret_cast(smem_query_ptr)[transpose(i / 2)] = - buf2; + constexpr auto kStride = vlen * PQ_LEN / num_packed_elements; + reinterpret_cast( + smem_query_ptr)[transpose( + i / num_packed_elements)] = buf; } else { - (reinterpret_cast(smem_query_ptr + i))[0] = buf2; + (reinterpret_cast(smem_query_ptr + i))[0] = buf; } } From 5fa53216cf33996f772ce3e10de4e5a545be3528 Mon Sep 17 00:00:00 2001 From: enp1s0 Date: Thu, 4 Jun 2026 18:10:24 +0900 Subject: [PATCH 04/15] Add FP8 support --- .../neighbors/detail/cagra/cagra_search.cuh | 5 + .../detail/cagra/compute_distance.hpp | 7 +- .../cagra/compute_distance_vpq-impl.cuh | 49 +++- .../detail/cagra/compute_distance_vpq.hpp | 10 +- .../cagra/compute_distance_vpq_inst.cu.in | 4 +- .../cagra/compute_distance_vpq_matrix.json | 4 + .../detail/cagra/device_memory_ops.hpp | 15 + cpp/src/neighbors/detail/cagra/factory.cuh | 12 +- .../cagra_jit_launcher_factory.hpp | 36 ++- .../jit_lto_kernels/cagra_planner_base.hpp | 274 ++++++++++++++---- .../jit_lto_kernels/compute_distance_impl.cuh | 69 +++-- .../compute_distance_kernel.cu.in | 7 +- .../compute_distance_matrix.json | 26 ++ .../jit_lto_kernels/setup_workspace_impl.cuh | 45 ++- .../setup_workspace_kernel.cu.in | 4 +- .../setup_workspace_matrix.json | 26 ++ .../neighbors/detail/cagra/packed_type.hpp | 49 ++++ 17 files changed, 520 insertions(+), 122 deletions(-) create mode 100644 cpp/src/neighbors/detail/cagra/packed_type.hpp diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index bca8d3314d..f199cf7882 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -153,6 +153,11 @@ void search_main(raft::resources const& res, // Dispatch search parameters based on the dataset kind. if (auto* strided_dset = dynamic_cast*>(&index.data()); strided_dset != nullptr) { + if (params.smem_dtype != cuvs::neighbors::cagra::internal_dtype::AUTO && + params.smem_dtype != cuvs::neighbors::cagra::internal_dtype::F16) { + RAFT_LOG_WARN("In this search mode, smem_dtype supports only AUTO or F16. Set it to AUTO."); + params.smem_dtype = cuvs::neighbors::cagra::internal_dtype::AUTO; + } // Search using a plain (strided) row-major dataset RAFT_EXPECTS(index.metric() != cuvs::distance::DistanceType::CosineExpanded || index.dataset_norms().has_value(), diff --git a/cpp/src/neighbors/detail/cagra/compute_distance.hpp b/cpp/src/neighbors/detail/cagra/compute_distance.hpp index 75b56860bb..7f921ce948 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance.hpp +++ b/cpp/src/neighbors/detail/cagra/compute_distance.hpp @@ -207,6 +207,7 @@ struct dataset_descriptor_host { bool is_vpq = false; uint32_t pq_bits = 0; uint32_t pq_len = 0; + bool enable_fp8 = false; // Codebook type is determined by DataT for VPQ (always half for now) struct state { @@ -258,7 +259,8 @@ struct dataset_descriptor_host { uint32_t dataset_block_dim_val, bool is_vpq_val = false, uint32_t pq_bits_val = 0, - uint32_t pq_len_val = 0) + uint32_t pq_len_val = 0, + bool enable_fp8_val = false) : value_{std::make_shared(init, sizeof(DescriptorImpl))}, smem_ws_size_in_bytes{dd_host.smem_ws_size_in_bytes()}, team_size{dd_host.team_size()}, @@ -266,7 +268,8 @@ struct dataset_descriptor_host { dataset_block_dim{dataset_block_dim_val}, is_vpq{is_vpq_val}, pq_bits{pq_bits_val}, - pq_len{pq_len_val} + pq_len{pq_len_val}, + enable_fp8{enable_fp8_val} { } diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh b/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh index d0f12a20fd..ea994c450a 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh @@ -6,6 +6,7 @@ #pragma once #include "compute_distance_vpq.hpp" +#include "packed_type.hpp" #include #include @@ -14,14 +15,27 @@ namespace cuvs::neighbors::cagra::detail { -template -struct vpq_smem_value_config { +template +struct vpq_smem_value_config; + +template +struct vpq_smem_value_config> { using smem_val_pack_t = half2; using smem_val_t = half; using smem_val_pack_uint_t = uint32_t; static constexpr uint32_t num_packed_elements = 2; }; +template +struct vpq_smem_value_config> { + using smem_val_pack_t = device::fp8xN; + using smem_val_t = typename smem_val_pack_t::unit_t; + using smem_val_pack_uint_t = typename smem_val_pack_t::uint_t; + static constexpr uint32_t num_packed_elements = smem_val_pack_t::num_elements; +}; + template + typename QueryT, + bool EnableFP8> struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t { using base_type = dataset_descriptor_base_t; using CODE_BOOK_T = CodebookT; @@ -46,6 +61,7 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t, "Only CODE_BOOK_T = `half` is supported now"); @@ -88,7 +104,7 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t; + using smem_val_config = vpq_smem_value_config; static constexpr std::uint32_t kSMemCodeBookSizeInBytes = (1 << PQ_BITS) * PQ_LEN * utils::size_of() / @@ -135,7 +151,8 @@ template + typename DistanceT, + bool EnableFP8> RAFT_KERNEL __launch_bounds__(1, 1) vpq_dataset_descriptor_init_kernel(dataset_descriptor_base_t* out, const std::uint8_t* encoded_dataset_ptr, @@ -153,7 +170,8 @@ RAFT_KERNEL __launch_bounds__(1, 1) DataT, IndexT, DistanceT, - half>; + half, + EnableFP8>; new (out) desc_type( encoded_dataset_ptr, encoded_dataset_dim, vq_code_book_ptr, pq_code_book_ptr, size, dim); } @@ -166,7 +184,8 @@ template + typename DistanceT, + bool EnableFP8> dataset_descriptor_host vpq_descriptor_spec::init_(const cagra::search_params& params, + DistanceT, + EnableFP8>::init_(const cagra::search_params& params, const std::uint8_t* encoded_dataset_ptr, uint32_t encoded_dataset_dim, const CodebookT* vq_code_book_ptr, @@ -192,7 +212,8 @@ vpq_descriptor_spec; + half, + EnableFP8>; return host_type{ desc_type{ @@ -207,7 +228,8 @@ vpq_descriptor_spec<<<1, 1, 0, stream>>>(dev_ptr, + DistanceT, + EnableFP8><<<1, 1, 0, stream>>>(dev_ptr, encoded_dataset_ptr, encoded_dataset_dim, vq_code_book_ptr, @@ -218,9 +240,10 @@ vpq_descriptor_spec +#include #include @@ -21,7 +22,8 @@ template + typename DistanceT, + bool EnableFP8> struct vpq_descriptor_spec : public instance_spec { using base_type = instance_spec; using typename base_type::data_type; @@ -63,12 +65,18 @@ struct vpq_descriptor_spec : public instance_spec { const DatasetT& dataset, cuvs::distance::DistanceType metric) -> double { + const auto fp8_natively_supported = raft::getComputeCapability().first >= 9; + const auto use_fp8 = + params.smem_dtype == cuvs::neighbors::cagra::internal_dtype::E5M2 || + (params.smem_dtype == cuvs::neighbors::cagra::internal_dtype::AUTO && fp8_natively_supported); + // If explicit team_size is specified and doesn't match the instance, discard it if (params.team_size != 0 && TeamSize != params.team_size) { return -1.0; } if (cuvs::distance::DistanceType::L2Expanded != metric) { return -1.0; } // Match codebook params if (dataset.pq_bits() != PqBits) { return -1.0; } if (dataset.pq_len() != PqLen) { return -1.0; } + if (use_fp8 != EnableFP8) { return -1.0; } // Keep auto-selection on the tuned VPQ diagonal while allowing explicit team_size requests to // use the expanded team_size / dataset_block_dim grid. constexpr std::uint32_t auto_dataset_block_dim_per_team = PqLen == 8 ? 32 : 16; diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq_inst.cu.in b/cpp/src/neighbors/detail/cagra/compute_distance_vpq_inst.cu.in index c159da3229..676f25c9fd 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq_inst.cu.in +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq_inst.cu.in @@ -13,6 +13,7 @@ constexpr uint32_t team_size = @team_size@; constexpr uint32_t dim = @dim@; constexpr uint32_t pq_bits = @pq_bits@; constexpr uint32_t pq_len = @pq_len@; +constexpr bool enable_fp8 = @enable_fp8@; using codebook_t = @codebook_type@; using data_t = @data_type@; using index_t = @index_type@; @@ -30,6 +31,7 @@ template struct vpq_descriptor_spec; + distance_t, + enable_fp8>; } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq_matrix.json b/cpp/src/neighbors/detail/cagra/compute_distance_vpq_matrix.json index c6e2ae319c..7dac07c2a4 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq_matrix.json +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq_matrix.json @@ -96,5 +96,9 @@ ], "metric": [ "L2Expanded" + ], + "enable_fp8": [ + "true", + "false" ] } diff --git a/cpp/src/neighbors/detail/cagra/device_memory_ops.hpp b/cpp/src/neighbors/detail/cagra/device_memory_ops.hpp index cc164994ea..1bcf6f8fbd 100644 --- a/cpp/src/neighbors/detail/cagra/device_memory_ops.hpp +++ b/cpp/src/neighbors/detail/cagra/device_memory_ops.hpp @@ -54,6 +54,11 @@ RAFT_DEVICE_INLINE_FUNCTION void lds(uint32_t& x, uint32_t addr) asm volatile("ld.shared.u32 {%0}, [%1];" : "=r"(x) : "r"(addr)); } +RAFT_DEVICE_INLINE_FUNCTION void lds(uint64_t& x, uint32_t addr) +{ + asm volatile("ld.shared.u64 {%0}, [%1];" : "=l"(x) : "r"(addr)); +} + RAFT_DEVICE_INLINE_FUNCTION void lds(uint32_t& x, const uint32_t* addr) { lds(x, uint32_t(__cvta_generic_to_shared(addr))); @@ -71,6 +76,16 @@ RAFT_DEVICE_INLINE_FUNCTION void lds(uint4& x, const uint4* addr) lds(x, uint32_t(__cvta_generic_to_shared(addr))); } +RAFT_DEVICE_INLINE_FUNCTION void sts(uint32_t addr, const uint32_t& x) +{ + asm volatile("st.shared.u32 [%0], %1;" : : "r"(addr), "r"(reinterpret_cast(x))); +} + +RAFT_DEVICE_INLINE_FUNCTION void sts(uint32_t addr, const uint64_t& x) +{ + asm volatile("st.shared.u64 [%0], %1;" : : "r"(addr), "l"(reinterpret_cast(x))); +} + RAFT_DEVICE_INLINE_FUNCTION void sts(uint32_t addr, const half2& x) { asm volatile("st.shared.v2.u16 [%0], {%1, %2};" diff --git a/cpp/src/neighbors/detail/cagra/factory.cuh b/cpp/src/neighbors/detail/cagra/factory.cuh index 26cd13bab8..a1e2f6be9c 100644 --- a/cpp/src/neighbors/detail/cagra/factory.cuh +++ b/cpp/src/neighbors/detail/cagra/factory.cuh @@ -87,6 +87,7 @@ struct key { uint32_t extra_val; // this one has different meanings for different descriptor types uint32_t team_size; uint32_t metric; + uint32_t smem_dtype; }; template @@ -100,7 +101,8 @@ auto make_key(const cagra::search_params& params, dataset.dim(), dataset.stride(), uint32_t(params.team_size), - uint32_t(metric)}; + uint32_t(metric), + uint32_t(params.smem_dtype)}; } template @@ -114,20 +116,22 @@ auto make_key(const cagra::search_params& params, dataset.dim(), uint32_t(reinterpret_cast(dataset.pq_code_book.data_handle()) >> 6), uint32_t(params.team_size), - uint32_t(metric)}; + uint32_t(metric), + uint32_t(params.smem_dtype)}; } inline auto operator==(const key& a, const key& b) -> bool { return a.data_ptr == b.data_ptr && a.n_rows == b.n_rows && a.dim == b.dim && - a.extra_val == b.extra_val && a.team_size == b.team_size && a.metric == b.metric; + a.extra_val == b.extra_val && a.team_size == b.team_size && a.metric == b.metric && + a.smem_dtype == b.smem_dtype; } struct key_hash { inline auto operator()(const key& x) const noexcept -> std::size_t { return size_t{x.data_ptr} + size_t{x.n_rows} * size_t{x.dim} * size_t{x.extra_val} + - (size_t{x.team_size} ^ size_t{x.metric}); + (size_t{x.team_size} ^ size_t{x.metric}) + size_t{x.smem_dtype}; } }; diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp index 60d17c5128..60e965796c 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp @@ -57,10 +57,14 @@ std::shared_ptr build_single_cta_launcher( persistent); if constexpr (std::is_same_v) { - planner.add_setup_workspace_device_function( - dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len); - planner.add_compute_distance_device_function( - dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len); + planner.add_setup_workspace_device_function(dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.pq_len, + dataset_desc.enable_fp8); + planner.add_compute_distance_device_function(dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.pq_len, + dataset_desc.enable_fp8); } else { planner.add_setup_workspace_device_function(dataset_desc.team_size, dataset_desc.dataset_block_dim); @@ -102,10 +106,14 @@ std::shared_ptr build_multi_cta_launcher( dataset_desc.pq_len); if constexpr (std::is_same_v) { - planner.add_setup_workspace_device_function( - dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len); - planner.add_compute_distance_device_function( - dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len); + planner.add_setup_workspace_device_function(dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.pq_len, + dataset_desc.enable_fp8); + planner.add_compute_distance_device_function(dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.pq_len, + dataset_desc.enable_fp8); } else { planner.add_setup_workspace_device_function(dataset_desc.team_size, dataset_desc.dataset_block_dim); @@ -147,10 +155,14 @@ std::shared_ptr build_multi_kernel_launcher( dataset_desc.pq_bits, dataset_desc.pq_len); if constexpr (std::is_same_v) { - planner.add_setup_workspace_device_function( - dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len); - planner.add_compute_distance_device_function( - dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len); + planner.add_setup_workspace_device_function(dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.pq_len, + dataset_desc.enable_fp8); + planner.add_compute_distance_device_function(dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.pq_len, + dataset_desc.enable_fp8); } else { planner.add_setup_workspace_device_function(dataset_desc.team_size, dataset_desc.dataset_block_dim); diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp index b9e7891723..14ef271c2a 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp @@ -50,11 +50,13 @@ struct CagraPlannerBase : AlgorithmPlanner { TeamSz, Dim, PqBitsV, - PqLenV>>(); + PqLenV, + tag_smem_f16>>(); }; - dispatch_cagra_team_dim(team_size, dataset_block_dim, [&add]() { - add.template operator()(); - }); + dispatch_cagra_standard_team_dim( + team_size, dataset_block_dim, [&add]() { + add.template operator()(); + }); } /// VPQ (`tag_codebook_half`): JIT matrix fixes `pq_bits=8`; only `pq_len` is selected at runtime. @@ -62,32 +64,39 @@ struct CagraPlannerBase : AlgorithmPlanner { std::enable_if_t, int> = 0> void add_setup_workspace_device_function(uint32_t team_size, uint32_t dataset_block_dim, - uint32_t pq_len) + uint32_t pq_len, + bool enable_fp8) { if (pq_len != 2 && pq_len != 4 && pq_len != 8) { RAFT_FAIL("CAGRA JIT VPQ setup_workspace expects pq_len in {2,4,8} (matrix uses pq_bits=8)"); } - auto add = [&]() { - this->add_static_fragment>(); + auto add = + [&]() { + this->add_static_fragment>(); + }; + auto dispatch_smem = [&]() { + dispatch_cagra_vpq_team_dim( + team_size, + dataset_block_dim, + pq_len, + [&add]() { + add.template operator()(); + }); }; - dispatch_cagra_team_dim( - team_size, dataset_block_dim, [&add, pq_len]() { - if (pq_len == 2) { - add.template operator()(); - } else if (pq_len == 4) { - add.template operator()(); - } else { - add.template operator()(); - } - }); + if (enable_fp8) { + dispatch_smem.template operator()(); + } else { + dispatch_smem.template operator()(); + } } /// Registers dist_op + normalization + `compute_distance` for standard layout. @@ -108,11 +117,13 @@ struct CagraPlannerBase : AlgorithmPlanner { TeamSz, Dim, PqBitsV, - PqLenV>>(); + PqLenV, + tag_smem_f16>>(); }; - dispatch_cagra_team_dim(team_size, dataset_block_dim, [&add]() { - add.template operator()(); - }); + dispatch_cagra_standard_team_dim( + team_size, dataset_block_dim, [&add]() { + add.template operator()(); + }); } /// VPQ: only the `compute_distance` fragment (no standard dist_op / normalization in this path). @@ -120,35 +131,179 @@ struct CagraPlannerBase : AlgorithmPlanner { std::enable_if_t, int> = 0> void add_compute_distance_device_function(uint32_t team_size, uint32_t dataset_block_dim, - uint32_t pq_len) + uint32_t pq_len, + bool enable_fp8) { if (pq_len != 2 && pq_len != 4 && pq_len != 8) { RAFT_FAIL("CAGRA JIT VPQ compute_distance expects pq_len in {2,4,8} (matrix uses pq_bits=8)"); } - auto add = [&]() { - this->add_static_fragment>(); + auto add = + [&]() { + this->add_static_fragment>(); + }; + auto dispatch_smem = [&]() { + dispatch_cagra_vpq_team_dim( + team_size, + dataset_block_dim, + pq_len, + [&add]() { + add.template operator()(); + }); }; - dispatch_cagra_team_dim( - team_size, dataset_block_dim, [&add, pq_len]() { - if (pq_len == 2) { - add.template operator()(); - } else if (pq_len == 4) { - add.template operator()(); - } else { - add.template operator()(); - } - }); + if (enable_fp8) { + dispatch_smem.template operator()(); + } else { + dispatch_smem.template operator()(); + } } private: + template + static void dispatch_cagra_standard_team_dim(uint32_t team_size, + uint32_t dataset_block_dim, + Lambda&& l) + { + switch (team_size) { + case 8: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()<8u, 128u>(); return; + case 256: std::forward(l).template operator()<8u, 256u>(); return; + case 512: std::forward(l).template operator()<8u, 512u>(); return; + default: break; + } + break; + case 16: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()<16u, 128u>(); return; + case 256: std::forward(l).template operator()<16u, 256u>(); return; + case 512: std::forward(l).template operator()<16u, 512u>(); return; + default: break; + } + break; + case 32: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()<32u, 128u>(); return; + case 256: std::forward(l).template operator()<32u, 256u>(); return; + case 512: std::forward(l).template operator()<32u, 512u>(); return; + default: break; + } + break; + default: break; + } + RAFT_FAIL("Unsupported standard team_size / dataset_block_dim for CAGRA JIT: team=%u dim=%u", + static_cast(team_size), + static_cast(dataset_block_dim)); + } + + template + static void dispatch_cagra_vpq_pq2_4_team_dim(uint32_t team_size, + uint32_t dataset_block_dim, + Lambda&& l) + { + switch (team_size) { + case 8: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()<8u, 128u, 8u, PqLenV>(); return; + case 256: std::forward(l).template operator()<8u, 256u, 8u, PqLenV>(); return; + case 512: std::forward(l).template operator()<8u, 512u, 8u, PqLenV>(); return; + default: break; + } + break; + case 16: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()<16u, 128u, 8u, PqLenV>(); return; + case 256: std::forward(l).template operator()<16u, 256u, 8u, PqLenV>(); return; + case 512: std::forward(l).template operator()<16u, 512u, 8u, PqLenV>(); return; + default: break; + } + break; + case 32: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()<32u, 128u, 8u, PqLenV>(); return; + case 256: std::forward(l).template operator()<32u, 256u, 8u, PqLenV>(); return; + case 512: std::forward(l).template operator()<32u, 512u, 8u, PqLenV>(); return; + default: break; + } + break; + default: break; + } + RAFT_FAIL( + "Unsupported VPQ pq_len=%u team_size / dataset_block_dim for CAGRA JIT: team=%u dim=%u", + static_cast(PqLenV), + static_cast(team_size), + static_cast(dataset_block_dim)); + } + + template + static void dispatch_cagra_vpq_pq8_team_dim(uint32_t team_size, + uint32_t dataset_block_dim, + Lambda&& l) + { + switch (team_size) { + case 4: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()<4u, 128u, 8u, 8u>(); return; + default: break; + } + break; + case 8: + switch (dataset_block_dim) { + case 256: std::forward(l).template operator()<8u, 256u, 8u, 8u>(); return; + default: break; + } + break; + case 16: + switch (dataset_block_dim) { + case 512: std::forward(l).template operator()<16u, 512u, 8u, 8u>(); return; + default: break; + } + break; + case 32: + switch (dataset_block_dim) { + case 1024: std::forward(l).template operator()<32u, 1024u, 8u, 8u>(); return; + default: break; + } + break; + default: break; + } + RAFT_FAIL( + "Unsupported VPQ pq_len=8 team_size / dataset_block_dim for CAGRA JIT: team=%u dim=%u", + static_cast(team_size), + static_cast(dataset_block_dim)); + } + + template + static void dispatch_cagra_vpq_team_dim(uint32_t team_size, + uint32_t dataset_block_dim, + uint32_t pq_len, + Lambda&& l) + { + switch (pq_len) { + case 2: + dispatch_cagra_vpq_pq2_4_team_dim<2u>( + team_size, dataset_block_dim, std::forward(l)); + return; + case 4: + dispatch_cagra_vpq_pq2_4_team_dim<4u>( + team_size, dataset_block_dim, std::forward(l)); + return; + case 8: + dispatch_cagra_vpq_pq8_team_dim(team_size, dataset_block_dim, std::forward(l)); + return; + default: break; + } + RAFT_FAIL("CAGRA JIT VPQ expects pq_len in {2,4,8}; got %u", static_cast(pq_len)); + } + void add_dist_op_device_function(cuvs::distance::DistanceType metric) { // dist_op_matrix.json pairs tag_metric_hamming with uint8 query (tag_u8) only; L2/IP/L1 use @@ -193,15 +348,16 @@ struct CagraPlannerBase : AlgorithmPlanner { uint32_t dataset_block_dim) { auto go = [&]() { - dispatch_cagra_team_dim(team_size, dataset_block_dim, [&]() { - this->add_static_fragment>(); - }); + dispatch_cagra_standard_team_dim( + team_size, dataset_block_dim, [&]() { + this->add_static_fragment>(); + }); }; // tag_u8 is only used for BitwiseHamming query layout; cosine norm fragments are built for // float query tag. Use if constexpr so we do not instantiate tag_norm_cosine with tag_u8 diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh index 08f44c171e..59b43bea64 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh @@ -100,9 +100,10 @@ _RAFT_DEVICE RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_vpq_worker_impl( constexpr auto DatasetBlockDim = DescriptorT::kDatasetBlockDim; constexpr auto PQ_BITS = DescriptorT::kPqBits; constexpr auto PQ_LEN = DescriptorT::kPqLen; + constexpr auto EnableFP8 = DescriptorT::kEnableFP8; using PQ_CODEBOOK_LOAD_T = uint32_t; - using smem_val_config = vpq_smem_value_config; + using smem_val_config = vpq_smem_value_config; using smem_val_pack_t = typename smem_val_config::smem_val_pack_t; using smem_val_pack_uint_t = typename smem_val_config::smem_val_pack_uint_t; constexpr uint32_t num_packed_elements = smem_val_config::num_packed_elements; @@ -154,23 +155,55 @@ _RAFT_DEVICE RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_vpq_worker_impl( for (std::uint32_t m = 0; m < PQ_LEN / num_packed_elements; m++) { constexpr uint32_t vq_val_pack_num_elements = 2; constexpr auto kQueryBlock = DatasetBlockDim / (vlen * PQ_LEN); - const std::uint32_t vq_half2_index = + std::uint32_t vq_half2_index = m * (num_packed_elements / vq_val_pack_num_elements) + (PQ_LEN / 2) * v; - static_assert(num_packed_elements == 2, - "CAGRA JIT VPQ currently stores pq_len=8 in half2 shared-memory packs"); - const uint32_t query_val_index = - vq_half2_index * kQueryBlock + elem_offset * (PQ_LEN / 2) + e * TeamSize + laneId; + uint32_t query_val_index; + if constexpr (num_packed_elements == 2) { + query_val_index = + vq_half2_index * kQueryBlock + elem_offset * (PQ_LEN / 2) + e * TeamSize + laneId; + } else if constexpr (PQ_LEN == num_packed_elements) { + query_val_index = elem_offset + v * (DatasetBlockDim / (num_packed_elements * vlen)) + + e * TeamSize + laneId; + } else { + const uint32_t query_vec_element_id = + (elem_offset + e * vlen * TeamSize + v + laneId * vlen) * PQ_LEN / + num_packed_elements; + constexpr auto kStride = vlen * PQ_LEN / num_packed_elements; + query_val_index = + transpose(query_vec_element_id); + } - smem_val_pack_t q2, c2; - device::lds(q2, query_ptr + sizeof(smem_val_pack_t) * query_val_index); - device::lds(c2, - pq_codebook_ptr + - sizeof(smem_val_pack_uint_t) * ((1 << PQ_BITS) * m + (pq_code & 0xff))); - auto dist = - q2 - c2 - reinterpret_cast(vq_vals)[vq_half2_index]; - dist = dist * dist; - norm += static_cast(dist.x + dist.y); + if constexpr (num_packed_elements == 2) { + smem_val_pack_t q2, c2; + device::lds(q2, query_ptr + sizeof(smem_val_pack_t) * query_val_index); + device::lds(c2, + pq_codebook_ptr + + sizeof(smem_val_pack_uint_t) * ((1 << PQ_BITS) * m + (pq_code & 0xff))); + auto dist = + q2 - c2 - reinterpret_cast(vq_vals)[vq_half2_index]; + dist = dist * dist; + norm += static_cast(dist.x + dist.y); + } else if constexpr (num_packed_elements == 4 || num_packed_elements == 8) { + smem_val_pack_t q_vec, c_vec; + device::lds(q_vec.as_uint(), + query_ptr + sizeof(smem_val_pack_uint_t) * query_val_index); + device::lds(c_vec.as_uint(), + pq_codebook_ptr + + sizeof(smem_val_pack_uint_t) * ((1 << PQ_BITS) * m + (pq_code & 0xff))); + + half2 q2, c2; +#pragma unroll + for (uint32_t bi = 0; bi < num_packed_elements / 2; bi++) { + q2 = q_vec.as_half2(bi); + c2 = c_vec.as_half2(bi); + auto dist = + q2 - c2 - reinterpret_cast(vq_vals)[vq_half2_index]; + dist = dist * dist; + norm += static_cast(dist.x + dist.y); + vq_half2_index += 1; + } + } } pq_code >>= 8; } @@ -237,7 +270,8 @@ template + typename QueryT, + bool EnableFP8> __device__ DistanceT compute_distance_impl( const typename dataset_descriptor_base_t::args_t args, IndexT dataset_index) @@ -256,7 +290,8 @@ __device__ DistanceT compute_distance_impl( DataT, IndexT, DistanceT, - QueryT>; + QueryT, + EnableFP8>; return compute_distance_vpq_impl(args, dataset_index); } else { static_assert(sizeof(TeamSize) == 0, diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in index 13cd022918..130cbf502f 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in @@ -11,6 +11,7 @@ constexpr uint32_t k_team_size = @team_size@u; constexpr uint32_t k_dataset_block_dim = @dataset_block_dim@u; constexpr uint32_t k_pq_bits = @pq_bits@u; constexpr uint32_t k_pq_len = @pq_len@u; +constexpr bool k_enable_fp8 = @enable_fp8@; using data_t = @data_type@; using index_t = @index_type@; @@ -38,7 +39,8 @@ __device__ distance_t compute_distance(const args_t data_t, index_t, distance_t, - query_t>(args, dataset_index) + query_t, + k_enable_fp8>(args, dataset_index) : distance_t{}; return device::team_sum(per_thread, team_size_bits); } @@ -55,7 +57,8 @@ compute_distance_per_thread(const args_t args, inde data_t, index_t, distance_t, - query_t>(args, dataset_index); + query_t, + k_enable_fp8>(args, dataset_index); } } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json index 2e64ee2ce1..f1ce0daaab 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json @@ -81,6 +81,12 @@ "codebook_type": "void", "codebook_abbrev": "none" } + ], + "_smem": [ + { + "enable_fp8": "false", + "smem_abbrev": "f16" + } ] }, { @@ -149,6 +155,16 @@ "codebook_type": "half", "codebook_abbrev": "half" } + ], + "_smem": [ + { + "enable_fp8": "false", + "smem_abbrev": "f16" + }, + { + "enable_fp8": "true", + "smem_abbrev": "e5m2" + } ] }, { @@ -219,6 +235,16 @@ "dataset_block_dim": "1024", "team_size": "32" } + ], + "_smem": [ + { + "enable_fp8": "false", + "smem_abbrev": "f16" + }, + { + "enable_fp8": "true", + "smem_abbrev": "e5m2" + } ] } ] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuh index ed83c181fe..220c76ac96 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuh @@ -84,7 +84,8 @@ _RAFT_DEVICE __noinline__ auto setup_workspace_vpq_impl( constexpr auto kDatasetBlockDim = DescriptorT::kDatasetBlockDim; constexpr auto PQ_BITS = DescriptorT::kPqBits; constexpr auto PQ_LEN = DescriptorT::kPqLen; - using smem_val_config = vpq_smem_value_config; + constexpr auto EnableFP8 = DescriptorT::kEnableFP8; + using smem_val_config = vpq_smem_value_config; using smem_val_t = typename smem_val_config::smem_val_t; using smem_val_pack_t = typename smem_val_config::smem_val_pack_t; using smem_val_pack_uint_t = typename smem_val_config::smem_val_pack_uint_t; @@ -120,10 +121,20 @@ _RAFT_DEVICE __noinline__ auto setup_workspace_vpq_impl( const auto smem_index = (j / num_banks_per_subspace) + (j % num_banks_per_subspace) * (1 << PQ_BITS); - smem_val_pack_t buf; - buf.x = r->pq_code_book_ptr()[i]; - buf.y = r->pq_code_book_ptr()[i + 1]; - device::sts(codebook_buf + smem_index * sizeof(smem_val_pack_t), buf); + if constexpr (num_packed_elements == 2) { + smem_val_pack_t buf; + buf.x = r->pq_code_book_ptr()[i]; + buf.y = r->pq_code_book_ptr()[i + 1]; + device::sts(codebook_buf + smem_index * sizeof(smem_val_pack_t), buf); + } else if constexpr (num_packed_elements == 4 || num_packed_elements == 8) { + smem_val_pack_t buf; +#pragma unroll + for (uint32_t k = 0; k < num_packed_elements; k++) { + buf.data.x1[k] = + static_cast(static_cast(r->pq_code_book_ptr()[i + k])); + } + device::sts(codebook_buf + smem_index * sizeof(smem_val_pack_uint_t), buf.as_uint()); + } } } } @@ -137,9 +148,21 @@ _RAFT_DEVICE __noinline__ auto setup_workspace_vpq_impl( DescriptorT::kSMemCodeBookSizeInBytes); for (unsigned i = threadIdx.x * num_packed_elements; i < dim; i += blockDim.x * num_packed_elements) { - smem_val_pack_t buf{0, 0}; - if (i < dim) { buf.x = mapping(queries_ptr[i]); } - if (i + 1 < dim) { buf.y = mapping(queries_ptr[i + 1]); } + smem_val_pack_t buf; + if constexpr (num_packed_elements == 2) { + buf.x = 0; + buf.y = 0; + if (i < dim) { buf.x = mapping(queries_ptr[i]); } + if (i + 1 < dim) { buf.y = mapping(queries_ptr[i + 1]); } + } else if constexpr (num_packed_elements == 4 || num_packed_elements == 8) { +#pragma unroll + for (uint32_t k = 0; k < num_packed_elements; k++) { + buf.data.x1[k] = static_cast(0.0f); + if (i + k < dim) { + buf.data.x1[k] = static_cast(static_cast(mapping(queries_ptr[i + k]))); + } + } + } if constexpr ((PQ_BITS == 8) && (PQ_LEN % num_packed_elements == 0)) { constexpr uint32_t vlen = 4; // **** DO NOT CHANGE **** constexpr auto kStride = vlen * PQ_LEN / num_packed_elements; @@ -162,7 +185,8 @@ template + typename QueryT, + bool EnableFP8> __device__ const dataset_descriptor_base_t* setup_workspace_impl( const dataset_descriptor_base_t* desc_ptr, void* smem, @@ -186,7 +210,8 @@ __device__ const dataset_descriptor_base_t* setup_work DataT, IndexT, DistanceT, - QueryT>; + QueryT, + EnableFP8>; const desc_t* desc = static_cast(desc_ptr); const desc_t* result = setup_workspace_vpq_impl(desc, smem, queries, query_id); diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in index fa17705250..2177212e36 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in @@ -12,6 +12,7 @@ constexpr uint32_t k_team_size = @team_size@u; constexpr uint32_t k_dataset_block_dim = @dataset_block_dim@u; constexpr uint32_t k_pq_bits = @pq_bits@u; constexpr uint32_t k_pq_len = @pq_len@u; +constexpr bool k_enable_fp8 = @enable_fp8@; using data_t = @data_type@; using index_t = @index_type@; @@ -39,7 +40,8 @@ setup_workspace( data_t, index_t, distance_t, - query_t>(desc, smem, queries, query_id); + query_t, + k_enable_fp8>(desc, smem, queries, query_id); } } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json index 64c82ce13a..7ee92494e6 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json @@ -81,6 +81,12 @@ "codebook_type": "void", "codebook_abbrev": "none" } + ], + "_smem": [ + { + "enable_fp8": "false", + "smem_abbrev": "f16" + } ] }, { @@ -149,6 +155,16 @@ "codebook_type": "half", "codebook_abbrev": "half" } + ], + "_smem": [ + { + "enable_fp8": "false", + "smem_abbrev": "f16" + }, + { + "enable_fp8": "true", + "smem_abbrev": "e5m2" + } ] }, { @@ -219,6 +235,16 @@ "dataset_block_dim": "1024", "team_size": "32" } + ], + "_smem": [ + { + "enable_fp8": "false", + "smem_abbrev": "f16" + }, + { + "enable_fp8": "true", + "smem_abbrev": "e5m2" + } ] } ] diff --git a/cpp/src/neighbors/detail/cagra/packed_type.hpp b/cpp/src/neighbors/detail/cagra/packed_type.hpp new file mode 100644 index 0000000000..f52edc126b --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/packed_type.hpp @@ -0,0 +1,49 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once +#include +#include + +#include +#include + +namespace cuvs::neighbors::cagra::detail::device { +template +struct uintN_t {}; +template <> +struct uintN_t<32> { + using type = uint32_t; +}; +template <> +struct uintN_t<64> { + using type = uint64_t; +}; + +template +struct fp8xN {}; + +template +struct fp8xN { + using uint_t = typename uintN_t<8 * NumPacked>::type; + using unit_t = __nv_fp8_e5m2; + using x2_t = __nv_fp8x2_storage_t; + static constexpr uint32_t num_elements = NumPacked; + + union { + unit_t x1[num_elements]; + x2_t x2[num_elements / 2]; + uint_t u; + } data; + + HDI fp8xN() { data.u = 0; } + + HDI uint_t& as_uint() { return data.u; } + HDI uint_t as_uint() const { return data.u; } + HDI half2 as_half2(const uint32_t i) const + { + return __nv_cvt_fp8x2_to_halfraw2(data.x2[i], __NV_E5M2); + } +}; +} // namespace cuvs::neighbors::cagra::detail::device From c323fa17eb8c603bb45a47e7c36c0fe3a6cc9d32 Mon Sep 17 00:00:00 2001 From: enp1s0 Date: Thu, 4 Jun 2026 18:40:30 +0900 Subject: [PATCH 05/15] Update EnableFP8 --- cpp/CMakeLists.txt | 14 ++++---- .../detail/jit_lto/cagra/cagra_fragments.hpp | 8 +++-- cpp/include/cuvs/neighbors/cagra.hpp | 6 ++++ .../detail/cagra/compute_distance.hpp | 17 ++++----- .../cagra/compute_distance_vpq-impl.cuh | 35 ++++++++++--------- .../detail/cagra/compute_distance_vpq.hpp | 12 ++++--- .../cagra/compute_distance_vpq_inst.cu.in | 4 +-- .../cagra/compute_distance_vpq_matrix.json | 12 +++++-- .../cagra_jit_launcher_factory.hpp | 12 +++---- .../jit_lto_kernels/cagra_planner_base.hpp | 32 ++++++++++------- .../jit_lto_kernels/compute_distance_impl.cuh | 8 ++--- .../compute_distance_kernel.cu.in | 6 ++-- .../compute_distance_matrix.json | 10 +++--- .../jit_lto_kernels/setup_workspace_impl.cuh | 8 ++--- .../setup_workspace_kernel.cu.in | 4 +-- .../setup_workspace_matrix.json | 10 +++--- cpp/tests/neighbors/ann_cagra.cuh | 22 +++++++++++- 17 files changed, 135 insertions(+), 85 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 7f9f88695c..a49df49812 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -268,12 +268,12 @@ if(NOT BUILD_CPU_ONLY) INPUT_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/compute_distance_vpq_inst.cu.in" OUTPUT_FILE_FORMAT - "${CMAKE_CURRENT_BINARY_DIR}/src/neighbors/detail/cagra/compute_distance_vpq_inst_data_@data_abbrev@_index_@index_abbrev@_distance_@distance_abbrev@_codebook_@codebook_abbrev@_metric_@metric@_team_@team_size@_dim_@dim@_pq_bits_@pq_bits@_pq_len_@pq_len@.cu" + "${CMAKE_CURRENT_BINARY_DIR}/src/neighbors/detail/cagra/compute_distance_vpq_inst_data_@data_abbrev@_index_@index_abbrev@_distance_@distance_abbrev@_codebook_@codebook_abbrev@_metric_@metric@_team_@team_size@_dim_@dim@_pq_bits_@pq_bits@_pq_len_@pq_len@_smem_@smem_abbrev@.cu" ) generate_string_matrix( cagra_compute_distance_vpq_selector_template_params ITEM_FORMAT - "\nvpq_descriptor_spec" + "\nvpq_descriptor_spec" GLUE "," MATRIX_JSON_FILE @@ -282,7 +282,7 @@ if(NOT BUILD_CPU_ONLY) generate_string_matrix( cagra_compute_distance_vpq_template_inst ITEM_FORMAT - "extern template struct vpq_descriptor_spec@semicolon@" + "extern template struct vpq_descriptor_spec@semicolon@" GLUE "\n" MATRIX_JSON_FILE @@ -688,13 +688,13 @@ if(NOT BUILD_CPU_ONLY) generate_jit_lto_kernels( jit_lto_files NAME_FORMAT - "cagra_setup_workspace@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_@pq_bits@pq_@pq_len@subd_data_@data_abbrev@_query_@query_abbrev@" + "cagra_setup_workspace@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_@pq_bits@pq_@pq_len@subd_data_@data_abbrev@_query_@query_abbrev@_smem_@smem_abbrev@" MATRIX_JSON_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json" KERNEL_INPUT_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in" FRAGMENT_TAG_FORMAT - "${cagra_ns}::fragment_tag_setup_workspace<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${neighbors_ns}::tag_@query_abbrev@, ${cagra_ns}::tag_codebook_@codebook_abbrev@, @team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@>" + "${cagra_ns}::fragment_tag_setup_workspace<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${neighbors_ns}::tag_@query_abbrev@, ${cagra_ns}::tag_codebook_@codebook_abbrev@, @team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@, ${cagra_ns}::tag_smem_@smem_abbrev@>" FRAGMENT_TAG_HEADER_FILES "" "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/setup_workspace" @@ -704,13 +704,13 @@ if(NOT BUILD_CPU_ONLY) generate_jit_lto_kernels( jit_lto_files NAME_FORMAT - "cagra_compute_distance@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_@pq_bits@pq_@pq_len@subd_data_@data_abbrev@_query_@query_abbrev@" + "cagra_compute_distance@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_@pq_bits@pq_@pq_len@subd_data_@data_abbrev@_query_@query_abbrev@_smem_@smem_abbrev@" MATRIX_JSON_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json" KERNEL_INPUT_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in" FRAGMENT_TAG_FORMAT - "${cagra_ns}::fragment_tag_compute_distance<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${neighbors_ns}::tag_@query_abbrev@, ${cagra_ns}::tag_codebook_@codebook_abbrev@, @team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@>" + "${cagra_ns}::fragment_tag_compute_distance<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${neighbors_ns}::tag_@query_abbrev@, ${cagra_ns}::tag_codebook_@codebook_abbrev@, @team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@, ${cagra_ns}::tag_smem_@smem_abbrev@>" FRAGMENT_TAG_HEADER_FILES "" "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/compute_distance" diff --git a/cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp index 0b42d79379..67cdd38783 100644 --- a/cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp +++ b/cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp @@ -16,6 +16,8 @@ struct tag_metric_cosine {}; struct tag_metric_hamming {}; struct tag_codebook_none {}; struct tag_codebook_half {}; +struct tag_smem_f16 {}; +struct tag_smem_e5m2 {}; struct tag_metric_l1 {}; struct tag_norm_noop {}; struct tag_norm_cosine {}; @@ -33,7 +35,8 @@ template + uint32_t PqLen, + typename SmemTag> struct fragment_tag_setup_workspace {}; template + uint32_t PqLen, + typename SmemTag> struct fragment_tag_compute_distance {}; template diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index 8edbcab8fa..d2a55ce406 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -276,6 +276,8 @@ enum class search_algo { enum class hash_mode { HASH = 0, SMALL = 1, AUTO = 100 }; +enum class internal_dtype { F16 = 0, E5M2 = 1, AUTO = 100 }; + struct search_params : cuvs::neighbors::search_params { /** Maximum number of queries to search at the same time (batch size). Auto select when 0.*/ size_t max_queries = 0; @@ -349,6 +351,10 @@ struct search_params : cuvs::neighbors::search_params { * negative, in which case the filtering rate is automatically calculated. */ float filtering_rate = -1.0; + + /** Data type of the query vector and codebook table on shared memory. Currently, only VPQ + * supports FP8. **/ + internal_dtype smem_dtype = internal_dtype::AUTO; }; /** diff --git a/cpp/src/neighbors/detail/cagra/compute_distance.hpp b/cpp/src/neighbors/detail/cagra/compute_distance.hpp index 7f921ce948..45997a62a3 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance.hpp +++ b/cpp/src/neighbors/detail/cagra/compute_distance.hpp @@ -202,12 +202,12 @@ struct dataset_descriptor_host { uint32_t team_size = 0; // JIT LTO metadata - stored when descriptor is created - cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded; - uint32_t dataset_block_dim = 0; - bool is_vpq = false; - uint32_t pq_bits = 0; - uint32_t pq_len = 0; - bool enable_fp8 = false; + cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded; + uint32_t dataset_block_dim = 0; + bool is_vpq = false; + uint32_t pq_bits = 0; + uint32_t pq_len = 0; + cuvs::neighbors::cagra::internal_dtype smem_dtype = cuvs::neighbors::cagra::internal_dtype::F16; // Codebook type is determined by DataT for VPQ (always half for now) struct state { @@ -260,7 +260,8 @@ struct dataset_descriptor_host { bool is_vpq_val = false, uint32_t pq_bits_val = 0, uint32_t pq_len_val = 0, - bool enable_fp8_val = false) + cuvs::neighbors::cagra::internal_dtype smem_dtype_val = + cuvs::neighbors::cagra::internal_dtype::F16) : value_{std::make_shared(init, sizeof(DescriptorImpl))}, smem_ws_size_in_bytes{dd_host.smem_ws_size_in_bytes()}, team_size{dd_host.team_size()}, @@ -269,7 +270,7 @@ struct dataset_descriptor_host { is_vpq{is_vpq_val}, pq_bits{pq_bits_val}, pq_len{pq_len_val}, - enable_fp8{enable_fp8_val} + smem_dtype{smem_dtype_val} { } diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh b/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh index ea994c450a..f73f901f95 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh @@ -15,21 +15,24 @@ namespace cuvs::neighbors::cagra::detail { -template +template struct vpq_smem_value_config; -template -struct vpq_smem_value_config> { +template +struct vpq_smem_value_config< + PQ_LEN, + SmemDType, + std::enable_if_t> { using smem_val_pack_t = half2; using smem_val_t = half; using smem_val_pack_uint_t = uint32_t; static constexpr uint32_t num_packed_elements = 2; }; -template +template struct vpq_smem_value_config> { + cuvs::neighbors::cagra::internal_dtype::E5M2, + std::enable_if_t> { using smem_val_pack_t = device::fp8xN; using smem_val_t = typename smem_val_pack_t::unit_t; using smem_val_pack_uint_t = typename smem_val_pack_t::uint_t; @@ -45,7 +48,7 @@ template + cuvs::neighbors::cagra::internal_dtype SmemDType> struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t { using base_type = dataset_descriptor_base_t; using CODE_BOOK_T = CodebookT; @@ -61,7 +64,7 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t, "Only CODE_BOOK_T = `half` is supported now"); @@ -104,7 +107,7 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t; + using smem_val_config = vpq_smem_value_config; static constexpr std::uint32_t kSMemCodeBookSizeInBytes = (1 << PQ_BITS) * PQ_LEN * utils::size_of() / @@ -152,7 +155,7 @@ template + cuvs::neighbors::cagra::internal_dtype SmemDType> RAFT_KERNEL __launch_bounds__(1, 1) vpq_dataset_descriptor_init_kernel(dataset_descriptor_base_t* out, const std::uint8_t* encoded_dataset_ptr, @@ -171,7 +174,7 @@ RAFT_KERNEL __launch_bounds__(1, 1) IndexT, DistanceT, half, - EnableFP8>; + SmemDType>; new (out) desc_type( encoded_dataset_ptr, encoded_dataset_dim, vq_code_book_ptr, pq_code_book_ptr, size, dim); } @@ -185,7 +188,7 @@ template + cuvs::neighbors::cagra::internal_dtype SmemDType> dataset_descriptor_host vpq_descriptor_spec::init_(const cagra::search_params& params, + SmemDType>::init_(const cagra::search_params& params, const std::uint8_t* encoded_dataset_ptr, uint32_t encoded_dataset_dim, const CodebookT* vq_code_book_ptr, @@ -213,7 +216,7 @@ vpq_descriptor_spec; + SmemDType>; return host_type{ desc_type{ @@ -229,7 +232,7 @@ vpq_descriptor_spec<<<1, 1, 0, stream>>>(dev_ptr, + SmemDType><<<1, 1, 0, stream>>>(dev_ptr, encoded_dataset_ptr, encoded_dataset_dim, vq_code_book_ptr, @@ -243,7 +246,7 @@ vpq_descriptor_spec + cuvs::neighbors::cagra::internal_dtype SmemDType> struct vpq_descriptor_spec : public instance_spec { using base_type = instance_spec; using typename base_type::data_type; @@ -66,9 +66,11 @@ struct vpq_descriptor_spec : public instance_spec { cuvs::distance::DistanceType metric) -> double { const auto fp8_natively_supported = raft::getComputeCapability().first >= 9; - const auto use_fp8 = - params.smem_dtype == cuvs::neighbors::cagra::internal_dtype::E5M2 || - (params.smem_dtype == cuvs::neighbors::cagra::internal_dtype::AUTO && fp8_natively_supported); + const auto selected_smem_dtype = + params.smem_dtype == cuvs::neighbors::cagra::internal_dtype::AUTO + ? (fp8_natively_supported ? cuvs::neighbors::cagra::internal_dtype::E5M2 + : cuvs::neighbors::cagra::internal_dtype::F16) + : params.smem_dtype; // If explicit team_size is specified and doesn't match the instance, discard it if (params.team_size != 0 && TeamSize != params.team_size) { return -1.0; } @@ -76,7 +78,7 @@ struct vpq_descriptor_spec : public instance_spec { // Match codebook params if (dataset.pq_bits() != PqBits) { return -1.0; } if (dataset.pq_len() != PqLen) { return -1.0; } - if (use_fp8 != EnableFP8) { return -1.0; } + if (selected_smem_dtype != SmemDType) { return -1.0; } // Keep auto-selection on the tuned VPQ diagonal while allowing explicit team_size requests to // use the expanded team_size / dataset_block_dim grid. constexpr std::uint32_t auto_dataset_block_dim_per_team = PqLen == 8 ? 32 : 16; diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq_inst.cu.in b/cpp/src/neighbors/detail/cagra/compute_distance_vpq_inst.cu.in index 676f25c9fd..25d4732a34 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq_inst.cu.in +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq_inst.cu.in @@ -13,7 +13,7 @@ constexpr uint32_t team_size = @team_size@; constexpr uint32_t dim = @dim@; constexpr uint32_t pq_bits = @pq_bits@; constexpr uint32_t pq_len = @pq_len@; -constexpr bool enable_fp8 = @enable_fp8@; +constexpr auto smem_dtype = @smem_dtype@; using codebook_t = @codebook_type@; using data_t = @data_type@; using index_t = @index_type@; @@ -32,6 +32,6 @@ template struct vpq_descriptor_spec; + smem_dtype>; } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq_matrix.json b/cpp/src/neighbors/detail/cagra/compute_distance_vpq_matrix.json index 7dac07c2a4..1241b2346c 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq_matrix.json +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq_matrix.json @@ -97,8 +97,14 @@ "metric": [ "L2Expanded" ], - "enable_fp8": [ - "true", - "false" + "_smem": [ + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::E5M2", + "smem_abbrev": "e5m2" + }, + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::F16", + "smem_abbrev": "f16" + } ] } diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp index 60e965796c..973a5a1176 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp @@ -60,11 +60,11 @@ std::shared_ptr build_single_cta_launcher( planner.add_setup_workspace_device_function(dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len, - dataset_desc.enable_fp8); + dataset_desc.smem_dtype); planner.add_compute_distance_device_function(dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len, - dataset_desc.enable_fp8); + dataset_desc.smem_dtype); } else { planner.add_setup_workspace_device_function(dataset_desc.team_size, dataset_desc.dataset_block_dim); @@ -109,11 +109,11 @@ std::shared_ptr build_multi_cta_launcher( planner.add_setup_workspace_device_function(dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len, - dataset_desc.enable_fp8); + dataset_desc.smem_dtype); planner.add_compute_distance_device_function(dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len, - dataset_desc.enable_fp8); + dataset_desc.smem_dtype); } else { planner.add_setup_workspace_device_function(dataset_desc.team_size, dataset_desc.dataset_block_dim); @@ -158,11 +158,11 @@ std::shared_ptr build_multi_kernel_launcher( planner.add_setup_workspace_device_function(dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len, - dataset_desc.enable_fp8); + dataset_desc.smem_dtype); planner.add_compute_distance_device_function(dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len, - dataset_desc.enable_fp8); + dataset_desc.smem_dtype); } else { planner.add_setup_workspace_device_function(dataset_desc.team_size, dataset_desc.dataset_block_dim); diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp index 14ef271c2a..cce18a0216 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp @@ -65,7 +65,7 @@ struct CagraPlannerBase : AlgorithmPlanner { void add_setup_workspace_device_function(uint32_t team_size, uint32_t dataset_block_dim, uint32_t pq_len, - bool enable_fp8) + cuvs::neighbors::cagra::internal_dtype smem_dtype) { if (pq_len != 2 && pq_len != 4 && pq_len != 8) { RAFT_FAIL("CAGRA JIT VPQ setup_workspace expects pq_len in {2,4,8} (matrix uses pq_bits=8)"); @@ -92,11 +92,7 @@ struct CagraPlannerBase : AlgorithmPlanner { add.template operator()(); }); }; - if (enable_fp8) { - dispatch_smem.template operator()(); - } else { - dispatch_smem.template operator()(); - } + dispatch_cagra_smem_dtype(smem_dtype, dispatch_smem); } /// Registers dist_op + normalization + `compute_distance` for standard layout. @@ -132,7 +128,7 @@ struct CagraPlannerBase : AlgorithmPlanner { void add_compute_distance_device_function(uint32_t team_size, uint32_t dataset_block_dim, uint32_t pq_len, - bool enable_fp8) + cuvs::neighbors::cagra::internal_dtype smem_dtype) { if (pq_len != 2 && pq_len != 4 && pq_len != 8) { RAFT_FAIL("CAGRA JIT VPQ compute_distance expects pq_len in {2,4,8} (matrix uses pq_bits=8)"); @@ -159,14 +155,26 @@ struct CagraPlannerBase : AlgorithmPlanner { add.template operator()(); }); }; - if (enable_fp8) { - dispatch_smem.template operator()(); - } else { - dispatch_smem.template operator()(); - } + dispatch_cagra_smem_dtype(smem_dtype, dispatch_smem); } private: + template + static void dispatch_cagra_smem_dtype(cuvs::neighbors::cagra::internal_dtype smem_dtype, + Lambda&& l) + { + switch (smem_dtype) { + case cuvs::neighbors::cagra::internal_dtype::F16: + std::forward(l).template operator()(); + return; + case cuvs::neighbors::cagra::internal_dtype::E5M2: + std::forward(l).template operator()(); + return; + default: break; + } + RAFT_FAIL("Unsupported CAGRA JIT smem_dtype: %u", static_cast(smem_dtype)); + } + template static void dispatch_cagra_standard_team_dim(uint32_t team_size, uint32_t dataset_block_dim, diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh index 59b43bea64..ea817110e7 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh @@ -100,10 +100,10 @@ _RAFT_DEVICE RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_vpq_worker_impl( constexpr auto DatasetBlockDim = DescriptorT::kDatasetBlockDim; constexpr auto PQ_BITS = DescriptorT::kPqBits; constexpr auto PQ_LEN = DescriptorT::kPqLen; - constexpr auto EnableFP8 = DescriptorT::kEnableFP8; + constexpr auto SmemDType = DescriptorT::kSmemDType; using PQ_CODEBOOK_LOAD_T = uint32_t; - using smem_val_config = vpq_smem_value_config; + using smem_val_config = vpq_smem_value_config; using smem_val_pack_t = typename smem_val_config::smem_val_pack_t; using smem_val_pack_uint_t = typename smem_val_config::smem_val_pack_uint_t; constexpr uint32_t num_packed_elements = smem_val_config::num_packed_elements; @@ -271,7 +271,7 @@ template + cuvs::neighbors::cagra::internal_dtype SmemDType> __device__ DistanceT compute_distance_impl( const typename dataset_descriptor_base_t::args_t args, IndexT dataset_index) @@ -291,7 +291,7 @@ __device__ DistanceT compute_distance_impl( IndexT, DistanceT, QueryT, - EnableFP8>; + SmemDType>; return compute_distance_vpq_impl(args, dataset_index); } else { static_assert(sizeof(TeamSize) == 0, diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in index 130cbf502f..1856781391 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in @@ -11,7 +11,7 @@ constexpr uint32_t k_team_size = @team_size@u; constexpr uint32_t k_dataset_block_dim = @dataset_block_dim@u; constexpr uint32_t k_pq_bits = @pq_bits@u; constexpr uint32_t k_pq_len = @pq_len@u; -constexpr bool k_enable_fp8 = @enable_fp8@; +constexpr auto k_smem_dtype = @smem_dtype@; using data_t = @data_type@; using index_t = @index_type@; @@ -40,7 +40,7 @@ __device__ distance_t compute_distance(const args_t index_t, distance_t, query_t, - k_enable_fp8>(args, dataset_index) + k_smem_dtype>(args, dataset_index) : distance_t{}; return device::team_sum(per_thread, team_size_bits); } @@ -58,7 +58,7 @@ compute_distance_per_thread(const args_t args, inde index_t, distance_t, query_t, - k_enable_fp8>(args, dataset_index); + k_smem_dtype>(args, dataset_index); } } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json index f1ce0daaab..4d260c5507 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json @@ -84,7 +84,7 @@ ], "_smem": [ { - "enable_fp8": "false", + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::F16", "smem_abbrev": "f16" } ] @@ -158,11 +158,11 @@ ], "_smem": [ { - "enable_fp8": "false", + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::F16", "smem_abbrev": "f16" }, { - "enable_fp8": "true", + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::E5M2", "smem_abbrev": "e5m2" } ] @@ -238,11 +238,11 @@ ], "_smem": [ { - "enable_fp8": "false", + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::F16", "smem_abbrev": "f16" }, { - "enable_fp8": "true", + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::E5M2", "smem_abbrev": "e5m2" } ] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuh index 220c76ac96..494e0973fe 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuh @@ -84,8 +84,8 @@ _RAFT_DEVICE __noinline__ auto setup_workspace_vpq_impl( constexpr auto kDatasetBlockDim = DescriptorT::kDatasetBlockDim; constexpr auto PQ_BITS = DescriptorT::kPqBits; constexpr auto PQ_LEN = DescriptorT::kPqLen; - constexpr auto EnableFP8 = DescriptorT::kEnableFP8; - using smem_val_config = vpq_smem_value_config; + constexpr auto SmemDType = DescriptorT::kSmemDType; + using smem_val_config = vpq_smem_value_config; using smem_val_t = typename smem_val_config::smem_val_t; using smem_val_pack_t = typename smem_val_config::smem_val_pack_t; using smem_val_pack_uint_t = typename smem_val_config::smem_val_pack_uint_t; @@ -186,7 +186,7 @@ template + cuvs::neighbors::cagra::internal_dtype SmemDType> __device__ const dataset_descriptor_base_t* setup_workspace_impl( const dataset_descriptor_base_t* desc_ptr, void* smem, @@ -211,7 +211,7 @@ __device__ const dataset_descriptor_base_t* setup_work IndexT, DistanceT, QueryT, - EnableFP8>; + SmemDType>; const desc_t* desc = static_cast(desc_ptr); const desc_t* result = setup_workspace_vpq_impl(desc, smem, queries, query_id); diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in index 2177212e36..6a54c9f956 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in @@ -12,7 +12,7 @@ constexpr uint32_t k_team_size = @team_size@u; constexpr uint32_t k_dataset_block_dim = @dataset_block_dim@u; constexpr uint32_t k_pq_bits = @pq_bits@u; constexpr uint32_t k_pq_len = @pq_len@u; -constexpr bool k_enable_fp8 = @enable_fp8@; +constexpr auto k_smem_dtype = @smem_dtype@; using data_t = @data_type@; using index_t = @index_type@; @@ -41,7 +41,7 @@ setup_workspace( index_t, distance_t, query_t, - k_enable_fp8>(desc, smem, queries, query_id); + k_smem_dtype>(desc, smem, queries, query_id); } } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json index 7ee92494e6..567fe3e5a1 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json @@ -84,7 +84,7 @@ ], "_smem": [ { - "enable_fp8": "false", + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::F16", "smem_abbrev": "f16" } ] @@ -158,11 +158,11 @@ ], "_smem": [ { - "enable_fp8": "false", + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::F16", "smem_abbrev": "f16" }, { - "enable_fp8": "true", + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::E5M2", "smem_abbrev": "e5m2" } ] @@ -238,11 +238,11 @@ ], "_smem": [ { - "enable_fp8": "false", + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::F16", "smem_abbrev": "f16" }, { - "enable_fp8": "true", + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::E5M2", "smem_abbrev": "e5m2" } ] diff --git a/cpp/tests/neighbors/ann_cagra.cuh b/cpp/tests/neighbors/ann_cagra.cuh index 826b8d1a3a..c8322d29fd 100644 --- a/cpp/tests/neighbors/ann_cagra.cuh +++ b/cpp/tests/neighbors/ann_cagra.cuh @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -276,6 +277,7 @@ struct AnnCagraInputs { std::optional non_owning_memory_buffer_flag = std::nullopt; cuvs::neighbors::MergeStrategy merge_strategy = cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_PHYSICAL; + cuvs::neighbors::cagra::internal_dtype smem_dtype = cuvs::neighbors::cagra::internal_dtype::AUTO; }; inline ::std::ostream& operator<<(::std::ostream& os, const AnnCagraInputs& p) @@ -299,6 +301,14 @@ inline ::std::ostream& operator<<(::std::ostream& os, const AnnCagraInputs& p) {search_algo::AUTO, "auto"} // }; std::vector build_algo = {"IVF_PQ", "NN_DESCENT", "ITERATIVE_CAGRA_SEARCH", "AUTO"}; + const auto smem_dtype_str = [](cuvs::neighbors::cagra::internal_dtype dtype) { + switch (dtype) { + case cuvs::neighbors::cagra::internal_dtype::F16: return "F16"; + case cuvs::neighbors::cagra::internal_dtype::E5M2: return "E5M2"; + case cuvs::neighbors::cagra::internal_dtype::AUTO: return "AUTO"; + } + return "Unknown"; + }; std::vector merge_strategy = {"PHYSICAL", "LOGICAL"}; os << "{n_queries=" << p.n_queries << ", dataset shape=" << p.n_rows << "x" << p.dim << ", k=" << p.k << ", " << algo_name[p.algo] << ", max_queries=" << p.max_queries @@ -312,7 +322,7 @@ inline ::std::ostream& operator<<(::std::ostream& os, const AnnCagraInputs& p) if (p.compression.has_value()) { auto vpq = p.compression.value(); os << ", pq_bits=" << vpq.pq_bits << ", pq_dim=" << vpq.pq_dim - << ", vq_n_centers=" << vpq.vq_n_centers; + << ", vq_n_centers=" << vpq.vq_n_centers << ", smem_dtype=" << smem_dtype_str(p.smem_dtype); } os << '}' << std::endl; return os; @@ -346,6 +356,10 @@ class AnnCagraTest : public ::testing::TestWithParam { if (ps.metric == cuvs::distance::DistanceType::L1 && ps.build_algo != graph_build_algo::ITERATIVE_CAGRA_SEARCH) GTEST_SKIP(); + if (ps.smem_dtype == cuvs::neighbors::cagra::internal_dtype::E5M2 && + raft::getComputeCapability().first < 9) { + GTEST_SKIP() << "CAGRA VPQ E5M2 smem dtype requires native FP8 support on SM90+"; + } if (ps.metric == cuvs::distance::DistanceType::CosineExpanded) { if (ps.compression.has_value()) { GTEST_SKIP(); } if (ps.build_algo == graph_build_algo::ITERATIVE_CAGRA_SEARCH || ps.dim == 1) { @@ -415,6 +429,7 @@ class AnnCagraTest : public ::testing::TestWithParam { search_params.algo = ps.algo; search_params.max_queries = ps.max_queries; search_params.team_size = ps.team_size; + search_params.smem_dtype = ps.smem_dtype; auto database_view = raft::make_device_matrix_view( (const DataT*)database.data(), ps.n_rows, ps.dim); @@ -1701,7 +1716,12 @@ inline std::vector generate_inputs() ps.pq_dim = input.dim / pq_len; ps.vq_n_centers = vq_n_centers; input.compression.emplace(ps); + input.smem_dtype = cuvs::neighbors::cagra::internal_dtype::AUTO; inputs.push_back(input); + if (pq_len >= 4 && vq_n_centers == 100) { + input.smem_dtype = cuvs::neighbors::cagra::internal_dtype::E5M2; + inputs.push_back(input); + } } } } From a577563fccababf13e5eba40d44240038b9e6d46 Mon Sep 17 00:00:00 2001 From: enp1s0 Date: Fri, 5 Jun 2026 00:55:48 +0900 Subject: [PATCH 06/15] Update vpq test --- cpp/tests/neighbors/ann_cagra.cuh | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/cpp/tests/neighbors/ann_cagra.cuh b/cpp/tests/neighbors/ann_cagra.cuh index c8322d29fd..5c0da34dd0 100644 --- a/cpp/tests/neighbors/ann_cagra.cuh +++ b/cpp/tests/neighbors/ann_cagra.cuh @@ -1711,15 +1711,15 @@ inline std::vector generate_inputs() // without refinement for (uint32_t pq_len : {2, 4, 8}) { for (uint32_t vq_n_centers : {100, 1000}) { - for (auto input : inputs2) { - vpq_params ps{}; - ps.pq_dim = input.dim / pq_len; - ps.vq_n_centers = vq_n_centers; - input.compression.emplace(ps); - input.smem_dtype = cuvs::neighbors::cagra::internal_dtype::AUTO; - inputs.push_back(input); - if (pq_len >= 4 && vq_n_centers == 100) { - input.smem_dtype = cuvs::neighbors::cagra::internal_dtype::E5M2; + for (auto internal_smem_dtype : {cuvs::neighbors::cagra::internal_dtype::E5M2, + cuvs::neighbors::cagra::internal_dtype::F16, + cuvs::neighbors::cagra::internal_dtype::AUTO}) { + for (auto input : inputs2) { + vpq_params ps{}; + ps.pq_dim = input.dim / pq_len; + ps.vq_n_centers = vq_n_centers; + input.compression.emplace(ps); + input.smem_dtype = internal_smem_dtype; inputs.push_back(input); } } From d05f5524eee4868dd62713e53e18e4afeaa0fce3 Mon Sep 17 00:00:00 2001 From: enp1s0 Date: Fri, 5 Jun 2026 14:55:22 +0900 Subject: [PATCH 07/15] Remove internal_dtype::AUTO --- cpp/include/cuvs/neighbors/cagra.hpp | 4 ++-- cpp/src/neighbors/detail/cagra/cagra_search.cuh | 7 +++---- .../neighbors/detail/cagra/compute_distance_vpq.hpp | 10 +--------- cpp/tests/neighbors/ann_cagra.cuh | 6 ++---- 4 files changed, 8 insertions(+), 19 deletions(-) diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index d2a55ce406..9a906687e3 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -276,7 +276,7 @@ enum class search_algo { enum class hash_mode { HASH = 0, SMALL = 1, AUTO = 100 }; -enum class internal_dtype { F16 = 0, E5M2 = 1, AUTO = 100 }; +enum class internal_dtype { F16 = 0, E5M2 = 1 }; struct search_params : cuvs::neighbors::search_params { /** Maximum number of queries to search at the same time (batch size). Auto select when 0.*/ @@ -354,7 +354,7 @@ struct search_params : cuvs::neighbors::search_params { /** Data type of the query vector and codebook table on shared memory. Currently, only VPQ * supports FP8. **/ - internal_dtype smem_dtype = internal_dtype::AUTO; + internal_dtype smem_dtype = internal_dtype::F16; }; /** diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index f199cf7882..6a64ad7a85 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -153,10 +153,9 @@ void search_main(raft::resources const& res, // Dispatch search parameters based on the dataset kind. if (auto* strided_dset = dynamic_cast*>(&index.data()); strided_dset != nullptr) { - if (params.smem_dtype != cuvs::neighbors::cagra::internal_dtype::AUTO && - params.smem_dtype != cuvs::neighbors::cagra::internal_dtype::F16) { - RAFT_LOG_WARN("In this search mode, smem_dtype supports only AUTO or F16. Set it to AUTO."); - params.smem_dtype = cuvs::neighbors::cagra::internal_dtype::AUTO; + if (params.smem_dtype != cuvs::neighbors::cagra::internal_dtype::F16) { + RAFT_LOG_WARN("In this search mode, smem_dtype supports only F16. Set it to F16."); + params.smem_dtype = cuvs::neighbors::cagra::internal_dtype::F16; } // Search using a plain (strided) row-major dataset RAFT_EXPECTS(index.metric() != cuvs::distance::DistanceType::CosineExpanded || diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp b/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp index b32a0f17c6..83954491bf 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp @@ -8,7 +8,6 @@ #include "compute_distance.hpp" #include -#include #include @@ -65,20 +64,13 @@ struct vpq_descriptor_spec : public instance_spec { const DatasetT& dataset, cuvs::distance::DistanceType metric) -> double { - const auto fp8_natively_supported = raft::getComputeCapability().first >= 9; - const auto selected_smem_dtype = - params.smem_dtype == cuvs::neighbors::cagra::internal_dtype::AUTO - ? (fp8_natively_supported ? cuvs::neighbors::cagra::internal_dtype::E5M2 - : cuvs::neighbors::cagra::internal_dtype::F16) - : params.smem_dtype; - // If explicit team_size is specified and doesn't match the instance, discard it if (params.team_size != 0 && TeamSize != params.team_size) { return -1.0; } if (cuvs::distance::DistanceType::L2Expanded != metric) { return -1.0; } // Match codebook params if (dataset.pq_bits() != PqBits) { return -1.0; } if (dataset.pq_len() != PqLen) { return -1.0; } - if (selected_smem_dtype != SmemDType) { return -1.0; } + if (params.smem_dtype != SmemDType) { return -1.0; } // Keep auto-selection on the tuned VPQ diagonal while allowing explicit team_size requests to // use the expanded team_size / dataset_block_dim grid. constexpr std::uint32_t auto_dataset_block_dim_per_team = PqLen == 8 ? 32 : 16; diff --git a/cpp/tests/neighbors/ann_cagra.cuh b/cpp/tests/neighbors/ann_cagra.cuh index 5c0da34dd0..5f83d2eb8d 100644 --- a/cpp/tests/neighbors/ann_cagra.cuh +++ b/cpp/tests/neighbors/ann_cagra.cuh @@ -277,7 +277,7 @@ struct AnnCagraInputs { std::optional non_owning_memory_buffer_flag = std::nullopt; cuvs::neighbors::MergeStrategy merge_strategy = cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_PHYSICAL; - cuvs::neighbors::cagra::internal_dtype smem_dtype = cuvs::neighbors::cagra::internal_dtype::AUTO; + cuvs::neighbors::cagra::internal_dtype smem_dtype = cuvs::neighbors::cagra::internal_dtype::F16; }; inline ::std::ostream& operator<<(::std::ostream& os, const AnnCagraInputs& p) @@ -305,7 +305,6 @@ inline ::std::ostream& operator<<(::std::ostream& os, const AnnCagraInputs& p) switch (dtype) { case cuvs::neighbors::cagra::internal_dtype::F16: return "F16"; case cuvs::neighbors::cagra::internal_dtype::E5M2: return "E5M2"; - case cuvs::neighbors::cagra::internal_dtype::AUTO: return "AUTO"; } return "Unknown"; }; @@ -1712,8 +1711,7 @@ inline std::vector generate_inputs() for (uint32_t pq_len : {2, 4, 8}) { for (uint32_t vq_n_centers : {100, 1000}) { for (auto internal_smem_dtype : {cuvs::neighbors::cagra::internal_dtype::E5M2, - cuvs::neighbors::cagra::internal_dtype::F16, - cuvs::neighbors::cagra::internal_dtype::AUTO}) { + cuvs::neighbors::cagra::internal_dtype::F16}) { for (auto input : inputs2) { vpq_params ps{}; ps.pq_dim = input.dim / pq_len; From 902073970e238dbede59697b2992bf066a8b75a7 Mon Sep 17 00:00:00 2001 From: enp1s0 Date: Fri, 5 Jun 2026 15:15:36 +0900 Subject: [PATCH 08/15] Update fp8xN to used SW emulated FP8 when FP8 is not natively supported --- .../neighbors/detail/cagra/packed_type.hpp | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/packed_type.hpp b/cpp/src/neighbors/detail/cagra/packed_type.hpp index f52edc126b..4e67fd50e6 100644 --- a/cpp/src/neighbors/detail/cagra/packed_type.hpp +++ b/cpp/src/neighbors/detail/cagra/packed_type.hpp @@ -3,6 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ #pragma once +#include "../../ivf_pq/ivf_pq_fp_8bit.cuh" + #include #include @@ -26,24 +28,39 @@ struct fp8xN {}; template struct fp8xN { - using uint_t = typename uintN_t<8 * NumPacked>::type; - using unit_t = __nv_fp8_e5m2; - using x2_t = __nv_fp8x2_storage_t; + using uint_t = typename uintN_t<8 * NumPacked>::type; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + using unit_t = __nv_fp8_e5m2; + using x2_t = __nv_fp8x2_storage_t; +#else + using unit_t = cuvs::neighbors::ivf_pq::detail::fp_8bit<5u, true>; +#endif static constexpr uint32_t num_elements = NumPacked; - union { + union storage_t { unit_t x1[num_elements]; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 x2_t x2[num_elements / 2]; +#endif uint_t u; + + HDI storage_t() : u{0} {} } data; - HDI fp8xN() { data.u = 0; } + HDI fp8xN() = default; HDI uint_t& as_uint() { return data.u; } HDI uint_t as_uint() const { return data.u; } HDI half2 as_half2(const uint32_t i) const { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 return __nv_cvt_fp8x2_to_halfraw2(data.x2[i], __NV_E5M2); +#else + half2 r; + r.x = static_cast(data.x1[2 * i]); + r.y = static_cast(data.x1[2 * i + 1]); + return r; +#endif } }; } // namespace cuvs::neighbors::cagra::detail::device From 627ee0dc1026b5d267c665e144e12ebb58eec6e7 Mon Sep 17 00:00:00 2001 From: enp1s0 Date: Fri, 5 Jun 2026 15:26:36 +0900 Subject: [PATCH 09/15] Fix VPQ test --- cpp/tests/neighbors/ann_cagra.cuh | 4 ---- 1 file changed, 4 deletions(-) diff --git a/cpp/tests/neighbors/ann_cagra.cuh b/cpp/tests/neighbors/ann_cagra.cuh index 5f83d2eb8d..46b427e241 100644 --- a/cpp/tests/neighbors/ann_cagra.cuh +++ b/cpp/tests/neighbors/ann_cagra.cuh @@ -355,10 +355,6 @@ class AnnCagraTest : public ::testing::TestWithParam { if (ps.metric == cuvs::distance::DistanceType::L1 && ps.build_algo != graph_build_algo::ITERATIVE_CAGRA_SEARCH) GTEST_SKIP(); - if (ps.smem_dtype == cuvs::neighbors::cagra::internal_dtype::E5M2 && - raft::getComputeCapability().first < 9) { - GTEST_SKIP() << "CAGRA VPQ E5M2 smem dtype requires native FP8 support on SM90+"; - } if (ps.metric == cuvs::distance::DistanceType::CosineExpanded) { if (ps.compression.has_value()) { GTEST_SKIP(); } if (ps.build_algo == graph_build_algo::ITERATIVE_CAGRA_SEARCH || ps.dim == 1) { From e7e4205c21ca87d9ebe83e810a8c001e31f6d26f Mon Sep 17 00:00:00 2001 From: enp1s0 Date: Fri, 5 Jun 2026 15:46:35 +0900 Subject: [PATCH 10/15] Fix compilation error --- cpp/src/neighbors/detail/cagra/packed_type.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/neighbors/detail/cagra/packed_type.hpp b/cpp/src/neighbors/detail/cagra/packed_type.hpp index 4e67fd50e6..cfcc68be09 100644 --- a/cpp/src/neighbors/detail/cagra/packed_type.hpp +++ b/cpp/src/neighbors/detail/cagra/packed_type.hpp @@ -47,7 +47,7 @@ struct fp8xN { HDI storage_t() : u{0} {} } data; - HDI fp8xN() = default; + HDI fp8xN() : data{} {} HDI uint_t& as_uint() { return data.u; } HDI uint_t as_uint() const { return data.u; } From 1032ffb2f4a8ee0cc938c6c11bec5260d758dbaf Mon Sep 17 00:00:00 2001 From: enp1s0 Date: Mon, 8 Jun 2026 13:34:46 +0900 Subject: [PATCH 11/15] Update VPQ test to use VpqMathT --- cpp/tests/neighbors/ann_cagra.cuh | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cpp/tests/neighbors/ann_cagra.cuh b/cpp/tests/neighbors/ann_cagra.cuh index 46b427e241..38f77a4346 100644 --- a/cpp/tests/neighbors/ann_cagra.cuh +++ b/cpp/tests/neighbors/ann_cagra.cuh @@ -477,9 +477,11 @@ class AnnCagraTest : public ::testing::TestWithParam { if (ps.compression.has_value()) { auto decoded_dataset = raft::make_device_matrix(handle_, ps.n_rows, ps.dim); - cuvs::neighbors::decode_vpq_dataset( + + using VpqMathT = half; + cuvs::neighbors::decode_vpq_dataset( decoded_dataset.view(), - dynamic_cast&>(index.data()), + dynamic_cast&>(index.data()), raft::resource::get_cuda_stream(handle_)); auto indices_out_view = raft::make_device_matrix_view( indices_dev.data(), ps.n_queries, ps.k); From 02e372639a21db8f75f2798067db139631333157 Mon Sep 17 00:00:00 2001 From: enp1s0 Date: Mon, 8 Jun 2026 15:14:05 +0900 Subject: [PATCH 12/15] Add pq_bits assert --- cpp/tests/neighbors/vpq_utils.cuh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cpp/tests/neighbors/vpq_utils.cuh b/cpp/tests/neighbors/vpq_utils.cuh index 613b5fe1d5..44b4d188ee 100644 --- a/cpp/tests/neighbors/vpq_utils.cuh +++ b/cpp/tests/neighbors/vpq_utils.cuh @@ -51,6 +51,9 @@ void decode_vpq_dataset(raft::device_matrix_view decoded_datase { const auto dataset_size = decoded_dataset.extent(0); RAFT_EXPECTS(vpq_dataset.data.extent(0) == dataset_size, "Dataset sizes mismatch"); + RAFT_EXPECTS(vpq_dataset.pq_bits() == 8, + "decode_vpq_dataset currently only supports pq_bits == 8 (got %u)", + vpq_dataset.pq_bits()); constexpr uint32_t block_size = 256; constexpr uint32_t warp_size = 32; From c608bd16a0beb25683b428d05cecd6b9dda027d7 Mon Sep 17 00:00:00 2001 From: enp1s0 Date: Wed, 10 Jun 2026 00:18:42 +0900 Subject: [PATCH 13/15] Remove SW emulated FP8 --- .../neighbors/detail/cagra/cagra_search.cuh | 6 +++++ .../detail/cagra/compute_distance_vpq.hpp | 13 ++++++++- .../neighbors/detail/cagra/packed_type.hpp | 27 ++++--------------- cpp/tests/neighbors/ann_cagra.cuh | 1 - 4 files changed, 23 insertions(+), 24 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index 6a64ad7a85..a5925b16d2 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -184,6 +184,12 @@ void search_main(raft::resources const& res, RAFT_FAIL("FP32 VPQ dataset support is coming soon"); } else if (auto* vpq_dset = dynamic_cast*>(&index.data()); vpq_dset != nullptr) { + if (params.smem_dtype == cuvs::neighbors::cagra::internal_dtype::E5M2 && + raft::getComputeCapability().first < 9) { + RAFT_LOG_WARN( + "CAGRA VPQ E5M2 smem_dtype requires native FP8 support on SM90+. Falling back to F16."); + params.smem_dtype = cuvs::neighbors::cagra::internal_dtype::F16; + } auto desc = dataset_descriptor_init_with_cache( res, params, *vpq_dset, index.metric(), nullptr); search_main_core( diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp b/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp index 83954491bf..6781eb6abc 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp @@ -8,11 +8,22 @@ #include "compute_distance.hpp" #include +#include #include namespace cuvs::neighbors::cagra::detail { +inline auto select_supported_vpq_smem_dtype(const cagra::search_params& params) + -> cuvs::neighbors::cagra::internal_dtype +{ + if (params.smem_dtype == cuvs::neighbors::cagra::internal_dtype::E5M2 && + raft::getComputeCapability().first < 9) { + return cuvs::neighbors::cagra::internal_dtype::F16; + } + return params.smem_dtype; +} + template { // Match codebook params if (dataset.pq_bits() != PqBits) { return -1.0; } if (dataset.pq_len() != PqLen) { return -1.0; } - if (params.smem_dtype != SmemDType) { return -1.0; } + if (select_supported_vpq_smem_dtype(params) != SmemDType) { return -1.0; } // Keep auto-selection on the tuned VPQ diagonal while allowing explicit team_size requests to // use the expanded team_size / dataset_block_dim grid. constexpr std::uint32_t auto_dataset_block_dim_per_team = PqLen == 8 ? 32 : 16; diff --git a/cpp/src/neighbors/detail/cagra/packed_type.hpp b/cpp/src/neighbors/detail/cagra/packed_type.hpp index cfcc68be09..f52edc126b 100644 --- a/cpp/src/neighbors/detail/cagra/packed_type.hpp +++ b/cpp/src/neighbors/detail/cagra/packed_type.hpp @@ -3,8 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ #pragma once -#include "../../ivf_pq/ivf_pq_fp_8bit.cuh" - #include #include @@ -28,39 +26,24 @@ struct fp8xN {}; template struct fp8xN { - using uint_t = typename uintN_t<8 * NumPacked>::type; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 - using unit_t = __nv_fp8_e5m2; - using x2_t = __nv_fp8x2_storage_t; -#else - using unit_t = cuvs::neighbors::ivf_pq::detail::fp_8bit<5u, true>; -#endif + using uint_t = typename uintN_t<8 * NumPacked>::type; + using unit_t = __nv_fp8_e5m2; + using x2_t = __nv_fp8x2_storage_t; static constexpr uint32_t num_elements = NumPacked; - union storage_t { + union { unit_t x1[num_elements]; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 x2_t x2[num_elements / 2]; -#endif uint_t u; - - HDI storage_t() : u{0} {} } data; - HDI fp8xN() : data{} {} + HDI fp8xN() { data.u = 0; } HDI uint_t& as_uint() { return data.u; } HDI uint_t as_uint() const { return data.u; } HDI half2 as_half2(const uint32_t i) const { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 return __nv_cvt_fp8x2_to_halfraw2(data.x2[i], __NV_E5M2); -#else - half2 r; - r.x = static_cast(data.x1[2 * i]); - r.y = static_cast(data.x1[2 * i + 1]); - return r; -#endif } }; } // namespace cuvs::neighbors::cagra::detail::device diff --git a/cpp/tests/neighbors/ann_cagra.cuh b/cpp/tests/neighbors/ann_cagra.cuh index 38f77a4346..93e933cc94 100644 --- a/cpp/tests/neighbors/ann_cagra.cuh +++ b/cpp/tests/neighbors/ann_cagra.cuh @@ -27,7 +27,6 @@ #include #include #include -#include #include #include From f706baae2cb9e559460e531e84ddff17b0bbbff4 Mon Sep 17 00:00:00 2001 From: enp1s0 Date: Wed, 10 Jun 2026 16:57:58 +0900 Subject: [PATCH 14/15] Update dispatch funcs --- .../jit_lto_kernels/cagra_planner_base.hpp | 204 +++++++----------- 1 file changed, 81 insertions(+), 123 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp index cce18a0216..0666ef815a 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp @@ -175,120 +175,6 @@ struct CagraPlannerBase : AlgorithmPlanner { RAFT_FAIL("Unsupported CAGRA JIT smem_dtype: %u", static_cast(smem_dtype)); } - template - static void dispatch_cagra_standard_team_dim(uint32_t team_size, - uint32_t dataset_block_dim, - Lambda&& l) - { - switch (team_size) { - case 8: - switch (dataset_block_dim) { - case 128: std::forward(l).template operator()<8u, 128u>(); return; - case 256: std::forward(l).template operator()<8u, 256u>(); return; - case 512: std::forward(l).template operator()<8u, 512u>(); return; - default: break; - } - break; - case 16: - switch (dataset_block_dim) { - case 128: std::forward(l).template operator()<16u, 128u>(); return; - case 256: std::forward(l).template operator()<16u, 256u>(); return; - case 512: std::forward(l).template operator()<16u, 512u>(); return; - default: break; - } - break; - case 32: - switch (dataset_block_dim) { - case 128: std::forward(l).template operator()<32u, 128u>(); return; - case 256: std::forward(l).template operator()<32u, 256u>(); return; - case 512: std::forward(l).template operator()<32u, 512u>(); return; - default: break; - } - break; - default: break; - } - RAFT_FAIL("Unsupported standard team_size / dataset_block_dim for CAGRA JIT: team=%u dim=%u", - static_cast(team_size), - static_cast(dataset_block_dim)); - } - - template - static void dispatch_cagra_vpq_pq2_4_team_dim(uint32_t team_size, - uint32_t dataset_block_dim, - Lambda&& l) - { - switch (team_size) { - case 8: - switch (dataset_block_dim) { - case 128: std::forward(l).template operator()<8u, 128u, 8u, PqLenV>(); return; - case 256: std::forward(l).template operator()<8u, 256u, 8u, PqLenV>(); return; - case 512: std::forward(l).template operator()<8u, 512u, 8u, PqLenV>(); return; - default: break; - } - break; - case 16: - switch (dataset_block_dim) { - case 128: std::forward(l).template operator()<16u, 128u, 8u, PqLenV>(); return; - case 256: std::forward(l).template operator()<16u, 256u, 8u, PqLenV>(); return; - case 512: std::forward(l).template operator()<16u, 512u, 8u, PqLenV>(); return; - default: break; - } - break; - case 32: - switch (dataset_block_dim) { - case 128: std::forward(l).template operator()<32u, 128u, 8u, PqLenV>(); return; - case 256: std::forward(l).template operator()<32u, 256u, 8u, PqLenV>(); return; - case 512: std::forward(l).template operator()<32u, 512u, 8u, PqLenV>(); return; - default: break; - } - break; - default: break; - } - RAFT_FAIL( - "Unsupported VPQ pq_len=%u team_size / dataset_block_dim for CAGRA JIT: team=%u dim=%u", - static_cast(PqLenV), - static_cast(team_size), - static_cast(dataset_block_dim)); - } - - template - static void dispatch_cagra_vpq_pq8_team_dim(uint32_t team_size, - uint32_t dataset_block_dim, - Lambda&& l) - { - switch (team_size) { - case 4: - switch (dataset_block_dim) { - case 128: std::forward(l).template operator()<4u, 128u, 8u, 8u>(); return; - default: break; - } - break; - case 8: - switch (dataset_block_dim) { - case 256: std::forward(l).template operator()<8u, 256u, 8u, 8u>(); return; - default: break; - } - break; - case 16: - switch (dataset_block_dim) { - case 512: std::forward(l).template operator()<16u, 512u, 8u, 8u>(); return; - default: break; - } - break; - case 32: - switch (dataset_block_dim) { - case 1024: std::forward(l).template operator()<32u, 1024u, 8u, 8u>(); return; - default: break; - } - break; - default: break; - } - RAFT_FAIL( - "Unsupported VPQ pq_len=8 team_size / dataset_block_dim for CAGRA JIT: team=%u dim=%u", - static_cast(team_size), - static_cast(dataset_block_dim)); - } - template static void dispatch_cagra_vpq_team_dim(uint32_t team_size, uint32_t dataset_block_dim, @@ -384,15 +270,11 @@ struct CagraPlannerBase : AlgorithmPlanner { // template parameters; CAGRA reads team_size / dataset_block_dim from the host descriptor at // planning time. template - static void dispatch_cagra_team_dim(uint32_t team_size, uint32_t dataset_block_dim, Lambda&& l) + static void dispatch_cagra_standard_team_dim(uint32_t team_size, + uint32_t dataset_block_dim, + Lambda&& l) { switch (team_size) { - case 4: - switch (dataset_block_dim) { - case 128: std::forward(l).template operator()<4u, 128u>(); return; - default: break; - } - break; case 8: switch (dataset_block_dim) { case 128: std::forward(l).template operator()<8u, 128u>(); return; @@ -414,17 +296,93 @@ struct CagraPlannerBase : AlgorithmPlanner { case 128: std::forward(l).template operator()<32u, 128u>(); return; case 256: std::forward(l).template operator()<32u, 256u>(); return; case 512: std::forward(l).template operator()<32u, 512u>(); return; - case 1024: std::forward(l).template operator()<32u, 1024u>(); return; default: break; } break; default: break; } - RAFT_FAIL("Unsupported team_size / dataset_block_dim for CAGRA JIT: team=%u dim=%u", + RAFT_FAIL("Unsupported standard team_size / dataset_block_dim for CAGRA JIT: team=%u dim=%u", static_cast(team_size), static_cast(dataset_block_dim)); } + template + static void dispatch_cagra_vpq_pq2_4_team_dim(uint32_t team_size, + uint32_t dataset_block_dim, + Lambda&& l) + { + switch (team_size) { + case 8: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()<8u, 128u, 8u, PqLenV>(); return; + case 256: std::forward(l).template operator()<8u, 256u, 8u, PqLenV>(); return; + case 512: std::forward(l).template operator()<8u, 512u, 8u, PqLenV>(); return; + default: break; + } + break; + case 16: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()<16u, 128u, 8u, PqLenV>(); return; + case 256: std::forward(l).template operator()<16u, 256u, 8u, PqLenV>(); return; + case 512: std::forward(l).template operator()<16u, 512u, 8u, PqLenV>(); return; + default: break; + } + break; + case 32: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()<32u, 128u, 8u, PqLenV>(); return; + case 256: std::forward(l).template operator()<32u, 256u, 8u, PqLenV>(); return; + case 512: std::forward(l).template operator()<32u, 512u, 8u, PqLenV>(); return; + default: break; + } + break; + default: break; + } + RAFT_FAIL( + "Unsupported VPQ pq_len=%u team_size / dataset_block_dim for CAGRA JIT: team=%u dim=%u", + static_cast(PqLenV), + static_cast(team_size), + static_cast(dataset_block_dim)); + } + + template + static void dispatch_cagra_vpq_pq8_team_dim(uint32_t team_size, + uint32_t dataset_block_dim, + Lambda&& l) + { + switch (team_size) { + case 4: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()<4u, 128u, 8u, 8u>(); return; + default: break; + } + break; + case 8: + switch (dataset_block_dim) { + case 256: std::forward(l).template operator()<8u, 256u, 8u, 8u>(); return; + default: break; + } + break; + case 16: + switch (dataset_block_dim) { + case 512: std::forward(l).template operator()<16u, 512u, 8u, 8u>(); return; + default: break; + } + break; + case 32: + switch (dataset_block_dim) { + case 1024: std::forward(l).template operator()<32u, 1024u, 8u, 8u>(); return; + default: break; + } + break; + default: break; + } + RAFT_FAIL( + "Unsupported VPQ pq_len=8 team_size / dataset_block_dim for CAGRA JIT: team=%u dim=%u", + static_cast(team_size), + static_cast(dataset_block_dim)); + } + void add_sample_filter_device_function() { if constexpr (!std::is_same_v) { From 0eef38fe66032618b645f452373c1c426eabcdbb Mon Sep 17 00:00:00 2001 From: enp1s0 Date: Wed, 10 Jun 2026 16:58:20 +0900 Subject: [PATCH 15/15] Fix ldg_cg use --- .../detail/cagra/jit_lto_kernels/compute_distance_impl.cuh | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh index ea817110e7..a6dd0495fd 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh @@ -127,12 +127,7 @@ _RAFT_DEVICE RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_vpq_worker_impl( for (std::uint32_t e = 0; e < nelem; e++) { const std::uint32_t k = e * kTeamVLen + elem_offset + laneId * vlen; if (k >= n_subspace) break; - if constexpr (std::is_same_v) { - device::ldg_cg(pq_codes[e], - reinterpret_cast(dataset_ptr + 4 + k)); - } else { - pq_codes[e] = *reinterpret_cast(dataset_ptr + 4 + k); - } + device::ldg_cg(pq_codes[e], reinterpret_cast(dataset_ptr + 4 + k)); } // if constexpr (PQ_LEN % 2 == 0) {