Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
2539833
Add pq_len=8
enp1s0 May 29, 2026
19d5a0a
Update cagra-q test
enp1s0 Jun 4, 2026
09deae5
Update the compute distance kernel
enp1s0 Jun 4, 2026
c1a2ce6
Merge branch 'main' into cagra-q-pq_len-8-alpha
enp1s0 Jun 4, 2026
5fa5321
Add FP8 support
enp1s0 Jun 4, 2026
c323fa1
Update EnableFP8
enp1s0 Jun 4, 2026
a577563
Update vpq test
enp1s0 Jun 4, 2026
d05f552
Remove internal_dtype::AUTO
enp1s0 Jun 5, 2026
9020739
Update fp8xN to used SW emulated FP8 when FP8 is not natively supported
enp1s0 Jun 5, 2026
07d29d2
Merge branch 'main' into cagra-q-pq_len-8
enp1s0 Jun 5, 2026
627ee0d
Fix VPQ test
enp1s0 Jun 5, 2026
e7e4205
Fix compilation error
enp1s0 Jun 5, 2026
b788dbb
Merge branch 'main' into cagra-q-pq_len-8
enp1s0 Jun 7, 2026
1032ffb
Update VPQ test to use VpqMathT
enp1s0 Jun 8, 2026
02e3726
Add pq_bits assert
enp1s0 Jun 8, 2026
d8c8844
Merge branch 'main' into cagra-q-pq_len-8
enp1s0 Jun 9, 2026
c608bd1
Remove SW emulated FP8
enp1s0 Jun 9, 2026
f706baa
Update dispatch funcs
enp1s0 Jun 10, 2026
0eef38f
Fix ldg_cg use
enp1s0 Jun 10, 2026
f777809
Merge branch 'main' into cagra-q-pq_len-8
enp1s0 Jun 10, 2026
0a74ac6
Merge branch 'cagra-q-pq_len-8' of github.com:enp1s0/cuvs into cagra-…
enp1s0 Jun 10, 2026
dd2500a
Merge branch 'main' into cagra-q-pq_len-8
enp1s0 Jun 11, 2026
ba1a5cf
Merge branch 'main' into cagra-q-pq_len-8
enp1s0 Jun 12, 2026
025659a
Merge branch 'main' into cagra-q-pq_len-8
enp1s0 Jun 16, 2026
9fca67c
Merge branch 'main' into cagra-q-pq_len-8
enp1s0 Jun 23, 2026
7a21676
Merge branch 'main' into cagra-q-pq_len-8
enp1s0 Jun 26, 2026
59951d1
Merge branch 'main' into cagra-q-pq_len-8
enp1s0 Jun 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,12 @@ if(NOT BUILD_CPU_ONLY)
INPUT_FILE
"${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/compute_distance_vpq_inst.cu.in"
OUTPUT_FILE_FORMAT
"${CMAKE_CURRENT_BINARY_DIR}/src/neighbors/detail/cagra/compute_distance_vpq_inst_data_@data_abbrev@_index_@index_abbrev@_distance_@distance_abbrev@_codebook_@codebook_abbrev@_metric_@metric@_team_@team_size@_dim_@dim@_pq_bits_@pq_bits@_pq_len_@pq_len@.cu"
"${CMAKE_CURRENT_BINARY_DIR}/src/neighbors/detail/cagra/compute_distance_vpq_inst_data_@data_abbrev@_index_@index_abbrev@_distance_@distance_abbrev@_codebook_@codebook_abbrev@_metric_@metric@_team_@team_size@_dim_@dim@_pq_bits_@pq_bits@_pq_len_@pq_len@_smem_@smem_abbrev@.cu"
)
generate_string_matrix(
cagra_compute_distance_vpq_selector_template_params
ITEM_FORMAT
"\nvpq_descriptor_spec<DistanceType::@metric@, @team_size@, @dim@, @pq_bits@, @pq_len@, @codebook_type@, @data_type@, @index_type@, @distance_type@>"
"\nvpq_descriptor_spec<DistanceType::@metric@, @team_size@, @dim@, @pq_bits@, @pq_len@, @codebook_type@, @data_type@, @index_type@, @distance_type@, @smem_dtype@>"
GLUE
","
MATRIX_JSON_FILE
Expand All @@ -282,7 +282,7 @@ if(NOT BUILD_CPU_ONLY)
generate_string_matrix(
cagra_compute_distance_vpq_template_inst
ITEM_FORMAT
"extern template struct vpq_descriptor_spec<DistanceType::@metric@, @team_size@, @dim@, @pq_bits@, @pq_len@, @codebook_type@, @data_type@, @index_type@, @distance_type@>@semicolon@"
"extern template struct vpq_descriptor_spec<DistanceType::@metric@, @team_size@, @dim@, @pq_bits@, @pq_len@, @codebook_type@, @data_type@, @index_type@, @distance_type@, @smem_dtype@>@semicolon@"
GLUE
"\n"
MATRIX_JSON_FILE
Expand Down Expand Up @@ -874,13 +874,13 @@ if(NOT BUILD_CPU_ONLY)
generate_jit_lto_kernels(
jit_lto_files
NAME_FORMAT
"cagra_setup_workspace@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_@pq_bits@pq_@pq_len@subd_data_@data_abbrev@_query_@query_abbrev@"
"cagra_setup_workspace@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_@pq_bits@pq_@pq_len@subd_data_@data_abbrev@_query_@query_abbrev@_smem_@smem_abbrev@"
MATRIX_JSON_FILE
"${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json"
KERNEL_INPUT_FILE
"${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in"
FRAGMENT_TAG_FORMAT
"${cagra_ns}::fragment_tag_setup_workspace<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${neighbors_ns}::tag_@query_abbrev@, ${cagra_ns}::tag_codebook_@codebook_abbrev@, @team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@>"
"${cagra_ns}::fragment_tag_setup_workspace<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${neighbors_ns}::tag_@query_abbrev@, ${cagra_ns}::tag_codebook_@codebook_abbrev@, @team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@, ${cagra_ns}::tag_smem_@smem_abbrev@>"
FRAGMENT_TAG_HEADER_FILES "<cuvs/detail/jit_lto/cagra/cagra_fragments.hpp>"
"<cuvs/detail/jit_lto/common_fragments.hpp>"
OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/setup_workspace"
Expand All @@ -890,13 +890,13 @@ if(NOT BUILD_CPU_ONLY)
generate_jit_lto_kernels(
jit_lto_files
NAME_FORMAT
"cagra_compute_distance@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_@pq_bits@pq_@pq_len@subd_data_@data_abbrev@_query_@query_abbrev@"
"cagra_compute_distance@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_@pq_bits@pq_@pq_len@subd_data_@data_abbrev@_query_@query_abbrev@_smem_@smem_abbrev@"
MATRIX_JSON_FILE
"${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json"
KERNEL_INPUT_FILE
"${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in"
FRAGMENT_TAG_FORMAT
"${cagra_ns}::fragment_tag_compute_distance<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${neighbors_ns}::tag_@query_abbrev@, ${cagra_ns}::tag_codebook_@codebook_abbrev@, @team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@>"
"${cagra_ns}::fragment_tag_compute_distance<${neighbors_ns}::tag_@data_abbrev@, ${neighbors_ns}::tag_index_@index_abbrev@, ${cagra_ns}::tag_dist_@distance_abbrev@, ${neighbors_ns}::tag_@query_abbrev@, ${cagra_ns}::tag_codebook_@codebook_abbrev@, @team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@, ${cagra_ns}::tag_smem_@smem_abbrev@>"
FRAGMENT_TAG_HEADER_FILES "<cuvs/detail/jit_lto/cagra/cagra_fragments.hpp>"
"<cuvs/detail/jit_lto/common_fragments.hpp>"
OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/compute_distance"
Expand Down
8 changes: 6 additions & 2 deletions cpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ struct tag_metric_cosine {};
struct tag_metric_hamming {};
struct tag_codebook_none {};
struct tag_codebook_half {};
struct tag_smem_f16 {};
struct tag_smem_e5m2 {};
struct tag_metric_l1 {};
struct tag_norm_noop {};
struct tag_norm_cosine {};
Expand All @@ -33,7 +35,8 @@ template <typename DataTag,
uint32_t TeamSize,
uint32_t DatasetBlockDim,
uint32_t PqBits,
uint32_t PqLen>
uint32_t PqLen,
typename SmemTag>
struct fragment_tag_setup_workspace {};

template <typename DataTag,
Expand All @@ -44,7 +47,8 @@ template <typename DataTag,
uint32_t TeamSize,
uint32_t DatasetBlockDim,
uint32_t PqBits,
uint32_t PqLen>
uint32_t PqLen,
typename SmemTag>
struct fragment_tag_compute_distance {};

template <typename QueryTag, typename DistanceTag, typename MetricTag>
Expand Down
6 changes: 6 additions & 0 deletions cpp/include/cuvs/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ enum class search_algo {

enum class hash_mode { HASH = 0, SMALL = 1, AUTO = 100 };

enum class internal_dtype { F16 = 0, E5M2 = 1 };

struct search_params : cuvs::neighbors::search_params {
/** Maximum number of queries to search at the same time (batch size). Auto select when 0.*/
size_t max_queries = 0;
Expand Down Expand Up @@ -353,6 +355,10 @@ struct search_params : cuvs::neighbors::search_params {
* inferred from the source string.
*/
float filtering_rate = -1.0;

/** Data type of the query vector and codebook table on shared memory. Currently, only VPQ
* supports FP8. **/
internal_dtype smem_dtype = internal_dtype::F16;
};

/**
Expand Down
10 changes: 10 additions & 0 deletions cpp/src/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ void search_main(raft::resources const& res,
// Dispatch search parameters based on the dataset kind.
if (auto* strided_dset = dynamic_cast<const strided_dataset<T, ds_idx_type>*>(&index.data());
strided_dset != nullptr) {
if (params.smem_dtype != cuvs::neighbors::cagra::internal_dtype::F16) {
RAFT_LOG_WARN("In this search mode, smem_dtype supports only F16. Set it to F16.");
params.smem_dtype = cuvs::neighbors::cagra::internal_dtype::F16;
}
// Search using a plain (strided) row-major dataset
RAFT_EXPECTS(index.metric() != cuvs::distance::DistanceType::CosineExpanded ||
index.dataset_norms().has_value(),
Expand Down Expand Up @@ -180,6 +184,12 @@ void search_main(raft::resources const& res,
RAFT_FAIL("FP32 VPQ dataset support is coming soon");
} else if (auto* vpq_dset = dynamic_cast<const vpq_dataset<half, ds_idx_type>*>(&index.data());
vpq_dset != nullptr) {
if (params.smem_dtype == cuvs::neighbors::cagra::internal_dtype::E5M2 &&
raft::getComputeCapability().first < 9) {
RAFT_LOG_WARN(
"CAGRA VPQ E5M2 smem_dtype requires native FP8 support on SM90+. Falling back to F16.");
params.smem_dtype = cuvs::neighbors::cagra::internal_dtype::F16;
}
auto desc = dataset_descriptor_init_with_cache<T, graph_idx_type, DistanceT>(
res, params, *vpq_dset, index.metric(), nullptr);
search_main_core<T, graph_idx_type, DistanceT, CagraSampleFilterT, IdxT, OutputIdxT>(
Expand Down
18 changes: 11 additions & 7 deletions cpp/src/neighbors/detail/cagra/compute_distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,12 @@ struct dataset_descriptor_host {
uint32_t team_size = 0;

// JIT LTO metadata - stored when descriptor is created
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded;
uint32_t dataset_block_dim = 0;
bool is_vpq = false;
uint32_t pq_bits = 0;
uint32_t pq_len = 0;
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded;
uint32_t dataset_block_dim = 0;
bool is_vpq = false;
uint32_t pq_bits = 0;
uint32_t pq_len = 0;
cuvs::neighbors::cagra::internal_dtype smem_dtype = cuvs::neighbors::cagra::internal_dtype::F16;
// Codebook type is determined by DataT for VPQ (always half for now)

struct state {
Expand Down Expand Up @@ -258,15 +259,18 @@ struct dataset_descriptor_host {
uint32_t dataset_block_dim_val,
bool is_vpq_val = false,
uint32_t pq_bits_val = 0,
uint32_t pq_len_val = 0)
uint32_t pq_len_val = 0,
cuvs::neighbors::cagra::internal_dtype smem_dtype_val =
cuvs::neighbors::cagra::internal_dtype::F16)
: value_{std::make_shared<state>(init, sizeof(DescriptorImpl))},
smem_ws_size_in_bytes{dd_host.smem_ws_size_in_bytes()},
team_size{dd_host.team_size()},
metric{metric_val},
dataset_block_dim{dataset_block_dim_val},
is_vpq{is_vpq_val},
pq_bits{pq_bits_val},
pq_len{pq_len_val}
pq_len{pq_len_val},
smem_dtype{smem_dtype_val}
{
}

Expand Down
63 changes: 51 additions & 12 deletions cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#pragma once

#include "compute_distance_vpq.hpp"
#include "packed_type.hpp"

#include <cuvs/distance/distance.hpp>
#include <raft/util/pow2_utils.cuh>
Expand All @@ -14,6 +15,30 @@

namespace cuvs::neighbors::cagra::detail {

template <uint32_t PQ_LEN, cuvs::neighbors::cagra::internal_dtype SmemDType, class Enable = void>
struct vpq_smem_value_config;

template <uint32_t PQ_LEN, cuvs::neighbors::cagra::internal_dtype SmemDType>
struct vpq_smem_value_config<
PQ_LEN,
SmemDType,
std::enable_if_t<PQ_LEN == 2 || SmemDType == cuvs::neighbors::cagra::internal_dtype::F16>> {
using smem_val_pack_t = half2;
using smem_val_t = half;
using smem_val_pack_uint_t = uint32_t;
static constexpr uint32_t num_packed_elements = 2;
};

template <uint32_t PQ_LEN>
struct vpq_smem_value_config<PQ_LEN,
cuvs::neighbors::cagra::internal_dtype::E5M2,
std::enable_if_t<PQ_LEN == 4 || PQ_LEN == 8>> {
using smem_val_pack_t = device::fp8xN<PQ_LEN, 5>;
using smem_val_t = typename smem_val_pack_t::unit_t;
using smem_val_pack_uint_t = typename smem_val_pack_t::uint_t;
static constexpr uint32_t num_packed_elements = smem_val_pack_t::num_elements;
};

template <uint32_t TeamSize,
uint32_t DatasetBlockDim,
uint32_t PQ_BITS,
Expand All @@ -22,7 +47,8 @@ template <uint32_t TeamSize,
typename DataT,
typename IndexT,
typename DistanceT,
typename QueryT>
typename QueryT,
cuvs::neighbors::cagra::internal_dtype SmemDType>
struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<DataT, IndexT, DistanceT> {
using base_type = dataset_descriptor_base_t<DataT, IndexT, DistanceT>;
using CODE_BOOK_T = CodebookT;
Expand All @@ -38,6 +64,7 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<DataT, In
constexpr static inline auto kDatasetBlockDim = DatasetBlockDim;
constexpr static inline auto kPqBits = PQ_BITS;
constexpr static inline auto kPqLen = PQ_LEN;
constexpr static inline auto kSmemDType = SmemDType;

static_assert(std::is_same_v<CODE_BOOK_T, half>, "Only CODE_BOOK_T = `half` is supported now");

Expand Down Expand Up @@ -80,8 +107,11 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<DataT, In
return args.extra_word1;
}

using smem_val_config = vpq_smem_value_config<PQ_LEN, SmemDType>;

static constexpr std::uint32_t kSMemCodeBookSizeInBytes =
(1 << PQ_BITS) * PQ_LEN * utils::size_of<CODE_BOOK_T>();
(1 << PQ_BITS) * PQ_LEN * utils::size_of<typename smem_val_config::smem_val_pack_uint_t>() /
smem_val_config::num_packed_elements;

_RAFT_HOST_DEVICE cagra_q_dataset_descriptor_t(const std::uint8_t* encoded_dataset_ptr,
std::uint32_t encoded_dataset_dim,
Expand All @@ -108,7 +138,9 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<DataT, In
3. Queries (smem_query_buffer_length elems)
*/
return sizeof(cagra_q_dataset_descriptor_t) + kSMemCodeBookSizeInBytes +
raft::round_up_safe<uint32_t>(dim, DatasetBlockDim) * sizeof(QUERY_T);
raft::round_up_safe<uint32_t>(dim, DatasetBlockDim) *
utils::size_of<typename smem_val_config::smem_val_pack_uint_t>() /
smem_val_config::num_packed_elements;
}

private:
Expand All @@ -122,7 +154,8 @@ template <cuvs::distance::DistanceType Metric,
typename CodebookT,
typename DataT,
typename IndexT,
typename DistanceT>
typename DistanceT,
cuvs::neighbors::cagra::internal_dtype SmemDType>
RAFT_KERNEL __launch_bounds__(1, 1)
vpq_dataset_descriptor_init_kernel(dataset_descriptor_base_t<DataT, IndexT, DistanceT>* out,
const std::uint8_t* encoded_dataset_ptr,
Expand All @@ -140,7 +173,8 @@ RAFT_KERNEL __launch_bounds__(1, 1)
DataT,
IndexT,
DistanceT,
half>;
half,
SmemDType>;
new (out) desc_type(
encoded_dataset_ptr, encoded_dataset_dim, vq_code_book_ptr, pq_code_book_ptr, size, dim);
}
Expand All @@ -153,7 +187,8 @@ template <cuvs::distance::DistanceType Metric,
typename CodebookT,
typename DataT,
typename IndexT,
typename DistanceT>
typename DistanceT,
cuvs::neighbors::cagra::internal_dtype SmemDType>
dataset_descriptor_host<DataT, IndexT, DistanceT>
vpq_descriptor_spec<Metric,
TeamSize,
Expand All @@ -163,7 +198,8 @@ vpq_descriptor_spec<Metric,
CodebookT,
DataT,
IndexT,
DistanceT>::init_(const cagra::search_params& params,
DistanceT,
SmemDType>::init_(const cagra::search_params& params,
const std::uint8_t* encoded_dataset_ptr,
uint32_t encoded_dataset_dim,
const CodebookT* vq_code_book_ptr,
Expand All @@ -179,7 +215,8 @@ vpq_descriptor_spec<Metric,
DataT,
IndexT,
DistanceT,
half>;
half,
SmemDType>;

return host_type{
desc_type{
Expand All @@ -194,7 +231,8 @@ vpq_descriptor_spec<Metric,
CodebookT,
DataT,
IndexT,
DistanceT><<<1, 1, 0, stream>>>(dev_ptr,
DistanceT,
SmemDType><<<1, 1, 0, stream>>>(dev_ptr,
encoded_dataset_ptr,
encoded_dataset_dim,
vq_code_book_ptr,
Expand All @@ -205,9 +243,10 @@ vpq_descriptor_spec<Metric,
},
Metric,
DatasetBlockDim,
true, // is_vpq
PqBits, // pq_bits
PqLen}; // pq_len
true, // is_vpq
PqBits, // pq_bits
PqLen, // pq_len
SmemDType}; // smem_dtype
}

} // namespace cuvs::neighbors::cagra::detail
23 changes: 21 additions & 2 deletions cpp/src/neighbors/detail/cagra/compute_distance_vpq.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand All @@ -8,11 +8,22 @@
#include "compute_distance.hpp"

#include <cuvs/distance/distance.hpp>
#include <raft/util/cudart_utils.hpp>

#include <type_traits>

namespace cuvs::neighbors::cagra::detail {

inline auto select_supported_vpq_smem_dtype(const cagra::search_params& params)
-> cuvs::neighbors::cagra::internal_dtype
{
if (params.smem_dtype == cuvs::neighbors::cagra::internal_dtype::E5M2 &&
raft::getComputeCapability().first < 9) {
return cuvs::neighbors::cagra::internal_dtype::F16;
}
return params.smem_dtype;
}

template <cuvs::distance::DistanceType Metric,
uint32_t TeamSize,
uint32_t DatasetBlockDim,
Expand All @@ -21,7 +32,8 @@ template <cuvs::distance::DistanceType Metric,
typename CodebookT,
typename DataT,
typename IndexT,
typename DistanceT>
typename DistanceT,
cuvs::neighbors::cagra::internal_dtype SmemDType>
struct vpq_descriptor_spec : public instance_spec<DataT, IndexT, DistanceT> {
using base_type = instance_spec<DataT, IndexT, DistanceT>;
using typename base_type::data_type;
Expand Down Expand Up @@ -69,6 +81,13 @@ struct vpq_descriptor_spec : public instance_spec<DataT, IndexT, DistanceT> {
// Match codebook params
if (dataset.pq_bits() != PqBits) { return -1.0; }
if (dataset.pq_len() != PqLen) { return -1.0; }
if (select_supported_vpq_smem_dtype(params) != SmemDType) { return -1.0; }
// Keep auto-selection on the tuned VPQ diagonal while allowing explicit team_size requests to
// use the expanded team_size / dataset_block_dim grid.
constexpr std::uint32_t auto_dataset_block_dim_per_team = PqLen == 8 ? 32 : 16;
if (params.team_size == 0 && DatasetBlockDim != TeamSize * auto_dataset_block_dim_per_team) {
return -1.0;
}
// Otherwise, favor the closest dataset dimensionality.
constexpr std::uint32_t preferred_load_elmes_per_thread =
16; /*magic number that is good based on experiments.*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ constexpr uint32_t team_size = @team_size@;
constexpr uint32_t dim = @dim@;
constexpr uint32_t pq_bits = @pq_bits@;
constexpr uint32_t pq_len = @pq_len@;
constexpr auto smem_dtype = @smem_dtype@;
using codebook_t = @codebook_type@;
using data_t = @data_type@;
using index_t = @index_type@;
Expand All @@ -30,6 +31,7 @@ template struct vpq_descriptor_spec<metric,
codebook_t,
data_t,
index_t,
distance_t>;
distance_t,
smem_dtype>;

} // namespace cuvs::neighbors::cagra::detail
Loading
Loading