diff --git a/cmd/mint/main.go b/cmd/mint/main.go index 4f9cdd0..67eb812 100644 --- a/cmd/mint/main.go +++ b/cmd/mint/main.go @@ -30,6 +30,10 @@ var ( // language-neutral input (numbers, symbols, etc.). const neutralLang = "neutral" +// posixLocale is the POSIX/C locale name, used as a locale env var value +// that carries no usable language subtag. +const posixLocale = "POSIX" + func main() { os.Exit(run()) } @@ -41,7 +45,15 @@ func run() int { defer stop() if err := newRootCmd().ExecuteContext(ctx); err != nil { - if errors.Is(err, context.Canceled) { + // Compare against context.Cause(ctx), not context.Canceled: net/http + // surfaces context.Cause(ctx), which signal.NotifyContext sets to a + // private signalError rather than context.Canceled. Checking errors.Is + // against the actual cause (instead of just ctx.Err() != nil) also + // makes sure an unrelated error isn't misreported as a clean interrupt + // merely because a signal happened to arrive around the same time. + // (context.Cause(ctx) is nil until ctx is done, and errors.Is(err, nil) + // is always false for a non-nil err, so no separate nil check is needed.) + if errors.Is(err, context.Cause(ctx)) { // Interrupted by the user — exit quietly with the conventional code. return 130 } @@ -502,7 +514,7 @@ func getSystemLanguage() string { lang, _, _ = strings.Cut(lang, ".") // Extract primary language subtag: "en_US" → "en"; ignore "C" / "POSIX" code, _, _ := strings.Cut(lang, "_") - if code == "" || code == "C" || code == "POSIX" { + if code == "" || code == "C" || code == posixLocale { continue } diff --git a/cmd/mint/main_test.go b/cmd/mint/main_test.go index 28e5630..115c42b 100644 --- a/cmd/mint/main_test.go +++ b/cmd/mint/main_test.go @@ -11,7 +11,9 @@ import ( "os" "slices" "strings" + "syscall" "testing" + "time" "github.com/min0625/mint/internal/llm" ) @@ -286,12 +288,12 @@ func TestGetSystemLanguage(t *testing.T) { {"lang without region", "en", "", "en"}, {"C locale skipped, uses LC_ALL", "C", "fr_FR.UTF-8", "fr"}, {"C.UTF-8 locale skipped, uses LC_ALL", "C.UTF-8", "fr_FR.UTF-8", "fr"}, - {"POSIX locale skipped, uses LC_ALL", "POSIX", "de_DE.UTF-8", "de"}, + {"POSIX locale skipped, uses LC_ALL", posixLocale, "de_DE.UTF-8", "de"}, {"LC_ALL used when LANG empty", "", "ja_JP.UTF-8", "ja"}, {"LC_ALL overrides LANG when both set", "en_US.UTF-8", "ja_JP.UTF-8", "ja"}, {"both empty returns empty string", "", "", ""}, {"LC_ALL is C, falls through to LANG", "de_DE.UTF-8", "C", "de"}, - {"LC_ALL is POSIX, falls through to LANG", "ko_KR.UTF-8", "POSIX", "ko"}, + {"LC_ALL is POSIX, falls through to LANG", "ko_KR.UTF-8", posixLocale, "ko"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -369,9 +371,10 @@ func TestResolveInputStdinReadError(t *testing.T) { } } -// captureStdout replaces os.Stdout with a pipe and returns a function that -// restores os.Stdout and returns whatever was written. -func captureStdout(t *testing.T) func() string { +// captureOutput redirects *target (os.Stdout or os.Stderr) to a pipe and +// returns a function that restores the original file and returns whatever +// was written. +func captureOutput(t *testing.T, target **os.File) func() string { t.Helper() r, w, err := os.Pipe() @@ -379,13 +382,21 @@ func captureStdout(t *testing.T) func() string { t.Fatal(err) } - old := os.Stdout - os.Stdout = w + old := *target + *target = w + + // Safety net: restores target even if the test fails (t.Fatal) before + // calling the returned flush function. + t.Cleanup(func() { + *target = old + _ = w.Close() + _ = r.Close() + }) return func() string { _ = w.Close() - os.Stdout = old + *target = old var sb strings.Builder @@ -396,6 +407,22 @@ func captureStdout(t *testing.T) func() string { } } +// captureStdout replaces os.Stdout with a pipe and returns a function that +// restores os.Stdout and returns whatever was written. +func captureStdout(t *testing.T) func() string { + t.Helper() + + return captureOutput(t, &os.Stdout) +} + +// captureStderr replaces os.Stderr with a pipe and returns a function that +// restores os.Stderr and returns whatever was written. +func captureStderr(t *testing.T) func() string { + t.Helper() + + return captureOutput(t, &os.Stderr) +} + func TestNewRootCmdLangNeutral(t *testing.T) { t.Setenv("MINT_PROVIDER", "openai") t.Setenv("MINT_API_KEY", "test") @@ -651,24 +678,88 @@ func TestRunError(t *testing.T) { defer func() { os.Args = old }() // Suppress stderr to keep test output clean. - rr, ww, err := os.Pipe() - if err != nil { - t.Fatal(err) + flushErr := captureStderr(t) + + if code := run(); code != 1 { + t.Errorf("expected exit code 1, got %d", code) + } + + _ = flushErr() +} + +// TestRunInterrupted covers both signals run() registers via +// signal.NotifyContext (os.Interrupt and syscall.SIGTERM): either one must +// cancel an in-flight request and exit quietly with code 130. +func TestRunInterrupted(t *testing.T) { + signals := []struct { + name string + sig syscall.Signal + }{ + {"SIGINT", syscall.SIGINT}, + {"SIGTERM", syscall.SIGTERM}, } - oldErr := os.Stderr + for _, tt := range signals { + t.Run(tt.name, func(t *testing.T) { + started := make(chan struct{}) + done := make(chan struct{}) - os.Stderr = ww - defer func() { - _ = ww.Close() + srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + close(started) + <-done // held open until the test releases it, regardless of client-side cancellation + })) - os.Stderr = oldErr - _, _ = io.Copy(io.Discard, rr) - _ = rr.Close() - }() + defer func() { + close(done) + srv.Close() + }() - if code := run(); code != 1 { - t.Errorf("expected exit code 1, got %d", code) + t.Setenv("MINT_PROVIDER", "openai") + t.Setenv("MINT_API_KEY", "test") + t.Setenv("MINT_BASE_URL", srv.URL) + t.Setenv("MINT_MODEL_NAME", "test-model") + + old := os.Args + + os.Args = []string{"mint", "--target", "en", "hello"} + defer func() { os.Args = old }() + + flushOut := captureStdout(t) + flushErr := captureStderr(t) + + codeCh := make(chan int, 1) + + go func() { codeCh <- run() }() + + select { + case <-started: + case <-time.After(5 * time.Second): + t.Fatal("request never reached the server") + } + + if err := syscall.Kill(os.Getpid(), tt.sig); err != nil { + t.Fatalf("failed to send %s: %v", tt.name, err) + } + + var code int + + select { + case code = <-codeCh: + case <-time.After(5 * time.Second): + t.Fatalf("run() did not return after %s", tt.name) + } + + stderrOutput := flushErr() + _ = flushOut() + + if code != 130 { + t.Errorf("expected exit code 130, got %d", code) + } + + if stderrOutput != "" { + t.Errorf("expected no stderr output on interrupt, got: %q", stderrOutput) + } + }) } }