diff --git a/README.md b/README.md index 4167010..5a58ebf 100644 --- a/README.md +++ b/README.md @@ -72,31 +72,33 @@ In most cases, Thruster should work out of the box with no additional configuration. But if you need to customize its behavior, there are a few environment variables that you can set. -| Variable Name | Description | Default Value | -|-----------------------------|---------------------------------------------------------|---------------| -| `TLS_DOMAIN` | Comma-separated list of domain names to use for TLS provisioning. If not set, TLS will be disabled. | None | -| `TARGET_PORT` | The port that your Puma server should run on. Thruster will set `PORT` to this value when starting your server. | 3000 | -| `CACHE_SIZE` | The size of the HTTP cache in bytes. | 64MB | -| `MAX_CACHE_ITEM_SIZE` | The maximum size of a single item in the HTTP cache in bytes. | 1MB | -| `GZIP_COMPRESSION_ENABLED` | Whether to enable gzip compression for responses. Set to `0` or `false` to disable. | Enabled | -| `GZIP_COMPRESSION_DISABLE_ON_AUTH` | If set to `true`, disable gzip compression for authenticated requests with `Cookie`, `Authorization`, or `X-Csrf-Token` headers. | `false` | -| `GZIP_COMPRESSION_JITTER` | The amount of random jitter (in bytes) to add to the compressed response size to mitigate BREACH attacks. Set to `0` to disable. | 32 | -| `X_SENDFILE_ENABLED` | Whether to enable X-Sendfile support. Set to `0` or `false` to disable. | Enabled | -| `MAX_REQUEST_BODY` | The maximum size of a request body in bytes. Requests larger than this size will be refused; `0` means no maximum size is enforced. | `0` | -| `STORAGE_PATH` | The path to store Thruster's internal state. Provisioned TLS certificates will be stored here, so that they will not need to be requested every time your application is started. | `./storage/thruster` | -| `BAD_GATEWAY_PAGE` | Path to an HTML file to serve when the backend server returns a 502 Bad Gateway error. If there is no file at the specific path, Thruster will serve an empty 502 response instead. Because Thruster boots very quickly, a custom page can be a useful way to show that your application is starting up. | `./public/502.html` | -| `HTTP_PORT` | The port to listen on for HTTP traffic. | 80 | -| `HTTPS_PORT` | The port to listen on for HTTPS traffic. | 443 | -| `HTTP_IDLE_TIMEOUT` | The maximum time in seconds that a client can be idle before the connection is closed. | 60 | -| `HTTP_READ_TIMEOUT` | The maximum time in seconds that a client can take to send the request headers and body. | 30 | -| `HTTP_WRITE_TIMEOUT` | The maximum time in seconds during which the client must read the response. | 30 | -| `H2C_ENABLED` | Set to `1` or `true` to enable h2c (http/2 cleartext) | Disabled | -| `ACME_DIRECTORY` | The URL of the ACME directory to use for TLS certificate provisioning. | `https://acme-v02.api.letsencrypt.org/directory` (Let's Encrypt production) | -| `EAB_KID` | The EAB key identifier to use when provisioning TLS certificates, if required. | None | -| `EAB_HMAC_KEY` | The Base64-encoded EAB HMAC key to use when provisioning TLS certificates, if required. | None | -| `FORWARD_HEADERS` | Whether to forward X-Forwarded-* headers from the client. | Disabled when running with TLS; enabled otherwise | -| `LOG_REQUESTS` | Log all requests. Set to `0` or `false` to disable request logging | Enabled | -| `DEBUG` | Set to `1` or `true` to enable debug logging. | Disabled | +| Variable Name | Description | Default Value | +|------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------| +| `TLS_DOMAIN` | Comma-separated list of domain names to use for TLS provisioning. If not set, TLS will be disabled. | None | +| `TARGET_PORT` | The port that your Puma server should run on. Thruster will set `PORT` to this value when starting your server. | 3000 | +| `CACHE_SIZE` | The size of the HTTP cache in bytes. | 64MB | +| `MAX_CACHE_ITEM_SIZE` | The maximum size of a single item in the HTTP cache in bytes. | 1MB | +| `GZIP_COMPRESSION_ENABLED` | Whether to enable gzip compression for responses. Set to `0` or `false` to disable. | Enabled | +| `GZIP_COMPRESSION_DISABLE_ON_AUTH` | If set to `true`, disable gzip compression for authenticated requests with `Cookie`, `Authorization`, or `X-Csrf-Token` headers. | `false` | +| `GZIP_COMPRESSION_JITTER` | The amount of random jitter (in bytes) to add to the compressed response size to mitigate BREACH attacks. Set to `0` to disable. | 32 | +| `X_SENDFILE_ENABLED` | Whether to enable X-Sendfile support. Set to `0` or `false` to disable. | Enabled | +| `MAX_REQUEST_BODY` | The maximum size of a request body in bytes. Requests larger than this size will be refused; `0` means no maximum size is enforced. | `0` | +| `STORAGE_PATH` | The path to store Thruster's internal state. Provisioned TLS certificates will be stored here, so that they will not need to be requested every time your application is started. | `./storage/thruster` | +| `BAD_GATEWAY_PAGE` | Path to an HTML file to serve when the backend server returns a 502 Bad Gateway error. If there is no file at the specific path, Thruster will serve an empty 502 response instead. Because Thruster boots very quickly, a custom page can be a useful way to show that your application is starting up. | `./public/502.html` | +| `HTTP_PORT` | The port to listen on for HTTP traffic. | 80 | +| `HTTPS_PORT` | The port to listen on for HTTPS traffic. | 443 | +| `HTTP_IDLE_TIMEOUT` | The maximum time in seconds that a client can be idle before the connection is closed. | 60 | +| `HTTP_READ_TIMEOUT` | The maximum time in seconds that a client can take to send the request headers and body. | 30 | +| `HTTP_WRITE_TIMEOUT` | The maximum time in seconds during which the client must read the response. | 30 | +| `H2C_ENABLED` | Set to `1` or `true` to enable h2c (http/2 cleartext) | Disabled | +| `WAIT_FOR_TARGET_PORT` | If set to `1` or `true`, Thruster will wait for the upstream application to bind to its port before starting the proxy server. | Disabled | +| `WAIT_FOR_TARGET_PORT_TIMEOUT` | The maximum time in seconds to wait for the upstream port to open. | 60 | +| `ACME_DIRECTORY` | The URL of the ACME directory to use for TLS certificate provisioning. | `https://acme-v02.api.letsencrypt.org/directory` (Let's Encrypt production) | +| `EAB_KID` | The EAB key identifier to use when provisioning TLS certificates, if required. | None | +| `EAB_HMAC_KEY` | The Base64-encoded EAB HMAC key to use when provisioning TLS certificates, if required. | None | +| `FORWARD_HEADERS` | Whether to forward X-Forwarded-* headers from the client. | Disabled when running with TLS; enabled otherwise | +| `LOG_REQUESTS` | Log all requests. Set to `0` or `false` to disable request logging | Enabled | +| `DEBUG` | Set to `1` or `true` to enable debug logging. | Disabled | To prevent naming clashes with your application's own environment variables, Thruster's environment variables can optionally be prefixed with `THRUSTER_`. diff --git a/internal/config.go b/internal/config.go index fd2887d..b4500fd 100644 --- a/internal/config.go +++ b/internal/config.go @@ -40,6 +40,9 @@ const ( defaultGzipCompressionDisableOnAuth = false defaultGzipCompressionJitter = 32 + + defaultWaitForTargetPort = false + defaultWaitForTargetPortTimeout = 60 * time.Second ) type Config struct { @@ -70,6 +73,9 @@ type Config struct { H2CEnabled bool + WaitForTargetPort bool + WaitForTargetPortTimeout time.Duration + ForwardHeaders bool LogLevel slog.Level @@ -114,6 +120,9 @@ func NewConfig() (*Config, error) { H2CEnabled: getEnvBool("H2C_ENABLED", defaultH2CEnabled), + WaitForTargetPort: getEnvBool("WAIT_FOR_TARGET_PORT", defaultWaitForTargetPort), + WaitForTargetPortTimeout: getEnvDuration("WAIT_FOR_TARGET_PORT_TIMEOUT", defaultWaitForTargetPortTimeout), + LogLevel: logLevel, LogRequests: getEnvBool("LOG_REQUESTS", defaultLogRequests), } diff --git a/internal/config_test.go b/internal/config_test.go index 5dc1d9b..e578eaf 100644 --- a/internal/config_test.go +++ b/internal/config_test.go @@ -106,6 +106,8 @@ func TestConfig_defaults(t *testing.T) { assert.Equal(t, defaultCacheSize, c.CacheSizeBytes) assert.Equal(t, slog.LevelInfo, c.LogLevel) assert.Equal(t, false, c.H2CEnabled) + assert.Equal(t, false, c.WaitForTargetPort) + assert.Equal(t, 60*time.Second, c.WaitForTargetPortTimeout) } func TestConfig_override_defaults_with_env_vars(t *testing.T) { @@ -121,6 +123,8 @@ func TestConfig_override_defaults_with_env_vars(t *testing.T) { usingEnvVar(t, "H2C_ENABLED", "true") usingEnvVar(t, "GZIP_COMPRESSION_DISABLE_ON_AUTH", "true") usingEnvVar(t, "GZIP_COMPRESSION_JITTER", "64") + usingEnvVar(t, "WAIT_FOR_TARGET_PORT", "true") + usingEnvVar(t, "WAIT_FOR_TARGET_PORT_TIMEOUT", "5") c, err := NewConfig() require.NoError(t, err) @@ -136,6 +140,8 @@ func TestConfig_override_defaults_with_env_vars(t *testing.T) { assert.Equal(t, true, c.H2CEnabled) assert.Equal(t, true, c.GzipCompressionDisableOnAuth) assert.Equal(t, 64, c.GzipCompressionJitter) + assert.Equal(t, true, c.WaitForTargetPort) + assert.Equal(t, 5*time.Second, c.WaitForTargetPortTimeout) } func TestConfig_override_defaults_with_env_vars_using_prefix(t *testing.T) { @@ -147,6 +153,8 @@ func TestConfig_override_defaults_with_env_vars_using_prefix(t *testing.T) { usingEnvVar(t, "THRUSTER_DEBUG", "1") usingEnvVar(t, "THRUSTER_LOG_REQUESTS", "0") usingEnvVar(t, "THRUSTER_H2C_ENABLED", "1") + usingEnvVar(t, "THRUSTER_WAIT_FOR_TARGET_PORT", "1") + usingEnvVar(t, "THRUSTER_WAIT_FOR_TARGET_PORT_TIMEOUT", "10") c, err := NewConfig() require.NoError(t, err) @@ -158,6 +166,8 @@ func TestConfig_override_defaults_with_env_vars_using_prefix(t *testing.T) { assert.Equal(t, slog.LevelDebug, c.LogLevel) assert.Equal(t, false, c.LogRequests) assert.Equal(t, true, c.H2CEnabled) + assert.Equal(t, true, c.WaitForTargetPort) + assert.Equal(t, 10*time.Second, c.WaitForTargetPortTimeout) } func TestConfig_prefixed_variables_take_precedence_over_non_prefixed(t *testing.T) { diff --git a/internal/service.go b/internal/service.go index c34ba73..e5bab0f 100644 --- a/internal/service.go +++ b/internal/service.go @@ -1,23 +1,74 @@ package internal import ( + "context" "fmt" "log/slog" + "net" "net/url" "os" + "os/signal" + "strconv" + "time" ) +type serviceTimeouts struct { + dialTimeout time.Duration + portCheckInterval time.Duration + fastFailureWait time.Duration + gracefulShutdown time.Duration + shutdownEscalation time.Duration + sigkillWait time.Duration + finalReapWait time.Duration +} + type Service struct { - config *Config + config *Config + dial dialer + timeouts serviceTimeouts } +type upstreamResult struct { + exitCode int + err error +} + +// dialer type for injecting net.DialContext in tests +type dialer func(ctx context.Context, network, address string) (net.Conn, error) + func NewService(config *Config) *Service { return &Service{ config: config, + dial: func(ctx context.Context, network, address string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, network, address) + }, + timeouts: serviceTimeouts{ + dialTimeout: 100 * time.Millisecond, + portCheckInterval: 100 * time.Millisecond, + fastFailureWait: 10 * time.Millisecond, + gracefulShutdown: 10 * time.Second, + shutdownEscalation: 5 * time.Second, + sigkillWait: 1 * time.Second, + finalReapWait: 5 * time.Second, + }, } } func (s *Service) Run() int { + // Initialize the signal channel early so it can be managed cleanly + // and we don't miss signals while booting. + signalChan := make(chan os.Signal, 1) + signal.Notify(signalChan, terminationSignals...) + defer signal.Stop(signalChan) + + if s.config.WaitForTargetPort { + if s.isPortOpen() { + slog.Error("Target port is already in use before starting upstream", "port", s.config.TargetPort) + return 1 + } + } + handlerOptions := HandlerOptions{ cache: s.cache(), targetUrl: s.targetUrl(), @@ -36,23 +87,229 @@ func (s *Service) Run() int { server := NewServer(s.config, handler) upstream := NewUpstreamProcess(s.config.UpstreamCommand, s.config.UpstreamArgs...) + s.setEnvironment() + + resultChan := make(chan upstreamResult, 1) + + // Start the upstream process + go func() { + exitCode, err := upstream.Run() + resultChan <- upstreamResult{exitCode: exitCode, err: err} + }() + + if s.config.WaitForTargetPort { + upstreamRes, sig, err := s.waitForTargetPort(resultChan, signalChan, s.dial) + if err != nil { + slog.Error("Failed waiting for target port", "error", err) + + // Upstream has already exited, no need to signal + if upstreamRes != nil { + slog.Info("Upstream process is already dead.") + return resolveExitCode(*upstreamRes, 1) + } + + // Determine which signal to use and calculate the appropriate fallback code + relaySig := defaultTerminationSignal + fallbackExitCode := exitCodeFromSignal(defaultTerminationSignal) + + if sig != nil { + relaySig = sig + fallbackExitCode = exitCodeFromSignal(sig) + } + + // Upstream is still running but port never opened (or we were interrupted), shut it down + return s.terminateUpstream(upstream, resultChan, relaySig, s.timeouts.shutdownEscalation, fallbackExitCode) + } + slog.Info("Upstream service is bound to port, starting proxy server.") + } else { + // Non-blocking wait to catch fast command failures without penalizing success. + timer := time.NewTimer(s.timeouts.fastFailureWait) + select { + case result := <-resultChan: + stopTimer(timer) + slog.Error("Upstream process exited prematurely", "command", s.config.UpstreamCommand, "exit_code", result.exitCode, "error", result.err) + return resolveExitCode(result, 1) + case <-timer.C: + // Upstream is running (or didn't fail instantly), proceed + } + } + if err := server.Start(); err != nil { - return 1 + slog.Error("Failed to start proxy server", "error", err) + return s.terminateUpstream(upstream, resultChan, defaultTerminationSignal, s.timeouts.shutdownEscalation, exitCodeFromSignal(defaultTerminationSignal)) } defer server.Stop() - s.setEnvironment() + return s.awaitTermination(upstream, resultChan, signalChan) +} - exitCode, err := upstream.Run() - if err != nil { - slog.Error("Failed to start wrapped process", "command", s.config.UpstreamCommand, "args", s.config.UpstreamArgs, "error", err) - return 1 +// Private + +// terminateUpstream encapsulates the escalation policy: signal -> wait -> force kill -> reap +func (s *Service) terminateUpstream(upstream *UpstreamProcess, resultChan <-chan upstreamResult, sig os.Signal, timeout time.Duration, fallbackExitCode int) int { + if upstream == nil { + return fallbackExitCode + } + + // Wait for the upstream process to either start successfully or fail immediately. + // This prevents signaling before the process is fully initialized. + select { + case result := <-resultChan: + slog.Info("Upstream process terminated before signal could be sent.") + return resolveExitCode(result, fallbackExitCode) + case <-upstream.Started(): + // Process is running (or failed to start), proceed + } + + slog.Info("Sending signal to upstream process...", "signal", sig) + if err := upstream.Signal(sig); err != nil { + slog.Error("Failed to send signal to upstream process", "error", err) + } + + if res, ok := waitResult(resultChan, timeout); ok { + slog.Info("Upstream process terminated after signal.") + return resolveExitCode(res, fallbackExitCode) } - return exitCode + slog.Warn("Upstream process did not terminate within timeout, killing it.", "timeout", timeout) + + // Map safely to os.Kill to honor the UpstreamProcess cross-platform encapsulation + if err := upstream.Signal(os.Kill); err != nil { + slog.Error("Failed to send KILL signal to upstream process", "error", err) + } + + if res, ok := waitResult(resultChan, s.timeouts.sigkillWait); ok { + return resolveExitCode(res, fallbackExitCode) + } + + // Ensure we do not orphan the running process, but provide a hard upper bound + slog.Error("Upstream process still running after os.Kill, waiting for OS to reap it...", "timeout", s.timeouts.finalReapWait) + if res, ok := waitResult(resultChan, s.timeouts.finalReapWait); ok { + return resolveExitCode(res, fallbackExitCode) + } + + slog.Error("Upstream process completely unresponsive, exiting wrapper and abandoning child.") + return fallbackExitCode } -// Private +func (s *Service) isPortOpen() bool { + portStr := strconv.Itoa(s.config.TargetPort) + addrs := []string{ + net.JoinHostPort("127.0.0.1", portStr), + net.JoinHostPort("::1", portStr), + } + + for _, addr := range addrs { + dialCtx, cancelDial := context.WithTimeout(context.Background(), s.timeouts.dialTimeout) + conn, err := s.dial(dialCtx, "tcp", addr) + cancelDial() + + if err == nil { + conn.Close() // Port is open + return true + } + } + return false +} + +func (s *Service) waitForTargetPort(resultChan <-chan upstreamResult, signalChan <-chan os.Signal, dial dialer) (*upstreamResult, os.Signal, error) { + ctx, cancel := context.WithTimeout(context.Background(), s.config.WaitForTargetPortTimeout) + defer cancel() + + portStr := strconv.Itoa(s.config.TargetPort) + addrs := []string{ + net.JoinHostPort("127.0.0.1", portStr), + net.JoinHostPort("::1", portStr), + } + slog.Info("Waiting for upstream to bind to port", "addresses", addrs) + + tryDial := func() bool { + for _, addr := range addrs { + // Cap the dial attempt to prevent unnecessary blocks on individual checks + dialCtx, cancelDial := context.WithTimeout(ctx, s.timeouts.dialTimeout) + conn, err := dial(dialCtx, "tcp", addr) + cancelDial() + + if err == nil { + conn.Close() // Port is open + return true + } + } + return false + } + + prematureExitError := func(result upstreamResult) (*upstreamResult, os.Signal, error) { + if result.err != nil { + return &result, nil, fmt.Errorf("upstream process exited prematurely with code %d: %w", result.exitCode, result.err) + } + return &result, nil, fmt.Errorf("upstream process exited prematurely with code %d", result.exitCode) + } + + checkPrematureExit := func() (*upstreamResult, os.Signal, error) { + select { + case result := <-resultChan: + return prematureExitError(result) + default: + return nil, nil, nil + } + } + + // Attempt a fast TCP connection immediately to prevent unnecessary latency on boot + if tryDial() { + return checkPrematureExit() + } + + // Fallback to checking continuously + ticker := time.NewTicker(s.timeouts.portCheckInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return nil, nil, fmt.Errorf("timed out after %v waiting for port %d", s.config.WaitForTargetPortTimeout, s.config.TargetPort) + + case sig := <-signalChan: + return nil, sig, fmt.Errorf("received signal %v while waiting for target port", sig) + + case result := <-resultChan: + return prematureExitError(result) + + case <-ticker.C: + if tryDial() { + return checkPrematureExit() + } + } + } +} + +func (s *Service) awaitTermination(upstream *UpstreamProcess, resultChan <-chan upstreamResult, signalChan <-chan os.Signal) int { + handleResult := func(result upstreamResult) int { + slog.Info("Wrapped process exited on its own.", "exit_code", result.exitCode) + if result.err != nil { + slog.Error("Wrapped process failed", "command", s.config.UpstreamCommand, "args", s.config.UpstreamArgs, "error", result.err) + // Return the upstream's exit code if available, fallback to 1 otherwise + return resolveExitCode(result, 1) + } + return result.exitCode + } + + select { + case result := <-resultChan: + return handleResult(result) + + case sig := <-signalChan: + // Prioritize an already-available upstream result to prevent race conditions + select { + case result := <-resultChan: + return handleResult(result) + default: + } + + slog.Info("Received signal, shutting down.", "signal", sig) + fallback := exitCodeFromSignal(sig) + return s.terminateUpstream(upstream, resultChan, sig, s.timeouts.gracefulShutdown, fallback) + } +} func (s *Service) cache() Cache { return NewMemoryCache(s.config.CacheSizeBytes, s.config.MaxCacheItemSizeBytes) @@ -65,5 +322,36 @@ func (s *Service) targetUrl() *url.URL { func (s *Service) setEnvironment() { // Set PORT to be inherited by the upstream process. - os.Setenv("PORT", fmt.Sprintf("%d", s.config.TargetPort)) + os.Setenv("PORT", strconv.Itoa(s.config.TargetPort)) +} + +// stopTimer safely stops a timer and drains its channel if it already fired +func stopTimer(t *time.Timer) { + if !t.Stop() { + select { + case <-t.C: + default: + } + } +} + +// waitResult wraps waiting for a process result with a timer, returning whether a result was received +func waitResult(resultChan <-chan upstreamResult, timeout time.Duration) (upstreamResult, bool) { + timer := time.NewTimer(timeout) + defer stopTimer(timer) + + select { + case res := <-resultChan: + return res, true + case <-timer.C: + return upstreamResult{}, false + } +} + +// resolveExitCode returns the upstream's exit code if non-zero, otherwise the provided fallback +func resolveExitCode(result upstreamResult, fallback int) int { + if result.exitCode != 0 { + return result.exitCode + } + return fallback } diff --git a/internal/service_test.go b/internal/service_test.go new file mode 100644 index 0000000..8155729 --- /dev/null +++ b/internal/service_test.go @@ -0,0 +1,390 @@ +package internal + +import ( + "context" + "errors" + "net" + "os" + "os/signal" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestService_Run_PortAlreadyInUse(t *testing.T) { + service := &Service{ + config: &Config{ + UpstreamCommand: os.Args[0], + UpstreamArgs: []string{"-test.run=TestHelperProcess"}, + TargetPort: 3000, + WaitForTargetPort: true, + }, + dial: func(ctx context.Context, network, address string) (net.Conn, error) { + // Simulate port already being open before upstream starts + client, server := net.Pipe() + server.Close() + return client, nil + }, + timeouts: serviceTimeouts{ + dialTimeout: 10 * time.Millisecond, + }, + } + + exitCode := service.Run() + + // Should fail immediately because the port is already open + assert.Equal(t, 1, exitCode) +} + +func TestService_waitForTargetPort(t *testing.T) { + t.Run("success when port opens", func(t *testing.T) { + service := &Service{ + config: &Config{ + TargetPort: 3000, + WaitForTargetPortTimeout: 2 * time.Second, // Relaxed margins for CI + }, + timeouts: serviceTimeouts{ + dialTimeout: 100 * time.Millisecond, + portCheckInterval: 100 * time.Millisecond, + }, + } + + resultChan := make(chan upstreamResult, 1) + signalChan := make(chan os.Signal, 1) + + dialAttempts := 0 + mockDial := func(ctx context.Context, network, address string) (net.Conn, error) { + dialAttempts++ + // First few iterations fail, then it opens + if dialAttempts >= 3 { + client, server := net.Pipe() + server.Close() // Close the other end to prevent resource leaks + return client, nil + } + return nil, errors.New("connection refused") + } + + _, _, err := service.waitForTargetPort(resultChan, signalChan, mockDial) + + require.NoError(t, err) + assert.GreaterOrEqual(t, dialAttempts, 3) + }) + + t.Run("timeout when port never opens", func(t *testing.T) { + service := &Service{ + config: &Config{ + TargetPort: 3000, + WaitForTargetPortTimeout: 300 * time.Millisecond, // Slightly longer to prevent flakes + }, + timeouts: serviceTimeouts{ + dialTimeout: 50 * time.Millisecond, + portCheckInterval: 50 * time.Millisecond, + }, + } + + resultChan := make(chan upstreamResult, 1) + signalChan := make(chan os.Signal, 1) + + mockDial := func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, errors.New("connection refused") + } + + _, _, err := service.waitForTargetPort(resultChan, signalChan, mockDial) + + require.Error(t, err) + assert.Contains(t, err.Error(), "timed out after") + }) + + t.Run("returns error if upstream exits early", func(t *testing.T) { + service := &Service{ + config: &Config{ + TargetPort: 3000, + WaitForTargetPortTimeout: 2 * time.Second, + }, + timeouts: serviceTimeouts{ + dialTimeout: 100 * time.Millisecond, + portCheckInterval: 100 * time.Millisecond, + }, + } + + resultChan := make(chan upstreamResult, 1) + signalChan := make(chan os.Signal, 1) + + mockDial := func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, errors.New("connection refused") + } + + // Simulate the upstream crashing immediately (e.g. bad command) + resultChan <- upstreamResult{exitCode: 1, err: errors.New("command not found")} + + res, _, err := service.waitForTargetPort(resultChan, signalChan, mockDial) + + require.Error(t, err) + require.NotNil(t, res) + assert.Contains(t, err.Error(), "upstream process exited prematurely") + }) + + t.Run("aborts early if signal is received", func(t *testing.T) { + service := &Service{ + config: &Config{ + TargetPort: 3000, + WaitForTargetPortTimeout: 5 * time.Second, // Long enough to ensure we don't hit normal timeout + }, + timeouts: serviceTimeouts{ + dialTimeout: 100 * time.Millisecond, + portCheckInterval: 100 * time.Millisecond, + }, + } + + resultChan := make(chan upstreamResult, 1) + signalChan := make(chan os.Signal, 1) + + mockDial := func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, errors.New("connection refused") + } + + // Fire a signal shortly after starting to interrupt the wait loop + go func() { + time.Sleep(50 * time.Millisecond) + signalChan <- os.Interrupt + }() + + _, sig, err := service.waitForTargetPort(resultChan, signalChan, mockDial) + + require.Error(t, err) + assert.Equal(t, os.Interrupt, sig) + assert.Contains(t, err.Error(), "received signal interrupt while waiting for target port") + }) +} + +func TestService_Run_WaitForTargetPortFailureCleanup(t *testing.T) { + service := &Service{ + config: &Config{ + UpstreamCommand: os.Args[0], + UpstreamArgs: []string{"-test.run=TestHelperProcess"}, + TargetPort: 3000, + WaitForTargetPort: true, + WaitForTargetPortTimeout: 300 * time.Millisecond, + CacheSizeBytes: 1024, + MaxCacheItemSizeBytes: 1024, + }, + dial: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, errors.New("connection refused") + }, + timeouts: serviceTimeouts{ + dialTimeout: 100 * time.Millisecond, + portCheckInterval: 100 * time.Millisecond, + fastFailureWait: 100 * time.Millisecond, + gracefulShutdown: 1 * time.Second, + shutdownEscalation: 1 * time.Second, + sigkillWait: 1 * time.Second, + finalReapWait: 1 * time.Second, + }, + } + + // Start helper in an env that expects to catch an interrupt and gracefully die + t.Setenv("GO_WANT_HELPER_PROCESS", "1") + + exitCode := service.Run() + + // The wrapper correctly favors the actual exit code over the fallback. + // On Unix, defaultTerminationSignal is SIGTERM, so the helper catches it and yields 143. + // On Windows, it is os.Kill, which cannot be caught, and the OS kills the process (typically yielding 1). + expectedExitCode := exitCodeFromSignal(defaultTerminationSignal) + assert.Contains(t, []int{1, expectedExitCode}, exitCode) +} + +func TestService_Run_EscalatesToKill(t *testing.T) { + service := &Service{ + config: &Config{ + UpstreamCommand: os.Args[0], + UpstreamArgs: []string{"-test.run=TestHelperProcess"}, + TargetPort: 3000, + WaitForTargetPort: true, + WaitForTargetPortTimeout: 150 * time.Millisecond, + CacheSizeBytes: 1024, + MaxCacheItemSizeBytes: 1024, + }, + dial: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, errors.New("connection refused") + }, + timeouts: serviceTimeouts{ + dialTimeout: 100 * time.Millisecond, + portCheckInterval: 100 * time.Millisecond, + fastFailureWait: 100 * time.Millisecond, + gracefulShutdown: 500 * time.Millisecond, + shutdownEscalation: 500 * time.Millisecond, + sigkillWait: 500 * time.Millisecond, + finalReapWait: 500 * time.Millisecond, + }, + } + + // Start helper in an env that expects to ignore the first signal + t.Setenv("GO_WANT_HELPER_PROCESS", "1") + t.Setenv("GO_WANT_HELPER_PROCESS_IGNORE_SIGNAL", "1") + + exitCode := service.Run() + + // Since we initiated the shutdown and escalated to KILL, the OS forcefully reaps the child. + // This results in 137 on POSIX, and 1 on Windows. + assert.Contains(t, []int{1, 137}, exitCode) +} + +func TestService_Run_FastFailure_ImmediateExit(t *testing.T) { + t.Setenv("GO_WANT_HELPER_PROCESS", "1") + t.Setenv("GO_WANT_HELPER_PROCESS_EXIT_CODE", "42") + + service := &Service{ + config: &Config{ + UpstreamCommand: os.Args[0], + UpstreamArgs: []string{"-test.run=TestHelperProcess"}, + WaitForTargetPort: false, + HttpPort: 0, // Random OS port + TargetPort: 3000, + CacheSizeBytes: 1024, + MaxCacheItemSizeBytes: 1024, + }, + timeouts: serviceTimeouts{ + fastFailureWait: 100 * time.Millisecond, + }, + } + + exitCode := service.Run() + assert.Equal(t, 42, exitCode) +} + +func TestService_Run_FastFailure_Proceeds(t *testing.T) { + t.Setenv("GO_WANT_HELPER_PROCESS", "1") + // Exits automatically after a delay (post-fastFailureWait) so the test gracefully finishes + t.Setenv("GO_WANT_HELPER_PROCESS_EXIT_AFTER", "500ms") + + service := &Service{ + config: &Config{ + UpstreamCommand: os.Args[0], + UpstreamArgs: []string{"-test.run=TestHelperProcess"}, + WaitForTargetPort: false, + HttpPort: 0, // Random OS port + TargetPort: 3000, + CacheSizeBytes: 1024, + MaxCacheItemSizeBytes: 1024, + }, + timeouts: serviceTimeouts{ + fastFailureWait: 100 * time.Millisecond, + gracefulShutdown: 1 * time.Second, + finalReapWait: 1 * time.Second, + }, + } + + exitCode := service.Run() + + // Since the upstream successfully bypassed the fastFailureWait check, + // the proxy server should have started, reaching awaitTermination(). + // Upon the helper exiting automatically, awaitTermination returns the 0 exit code. + assert.Equal(t, 0, exitCode) +} + +func TestService_awaitTermination_NormalExit(t *testing.T) { + // Provide a dummy config so s.config.UpstreamCommand doesn't panic on error logs + service := &Service{ + config: &Config{ + UpstreamCommand: "dummy", + UpstreamArgs: []string{"-v"}, + }, + timeouts: serviceTimeouts{ + gracefulShutdown: 1 * time.Second, + finalReapWait: 1 * time.Second, + }, + } + resultChan := make(chan upstreamResult, 1) + signalChan := make(chan os.Signal, 1) + + // Simulate successful graceful exit + resultChan <- upstreamResult{exitCode: 0, err: nil} + exitCode := service.awaitTermination(nil, resultChan, signalChan) + assert.Equal(t, 0, exitCode) + + // Simulate error exit + resultChan <- upstreamResult{exitCode: 1, err: errors.New("crash")} + exitCode = service.awaitTermination(nil, resultChan, signalChan) + assert.Equal(t, 1, exitCode) +} + +// TestHelperProcess is used to run a hermetic sub-process that correctly +// handles OS-level semantics in a portable way. +func TestHelperProcess(t *testing.T) { + if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" { + return + } + + if codeStr := os.Getenv("GO_WANT_HELPER_PROCESS_EXIT_CODE"); codeStr != "" { + if code, err := strconv.Atoi(codeStr); err == nil { + os.Exit(code) + } + } + + if delayStr := os.Getenv("GO_WANT_HELPER_PROCESS_EXIT_AFTER"); delayStr != "" { + if delay, err := time.ParseDuration(delayStr); err == nil { + time.Sleep(delay) + os.Exit(0) + } + } + + if os.Getenv("GO_WANT_HELPER_PROCESS_IGNORE_SIGNAL") == "1" { + c := make(chan os.Signal, 1) + signal.Notify(c, terminationSignals...) + defer signal.Stop(c) + <-c + // Ignore the signal and wait to be killed forcefully (but bounded) + time.Sleep(15 * time.Second) + os.Exit(0) + } + + // Block until signal is received + c := make(chan os.Signal, 1) + signal.Notify(c, terminationSignals...) + defer signal.Stop(c) + sig := <-c + + // Exit gracefully using cross-platform signal exit map + os.Exit(exitCodeFromSignal(sig)) +} + +func TestService_awaitTermination_Signal(t *testing.T) { + service := &Service{ + config: &Config{ + UpstreamCommand: os.Args[0], + UpstreamArgs: []string{"-test.run=TestHelperProcess"}, + }, + timeouts: serviceTimeouts{ + gracefulShutdown: 5 * time.Second, + finalReapWait: 1 * time.Second, + }, + } + + // Run the test suite itself as the hermetic target process + upstream := NewUpstreamProcess(os.Args[0], "-test.run=TestHelperProcess") + upstream.cmd.Env = append(os.Environ(), "GO_WANT_HELPER_PROCESS=1") + + resultChan := make(chan upstreamResult, 1) + signalChan := make(chan os.Signal, 1) + + // Start the mocked upstream process + go func() { + exitCode, err := upstream.Run() + resultChan <- upstreamResult{exitCode: exitCode, err: err} + }() + <-upstream.Started() // Ensure it is ready to receive signals + + // Mock injecting an OS signal directly into the injected channel + // We relay unmodified, so sending os.Interrupt will hit the helper which catches it + signalChan <- os.Interrupt + + exitCode := service.awaitTermination(upstream, resultChan, signalChan) + + // Ensure the signal relay returned the correct exit code observed from the upstream + expectedExitCode := exitCodeFromSignal(os.Interrupt) + assert.Equal(t, expectedExitCode, exitCode) +} diff --git a/internal/signals_unix.go b/internal/signals_unix.go new file mode 100644 index 0000000..0a7730b --- /dev/null +++ b/internal/signals_unix.go @@ -0,0 +1,31 @@ +//go:build !windows + +package internal + +import ( + "os" + "syscall" +) + +var terminationSignals = []os.Signal{os.Interrupt, syscall.SIGTERM} +var defaultTerminationSignal os.Signal = syscall.SIGTERM + +func exitCodeFromSignal(sig os.Signal) int { + if sysSig, ok := sig.(syscall.Signal); ok { + return 128 + int(sysSig) + } + return 1 +} + +func exitCodeFromProcessState(state *os.ProcessState) int { + if status, ok := state.Sys().(syscall.WaitStatus); ok { + if status.Signaled() { + return 128 + int(status.Signal()) + } + } + return state.ExitCode() +} + +func forwardSignal(p *os.Process, sig os.Signal) error { + return p.Signal(sig) +} diff --git a/internal/signals_windows.go b/internal/signals_windows.go new file mode 100644 index 0000000..09b9bd5 --- /dev/null +++ b/internal/signals_windows.go @@ -0,0 +1,29 @@ +//go:build windows + +package internal + +import ( + "os" +) + +// Windows does not have SIGTERM natively, and os.Interrupt is unreliable via os.Process.Signal. +var terminationSignals = []os.Signal{os.Interrupt} +var defaultTerminationSignal = os.Kill + +func exitCodeFromSignal(sig os.Signal) int { + return 1 +} + +func exitCodeFromProcessState(state *os.ProcessState) int { + return state.ExitCode() +} + +func forwardSignal(p *os.Process, sig os.Signal) error { + if sig == os.Interrupt || sig == os.Kill { + // os.Interrupt and os.Kill are not reliably supported via os.Process.Signal on Windows unless + // CREATE_NEW_PROCESS_GROUP is used with specific Windows console APIs. + // Escalating to Kill ensures the child process reliably terminates. + return p.Kill() + } + return p.Signal(sig) +} diff --git a/internal/upstream_process.go b/internal/upstream_process.go index 4295f00..80f9bfc 100644 --- a/internal/upstream_process.go +++ b/internal/upstream_process.go @@ -2,65 +2,61 @@ package internal import ( "errors" - "log/slog" "os" "os/exec" - "os/signal" - "syscall" + "sync" ) +var ErrProcessNotRunning = errors.New("process not running") + type UpstreamProcess struct { - Started chan struct{} - cmd *exec.Cmd + started chan struct{} + cmd *exec.Cmd + startOnce sync.Once } func NewUpstreamProcess(name string, arg ...string) *UpstreamProcess { return &UpstreamProcess{ - Started: make(chan struct{}, 1), + started: make(chan struct{}), cmd: exec.Command(name, arg...), } } +func (p *UpstreamProcess) Started() <-chan struct{} { + return p.started +} + func (p *UpstreamProcess) Run() (int, error) { p.cmd.Stdin = os.Stdin p.cmd.Stdout = os.Stdout p.cmd.Stderr = os.Stderr err := p.cmd.Start() + + // Broadcast that the start attempt has concluded (unblocks waiters on success OR failure) + p.startOnce.Do(func() { + close(p.started) + }) + if err != nil { return 0, err } - p.Started <- struct{}{} - - go p.handleSignals() err = p.cmd.Wait() return p.handleExitCode(err) } func (p *UpstreamProcess) Signal(sig os.Signal) error { - return p.cmd.Process.Signal(sig) -} - -func (p *UpstreamProcess) handleSignals() { - ch := make(chan os.Signal, 1) - signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) - - sig := <-ch - slog.Info("Relaying signal to upstream process", "signal", sig.String()) - _ = p.Signal(sig) + if p.cmd == nil || p.cmd.Process == nil { + return ErrProcessNotRunning + } + return forwardSignal(p.cmd.Process, sig) } func (p *UpstreamProcess) handleExitCode(err error) (int, error) { - var exitErr *exec.ExitError - if errors.As(err, &exitErr) { - if status, ok := exitErr.Sys().(syscall.WaitStatus); ok { - if status.Signaled() { - return 128 + int(status.Signal()), nil - } - } - return exitErr.ExitCode(), nil + if exitErr, ok := err.(*exec.ExitError); ok { + return exitCodeFromProcessState(exitErr.ProcessState), nil } return 0, err diff --git a/internal/upstream_process_test.go b/internal/upstream_process_test.go index 27d1ed1..039b9b0 100644 --- a/internal/upstream_process_test.go +++ b/internal/upstream_process_test.go @@ -1,10 +1,14 @@ package internal import ( - "syscall" + "os" + "os/signal" + "runtime" "testing" + "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestUpstreamProcess(t *testing.T) { @@ -16,23 +20,73 @@ func TestUpstreamProcess(t *testing.T) { assert.Equal(t, 1, exitCode) }) - t.Run("signal a process to stop it", func(t *testing.T) { + t.Run("returns error if signaled before running", func(t *testing.T) { + p := NewUpstreamProcess("echo", "hello") + + // Attempt to signal before p.Run() is called + err := p.Signal(os.Interrupt) + + require.Error(t, err) + assert.ErrorIs(t, err, ErrProcessNotRunning) + }) +} + +// TestUpstreamHelperProcess is a hermetic cross-platform target for signal testing +func TestUpstreamHelperProcess(t *testing.T) { + if os.Getenv("GO_WANT_UPSTREAM_HELPER_PROCESS") != "1" { + return + } + + c := make(chan os.Signal, 1) + signal.Notify(c, terminationSignals...) + defer signal.Stop(c) + sig := <-c + + os.Exit(exitCodeFromSignal(sig)) +} + +func TestUpstreamProcess_Signal(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Windows does not support graceful interrupt signaling natively via os.Process") + } + + t.Run("successfully signals a running process", func(t *testing.T) { var exitCode int - var err error + var runErr error done := make(chan struct{}) - p := NewUpstreamProcess("sleep", "10") + p := NewUpstreamProcess(os.Args[0], "-test.run=TestUpstreamHelperProcess") + p.cmd.Env = append(os.Environ(), "GO_WANT_UPSTREAM_HELPER_PROCESS=1") + + t.Cleanup(func() { + if p.cmd != nil && p.cmd.Process != nil { + _ = p.cmd.Process.Kill() + } + }) go func() { - exitCode, err = p.Run() + exitCode, runErr = p.Run() close(done) }() - <-p.Started - assert.NoError(t, p.Signal(syscall.SIGTERM)) - <-done + select { + case <-p.Started(): + // Process has been spawned + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for upstream to start") + } + err := p.Signal(os.Interrupt) assert.NoError(t, err) - assert.Equal(t, 128+int(syscall.SIGTERM), exitCode) + + select { + case <-done: + // Process exited + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for upstream to exit") + } + + assert.NoError(t, runErr) + assert.Equal(t, exitCodeFromSignal(os.Interrupt), exitCode) }) }