Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions rust/lance-linalg/src/distance/dot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,44 @@ pub fn dot<T: Dot>(from: &[T], to: &[T]) -> f32 {
T::dot(from, to)
}

/// Dot product between two f32 slices, dispatched to the widest SIMD backend
/// available at runtime. See [`crate::distance::l2::l2_f32`] for why this is
/// needed on top of the generic [`dot`].
#[inline]
pub fn dot_f32(x: &[f32], y: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
use lance_core::utils::cpu::SimdSupport;
if matches!(*SIMD_SUPPORT, SimdSupport::Avx512 | SimdSupport::Avx512FP16) {
// SAFETY: guarded by the runtime AVX-512 detection above.
return unsafe { dot_f32_avx512(x, y) };
}
}
dot(x, y)
}

#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn dot_f32_avx512(x: &[f32], y: &[f32]) -> f32 {
use std::arch::x86_64::*;
debug_assert_eq!(x.len(), y.len());
let n = x.len();
let mut acc = _mm512_setzero_ps();
let mut i = 0usize;
while i + 16 <= n {
let a = _mm512_loadu_ps(x.as_ptr().add(i));
let b = _mm512_loadu_ps(y.as_ptr().add(i));
acc = _mm512_fmadd_ps(a, b, acc);
i += 16;
}
let mut sum = _mm512_reduce_add_ps(acc);
while i < n {
sum += x[i] * y[i];
i += 1;
}
sum
}

/// Negative [Dot] distance.
#[inline]
pub fn dot_distance<T: Dot>(from: &[T], to: &[T]) -> f32 {
Expand Down Expand Up @@ -329,6 +367,17 @@ mod tests {
use num_traits::{Float, FromPrimitive};
use proptest::prelude::*;

#[test]
fn test_dot_f32_dispatch_matches_scalar() {
use approx::assert_relative_eq;
// Covers tail handling for lengths around the 16-lane AVX-512 stride.
for dim in [1usize, 7, 15, 16, 17, 31, 33, 64, 100, 1024] {
let x: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.5 - 3.0).collect();
let y: Vec<f32> = (0..dim).map(|i| (i as f32) * -0.25 + 1.5).collect();
assert_relative_eq!(dot_f32(&x, &y), dot(&x, &y), max_relative = 1e-5);
}
}

#[test]
fn test_dot() {
let x: Vec<f32> = (0..20).map(|v| v as f32).collect();
Expand Down
55 changes: 55 additions & 0 deletions rust/lance-linalg/src/distance/l2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,51 @@ pub fn l2<T: L2>(from: &[T], to: &[T]) -> f32 {
T::l2(from, to)
}

/// L2 distance between two f32 slices, dispatched to the widest SIMD backend
/// available at runtime.
///
/// On x86_64 with AVX-512 this uses 16-wide f32 lanes; otherwise it falls back
/// to [`l2`], which auto-vectorizes to the compiled target (AVX2 on the default
/// `haswell` build). Lance ships an AVX2-baseline binary, so the generic
/// [`l2`] never emits AVX-512 even on capable CPUs — this dispatcher recovers
/// that throughput for callers in the hot path (e.g. the in-memory HNSW index).
#[inline]
pub fn l2_f32(x: &[f32], y: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
use lance_core::utils::cpu::SimdSupport;
if matches!(*SIMD_SUPPORT, SimdSupport::Avx512 | SimdSupport::Avx512FP16) {
// SAFETY: guarded by the runtime AVX-512 detection above.
return unsafe { l2_f32_avx512(x, y) };
}
}
l2(x, y)
}

#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn l2_f32_avx512(x: &[f32], y: &[f32]) -> f32 {
use std::arch::x86_64::*;
debug_assert_eq!(x.len(), y.len());
let n = x.len();
let mut acc = _mm512_setzero_ps();
let mut i = 0usize;
while i + 16 <= n {
let a = _mm512_loadu_ps(x.as_ptr().add(i));
let b = _mm512_loadu_ps(y.as_ptr().add(i));
let diff = _mm512_sub_ps(a, b);
acc = _mm512_fmadd_ps(diff, diff, acc);
i += 16;
}
let mut sum = _mm512_reduce_add_ps(acc);
while i < n {
let diff = x[i] - y[i];
sum += diff * diff;
i += 1;
}
sum
}

