Improve CAGRA-Q performance and add support for PQ_LEN=8#1533
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdd a configurable shared-memory dtype (F16/E5M2) for CAGRA VPQ: new enum and search param, packed-fp8 utilities, descriptor and cache updates, JIT-LTO planner dispatch and fragment-tagging, threaded SmemDType through setup_workspace and compute_distance device kernels, build/template instantiation updates, kernel matrix metadata, and tests. ChangesCAGRA VPQ Shared-Memory Dtype Support
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 8
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh (1)
489-507:⚠️ Potential issue | 🟡 MinorMatch the
_CLK_BREAKDOWNplaceholders with arguments.The format string now prints both
distanceandhash, but this call only passes oneuint64_tafterclk_pickup_parents. With_CLK_BREAKDOWNenabled, that is undefined behavior and will emit garbage timing data.💡 Suggested fix
clk_init, clk_compute_1st_distance, clk_topk, clk_pickup_parents, + clk_compute_actual_distance, clk_compute_distance - clk_compute_actual_distance);🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh` around lines 489 - 507, The printf call in search_multi_cta_kernel-inl.cuh prints both "distance" and "hash" but only passes one uint64_t after clk_pickup_parents, causing mismatched arguments when _CLK_BREAKDOWN is enabled; update the printf argument list used in the printf near the debug block (the printf that references __FILE__, __LINE__, query_id, threadIdx.x and clk_* variables) to pass both timing values (e.g., clk_compute_distance and clk_compute_actual_distance or the intended clk_hash value) so the number and order of format specifiers match the provided arguments; ensure the variables clk_init, clk_compute_1st_distance, clk_topk, clk_pickup_parents, clk_compute_distance, and clk_compute_actual_distance (or the correct hash timing variable) are all supplied in the same order as the format string.cpp/include/cuvs/neighbors/cagra.hpp (1)
273-351:⚠️ Potential issue | 🟡 MinorDocument the new
internal_dtypeAPI surface.
internal_dtypeandsearch_params::smem_dtypeare now public API, but this header does not document the enum values or the key constraints users need to know: FP8 is VPQ-only,E5M2is ignored for strided datasets (cpp/src/neighbors/detail/cagra/cagra_search.cuhLines 155-159), andAUTOis device-dependent for VPQ selection. Please add Doxygen on the enum/field and flag this new knob in the user-facing CAGRA docs.As per coding guidelines, "For public C++ API headers, additionally check: Doxygen documentation for all public functions/classes" and "API changes flagged for docs/ updates".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/include/cuvs/neighbors/cagra.hpp` around lines 273 - 351, The public enum internal_dtype and search_params::smem_dtype lack Doxygen and user-facing guidance; add Doxygen comments to the internal_dtype enum (document each value: F16, E5M2, AUTO) and to search_params::smem_dtype explaining constraints: FP8 (if present) is VPQ-only, E5M2 is ignored for strided datasets, AUTO behavior is device-dependent and selects VPQ when appropriate, and valid value ranges/compatibility with other params (e.g., smem usage, VPQ). Also add a short API-note in the CAGRA user docs indicating this new knob, its VPQ-only/strided dataset caveats, and guidance on when to choose AUTO vs explicit types. Reference symbols: internal_dtype, search_params::smem_dtype, and VPQ behavior in cagra_search logic.cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh (1)
1086-1103:⚠️ Potential issue | 🟡 MinorMinor: misleading column ordering in
_CLK_BREAKDOWNprintf.The header strings are emitted in the order
..., distance, hashbut the corresponding values areclk_compute_actual_distance, clk_compute_distance - clk_compute_actual_distance. Sinceclk_compute_distanceaccumulates the wall time ofcompute_distance_to_child_nodes(which includes both the hashmap inserts and the actual distance computation), calling the residual "hash" is approximately right, but the column label "distance" now refers to the actual distance kernel time rather than the previously reported total. This is a behavioral break for anyone parsing the existing_CLK_BREAKDOWNoutput. Consider renaming the labels (e.g.actual_distance,non_distance) so log consumers don't silently misinterpret the numbers.This is debug-only instrumentation gated by
#ifdef _CLK_BREAKDOWN, so it does not affect release builds.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh` around lines 1086 - 1103, The `_CLK_BREAKDOWN` printf currently labels the last two columns as "distance, hash" but passes clk_compute_actual_distance and (clk_compute_distance - clk_compute_actual_distance), which is misleading; update the header labels to match the values (for example change "distance, hash" to "actual_distance, non_distance" or similar) so the printed column names align with the passed variables, and ensure the printf string near _CLK_BREAKDOWN and the associated argument list (using clk_compute_actual_distance and clk_compute_distance) are kept consistent.
🧹 Nitpick comments (13)
cpp/tests/neighbors/ann_utils.cuh (2)
289-290:index_based_actual_recallis destructured but never used here.
eval_neighboursonly checksactual_recall; the new index-only recall is consumed by callers (e.g., the VPQ path inann_cagra.cuhviastd::get<1>). To keep this internal binding intentional and avoid a potential-Wunused-variableon some toolchains, mark it[[maybe_unused]]or use a discard.Proposed change
- auto [actual_recall, index_based_actual_recall, match_count, total_count] = - calc_recall(expected_idx, actual_idx, expected_dist, actual_dist, rows, cols, eps); + auto [actual_recall, index_based_actual_recall, match_count, total_count] = + calc_recall(expected_idx, actual_idx, expected_dist, actual_dist, rows, cols, eps); + (void)index_based_actual_recall; // currently unused in eval_neighbours🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tests/neighbors/ann_utils.cuh` around lines 289 - 290, The destructured variable index_based_actual_recall from the calc_recall call in eval_neighbours is not used and can trigger -Wunused-variable; update the destructuring to indicate it's intentionally unused (e.g., mark index_based_actual_recall with [[maybe_unused]] or replace it with a discard/ignore) so the binding remains for callers that consume the second return (calc_recall) but avoids unused-variable warnings in eval_neighbours.
252-273: Avoid the duplicate O(rows·cols²) pass for index-only recall.The new "Index based recall" loop reproduces the structure of the loop above, just without the distance check. Folding the index-only counter into the first loop saves one full O(rows·cols²) traversal and keeps the two recall metrics consistent by construction.
Proposed merge
for (size_t i = 0; i < rows; ++i) { for (size_t k = 0; k < cols; ++k) { size_t idx_k = i * cols + k; // row major assumption! auto act_idx = actual_idx[idx_k]; auto act_dist = actual_dist[idx_k]; + bool idx_matched = false; for (size_t j = 0; j < cols; ++j) { size_t idx = i * cols + j; // row major assumption! auto exp_idx = expected_idx[idx]; auto exp_dist = expected_dist[idx]; + if (!idx_matched && act_idx == exp_idx) { + index_match_count++; + idx_matched = true; + } idx_dist_pair exp_kvp(exp_idx, exp_dist, cuvs::CompareApprox<DistT>(eps)); idx_dist_pair act_kvp(act_idx, act_dist, cuvs::CompareApprox<DistT>(eps)); if (exp_kvp == act_kvp) { match_count++; break; } } } } - - // Index based recall - for (size_t i = 0; i < rows; ++i) { - ... - }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tests/neighbors/ann_utils.cuh` around lines 252 - 273, The second "Index based recall" triple loop duplicates the earlier O(rows·cols²) traversal; instead, inside the original nested loops where you compute match_count (the first double loop that iterates i, k and inner j over expected_idx and checks distances), also check for index equality (if act_idx == exp_idx) and increment index_match_count and break just like the separate loop did, ensuring you only count once per (i,k); then remove the duplicate loop entirely so index_match_count is updated in the same pass as match_count (use the same variables actual_idx, expected_idx, act_idx, exp_idx, index_match_count, match_count, total_count to locate and modify the code).cpp/tests/neighbors/ann_cagra.cuh (2)
563-563: Initializereference_recallat declaration.Currently the member is left uninitialized; it is assigned inside
testCagra()before being read at line 514, but a default initializer makes the invariant explicit and avoids latent UB if any future code path reads it earlier.Proposed change
- double reference_recall; + double reference_recall = 1.0;🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tests/neighbors/ann_cagra.cuh` at line 563, The variable reference_recall is declared uninitialized; initialize it at declaration (e.g., double reference_recall = 0.0;) to make the invariant explicit and avoid potential UB if read before assignment—update the declaration of reference_recall near the top of the file so testCagra() can still assign it later but the variable has a safe default value.
503-503: Replace rawprintfwith the RAFT logger.Other test diagnostics use
RAFT_LOG_INFO. A bareprintfis harder to silence and bypasses the logger configuration.Proposed change
- printf("reference_recall = %e\n", reference_recall); + RAFT_LOG_INFO("reference_recall = %e", reference_recall);🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tests/neighbors/ann_cagra.cuh` at line 503, Replace the raw printf call in ann_cagra.cuh that prints reference_recall with the RAFT logger: change printf("reference_recall = %e\n", reference_recall); to a RAFT_LOG_INFO call (e.g., RAFT_LOG_INFO("reference_recall = %e", reference_recall)); ensure the RAFT logging header is included where needed so the logger symbol RAFT_LOG_INFO is available.cpp/tests/neighbors/vpq_utils.cuh (2)
25-28: Verify alignment of the 4-bytevq_coderead.
reinterpret_cast<const uint32_t*>(local_data_ptr)requiresdata_ptr + ldi * batch_idto be 4-byte aligned. As long as the encoded row strideldiis a multiple of 4 and the base pointer is aligned (which RMM/raft allocations typically are), this is fine — but nothing in this file enforces it. ARAFT_EXPECTS(vpq_dataset.data.stride(0) % 4 == 0, ...)in the host wrapper would make the requirement explicit and fail loudly instead of silently returning misaligned-load garbage on platforms that don't tolerate it.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tests/neighbors/vpq_utils.cuh` around lines 25 - 28, The code reads a 4-byte vq_code via reinterpret_cast<const uint32_t*>(local_data_ptr) (symbols: local_data_ptr, vq_code, data_ptr, ldi, batch_id), but there is no guarantee that local_data_ptr is 4-byte aligned; add an explicit runtime check in the host wrapper that prepares vpq data (e.g., validate vpq_dataset.data.stride(0) / ldi) using RAFT_EXPECTS(vpq_dataset.data.stride(0) % 4 == 0, "stride must be 4-byte aligned") so misaligned strides fail loudly; ensure this check runs before any device/kernel launch that uses local_data_ptr and document the alignment requirement in the wrapper's API comment.
14-20:pq_table_sizeis computed and passed but never used inside the kernel.The host wrapper computes
1u << vpq_dataset.pq_bits()and forwards it aspq_table_size, butdecode_vpq_dataset_kerneldoesn't reference that parameter anywhere. Either drop it from the signature or use it (e.g., to bounds-checkpq_codeagainst the codebook size).Proposed cleanup
__global__ void decode_vpq_dataset_kernel(data_t* const decoded_dataset_ptr, const uint32_t ldd, const math_t* const vq_codebook_ptr, const uint32_t ldv, const math_t* const pq_codebook_ptr, const uint32_t pq_subspace_dim, - const uint32_t pq_table_size, const uint32_t dataset_dim, const size_t dataset_size, const uint8_t* const data_ptr, const uint32_t ldi)…and remove the corresponding argument at the launch site (line 60).
Also applies to: 53-64
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tests/neighbors/vpq_utils.cuh` around lines 14 - 20, The kernel decode_vpq_dataset_kernel currently accepts pq_table_size but never uses it; either remove pq_table_size from the kernel signature and from any launch-site argument lists, or use it to validate decoded PQ indices (e.g., check pq_code < pq_table_size before indexing the codebook) to prevent out-of-bounds access. Locate the kernel function decode_vpq_dataset_kernel and all call sites that pass the computed 1u << vpq_dataset.pq_bits(), then either delete the pq_table_size parameter from the function signature and corresponding launches, or add a bounds-check against pq_table_size where pq_code (or similar PQ index) is used to index the codebook. Ensure consistency across all occurrences (also apply same change to the analogous kernel referenced at lines 53-64).cpp/src/neighbors/detail/cagra/device_common.hpp (3)
254-289:fp8xNonly safely supports evenNumPacked; document or constrain it.
uintN_tis only specialized for 32 and 64, sofp8xN<NumPacked, 5>is only instantiable forNumPacked ∈ {4, 8}(1 →uintN_t<8>, 2 →uintN_t<16>aren't defined). Additionally:
data.x2[num_elements / 2]silently truncates whennum_elementsis odd.as_half2(i)indexes thex2member, so it implicitly assumes pairs, i.e. evennum_elements— there's no guard.Today the only callers (in
compute_distance_vpq-impl.cuh) passPQ_LEN ∈ {4, 8}, so this is safe. To prevent surprises if someone later instantiates with an oddNumPacked, please add astatic_assert(NumPacked % 2 == 0 && (NumPacked == 4 || NumPacked == 8), ...)(or specializeuintN_tfor the supported widths only) in this struct.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/src/neighbors/detail/cagra/device_common.hpp` around lines 254 - 289, The fp8xN<NumPacked, 5> specialization currently assumes an even NumPacked and relies on uintN_t being defined only for 32/64-bit widths; add a compile-time guard to prevent accidental odd or unsupported instantiations by inserting a static_assert in the fp8xN<NumPacked,5> body (near the union/data and methods) that enforces NumPacked % 2 == 0 and restricts NumPacked to the supported sizes (e.g., NumPacked == 4 || NumPacked == 8), and mention uintN_t and as_half2/data.x2 in the assertion message so users know the reason; alternatively, explicitly specialize uintN_t for the required widths and document the even-element requirement in fp8xN.
357-365: Redundantreinterpret_castinstsoverloads.
xalready has typeconst uint32_t&/const uint64_t&, so thereinterpret_cast<const uint32_t&>(x)/reinterpret_cast<const uint64_t&>(x)are no-ops — they can simply bex. Not a correctness issue, just dead casting that obscures the intent of these helpers and differs in style from the templated overloads above.♻️ Suggested cleanup
RAFT_DEVICE_INLINE_FUNCTION void sts(uint32_t addr, const uint32_t& x) { - asm volatile("st.shared.u32 [%0], %1;" : : "r"(addr), "r"(reinterpret_cast<const uint32_t&>(x))); + asm volatile("st.shared.u32 [%0], %1;" : : "r"(addr), "r"(x)); } RAFT_DEVICE_INLINE_FUNCTION void sts(uint32_t addr, const uint64_t& x) { - asm volatile("st.shared.u64 [%0], %1;" : : "r"(addr), "l"(reinterpret_cast<const uint64_t&>(x))); + asm volatile("st.shared.u64 [%0], %1;" : : "r"(addr), "l"(x)); }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/src/neighbors/detail/cagra/device_common.hpp` around lines 357 - 365, The two sts overloads use redundant reinterpret_casts—change the asm operand expressions to use x directly (i.e., replace reinterpret_cast<const uint32_t&>(x) and reinterpret_cast<const uint64_t&>(x) with x) in the functions sts(uint32_t, const uint32_t&) and sts(uint32_t, const uint64_t&) while keeping the existing asm volatile constraints ("r" for u32 and "l" for u64) and signatures unchanged.
323-328: Pre-existing bug surfaced by adjacent diff:lds(uint8_t&)cast is wrong.This isn't a line you changed, but it sits one block above the new
lds(uint64_t&)andstsadditions, so flagging while the surrounding code is under review:RAFT_DEVICE_INLINE_FUNCTION void lds(uint8_t& x, uint32_t addr) { uint32_t res; asm volatile("ld.shared.u8 {%0}, [%1];" : "=r"(res) : "r"(addr)); x = static_cast<uint32_t>(res); // <- assigning uint32_t to uint8_t&; should narrow }The final assignment uses
static_cast<uint32_t>instead ofstatic_cast<uint8_t>. It works only because narrowing is implicit for fundamental types, but it's misleading and contradicts the function's declared semantics. Consider fixing while you're touching this region:- x = static_cast<uint32_t>(res); + x = static_cast<uint8_t>(res);🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/src/neighbors/detail/cagra/device_common.hpp` around lines 323 - 328, The lds(uint8_t& x, uint32_t addr) function casts the loaded 32-bit register to uint32_t before assigning to the uint8_t reference, which is misleading; change the assignment to narrow explicitly by using static_cast<uint8_t>(res) (or declare a uint8_t temp and assign that) so the function's semantics match its signature and the narrowing is explicit.cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh (4)
211-231: Dead branch whenPQ_LEN < num_elements_per_bankis silent.The
if constexpr (PQ_LEN >= num_elements_per_bank)guard on line 211 elides the codebook copy when the assumption is violated — but with noelseand nostatic_assert, an inadvertent future combination ofPQ_LEN/EnableFP8would just silently skip writing the PQ codebook to SMEM, which would corrupt search results without any compile-time signal. All currently instantiated combinations satisfy the guard ({2,4,8}PQ_LEN withnum_elements_per_bank ∈ {2,4}), so this is fine today.Please add a
static_assertinstead of silentif constexprto fail loud if someone adds a newPQ_LEN/FP8 combination later:🛡️ Suggested hardening
- if constexpr (PQ_LEN >= num_elements_per_bank) { // safety + static_assert(PQ_LEN >= num_elements_per_bank, + "PQ_LEN must be >= number of FP8 elements per 32-bit bank"); + { constexpr auto num_banks_per_subspace = PQ_LEN / num_elements_per_bank; ... }The same comment applies to the analogous
if constexpr (PQ_LEN >= num_packed_elements)guard at line 329 ofcompute_distance_vpq_worker.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh` around lines 211 - 231, Replace the silent conditional guard with a compile-time check: add a static_assert that PQ_LEN >= num_elements_per_bank (and similarly static_assert(PQ_LEN >= num_packed_elements) in the analogous guard inside compute_distance_vpq_worker) so any unsupported PQ_LEN / FP8 combinations fail at compile time instead of silently eliding the SMEM codebook copy; keep the existing copy logic (the body that uses PQ_LEN, num_elements_per_bank, num_packed_elements, codebook_buf, smem_val_pack_t, device::sts, and r->pq_code_book_ptr()) unchanged, but remove or tighten the surrounding if constexpr to rely on the static_assert to enforce the invariant.
39-67: AddingEnableFP8as the trailing template parameter is a non-breaking, clean extension.Putting the new
bool EnableFP8template parameter at the end and exposing it as akEnableFP8constexpr keeps the descriptor type discoverable from kernels (e.g.DescriptorT::kEnableFP8is whatsetup_workspace_vpqandcompute_distance_vpq_workerrely on). The propagation throughvpq_dataset_descriptor_init_kernelandvpq_descriptor_spec::init_is consistent.One follow-up worth considering: add a
static_assert(!EnableFP8 || (PQ_LEN == 4 || PQ_LEN == 8), "FP8 SMEM is only supported for PQ_LEN in {4, 8}")here, so an accidental instantiation withEnableFP8=true, PQ_LEN=2fails loudly at compile time instead of silently falling through to the half2 specialization insmem_val_type_t.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh` around lines 39 - 67, Add a compile-time guard to cagra_q_dataset_descriptor_t to prevent invalid FP8 configurations: inside the template struct cagra_q_dataset_descriptor_t (the type that defines kEnableFP8), add a static_assert that checks !EnableFP8 || (PQ_LEN == 4 || PQ_LEN == 8) with a clear message like "FP8 SMEM is only supported for PQ_LEN in {4, 8}" so instantiations with EnableFP8=true and unsupported PQ_LEN (e.g., 2) fail at compile time; reference cagra_q_dataset_descriptor_t, kEnableFP8, EnableFP8, and PQ_LEN when locating where to add the assertion.
300-326:PQ_CODEBOOK_LOAD_Tis hard-coded touint32_t; theelsebranch on Line 323-325 is dead.
PQ_CODEBOOK_LOAD_Tis locallyusing PQ_CODEBOOK_LOAD_T = uint32_t;(line 290) and never aliased, so theif constexpr (std::is_same_v<PQ_CODEBOOK_LOAD_T, uint32_t>)always takes the first branch and the fallback*reinterpret_cast<const PQ_CODEBOOK_LOAD_T*>(...)is unreachable. Either:
- promote
PQ_CODEBOOK_LOAD_Tto a real template parameter / trait if the intent is to support other widths in the future, or- drop the dead
elseand the surroundingif constexpruntil that flexibility is actually needed (KISS / YAGNI).This is purely a maintainability nit — no functional impact.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh` around lines 300 - 326, The code has a dead conditional because PQ_CODEBOOK_LOAD_T is locally typedef'd to uint32_t so the if constexpr(std::is_same_v<PQ_CODEBOOK_LOAD_T, uint32_t>) always selects the device::ldg_cg path and the else branch is unreachable; fix by either (A) making PQ_CODEBOOK_LOAD_T a template parameter or trait used by compute_distance_vpq (so the reinterpret_cast fallback can be meaningful for other widths), or (B) remove the if constexpr and else branch and always use device::ldg_cg/pq_codes assignment for the current uint32_t type to keep the code simple—update the code around PQ_CODEBOOK_LOAD_T, the loading loop that writes pq_codes[e], and the device::ldg_cg/dataset_ptr usage accordingly.
222-230: Remove unnecessary intermediatefloatconversion in codebook setup.The code currently performs a wasteful double conversion:
buf.data.x1[k] = static_cast<smem_val_t>(static_cast<float>(r->pq_code_book_ptr()[i + k]));Since
pq_code_book_ptr()[i + k]ishalfandsmem_val_tis__nv_fp8_e5m2, the intermediatefloatconversion adds unnecessary overhead.__nv_fp8_e5m2provides a direct constructor from__half, supported in CUDA 11.8+ (cuVS targets CUDA 12.9+).♻️ Suggested simplification
-buf.data.x1[k] = - static_cast<smem_val_t>(static_cast<float>(r->pq_code_book_ptr()[i + k])); +buf.data.x1[k] = static_cast<smem_val_t>(r->pq_code_book_ptr()[i + k]);🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh` around lines 222 - 230, The assignment in the codebook setup uses an unnecessary intermediate float cast; change the line inside the loop that sets buf.data.x1[k] so it converts directly from the source half value to smem_val_t (i.e., use a direct static_cast or constructor from r->pq_code_book_ptr()[i + k] to smem_val_t) instead of static_cast<smem_val_t>(static_cast<float>(...)); update the assignment in the loop that writes into buf (within the branch handling num_packed_elements == 4 || 8) so buf.data.x1, smem_val_t, r->pq_code_book_ptr(), and device::sts remain used but the extra float conversion is removed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh`:
- Around line 245-249: The branch for num_packed_elements == 2 contains dead
outer casts and a redundant inner if: the expressions
static_cast<smem_val_t>(static_cast<float>(buf.x = mapping(queries_ptr[i])))
discard the cast result and only perform buf.x = mapping(...); fix by applying
the intended cast to the stored value (e.g. buf.x =
static_cast<smem_val_t>(static_cast<float>(mapping(queries_ptr[i])))) or simply
remove the casts entirely if unnecessary, and drop the inner if (i < dim)
because the surrounding loop already ensures i < dim; update both buf.x and
buf.y assignments accordingly (references: num_packed_elements, smem_val_t,
buf.x, buf.y, mapping, queries_ptr, i, dim).
- Around line 242-269: The half2 branch leaves parts of local buf
(smem_val_pack_t) uninitialized when dim % num_packed_elements != 0; initialize
the unused lane(s) before writing buf to shared memory to match the fp8 path's
zeroed lanes. In the num_packed_elements == 2 path (where buf is a half2), after
assigning buf.x and buf.y conditionally based on i and i+1, explicitly set the
unused lane to zero (using smem_val_t/zero cast or equivalent) when i or i+1 >=
dim so that the stored buf (written via
reinterpret_cast<smem_val_pack_t*>(smem_query_ptr)[transpose<...>(...)] or the
fallback write) contains no garbage; keep existing mapping(queries_ptr[...] )
assignments and do not change the transpose, compute_distance_vpq_worker,
PQ_BITS or PQ_LEN logic.
- Around line 386-411: Add unit tests covering the FP8 VQ-PQ path to validate
indexing and recall: create tests that run the kernel exercising the branch
where num_packed_elements == 4 and == 8 (with PQ_LEN=4 and PQ_LEN=8), feed known
vq_vals and pq_codebook inputs, and assert correct consumption of FP8 elements
by checking final vq_half2_index and end-to-end recall against a dense baseline;
also add a short comment near the loop that mentions the E5M2 precision tradeoff
and include measured recall deltas (e.g., SIFT-1M or DEEP-100M results for
PQ_LEN=4 and 8) in the PR/body so reviewers can see empirical impact.
- Around line 18-37: The template specialization for smem_val_type_t (the
partial specialization keyed by "PQ_LEN == 2 || !EnableFP8") causes
smem_val_type_t<2,true> to compile the same half2 path, so EnableFP8 is
effectively ignored for PQ_LEN==2 and doubles compile artifacts; fix by either
removing EnableFP8=true entries for PQ_LEN==2 from the test/matrix generator
(compute_distance_vpq_matrix.json) so those instantiations are not emitted, or
add a concise explanatory comment immediately above the smem_val_type_t
specializations documenting that the PQ_LEN==2 branch intentionally covers both
EnableFP8 values and that compute_distance_vpq.hpp’s priority function (the if
(use_fp8 != EnableFP8) check) filters the unwanted runtime case—choose one of
these two actions to avoid redundant compilation.
In `@cpp/tests/neighbors/ann_cagra.cuh`:
- Around line 1697-1698: The inline comment next to the pq_len loop in
ann_cagra.cuh is stale — it says "only pq_len = 2 is supported" while the loop
now iterates {2,4,8}; update or remove that comment and similarly revise/remove
the identical stale remarks in the generate_addnode_inputs and
generate_filtering_inputs blocks so comments reflect actual supported PQ lengths
(or note why those functions still only use {2} if intentional); locate the
loops by searching for the pq_len variable and the function names
generate_addnode_inputs and generate_filtering_inputs to apply the fixes.
- Around line 466-504: The code assumes vpq datasets use half by doing
dynamic_cast to vpq_dataset<half, int64_t>& before decode_vpq_dataset; change
this to a type-safe check: replace the throwing reference dynamic_cast with a
pointer dynamic_cast to vpq_dataset<half, int64_t>* and verify it is non-null,
otherwise query the index's actual math type (or a provided math_type()/type()
accessor on index.data()) and call decode_vpq_dataset with the correct template
specialization (or emit a clear fatal error message stating the unexpected
codebook math type). Update the block around
decode_vpq_dataset/dynamic_cast/naive_knn (symbols: decode_vpq_dataset,
vpq_dataset, dynamic_cast, vpq_build, reference_recall) so the cast is validated
and the decode is chosen based on the runtime math type.
In `@cpp/tests/neighbors/vpq_utils.cuh`:
- Around line 1-7: This header file is missing an include guard: add a
top-of-file header guard (preferably a single-line `#pragma once`) to prevent
multiple inclusion of the test header (vpq_utils.cuh) which defines
kernel/function templates in namespace cuvs::neighbors; simply insert `#pragma
once` at the very beginning of vpq_utils.cuh so the kernel and template
definitions in the cuvs::neighbors namespace are not redefined when included by
multiple translation units.
- Around line 9-38: The decoder kernel decode_vpq_dataset_kernel wrongly assumes
pq_bits==8 by reading PQ codes with pq_code_ptr[i / pq_subspace_dim]; either
enforce pq_bits==8 up front or implement proper bit-packed reads: add a
precondition/assertion that the vpq_params.pq_bits == 8 (and log/error if not)
or replace the single-byte read with bitfield unpacking using the existing
bitfield_view_t (or equivalent) to extract the pq_code for each subspace index
before indexing pq_codebook_ptr; update references in decode_vpq_dataset_kernel
(pq_code_ptr usage, loop that computes pq_code) to use the chosen fix.
---
Outside diff comments:
In `@cpp/include/cuvs/neighbors/cagra.hpp`:
- Around line 273-351: The public enum internal_dtype and
search_params::smem_dtype lack Doxygen and user-facing guidance; add Doxygen
comments to the internal_dtype enum (document each value: F16, E5M2, AUTO) and
to search_params::smem_dtype explaining constraints: FP8 (if present) is
VPQ-only, E5M2 is ignored for strided datasets, AUTO behavior is
device-dependent and selects VPQ when appropriate, and valid value
ranges/compatibility with other params (e.g., smem usage, VPQ). Also add a short
API-note in the CAGRA user docs indicating this new knob, its VPQ-only/strided
dataset caveats, and guidance on when to choose AUTO vs explicit types.
Reference symbols: internal_dtype, search_params::smem_dtype, and VPQ behavior
in cagra_search logic.
In `@cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh`:
- Around line 489-507: The printf call in search_multi_cta_kernel-inl.cuh prints
both "distance" and "hash" but only passes one uint64_t after
clk_pickup_parents, causing mismatched arguments when _CLK_BREAKDOWN is enabled;
update the printf argument list used in the printf near the debug block (the
printf that references __FILE__, __LINE__, query_id, threadIdx.x and clk_*
variables) to pass both timing values (e.g., clk_compute_distance and
clk_compute_actual_distance or the intended clk_hash value) so the number and
order of format specifiers match the provided arguments; ensure the variables
clk_init, clk_compute_1st_distance, clk_topk, clk_pickup_parents,
clk_compute_distance, and clk_compute_actual_distance (or the correct hash
timing variable) are all supplied in the same order as the format string.
In `@cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh`:
- Around line 1086-1103: The `_CLK_BREAKDOWN` printf currently labels the last
two columns as "distance, hash" but passes clk_compute_actual_distance and
(clk_compute_distance - clk_compute_actual_distance), which is misleading;
update the header labels to match the values (for example change "distance,
hash" to "actual_distance, non_distance" or similar) so the printed column names
align with the passed variables, and ensure the printf string near
_CLK_BREAKDOWN and the associated argument list (using
clk_compute_actual_distance and clk_compute_distance) are kept consistent.
---
Nitpick comments:
In `@cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh`:
- Around line 211-231: Replace the silent conditional guard with a compile-time
check: add a static_assert that PQ_LEN >= num_elements_per_bank (and similarly
static_assert(PQ_LEN >= num_packed_elements) in the analogous guard inside
compute_distance_vpq_worker) so any unsupported PQ_LEN / FP8 combinations fail
at compile time instead of silently eliding the SMEM codebook copy; keep the
existing copy logic (the body that uses PQ_LEN, num_elements_per_bank,
num_packed_elements, codebook_buf, smem_val_pack_t, device::sts, and
r->pq_code_book_ptr()) unchanged, but remove or tighten the surrounding if
constexpr to rely on the static_assert to enforce the invariant.
- Around line 39-67: Add a compile-time guard to cagra_q_dataset_descriptor_t to
prevent invalid FP8 configurations: inside the template struct
cagra_q_dataset_descriptor_t (the type that defines kEnableFP8), add a
static_assert that checks !EnableFP8 || (PQ_LEN == 4 || PQ_LEN == 8) with a
clear message like "FP8 SMEM is only supported for PQ_LEN in {4, 8}" so
instantiations with EnableFP8=true and unsupported PQ_LEN (e.g., 2) fail at
compile time; reference cagra_q_dataset_descriptor_t, kEnableFP8, EnableFP8, and
PQ_LEN when locating where to add the assertion.
- Around line 300-326: The code has a dead conditional because
PQ_CODEBOOK_LOAD_T is locally typedef'd to uint32_t so the if
constexpr(std::is_same_v<PQ_CODEBOOK_LOAD_T, uint32_t>) always selects the
device::ldg_cg path and the else branch is unreachable; fix by either (A) making
PQ_CODEBOOK_LOAD_T a template parameter or trait used by compute_distance_vpq
(so the reinterpret_cast fallback can be meaningful for other widths), or (B)
remove the if constexpr and else branch and always use device::ldg_cg/pq_codes
assignment for the current uint32_t type to keep the code simple—update the code
around PQ_CODEBOOK_LOAD_T, the loading loop that writes pq_codes[e], and the
device::ldg_cg/dataset_ptr usage accordingly.
- Around line 222-230: The assignment in the codebook setup uses an unnecessary
intermediate float cast; change the line inside the loop that sets
buf.data.x1[k] so it converts directly from the source half value to smem_val_t
(i.e., use a direct static_cast or constructor from r->pq_code_book_ptr()[i + k]
to smem_val_t) instead of static_cast<smem_val_t>(static_cast<float>(...));
update the assignment in the loop that writes into buf (within the branch
handling num_packed_elements == 4 || 8) so buf.data.x1, smem_val_t,
r->pq_code_book_ptr(), and device::sts remain used but the extra float
conversion is removed.
In `@cpp/src/neighbors/detail/cagra/device_common.hpp`:
- Around line 254-289: The fp8xN<NumPacked, 5> specialization currently assumes
an even NumPacked and relies on uintN_t being defined only for 32/64-bit widths;
add a compile-time guard to prevent accidental odd or unsupported instantiations
by inserting a static_assert in the fp8xN<NumPacked,5> body (near the union/data
and methods) that enforces NumPacked % 2 == 0 and restricts NumPacked to the
supported sizes (e.g., NumPacked == 4 || NumPacked == 8), and mention uintN_t
and as_half2/data.x2 in the assertion message so users know the reason;
alternatively, explicitly specialize uintN_t for the required widths and
document the even-element requirement in fp8xN.
- Around line 357-365: The two sts overloads use redundant
reinterpret_casts—change the asm operand expressions to use x directly (i.e.,
replace reinterpret_cast<const uint32_t&>(x) and reinterpret_cast<const
uint64_t&>(x) with x) in the functions sts(uint32_t, const uint32_t&) and
sts(uint32_t, const uint64_t&) while keeping the existing asm volatile
constraints ("r" for u32 and "l" for u64) and signatures unchanged.
- Around line 323-328: The lds(uint8_t& x, uint32_t addr) function casts the
loaded 32-bit register to uint32_t before assigning to the uint8_t reference,
which is misleading; change the assignment to narrow explicitly by using
static_cast<uint8_t>(res) (or declare a uint8_t temp and assign that) so the
function's semantics match its signature and the narrowing is explicit.
In `@cpp/tests/neighbors/ann_cagra.cuh`:
- Line 563: The variable reference_recall is declared uninitialized; initialize
it at declaration (e.g., double reference_recall = 0.0;) to make the invariant
explicit and avoid potential UB if read before assignment—update the declaration
of reference_recall near the top of the file so testCagra() can still assign it
later but the variable has a safe default value.
- Line 503: Replace the raw printf call in ann_cagra.cuh that prints
reference_recall with the RAFT logger: change printf("reference_recall = %e\n",
reference_recall); to a RAFT_LOG_INFO call (e.g.,
RAFT_LOG_INFO("reference_recall = %e", reference_recall)); ensure the RAFT
logging header is included where needed so the logger symbol RAFT_LOG_INFO is
available.
In `@cpp/tests/neighbors/ann_utils.cuh`:
- Around line 289-290: The destructured variable index_based_actual_recall from
the calc_recall call in eval_neighbours is not used and can trigger
-Wunused-variable; update the destructuring to indicate it's intentionally
unused (e.g., mark index_based_actual_recall with [[maybe_unused]] or replace it
with a discard/ignore) so the binding remains for callers that consume the
second return (calc_recall) but avoids unused-variable warnings in
eval_neighbours.
- Around line 252-273: The second "Index based recall" triple loop duplicates
the earlier O(rows·cols²) traversal; instead, inside the original nested loops
where you compute match_count (the first double loop that iterates i, k and
inner j over expected_idx and checks distances), also check for index equality
(if act_idx == exp_idx) and increment index_match_count and break just like the
separate loop did, ensuring you only count once per (i,k); then remove the
duplicate loop entirely so index_match_count is updated in the same pass as
match_count (use the same variables actual_idx, expected_idx, act_idx, exp_idx,
index_match_count, match_count, total_count to locate and modify the code).
In `@cpp/tests/neighbors/vpq_utils.cuh`:
- Around line 25-28: The code reads a 4-byte vq_code via reinterpret_cast<const
uint32_t*>(local_data_ptr) (symbols: local_data_ptr, vq_code, data_ptr, ldi,
batch_id), but there is no guarantee that local_data_ptr is 4-byte aligned; add
an explicit runtime check in the host wrapper that prepares vpq data (e.g.,
validate vpq_dataset.data.stride(0) / ldi) using
RAFT_EXPECTS(vpq_dataset.data.stride(0) % 4 == 0, "stride must be 4-byte
aligned") so misaligned strides fail loudly; ensure this check runs before any
device/kernel launch that uses local_data_ptr and document the alignment
requirement in the wrapper's API comment.
- Around line 14-20: The kernel decode_vpq_dataset_kernel currently accepts
pq_table_size but never uses it; either remove pq_table_size from the kernel
signature and from any launch-site argument lists, or use it to validate decoded
PQ indices (e.g., check pq_code < pq_table_size before indexing the codebook) to
prevent out-of-bounds access. Locate the kernel function
decode_vpq_dataset_kernel and all call sites that pass the computed 1u <<
vpq_dataset.pq_bits(), then either delete the pq_table_size parameter from the
function signature and corresponding launches, or add a bounds-check against
pq_table_size where pq_code (or similar PQ index) is used to index the codebook.
Ensure consistency across all occurrences (also apply same change to the
analogous kernel referenced at lines 53-64).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: a9bd4e85-0f4b-4457-b66b-991a60f896db
📒 Files selected for processing (13)
cpp/CMakeLists.txtcpp/include/cuvs/neighbors/cagra.hppcpp/src/neighbors/detail/cagra/cagra_search.cuhcpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuhcpp/src/neighbors/detail/cagra/compute_distance_vpq.hppcpp/src/neighbors/detail/cagra/compute_distance_vpq_inst.cu.incpp/src/neighbors/detail/cagra/compute_distance_vpq_matrix.jsoncpp/src/neighbors/detail/cagra/device_common.hppcpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuhcpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuhcpp/tests/neighbors/ann_cagra.cuhcpp/tests/neighbors/ann_utils.cuhcpp/tests/neighbors/vpq_utils.cuh
irina-resh-nvda
left a comment
There was a problem hiding this comment.
Reviewed mostly the inner nest in compute_distance_vpq_worker (lines 311–417): the loops, the FP8 unpack via as_half2(bi), and the codebook-transpose math in setup_workspace_vpq. Built a minimal Compiler-Explorer reproducer of the consumer-side lds + unpack chain and inspected the SASS on sm_100. Confirmed one LDS.64 per call (with the cbook/query offset folded into the immediate), one F2FP.F16.E5M2.UNPACK_B per as_half2, and no LDL/STL.
Learned a lot from this one. LGTM.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp`:
- Around line 163-176: The helper dispatch_cagra_smem_dtype must handle
internal_dtype::AUTO before the VPQ smem-tag switch; update
dispatch_cagra_smem_dtype to detect cuvs::neighbors::cagra::internal_dtype::AUTO
and remap it to the same concrete dtype used elsewhere (e.g.,
internal_dtype::E5M2) before or inside the switch, then continue to call the
lambda with tag_smem_e5m2 or tag_smem_f16 as appropriate (keep existing calls to
operator()<tag_smem_f16>() and operator()<tag_smem_e5m2>()); this ensures values
forwarded from cagra_jit_launcher_factory.hpp match the resolution logic in
compute_distance_vpq.hpp and avoid the RAFT_FAIL for AUTO.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: e46b33e7-4a90-4f61-b1c3-f89f27d64bb7
📒 Files selected for processing (20)
cpp/CMakeLists.txtcpp/include/cuvs/detail/jit_lto/cagra/cagra_fragments.hppcpp/include/cuvs/neighbors/cagra.hppcpp/src/neighbors/detail/cagra/cagra_search.cuhcpp/src/neighbors/detail/cagra/compute_distance.hppcpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuhcpp/src/neighbors/detail/cagra/compute_distance_vpq.hppcpp/src/neighbors/detail/cagra/compute_distance_vpq_inst.cu.incpp/src/neighbors/detail/cagra/compute_distance_vpq_matrix.jsoncpp/src/neighbors/detail/cagra/device_memory_ops.hppcpp/src/neighbors/detail/cagra/factory.cuhcpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_jit_launcher_factory.hppcpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hppcpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuhcpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.incpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.jsoncpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuhcpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.incpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.jsoncpp/src/neighbors/detail/cagra/packed_type.hpp
💤 Files with no reviewable changes (7)
- cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in
- cpp/src/neighbors/detail/cagra/packed_type.hpp
- cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json
- cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_impl.cuh
- cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_impl.cuh
- cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in
- cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json
🚧 Files skipped from review as they are similar to previous changes (3)
- cpp/src/neighbors/detail/cagra/cagra_search.cuh
- cpp/include/cuvs/neighbors/cagra.hpp
- cpp/src/neighbors/detail/cagra/compute_distance_vpq_matrix.json
achirkin
left a comment
There was a problem hiding this comment.
Thanks for working on this and sorry for such a long delay in the review process!
Also I like the packed structure; maybe we should consolidate it with the vectorized types in raft in future.
| device::ldg_cg(pq_codes[e], | ||
| reinterpret_cast<const PQ_CODEBOOK_LOAD_T*>(dataset_ptr + 4 + k)); | ||
| } else { | ||
| pq_codes[e] = *reinterpret_cast<const PQ_CODEBOOK_LOAD_T*>(dataset_ptr + 4 + k); |
There was a problem hiding this comment.
Is there a specific reason to not use ldg_cg? If it's just the overload what's missing, we can that overload in https://github.com/rapidsai/cuvs/blob/main/cpp/src/neighbors/detail/cagra/device_memory_ops.hpp
| switch (dataset_block_dim) { | ||
| case 128: std::forward<Lambda>(l).template operator()<8u, 128u>(); return; | ||
| case 256: std::forward<Lambda>(l).template operator()<8u, 256u>(); return; | ||
| case 512: std::forward<Lambda>(l).template operator()<8u, 512u>(); return; | ||
| default: break; |
There was a problem hiding this comment.
Would be nice to reduce these repeated lines (e.g. by introducing a template struct for each template parameter like we do in https://github.com/rapidsai/cuvs/blob/main/cpp/src/neighbors/ivf_pq/ivf_pq_compute_similarity_impl.cuh).
But you follow an already established pattern in this file, so I don't insist we must do it here.
This PR:
E5M2 as smem data type
Using a lower-precision data type helps reduce shared memory bank conflicts and can improve throughput.
Since the quantization error from VQ+PQ is typically larger than the representation error of E5M2, the impact on search recall is expected to be negligible.
Support for PQ_LEN=8
The current cuVS implementation supports only
PQ_LEN = 2(4 bits per vector element) and4(2 bits per vector element).This PR adds support for
PQ_LEN = 8to enable a higher compression ratio (1 bit per vector element).