diff --git a/internal/httputil/body.go b/internal/httputil/body.go new file mode 100644 index 0000000..9af1655 --- /dev/null +++ b/internal/httputil/body.go @@ -0,0 +1,20 @@ +// Package httputil provides shared HTTP utility functions. +package httputil + +import ( + "bytes" + "io" +) + +// ReconstitutedBody returns a new ReadCloser that first replays the captured +// bytes, then continues reading from the remaining original body. Close +// delegates to the original body's Close method. +func ReconstitutedBody(captured []byte, orig io.ReadCloser) io.ReadCloser { + return struct { + io.Reader + io.Closer + }{ + Reader: io.MultiReader(bytes.NewReader(captured), orig), + Closer: orig, + } +} diff --git a/internal/httputil/body_test.go b/internal/httputil/body_test.go new file mode 100644 index 0000000..f6aca16 --- /dev/null +++ b/internal/httputil/body_test.go @@ -0,0 +1,42 @@ +package httputil_test + +import ( + "io" + "strings" + "testing" + + "github.com/yesdevnull/trenchcoat/internal/httputil" +) + +func TestReconstitutedBody(t *testing.T) { + original := "hello world" + captured := []byte("hello ") + remaining := io.NopCloser(strings.NewReader("world")) + + body := httputil.ReconstitutedBody(captured, remaining) + + got, err := io.ReadAll(body) + if err != nil { + t.Fatal(err) + } + if string(got) != original { + t.Errorf("got %q, want %q", string(got), original) + } + + if err := body.Close(); err != nil { + t.Errorf("unexpected close error: %v", err) + } +} + +func TestReconstitutedBodyEmpty(t *testing.T) { + remaining := io.NopCloser(strings.NewReader("full body")) + body := httputil.ReconstitutedBody(nil, remaining) + + got, err := io.ReadAll(body) + if err != nil { + t.Fatal(err) + } + if string(got) != "full body" { + t.Errorf("got %q, want %q", string(got), "full body") + } +} diff --git a/internal/matcher/matcher.go b/internal/matcher/matcher.go index edadc59..3df78e7 100644 --- a/internal/matcher/matcher.go +++ b/internal/matcher/matcher.go @@ -4,7 +4,6 @@ package matcher import ( - "bytes" "fmt" "io" "net/http" @@ -16,6 +15,7 @@ import ( "github.com/bmatcuk/doublestar/v4" "github.com/yesdevnull/trenchcoat/internal/coat" + "github.com/yesdevnull/trenchcoat/internal/httputil" ) // maxBodyMatchSize is the maximum request body size (in bytes) that the matcher @@ -36,7 +36,8 @@ const ( // entry is a compiled coat with pre-computed matching metadata. type entry struct { coat coat.Coat - index int // original definition order + index int // original definition order + filePath string // source file path, empty for programmatic coats uriType uriMatchType regex *regexp.Regexp // only for regex URIs bodyRegex *regexp.Regexp // only for body_match: regex @@ -54,8 +55,9 @@ type entry struct { type MatchResult struct { Name string Coat coat.Coat - ResponseIdx int // index into Responses for sequence coats, -1 for singular - Exhausted bool // true if sequence is exhausted (once mode) + FilePath string // source file path (empty for programmatic coats) + ResponseIdx int // index into Responses for sequence coats, -1 for singular + Exhausted bool // true if sequence is exhausted (once mode) } // Matcher matches HTTP requests to coat definitions. @@ -142,17 +144,35 @@ func New(coats []coat.Coat) *Matcher { return &Matcher{entries: entries} } +// NewWithPaths creates a Matcher from coats with associated file paths. +// The paths slice must be the same length as coats (use "" for programmatic coats). +// Panics if len(paths) != len(coats). +func NewWithPaths(coats []coat.Coat, paths []string) *Matcher { + if len(paths) != len(coats) { + panic(fmt.Sprintf("matcher.NewWithPaths: len(paths)=%d != len(coats)=%d", len(paths), len(coats))) + } + m := New(coats) + for i := range m.entries { + m.entries[i].filePath = paths[i] + } + return m +} + +// errBodyTooLarge is returned by the body reader when the request body exceeds +// maxBodyMatchSize. +var errBodyTooLarge = fmt.Errorf("request body exceeds %d bytes", maxBodyMatchSize) + // lazyBodyReader creates a function that lazily reads the request body on first // call, bounded to maxBodyMatchSize. The request body is reconstituted so // downstream handlers still see the full body. -func lazyBodyReader(req *http.Request) func() (string, bool) { +func lazyBodyReader(req *http.Request) func() (string, error) { var reqBodyStr string var bodyRead bool - var bodyReadErr bool + var readErr error - return func() (string, bool) { + return func() (string, error) { if bodyRead { - return reqBodyStr, bodyReadErr + return reqBodyStr, readErr } bodyRead = true if req.Body != nil { @@ -162,34 +182,26 @@ func lazyBodyReader(req *http.Request) func() (string, bool) { limited := io.LimitReader(origBody, maxBodyMatchSize+1) allRead, err := io.ReadAll(limited) if err != nil { - bodyReadErr = true + readErr = fmt.Errorf("reading request body: %w", err) + req.Body = httputil.ReconstitutedBody(allRead, origBody) + return reqBodyStr, readErr } // If we read more than maxBodyMatchSize bytes, treat it as too large // for body matching, but still restore the full body for downstream use. - var reqBody []byte if len(allRead) > maxBodyMatchSize { - bodyReadErr = true - reqBody = allRead[:maxBodyMatchSize] + readErr = errBodyTooLarge + reqBodyStr = string(allRead[:maxBodyMatchSize]) } else { - reqBody = allRead + reqBodyStr = string(allRead) } - // Convert to string once to avoid repeated allocations in matchesBody. - reqBodyStr = string(reqBody) - // Reconstitute req.Body as the bytes already read plus the remaining // unread original body so downstream handlers see the full body, and // ensure Close() still delegates to the original body's Close(). - req.Body = struct { - io.Reader - io.Closer - }{ - Reader: io.MultiReader(bytes.NewReader(allRead), origBody), - Closer: origBody, - } + req.Body = httputil.ReconstitutedBody(allRead, origBody) } - return reqBodyStr, bodyReadErr + return reqBodyStr, readErr } } @@ -227,7 +239,7 @@ type candidate struct { } // findCandidates evaluates all entries against the request and returns matching candidates. -func (m *Matcher) findCandidates(req *http.Request, getBody func() (string, bool)) []candidate { +func (m *Matcher) findCandidates(req *http.Request, getBody func() (string, error)) []candidate { var candidates []candidate for _, e := range m.entries { if !matchesMethod(e, req.Method) { @@ -262,8 +274,9 @@ func selectBest(candidates []candidate) *MatchResult { best := candidates[0].entry result := &MatchResult{ - Name: best.resolvedName(), - Coat: best.coat, + Name: best.resolvedName(), + Coat: best.coat, + FilePath: best.filePath, } idx, exhausted := resolveSequence(best) @@ -540,12 +553,12 @@ func matchesQuery(e *entry, rawQuery string, queryValues map[string][]string) bo return true } -func matchesBody(e *entry, getBody func() (string, bool)) bool { +func matchesBody(e *entry, getBody func() (string, error)) bool { if e.coat.Request.Body == nil { return true // No body constraint — matches anything. } - body, readErr := getBody() - if readErr { + body, err := getBody() + if err != nil { return false // Treat read errors as non-match. } switch e.coat.Request.BodyMatch { diff --git a/internal/matcher/matcher_test.go b/internal/matcher/matcher_test.go index 8a93c8c..2728934 100644 --- a/internal/matcher/matcher_test.go +++ b/internal/matcher/matcher_test.go @@ -3,6 +3,8 @@ package matcher_test import ( "fmt" "net/http" + "net/http/httptest" + "path/filepath" "strings" "testing" @@ -1044,6 +1046,37 @@ func TestMatch_BodyMatch_ExplicitExact(t *testing.T) { } } +// --- Body I/O error handling --- + +type errorReader struct { + err error +} + +func (r *errorReader) Read(p []byte) (int, error) { + return 0, r.err +} + +func (r *errorReader) Close() error { + return nil +} + +func TestMatchBodyIOError(t *testing.T) { + coats := []coat.Coat{ + { + Name: "body-match", + Request: coat.Request{URI: "/test", Body: coat.StringPtr("expected")}, + Response: &coat.Response{Code: 200}, + }, + } + m := matcher.New(coats) + + req := httptest.NewRequest("GET", "/test", &errorReader{err: fmt.Errorf("disk failure")}) + result := m.Match(req) + if result != nil { + t.Error("expected no match when body read fails, but got a match") + } +} + // --- Body size limit --- func TestMatch_BodyExceedsMaxSize_NoMatch(t *testing.T) { @@ -1277,6 +1310,30 @@ func TestMatchVerbose_MatchReturnsNoMismatches(t *testing.T) { } } +// --- FilePath propagation --- + +func TestMatchResultFilePath(t *testing.T) { + coats := []coat.Coat{ + { + Name: "test", + Request: coat.Request{URI: "/test"}, + Response: &coat.Response{Code: 200}, + }, + } + wantPath := filepath.Join(t.TempDir(), "test.yaml") + paths := []string{wantPath} + m := matcher.NewWithPaths(coats, paths) + + req := httptest.NewRequest("GET", "/test", nil) + result := m.Match(req) + if result == nil { + t.Fatal("expected match") + } + if result.FilePath != wantPath { + t.Errorf("got FilePath %q, want %q", result.FilePath, wantPath) + } +} + // --- Helpers --- func newRequest(t *testing.T, method, uri string, headers map[string]string) *http.Request { diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index f25523f..55b6bcb 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -75,7 +75,8 @@ type Proxy struct { mu sync.Mutex counters map[string]int // for append dedup mode - captures sync.WaitGroup + captures sync.WaitGroup + captureSem chan struct{} // bounds concurrent capture goroutines } // New creates a new Proxy. @@ -146,7 +147,8 @@ func New(cfg Config) (*Proxy, error) { return http.ErrUseLastResponse }, }, - counters: make(map[string]int), + counters: make(map[string]int), + captureSem: make(chan struct{}, 20), } return p, nil @@ -200,10 +202,27 @@ func (p *Proxy) WaitCaptures() { p.captures.Wait() } -// Shutdown gracefully stops the proxy, waiting for pending captures to complete. +// Shutdown gracefully stops the proxy, waiting for pending captures to complete +// within half the timeout, then draining HTTP connections with the remaining time. func (p *Proxy) Shutdown(timeout time.Duration) error { - p.captures.Wait() - ctx, cancel := context.WithTimeout(context.Background(), timeout) + // Split timeout: half for capture drain, half for HTTP drain. + captureDrain := timeout / 2 + httpDrain := timeout - captureDrain + + // Wait for captures with timeout. + done := make(chan struct{}) + go func() { + p.captures.Wait() + close(done) + }() + select { + case <-done: + // All captures finished. + case <-time.After(captureDrain): + p.logger.Warn("timed out waiting for pending captures to complete") + } + + ctx, cancel := context.WithTimeout(context.Background(), httpDrain) defer cancel() return p.httpServer.Shutdown(ctx) } @@ -301,7 +320,18 @@ func (p *Proxy) handleRequest(w http.ResponseWriter, r *http.Request) { // Capture if applicable. if shouldCapture { - p.captures.Go(func() { p.captureCoat(r, reqBody, upstreamResp, respBody) }) + capReq := captureRequest{ + Method: r.Method, + URI: r.URL.Path, + RawQuery: r.URL.RawQuery, + Header: r.Header.Clone(), + Body: reqBody, + } + p.captures.Go(func() { + p.captureSem <- struct{}{} + defer func() { <-p.captureSem }() + p.captureCoatFromCopy(capReq, upstreamResp, respBody) + }) } } @@ -317,7 +347,7 @@ func (p *Proxy) shouldCapture(urlPath string) bool { return matched } -func (p *Proxy) captureCoat(r *http.Request, reqBody []byte, resp *http.Response, respBody []byte) { +func (p *Proxy) captureCoatFromCopy(req captureRequest, resp *http.Response, respBody []byte) { // Decompress response body for human-readable coat capture. captureBody := respBody decompressed := false @@ -341,9 +371,9 @@ func (p *Proxy) captureCoat(r *http.Request, reqBody []byte, resp *http.Response var reqHeaders, respHeaders map[string]string if !p.config.NoHeaders { reqHeaders = make(map[string]string) - for k := range r.Header { + for k := range req.Header { if !p.isStrippedHeader(k) { - reqHeaders[k] = r.Header.Get(k) + reqHeaders[k] = req.Header.Get(k) } } @@ -368,10 +398,10 @@ func (p *Proxy) captureCoat(r *http.Request, reqBody []byte, resp *http.Response coatDef := coatFile{ Coats: []coatEntry{ { - Name: fmt.Sprintf("%s %s", r.Method, r.URL.Path), + Name: fmt.Sprintf("%s %s", req.Method, req.URI), Request: coatRequest{ - Method: r.Method, - URI: r.URL.Path, + Method: req.Method, + URI: req.URI, }, Response: coatResponse{ Code: resp.StatusCode, @@ -386,17 +416,17 @@ func (p *Proxy) captureCoat(r *http.Request, reqBody []byte, resp *http.Response coatDef.Coats[0].Request.Headers = reqHeaders } - if p.captureBodyEnabled() && len(reqBody) > 0 { - body := string(reqBody) + if p.captureBodyEnabled() && len(req.Body) > 0 { + body := string(req.Body) coatDef.Coats[0].Request.Body = &body } - if r.URL.RawQuery != "" { - coatDef.Coats[0].Request.Query = r.URL.RawQuery + if req.RawQuery != "" { + coatDef.Coats[0].Request.Query = req.RawQuery } // Generate filename. - filename := p.generateFilename(r.Method, r.URL.Path, resp.StatusCode) + filename := p.generateFilename(req.Method, req.URI, resp.StatusCode) if filename == "" { // Skip dedup — file already exists. return @@ -428,6 +458,16 @@ func (p *Proxy) captureCoat(r *http.Request, reqBody []byte, resp *http.Response } } +// captureRequest holds request data copied from *http.Request before the +// handler returns, so the capture goroutine doesn't reference the original. +type captureRequest struct { + Method string + URI string + RawQuery string + Header http.Header + Body []byte +} + // nameTemplateData provides fields for custom file name templates. type nameTemplateData struct { Method string diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index fbb5b43..6051af9 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -4,12 +4,15 @@ import ( "bytes" "compress/gzip" "encoding/json" + "fmt" "io" + "log/slog" "net/http" "net/http/httptest" "os" "path/filepath" "strings" + "sync" "testing" "time" @@ -1251,3 +1254,98 @@ func TestProxy_NameTemplate(t *testing.T) { t.Fatalf("expected file POST-api_v1_users-201.yaml, found: %v", allFiles) } } + +func TestProxyCapturesConcurrencyBounded(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + _, _ = fmt.Fprint(w, "ok") + })) + defer upstream.Close() + + writeDir := t.TempDir() + p, err := proxy.New(proxy.Config{ + UpstreamURL: upstream.URL, + WriteDir: writeDir, + Dedupe: "append", + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + }) + if err != nil { + t.Fatal(err) + } + addr, err := p.Start("127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + + const numRequests = 30 + var wg sync.WaitGroup + client := &http.Client{Timeout: 5 * time.Second} + for i := 0; i < numRequests; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + resp, err := client.Get(fmt.Sprintf("http://%s/path/%d", addr, n)) + if err != nil { + t.Errorf("request %d failed: %v", n, err) + return + } + _ = resp.Body.Close() + }(i) + } + wg.Wait() + + p.WaitCaptures() + + files, err := filepath.Glob(filepath.Join(writeDir, "*.yaml")) + if err != nil { + t.Fatal(err) + } + if len(files) == 0 { + t.Fatal("expected captured coat files, got none") + } + if len(files) < 20 { + t.Errorf("expected at least 20 coat files, got %d", len(files)) + } + + _ = p.Shutdown(5 * time.Second) +} + +func TestProxyShutdownRespectsTimeoutWithPendingCaptures(t *testing.T) { + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + _, _ = fmt.Fprint(w, "ok") + })) + defer upstream.Close() + + writeDir := t.TempDir() + + p, err := proxy.New(proxy.Config{ + UpstreamURL: upstream.URL, + WriteDir: writeDir, + Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), + }) + if err != nil { + t.Fatal(err) + } + addr, err := p.Start("127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + + resp, err := http.Get("http://" + addr + "/test") + if err != nil { + t.Fatal(err) + } + _ = resp.Body.Close() + + done := make(chan error, 1) + go func() { + done <- p.Shutdown(500 * time.Millisecond) + }() + + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("Shutdown did not return within timeout") + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 403cd26..8446a46 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -20,6 +20,7 @@ import ( "time" "github.com/yesdevnull/trenchcoat/internal/coat" + "github.com/yesdevnull/trenchcoat/internal/httputil" "github.com/yesdevnull/trenchcoat/internal/matcher" ) @@ -61,11 +62,13 @@ func New(loaded []coat.LoadedCoat, cfg Config) *Server { cfg.Logger = slog.Default() } + coats, paths := extractCoatsAndPaths(loaded) + s := &Server{ logger: cfg.Logger, verbose: cfg.Verbose, recordCalls: cfg.RecordCalls, - matcher: matcher.New(extractCoats(loaded)), + matcher: matcher.NewWithPaths(coats, paths), coats: loaded, calls: make(map[string][]CapturedRequest), } @@ -81,13 +84,15 @@ func New(loaded []coat.LoadedCoat, cfg Config) *Server { return s } -// extractCoats returns just the Coat values from a slice of LoadedCoat. -func extractCoats(loaded []coat.LoadedCoat) []coat.Coat { +// extractCoatsAndPaths returns the Coat values and file paths from a slice of LoadedCoat. +func extractCoatsAndPaths(loaded []coat.LoadedCoat) ([]coat.Coat, []string) { coats := make([]coat.Coat, len(loaded)) + paths := make([]string, len(loaded)) for i, lc := range loaded { coats[i] = lc.Coat + paths[i] = lc.FilePath } - return coats + return coats, paths } // Start begins listening on the configured port. It returns the actual @@ -148,7 +153,8 @@ func (s *Server) Shutdown(timeout time.Duration) error { // Reload replaces the loaded coats and rebuilds the matcher. func (s *Server) Reload(loaded []coat.LoadedCoat) { - m := matcher.New(extractCoats(loaded)) + coats, paths := extractCoatsAndPaths(loaded) + m := matcher.NewWithPaths(coats, paths) s.mu.Lock() s.coats = loaded @@ -163,7 +169,6 @@ func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) { s.mu.RLock() m := s.matcher - allCoats := s.coats s.mu.RUnlock() var result *matcher.MatchResult @@ -186,34 +191,7 @@ func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) { s.recordCall(result.Name, r) } - // Look up the coat's source file path for logging. - // Match on name+method+uri, but only use the path when there is a - // single unique file path to avoid attributing the request to the wrong - // file when duplicate coats exist across different files or names are empty. - var coatFilePath string - if s.verbose { - var firstPath string - hasMatch := false - ambiguous := false - resultMethod := normalizeMethod(result.Coat.Request.Method) - for _, lc := range allCoats { - if lc.Coat.Name == result.Coat.Name && - lc.Coat.Request.URI == result.Coat.Request.URI && - normalizeMethod(lc.Coat.Request.Method) == resultMethod { - if !hasMatch { - firstPath = lc.FilePath - hasMatch = true - } else if lc.FilePath != firstPath { - // Multiple matching coats from different files: path is ambiguous. - ambiguous = true - break - } - } - } - if hasMatch && !ambiguous { - coatFilePath = firstPath - } - } + coatFilePath := result.FilePath // Handle exhausted sequences. if result.Exhausted { @@ -235,7 +213,7 @@ func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) { // Resolve body_file before setting headers so error responses stay clean. body := resp.Body if resp.BodyFile != "" { - bodyBytes, err := resolveBodyFile(resp.BodyFile, result.Coat, allCoats) + bodyBytes, err := resolveBodyFile(resp.BodyFile, result.FilePath) if err != nil { s.logger.Error("body_file resolution failed", "path", resp.BodyFile, "error", err) w.Header().Set("Content-Type", "application/json") @@ -250,7 +228,7 @@ func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) { } // Render response body templates. - body = renderTemplate(body, r) + body = renderTemplate(body, r, s.logger) // Apply delay with optional jitter (context-aware so it cancels if the client disconnects). if resp.DelayMs > 0 || resp.DelayJitterMs > 0 { @@ -415,13 +393,7 @@ func (s *Server) recordCall(name string, r *http.Request) { } } // Reconstruct r.Body: captured bytes + remaining unread original body. - r.Body = struct { - io.Reader - io.Closer - }{ - Reader: io.MultiReader(bytes.NewReader(headBytes), origBody), - Closer: origBody, - } + r.Body = httputil.ReconstitutedBody(headBytes, origBody) } s.callsMu.Lock() @@ -451,16 +423,6 @@ func (s *Server) Calls(name string) []CapturedRequest { return out } -// normalizeMethod returns the effective HTTP method, applying the same rules -// as the matcher: uppercase, default to GET when empty. -func normalizeMethod(method string) string { - m := strings.ToUpper(method) - if m == "" { - return "GET" - } - return m -} - // ResetCalls clears all recorded call data. func (s *Server) ResetCalls() { s.callsMu.Lock() @@ -498,7 +460,7 @@ func (td templateData) Segment(n int) string { // renderTemplate parses and executes a Go text/template with request context. // Returns the original body if it contains no template directives or if parsing fails. -func renderTemplate(body string, r *http.Request) string { +func renderTemplate(body string, r *http.Request, logger *slog.Logger) string { if !strings.Contains(body, "{{") { return body } @@ -521,13 +483,7 @@ func renderTemplate(body string, r *http.Request) string { // Reconstruct r.Body so that downstream handlers can still read // the entire request body: first the bytes we've captured, then // the remaining unread bytes from the original body. - r.Body = struct { - io.Reader - io.Closer - }{ - Reader: io.MultiReader(bytes.NewReader(bodyBytes), origBody), - Closer: origBody, - } + r.Body = httputil.ReconstitutedBody(bodyBytes, origBody) } data := templateData{ @@ -540,28 +496,14 @@ func renderTemplate(body string, r *http.Request) string { var buf bytes.Buffer if err := tmpl.Execute(&buf, data); err != nil { + logger.Warn("response template execution failed", "error", err) return body } return buf.String() } // resolveBodyFile resolves a body_file path relative to the coat file's location. -func resolveBodyFile(bodyFile string, c coat.Coat, allCoats []coat.LoadedCoat) ([]byte, error) { - // Find the file path for this coat, detecting ambiguous matches. - var coatFilePath string - cMethod := normalizeMethod(c.Request.Method) - for _, lc := range allCoats { - if lc.Coat.Name == c.Name && - lc.Coat.Request.URI == c.Request.URI && - normalizeMethod(lc.Coat.Request.Method) == cMethod { - if coatFilePath == "" { - coatFilePath = lc.FilePath - } else if lc.FilePath != coatFilePath { - return nil, fmt.Errorf("ambiguous coat source for body_file %q: multiple coats match name=%q uri=%q method=%q", bodyFile, c.Name, c.Request.URI, c.Request.Method) - } - } - } - +func resolveBodyFile(bodyFile string, coatFilePath string) ([]byte, error) { // Reject absolute paths — body_file must always be relative. if filepath.IsAbs(bodyFile) { return nil, fmt.Errorf("body_file must be a relative path, got absolute path") @@ -592,6 +534,32 @@ func resolveBodyFile(bodyFile string, c coat.Coat, allCoats []coat.LoadedCoat) ( return nil, fmt.Errorf("unable to resolve body_file path: %w", err) } + // Early containment check before any Stat calls to avoid leaking + // file existence information for paths outside the base directory. + if !isSubPath(absBase, absResolved) { + return nil, fmt.Errorf("body_file path escapes the coat file directory") + } + + // Check that the base directory exists before attempting symlink resolution. + // Use Lstat to avoid following symlinks during the pre-check. + if _, statErr := os.Lstat(absBase); statErr != nil { + if os.IsNotExist(statErr) { + return nil, fmt.Errorf("body_file base directory %q not found", baseDir) + } + return nil, fmt.Errorf("body_file base directory %q: %w", baseDir, statErr) + } + + // Check if the target file exists before attempting symlink resolution, + // so the error message is clear ("not found") rather than confusing + // ("unable to resolve symlinks"). Use Lstat to avoid following symlinks + // to paths outside the base directory before EvalSymlinks runs. + if _, statErr := os.Lstat(absResolved); statErr != nil { + if os.IsNotExist(statErr) { + return nil, fmt.Errorf("body_file %q not found", bodyFile) + } + return nil, fmt.Errorf("body_file %q: %w", bodyFile, statErr) + } + // Resolve symlinks to prevent escapes via symlinked paths. canonicalBase, err := filepath.EvalSymlinks(absBase) if err != nil { diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 172df3f..4b17345 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -1,8 +1,10 @@ package server_test import ( + "bytes" "encoding/json" "io" + "log/slog" "net/http" "os" "path/filepath" @@ -150,7 +152,7 @@ func TestServe_BodyFile(t *testing.T) { func TestServe_BodyFile_Missing(t *testing.T) { srv := startServer(t, []coat.LoadedCoat{ { - FilePath: "/tmp/nonexistent/coat.yaml", + FilePath: filepath.Join(t.TempDir(), "coat.yaml"), Coat: coat.Coat{ Name: "missing-file", Request: coat.Request{Method: "GET", URI: "/data"}, @@ -548,15 +550,12 @@ func TestServe_Sequence_DefaultCycle(t *testing.T) { assertEqual(t, "third (cycle)", "a", readBody(t, resp3)) } -// --- resolveBodyFile ambiguity tests --- +// --- resolveBodyFile tests: duplicate coats and body_file resolution --- -func TestServe_BodyFile_AmbiguousCoatSources(t *testing.T) { - // Two coat files define a coat with the same name/URI/method but different - // file paths. resolveBodyFile should detect the ambiguity and return 500. +func TestServe_BodyFile_DuplicateCoatsFirstWins(t *testing.T) { dirA := t.TempDir() dirB := t.TempDir() - // Create body files in both directories. if err := os.WriteFile(filepath.Join(dirA, "data.json"), []byte(`{"from": "A"}`), 0644); err != nil { t.Fatal(err) } @@ -589,13 +588,9 @@ func TestServe_BodyFile_AmbiguousCoatSources(t *testing.T) { } defer func() { _ = resp.Body.Close() }() - assertEqual(t, "status", 500, resp.StatusCode) - - var errBody map[string]string - if err := json.NewDecoder(resp.Body).Decode(&errBody); err != nil { - t.Fatalf("failed to decode error body: %v", err) - } - assertEqual(t, "error", "body_file not found", errBody["error"]) + assertEqual(t, "status", 200, resp.StatusCode) + body := readBody(t, resp) + assertEqual(t, "body", `{"from": "A"}`, body) } func TestServe_BodyFile_SameCoatFilePath_NoAmbiguity(t *testing.T) { @@ -788,6 +783,88 @@ func assertEqual[T comparable](t *testing.T, field string, expected, actual T) { } } +func TestServe_BodyFile_Missing_ClearLogMessage(t *testing.T) { + var logBuf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&logBuf, &slog.HandlerOptions{Level: slog.LevelError})) + + coats := []coat.LoadedCoat{ + { + FilePath: filepath.Join(t.TempDir(), "coat.yaml"), + Coat: coat.Coat{ + Name: "missing-file", + Request: coat.Request{Method: "GET", URI: "/data"}, + Response: &coat.Response{Code: 200, BodyFile: "does-not-exist.json"}, + }, + }, + } + + srv := server.New(coats, server.Config{Logger: logger}) + addr, err := srv.Start("127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = srv.Shutdown(5 * time.Second) }) + + resp, err := httpClient.Get("http://" + addr + "/data") + if err != nil { + t.Fatal(err) + } + defer func() { _ = resp.Body.Close() }() + _ = readBody(t, resp) + + assertEqual(t, "status", 500, resp.StatusCode) + + logOutput := logBuf.String() + if strings.Contains(logOutput, "symlink") { + t.Errorf("log message should not mention symlinks for a missing file, got: %s", logOutput) + } + if !strings.Contains(logOutput, "not found") { + t.Errorf("log message should contain 'not found', got: %s", logOutput) + } +} + +func TestServe_TemplateExecutionErrorLogged(t *testing.T) { + var logBuf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&logBuf, &slog.HandlerOptions{Level: slog.LevelWarn})) + + coats := []coat.LoadedCoat{ + { + Coat: coat.Coat{ + Name: "bad-template", + Request: coat.Request{URI: "/tmpl"}, + Response: &coat.Response{ + Code: 200, + Body: `result: {{call .Method}}`, + }, + }, + }, + } + + srv := server.New(coats, server.Config{Logger: logger}) + addr, err := srv.Start("127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = srv.Shutdown(5 * time.Second) }) + + resp, err := httpClient.Get("http://" + addr + "/tmpl") + if err != nil { + t.Fatal(err) + } + defer func() { _ = resp.Body.Close() }() + + body := readBody(t, resp) + assertEqual(t, "status", 200, resp.StatusCode) + if !strings.Contains(body, "{{call .Method}}") { + t.Errorf("expected raw template in response body, got: %s", body) + } + + logOutput := logBuf.String() + if !strings.Contains(logOutput, "template execution failed") { + t.Errorf("expected 'template execution failed' in logs, got: %s", logOutput) + } +} + func TestCalls_ReturnsClonedHeaders(t *testing.T) { coats := []coat.LoadedCoat{ { diff --git a/internal/server/verbose_test.go b/internal/server/verbose_test.go index 7f85b1c..e663882 100644 --- a/internal/server/verbose_test.go +++ b/internal/server/verbose_test.go @@ -172,9 +172,9 @@ func TestServe_VerboseLogging_UnnamedCoatFilePath(t *testing.T) { } } -func TestServe_VerboseLogging_AmbiguousCoatOmitsFilePath(t *testing.T) { +func TestServe_VerboseLogging_DuplicateCoatsLogsFirstFilePath(t *testing.T) { // When multiple coats share the same name+method+URI across different - // files, the file path should be omitted to avoid misattribution. + // files, the first coat's file path is logged (deterministic by order). coats := []coat.LoadedCoat{ { FilePath: "/fake/a.yaml", @@ -213,9 +213,8 @@ func TestServe_VerboseLogging_AmbiguousCoatOmitsFilePath(t *testing.T) { assertEqual(t, "status", 200, resp.StatusCode) logOutput := logBuf.String() - // File path should NOT appear because the match is ambiguous. - if strings.Contains(logOutput, "file=") { - t.Errorf("expected no 'file=' in log output for ambiguous coats, got:\n%s", logOutput) + if !strings.Contains(logOutput, "file=/fake/a.yaml") { + t.Errorf("expected 'file=/fake/a.yaml' in log output, got:\n%s", logOutput) } }