Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions internal/proxy_handler.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
package internal

import (
"context"
"errors"
"log/slog"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
)

// StatusClientClosedRequest is a non-standard status code, following nginx
// convention, used when the client disconnects before we are able to respond.
// Recording it (rather than a 502) keeps client-cancelled requests out of the
// 5xx bucket in access logs and metrics.
const StatusClientClosedRequest = 499

func NewProxyHandler(targetUrl *url.URL, badGatewayPage string, forwardHeaders bool) http.Handler {
return &httputil.ReverseProxy{
Rewrite: func(r *httputil.ProxyRequest) {
Expand All @@ -29,13 +37,35 @@ func ProxyErrorHandler(badGatewayPage string) func(w http.ResponseWriter, r *htt
}

return func(w http.ResponseWriter, r *http.Request, err error) {
if isClientCancellation(err) {
// The client disconnected before we could respond, so there is
// nothing to send. We still set a status code so the request is
// recorded in the access log, but as a client-closed request
// rather than an upstream failure -- otherwise client cancellations
// (a fetch aborted, a Turbo Frame swapped, a navigation away) show
// up as 502s and pollute error dashboards.
slog.Debug("Client disconnected before response", "path", r.URL.Path)
w.WriteHeader(StatusClientClosedRequest)
return
}

slog.Info("Unable to proxy request", "path", r.URL.Path, "error", err)

if isRequestEntityTooLarge(err) {
w.WriteHeader(http.StatusRequestEntityTooLarge)
return
}

if isGatewayTimeout(err) {
w.WriteHeader(http.StatusGatewayTimeout)
return
}

if isChunkedEncodingError(err) {
w.WriteHeader(http.StatusBadRequest)
return
}

if content != nil {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusBadGateway)
Expand Down Expand Up @@ -64,11 +94,43 @@ func setXForwarded(r *httputil.ProxyRequest, forwardHeaders bool) {
}
}

func isClientCancellation(err error) bool {
return errors.Is(err, context.Canceled)
}

func isRequestEntityTooLarge(err error) bool {
var maxBytesError *http.MaxBytesError
return errors.As(err, &maxBytesError)
}

func isGatewayTimeout(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
return netErr.Timeout()
}
return false
}

func isChunkedEncodingError(err error) bool {
if err == nil {
return false
}

// The chunked encoding support in the stdlib returns these failures as
// plain errors built with errors.New, so matching them means string
// matching on the error message, unfortunately.
switch err.Error() {
case "invalid byte in chunk length",
"http chunk length too large",
"malformed chunked encoding",
"trailer header without chunked transfer encoding",
"too many trailers":
return true
}

return false
}

func createProxyTransport() *http.Transport {
// The default transport requests compressed responses even if the client
// didn't. If it receives a compressed response but the client wants
Expand Down
171 changes: 171 additions & 0 deletions internal/proxy_handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
package internal

import (
"bytes"
"context"
"errors"
"fmt"
"log/slog"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestProxyErrorHandler_clientCancellationReturnsClientClosedRequest(t *testing.T) {
handler := ProxyErrorHandler("")

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)

handler(w, r, context.Canceled)

assert.Equal(t, StatusClientClosedRequest, w.Code)
assert.Empty(t, w.Body.String())
}

func TestProxyErrorHandler_wrappedClientCancellationReturnsClientClosedRequest(t *testing.T) {
handler := ProxyErrorHandler("")

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)

handler(w, r, fmt.Errorf("proxying request: %w", context.Canceled))

assert.Equal(t, StatusClientClosedRequest, w.Code)
}

func TestProxyErrorHandler_clientCancellationIsNotLoggedAsProxyError(t *testing.T) {
var buf bytes.Buffer
original := slog.Default()
slog.SetDefault(slog.New(slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelInfo})))
defer slog.SetDefault(original)
Comment thread
alexspeller marked this conversation as resolved.

handler := ProxyErrorHandler("")
handler(httptest.NewRecorder(), httptest.NewRequest("GET", "/", nil), context.Canceled)

assert.NotContains(t, buf.String(), "Unable to proxy request")
}

func TestProxyErrorHandler_upstreamErrorReturnsBadGateway(t *testing.T) {
handler := ProxyErrorHandler("")

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)

handler(w, r, errors.New("dial tcp [::1]:3000: connect: connection refused"))

assert.Equal(t, http.StatusBadGateway, w.Code)
}

func TestProxyErrorHandler_connectionRefusedReturnsBadGateway(t *testing.T) {
handler := ProxyErrorHandler("")

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)

// A real connection-refused error is a net.Error, but it is not a timeout,
// so it must still be treated as a bad gateway rather than a 504.
err := &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connect: connection refused")}
require.False(t, err.Timeout())

handler(w, r, err)

assert.Equal(t, http.StatusBadGateway, w.Code)
}

func TestProxyErrorHandler_upstreamTimeoutReturnsGatewayTimeout(t *testing.T) {
handler := ProxyErrorHandler("")

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)

// context.DeadlineExceeded satisfies net.Error with Timeout() == true, the
// same shape the transport returns when an upstream read/dial times out.
handler(w, r, context.DeadlineExceeded)

assert.Equal(t, http.StatusGatewayTimeout, w.Code)
}

func TestProxyErrorHandler_chunkedEncodingErrorReturnsBadRequest(t *testing.T) {
handler := ProxyErrorHandler("")

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)

handler(w, r, errors.New("malformed chunked encoding"))

assert.Equal(t, http.StatusBadRequest, w.Code)
}

func TestProxyErrorHandler_entityTooLargeReturnsRequestEntityTooLarge(t *testing.T) {
handler := ProxyErrorHandler("")

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)

handler(w, r, &http.MaxBytesError{})

assert.Equal(t, http.StatusRequestEntityTooLarge, w.Code)
}

// End-to-end: a client that disconnects mid-request must be recorded as a
// client-closed request (499), not an upstream failure (502).
func TestProxyHandler_clientDisconnectIsRecordedAsClientClosedRequest(t *testing.T) {
upstreamReached := make(chan struct{})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
close(upstreamReached)
<-r.Context().Done() // block until the client goes away
}))
defer upstream.Close()

targetUrl, err := url.Parse(upstream.URL)
require.NoError(t, err)

proxy := NewProxyHandler(targetUrl, "", false)

var capturedStatus int
done := make(chan struct{})
front := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
recorder := &statusRecorder{ResponseWriter: w, status: http.StatusOK}
proxy.ServeHTTP(recorder, r)
capturedStatus = recorder.status
close(done)
}))
defer front.Close()

ctx, cancel := context.WithCancel(context.Background())
req, err := http.NewRequestWithContext(ctx, "GET", front.URL, nil)
require.NoError(t, err)

clientDone := make(chan struct{})
go func() {
defer close(clientDone)
resp, err := http.DefaultClient.Do(req)
if resp != nil {
_ = resp.Body.Close()
}
_ = err // the client cancels the request, so an error is expected
}()

<-upstreamReached
cancel()
<-done
<-clientDone

assert.Equal(t, StatusClientClosedRequest, capturedStatus)
}

type statusRecorder struct {
http.ResponseWriter
status int
}

func (r *statusRecorder) WriteHeader(status int) {
r.status = status
r.ResponseWriter.WriteHeader(status)
}
Loading