diff --git a/index/keyset_test.go b/index/keyset_test.go index 3dd43d4..67a382b 100644 --- a/index/keyset_test.go +++ b/index/keyset_test.go @@ -35,3 +35,16 @@ func TestKeySet_Multi(t *testing.T) { }) require.ElementsMatch(t, vs, [][]byte{[]byte("baz"), []byte("quux")}) } + +func TestKeySet_DuplicateKeys(t *testing.T) { + ks := index.NewKeySet([]byte("baz"), []byte("quux"), []byte("baz")) + require.EqualValues(t, "baz", ks.First()) + require.True(t, ks.Exists([]byte("baz"))) + require.True(t, ks.Exists([]byte("quux"))) + require.False(t, ks.Exists([]byte("foo"))) + vs := [][]byte{} + ks.Foreach(func(bs index.Key) { + vs = append(vs, bs) + }) + require.ElementsMatch(t, vs, [][]byte{[]byte("baz"), []byte("quux"), []byte("baz")}) +} diff --git a/part_index.go b/part_index.go index c227a98..917ffea 100644 --- a/part_index.go +++ b/part_index.go @@ -367,10 +367,7 @@ func (r *partIndexTxn) reindex(idKey index.Key, old object, new object) { if !unique { oldKey = encodeNonUniqueKey(idKey, oldKey) } - _, hadOld := r.tx.Delete(oldKey) - if !unique && !hadOld { - panic("BUG: delete did not find old object") - } + r.tx.Delete(oldKey) } }, ) diff --git a/part_index_test.go b/part_index_test.go new file mode 100644 index 0000000..59283e2 --- /dev/null +++ b/part_index_test.go @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Authors of Cilium + +package statedb + +import ( + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/cilium/statedb/index" +) + +type partIndexDuplicateKeyObject struct { + ID uint64 + Tags []string +} + +func (partIndexDuplicateKeyObject) TableHeader() []string { + return []string{"ID", "Tags"} +} + +func (obj partIndexDuplicateKeyObject) TableRow() []string { + return []string{strconv.FormatUint(obj.ID, 10), strings.Join(obj.Tags, ",")} +} + +func TestPartIndex_ReindexDuplicateKeys(t *testing.T) { + idIndex := Index[partIndexDuplicateKeyObject, uint64]{ + Name: "id", + FromObject: func(obj partIndexDuplicateKeyObject) index.KeySet { + return index.NewKeySet(index.Uint64(obj.ID)) + }, + FromKey: index.Uint64, + Unique: true, + } + tagIndex := Index[partIndexDuplicateKeyObject, string]{ + Name: "tag", + FromObject: func(obj partIndexDuplicateKeyObject) index.KeySet { + return index.StringSlice(obj.Tags) + }, + FromKey: index.String, + Unique: false, + } + + db := New() + table, err := NewTable(db, "part-index-test", idIndex, tagIndex) + require.NoError(t, err) + + wtxn := db.WriteTxn(table) + _, _, err = table.Insert(wtxn, partIndexDuplicateKeyObject{ + ID: 1, + Tags: []string{"duplicate", "duplicate"}, + }) + require.NoError(t, err) + txn := wtxn.Commit() + + require.Len(t, Collect(table.List(txn, tagIndex.Query("duplicate"))), 1) + + var ( + oldObj partIndexDuplicateKeyObject + hadOld bool + ) + require.NotPanics(t, func() { + wtxn = db.WriteTxn(table) + defer wtxn.Abort() + + oldObj, hadOld, err = table.Insert(wtxn, partIndexDuplicateKeyObject{ + ID: 1, + Tags: []string{"replacement"}, + }) + if err == nil { + txn = wtxn.Commit() + } + }) + require.NoError(t, err) + require.True(t, hadOld) + require.Equal(t, []string{"duplicate", "duplicate"}, oldObj.Tags) + + require.Empty(t, Collect(table.List(txn, tagIndex.Query("duplicate")))) + objs := Collect(table.List(txn, tagIndex.Query("replacement"))) + require.Len(t, objs, 1) + require.Equal(t, uint64(1), objs[0].ID) +}