/// Calculate L2 distance between two uint8 slices.
#[inline]
pub fn l2_distance_uint_scalar(key: &[u8], target: &[u8]) -> f32 {
Expand Down Expand Up @@ -466,6 +511,16 @@ mod tests {
arbitrary_bf16, arbitrary_f16, arbitrary_f32, arbitrary_f64, arbitrary_vector_pair,
};

#[test]
fn test_l2_f32_dispatch_matches_scalar() {
// Covers tail handling for lengths around the 16-lane AVX-512 stride.
for dim in [1usize, 7, 15, 16, 17, 31, 33, 64, 100, 1024] {
let x: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.5 - 3.0).collect();
let y: Vec<f32> = (0..dim).map(|i| (i as f32) * -0.25 + 1.5).collect();
assert_relative_eq!(l2_f32(&x, &y), l2(&x, &y), max_relative = 1e-5);
}
}

#[test]
fn test_euclidean_distance() {
let mat = FixedSizeListArray::from_iter_primitive::<Float32Type, _, _>(
Expand Down
28 changes: 17 additions & 11 deletions rust/lance/benches/mem_wal/vector/hnsw/mem_wal_hnsw_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ struct Args {
seed: u64,
clusters: usize,
noise: f32,
query_repeats: usize,
}

impl Default for Args {
Expand All @@ -48,6 +49,7 @@ impl Default for Args {
seed: 100,
clusters: 4096,
noise: 0.05,
query_repeats: 1,
}
}
}
Expand Down Expand Up @@ -111,18 +113,21 @@ fn main() -> Result<(), Box<dyn Error>> {

let search_query_ids = query_ids(&args, args.queries);
let query_start = Instant::now();
let hits: usize = search_query_ids
.par_iter()
.map(|row| {
let query = snapshot.vector(*row as u32);
let results = graph
.search(query, SearchParams::new(args.k, args.ef_search), &snapshot)
.expect("search should succeed");
usize::from(results.iter().any(|result| result.id as usize == *row))
})
.sum();
let mut hits = 0usize;
for _ in 0..args.query_repeats {
hits = search_query_ids
.par_iter()
.map(|row| {
let query = snapshot.vector(*row as u32);
let results = graph
.search(query, SearchParams::new(args.k, args.ef_search), &snapshot)
.expect("search should succeed");
usize::from(results.iter().any(|result| result.id as usize == *row))
})
.sum();
}
let query_s = query_start.elapsed().as_secs_f64();
let query_qps = args.queries as f64 / query_s;
let query_qps = (args.queries * args.query_repeats) as f64 / query_s;
let self_recall = hits as f64 / args.queries as f64;

let truth_query_ids = query_ids(&args, args.truth_queries);
Expand Down Expand Up @@ -204,6 +209,7 @@ fn parse_args() -> Result<Args, Box<dyn Error>> {
"--seed" => args.seed = value.parse()?,
"--clusters" => args.clusters = value.parse()?,
"--noise" => args.noise = value.parse()?,
"--query-repeats" => args.query_repeats = value.parse()?,
_ => return Err(format!("unknown argument: {flag}").into()),
}
}
Expand Down
38 changes: 22 additions & 16 deletions rust/lance/benches/mem_wal/vector/hnsw/mem_wal_hnswlib_bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ struct Args {
uint64_t seed = 100;
size_t clusters = 4096;
float noise = 0.05f;
size_t query_repeats = 1;
};

Args parse_args(int argc, char **argv);
Expand Down Expand Up @@ -129,24 +130,27 @@ int main(int argc, char **argv) {
std::vector<size_t> queries = query_ids(args, args.queries);
auto query_start = clock_now();
std::atomic<size_t> hits{0};
parallel_for(0, queries.size(), args.threads, [&](size_t idx, size_t) {
size_t row = queries[idx];
std::priority_queue<std::pair<float, hnswlib::labeltype>> result =
index.searchKnn(data.data() + row * args.dim, args.k);
bool found = false;
while (!result.empty()) {
if (static_cast<size_t>(result.top().second) == row) {
found = true;
break;
for (size_t rep = 0; rep < args.query_repeats; ++rep) {
hits.store(0, std::memory_order_relaxed);
parallel_for(0, queries.size(), args.threads, [&](size_t idx, size_t) {
size_t row = queries[idx];
std::priority_queue<std::pair<float, hnswlib::labeltype>> result =
index.searchKnn(data.data() + row * args.dim, args.k);
bool found = false;
while (!result.empty()) {
if (static_cast<size_t>(result.top().second) == row) {
found = true;
break;
}
result.pop();
}
result.pop();
}
if (found) {
hits.fetch_add(1, std::memory_order_relaxed);
}
});
if (found) {
hits.fetch_add(1, std::memory_order_relaxed);
}
});
}
double query_s = elapsed_seconds(query_start);
double query_qps = static_cast<double>(args.queries) / query_s;
double query_qps = static_cast<double>(args.queries * args.query_repeats) / query_s;
double self_recall = static_cast<double>(hits.load()) / static_cast<double>(args.queries);

std::vector<size_t> truth_queries = query_ids(args, args.truth_queries);
Expand Down Expand Up @@ -241,6 +245,8 @@ Args parse_args(int argc, char **argv) {
args.clusters = parse_size(value);
} else if (flag == "--noise") {
args.noise = std::stof(value);
} else if (flag == "--query-repeats") {
args.query_repeats = parse_size(value);
} else {
throw std::invalid_argument("unknown argument: " + flag);
}
Expand Down
94 changes: 94 additions & 0 deletions rust/lance/benches/mem_wal/vector/hnsw/run_parity_suite.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#!/usr/bin/env bash
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The Lance Authors
#
# Runs the Lance HNSW primitive vs hnswlib across memtable sizes (100k/500k/1M),
# capturing insert/query throughput, recall, peak RSS (/usr/bin/time -v) and CPU
# counters (perf stat, if available). Designed for the parity benchmark loop.

set -euo pipefail

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd -P)"
# Cargo writes to the workspace target dir at the repo root; resolve it robustly.
REPO_ROOT="$(git -C "$SCRIPT_DIR" rev-parse --show-toplevel 2>/dev/null)"
REPO_ROOT="${REPO_ROOT:-$(cd "$SCRIPT_DIR/../../../../.." && pwd -P)}"
HNSWLIB_DIR="${HNSWLIB_DIR:-$HOME/oss/hnswlib}"
# Honor CARGO_TARGET_DIR so haswell vs target-cpu=native builds can coexist.
TARGET_DIR="${CARGO_TARGET_DIR:-$REPO_ROOT/target}"
OUT_DIR="${OUT_DIR:-$TARGET_DIR/parity_suite}"

SIZES="${SIZES:-100000 500000 1000000}"
DIM="${DIM:-1024}"
QUERIES="${QUERIES:-5000}"
QUERY_REPEATS="${QUERY_REPEATS:-20}"
TRUTH_QUERIES="${TRUTH_QUERIES:-200}"
K="${K:-10}"
M="${M:-12}"
EF_CONSTRUCTION="${EF_CONSTRUCTION:-64}"
EF_SEARCH="${EF_SEARCH:-64}"
THREADS="${THREADS:-$(getconf _NPROCESSORS_ONLN 2>/dev/null || sysctl -n hw.ncpu)}"
SEED="${SEED:-100}"
CLUSTERS="${CLUSTERS:-4096}"
NOISE="${NOISE:-0.05}"

if [ ! -d "$HNSWLIB_DIR/hnswlib" ]; then
echo "ERROR: HNSWLIB_DIR must point to a hnswlib checkout, got: $HNSWLIB_DIR" >&2
exit 1
fi

mkdir -p "$OUT_DIR" "$TARGET_DIR/release"

# Locate /usr/bin/time (GNU time, for peak RSS). Fall back to no wrapper.
TIME_BIN=""
if [ -x /usr/bin/time ]; then
TIME_BIN="/usr/bin/time"
fi
echo "=== Building Lance HNSW benchmark (release) ==="
cargo bench -p lance --bench mem_wal_hnsw_bench --no-run 2>&1 | tail -3
LANCE_BIN="$(find "$TARGET_DIR/release/deps" -maxdepth 1 -type f -perm -111 -name 'mem_wal_hnsw_bench-*' | sort | tail -n 1)"

echo "=== Building hnswlib benchmark (g++ -O3 -march=native) ==="
g++ -std=c++17 -O3 -march=native -DNDEBUG -pthread \
-I "$HNSWLIB_DIR" \
"$SCRIPT_DIR/mem_wal_hnswlib_bench.cpp" \
-o "$TARGET_DIR/release/hnswlib_bench"
HNSWLIB_BIN="$TARGET_DIR/release/hnswlib_bench"

run_one() {
local impl="$1" bin="$2" rows="$3"
local tag="${impl}_r${rows}"
local args=(
--rows "$rows" --dim "$DIM" --queries "$QUERIES" --truth-queries "$TRUTH_QUERIES"
--k "$K" --m "$M" --ef-construction "$EF_CONSTRUCTION" --ef-search "$EF_SEARCH"
--threads "$THREADS" --seed "$SEED" --clusters "$CLUSTERS" --noise "$NOISE"
--query-repeats "$QUERY_REPEATS"
)
local out="$OUT_DIR/${tag}.out"
local timef="$OUT_DIR/${tag}.time"
echo "--- run $tag ---"
if [ -n "$TIME_BIN" ]; then
"$TIME_BIN" -v "$bin" "${args[@]}" >"$out" 2>"$timef" || cat "$timef"
else
"$bin" "${args[@]}" >"$out" 2>&1
fi
grep -E '^(bench|result)' "$out" || true
if [ -f "$timef" ]; then
grep -E 'Maximum resident set size|Elapsed \(wall|Percent of CPU' "$timef" || true
fi
}

echo "=== Parity suite: sizes=[$SIZES] dim=$DIM threads=$THREADS ==="
for rows in $SIZES; do
run_one lance "$LANCE_BIN" "$rows"
run_one hnswlib "$HNSWLIB_BIN" "$rows"
done

echo "=== SUMMARY (json lines) ==="
grep -h '^{' "$OUT_DIR"/*.out 2>/dev/null || true
echo "=== peak RSS (KB) ==="
for f in "$OUT_DIR"/*.time; do
[ -f "$f" ] || continue
rss=$(grep 'Maximum resident set size' "$f" | grep -oE '[0-9]+' | head -1)
printf '%-28s %s\n' "$(basename "$f" .time)" "${rss:-NA}"
done
echo "=== results written to $OUT_DIR ==="
6 changes: 3 additions & 3 deletions rust/lance/src/dataset/mem_wal/hnsw/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use arrow_array::types::Float32Type;
use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array, RecordBatch, UInt64Array};
use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef};
use lance_core::{Error, ROW_ID, Result};
use lance_linalg::distance::{DistanceType, Dot, L2, cosine_distance};
use lance_linalg::distance::{DistanceType, cosine_distance, dot_f32, l2_f32};

use super::graph::ScoredPoint;

Expand Down Expand Up @@ -79,8 +79,8 @@ pub trait VectorSource: Send + Sync {
/// distance kernels.
pub fn compute_f32_distance(query: &[f32], vector: &[f32], distance_type: DistanceType) -> f32 {
match distance_type {
DistanceType::L2 => f32::l2(query, vector),
DistanceType::Dot => f32::dot(query, vector),
DistanceType::L2 => l2_f32(query, vector),
DistanceType::Dot => dot_f32(query, vector),
DistanceType::Cosine => cosine_distance(query, vector),
DistanceType::Hamming => f32::INFINITY,
}
Expand Down
Loading