Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ func TestClientLayeredCache(t *testing.T) {
}

l1 := newRistrettoCache[*User](t)
l2 := newGORMCache[*User](t, "user_cache")
l2, _ := newGORMCache[*User](t, "user_cache")

apiCallCount := 0
apiUpstream := UpstreamFunc[*User](func(ctx context.Context, key string) (*User, error) {
Expand Down
29 changes: 25 additions & 4 deletions gorm.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cachex

import (
"cmp"
"context"
"encoding/json"
"time"
Expand Down Expand Up @@ -60,12 +61,29 @@ func (g *GORMCache[T]) prefixedKey(key string) string {

// Migrate creates or updates the cache table schema
func (g *GORMCache[T]) Migrate(ctx context.Context) error {
if err := g.db.WithContext(ctx).Table(g.tableName).AutoMigrate(&cacheEntry{}); err != nil {
tx := cmp.Or(GetGORMTx(ctx), g.db)
if err := tx.WithContext(ctx).Table(g.tableName).AutoMigrate(&cacheEntry{}); err != nil {
return errors.Wrapf(err, "failed to migrate cache table for table: %s", g.tableName)
}
return nil
}

type ctxKeyGORMTx struct{}

// WithGORMTx attaches a GORM transaction to the context.
// All GORMCache operations using this context will execute within the transaction.
// The transaction must be committed or rolled back by the caller.
func WithGORMTx(ctx context.Context, tx *gorm.DB) context.Context {
return context.WithValue(ctx, ctxKeyGORMTx{}, tx)
}

// GetGORMTx retrieves the GORM transaction from the context.
// Returns nil if no transaction is attached to the context.
func GetGORMTx(ctx context.Context) *gorm.DB {
tx, _ := ctx.Value(ctxKeyGORMTx{}).(*gorm.DB)
return tx
}

// Set stores a value in the cache
func (g *GORMCache[T]) Set(ctx context.Context, key string, value T) error {
data, err := json.Marshal(value)
Expand All @@ -78,7 +96,8 @@ func (g *GORMCache[T]) Set(ctx context.Context, key string, value T) error {
Value: data,
}

if err := g.db.WithContext(ctx).
tx := cmp.Or(GetGORMTx(ctx), g.db)
if err := tx.WithContext(ctx).
Table(g.tableName).
Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}},
Expand All @@ -96,7 +115,8 @@ func (g *GORMCache[T]) Get(ctx context.Context, key string) (T, error) {
var zero T
var entry cacheEntry

if err := g.db.WithContext(ctx).
tx := cmp.Or(GetGORMTx(ctx), g.db)
if err := tx.WithContext(ctx).
Table(g.tableName).
Where("key = ?", g.prefixedKey(key)).
First(&entry).Error; err != nil {
Expand All @@ -116,7 +136,8 @@ func (g *GORMCache[T]) Get(ctx context.Context, key string) (T, error) {

// Del removes a value from the cache
func (g *GORMCache[T]) Del(ctx context.Context, key string) error {
if err := g.db.WithContext(ctx).
tx := cmp.Or(GetGORMTx(ctx), g.db)
if err := tx.WithContext(ctx).
Table(g.tableName).
Where("key = ?", g.prefixedKey(key)).
Delete(nil).Error; err != nil {
Expand Down
146 changes: 142 additions & 4 deletions gorm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cachex

import (
"context"
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -10,20 +11,20 @@ import (
"gorm.io/gorm"
)

func newGORMCache[T any](tb testing.TB, tableName string) *GORMCache[T] {
func newGORMCache[T any](tb testing.TB, tableName string) (*GORMCache[T], *gorm.DB) {
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
require.NoError(tb, err)
cache := NewGORMCache[T](&GORMCacheConfig{
DB: db,
TableName: tableName,
})
require.NoError(tb, cache.Migrate(context.Background()))
return cache
return cache, db
}

func TestGORMCacheBasics(t *testing.T) {
ctx := context.Background()
cache := newGORMCache[string](t, "test_cache")
cache, _ := newGORMCache[string](t, "test_cache")

require.NoError(t, cache.Set(ctx, "key1", "value1"))

Expand All @@ -39,7 +40,7 @@ func TestGORMCacheBasics(t *testing.T) {

func TestGORMCacheWithBytes(t *testing.T) {
ctx := context.Background()
cache := newGORMCache[[]byte](t, "bytes_cache")
cache, _ := newGORMCache[[]byte](t, "bytes_cache")

testData := []byte("raw binary data \x00\x01\x02")

Expand Down Expand Up @@ -84,3 +85,140 @@ func TestGORMCacheConfigWithPrefix(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, "dev-secret-456", devValue)
}

func TestGORMCacheTransactionCommit(t *testing.T) {
ctx := context.Background()
cache, db := newGORMCache[string](t, "tx_commit_cache")

require.NoError(t, cache.Set(ctx, "other_key", "other_value"))

tx := db.Begin()
Comment thread
molon marked this conversation as resolved.
txCtx := WithGORMTx(ctx, tx)

require.NoError(t, cache.Set(txCtx, "tx_key", "tx_value"))

value, err := cache.Get(txCtx, "tx_key")
require.NoError(t, err)
assert.Equal(t, "tx_value", value, "should read value within transaction")

// SQLite write lock: when a transaction has pending writes, concurrent reads from
// outside the transaction may fail with "no such table" due to SQLite's locking behavior.
// Both errors (key not found or table locked) prove transaction isolation.
_, err = cache.Get(ctx, "tx_key")
assert.True(t, IsErrKeyNotFound(err) || strings.Contains(err.Error(), "no such table"),
"should not read uncommitted value outside transaction (got: %v)", err)

require.NoError(t, tx.Commit().Error)

value, err = cache.Get(ctx, "tx_key")
require.NoError(t, err)
assert.Equal(t, "tx_value", value, "should find key after transaction commit")
}

func TestGORMCacheTransactionRollback(t *testing.T) {
ctx := context.Background()
cache, db := newGORMCache[string](t, "tx_rollback_cache")

require.NoError(t, cache.Set(ctx, "exists_key", "exists_value"))

tx := db.Begin()
Comment thread
molon marked this conversation as resolved.
txCtx := WithGORMTx(ctx, tx)

require.NoError(t, cache.Set(txCtx, "rollback_key", "rollback_value"))
require.NoError(t, cache.Set(txCtx, "exists_key", "exists_value2"))

require.NoError(t, tx.Rollback().Error)

_, err := cache.Get(ctx, "rollback_key")
assert.True(t, IsErrKeyNotFound(err), "should not find key after transaction rollback")

value, err := cache.Get(ctx, "exists_key")
require.NoError(t, err)
assert.Equal(t, "exists_value", value, "should find key after transaction rollback")
}

func TestGORMCacheTransactionIsolation(t *testing.T) {
ctx := context.Background()
cache, db := newGORMCache[string](t, "tx_isolation_cache")

require.NoError(t, cache.Set(ctx, "isolation_key", "original_value"))

value, err := cache.Get(ctx, "isolation_key")
require.NoError(t, err)
assert.Equal(t, "original_value", value, "should read original value before transaction")

tx := db.Begin()
Comment thread
molon marked this conversation as resolved.
txCtx := WithGORMTx(ctx, tx)

require.NoError(t, cache.Set(txCtx, "isolation_key", "updated_value"))

txValue, err := cache.Get(txCtx, "isolation_key")
require.NoError(t, err)
assert.Equal(t, "updated_value", txValue, "should see updated value inside transaction")

require.NoError(t, tx.Commit().Error)

finalValue, err := cache.Get(ctx, "isolation_key")
require.NoError(t, err)
assert.Equal(t, "updated_value", finalValue, "should see updated value after commit")
}

func TestGORMCacheTransactionDelete(t *testing.T) {
ctx := context.Background()
cache, db := newGORMCache[string](t, "tx_delete_cache")

require.NoError(t, cache.Set(ctx, "del_key", "del_value"))

value, err := cache.Get(ctx, "del_key")
require.NoError(t, err)
assert.Equal(t, "del_value", value, "should read value before transaction")

tx := db.Begin()
Comment thread
molon marked this conversation as resolved.
txCtx := WithGORMTx(ctx, tx)

require.NoError(t, cache.Del(txCtx, "del_key"))

_, err = cache.Get(txCtx, "del_key")
assert.True(t, IsErrKeyNotFound(err), "should not find key inside transaction after delete")

require.NoError(t, tx.Commit().Error)

_, err = cache.Get(ctx, "del_key")
assert.True(t, IsErrKeyNotFound(err), "should not find key after transaction commit")
}

func TestGORMCacheWithClientTransaction(t *testing.T) {
type User struct {
ID string
Name string
}

ctx := context.Background()
cache, db := newGORMCache[*User](t, "client_tx_cache")

fetchCount := 0
upstream := UpstreamFunc[*User](func(ctx context.Context, key string) (*User, error) {
fetchCount++
return &User{ID: key, Name: "User " + key}, nil
})

client := NewClient(cache, upstream)

require.NoError(t, cache.Set(ctx, "other_key", &User{ID: "other", Name: "Other User"}))

tx := db.Begin()
txCtx := WithGORMTx(ctx, tx)

user1 := &User{ID: "user1", Name: "User One"}
require.NoError(t, client.Set(txCtx, "user1", user1))

value, err := client.Get(txCtx, "user1")
require.NoError(t, err)
assert.Equal(t, "User One", value.Name, "should read value within transaction")

require.NoError(t, tx.Rollback().Error)

_, err = cache.Get(ctx, "user1")
assert.True(t, IsErrKeyNotFound(err), "should not find value in cache after transaction rollback")
assert.Equal(t, 0, fetchCount, "should not fetch from upstream during rollback test")
}
2 changes: 1 addition & 1 deletion ristretto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

func newRistrettoCache[T any](tb testing.TB) *RistrettoCache[T] {
cache, err := NewRistrettoCache[T](DefaultRistrettoCacheConfig[T]())
cache, err := NewRistrettoCache(DefaultRistrettoCacheConfig[T]())
require.NoError(tb, err)
tb.Cleanup(func() { _ = cache.Close() })
return cache
Expand Down
Loading