diff --git a/faiss_vector_index_bivf.go b/faiss_vector_index_bivf.go index 596e19d3..7996298e 100644 --- a/faiss_vector_index_bivf.go +++ b/faiss_vector_index_bivf.go @@ -251,16 +251,18 @@ func (b *faissBinaryIndex) setNProbe(nprobe int32) { } func (b *faissBinaryIndex) trainAndAdd(trainingData *vectorSet, vecsToAdd *vectorSet) error { + nlist := determineCentroids(trainingData.nvecs) + nvecsToTrain := nlist * 40 * b.dim() // train the backing index with the floatData var err error if b.backing.IsSQIndex() { - err = b.backing.Train(trainingData.floatData) + err = b.backing.Train(trainingData.floatData[:nvecsToTrain]) if err != nil { return err } } - err = b.binary.Train(trainingData.binaryData) + err = b.binary.Train(trainingData.binaryData[:(nvecsToTrain / 8)]) if err != nil { return err } diff --git a/faiss_vector_index_float32.go b/faiss_vector_index_float32.go index bfcd7267..63af82d1 100644 --- a/faiss_vector_index_float32.go +++ b/faiss_vector_index_float32.go @@ -155,7 +155,9 @@ func (f *faissFloat32Index) setNProbe(nprobe int32) { } func (f *faissFloat32Index) trainAndAdd(trainingData *vectorSet, vecsToAdd *vectorSet) error { - err := f.idx.Train(trainingData.floatData) + nlist := determineCentroids(trainingData.nvecs) + nvecsToTrain := nlist * 40 * f.dim() + err := f.idx.Train(trainingData.floatData[:nvecsToTrain]) if err != nil { return err }