diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 89cbadfcfc..e3f4a105c1 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -268,12 +268,12 @@ if(NOT BUILD_CPU_ONLY) INPUT_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/compute_distance_vpq_inst.cu.in" OUTPUT_FILE_FORMAT - "${CMAKE_CURRENT_BINARY_DIR}/src/neighbors/detail/cagra/compute_distance_vpq_inst_data_@data_abbrev@_index_@index_abbrev@_distance_@distance_abbrev@_codebook_@codebook_abbrev@_metric_@metric@_team_@team_size@_dim_@dim@_pq_bits_@pq_bits@_pq_len_@pq_len@.cu" + "${CMAKE_CURRENT_BINARY_DIR}/src/neighbors/detail/cagra/compute_distance_vpq_inst_data_@data_abbrev@_index_@index_abbrev@_distance_@distance_abbrev@_codebook_@codebook_abbrev@_metric_@metric@_team_@team_size@_dim_@dim@_pq_bits_@pq_bits@_pq_len_@pq_len@_smem_@smem_abbrev@.cu" ) generate_string_matrix( cagra_compute_distance_vpq_selector_template_params ITEM_FORMAT - "\nvpq_descriptor_spec" + "\nvpq_descriptor_spec" GLUE "," MATRIX_JSON_FILE @@ -282,7 +282,7 @@ if(NOT BUILD_CPU_ONLY) generate_string_matrix( cagra_compute_distance_vpq_template_inst ITEM_FORMAT - "extern template struct vpq_descriptor_spec@semicolon@" + "extern template struct vpq_descriptor_spec@semicolon@" GLUE "\n" MATRIX_JSON_FILE @@ -874,13 +874,13 @@ if(NOT BUILD_CPU_ONLY) generate_jit_lto_kernels( jit_lto_files NAME_FORMAT - "cagra_setup_workspace@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_@pq_bits@pq_@pq_len@subd_data_@data_abbrev@_query_@query_abbrev@" + "cagra_setup_workspace@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_@pq_bits@pq_@pq_len@subd_data_@data_abbrev@_query_@query_abbrev@_smem_@smem_abbrev@" MATRIX_JSON_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json" KERNEL_INPUT_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in" FRAGMENT_TAG_FORMAT - "${cagra_ns}::fragment_tag_setup_workspace<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${neighbors_ns}::tag_@query_abbrev@, ${cagra_ns}::tag_codebook_@codebook_abbrev@, @team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@>" + "${cagra_ns}::fragment_tag_setup_workspace<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${neighbors_ns}::tag_@query_abbrev@, ${cagra_ns}::tag_codebook_@codebook_abbrev@, @team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@, ${cagra_ns}::tag_smem_@smem_abbrev@>" FRAGMENT_TAG_HEADER_FILES "" "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/setup_workspace" @@ -890,13 +890,13 @@ if(NOT BUILD_CPU_ONLY) generate_jit_lto_kernels( jit_lto_files NAME_FORMAT - "cagra_compute_distance@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_@pq_bits@pq_@pq_len@subd_data_@data_abbrev@_query_@query_abbrev@" + "cagra_compute_distance@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_@pq_bits@pq_@pq_len@subd_data_@data_abbrev@_query_@query_abbrev@_smem_@smem_abbrev@" MATRIX_JSON_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json" KERNEL_INPUT_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in" FRAGMENT_TAG_FORMAT - "${cagra_ns}::fragment_tag_compute_distance<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${neighbors_ns}::tag_@query_abbrev@, ${cagra_ns}::tag_codebook_@codebook_abbrev@, @team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@>" + "${cagra_ns}::fragment_tag_compute_distance<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${neighbors_ns}::tag_@query_abbrev@, ${cagra_ns}::tag_codebook_@codebook_abbrev@, @team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@, ${cagra_ns}::tag_smem_@smem_abbrev@>" FRAGMENT_TAG_HEADER_FILES "" "" OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/compute_distance" diff --git a/cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp b/cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp index 0b42d79379..67cdd38783 100644 --- a/cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp +++ b/cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp @@ -16,6 +16,8 @@ struct tag_metric_cosine {}; struct tag_metric_hamming {}; struct tag_codebook_none {}; struct tag_codebook_half {}; +struct tag_smem_f16 {}; +struct tag_smem_e5m2 {}; struct tag_metric_l1 {}; struct tag_norm_noop {}; struct tag_norm_cosine {}; @@ -33,7 +35,8 @@ template + uint32_t PqLen, + typename SmemTag> struct fragment_tag_setup_workspace {}; template + uint32_t PqLen, + typename SmemTag> struct fragment_tag_compute_distance {}; template diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index 637e40c340..aa43046fb5 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -277,6 +277,8 @@ enum class search_algo { enum class hash_mode { HASH = 0, SMALL = 1, AUTO = 100 }; +enum class internal_dtype { F16 = 0, E5M2 = 1 }; + struct search_params : cuvs::neighbors::search_params { /** Maximum number of queries to search at the same time (batch size). Auto select when 0.*/ size_t max_queries = 0; @@ -353,6 +355,10 @@ struct search_params : cuvs::neighbors::search_params { * inferred from the source string. */ float filtering_rate = -1.0; + + /** Data type of the query vector and codebook table on shared memory. Currently, only VPQ + * supports FP8. **/ + internal_dtype smem_dtype = internal_dtype::F16; }; /** diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index bca8d3314d..a5925b16d2 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -153,6 +153,10 @@ void search_main(raft::resources const& res, // Dispatch search parameters based on the dataset kind. if (auto* strided_dset = dynamic_cast*>(&index.data()); strided_dset != nullptr) { + if (params.smem_dtype != cuvs::neighbors::cagra::internal_dtype::F16) { + RAFT_LOG_WARN("In this search mode, smem_dtype supports only F16. Set it to F16."); + params.smem_dtype = cuvs::neighbors::cagra::internal_dtype::F16; + } // Search using a plain (strided) row-major dataset RAFT_EXPECTS(index.metric() != cuvs::distance::DistanceType::CosineExpanded || index.dataset_norms().has_value(), @@ -180,6 +184,12 @@ void search_main(raft::resources const& res, RAFT_FAIL("FP32 VPQ dataset support is coming soon"); } else if (auto* vpq_dset = dynamic_cast*>(&index.data()); vpq_dset != nullptr) { + if (params.smem_dtype == cuvs::neighbors::cagra::internal_dtype::E5M2 && + raft::getComputeCapability().first < 9) { + RAFT_LOG_WARN( + "CAGRA VPQ E5M2 smem_dtype requires native FP8 support on SM90+. Falling back to F16."); + params.smem_dtype = cuvs::neighbors::cagra::internal_dtype::F16; + } auto desc = dataset_descriptor_init_with_cache( res, params, *vpq_dset, index.metric(), nullptr); search_main_core( diff --git a/cpp/src/neighbors/detail/cagra/compute_distance.hpp b/cpp/src/neighbors/detail/cagra/compute_distance.hpp index 75b56860bb..45997a62a3 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance.hpp +++ b/cpp/src/neighbors/detail/cagra/compute_distance.hpp @@ -202,11 +202,12 @@ struct dataset_descriptor_host { uint32_t team_size = 0; // JIT LTO metadata - stored when descriptor is created - cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded; - uint32_t dataset_block_dim = 0; - bool is_vpq = false; - uint32_t pq_bits = 0; - uint32_t pq_len = 0; + cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded; + uint32_t dataset_block_dim = 0; + bool is_vpq = false; + uint32_t pq_bits = 0; + uint32_t pq_len = 0; + cuvs::neighbors::cagra::internal_dtype smem_dtype = cuvs::neighbors::cagra::internal_dtype::F16; // Codebook type is determined by DataT for VPQ (always half for now) struct state { @@ -258,7 +259,9 @@ struct dataset_descriptor_host { uint32_t dataset_block_dim_val, bool is_vpq_val = false, uint32_t pq_bits_val = 0, - uint32_t pq_len_val = 0) + uint32_t pq_len_val = 0, + cuvs::neighbors::cagra::internal_dtype smem_dtype_val = + cuvs::neighbors::cagra::internal_dtype::F16) : value_{std::make_shared(init, sizeof(DescriptorImpl))}, smem_ws_size_in_bytes{dd_host.smem_ws_size_in_bytes()}, team_size{dd_host.team_size()}, @@ -266,7 +269,8 @@ struct dataset_descriptor_host { dataset_block_dim{dataset_block_dim_val}, is_vpq{is_vpq_val}, pq_bits{pq_bits_val}, - pq_len{pq_len_val} + pq_len{pq_len_val}, + smem_dtype{smem_dtype_val} { } diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh b/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh index 6992ae979a..f73f901f95 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh @@ -6,6 +6,7 @@ #pragma once #include "compute_distance_vpq.hpp" +#include "packed_type.hpp" #include #include @@ -14,6 +15,30 @@ namespace cuvs::neighbors::cagra::detail { +template +struct vpq_smem_value_config; + +template +struct vpq_smem_value_config< + PQ_LEN, + SmemDType, + std::enable_if_t> { + using smem_val_pack_t = half2; + using smem_val_t = half; + using smem_val_pack_uint_t = uint32_t; + static constexpr uint32_t num_packed_elements = 2; +}; + +template +struct vpq_smem_value_config> { + using smem_val_pack_t = device::fp8xN; + using smem_val_t = typename smem_val_pack_t::unit_t; + using smem_val_pack_uint_t = typename smem_val_pack_t::uint_t; + static constexpr uint32_t num_packed_elements = smem_val_pack_t::num_elements; +}; + template + typename QueryT, + cuvs::neighbors::cagra::internal_dtype SmemDType> struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t { using base_type = dataset_descriptor_base_t; using CODE_BOOK_T = CodebookT; @@ -38,6 +64,7 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t, "Only CODE_BOOK_T = `half` is supported now"); @@ -80,8 +107,11 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t; + static constexpr std::uint32_t kSMemCodeBookSizeInBytes = - (1 << PQ_BITS) * PQ_LEN * utils::size_of(); + (1 << PQ_BITS) * PQ_LEN * utils::size_of() / + smem_val_config::num_packed_elements; _RAFT_HOST_DEVICE cagra_q_dataset_descriptor_t(const std::uint8_t* encoded_dataset_ptr, std::uint32_t encoded_dataset_dim, @@ -108,7 +138,9 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t(dim, DatasetBlockDim) * sizeof(QUERY_T); + raft::round_up_safe(dim, DatasetBlockDim) * + utils::size_of() / + smem_val_config::num_packed_elements; } private: @@ -122,7 +154,8 @@ template + typename DistanceT, + cuvs::neighbors::cagra::internal_dtype SmemDType> RAFT_KERNEL __launch_bounds__(1, 1) vpq_dataset_descriptor_init_kernel(dataset_descriptor_base_t* out, const std::uint8_t* encoded_dataset_ptr, @@ -140,7 +173,8 @@ RAFT_KERNEL __launch_bounds__(1, 1) DataT, IndexT, DistanceT, - half>; + half, + SmemDType>; new (out) desc_type( encoded_dataset_ptr, encoded_dataset_dim, vq_code_book_ptr, pq_code_book_ptr, size, dim); } @@ -153,7 +187,8 @@ template + typename DistanceT, + cuvs::neighbors::cagra::internal_dtype SmemDType> dataset_descriptor_host vpq_descriptor_spec::init_(const cagra::search_params& params, + DistanceT, + SmemDType>::init_(const cagra::search_params& params, const std::uint8_t* encoded_dataset_ptr, uint32_t encoded_dataset_dim, const CodebookT* vq_code_book_ptr, @@ -179,7 +215,8 @@ vpq_descriptor_spec; + half, + SmemDType>; return host_type{ desc_type{ @@ -194,7 +231,8 @@ vpq_descriptor_spec<<<1, 1, 0, stream>>>(dev_ptr, + DistanceT, + SmemDType><<<1, 1, 0, stream>>>(dev_ptr, encoded_dataset_ptr, encoded_dataset_dim, vq_code_book_ptr, @@ -205,9 +243,10 @@ vpq_descriptor_spec +#include #include namespace cuvs::neighbors::cagra::detail { +inline auto select_supported_vpq_smem_dtype(const cagra::search_params& params) + -> cuvs::neighbors::cagra::internal_dtype +{ + if (params.smem_dtype == cuvs::neighbors::cagra::internal_dtype::E5M2 && + raft::getComputeCapability().first < 9) { + return cuvs::neighbors::cagra::internal_dtype::F16; + } + return params.smem_dtype; +} + template + typename DistanceT, + cuvs::neighbors::cagra::internal_dtype SmemDType> struct vpq_descriptor_spec : public instance_spec { using base_type = instance_spec; using typename base_type::data_type; @@ -69,6 +81,13 @@ struct vpq_descriptor_spec : public instance_spec { // Match codebook params if (dataset.pq_bits() != PqBits) { return -1.0; } if (dataset.pq_len() != PqLen) { return -1.0; } + if (select_supported_vpq_smem_dtype(params) != SmemDType) { return -1.0; } + // Keep auto-selection on the tuned VPQ diagonal while allowing explicit team_size requests to + // use the expanded team_size / dataset_block_dim grid. + constexpr std::uint32_t auto_dataset_block_dim_per_team = PqLen == 8 ? 32 : 16; + if (params.team_size == 0 && DatasetBlockDim != TeamSize * auto_dataset_block_dim_per_team) { + return -1.0; + } // Otherwise, favor the closest dataset dimensionality. constexpr std::uint32_t preferred_load_elmes_per_thread = 16; /*magic number that is good based on experiments.*/ diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq_inst.cu.in b/cpp/src/neighbors/detail/cagra/compute_distance_vpq_inst.cu.in index c159da3229..25d4732a34 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq_inst.cu.in +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq_inst.cu.in @@ -13,6 +13,7 @@ constexpr uint32_t team_size = @team_size@; constexpr uint32_t dim = @dim@; constexpr uint32_t pq_bits = @pq_bits@; constexpr uint32_t pq_len = @pq_len@; +constexpr auto smem_dtype = @smem_dtype@; using codebook_t = @codebook_type@; using data_t = @data_type@; using index_t = @index_type@; @@ -30,6 +31,7 @@ template struct vpq_descriptor_spec; + distance_t, + smem_dtype>; } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq_matrix.json b/cpp/src/neighbors/detail/cagra/compute_distance_vpq_matrix.json index cf6e060d33..1241b2346c 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq_matrix.json +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq_matrix.json @@ -36,15 +36,53 @@ "_mxdim_team": [ { "dim": "128", - "team_size": "8" + "team_size": "8", + "pq_len": "2" }, { "dim": "256", - "team_size": "16" + "team_size": "16", + "pq_len": "2" }, { "dim": "512", - "team_size": "32" + "team_size": "32", + "pq_len": "2" + }, + { + "dim": "128", + "team_size": "8", + "pq_len": "4" + }, + { + "dim": "256", + "team_size": "16", + "pq_len": "4" + }, + { + "dim": "512", + "team_size": "32", + "pq_len": "4" + }, + { + "dim": "128", + "team_size": "4", + "pq_len": "8" + }, + { + "dim": "256", + "team_size": "8", + "pq_len": "8" + }, + { + "dim": "512", + "team_size": "16", + "pq_len": "8" + }, + { + "dim": "1024", + "team_size": "32", + "pq_len": "8" } ], "_codebook": [ @@ -53,7 +91,20 @@ "codebook_abbrev": "h" } ], - "pq_bits": ["8"], - "pq_len": ["2", "4"], - "metric": ["L2Expanded"] + "pq_bits": [ + "8" + ], + "metric": [ + "L2Expanded" + ], + "_smem": [ + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::E5M2", + "smem_abbrev": "e5m2" + }, + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::F16", + "smem_abbrev": "f16" + } + ] } diff --git a/cpp/src/neighbors/detail/cagra/device_memory_ops.hpp b/cpp/src/neighbors/detail/cagra/device_memory_ops.hpp index cc164994ea..1bcf6f8fbd 100644 --- a/cpp/src/neighbors/detail/cagra/device_memory_ops.hpp +++ b/cpp/src/neighbors/detail/cagra/device_memory_ops.hpp @@ -54,6 +54,11 @@ RAFT_DEVICE_INLINE_FUNCTION void lds(uint32_t& x, uint32_t addr) asm volatile("ld.shared.u32 {%0}, [%1];" : "=r"(x) : "r"(addr)); } +RAFT_DEVICE_INLINE_FUNCTION void lds(uint64_t& x, uint32_t addr) +{ + asm volatile("ld.shared.u64 {%0}, [%1];" : "=l"(x) : "r"(addr)); +} + RAFT_DEVICE_INLINE_FUNCTION void lds(uint32_t& x, const uint32_t* addr) { lds(x, uint32_t(__cvta_generic_to_shared(addr))); @@ -71,6 +76,16 @@ RAFT_DEVICE_INLINE_FUNCTION void lds(uint4& x, const uint4* addr) lds(x, uint32_t(__cvta_generic_to_shared(addr))); } +RAFT_DEVICE_INLINE_FUNCTION void sts(uint32_t addr, const uint32_t& x) +{ + asm volatile("st.shared.u32 [%0], %1;" : : "r"(addr), "r"(reinterpret_cast(x))); +} + +RAFT_DEVICE_INLINE_FUNCTION void sts(uint32_t addr, const uint64_t& x) +{ + asm volatile("st.shared.u64 [%0], %1;" : : "r"(addr), "l"(reinterpret_cast(x))); +} + RAFT_DEVICE_INLINE_FUNCTION void sts(uint32_t addr, const half2& x) { asm volatile("st.shared.v2.u16 [%0], {%1, %2};" diff --git a/cpp/src/neighbors/detail/cagra/factory.cuh b/cpp/src/neighbors/detail/cagra/factory.cuh index 26cd13bab8..a1e2f6be9c 100644 --- a/cpp/src/neighbors/detail/cagra/factory.cuh +++ b/cpp/src/neighbors/detail/cagra/factory.cuh @@ -87,6 +87,7 @@ struct key { uint32_t extra_val; // this one has different meanings for different descriptor types uint32_t team_size; uint32_t metric; + uint32_t smem_dtype; }; template @@ -100,7 +101,8 @@ auto make_key(const cagra::search_params& params, dataset.dim(), dataset.stride(), uint32_t(params.team_size), - uint32_t(metric)}; + uint32_t(metric), + uint32_t(params.smem_dtype)}; } template @@ -114,20 +116,22 @@ auto make_key(const cagra::search_params& params, dataset.dim(), uint32_t(reinterpret_cast(dataset.pq_code_book.data_handle()) >> 6), uint32_t(params.team_size), - uint32_t(metric)}; + uint32_t(metric), + uint32_t(params.smem_dtype)}; } inline auto operator==(const key& a, const key& b) -> bool { return a.data_ptr == b.data_ptr && a.n_rows == b.n_rows && a.dim == b.dim && - a.extra_val == b.extra_val && a.team_size == b.team_size && a.metric == b.metric; + a.extra_val == b.extra_val && a.team_size == b.team_size && a.metric == b.metric && + a.smem_dtype == b.smem_dtype; } struct key_hash { inline auto operator()(const key& x) const noexcept -> std::size_t { return size_t{x.data_ptr} + size_t{x.n_rows} * size_t{x.dim} * size_t{x.extra_val} + - (size_t{x.team_size} ^ size_t{x.metric}); + (size_t{x.team_size} ^ size_t{x.metric}) + size_t{x.smem_dtype}; } }; diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp index 896817d869..9b7e44bd8e 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hpp @@ -59,10 +59,14 @@ std::shared_ptr build_single_cta_launcher( persistent); if constexpr (std::is_same_v) { - planner.add_setup_workspace_device_function( - dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len); - planner.add_compute_distance_device_function( - dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len); + planner.add_setup_workspace_device_function(dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.pq_len, + dataset_desc.smem_dtype); + planner.add_compute_distance_device_function(dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.pq_len, + dataset_desc.smem_dtype); } else { planner.add_setup_workspace_device_function(dataset_desc.team_size, dataset_desc.dataset_block_dim); @@ -105,10 +109,14 @@ std::shared_ptr build_multi_cta_launcher( dataset_desc.pq_len); if constexpr (std::is_same_v) { - planner.add_setup_workspace_device_function( - dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len); - planner.add_compute_distance_device_function( - dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len); + planner.add_setup_workspace_device_function(dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.pq_len, + dataset_desc.smem_dtype); + planner.add_compute_distance_device_function(dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.pq_len, + dataset_desc.smem_dtype); } else { planner.add_setup_workspace_device_function(dataset_desc.team_size, dataset_desc.dataset_block_dim); @@ -151,10 +159,14 @@ std::shared_ptr build_multi_kernel_launcher( dataset_desc.pq_bits, dataset_desc.pq_len); if constexpr (std::is_same_v) { - planner.add_setup_workspace_device_function( - dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len); - planner.add_compute_distance_device_function( - dataset_desc.team_size, dataset_desc.dataset_block_dim, dataset_desc.pq_len); + planner.add_setup_workspace_device_function(dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.pq_len, + dataset_desc.smem_dtype); + planner.add_compute_distance_device_function(dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.pq_len, + dataset_desc.smem_dtype); } else { planner.add_setup_workspace_device_function(dataset_desc.team_size, dataset_desc.dataset_block_dim); diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp index 0c3ed64d13..8539efd004 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp @@ -52,11 +52,13 @@ struct CagraPlannerBase : AlgorithmPlanner { TeamSz, Dim, PqBitsV, - PqLenV>>(); + PqLenV, + tag_smem_f16>>(); }; - dispatch_cagra_team_dim(team_size, dataset_block_dim, [&add]() { - add.template operator()(); - }); + dispatch_cagra_standard_team_dim( + team_size, dataset_block_dim, [&add]() { + add.template operator()(); + }); } /// VPQ (`tag_codebook_half`): JIT matrix fixes `pq_bits=8`; only `pq_len` is selected at runtime. @@ -64,30 +66,35 @@ struct CagraPlannerBase : AlgorithmPlanner { std::enable_if_t, int> = 0> void add_setup_workspace_device_function(uint32_t team_size, uint32_t dataset_block_dim, - uint32_t pq_len) + uint32_t pq_len, + cuvs::neighbors::cagra::internal_dtype smem_dtype) { - if (pq_len != 2 && pq_len != 4) { - RAFT_FAIL("CAGRA JIT VPQ setup_workspace expects pq_len in {2,4} (matrix uses pq_bits=8)"); + if (pq_len != 2 && pq_len != 4 && pq_len != 8) { + RAFT_FAIL("CAGRA JIT VPQ setup_workspace expects pq_len in {2,4,8} (matrix uses pq_bits=8)"); } - auto add = [&]() { - this->add_static_fragment>(); + auto add = + [&]() { + this->add_static_fragment>(); + }; + auto dispatch_smem = [&]() { + dispatch_cagra_vpq_team_dim( + team_size, + dataset_block_dim, + pq_len, + [&add]() { + add.template operator()(); + }); }; - dispatch_cagra_team_dim( - team_size, dataset_block_dim, [&add, pq_len]() { - if (pq_len == 2) { - add.template operator()(); - } else { - add.template operator()(); - } - }); + dispatch_cagra_smem_dtype(smem_dtype, dispatch_smem); } /// Registers dist_op + normalization + `compute_distance` for standard layout. @@ -108,11 +115,13 @@ struct CagraPlannerBase : AlgorithmPlanner { TeamSz, Dim, PqBitsV, - PqLenV>>(); + PqLenV, + tag_smem_f16>>(); }; - dispatch_cagra_team_dim(team_size, dataset_block_dim, [&add]() { - add.template operator()(); - }); + dispatch_cagra_standard_team_dim( + team_size, dataset_block_dim, [&add]() { + add.template operator()(); + }); } /// VPQ: only the `compute_distance` fragment (no standard dist_op / normalization in this path). @@ -120,33 +129,77 @@ struct CagraPlannerBase : AlgorithmPlanner { std::enable_if_t, int> = 0> void add_compute_distance_device_function(uint32_t team_size, uint32_t dataset_block_dim, - uint32_t pq_len) + uint32_t pq_len, + cuvs::neighbors::cagra::internal_dtype smem_dtype) { - if (pq_len != 2 && pq_len != 4) { - RAFT_FAIL("CAGRA JIT VPQ compute_distance expects pq_len in {2,4} (matrix uses pq_bits=8)"); + if (pq_len != 2 && pq_len != 4 && pq_len != 8) { + RAFT_FAIL("CAGRA JIT VPQ compute_distance expects pq_len in {2,4,8} (matrix uses pq_bits=8)"); } - auto add = [&]() { - this->add_static_fragment>(); + auto add = + [&]() { + this->add_static_fragment>(); + }; + auto dispatch_smem = [&]() { + dispatch_cagra_vpq_team_dim( + team_size, + dataset_block_dim, + pq_len, + [&add]() { + add.template operator()(); + }); }; - dispatch_cagra_team_dim( - team_size, dataset_block_dim, [&add, pq_len]() { - if (pq_len == 2) { - add.template operator()(); - } else { - add.template operator()(); - } - }); + dispatch_cagra_smem_dtype(smem_dtype, dispatch_smem); } private: + template + static void dispatch_cagra_smem_dtype(cuvs::neighbors::cagra::internal_dtype smem_dtype, + Lambda&& l) + { + switch (smem_dtype) { + case cuvs::neighbors::cagra::internal_dtype::F16: + std::forward(l).template operator()(); + return; + case cuvs::neighbors::cagra::internal_dtype::E5M2: + std::forward(l).template operator()(); + return; + default: break; + } + RAFT_FAIL("Unsupported CAGRA JIT smem_dtype: %u", static_cast(smem_dtype)); + } + + template + static void dispatch_cagra_vpq_team_dim(uint32_t team_size, + uint32_t dataset_block_dim, + uint32_t pq_len, + Lambda&& l) + { + switch (pq_len) { + case 2: + dispatch_cagra_vpq_pq2_4_team_dim<2u>( + team_size, dataset_block_dim, std::forward(l)); + return; + case 4: + dispatch_cagra_vpq_pq2_4_team_dim<4u>( + team_size, dataset_block_dim, std::forward(l)); + return; + case 8: + dispatch_cagra_vpq_pq8_team_dim(team_size, dataset_block_dim, std::forward(l)); + return; + default: break; + } + RAFT_FAIL("CAGRA JIT VPQ expects pq_len in {2,4,8}; got %u", static_cast(pq_len)); + } + void add_dist_op_device_function(cuvs::distance::DistanceType metric) { // dist_op_matrix.json pairs tag_metric_hamming with uint8 query (tag_u8) only; L2/IP/L1 use @@ -191,15 +244,16 @@ struct CagraPlannerBase : AlgorithmPlanner { uint32_t dataset_block_dim) { auto go = [&]() { - dispatch_cagra_team_dim(team_size, dataset_block_dim, [&]() { - this->add_static_fragment>(); - }); + dispatch_cagra_standard_team_dim( + team_size, dataset_block_dim, [&]() { + this->add_static_fragment>(); + }); }; // tag_u8 is only used for BitwiseHamming query layout; cosine norm fragments are built for // float query tag. Use if constexpr so we do not instantiate tag_norm_cosine with tag_u8 @@ -218,7 +272,9 @@ struct CagraPlannerBase : AlgorithmPlanner { // template parameters; CAGRA reads team_size / dataset_block_dim from the host descriptor at // planning time. template - static void dispatch_cagra_team_dim(uint32_t team_size, uint32_t dataset_block_dim, Lambda&& l) + static void dispatch_cagra_standard_team_dim(uint32_t team_size, + uint32_t dataset_block_dim, + Lambda&& l) { switch (team_size) { case 8: @@ -247,11 +303,88 @@ struct CagraPlannerBase : AlgorithmPlanner { break; default: break; } - RAFT_FAIL("Unsupported team_size / dataset_block_dim for CAGRA JIT: team=%u dim=%u", + RAFT_FAIL("Unsupported standard team_size / dataset_block_dim for CAGRA JIT: team=%u dim=%u", static_cast(team_size), static_cast(dataset_block_dim)); } + template + static void dispatch_cagra_vpq_pq2_4_team_dim(uint32_t team_size, + uint32_t dataset_block_dim, + Lambda&& l) + { + switch (team_size) { + case 8: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()<8u, 128u, 8u, PqLenV>(); return; + case 256: std::forward(l).template operator()<8u, 256u, 8u, PqLenV>(); return; + case 512: std::forward(l).template operator()<8u, 512u, 8u, PqLenV>(); return; + default: break; + } + break; + case 16: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()<16u, 128u, 8u, PqLenV>(); return; + case 256: std::forward(l).template operator()<16u, 256u, 8u, PqLenV>(); return; + case 512: std::forward(l).template operator()<16u, 512u, 8u, PqLenV>(); return; + default: break; + } + break; + case 32: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()<32u, 128u, 8u, PqLenV>(); return; + case 256: std::forward(l).template operator()<32u, 256u, 8u, PqLenV>(); return; + case 512: std::forward(l).template operator()<32u, 512u, 8u, PqLenV>(); return; + default: break; + } + break; + default: break; + } + RAFT_FAIL( + "Unsupported VPQ pq_len=%u team_size / dataset_block_dim for CAGRA JIT: team=%u dim=%u", + static_cast(PqLenV), + static_cast(team_size), + static_cast(dataset_block_dim)); + } + + template + static void dispatch_cagra_vpq_pq8_team_dim(uint32_t team_size, + uint32_t dataset_block_dim, + Lambda&& l) + { + switch (team_size) { + case 4: + switch (dataset_block_dim) { + case 128: std::forward(l).template operator()<4u, 128u, 8u, 8u>(); return; + default: break; + } + break; + case 8: + switch (dataset_block_dim) { + case 256: std::forward(l).template operator()<8u, 256u, 8u, 8u>(); return; + default: break; + } + break; + case 16: + switch (dataset_block_dim) { + case 512: std::forward(l).template operator()<16u, 512u, 8u, 8u>(); return; + default: break; + } + break; + case 32: + switch (dataset_block_dim) { + case 1024: std::forward(l).template operator()<32u, 1024u, 8u, 8u>(); return; + default: break; + } + break; + default: break; + } + RAFT_FAIL( + "Unsupported VPQ pq_len=8 team_size / dataset_block_dim for CAGRA JIT: team=%u dim=%u", + static_cast(team_size), + static_cast(dataset_block_dim)); + } + void add_sample_filter_device_function(std::unique_ptr udf_fragment = nullptr) { if constexpr (std::is_same_v) { diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh index 92a014bd2f..a6dd0495fd 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh @@ -100,10 +100,17 @@ _RAFT_DEVICE RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_vpq_worker_impl( constexpr auto DatasetBlockDim = DescriptorT::kDatasetBlockDim; constexpr auto PQ_BITS = DescriptorT::kPqBits; constexpr auto PQ_LEN = DescriptorT::kPqLen; + constexpr auto SmemDType = DescriptorT::kSmemDType; + using PQ_CODEBOOK_LOAD_T = uint32_t; + + using smem_val_config = vpq_smem_value_config; + using smem_val_pack_t = typename smem_val_config::smem_val_pack_t; + using smem_val_pack_uint_t = typename smem_val_config::smem_val_pack_uint_t; + constexpr uint32_t num_packed_elements = smem_val_config::num_packed_elements; const uint32_t query_ptr = pq_codebook_ptr + DescriptorT::kSMemCodeBookSizeInBytes; static_assert(PQ_BITS == 8, "Only pq_bits == 8 is supported at the moment."); - constexpr uint32_t vlen = 4; // **** DO NOT CHANGE **** + constexpr uint32_t vlen = utils::size_of() / utils::size_of(); constexpr uint32_t nelem = raft::div_rounding_up_unsafe(DatasetBlockDim / PQ_LEN, TeamSize * vlen); @@ -115,12 +122,12 @@ _RAFT_DEVICE RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_vpq_worker_impl( DISTANCE_T norm = 0; for (uint32_t elem_offset = 0; elem_offset * PQ_LEN < dim; elem_offset += DatasetBlockDim / PQ_LEN) { - uint32_t pq_codes[nelem]; + PQ_CODEBOOK_LOAD_T pq_codes[nelem]; #pragma unroll for (std::uint32_t e = 0; e < nelem; e++) { const std::uint32_t k = e * kTeamVLen + elem_offset + laneId * vlen; if (k >= n_subspace) break; - device::ldg_cg(pq_codes[e], reinterpret_cast(dataset_ptr + 4 + k)); + device::ldg_cg(pq_codes[e], reinterpret_cast(dataset_ptr + 4 + k)); } // if constexpr (PQ_LEN % 2 == 0) { @@ -135,24 +142,63 @@ _RAFT_DEVICE RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_vpq_worker_impl( if (d >= dim) break; device::ldg_ca(vq_vals[m], vq_code_book_ptr + d); } - std::uint32_t pq_code = pq_codes[e]; + PQ_CODEBOOK_LOAD_T pq_code = pq_codes[e]; #pragma unroll for (std::uint32_t v = 0; v < vlen; v++) { if (PQ_LEN * (v + k) >= dim) break; #pragma unroll - for (std::uint32_t m = 0; m < PQ_LEN / 2; m++) { - constexpr auto kQueryBlock = DatasetBlockDim / (vlen * PQ_LEN); - const std::uint32_t d1 = m + (PQ_LEN / 2) * v; - const std::uint32_t d = - d1 * kQueryBlock + elem_offset * (PQ_LEN / 2) + e * TeamSize + laneId; - half2 q2, c2; - device::lds(q2, query_ptr + sizeof(half2) * d); - device::lds(c2, - pq_codebook_ptr + - sizeof(CODE_BOOK_T) * ((1 << PQ_BITS) * 2 * m + (2 * (pq_code & 0xff)))); - auto dist = q2 - c2 - reinterpret_cast(vq_vals)[d1]; - dist = dist * dist; - norm += static_cast(dist.x + dist.y); + for (std::uint32_t m = 0; m < PQ_LEN / num_packed_elements; m++) { + constexpr uint32_t vq_val_pack_num_elements = 2; + constexpr auto kQueryBlock = DatasetBlockDim / (vlen * PQ_LEN); + std::uint32_t vq_half2_index = + m * (num_packed_elements / vq_val_pack_num_elements) + (PQ_LEN / 2) * v; + + uint32_t query_val_index; + if constexpr (num_packed_elements == 2) { + query_val_index = + vq_half2_index * kQueryBlock + elem_offset * (PQ_LEN / 2) + e * TeamSize + laneId; + } else if constexpr (PQ_LEN == num_packed_elements) { + query_val_index = elem_offset + v * (DatasetBlockDim / (num_packed_elements * vlen)) + + e * TeamSize + laneId; + } else { + const uint32_t query_vec_element_id = + (elem_offset + e * vlen * TeamSize + v + laneId * vlen) * PQ_LEN / + num_packed_elements; + constexpr auto kStride = vlen * PQ_LEN / num_packed_elements; + query_val_index = + transpose(query_vec_element_id); + } + + if constexpr (num_packed_elements == 2) { + smem_val_pack_t q2, c2; + device::lds(q2, query_ptr + sizeof(smem_val_pack_t) * query_val_index); + device::lds(c2, + pq_codebook_ptr + + sizeof(smem_val_pack_uint_t) * ((1 << PQ_BITS) * m + (pq_code & 0xff))); + auto dist = + q2 - c2 - reinterpret_cast(vq_vals)[vq_half2_index]; + dist = dist * dist; + norm += static_cast(dist.x + dist.y); + } else if constexpr (num_packed_elements == 4 || num_packed_elements == 8) { + smem_val_pack_t q_vec, c_vec; + device::lds(q_vec.as_uint(), + query_ptr + sizeof(smem_val_pack_uint_t) * query_val_index); + device::lds(c_vec.as_uint(), + pq_codebook_ptr + + sizeof(smem_val_pack_uint_t) * ((1 << PQ_BITS) * m + (pq_code & 0xff))); + + half2 q2, c2; +#pragma unroll + for (uint32_t bi = 0; bi < num_packed_elements / 2; bi++) { + q2 = q_vec.as_half2(bi); + c2 = c_vec.as_half2(bi); + auto dist = + q2 - c2 - reinterpret_cast(vq_vals)[vq_half2_index]; + dist = dist * dist; + norm += static_cast(dist.x + dist.y); + vq_half2_index += 1; + } + } } pq_code >>= 8; } @@ -219,7 +265,8 @@ template + typename QueryT, + cuvs::neighbors::cagra::internal_dtype SmemDType> __device__ DistanceT compute_distance_impl( const typename dataset_descriptor_base_t::args_t args, IndexT dataset_index) @@ -238,7 +285,8 @@ __device__ DistanceT compute_distance_impl( DataT, IndexT, DistanceT, - QueryT>; + QueryT, + SmemDType>; return compute_distance_vpq_impl(args, dataset_index); } else { static_assert(sizeof(TeamSize) == 0, diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in index 13cd022918..1856781391 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in @@ -11,6 +11,7 @@ constexpr uint32_t k_team_size = @team_size@u; constexpr uint32_t k_dataset_block_dim = @dataset_block_dim@u; constexpr uint32_t k_pq_bits = @pq_bits@u; constexpr uint32_t k_pq_len = @pq_len@u; +constexpr auto k_smem_dtype = @smem_dtype@; using data_t = @data_type@; using index_t = @index_type@; @@ -38,7 +39,8 @@ __device__ distance_t compute_distance(const args_t data_t, index_t, distance_t, - query_t>(args, dataset_index) + query_t, + k_smem_dtype>(args, dataset_index) : distance_t{}; return device::team_sum(per_thread, team_size_bits); } @@ -55,7 +57,8 @@ compute_distance_per_thread(const args_t args, inde data_t, index_t, distance_t, - query_t>(args, dataset_index); + query_t, + k_smem_dtype>(args, dataset_index); } } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json index 82b8dbdf4e..4d260c5507 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json @@ -81,6 +81,12 @@ "codebook_type": "void", "codebook_abbrev": "none" } + ], + "_smem": [ + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::F16", + "smem_abbrev": "f16" + } ] }, { @@ -149,6 +155,96 @@ "codebook_type": "half", "codebook_abbrev": "half" } + ], + "_smem": [ + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::F16", + "smem_abbrev": "f16" + }, + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::E5M2", + "smem_abbrev": "e5m2" + } + ] + }, + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f" + }, + { + "data_type": "__half", + "data_abbrev": "h" + }, + { + "data_type": "uint8_t", + "data_abbrev": "u8" + }, + { + "data_type": "int8_t", + "data_abbrev": "i8" + } + ], + "_query": [ + { + "query_type": "half", + "query_abbrev": "h" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "u32" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "_pq": [ + { + "pq_len": "8", + "pq_bits": "8", + "pq_prefix": "_vpq", + "pq_suffix": "_8pq_8subd" + } + ], + "_codebook": [ + { + "codebook_type": "half", + "codebook_abbrev": "half" + } + ], + "_mxdim_team": [ + { + "dataset_block_dim": "128", + "team_size": "4" + }, + { + "dataset_block_dim": "256", + "team_size": "8" + }, + { + "dataset_block_dim": "512", + "team_size": "16" + }, + { + "dataset_block_dim": "1024", + "team_size": "32" + } + ], + "_smem": [ + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::F16", + "smem_abbrev": "f16" + }, + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::E5M2", + "smem_abbrev": "e5m2" + } ] } ] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuh index 8cdd7febd5..494e0973fe 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuh +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuh @@ -79,12 +79,17 @@ _RAFT_DEVICE __noinline__ auto setup_workspace_vpq_impl( const typename DescriptorT::DATA_T* queries_ptr, uint32_t query_id) -> const DescriptorT* { - using QUERY_T = typename DescriptorT::QUERY_T; - using CODE_BOOK_T = typename DescriptorT::CODE_BOOK_T; - using word_type = uint32_t; - constexpr auto kDatasetBlockDim = DescriptorT::kDatasetBlockDim; - constexpr auto PQ_BITS = DescriptorT::kPqBits; - constexpr auto PQ_LEN = DescriptorT::kPqLen; + using QUERY_T = typename DescriptorT::QUERY_T; + using word_type = uint32_t; + constexpr auto kDatasetBlockDim = DescriptorT::kDatasetBlockDim; + constexpr auto PQ_BITS = DescriptorT::kPqBits; + constexpr auto PQ_LEN = DescriptorT::kPqLen; + constexpr auto SmemDType = DescriptorT::kSmemDType; + using smem_val_config = vpq_smem_value_config; + using smem_val_t = typename smem_val_config::smem_val_t; + using smem_val_pack_t = typename smem_val_config::smem_val_pack_t; + using smem_val_pack_uint_t = typename smem_val_config::smem_val_pack_uint_t; + constexpr auto num_packed_elements = smem_val_config::num_packed_elements; auto* r = reinterpret_cast(smem_ptr); @@ -105,18 +110,32 @@ _RAFT_DEVICE __noinline__ auto setup_workspace_vpq_impl( } __syncthreads(); - for (unsigned i = threadIdx.x * 2; i < (1 << PQ_BITS) * PQ_LEN; i += blockDim.x * 2) { - half2 buf2; - buf2.x = r->pq_code_book_ptr()[i]; - buf2.y = r->pq_code_book_ptr()[i + 1]; - - constexpr auto num_elements_per_bank = 4 / utils::size_of(); - constexpr auto num_banks_per_subspace = PQ_LEN / num_elements_per_bank; - const auto j = i / num_elements_per_bank; - const auto smem_index = - (j / num_banks_per_subspace) + (j % num_banks_per_subspace) * (1 << PQ_BITS); - - device::sts(codebook_buf + smem_index * sizeof(half2), buf2); + for (unsigned i = threadIdx.x * num_packed_elements; i < (1 << PQ_BITS) * PQ_LEN; + i += blockDim.x * num_packed_elements) { + constexpr auto num_elements_per_bank = + num_packed_elements / (utils::size_of() / utils::size_of()); + + if constexpr (PQ_LEN >= num_elements_per_bank) { + constexpr auto num_banks_per_subspace = PQ_LEN / num_elements_per_bank; + const auto j = i / num_elements_per_bank; + const auto smem_index = + (j / num_banks_per_subspace) + (j % num_banks_per_subspace) * (1 << PQ_BITS); + + if constexpr (num_packed_elements == 2) { + smem_val_pack_t buf; + buf.x = r->pq_code_book_ptr()[i]; + buf.y = r->pq_code_book_ptr()[i + 1]; + device::sts(codebook_buf + smem_index * sizeof(smem_val_pack_t), buf); + } else if constexpr (num_packed_elements == 4 || num_packed_elements == 8) { + smem_val_pack_t buf; +#pragma unroll + for (uint32_t k = 0; k < num_packed_elements; k++) { + buf.data.x1[k] = + static_cast(static_cast(r->pq_code_book_ptr()[i + k])); + } + device::sts(codebook_buf + smem_index * sizeof(smem_val_pack_uint_t), buf.as_uint()); + } + } } } @@ -125,19 +144,33 @@ _RAFT_DEVICE __noinline__ auto setup_workspace_vpq_impl( constexpr cuvs::spatial::knn::detail::utils::mapping mapping{}; auto smem_query_ptr = - reinterpret_cast(reinterpret_cast(smem_ptr) + sizeof(DescriptorT) + - DescriptorT::kSMemCodeBookSizeInBytes); - for (unsigned i = threadIdx.x * 2; i < dim; i += blockDim.x * 2) { - half2 buf2{0, 0}; - if (i < dim) { buf2.x = mapping(queries_ptr[i]); } - if (i + 1 < dim) { buf2.y = mapping(queries_ptr[i + 1]); } - if constexpr ((PQ_BITS == 8) && (PQ_LEN % 2 == 0)) { + reinterpret_cast(reinterpret_cast(smem_ptr) + sizeof(DescriptorT) + + DescriptorT::kSMemCodeBookSizeInBytes); + for (unsigned i = threadIdx.x * num_packed_elements; i < dim; + i += blockDim.x * num_packed_elements) { + smem_val_pack_t buf; + if constexpr (num_packed_elements == 2) { + buf.x = 0; + buf.y = 0; + if (i < dim) { buf.x = mapping(queries_ptr[i]); } + if (i + 1 < dim) { buf.y = mapping(queries_ptr[i + 1]); } + } else if constexpr (num_packed_elements == 4 || num_packed_elements == 8) { +#pragma unroll + for (uint32_t k = 0; k < num_packed_elements; k++) { + buf.data.x1[k] = static_cast(0.0f); + if (i + k < dim) { + buf.data.x1[k] = static_cast(static_cast(mapping(queries_ptr[i + k]))); + } + } + } + if constexpr ((PQ_BITS == 8) && (PQ_LEN % num_packed_elements == 0)) { constexpr uint32_t vlen = 4; // **** DO NOT CHANGE **** - constexpr auto kStride = vlen * PQ_LEN / 2; - reinterpret_cast(smem_query_ptr)[transpose(i / 2)] = - buf2; + constexpr auto kStride = vlen * PQ_LEN / num_packed_elements; + reinterpret_cast( + smem_query_ptr)[transpose( + i / num_packed_elements)] = buf; } else { - (reinterpret_cast(smem_query_ptr + i))[0] = buf2; + (reinterpret_cast(smem_query_ptr + i))[0] = buf; } } @@ -152,7 +185,8 @@ template + typename QueryT, + cuvs::neighbors::cagra::internal_dtype SmemDType> __device__ const dataset_descriptor_base_t* setup_workspace_impl( const dataset_descriptor_base_t* desc_ptr, void* smem, @@ -176,7 +210,8 @@ __device__ const dataset_descriptor_base_t* setup_work DataT, IndexT, DistanceT, - QueryT>; + QueryT, + SmemDType>; const desc_t* desc = static_cast(desc_ptr); const desc_t* result = setup_workspace_vpq_impl(desc, smem, queries, query_id); diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in index fa17705250..6a54c9f956 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in @@ -12,6 +12,7 @@ constexpr uint32_t k_team_size = @team_size@u; constexpr uint32_t k_dataset_block_dim = @dataset_block_dim@u; constexpr uint32_t k_pq_bits = @pq_bits@u; constexpr uint32_t k_pq_len = @pq_len@u; +constexpr auto k_smem_dtype = @smem_dtype@; using data_t = @data_type@; using index_t = @index_type@; @@ -39,7 +40,8 @@ setup_workspace( data_t, index_t, distance_t, - query_t>(desc, smem, queries, query_id); + query_t, + k_smem_dtype>(desc, smem, queries, query_id); } } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json index 83aa8764bc..567fe3e5a1 100644 --- a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json @@ -81,6 +81,12 @@ "codebook_type": "void", "codebook_abbrev": "none" } + ], + "_smem": [ + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::F16", + "smem_abbrev": "f16" + } ] }, { @@ -149,6 +155,96 @@ "codebook_type": "half", "codebook_abbrev": "half" } + ], + "_smem": [ + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::F16", + "smem_abbrev": "f16" + }, + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::E5M2", + "smem_abbrev": "e5m2" + } + ] + }, + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f" + }, + { + "data_type": "__half", + "data_abbrev": "h" + }, + { + "data_type": "uint8_t", + "data_abbrev": "u8" + }, + { + "data_type": "int8_t", + "data_abbrev": "i8" + } + ], + "_query": [ + { + "query_type": "half", + "query_abbrev": "h" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "u32" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "_pq": [ + { + "pq_len": "8", + "pq_bits": "8", + "pq_prefix": "_vpq", + "pq_suffix": "_8pq_8subd" + } + ], + "_codebook": [ + { + "codebook_type": "half", + "codebook_abbrev": "half" + } + ], + "_mxdim_team": [ + { + "dataset_block_dim": "128", + "team_size": "4" + }, + { + "dataset_block_dim": "256", + "team_size": "8" + }, + { + "dataset_block_dim": "512", + "team_size": "16" + }, + { + "dataset_block_dim": "1024", + "team_size": "32" + } + ], + "_smem": [ + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::F16", + "smem_abbrev": "f16" + }, + { + "smem_dtype": "cuvs::neighbors::cagra::internal_dtype::E5M2", + "smem_abbrev": "e5m2" + } ] } ] diff --git a/cpp/src/neighbors/detail/cagra/packed_type.hpp b/cpp/src/neighbors/detail/cagra/packed_type.hpp new file mode 100644 index 0000000000..f52edc126b --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/packed_type.hpp @@ -0,0 +1,49 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once +#include +#include + +#include +#include + +namespace cuvs::neighbors::cagra::detail::device { +template +struct uintN_t {}; +template <> +struct uintN_t<32> { + using type = uint32_t; +}; +template <> +struct uintN_t<64> { + using type = uint64_t; +}; + +template +struct fp8xN {}; + +template +struct fp8xN { + using uint_t = typename uintN_t<8 * NumPacked>::type; + using unit_t = __nv_fp8_e5m2; + using x2_t = __nv_fp8x2_storage_t; + static constexpr uint32_t num_elements = NumPacked; + + union { + unit_t x1[num_elements]; + x2_t x2[num_elements / 2]; + uint_t u; + } data; + + HDI fp8xN() { data.u = 0; } + + HDI uint_t& as_uint() { return data.u; } + HDI uint_t as_uint() const { return data.u; } + HDI half2 as_half2(const uint32_t i) const + { + return __nv_cvt_fp8x2_to_halfraw2(data.x2[i], __NV_E5M2); + } +}; +} // namespace cuvs::neighbors::cagra::detail::device diff --git a/cpp/tests/neighbors/ann_cagra.cuh b/cpp/tests/neighbors/ann_cagra.cuh index a6704f892a..93e933cc94 100644 --- a/cpp/tests/neighbors/ann_cagra.cuh +++ b/cpp/tests/neighbors/ann_cagra.cuh @@ -6,6 +6,7 @@ #include "../test_utils.cuh" #include "ann_utils.cuh" +#include "vpq_utils.cuh" #include #include "naive_knn.cuh" @@ -275,6 +276,7 @@ struct AnnCagraInputs { std::optional non_owning_memory_buffer_flag = std::nullopt; cuvs::neighbors::MergeStrategy merge_strategy = cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_PHYSICAL; + cuvs::neighbors::cagra::internal_dtype smem_dtype = cuvs::neighbors::cagra::internal_dtype::F16; }; inline ::std::ostream& operator<<(::std::ostream& os, const AnnCagraInputs& p) @@ -298,6 +300,13 @@ inline ::std::ostream& operator<<(::std::ostream& os, const AnnCagraInputs& p) {search_algo::AUTO, "auto"} // }; std::vector build_algo = {"IVF_PQ", "NN_DESCENT", "ITERATIVE_CAGRA_SEARCH", "AUTO"}; + const auto smem_dtype_str = [](cuvs::neighbors::cagra::internal_dtype dtype) { + switch (dtype) { + case cuvs::neighbors::cagra::internal_dtype::F16: return "F16"; + case cuvs::neighbors::cagra::internal_dtype::E5M2: return "E5M2"; + } + return "Unknown"; + }; std::vector merge_strategy = {"PHYSICAL", "LOGICAL"}; os << "{n_queries=" << p.n_queries << ", dataset shape=" << p.n_rows << "x" << p.dim << ", k=" << p.k << ", " << algo_name[p.algo] << ", max_queries=" << p.max_queries @@ -311,7 +320,7 @@ inline ::std::ostream& operator<<(::std::ostream& os, const AnnCagraInputs& p) if (p.compression.has_value()) { auto vpq = p.compression.value(); os << ", pq_bits=" << vpq.pq_bits << ", pq_dim=" << vpq.pq_dim - << ", vq_n_centers=" << vpq.vq_n_centers; + << ", vq_n_centers=" << vpq.vq_n_centers << ", smem_dtype=" << smem_dtype_str(p.smem_dtype); } os << '}' << std::endl; return os; @@ -414,6 +423,7 @@ class AnnCagraTest : public ::testing::TestWithParam { search_params.algo = ps.algo; search_params.max_queries = ps.max_queries; search_params.team_size = ps.team_size; + search_params.smem_dtype = ps.smem_dtype; auto database_view = raft::make_device_matrix_view( (const DataT*)database.data(), ps.n_rows, ps.dim); @@ -461,6 +471,48 @@ class AnnCagraTest : public ::testing::TestWithParam { raft::update_host(indices_Cagra.data(), indices_dev.data(), queries_size, stream_); raft::resource::sync_stream(handle_); + + reference_recall = 1; + if (ps.compression.has_value()) { + auto decoded_dataset = + raft::make_device_matrix(handle_, ps.n_rows, ps.dim); + + using VpqMathT = half; + cuvs::neighbors::decode_vpq_dataset( + decoded_dataset.view(), + dynamic_cast&>(index.data()), + raft::resource::get_cuda_stream(handle_)); + auto indices_out_view = raft::make_device_matrix_view( + indices_dev.data(), ps.n_queries, ps.k); + auto dists_out_view = raft::make_device_matrix_view( + distances_dev.data(), ps.n_queries, ps.k); + + cuvs::neighbors::naive_knn(handle_, + dists_out_view.data_handle(), + indices_out_view.data_handle(), + search_queries.data(), + decoded_dataset.data_handle(), + ps.n_queries, + ps.n_rows, + ps.dim, + ps.k, + ps.metric); + std::vector indices_vpq_dataset(queries_size); + std::vector distances_vpq_dataset(queries_size); + raft::update_host( + distances_vpq_dataset.data(), dists_out_view.data_handle(), queries_size, stream_); + raft::update_host( + indices_vpq_dataset.data(), indices_out_view.data_handle(), queries_size, stream_); + + reference_recall = std::get<1>(calc_recall(indices_naive, + indices_vpq_dataset, + distances_naive, + distances_vpq_dataset, + ps.n_queries, + ps.k, + 0)); + printf("reference_recall = %e\n", reference_recall); + } } // for (int i = 0; i < min(ps.n_queries, 10); i++) { @@ -470,7 +522,7 @@ class AnnCagraTest : public ::testing::TestWithParam { // print_vector("T", distances_naive.data() + i * ps.k, ps.k, std::cout); // print_vector("C", distances_Cagra.data() + i * ps.k, ps.k, std::cout); // } - double min_recall = ps.min_recall; + double min_recall = ps.min_recall * reference_recall; EXPECT_TRUE(eval_neighbours(indices_naive, indices_Cagra, distances_naive, @@ -519,6 +571,7 @@ class AnnCagraTest : public ::testing::TestWithParam { AnnCagraInputs ps; rmm::device_uvector database; rmm::device_uvector search_queries; + double reference_recall; }; template @@ -1652,14 +1705,18 @@ inline std::vector generate_inputs() {cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_PHYSICAL, cuvs::neighbors::MergeStrategy::MERGE_STRATEGY_LOGICAL}); // don't demand high recall // without refinement - for (uint32_t pq_len : {2}) { // for now, only pq_len = 2 is supported, more options coming soon + for (uint32_t pq_len : {2, 4, 8}) { for (uint32_t vq_n_centers : {100, 1000}) { - for (auto input : inputs2) { - vpq_params ps{}; - ps.pq_dim = input.dim / pq_len; - ps.vq_n_centers = vq_n_centers; - input.compression.emplace(ps); - inputs.push_back(input); + for (auto internal_smem_dtype : {cuvs::neighbors::cagra::internal_dtype::E5M2, + cuvs::neighbors::cagra::internal_dtype::F16}) { + for (auto input : inputs2) { + vpq_params ps{}; + ps.pq_dim = input.dim / pq_len; + ps.vq_n_centers = vq_n_centers; + input.compression.emplace(ps); + input.smem_dtype = internal_smem_dtype; + inputs.push_back(input); + } } } } diff --git a/cpp/tests/neighbors/ann_utils.cuh b/cpp/tests/neighbors/ann_utils.cuh index cbc95d7bb7..64732a8a3a 100644 --- a/cpp/tests/neighbors/ann_utils.cuh +++ b/cpp/tests/neighbors/ann_utils.cuh @@ -227,8 +227,9 @@ auto calc_recall(const std::vector& expected_idx, size_t cols, double eps) { - size_t match_count = 0; - size_t total_count = static_cast(rows) * static_cast(cols); + size_t match_count = 0; + size_t index_match_count = 0; + size_t total_count = static_cast(rows) * static_cast(cols); for (size_t i = 0; i < rows; ++i) { for (size_t k = 0; k < cols; ++k) { size_t idx_k = i * cols + k; // row major assumption! @@ -247,8 +248,28 @@ auto calc_recall(const std::vector& expected_idx, } } } - return std::make_tuple( - static_cast(match_count) / static_cast(total_count), match_count, total_count); + + // Index based recall + for (size_t i = 0; i < rows; ++i) { + for (size_t k = 0; k < cols; ++k) { + size_t idx_k = i * cols + k; // row major assumption! + auto act_idx = actual_idx[idx_k]; + for (size_t j = 0; j < cols; ++j) { + size_t idx = i * cols + j; // row major assumption! + auto exp_idx = expected_idx[idx]; + + if (act_idx == exp_idx) { + index_match_count++; + break; + } + } + } + } + + return std::make_tuple(static_cast(match_count) / static_cast(total_count), + static_cast(index_match_count) / static_cast(total_count), + match_count, + total_count); } /** same as eval_recall, but in case indices do not match, @@ -265,7 +286,7 @@ auto eval_neighbours(const std::vector& expected_idx, bool test_unique = true, size_t max_duplicates = 0) -> testing::AssertionResult { - auto [actual_recall, match_count, total_count] = + auto [actual_recall, index_based_actual_recall, match_count, total_count] = calc_recall(expected_idx, actual_idx, expected_dist, actual_dist, rows, cols, eps); double error_margin = (actual_recall - min_recall) / std::max(1.0 - min_recall, eps); diff --git a/cpp/tests/neighbors/vpq_utils.cuh b/cpp/tests/neighbors/vpq_utils.cuh new file mode 100644 index 0000000000..44b4d188ee --- /dev/null +++ b/cpp/tests/neighbors/vpq_utils.cuh @@ -0,0 +1,76 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include +#include +#include + +#include +#include + +namespace cuvs::neighbors { +template +__global__ void decode_vpq_dataset_kernel(data_t* const decoded_dataset_ptr, + const uint32_t ldd, + const math_t* const vq_codebook_ptr, + const uint32_t ldv, + const math_t* const pq_codebook_ptr, + const uint32_t pq_subspace_dim, + const uint32_t pq_table_size, + const uint32_t dataset_dim, + const size_t dataset_size, + const uint8_t* const data_ptr, + const uint32_t ldi) +{ + constexpr uint32_t warp_size = 32; + const size_t batch_id = (blockIdx.x * blockDim.x + threadIdx.x) / warp_size; + if (batch_id >= dataset_size) { return; } + + const auto local_data_ptr = data_ptr + ldi * batch_id; + const auto vq_code = *reinterpret_cast(local_data_ptr); + const auto pq_code_ptr = local_data_ptr + sizeof(uint32_t); + const auto vq_vec_ptr = vq_codebook_ptr + vq_code * ldv; + auto local_dst_ptr = decoded_dataset_ptr + batch_id * ldd; + + const auto lane_id = threadIdx.x % warp_size; + for (uint32_t i = lane_id; i < dataset_dim; i += warp_size) { + const auto pq_code = pq_code_ptr[i / pq_subspace_dim]; + const auto pq_v = pq_codebook_ptr[pq_code * pq_subspace_dim + (i % pq_subspace_dim)]; + + local_dst_ptr[i] = static_cast(vq_vec_ptr[i]) + static_cast(pq_v); + } +} + +template +void decode_vpq_dataset(raft::device_matrix_view decoded_dataset, + const cuvs::neighbors::vpq_dataset& vpq_dataset, + cudaStream_t cuda_stream) +{ + const auto dataset_size = decoded_dataset.extent(0); + RAFT_EXPECTS(vpq_dataset.data.extent(0) == dataset_size, "Dataset sizes mismatch"); + RAFT_EXPECTS(vpq_dataset.pq_bits() == 8, + "decode_vpq_dataset currently only supports pq_bits == 8 (got %u)", + vpq_dataset.pq_bits()); + + constexpr uint32_t block_size = 256; + constexpr uint32_t warp_size = 32; + constexpr int64_t vecs_per_cta = block_size / warp_size; + const auto grid_size = raft::div_rounding_up_safe(decoded_dataset.extent(0), vecs_per_cta); + + decode_vpq_dataset_kernel + <<>>(decoded_dataset.data_handle(), + decoded_dataset.stride(0), + vpq_dataset.vq_code_book.data_handle(), + vpq_dataset.vq_code_book.stride(0), + vpq_dataset.pq_code_book.data_handle(), + vpq_dataset.pq_len(), + 1u << vpq_dataset.pq_bits(), + vpq_dataset.dim(), + dataset_size, + vpq_dataset.data.data_handle(), + vpq_dataset.data.stride(0)); +} +} // namespace cuvs::neighbors