diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 147b5b4b9e3..1b82e5298ed 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -3305,6 +3305,7 @@ def _create_index_impl( streaming_coreset_rate: Optional[int] = None, streaming_refine_passes: Optional[int] = None, skip_transpose: bool = False, + rabitq_model: Optional[str] = None, require_commit: bool = True, **kwargs, ) -> Index: @@ -3626,6 +3627,9 @@ def _create_index_impl( if skip_transpose: kwargs["skip_transpose"] = True + if rabitq_model is not None: + kwargs["rabitq_model"] = rabitq_model + # Add fragment_ids and index_uuid to kwargs if provided for # distributed indexing if fragment_ids is not None: @@ -3948,6 +3952,7 @@ def create_index_uncommitted( streaming_coreset_rate: Optional[int] = None, streaming_refine_passes: Optional[int] = None, skip_transpose: bool = False, + rabitq_model: Optional[str] = None, **kwargs, ) -> Index: """ @@ -3974,6 +3979,12 @@ def create_index_uncommitted( requirement: - ``fragment_ids`` must be provided + - ``rabitq_model`` (``IVF_RQ`` only): a JSON string produced by + ``lance.lance.indices.build_rq_model``. It must be identical across all + workers for their segments to be mergeable, since it pins the RaBitQ + rotation so every segment rotates vectors the same way. If omitted, each + call generates its own random rotation, which is only safe for a single, + non-merged segment. Returns ------- @@ -4006,6 +4017,7 @@ def create_index_uncommitted( streaming_coreset_rate=streaming_coreset_rate, streaming_refine_passes=streaming_refine_passes, skip_transpose=skip_transpose, + rabitq_model=rabitq_model, require_commit=False, **kwargs, ) diff --git a/python/python/lance/lance/indices/__init__.pyi b/python/python/lance/lance/indices/__init__.pyi index fc5d03b80bd..727469c4bc3 100644 --- a/python/python/lance/lance/indices/__init__.pyi +++ b/python/python/lance/lance/indices/__init__.pyi @@ -67,6 +67,11 @@ def transform_vectors( pq_codebook: pa.Array, dst_uri: str, ): ... +def build_rq_model( + dimension: int, + num_bits: int = 1, + dtype: str = "float32", +) -> str: ... class IndexSegmentDescription: uuid: str diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index 8898d8853ff..c74aaa28684 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -3035,6 +3035,51 @@ def test_commit_existing_index_segments_accepts_index_metadata(tmp_path): assert 0 < len(results) <= 5 +def test_distributed_ivf_rq_shared_rotation(tmp_path): + """Two IVF_RQ segments built on separate fragments with one shared RaBitQ rotation + merge into a single committed, queryable index. The shared ``rabitq_model`` (from + ``lance.lance.indices.build_rq_model``) is what makes the independently built + segments mergeable.""" + from lance.lance import indices + + dim = 32 + ds = _make_sample_dataset_base( + tmp_path, "dist_rq_merge", n_rows=512, dim=dim, max_rows_per_file=256 + ) + frags = ds.get_fragments() + assert len(frags) == 2 + + ivf_model = IndicesBuilder(ds, "vector").train_ivf( + num_partitions=2, + distance_type="l2", + sample_rate=8, + ) + rabitq_model = indices.build_rq_model(dimension=dim, num_bits=1) + base_kwargs = { + "column": "vector", + "index_type": "IVF_RQ", + "num_partitions": 2, + "num_bits": 1, + "ivf_centroids": ivf_model.centroids, + "rabitq_model": rabitq_model, + } + first = ds.create_index_uncommitted( + **base_kwargs, + fragment_ids=[frags[0].fragment_id], + ) + second = ds.create_index_uncommitted( + **base_kwargs, + fragment_ids=[frags[1].fragment_id], + ) + + merged = ds.merge_existing_index_segments([first, second]) + ds = ds.commit_existing_index_segments("vector_idx", "vector", [merged]) + + q = np.random.rand(dim).astype(np.float32) + results = ds.to_table(nearest={"column": "vector", "q": q, "k": 5}) + assert 0 < len(results) <= 5 + + def test_index_segment_builder_builds_vector_segments(tmp_path): ds = _make_sample_dataset_base(tmp_path, "segment_builder_ds", 2000, 128) frags = ds.get_fragments() diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 45350e92109..2deeb610be0 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -20,6 +20,7 @@ use blob::LanceBlobFile; use chrono::{Duration, TimeDelta, Utc}; use futures::{StreamExt, TryFutureExt}; use lance_index::vector::bq::RQBuildParams; +use lance_index::vector::bq::storage::RabitQuantizationMetadata; use log::error; use object_store::path::Path; use pyo3::exceptions::{PyStopIteration, PyTypeError}; @@ -4439,6 +4440,13 @@ fn prepare_vector_index_params( pq_params.codebook = Some(codebook.values().clone()) }; + if let Some(r) = kwargs.get_item("rabitq_model")? { + let json: String = r.extract()?; + let meta: RabitQuantizationMetadata = serde_json::from_str(&json) + .map_err(|e| PyValueError::new_err(format!("Invalid rabitq_model JSON: {e}")))?; + rq_params.rotation = Some(meta); + }; + if let Some(version) = kwargs.get_item("index_file_version")? { let version: String = version.extract()?; index_file_version = IndexFileVersion::try_from(&version) diff --git a/python/src/indices.rs b/python/src/indices.rs index cf93579b867..a0fec3c8812 100644 --- a/python/src/indices.rs +++ b/python/src/indices.rs @@ -358,6 +358,61 @@ fn train_pq_model<'py>( codebook.to_pyarrow(py) } +/// Mint one RaBitQ rotation and return it as a JSON string. +/// +/// Distributed IVF_RQ builds must pin a single rotation across all workers so that +/// independently built per-fragment segments rotate vectors identically and their +/// binary codes remain comparable when merged. A driver calls this once and broadcasts +/// the resulting string to every `create_index_uncommitted(..., rabitq_model=...)` call. +/// +/// The rotation is always the "fast" rotation since its sign vector is JSON-serializable, +/// whereas the "matrix" rotation stores a dense matrix in a binary buffer that is dropped by +/// the JSON wire format. `dtype` is accepted for API symmetry but does not affect the fast +/// rotation. +/// +/// # Example (Python) +/// +/// ```python +/// from lance.lance import indices +/// +/// # Mint one model and broadcast `model` to every worker. +/// model = indices.build_rq_model(dimension=128, num_bits=1) +/// seg = ds.create_index_uncommitted( +/// column="vector", +/// index_type="IVF_RQ", +/// num_partitions=256, +/// ivf_centroids=centroids, +/// rabitq_model=model, +/// fragment_ids=my_fragments, +/// ) +/// ``` +#[pyfunction] +#[pyo3(signature = (dimension, num_bits=1, dtype="float32"))] +pub fn build_rq_model(dimension: usize, num_bits: u8, dtype: &str) -> PyResult { + use arrow::datatypes::{Float16Type, Float32Type, Float64Type}; + use lance_index::vector::bq::builder::RabitQuantizer; + use lance_index::vector::bq::RQRotationType; + use lance_index::vector::quantizer::Quantization; + + if !dimension.is_multiple_of(u8::BITS as usize) { + return Err(PyValueError::new_err( + "dimension must be divisible by 8 for IVF_RQ", + )); + } + let dim = dimension as i32; + let rotation = RQRotationType::Fast; + let quantizer = match dtype.to_lowercase().as_str() { + "float16" => RabitQuantizer::new_with_rotation::(num_bits, dim, rotation), + "float32" => RabitQuantizer::new_with_rotation::(num_bits, dim, rotation), + "float64" => RabitQuantizer::new_with_rotation::(num_bits, dim, rotation), + other => { + return Err(PyValueError::new_err(format!("unsupported dtype: {other}"))); + } + }; + serde_json::to_string(&quantizer.metadata(None)) + .map_err(|e| PyValueError::new_err(format!("failed to serialize RQ model: {e}"))) +} + #[allow(clippy::too_many_arguments)] async fn do_transform_vectors( dataset: &Dataset, @@ -752,6 +807,7 @@ pub fn register_indices(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { let indices = PyModule::new(py, "indices")?; indices.add_wrapped(wrap_pyfunction!(train_ivf_model))?; indices.add_wrapped(wrap_pyfunction!(train_pq_model))?; + indices.add_wrapped(wrap_pyfunction!(build_rq_model))?; indices.add_wrapped(wrap_pyfunction!(transform_vectors))?; indices.add_wrapped(wrap_pyfunction!(shuffle_transformed_vectors))?; indices.add_wrapped(wrap_pyfunction!(load_shuffled_vectors))?; diff --git a/rust/lance-index/src/vector/bq.rs b/rust/lance-index/src/vector/bq.rs index 1df04d4b134..51439e2c905 100644 --- a/rust/lance-index/src/vector/bq.rs +++ b/rust/lance-index/src/vector/bq.rs @@ -14,6 +14,7 @@ use lance_core::{Error, Result}; use num_traits::Float; use serde::{Deserialize, Serialize}; +use crate::vector::bq::storage::RabitQuantizationMetadata; use crate::vector::quantizer::QuantizerBuildParams; pub mod builder; @@ -104,10 +105,16 @@ impl FromStr for RQRotationType { } } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug)] pub struct RQBuildParams { pub num_bits: u8, pub rotation_type: RQRotationType, + /// Optional pre-built rotation to reuse instead of generating a fresh random one. + /// + /// Distributed `IVF_RQ` builds mint one rotation and broadcast it so every segment + /// rotates vectors identically. This is transient build-time state and is never + /// persisted to the `RabitQuantization` params proto. + pub rotation: Option, } pub fn validate_rq_num_bits(num_bits: u8) -> Result<()> { @@ -155,6 +162,7 @@ impl RQBuildParams { Self { num_bits, rotation_type: RQRotationType::default(), + rotation: None, } } @@ -162,6 +170,7 @@ impl RQBuildParams { Self { num_bits, rotation_type, + rotation: None, } } } @@ -190,6 +199,7 @@ impl Default for RQBuildParams { Self { num_bits: 1, rotation_type: RQRotationType::default(), + rotation: None, } } } diff --git a/rust/lance-index/src/vector/bq/builder.rs b/rust/lance-index/src/vector/bq/builder.rs index f98c370aefc..70e084472d7 100644 --- a/rust/lance-index/src/vector/bq/builder.rs +++ b/rust/lance-index/src/vector/bq/builder.rs @@ -23,7 +23,7 @@ use crate::vector::bq::storage::{ use crate::vector::bq::transform::{ADD_FACTORS_FIELD, SCALE_FACTORS_FIELD}; use crate::vector::bq::{ RQBuildParams, RQRotationType, rabit_binary_code_bytes, - rotation::{apply_fast_rotation, random_fast_rotation_signs}, + rotation::{apply_fast_rotation, fast_rotation_signs_len, random_fast_rotation_signs}, validate_supported_rq_num_bits, }; use crate::vector::quantizer::{Quantization, Quantizer, QuantizerBuildParams}; @@ -329,6 +329,46 @@ impl Quantization for RabitQuantizer { )); } + // Reuse a supplied rotation instead of generating a fresh random one. + if let Some(meta) = ¶ms.rotation { + let expected_code_dim = dim * params.num_bits as usize; + if meta.num_bits != params.num_bits || meta.code_dim as usize != expected_code_dim { + return Err(Error::invalid_input(format!( + "supplied RaBitQ rotation does not match build params: rotation \ + num_bits={}, code_dim={}; expected num_bits={}, code_dim={}", + meta.num_bits, meta.code_dim, params.num_bits, expected_code_dim + ))); + } + + match meta.rotation_type { + RQRotationType::Fast => { + let signs = meta.fast_rotation_signs.as_ref().ok_or_else(|| { + Error::invalid_input("supplied fast RaBitQ rotation is missing signs") + })?; + let expected_len = fast_rotation_signs_len(meta.code_dim as usize); + if signs.len() != expected_len { + return Err(Error::invalid_input(format!( + "supplied fast RaBitQ rotation signs length {} does not match \ + expected {} for code_dim={}", + signs.len(), + expected_len, + meta.code_dim + ))); + } + } + RQRotationType::Matrix => { + if meta.rotate_mat.is_none() { + return Err(Error::invalid_input( + "use the fast rotation for distributed builds", + )); + } + } + } + return Ok(Self { + metadata: meta.clone(), + }); + } + let q = match data.as_fixed_size_list().value_type() { DataType::Float16 => Self::new_with_rotation::( params.num_bits, @@ -594,6 +634,124 @@ mod tests { ); } + fn sample_fsl(n: usize, dim: usize) -> FixedSizeListArray { + let values: Vec = (0..n * dim).map(|i| ((i * 31 % 17) as f32) - 8.0).collect(); + FixedSizeListArray::try_new_from_values(Float32Array::from(values), dim as i32).unwrap() + } + + fn quantized_codes(q: &RabitQuantizer, data: &FixedSizeListArray) -> Vec { + use arrow::datatypes::UInt8Type; + q.quantize(data) + .unwrap() + .as_fixed_size_list() + .values() + .as_primitive::() + .values() + .to_vec() + } + + #[test] + fn test_shared_fast_rotation_gives_identical_codes() { + let dim = 32; + let seed = RabitQuantizer::new_with_rotation::(1, dim, RQRotationType::Fast); + let json = serde_json::to_string(&seed.metadata(None)).unwrap(); + let meta: RabitQuantizationMetadata = serde_json::from_str(&json).unwrap(); + + let params = RQBuildParams { + num_bits: 1, + rotation_type: RQRotationType::Fast, + rotation: Some(meta), + }; + let data = sample_fsl(8, dim as usize); + let q_a = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap(); + let q_b = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap(); + + assert_eq!( + quantized_codes(&q_a, &data), + quantized_codes(&q_b, &data), + "shared rotation must yield identical codes" + ); + } + + #[test] + fn test_unpinned_rotation_gives_different_codes() { + let dim = 32; + let params = RQBuildParams::new(1); + let data = sample_fsl(8, dim as usize); + let q_a = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap(); + let q_b = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap(); + + assert_ne!( + quantized_codes(&q_a, &data), + quantized_codes(&q_b, &data), + "independent unpinned rotations must yield different codes" + ); + } + + #[test] + fn test_build_rejects_rotation_with_mismatched_code_dim() { + let seed = RabitQuantizer::new_with_rotation::(1, 16, RQRotationType::Fast); + let params = RQBuildParams { + num_bits: 1, + rotation_type: RQRotationType::Fast, + rotation: Some(seed.metadata(None)), + }; + let data = sample_fsl(4, 32); + let err = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap_err(); + assert!( + err.to_string().contains("does not match build params"), + "{}", + err + ); + } + + #[test] + fn test_build_rejects_fast_rotation_with_bad_signs_length() { + let dim = 16; + let seed = RabitQuantizer::new_with_rotation::(1, dim, RQRotationType::Fast); + let mut meta = seed.metadata(None); + // Corrupt the signs to the wrong length (valid would be 4 * ceil(16/8) = 8). + meta.fast_rotation_signs = Some(vec![0u8; 7]); + let params = RQBuildParams { + num_bits: 1, + rotation_type: RQRotationType::Fast, + rotation: Some(meta), + }; + let data = sample_fsl(4, dim as usize); + let err = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap_err(); + assert!(err.to_string().contains("signs length"), "{}", err); + } + + #[test] + fn test_matrix_rotation_lost_through_json_is_rejected() { + let dim = 16; + let seed = RabitQuantizer::new_with_rotation::(1, dim, RQRotationType::Matrix); + let meta = seed.metadata(None); + assert!(meta.rotate_mat.is_some()); + + let json = serde_json::to_string(&meta).unwrap(); + let parsed: RabitQuantizationMetadata = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.rotation_type, RQRotationType::Matrix); + assert!( + parsed.rotate_mat.is_none(), + "matrix is expected to be dropped by JSON serialization" + ); + + let params = RQBuildParams { + num_bits: 1, + rotation_type: RQRotationType::Matrix, + rotation: Some(parsed), + }; + let data = sample_fsl(4, dim as usize); + let err = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap_err(); + assert!( + err.to_string() + .contains("fast rotation for distributed builds"), + "{}", + err + ); + } + #[test] fn test_rabit_quantizer_rejects_unsupported_num_bits() { let vectors = Float32Array::from(vec![0.0f32; 4 * 32]); diff --git a/rust/lance-index/src/vector/bq/rotation.rs b/rust/lance-index/src/vector/bq/rotation.rs index 4f4895ac198..de8c8acccb3 100644 --- a/rust/lance-index/src/vector/bq/rotation.rs +++ b/rust/lance-index/src/vector/bq/rotation.rs @@ -138,9 +138,13 @@ fn sign_bytes_per_round(dim: usize) -> usize { dim.div_ceil(8) } +pub(crate) fn fast_rotation_signs_len(dim: usize) -> usize { + FAST_ROTATION_ROUNDS * sign_bytes_per_round(dim) +} + pub fn random_fast_rotation_signs(dim: usize) -> Vec { // Each round needs one random sign bit per dimension. - let mut signs = vec![0u8; FAST_ROTATION_ROUNDS * sign_bytes_per_round(dim)]; + let mut signs = vec![0u8; fast_rotation_signs_len(dim)]; rand::rng().fill_bytes(&mut signs); signs } diff --git a/rust/lance/src/index/vector.rs b/rust/lance/src/index/vector.rs index ff7d2383c67..04c9d31a0e8 100644 --- a/rust/lance/src/index/vector.rs +++ b/rust/lance/src/index/vector.rs @@ -1880,6 +1880,7 @@ fn derive_rabit_params(rabit_quantizer: &RabitQuantizer) -> RQBuildParams { RQBuildParams { num_bits: rabit_quantizer.num_bits(), rotation_type: rabit_quantizer.rotation_type(), + rotation: None, } }