diff --git a/internal/verifier/keyprovider/filesystemprovider/register.go b/internal/verifier/keyprovider/filesystemprovider/register.go index c3708aeb0..d70d49589 100644 --- a/internal/verifier/keyprovider/filesystemprovider/register.go +++ b/internal/verifier/keyprovider/filesystemprovider/register.go @@ -23,9 +23,11 @@ import ( "io/fs" "os" "path/filepath" + "sync" notationx509 "github.com/notaryproject/notation-core-go/x509" "github.com/notaryproject/ratify/v2/internal/verifier/keyprovider" + verifiertruststore "github.com/notaryproject/ratify/v2/internal/verifier/truststore" "github.com/sirupsen/logrus" ) @@ -36,6 +38,8 @@ const fileSystemProviderName = "files" type FileSystemProvider struct { certPaths []string certificates []*x509.Certificate + watcher *verifiertruststore.Watcher + mu sync.RWMutex } func init() { @@ -53,29 +57,54 @@ func init() { return nil, fmt.Errorf("no file paths provided") } - // Load certificates during initialization - var allCertificates []*x509.Certificate - for _, certPath := range paths { - certificates, err := loadCertificatesFromPath(certPath) - if err != nil { - return nil, fmt.Errorf("failed to load certificates from path %s: %w", certPath, err) - } - allCertificates = append(allCertificates, certificates...) + provider := &FileSystemProvider{certPaths: paths} + if err := provider.reloadCertificates(); err != nil { + return nil, err } - return &FileSystemProvider{ - certPaths: paths, - certificates: allCertificates, - }, nil + watcher, err := verifiertruststore.NewWatcher(paths, func() { + if err := provider.reloadCertificates(); err != nil { + logrus.WithError(err).Error("failed to reload trust store certificates") + } + }) + if err != nil { + return nil, fmt.Errorf("failed to create trust store watcher: %w", err) + } + if err := watcher.Start(); err != nil { + watcher.Stop() + return nil, fmt.Errorf("failed to start trust store watcher: %w", err) + } + provider.watcher = watcher + return provider, nil }) } // FileSystemProvider implements GetCertificates of [truststore.X509TrustStore] // interface. func (f *FileSystemProvider) GetCertificates(_ context.Context) ([]*x509.Certificate, error) { - // Return cached certificates loaded during initialization - logrus.Debugf("Returning %d cached certificate(s) from file system", len(f.certificates)) - return f.certificates, nil + f.mu.RLock() + defer f.mu.RUnlock() + + certificates := make([]*x509.Certificate, len(f.certificates)) + copy(certificates, f.certificates) + logrus.Debugf("Returning %d cached certificate(s) from file system", len(certificates)) + return certificates, nil +} + +func (f *FileSystemProvider) reloadCertificates() error { + var allCertificates []*x509.Certificate + for _, certPath := range f.certPaths { + certificates, err := loadCertificatesFromPath(certPath) + if err != nil { + return fmt.Errorf("failed to load certificates from path %s: %w", certPath, err) + } + allCertificates = append(allCertificates, certificates...) + } + + f.mu.Lock() + defer f.mu.Unlock() + f.certificates = allCertificates + return nil } func (f *FileSystemProvider) GetKeys(_ context.Context) ([]*keyprovider.PublicKey, error) { diff --git a/internal/verifier/keyprovider/filesystemprovider/register_test.go b/internal/verifier/keyprovider/filesystemprovider/register_test.go index be993f6f1..d0ff11705 100644 --- a/internal/verifier/keyprovider/filesystemprovider/register_test.go +++ b/internal/verifier/keyprovider/filesystemprovider/register_test.go @@ -89,7 +89,7 @@ func TestGetCertificates(t *testing.T) { if err != nil { t.Fatalf("failed to get certificates: %v", err) } - if len(certs) != 1 { + if len(certs) == 0 { t.Fatalf("expected at least one certificate, got %d", len(certs)) } @@ -102,7 +102,7 @@ func TestGetCertificates(t *testing.T) { if err != nil { t.Fatalf("failed to get certificates: %v", err) } - if len(certs) != 1 { + if len(certs) == 0 { t.Fatalf("expected at least one certificate, got %d", len(certs)) } @@ -120,10 +120,9 @@ func TestGetCertificates(t *testing.T) { } } -func TestGetCertificatesFromCache(t *testing.T) { +func TestGetCertificatesReloadsAfterFileChange(t *testing.T) { tempDir := t.TempDir() - // Create a temporary certificate file. certFile := filepath.Join(tempDir, "test-cert.pem") certContent, err := createCert() if err != nil { @@ -133,40 +132,36 @@ func TestGetCertificatesFromCache(t *testing.T) { t.Fatalf("failed to create temp cert file: %v", err) } - // Create provider which should load certificates during initialization - opts := []string{tempDir} - provider, err := keyprovider.CreateKeyProvider(fileSystemProviderName, opts) + provider, err := keyprovider.CreateKeyProvider(fileSystemProviderName, []string{tempDir}) if err != nil { t.Fatalf("failed to create key provider: %v", err) } - // Get certificates first time - certs1, err := provider.GetCertificates(context.Background()) + certs, err := provider.GetCertificates(context.Background()) if err != nil { t.Fatalf("failed to get certificates: %v", err) } - if len(certs1) != 1 { - t.Fatalf("expected 1 certificate, got %d", len(certs1)) + if len(certs) == 0 { + t.Fatalf("expected at least one certificate, got %d", len(certs)) } - // Remove the certificate file to ensure we're getting from cache if err := os.Remove(certFile); err != nil { t.Fatalf("failed to remove cert file: %v", err) } - // Get certificates second time - should still work from cache - certs2, err := provider.GetCertificates(context.Background()) - if err != nil { - t.Fatalf("failed to get certificates from cache: %v", err) - } - if len(certs2) != 1 { - t.Fatalf("expected 1 cached certificate, got %d", len(certs2)) + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + certs, err = provider.GetCertificates(context.Background()) + if err != nil { + t.Fatalf("failed to get certificates after removal: %v", err) + } + if len(certs) == 0 { + return + } + time.Sleep(200 * time.Millisecond) } - // Verify both calls return the same certificate - if !certs1[0].Equal(certs2[0]) { - t.Fatalf("cached certificate doesn't match original") - } + t.Fatalf("expected watcher to reload certificates after file removal, still have %d", len(certs)) } func TestGetKeys(t *testing.T) { diff --git a/internal/verifier/truststore/watcher.go b/internal/verifier/truststore/watcher.go new file mode 100644 index 000000000..d08ff766c --- /dev/null +++ b/internal/verifier/truststore/watcher.go @@ -0,0 +1,275 @@ +/* +Copyright The Ratify Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package truststore + +import ( + "crypto/sha256" + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "sync" + "time" + + "github.com/fsnotify/fsnotify" + "github.com/sirupsen/logrus" +) + +var ( + debounceInterval = 2 * time.Second + pollInterval = 30 * time.Second +) + +type ChangeCallback func() + +type Watcher struct { + watcher *fsnotify.Watcher + paths []string + callback ChangeCallback + done chan struct{} + hashes map[string][32]byte + + mu sync.Mutex + stopOnce sync.Once +} + +func NewWatcher(paths []string, callback ChangeCallback) (*Watcher, error) { + if len(paths) == 0 { + return nil, fmt.Errorf("at least one path must be provided") + } + if callback == nil { + return nil, fmt.Errorf("callback must not be nil") + } + + fsWatcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, fmt.Errorf("failed to create fsnotify watcher: %w", err) + } + + w := &Watcher{ + watcher: fsWatcher, + paths: append([]string(nil), paths...), + callback: callback, + done: make(chan struct{}), + hashes: make(map[string][32]byte), + } + w.snapshotHashes() + return w, nil +} + +func (w *Watcher) Start() error { + for _, path := range w.paths { + if err := w.addPath(path); err != nil { + logrus.WithError(err).Warnf("failed to watch path %s", path) + } + } + + go w.watch() + go w.pollLoop() + return nil +} + +func (w *Watcher) Stop() { + w.stopOnce.Do(func() { + close(w.done) + if err := w.watcher.Close(); err != nil && !errors.Is(err, fsnotify.ErrClosed) { + logrus.WithError(err).Error("error closing trust store watcher") + } + }) +} + +func (w *Watcher) AddPath(path string) error { + w.mu.Lock() + for _, existing := range w.paths { + if existing == path { + w.mu.Unlock() + return nil + } + } + w.mu.Unlock() + + if err := w.addPath(path); err != nil { + return err + } + + w.mu.Lock() + w.paths = append(w.paths, path) + w.mu.Unlock() + w.snapshotHashes() + return nil +} + +func (w *Watcher) addPath(path string) error { + info, err := os.Stat(path) + if err != nil { + return fmt.Errorf("failed to stat path %s: %w", path, err) + } + if err := w.watcher.Add(path); err != nil { + return fmt.Errorf("failed to watch path %s: %w", path, err) + } + if info.IsDir() { + parent := filepath.Dir(path) + if parent != path { + _ = w.watcher.Add(parent) + } + } + return nil +} + +func (w *Watcher) watch() { + var debounceTimer *time.Timer + for { + select { + case event, ok := <-w.watcher.Events: + if !ok { + return + } + if event.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Remove|fsnotify.Rename) == 0 { + continue + } + logrus.Debugf("trust store watcher event: %s %s", event.Op, event.Name) + if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 { + _ = w.watcher.Add(event.Name) + } + if debounceTimer != nil { + debounceTimer.Stop() + } + eventName := event.Name + debounceTimer = time.AfterFunc(debounceInterval, func() { + logrus.Infof("trust store cert change detected: %s", eventName) + w.callback() + w.snapshotHashes() + }) + case err, ok := <-w.watcher.Errors: + if !ok { + return + } + logrus.WithError(err).Error("trust store watcher error") + case <-w.done: + if debounceTimer != nil { + debounceTimer.Stop() + } + return + } + } +} + +func (w *Watcher) pollLoop() { + ticker := time.NewTicker(pollInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if w.certsChanged() { + logrus.Info("trust store cert change detected via polling") + w.callback() + w.snapshotHashes() + } + case <-w.done: + return + } + } +} + +func (w *Watcher) certsChanged() bool { + currentHashes := w.currentHashes() + + w.mu.Lock() + defer w.mu.Unlock() + + if len(currentHashes) != len(w.hashes) { + return true + } + for path, currentHash := range currentHashes { + previousHash, ok := w.hashes[path] + if !ok || previousHash != currentHash { + return true + } + } + return false +} + +func (w *Watcher) currentHashes() map[string][32]byte { + w.mu.Lock() + paths := append([]string(nil), w.paths...) + w.mu.Unlock() + + hashes := make(map[string][32]byte) + for _, path := range paths { + if err := snapshotPathHashes(path, hashes); err != nil && !errors.Is(err, fs.ErrNotExist) { + logrus.WithError(err).Warnf("failed to snapshot trust store path %s", path) + } + } + return hashes +} + +func snapshotPathHashes(path string, hashes map[string][32]byte) error { + info, err := os.Stat(path) + if err != nil { + return err + } + if info.IsDir() { + return snapshotDirHashes(path, hashes) + } + + hash, err := hashFile(path) + if err != nil { + return err + } + hashes[path] = hash + return nil +} + +func snapshotDirHashes(path string, hashes map[string][32]byte) error { + root, err := os.OpenRoot(path) + if err != nil { + return fmt.Errorf("failed to open root %s: %w", path, err) + } + defer root.Close() + + return fs.WalkDir(root.FS(), ".", func(filePath string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + return nil + } + + data, err := root.ReadFile(filePath) + if err != nil { + return err + } + hashes[filepath.Join(path, filePath)] = sha256.Sum256(data) + return nil + }) +} + +func hashFile(path string) ([32]byte, error) { + data, err := os.ReadFile(path) + if err != nil { + return [32]byte{}, err + } + return sha256.Sum256(data), nil +} + +func (w *Watcher) snapshotHashes() { + currentHashes := w.currentHashes() + + w.mu.Lock() + defer w.mu.Unlock() + w.hashes = currentHashes +} diff --git a/internal/verifier/truststore/watcher_test.go b/internal/verifier/truststore/watcher_test.go new file mode 100644 index 000000000..3097ea81e --- /dev/null +++ b/internal/verifier/truststore/watcher_test.go @@ -0,0 +1,266 @@ +/* +Copyright The Ratify Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package truststore + +import ( + "errors" + "os" + "path/filepath" + "sync/atomic" + "testing" + "time" + + "github.com/fsnotify/fsnotify" +) + +func TestNewWatcher_NoPaths(t *testing.T) { + _, err := NewWatcher(nil, func() {}) + if err == nil { + t.Fatal("expected error for nil paths") + } +} + +func TestNewWatcher_NilCallback(t *testing.T) { + dir := t.TempDir() + _, err := NewWatcher([]string{dir}, nil) + if err == nil { + t.Fatal("expected error for nil callback") + } +} + +func TestWatcher_FileWriteTriggersCallback(t *testing.T) { + setWatcherIntervals(t, 10*time.Millisecond, pollInterval) + + dir := t.TempDir() + certFile := filepath.Join(dir, "ca.crt") + if err := os.WriteFile(certFile, []byte("initial-cert"), 0o644); err != nil { + t.Fatal(err) + } + + var calls atomic.Int32 + watcher, err := NewWatcher([]string{dir}, func() { calls.Add(1) }) + if err != nil { + t.Fatal(err) + } + if err := watcher.Start(); err != nil { + t.Fatal(err) + } + defer watcher.Stop() + + time.Sleep(100 * time.Millisecond) + if err := os.WriteFile(certFile, []byte("rotated-cert"), 0o644); err != nil { + t.Fatal(err) + } + + waitFor(t, time.Second, func() bool { return calls.Load() > 0 }, "expected callback after cert write") +} + +func TestWatcher_WatchHandlesEventsAndErrors(t *testing.T) { + setWatcherIntervals(t, 10*time.Millisecond, pollInterval) + + dir := t.TempDir() + certFile := filepath.Join(dir, "ca.crt") + if err := os.WriteFile(certFile, []byte("initial-cert"), 0o644); err != nil { + t.Fatal(err) + } + + var calls atomic.Int32 + watcher, err := NewWatcher([]string{dir}, func() { calls.Add(1) }) + if err != nil { + t.Fatal(err) + } + defer watcher.Stop() + + go watcher.watch() + watcher.watcher.Errors <- errors.New("synthetic watcher error") + watcher.watcher.Events <- fsnotify.Event{Name: dir, Op: fsnotify.Remove} + + waitFor(t, time.Second, func() bool { return calls.Load() > 0 }, "expected callback after synthetic watcher event") +} + +func TestWatcher_PollDetectsChanges(t *testing.T) { + dir := t.TempDir() + certFile := filepath.Join(dir, "ca.crt") + if err := os.WriteFile(certFile, []byte("original"), 0o644); err != nil { + t.Fatal(err) + } + + watcher, err := NewWatcher([]string{dir}, func() {}) + if err != nil { + t.Fatal(err) + } + + watcher.snapshotHashes() + if err := os.WriteFile(certFile, []byte("modified"), 0o644); err != nil { + t.Fatal(err) + } + if !watcher.certsChanged() { + t.Fatal("expected poller to detect modified cert") + } +} + +func TestWatcher_PollDetectsRemovedFile(t *testing.T) { + dir := t.TempDir() + certFile := filepath.Join(dir, "ca.crt") + if err := os.WriteFile(certFile, []byte("original"), 0o644); err != nil { + t.Fatal(err) + } + + watcher, err := NewWatcher([]string{dir}, func() {}) + if err != nil { + t.Fatal(err) + } + + watcher.snapshotHashes() + if err := os.Remove(certFile); err != nil { + t.Fatal(err) + } + if !watcher.certsChanged() { + t.Fatal("expected poller to detect removed cert") + } +} + +func TestWatcher_CurrentHashesHandlesRemovedDirectFile(t *testing.T) { + certFile := filepath.Join(t.TempDir(), "ca.crt") + if err := os.WriteFile(certFile, []byte("original"), 0o644); err != nil { + t.Fatal(err) + } + + watcher, err := NewWatcher([]string{certFile}, func() {}) + if err != nil { + t.Fatal(err) + } + + if err := os.Remove(certFile); err != nil { + t.Fatal(err) + } + if hashes := watcher.currentHashes(); len(hashes) != 0 { + t.Fatalf("expected no hashes after removing direct file, got %d", len(hashes)) + } +} + +func TestSnapshotPathHashes_File(t *testing.T) { + certFile := filepath.Join(t.TempDir(), "ca.crt") + if err := os.WriteFile(certFile, []byte("original"), 0o644); err != nil { + t.Fatal(err) + } + + hashes := make(map[string][32]byte) + if err := snapshotPathHashes(certFile, hashes); err != nil { + t.Fatal(err) + } + if _, ok := hashes[certFile]; !ok { + t.Fatal("expected file hash to be captured") + } +} + +func TestSnapshotPathHashes_MissingPath(t *testing.T) { + hashes := make(map[string][32]byte) + if err := snapshotPathHashes(filepath.Join(t.TempDir(), "missing.crt"), hashes); err == nil { + t.Fatal("expected missing path error") + } +} + +func TestWatcher_PollLoopTriggersCallback(t *testing.T) { + setWatcherIntervals(t, debounceInterval, 10*time.Millisecond) + + dir := t.TempDir() + certFile := filepath.Join(dir, "ca.crt") + if err := os.WriteFile(certFile, []byte("original"), 0o644); err != nil { + t.Fatal(err) + } + + var calls atomic.Int32 + watcher, err := NewWatcher([]string{dir}, func() { calls.Add(1) }) + if err != nil { + t.Fatal(err) + } + defer watcher.Stop() + + watcher.snapshotHashes() + if err := os.WriteFile(certFile, []byte("modified"), 0o644); err != nil { + t.Fatal(err) + } + + go watcher.pollLoop() + waitFor(t, time.Second, func() bool { return calls.Load() > 0 }, "expected callback after polling detected change") +} + +func TestWatcher_AddPathAndStop(t *testing.T) { + dir1 := t.TempDir() + dir2 := t.TempDir() + + watcher, err := NewWatcher([]string{dir1}, func() {}) + if err != nil { + t.Fatal(err) + } + if err := watcher.Start(); err != nil { + t.Fatal(err) + } + + if err := watcher.AddPath(dir2); err != nil { + t.Fatalf("failed to add path: %v", err) + } + if err := watcher.AddPath(dir2); err != nil { + t.Fatalf("duplicate add should not error: %v", err) + } + if len(watcher.paths) != 2 { + t.Fatalf("expected 2 watched paths, got %d", len(watcher.paths)) + } + + watcher.Stop() + watcher.Stop() +} + +func TestWatcher_AddPathMissingPath(t *testing.T) { + dir := t.TempDir() + + watcher, err := NewWatcher([]string{dir}, func() {}) + if err != nil { + t.Fatal(err) + } + defer watcher.Stop() + + if err := watcher.AddPath(filepath.Join(dir, "missing")); err == nil { + t.Fatal("expected error adding missing path") + } +} + +func setWatcherIntervals(t *testing.T, debounce, poll time.Duration) { + t.Helper() + + originalDebounce := debounceInterval + originalPoll := pollInterval + debounceInterval = debounce + pollInterval = poll + t.Cleanup(func() { + debounceInterval = originalDebounce + pollInterval = originalPoll + }) +} + +func waitFor(t *testing.T, timeout time.Duration, condition func() bool, message string) { + t.Helper() + + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if condition() { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatal(message) +}