Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 44 additions & 15 deletions internal/verifier/keyprovider/filesystemprovider/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ import (
"io/fs"
"os"
"path/filepath"
"sync"

notationx509 "github.com/notaryproject/notation-core-go/x509"
"github.com/notaryproject/ratify/v2/internal/verifier/keyprovider"
verifiertruststore "github.com/notaryproject/ratify/v2/internal/verifier/truststore"
"github.com/sirupsen/logrus"
)

Expand All @@ -36,6 +38,8 @@ const fileSystemProviderName = "files"
type FileSystemProvider struct {
certPaths []string
certificates []*x509.Certificate
watcher *verifiertruststore.Watcher
mu sync.RWMutex
}

func init() {
Expand All @@ -53,29 +57,54 @@ func init() {
return nil, fmt.Errorf("no file paths provided")
}

// Load certificates during initialization
var allCertificates []*x509.Certificate
for _, certPath := range paths {
certificates, err := loadCertificatesFromPath(certPath)
if err != nil {
return nil, fmt.Errorf("failed to load certificates from path %s: %w", certPath, err)
}
allCertificates = append(allCertificates, certificates...)
provider := &FileSystemProvider{certPaths: paths}
if err := provider.reloadCertificates(); err != nil {
return nil, err
}

return &FileSystemProvider{
certPaths: paths,
certificates: allCertificates,
}, nil
watcher, err := verifiertruststore.NewWatcher(paths, func() {
if err := provider.reloadCertificates(); err != nil {
logrus.WithError(err).Error("failed to reload trust store certificates")
}
})
if err != nil {
return nil, fmt.Errorf("failed to create trust store watcher: %w", err)
}
if err := watcher.Start(); err != nil {
watcher.Stop()
return nil, fmt.Errorf("failed to start trust store watcher: %w", err)
}
provider.watcher = watcher
return provider, nil
Comment on lines +65 to +78
})
}

// FileSystemProvider implements GetCertificates of [truststore.X509TrustStore]
// interface.
func (f *FileSystemProvider) GetCertificates(_ context.Context) ([]*x509.Certificate, error) {
// Return cached certificates loaded during initialization
logrus.Debugf("Returning %d cached certificate(s) from file system", len(f.certificates))
return f.certificates, nil
f.mu.RLock()
defer f.mu.RUnlock()

certificates := make([]*x509.Certificate, len(f.certificates))
copy(certificates, f.certificates)
logrus.Debugf("Returning %d cached certificate(s) from file system", len(certificates))
return certificates, nil
}

func (f *FileSystemProvider) reloadCertificates() error {
var allCertificates []*x509.Certificate
for _, certPath := range f.certPaths {
certificates, err := loadCertificatesFromPath(certPath)
if err != nil {
return fmt.Errorf("failed to load certificates from path %s: %w", certPath, err)
}
allCertificates = append(allCertificates, certificates...)
}

f.mu.Lock()
defer f.mu.Unlock()
f.certificates = allCertificates
return nil
}

func (f *FileSystemProvider) GetKeys(_ context.Context) ([]*keyprovider.PublicKey, error) {
Expand Down
41 changes: 18 additions & 23 deletions internal/verifier/keyprovider/filesystemprovider/register_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func TestGetCertificates(t *testing.T) {
if err != nil {
t.Fatalf("failed to get certificates: %v", err)
}
if len(certs) != 1 {
if len(certs) == 0 {
t.Fatalf("expected at least one certificate, got %d", len(certs))
}
Comment on lines +92 to 94

Expand All @@ -102,7 +102,7 @@ func TestGetCertificates(t *testing.T) {
if err != nil {
t.Fatalf("failed to get certificates: %v", err)
}
if len(certs) != 1 {
if len(certs) == 0 {
t.Fatalf("expected at least one certificate, got %d", len(certs))
}
Comment on lines +105 to 107

Expand All @@ -120,10 +120,9 @@ func TestGetCertificates(t *testing.T) {
}
}

func TestGetCertificatesFromCache(t *testing.T) {
func TestGetCertificatesReloadsAfterFileChange(t *testing.T) {
tempDir := t.TempDir()

// Create a temporary certificate file.
certFile := filepath.Join(tempDir, "test-cert.pem")
certContent, err := createCert()
if err != nil {
Expand All @@ -133,40 +132,36 @@ func TestGetCertificatesFromCache(t *testing.T) {
t.Fatalf("failed to create temp cert file: %v", err)
}

// Create provider which should load certificates during initialization
opts := []string{tempDir}
provider, err := keyprovider.CreateKeyProvider(fileSystemProviderName, opts)
provider, err := keyprovider.CreateKeyProvider(fileSystemProviderName, []string{tempDir})
if err != nil {
t.Fatalf("failed to create key provider: %v", err)
}

// Get certificates first time
certs1, err := provider.GetCertificates(context.Background())
certs, err := provider.GetCertificates(context.Background())
if err != nil {
t.Fatalf("failed to get certificates: %v", err)
}
if len(certs1) != 1 {
t.Fatalf("expected 1 certificate, got %d", len(certs1))
if len(certs) == 0 {
t.Fatalf("expected at least one certificate, got %d", len(certs))
}

// Remove the certificate file to ensure we're getting from cache
if err := os.Remove(certFile); err != nil {
t.Fatalf("failed to remove cert file: %v", err)
}

// Get certificates second time - should still work from cache
certs2, err := provider.GetCertificates(context.Background())
if err != nil {
t.Fatalf("failed to get certificates from cache: %v", err)
}
if len(certs2) != 1 {
t.Fatalf("expected 1 cached certificate, got %d", len(certs2))
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
certs, err = provider.GetCertificates(context.Background())
if err != nil {
t.Fatalf("failed to get certificates after removal: %v", err)
}
if len(certs) == 0 {
return
}
time.Sleep(200 * time.Millisecond)
}

// Verify both calls return the same certificate
if !certs1[0].Equal(certs2[0]) {
t.Fatalf("cached certificate doesn't match original")
}
t.Fatalf("expected watcher to reload certificates after file removal, still have %d", len(certs))
}

func TestGetKeys(t *testing.T) {
Expand Down
Loading
Loading