diff --git a/python/python/lance/indices/builder.py b/python/python/lance/indices/builder.py index c31ea0a7a0c..d3d61c5f8ff 100644 --- a/python/python/lance/indices/builder.py +++ b/python/python/lance/indices/builder.py @@ -56,7 +56,9 @@ def __init__(self, dataset, column: str): """ self.dataset = dataset self.column = self._normalize_column(column) - self.dimension = self.dataset.schema.field(self.column[0]).type.list_size + self.dimension = self._vector_dimension( + self.dataset.schema.field(self.column[0]).type + ) def train_ivf( self, @@ -199,7 +201,6 @@ def train_pq( from lance.lance import indices num_rows = self._count_rows(fragment_ids) - self.dataset.schema.field(self.column[0]).type.list_size num_subvectors = self._normalize_pq_params(num_subvectors, self.dimension) self._verify_pq_sample_rate(num_rows, sample_rate) distance_type = ivf_model.distance_type @@ -359,7 +360,7 @@ def transform_vectors( """ from lance.lance import indices - dimension = self.dataset.schema.field(self.column[0]).type.list_size + dimension = self.dimension num_subvectors = pq.num_subvectors distance_type = ivf.distance_type if fragments is None: @@ -452,7 +453,7 @@ def load_shuffled_vectors( The PQ model used to create the inputs. """ - pq_dimension = self.dataset.schema.field(self.column[0]).type.list_size + pq_dimension = self.dimension num_subvectors = pq.num_subvectors distance_type = ivf.distance_type @@ -578,28 +579,46 @@ def _normalize_column(self, column): if c not in self.dataset.schema.names: raise KeyError(f"{c} not found in schema") field = self.dataset.schema.field(c) - if not ( - pa.types.is_fixed_size_list(field.type) - or ( - isinstance(field.type, pa.FixedShapeTensorType) - and len(field.type.shape) == 1 - ) - ): + vector_type = self._describe_vector_type(field.type) + if vector_type is None: raise TypeError( - f"Vector column {c} must be FixedSizeListArray " + f"Vector column {c} must be FixedSizeListArray, " + "list (multivector), or " f"1-dimensional FixedShapeTensorArray, got {field.type}" ) + _, value_type = vector_type if not ( - pa.types.is_floating(field.type.value_type) - or pa.types.is_unsigned_integer(field.type.value_type) + pa.types.is_floating(value_type) + or pa.types.is_unsigned_integer(value_type) ): raise TypeError( f"Vector column {c} must have floating or unsigned integer " - f"value type, got {field.type.value_type}" + f"value type, got {value_type}" ) return column + def _vector_dimension(self, data_type): + vector_type = self._describe_vector_type(data_type) + if vector_type is not None: + return vector_type[0] + raise TypeError( + "Vector column must be FixedSizeListArray, " + "list (multivector), or " + f"1-dimensional FixedShapeTensorArray, got {data_type}" + ) + + def _describe_vector_type(self, data_type): + if pa.types.is_fixed_size_list(data_type): + return data_type.list_size, data_type.value_type + if pa.types.is_list(data_type) and pa.types.is_fixed_size_list( + data_type.value_type + ): + return data_type.value_type.list_size, data_type.value_type.value_type + if isinstance(data_type, pa.FixedShapeTensorType) and len(data_type.shape) == 1: + return data_type.shape[0], data_type.value_type + return None + @dataclass class IndexConfig: diff --git a/python/python/tests/test_indices.py b/python/python/tests/test_indices.py index 88cae659561..7f6595f2ecc 100644 --- a/python/python/tests/test_indices.py +++ b/python/python/tests/test_indices.py @@ -65,6 +65,32 @@ def mostly_null_dataset(tmpdir, request): return ds +def make_multivector_dataset(tmpdir): + dimension = 4 + vector_type = pa.list_(pa.list_(pa.float32(), dimension)) + schema = pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("embeddings", vector_type), + ] + ) + table = pa.Table.from_pylist( + [ + {"id": 1, "embeddings": [[0.1, 0.2, 0.3, 0.4], [0.2, 0.3, 0.4, 0.5]]}, + {"id": 2, "embeddings": [[0.8, 0.7, 0.6, 0.5]]}, + {"id": 3, "embeddings": [[0.3, 0.2, 0.1, 0.0], [0.9, 0.8, 0.7, 0.6]]}, + {"id": 4, "embeddings": [[0.4, 0.1, 0.2, 0.3]]}, + ], + schema=schema, + ) + ds = lance.write_dataset( + table, + pathlib.Path(tmpdir) / "multivector_fragment_ivf", + max_rows_per_file=2, + ) + return ds, dimension + + def test_ivf_centroids(tmpdir, rand_dataset): ivf = IndicesBuilder(rand_dataset, "vectors").train_ivf(sample_rate=16) @@ -218,6 +244,88 @@ def test_ivf_centroids_fragment_ids(tmpdir): assert np.allclose(second_centroid, 10.0, atol=1e-4) +def test_ivf_centroids_multivector_fragment_ids(tmpdir): + ds, dimension = make_multivector_dataset(tmpdir) + builder = IndicesBuilder(ds, "embeddings") + assert builder.dimension == dimension + + centroids = pa.FixedSizeListArray.from_arrays( + pa.array( + [ + 0.1, + 0.2, + 0.3, + 0.4, + 0.8, + 0.7, + 0.6, + 0.5, + ], + type=pa.float32(), + ), + dimension, + ) + fragment_ids = [fragment.fragment_id for fragment in ds.get_fragments()] + + index = ds.create_index_uncommitted( + "embeddings", + index_type="IVF_HNSW_SQ", + metric="cosine", + num_partitions=2, + fragment_ids=fragment_ids, + index_uuid="00000000-0000-4000-8000-000000000001", + ivf_centroids=centroids, + ) + + assert index.uuid == "00000000-0000-4000-8000-000000000001" + assert index.fragment_ids == set(fragment_ids) + assert index.name == "embeddings_idx" + + +def test_indices_builder_multivector_distributed_dimensions(tmpdir, monkeypatch): + ds, dimension = make_multivector_dataset(tmpdir) + builder = IndicesBuilder(ds, "embeddings") + centroids = pa.FixedSizeListArray.from_arrays( + pa.array([0.1, 0.2, 0.3, 0.4], type=pa.float32()), + dimension, + ) + codebook = pa.FixedSizeListArray.from_arrays( + pa.array([0.1, 0.2, 0.3, 0.4], type=pa.float32()), + dimension, + ) + ivf = IvfModel(centroids, "l2") + pq = PqModel(2, codebook) + + from lance.lance import indices + + captured_dimensions = {} + + def train_pq_model(*args): + captured_dimensions["train_pq"] = args[2] + return codebook + + def transform_vectors(*args): + captured_dimensions["transform_vectors"] = args[2] + + def load_shuffled_vectors(*args): + captured_dimensions["load_shuffled_vectors"] = args[6] + + monkeypatch.setattr(indices, "train_pq_model", train_pq_model) + monkeypatch.setattr(indices, "transform_vectors", transform_vectors) + monkeypatch.setattr(indices, "load_shuffled_vectors", load_shuffled_vectors) + monkeypatch.setattr(builder, "_verify_pq_sample_rate", lambda *args: None) + + builder.train_pq(ivf, num_subvectors=2, sample_rate=2) + builder.transform_vectors(ivf, pq, str(pathlib.Path(tmpdir) / "transformed")) + builder.load_shuffled_vectors(["sorted"], str(tmpdir), ivf, pq) + + assert captured_dimensions == { + "train_pq": dimension, + "transform_vectors": dimension, + "load_shuffled_vectors": dimension, + } + + def test_pq_fragment_ids(rand_dataset): fragment_id = rand_dataset.get_fragments()[0].fragment_id ivf = IndicesBuilder(rand_dataset, "vectors").train_ivf( diff --git a/python/src/dataset.rs b/python/src/dataset.rs index d31cb870d0b..0f834484f30 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -4337,17 +4337,27 @@ fn prepare_vector_index_params( )); } + let centroid_type = match column_type { + DataType::List(field) + if matches!(field.data_type(), DataType::FixedSizeList(_, _)) => + { + field.data_type() + } + _ => column_type, + }; + // It's important that the centroids are the same data type // as the vectors that will be indexed. let mut centroids: Arc = batch.column(0).clone(); - if centroids.data_type() != column_type { - centroids = cast_with_options(centroids.as_ref(), column_type, &Default::default()) - .map_err(|e| { - PyValueError::new_err(format!( - "Failed to cast centroids to column type: {}", - e - )) - })?; + if centroids.data_type() != centroid_type { + centroids = + cast_with_options(centroids.as_ref(), centroid_type, &Default::default()) + .map_err(|e| { + PyValueError::new_err(format!( + "Failed to cast centroids to vector type: {}", + e + )) + })?; } let centroids = as_fixed_size_list_array(centroids.as_ref());