From a5bd8a8126dbbe40f6ae44185b0dc4d6ddfa14de Mon Sep 17 00:00:00 2001 From: gstamatakis95 <126914070+gstamatakis95@users.noreply.github.com> Date: Sun, 31 May 2026 12:47:12 +0200 Subject: [PATCH 1/2] feat(python): add shared RaBitQ rotation for distributed IVF_RQ builds --- python/python/lance/dataset.py | 12 ++ .../python/lance/lance/indices/__init__.pyi | 6 + python/python/tests/test_vector_index.py | 45 +++++ python/src/dataset.rs | 8 + python/src/indices.rs | 74 ++++++++ rust/lance-index/src/vector/bq.rs | 12 +- rust/lance-index/src/vector/bq/builder.rs | 160 +++++++++++++++++- rust/lance-index/src/vector/bq/rotation.rs | 6 +- rust/lance/src/index/vector.rs | 1 + 9 files changed, 321 insertions(+), 3 deletions(-) diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 4cbd39fdebb..a3aea17c6b2 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -3304,6 +3304,7 @@ def _create_index_impl( *, target_partition_size: Optional[int] = None, skip_transpose: bool = False, + rq_rotation: Optional[str] = None, require_commit: bool = True, **kwargs, ) -> Index: @@ -3619,6 +3620,9 @@ def _create_index_impl( if skip_transpose: kwargs["skip_transpose"] = True + if rq_rotation is not None: + kwargs["rq_rotation"] = rq_rotation + # Add fragment_ids and index_uuid to kwargs if provided for # distributed indexing if fragment_ids is not None: @@ -3919,6 +3923,7 @@ def create_index_uncommitted( *, target_partition_size: Optional[int] = None, skip_transpose: bool = False, + rq_rotation: Optional[str] = None, **kwargs, ) -> Index: """ @@ -3945,6 +3950,12 @@ def create_index_uncommitted( requirement: - ``fragment_ids`` must be provided + - ``rq_rotation`` (``IVF_RQ`` only): a JSON string produced by + ``lance.lance.indices.build_rq_rotation``. It must be identical across all + workers for their segments to be mergeable, since it pins the RaBitQ + rotation so every segment rotates vectors the same way. If omitted, each + call generates its own random rotation, which is only safe for a single, + non-merged segment. Returns ------- @@ -3974,6 +3985,7 @@ def create_index_uncommitted( index_uuid=index_uuid, target_partition_size=target_partition_size, skip_transpose=skip_transpose, + rq_rotation=rq_rotation, require_commit=False, **kwargs, ) diff --git a/python/python/lance/lance/indices/__init__.pyi b/python/python/lance/lance/indices/__init__.pyi index fc5d03b80bd..8afb1761f86 100644 --- a/python/python/lance/lance/indices/__init__.pyi +++ b/python/python/lance/lance/indices/__init__.pyi @@ -67,6 +67,12 @@ def transform_vectors( pq_codebook: pa.Array, dst_uri: str, ): ... +def build_rq_rotation( + dimension: int, + num_bits: int = 1, + rotation_type: str = "fast", + dtype: str = "float32", +) -> str: ... class IndexSegmentDescription: uuid: str diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index 4cf1c4947e3..49ec6c28501 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -3023,6 +3023,51 @@ def test_commit_existing_index_segments_accepts_index_metadata(tmp_path): assert 0 < len(results) <= 5 +def test_distributed_ivf_rq_shared_rotation(tmp_path): + """Two IVF_RQ segments built on separate fragments with one shared RaBitQ rotation + merge into a single committed, queryable index. The shared ``rq_rotation`` (from + ``lance.lance.indices.build_rq_rotation``) is what makes the independently built + segments mergeable.""" + from lance.lance import indices + + dim = 32 + ds = _make_sample_dataset_base( + tmp_path, "dist_rq_merge", n_rows=512, dim=dim, max_rows_per_file=256 + ) + frags = ds.get_fragments() + assert len(frags) == 2 + + ivf_model = IndicesBuilder(ds, "vector").train_ivf( + num_partitions=2, + distance_type="l2", + sample_rate=8, + ) + rq_rotation = indices.build_rq_rotation(dimension=dim, num_bits=1) + base_kwargs = { + "column": "vector", + "index_type": "IVF_RQ", + "num_partitions": 2, + "num_bits": 1, + "ivf_centroids": ivf_model.centroids, + "rq_rotation": rq_rotation, + } + first = ds.create_index_uncommitted( + **base_kwargs, + fragment_ids=[frags[0].fragment_id], + ) + second = ds.create_index_uncommitted( + **base_kwargs, + fragment_ids=[frags[1].fragment_id], + ) + + merged = ds.merge_existing_index_segments([first, second]) + ds = ds.commit_existing_index_segments("vector_idx", "vector", [merged]) + + q = np.random.rand(dim).astype(np.float32) + results = ds.to_table(nearest={"column": "vector", "q": q, "k": 5}) + assert 0 < len(results) <= 5 + + def test_index_segment_builder_builds_vector_segments(tmp_path): ds = _make_sample_dataset_base(tmp_path, "segment_builder_ds", 2000, 128) frags = ds.get_fragments() diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 0f834484f30..3143854d50a 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -20,6 +20,7 @@ use blob::LanceBlobFile; use chrono::{Duration, TimeDelta, Utc}; use futures::{StreamExt, TryFutureExt}; use lance_index::vector::bq::RQBuildParams; +use lance_index::vector::bq::storage::RabitQuantizationMetadata; use log::error; use object_store::path::Path; use pyo3::exceptions::{PyStopIteration, PyTypeError}; @@ -4429,6 +4430,13 @@ fn prepare_vector_index_params( pq_params.codebook = Some(codebook.values().clone()) }; + if let Some(r) = kwargs.get_item("rq_rotation")? { + let json: String = r.extract()?; + let meta: RabitQuantizationMetadata = serde_json::from_str(&json) + .map_err(|e| PyValueError::new_err(format!("Invalid rq_rotation JSON: {e}")))?; + rq_params.rotation = Some(meta); + }; + if let Some(version) = kwargs.get_item("index_file_version")? { let version: String = version.extract()?; index_file_version = IndexFileVersion::try_from(&version) diff --git a/python/src/indices.rs b/python/src/indices.rs index cf93579b867..7a7667fbc18 100644 --- a/python/src/indices.rs +++ b/python/src/indices.rs @@ -358,6 +358,79 @@ fn train_pq_model<'py>( codebook.to_pyarrow(py) } +/// Mint one RaBitQ rotation and return it as a JSON string. +/// +/// Distributed IVF_RQ builds must pin a single rotation across all workers so that +/// independently built per-fragment segments rotate vectors identically and their +/// binary codes remain comparable when merged. A driver calls this once and broadcasts +/// the resulting string to every `create_index_uncommitted(..., rq_rotation=...)` call. +/// +/// Only the "fast" rotation is supported since its sign vector is JSON-serializable, whereas +/// the "matrix" rotation stores a dense matrix in a binary buffer that is dropped by the +/// JSON wire format. `dtype` is accepted for API symmetry but does not affect the fast +/// rotation. +/// +/// # Example (Python) +/// +/// ```python +/// from lance.lance import indices +/// +/// # Mint one rotation and broadcast `rot` to every worker. +/// rot = indices.build_rq_rotation(dimension=128, num_bits=1) +/// seg = ds.create_index_uncommitted( +/// column="vector", +/// index_type="IVF_RQ", +/// num_partitions=256, +/// ivf_centroids=centroids, +/// rq_rotation=rot, +/// fragment_ids=my_fragments, +/// ) +/// ``` +#[pyfunction] +#[pyo3(signature = (dimension, num_bits=1, rotation_type="fast", dtype="float32"))] +pub fn build_rq_rotation( + dimension: usize, + num_bits: u8, + rotation_type: &str, + dtype: &str, +) -> PyResult { + use arrow::datatypes::{Float16Type, Float32Type, Float64Type}; + use lance_index::vector::bq::builder::RabitQuantizer; + use lance_index::vector::bq::RQRotationType; + use lance_index::vector::quantizer::Quantization; + + if !dimension.is_multiple_of(u8::BITS as usize) { + return Err(PyValueError::new_err( + "dimension must be divisible by 8 for IVF_RQ", + )); + } + let rotation = match rotation_type.to_lowercase().as_str() { + "fast" => RQRotationType::Fast, + "matrix" => { + return Err(PyValueError::new_err( + "matrix rotation cannot be serialized to JSON for distributed builds; \ + use rotation_type='fast'", + )); + } + other => { + return Err(PyValueError::new_err(format!( + "unknown rotation_type: {other}; expected 'fast'" + ))); + } + }; + let dim = dimension as i32; + let quantizer = match dtype.to_lowercase().as_str() { + "float16" => RabitQuantizer::new_with_rotation::(num_bits, dim, rotation), + "float32" => RabitQuantizer::new_with_rotation::(num_bits, dim, rotation), + "float64" => RabitQuantizer::new_with_rotation::(num_bits, dim, rotation), + other => { + return Err(PyValueError::new_err(format!("unsupported dtype: {other}"))); + } + }; + serde_json::to_string(&quantizer.metadata(None)) + .map_err(|e| PyValueError::new_err(format!("failed to serialize RQ rotation: {e}"))) +} + #[allow(clippy::too_many_arguments)] async fn do_transform_vectors( dataset: &Dataset, @@ -752,6 +825,7 @@ pub fn register_indices(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { let indices = PyModule::new(py, "indices")?; indices.add_wrapped(wrap_pyfunction!(train_ivf_model))?; indices.add_wrapped(wrap_pyfunction!(train_pq_model))?; + indices.add_wrapped(wrap_pyfunction!(build_rq_rotation))?; indices.add_wrapped(wrap_pyfunction!(transform_vectors))?; indices.add_wrapped(wrap_pyfunction!(shuffle_transformed_vectors))?; indices.add_wrapped(wrap_pyfunction!(load_shuffled_vectors))?; diff --git a/rust/lance-index/src/vector/bq.rs b/rust/lance-index/src/vector/bq.rs index a0a16b22169..62de70f2bf3 100644 --- a/rust/lance-index/src/vector/bq.rs +++ b/rust/lance-index/src/vector/bq.rs @@ -14,6 +14,7 @@ use lance_core::{Error, Result}; use num_traits::Float; use serde::{Deserialize, Serialize}; +use crate::vector::bq::storage::RabitQuantizationMetadata; use crate::vector::quantizer::QuantizerBuildParams; pub mod builder; @@ -100,10 +101,16 @@ impl FromStr for RQRotationType { } } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug)] pub struct RQBuildParams { pub num_bits: u8, pub rotation_type: RQRotationType, + /// Optional pre-built rotation to reuse instead of generating a fresh random one. + /// + /// Distributed `IVF_RQ` builds mint one rotation and broadcast it so every segment + /// rotates vectors identically. This is transient build-time state and is never + /// persisted to the `RabitQuantization` params proto. + pub rotation: Option, } impl RQBuildParams { @@ -111,6 +118,7 @@ impl RQBuildParams { Self { num_bits, rotation_type: RQRotationType::default(), + rotation: None, } } @@ -118,6 +126,7 @@ impl RQBuildParams { Self { num_bits, rotation_type, + rotation: None, } } } @@ -146,6 +155,7 @@ impl Default for RQBuildParams { Self { num_bits: 1, rotation_type: RQRotationType::default(), + rotation: None, } } } diff --git a/rust/lance-index/src/vector/bq/builder.rs b/rust/lance-index/src/vector/bq/builder.rs index 491e14d3af9..c61c35374e1 100644 --- a/rust/lance-index/src/vector/bq/builder.rs +++ b/rust/lance-index/src/vector/bq/builder.rs @@ -22,7 +22,7 @@ use crate::vector::bq::storage::{ use crate::vector::bq::transform::{ADD_FACTORS_FIELD, SCALE_FACTORS_FIELD}; use crate::vector::bq::{ RQBuildParams, RQRotationType, - rotation::{apply_fast_rotation, random_fast_rotation_signs}, + rotation::{apply_fast_rotation, fast_rotation_signs_len, random_fast_rotation_signs}, }; use crate::vector::quantizer::{Quantization, Quantizer, QuantizerBuildParams}; @@ -324,6 +324,46 @@ impl Quantization for RabitQuantizer { )); } + // Reuse a supplied rotation instead of generating a fresh random one. + if let Some(meta) = ¶ms.rotation { + let expected_code_dim = dim * params.num_bits as usize; + if meta.num_bits != params.num_bits || meta.code_dim as usize != expected_code_dim { + return Err(Error::invalid_input(format!( + "supplied RaBitQ rotation does not match build params: rotation \ + num_bits={}, code_dim={}; expected num_bits={}, code_dim={}", + meta.num_bits, meta.code_dim, params.num_bits, expected_code_dim + ))); + } + + match meta.rotation_type { + RQRotationType::Fast => { + let signs = meta.fast_rotation_signs.as_ref().ok_or_else(|| { + Error::invalid_input("supplied fast RaBitQ rotation is missing signs") + })?; + let expected_len = fast_rotation_signs_len(meta.code_dim as usize); + if signs.len() != expected_len { + return Err(Error::invalid_input(format!( + "supplied fast RaBitQ rotation signs length {} does not match \ + expected {} for code_dim={}", + signs.len(), + expected_len, + meta.code_dim + ))); + } + } + RQRotationType::Matrix => { + if meta.rotate_mat.is_none() { + return Err(Error::invalid_input( + "use the fast rotation for distributed builds", + )); + } + } + } + return Ok(Self { + metadata: meta.clone(), + }); + } + let q = match data.as_fixed_size_list().value_type() { DataType::Float16 => Self::new_with_rotation::( params.num_bits, @@ -582,4 +622,122 @@ mod tests { err ); } + + fn sample_fsl(n: usize, dim: usize) -> FixedSizeListArray { + let values: Vec = (0..n * dim).map(|i| ((i * 31 % 17) as f32) - 8.0).collect(); + FixedSizeListArray::try_new_from_values(Float32Array::from(values), dim as i32).unwrap() + } + + fn quantized_codes(q: &RabitQuantizer, data: &FixedSizeListArray) -> Vec { + use arrow::datatypes::UInt8Type; + q.quantize(data) + .unwrap() + .as_fixed_size_list() + .values() + .as_primitive::() + .values() + .to_vec() + } + + #[test] + fn test_shared_fast_rotation_gives_identical_codes() { + let dim = 32; + let seed = RabitQuantizer::new_with_rotation::(1, dim, RQRotationType::Fast); + let json = serde_json::to_string(&seed.metadata(None)).unwrap(); + let meta: RabitQuantizationMetadata = serde_json::from_str(&json).unwrap(); + + let params = RQBuildParams { + num_bits: 1, + rotation_type: RQRotationType::Fast, + rotation: Some(meta), + }; + let data = sample_fsl(8, dim as usize); + let q_a = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap(); + let q_b = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap(); + + assert_eq!( + quantized_codes(&q_a, &data), + quantized_codes(&q_b, &data), + "shared rotation must yield identical codes" + ); + } + + #[test] + fn test_unpinned_rotation_gives_different_codes() { + let dim = 32; + let params = RQBuildParams::new(1); + let data = sample_fsl(8, dim as usize); + let q_a = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap(); + let q_b = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap(); + + assert_ne!( + quantized_codes(&q_a, &data), + quantized_codes(&q_b, &data), + "independent unpinned rotations must yield different codes" + ); + } + + #[test] + fn test_build_rejects_rotation_with_mismatched_code_dim() { + let seed = RabitQuantizer::new_with_rotation::(1, 16, RQRotationType::Fast); + let params = RQBuildParams { + num_bits: 1, + rotation_type: RQRotationType::Fast, + rotation: Some(seed.metadata(None)), + }; + let data = sample_fsl(4, 32); + let err = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap_err(); + assert!( + err.to_string().contains("does not match build params"), + "{}", + err + ); + } + + #[test] + fn test_build_rejects_fast_rotation_with_bad_signs_length() { + let dim = 16; + let seed = RabitQuantizer::new_with_rotation::(1, dim, RQRotationType::Fast); + let mut meta = seed.metadata(None); + // Corrupt the signs to the wrong length (valid would be 4 * ceil(16/8) = 8). + meta.fast_rotation_signs = Some(vec![0u8; 7]); + let params = RQBuildParams { + num_bits: 1, + rotation_type: RQRotationType::Fast, + rotation: Some(meta), + }; + let data = sample_fsl(4, dim as usize); + let err = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap_err(); + assert!(err.to_string().contains("signs length"), "{}", err); + } + + #[test] + fn test_matrix_rotation_lost_through_json_is_rejected() { + let dim = 16; + let seed = RabitQuantizer::new_with_rotation::(1, dim, RQRotationType::Matrix); + let meta = seed.metadata(None); + assert!(meta.rotate_mat.is_some()); + + let json = serde_json::to_string(&meta).unwrap(); + let parsed: RabitQuantizationMetadata = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.rotation_type, RQRotationType::Matrix); + assert!( + parsed.rotate_mat.is_none(), + "matrix is expected to be dropped by JSON serialization" + ); + + let params = RQBuildParams { + num_bits: 1, + rotation_type: RQRotationType::Matrix, + rotation: Some(parsed), + }; + let data = sample_fsl(4, dim as usize); + let err = RabitQuantizer::build(&data, DistanceType::L2, ¶ms).unwrap_err(); + assert!( + err.to_string() + .contains("fast rotation for distributed builds"), + "{}", + err + ); + } } diff --git a/rust/lance-index/src/vector/bq/rotation.rs b/rust/lance-index/src/vector/bq/rotation.rs index de4fbf549f1..2346c772cb3 100644 --- a/rust/lance-index/src/vector/bq/rotation.rs +++ b/rust/lance-index/src/vector/bq/rotation.rs @@ -138,9 +138,13 @@ fn sign_bytes_per_round(dim: usize) -> usize { dim.div_ceil(8) } +pub(crate) fn fast_rotation_signs_len(dim: usize) -> usize { + FAST_ROTATION_ROUNDS * sign_bytes_per_round(dim) +} + pub fn random_fast_rotation_signs(dim: usize) -> Vec { // Each round needs one random sign bit per dimension. - let mut signs = vec![0u8; FAST_ROTATION_ROUNDS * sign_bytes_per_round(dim)]; + let mut signs = vec![0u8; fast_rotation_signs_len(dim)]; rand::rng().fill_bytes(&mut signs); signs } diff --git a/rust/lance/src/index/vector.rs b/rust/lance/src/index/vector.rs index 87b32344ec6..526ba62be56 100644 --- a/rust/lance/src/index/vector.rs +++ b/rust/lance/src/index/vector.rs @@ -1877,6 +1877,7 @@ fn derive_rabit_params(rabit_quantizer: &RabitQuantizer) -> RQBuildParams { RQBuildParams { num_bits: rabit_quantizer.num_bits(), rotation_type: rabit_quantizer.rotation_type(), + rotation: None, } } From c010e39d220ce1678db66047df9335f4ff8ca6e3 Mon Sep 17 00:00:00 2001 From: gstamatakis95 <126914070+gstamatakis95@users.noreply.github.com> Date: Mon, 1 Jun 2026 22:02:32 +0200 Subject: [PATCH 2/2] refactor(python): rename RaBitQ rotation API per review --- python/python/lance/dataset.py | 14 +++---- .../python/lance/lance/indices/__init__.pyi | 3 +- python/python/tests/test_vector_index.py | 8 ++-- python/src/dataset.rs | 4 +- python/src/indices.rs | 42 ++++++------------- 5 files changed, 26 insertions(+), 45 deletions(-) diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index dcf462fdcb7..1b82e5298ed 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -3305,7 +3305,7 @@ def _create_index_impl( streaming_coreset_rate: Optional[int] = None, streaming_refine_passes: Optional[int] = None, skip_transpose: bool = False, - rq_rotation: Optional[str] = None, + rabitq_model: Optional[str] = None, require_commit: bool = True, **kwargs, ) -> Index: @@ -3627,8 +3627,8 @@ def _create_index_impl( if skip_transpose: kwargs["skip_transpose"] = True - if rq_rotation is not None: - kwargs["rq_rotation"] = rq_rotation + if rabitq_model is not None: + kwargs["rabitq_model"] = rabitq_model # Add fragment_ids and index_uuid to kwargs if provided for # distributed indexing @@ -3952,7 +3952,7 @@ def create_index_uncommitted( streaming_coreset_rate: Optional[int] = None, streaming_refine_passes: Optional[int] = None, skip_transpose: bool = False, - rq_rotation: Optional[str] = None, + rabitq_model: Optional[str] = None, **kwargs, ) -> Index: """ @@ -3979,8 +3979,8 @@ def create_index_uncommitted( requirement: - ``fragment_ids`` must be provided - - ``rq_rotation`` (``IVF_RQ`` only): a JSON string produced by - ``lance.lance.indices.build_rq_rotation``. It must be identical across all + - ``rabitq_model`` (``IVF_RQ`` only): a JSON string produced by + ``lance.lance.indices.build_rq_model``. It must be identical across all workers for their segments to be mergeable, since it pins the RaBitQ rotation so every segment rotates vectors the same way. If omitted, each call generates its own random rotation, which is only safe for a single, @@ -4017,7 +4017,7 @@ def create_index_uncommitted( streaming_coreset_rate=streaming_coreset_rate, streaming_refine_passes=streaming_refine_passes, skip_transpose=skip_transpose, - rq_rotation=rq_rotation, + rabitq_model=rabitq_model, require_commit=False, **kwargs, ) diff --git a/python/python/lance/lance/indices/__init__.pyi b/python/python/lance/lance/indices/__init__.pyi index 8afb1761f86..727469c4bc3 100644 --- a/python/python/lance/lance/indices/__init__.pyi +++ b/python/python/lance/lance/indices/__init__.pyi @@ -67,10 +67,9 @@ def transform_vectors( pq_codebook: pa.Array, dst_uri: str, ): ... -def build_rq_rotation( +def build_rq_model( dimension: int, num_bits: int = 1, - rotation_type: str = "fast", dtype: str = "float32", ) -> str: ... diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index 25c330b20f1..c74aaa28684 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -3037,8 +3037,8 @@ def test_commit_existing_index_segments_accepts_index_metadata(tmp_path): def test_distributed_ivf_rq_shared_rotation(tmp_path): """Two IVF_RQ segments built on separate fragments with one shared RaBitQ rotation - merge into a single committed, queryable index. The shared ``rq_rotation`` (from - ``lance.lance.indices.build_rq_rotation``) is what makes the independently built + merge into a single committed, queryable index. The shared ``rabitq_model`` (from + ``lance.lance.indices.build_rq_model``) is what makes the independently built segments mergeable.""" from lance.lance import indices @@ -3054,14 +3054,14 @@ def test_distributed_ivf_rq_shared_rotation(tmp_path): distance_type="l2", sample_rate=8, ) - rq_rotation = indices.build_rq_rotation(dimension=dim, num_bits=1) + rabitq_model = indices.build_rq_model(dimension=dim, num_bits=1) base_kwargs = { "column": "vector", "index_type": "IVF_RQ", "num_partitions": 2, "num_bits": 1, "ivf_centroids": ivf_model.centroids, - "rq_rotation": rq_rotation, + "rabitq_model": rabitq_model, } first = ds.create_index_uncommitted( **base_kwargs, diff --git a/python/src/dataset.rs b/python/src/dataset.rs index ea92736be68..2deeb610be0 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -4440,10 +4440,10 @@ fn prepare_vector_index_params( pq_params.codebook = Some(codebook.values().clone()) }; - if let Some(r) = kwargs.get_item("rq_rotation")? { + if let Some(r) = kwargs.get_item("rabitq_model")? { let json: String = r.extract()?; let meta: RabitQuantizationMetadata = serde_json::from_str(&json) - .map_err(|e| PyValueError::new_err(format!("Invalid rq_rotation JSON: {e}")))?; + .map_err(|e| PyValueError::new_err(format!("Invalid rabitq_model JSON: {e}")))?; rq_params.rotation = Some(meta); }; diff --git a/python/src/indices.rs b/python/src/indices.rs index 7a7667fbc18..a0fec3c8812 100644 --- a/python/src/indices.rs +++ b/python/src/indices.rs @@ -363,11 +363,11 @@ fn train_pq_model<'py>( /// Distributed IVF_RQ builds must pin a single rotation across all workers so that /// independently built per-fragment segments rotate vectors identically and their /// binary codes remain comparable when merged. A driver calls this once and broadcasts -/// the resulting string to every `create_index_uncommitted(..., rq_rotation=...)` call. +/// the resulting string to every `create_index_uncommitted(..., rabitq_model=...)` call. /// -/// Only the "fast" rotation is supported since its sign vector is JSON-serializable, whereas -/// the "matrix" rotation stores a dense matrix in a binary buffer that is dropped by the -/// JSON wire format. `dtype` is accepted for API symmetry but does not affect the fast +/// The rotation is always the "fast" rotation since its sign vector is JSON-serializable, +/// whereas the "matrix" rotation stores a dense matrix in a binary buffer that is dropped by +/// the JSON wire format. `dtype` is accepted for API symmetry but does not affect the fast /// rotation. /// /// # Example (Python) @@ -375,25 +375,20 @@ fn train_pq_model<'py>( /// ```python /// from lance.lance import indices /// -/// # Mint one rotation and broadcast `rot` to every worker. -/// rot = indices.build_rq_rotation(dimension=128, num_bits=1) +/// # Mint one model and broadcast `model` to every worker. +/// model = indices.build_rq_model(dimension=128, num_bits=1) /// seg = ds.create_index_uncommitted( /// column="vector", /// index_type="IVF_RQ", /// num_partitions=256, /// ivf_centroids=centroids, -/// rq_rotation=rot, +/// rabitq_model=model, /// fragment_ids=my_fragments, /// ) /// ``` #[pyfunction] -#[pyo3(signature = (dimension, num_bits=1, rotation_type="fast", dtype="float32"))] -pub fn build_rq_rotation( - dimension: usize, - num_bits: u8, - rotation_type: &str, - dtype: &str, -) -> PyResult { +#[pyo3(signature = (dimension, num_bits=1, dtype="float32"))] +pub fn build_rq_model(dimension: usize, num_bits: u8, dtype: &str) -> PyResult { use arrow::datatypes::{Float16Type, Float32Type, Float64Type}; use lance_index::vector::bq::builder::RabitQuantizer; use lance_index::vector::bq::RQRotationType; @@ -404,21 +399,8 @@ pub fn build_rq_rotation( "dimension must be divisible by 8 for IVF_RQ", )); } - let rotation = match rotation_type.to_lowercase().as_str() { - "fast" => RQRotationType::Fast, - "matrix" => { - return Err(PyValueError::new_err( - "matrix rotation cannot be serialized to JSON for distributed builds; \ - use rotation_type='fast'", - )); - } - other => { - return Err(PyValueError::new_err(format!( - "unknown rotation_type: {other}; expected 'fast'" - ))); - } - }; let dim = dimension as i32; + let rotation = RQRotationType::Fast; let quantizer = match dtype.to_lowercase().as_str() { "float16" => RabitQuantizer::new_with_rotation::(num_bits, dim, rotation), "float32" => RabitQuantizer::new_with_rotation::(num_bits, dim, rotation), @@ -428,7 +410,7 @@ pub fn build_rq_rotation( } }; serde_json::to_string(&quantizer.metadata(None)) - .map_err(|e| PyValueError::new_err(format!("failed to serialize RQ rotation: {e}"))) + .map_err(|e| PyValueError::new_err(format!("failed to serialize RQ model: {e}"))) } #[allow(clippy::too_many_arguments)] @@ -825,7 +807,7 @@ pub fn register_indices(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { let indices = PyModule::new(py, "indices")?; indices.add_wrapped(wrap_pyfunction!(train_ivf_model))?; indices.add_wrapped(wrap_pyfunction!(train_pq_model))?; - indices.add_wrapped(wrap_pyfunction!(build_rq_rotation))?; + indices.add_wrapped(wrap_pyfunction!(build_rq_model))?; indices.add_wrapped(wrap_pyfunction!(transform_vectors))?; indices.add_wrapped(wrap_pyfunction!(shuffle_transformed_vectors))?; indices.add_wrapped(wrap_pyfunction!(load_shuffled_vectors))?;