diff --git a/packages/cmd/login_status.go b/packages/cmd/login_status.go new file mode 100644 index 00000000..62bda7ac --- /dev/null +++ b/packages/cmd/login_status.go @@ -0,0 +1,629 @@ +package cmd + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "os" + "strings" + "time" + + "github.com/Infisical/infisical-merge/packages/api" + "github.com/Infisical/infisical-merge/packages/config" + "github.com/Infisical/infisical-merge/packages/util" + "github.com/fatih/color" + jwt "github.com/golang-jwt/jwt/v5" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" +) + +const ( + statusAuthenticated = "authenticated" + statusExpired = "expired" + statusRejected = "rejected" + + principalKindUser = "user" + principalKindMachineIdentity = "machine-identity" + principalKindServiceToken = "service-token" + + verifyStateVerified = "verified" + verifyStateRejected = "rejected" + verifyStateUnknown = "unknown" + verifyStateSkipped = "skipped" + + tokenSourceLoginSession = "infisical login (keyring)" + + verifyTimeout = 10 * time.Second + + authTokenTypeAccess = "accessToken" + authTokenTypeIdentityAccess = "identityAccessToken" +) + +var ( + boldStyle = color.New(color.Bold) + greenStyle = color.New(color.FgGreen, color.Bold) + redStyle = color.New(color.FgRed, color.Bold) +) + +var loginStatusCmd = &cobra.Command{ + Use: "status", + Short: "View the current authentication status", + Long: "Reports whether the CLI is authenticated to Infisical and, when available, the organization the active session is scoped to.", + DisableFlagsInUseLine: true, + Example: "infisical login status", + Args: cobra.NoArgs, + Run: runLoginStatus, +} + +func runLoginStatus(cmd *cobra.Command, args []string) { + jsonOutput, _ := cmd.Flags().GetBool("json") + + // Machine identity / service token domain comes from the env/flag-driven + // config.INFISICAL_URL. The user-session domain is whatever the user's + // saved config points at, which GetCurrentLoggedInUserDetails(true) writes + // back into config.INFISICAL_URL — so capture the env value first. + envDomain := strings.TrimSuffix(config.INFISICAL_URL, "/api") + + flagToken, _ := cmd.Flags().GetString("token") + flagToken = strings.TrimSpace(flagToken) + if flagToken != "" { + if !cmd.Flags().Changed("domain") { + if _, envSet := os.LookupEnv("INFISICAL_API_URL"); !envSet { + util.PrintErrorMessageAndExit("--token requires --domain (or INFISICAL_API_URL) to be set so the status reflects the correct Infisical instance") + } + } + ctx, err := buildContextFromToken(flagToken, "--token flag", envDomain) + if err != nil { + util.PrintErrorMessageAndExit(err.Error()) + } + ctx.verification = verifySession(ctx) + emitLoginStatus([]loginStatusContext{ctx}, jsonOutput) + if shouldExitWithError(ctx) { + os.Exit(1) + } + return + } + + var sessions []loginStatusContext + + if token, source, ok := detectMachineIdentityEnvToken(); ok { + ctx, err := buildContextFromToken(token, source, envDomain) + if err != nil { + util.PrintErrorMessageAndExit(err.Error()) + } + sessions = append(sessions, ctx) + } + + loggedInUserDetails, err := util.GetCurrentLoggedInUserDetails(true) + if err != nil && !errors.Is(err, util.ErrUserNotLoggedIn) { + util.HandleError(err, "Unable to read logged-in user details") + } + if loggedInUserDetails.IsUserLoggedIn { + userDomain := strings.TrimSuffix(config.INFISICAL_URL, "/api") + sessions = append(sessions, buildUserContext(loggedInUserDetails, userDomain)) + } + + if len(sessions) == 0 { + renderNotAuthenticated(jsonOutput) + os.Exit(1) + } + + for i := range sessions { + sessions[i].verification = verifySession(sessions[i]) + } + + emitLoginStatus(sessions, jsonOutput) + + for _, s := range sessions { + if shouldExitWithError(s) { + os.Exit(1) + } + } +} + +func buildUserContext(details util.LoggedInUserDetails, domain string) loginStatusContext { + claims, claimsErr := parseLoginJWTClaims(details.UserCredentials.JTWToken) + return loginStatusContext{ + kind: principalKindUser, + domain: domain, + loggedInUser: details, + rawToken: details.UserCredentials.JTWToken, + claims: claims, + claimsErr: claimsErr, + } +} + +// classifyToken determines whether a raw credential is a service token, a user +// session JWT, or a machine identity access token JWT. Service tokens are +// recognized by their "st." prefix; JWTs are dispatched on the authTokenType +// claim the backend stamps into every token it signs. For very old JWTs that +// pre-date that claim, falls back to looking for identityId / userId so we +// preserve back-compat without misclassifying a user JWT as a machine identity. +func classifyToken(token string) (string, loginTokenClaims, error) { + if strings.HasPrefix(token, "st.") { + return principalKindServiceToken, loginTokenClaims{}, nil + } + + claims, err := parseLoginJWTClaims(token) + if err != nil { + return "", claims, err + } + + switch claims.AuthTokenType { + case authTokenTypeIdentityAccess: + return principalKindMachineIdentity, claims, nil + case authTokenTypeAccess: + return principalKindUser, claims, nil + case "": + // Legacy tokens issued before authTokenType existed. + if claims.IdentityID != "" { + return principalKindMachineIdentity, claims, nil + } + if claims.UserID != "" { + return principalKindUser, claims, nil + } + return principalKindMachineIdentity, claims, nil + default: + return "", claims, fmt.Errorf("unsupported token type %q (CLI only accepts user access tokens and machine identity access tokens)", claims.AuthTokenType) + } +} + +// buildContextFromToken constructs a status context for any externally-supplied +// token (--token flag or environment variable). The principal kind is derived +// from the token itself rather than from where it came from. +func buildContextFromToken(token, source, domain string) (loginStatusContext, error) { + kind, claims, classifyErr := classifyToken(token) + if classifyErr != nil && kind == "" { + return loginStatusContext{}, classifyErr + } + + ctx := loginStatusContext{ + kind: kind, + domain: domain, + rawToken: token, + tokenSource: source, + } + if kind != principalKindServiceToken { + ctx.claims = claims + ctx.claimsErr = classifyErr + } + return ctx, nil +} + +func isContextExpired(ctx loginStatusContext) bool { + if ctx.kind == principalKindUser && ctx.loggedInUser.LoginExpired { + return true + } + if ctx.kind == principalKindServiceToken { + return false + } + return ctx.claimsErr == nil && ctx.claims.ExpiresAt != nil && !ctx.claims.ExpiresAt.After(time.Now()) +} + +func contextStatus(ctx loginStatusContext) string { + if isContextExpired(ctx) { + return statusExpired + } + if ctx.verification.state == verifyStateRejected { + return statusRejected + } + return statusAuthenticated +} + +func shouldExitWithError(ctx loginStatusContext) bool { + s := contextStatus(ctx) + return s == statusExpired || s == statusRejected +} + +func emitLoginStatus(sessions []loginStatusContext, jsonOutput bool) { + if jsonOutput { + if err := writeLoginStatusJSON(buildJSONOutput(sessions)); err != nil { + util.HandleError(err, "Unable to encode JSON output") + } + return + } + for i, s := range sessions { + if i > 0 { + util.PrintlnStdout("") + } + renderHuman(s) + } +} + +func detectMachineIdentityEnvToken() (token, source string, ok bool) { + if v := strings.TrimSpace(os.Getenv(util.INFISICAL_TOKEN_NAME)); v != "" { + return v, fmt.Sprintf("%s environment variable", util.INFISICAL_TOKEN_NAME), true + } + return "", "", false +} + +type loginStatusContext struct { + kind string + domain string + loggedInUser util.LoggedInUserDetails // populated when kind == principalKindUser + rawToken string // bearer credential used for backend verification + tokenSource string // populated for machine-identity / service-token + claims loginTokenClaims + claimsErr error + verification verificationResult +} + +type verificationResult struct { + state string + reason string +} + +type loginStatusJSONOutput struct { + Sessions []loginStatusSessionJSON `json:"sessions"` +} + +type loginStatusSessionJSON struct { + PrincipalType string `json:"principalType,omitempty"` + Status string `json:"status,omitempty"` + Domain string `json:"domain,omitempty"` + Email string `json:"email,omitempty"` + UserID string `json:"userId,omitempty"` + AuthMethod string `json:"authMethod,omitempty"` + TokenSource string `json:"tokenSource,omitempty"` + Identity *loginStatusIdentityJSON `json:"identity,omitempty"` + Token *loginStatusTokenJSON `json:"token,omitempty"` + Organization *string `json:"organization,omitempty"` + SubOrganization *string `json:"subOrganization,omitempty"` + Verification *loginStatusVerificationJSON `json:"verification,omitempty"` +} + +type loginStatusTokenJSON struct { + Exp int64 `json:"exp,omitempty"` +} + +type loginStatusIdentityJSON struct { + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` +} + +type loginStatusVerificationJSON struct { + State string `json:"state"` + Reason string `json:"reason,omitempty"` +} + +func buildJSONOutput(sessions []loginStatusContext) loginStatusJSONOutput { + out := loginStatusJSONOutput{Sessions: make([]loginStatusSessionJSON, 0, len(sessions))} + for _, ctx := range sessions { + out.Sessions = append(out.Sessions, buildSessionJSON(ctx)) + } + return out +} + +func buildSessionJSON(ctx loginStatusContext) loginStatusSessionJSON { + session := loginStatusSessionJSON{ + PrincipalType: ctx.kind, + Domain: ctx.domain, + AuthMethod: authMethodLabel(ctx), + TokenSource: tokenSourceLabel(ctx), + Status: contextStatus(ctx), + Verification: verificationJSON(ctx.verification), + } + + switch ctx.kind { + case principalKindUser: + session.Email = ctx.loggedInUser.UserCredentials.Email + if ctx.claimsErr != nil { + return session + } + if ctx.claims.UserID != "" { + session.UserID = ctx.claims.UserID + } + if ctx.claims.ExpiresAt != nil { + session.Token = &loginStatusTokenJSON{Exp: ctx.claims.ExpiresAt.Unix()} + } + if ctx.claims.OrganizationID != "" { + session.Organization = &ctx.claims.OrganizationID + } + if ctx.claims.SubOrganizationID != "" { + session.SubOrganization = &ctx.claims.SubOrganizationID + } + + case principalKindMachineIdentity: + if ctx.claimsErr != nil { + return session + } + if ctx.claims.IdentityID != "" || ctx.claims.IdentityName != "" { + session.Identity = &loginStatusIdentityJSON{ + ID: ctx.claims.IdentityID, + Name: ctx.claims.IdentityName, + } + } + if ctx.claims.ExpiresAt != nil { + session.Token = &loginStatusTokenJSON{Exp: ctx.claims.ExpiresAt.Unix()} + } + if ctx.claims.OrgID != "" { + session.Organization = &ctx.claims.OrgID + } + } + + return session +} + +func verificationJSON(v verificationResult) *loginStatusVerificationJSON { + if v.state == "" { + return nil + } + return &loginStatusVerificationJSON{State: v.state, Reason: v.reason} +} + +func writeLoginStatusJSON(out loginStatusJSONOutput) error { + data, err := json.MarshalIndent(out, "", " ") + if err != nil { + return err + } + util.PrintlnStdout(string(data)) + return nil +} + +func renderHuman(ctx loginStatusContext) { + label := principalLabel(ctx) + status := contextStatus(ctx) + + if status == statusAuthenticated { + util.PrintfStdout("%s Authenticated as %s\n", greenStyle.Sprint("✓"), boldStyle.Sprint(label)) + } else { + util.PrintfStdout("%s Failed to authenticate as %s\n", redStyle.Sprint("x"), boldStyle.Sprint(label)) + } + + if status != statusAuthenticated { + if status == statusExpired { + printStatusItem("Reason", "session expired") + } + if line := verificationLine(ctx.verification); line != "" { + printStatusItem("Reason", line) + } + } + + if ctx.domain != "" { + printStatusItem("Domain", ctx.domain) + } + if method := authMethodLabel(ctx); method != "" { + printStatusItem("Auth method", method) + } + if ctx.kind == principalKindMachineIdentity && ctx.claimsErr == nil && ctx.claims.IdentityID != "" { + printStatusItem("Identity", ctx.claims.IdentityID) + } + if source := tokenSourceLabel(ctx); source != "" { + printStatusItem("Token source", source) + } + if ctx.kind != principalKindServiceToken { + printStatusItem("Token expiration", tokenStatusLine(ctx.claims, ctx.claimsErr)) + } + if ctx.kind == principalKindUser && ctx.claimsErr == nil && ctx.claims.UserID != "" { + printStatusItem("User ID", ctx.claims.UserID) + } + if org := organizationLineFor(ctx); org != "" { + printStatusItem("Organization", org) + } + if ctx.kind == principalKindUser && ctx.claimsErr == nil && ctx.claims.SubOrganizationID != "" { + printStatusItem("Sub-organization", ctx.claims.SubOrganizationID) + } + + if status != statusAuthenticated { + switch ctx.kind { + case principalKindUser: + util.PrintlnStdout(" - Run `infisical login` to re-authenticate.") + case principalKindMachineIdentity: + util.PrintlnStdout(" - Verify the domain being used or run `infisical login` to re-authenticate and re-export your token.") + case principalKindServiceToken: + util.PrintlnStdout(" - Verify the service token has not been revoked or expired in your Infisical instance.") + } + } +} + +func verificationLine(v verificationResult) string { + labels := map[string]string{ + verifyStateVerified: "verified", + verifyStateRejected: "rejected", + verifyStateUnknown: "unreachable", + verifyStateSkipped: "skipped", + } + label, ok := labels[v.state] + if !ok { + return "" + } + if v.reason != "" && v.state != verifyStateVerified { + return fmt.Sprintf("%s (%s)", label, v.reason) + } + return label +} + +func principalLabel(ctx loginStatusContext) string { + switch ctx.kind { + case principalKindUser: + if email := ctx.loggedInUser.UserCredentials.Email; email != "" { + return email + } + return "user" + case principalKindMachineIdentity: + if ctx.claimsErr == nil && ctx.claims.IdentityName != "" { + return ctx.claims.IdentityName + } + return "machine identity" + case principalKindServiceToken: + return "service token" + } + return "" +} + +func authMethodLabel(ctx loginStatusContext) string { + if ctx.claimsErr == nil { + return ctx.claims.AuthMethod + } + return "unknown" +} + +func tokenSourceLabel(ctx loginStatusContext) string { + if ctx.tokenSource != "" { + return ctx.tokenSource + } + if ctx.kind == principalKindUser { + return tokenSourceLoginSession + } + return "" +} + +func organizationLineFor(ctx loginStatusContext) string { + switch ctx.kind { + case principalKindUser: + return orgStatusLine(ctx.claims.OrganizationID, ctx.claimsErr) + case principalKindMachineIdentity: + return orgStatusLine(ctx.claims.OrgID, ctx.claimsErr) + } + return "" +} + +func tokenStatusLine(claims loginTokenClaims, claimsErr error) string { + if claimsErr != nil { + return "unknown (could not parse token)" + } + if claims.ExpiresAt == nil { + return "no expiration set" + } + return formatExpiry(claims.ExpiresAt.Time) +} + +func orgStatusLine(orgID string, claimsErr error) string { + if claimsErr != nil { + log.Debug().Err(claimsErr).Msg("login status: unable to decode token payload") + return "unknown (could not parse token)" + } + if orgID == "" { + return "none (token is not scoped to an organization)" + } + return orgID +} + +func printStatusItem(key, value string) { + util.PrintfStdout(" - %s: %s\n", key, boldStyle.Sprint(value)) +} + +type loginTokenClaims struct { + // Token kind discriminator stamped by the backend on every JWT it issues. + AuthTokenType string `json:"authTokenType"` + + // User session JWT claims + UserID string `json:"userId"` + OrganizationID string `json:"organizationId"` + SubOrganizationID string `json:"subOrganizationId"` + + // Machine identity access token JWT claims + IdentityID string `json:"identityId"` + IdentityName string `json:"identityName"` + AuthMethod string `json:"authMethod"` + OrgID string `json:"orgId"` + + jwt.RegisteredClaims +} + +func parseLoginJWTClaims(token string) (loginTokenClaims, error) { + var claims loginTokenClaims + parser := jwt.NewParser() + if _, _, err := parser.ParseUnverified(token, &claims); err != nil { + return loginTokenClaims{}, err + } + return claims, nil +} + +func formatExpiry(expiresAt time.Time) string { + remaining := time.Until(expiresAt) + if remaining <= 0 { + return "expired" + } + hours := int(remaining.Hours()) + if hours >= 24 { + days := hours / 24 + return fmt.Sprintf("%dd %dh", days, hours%24) + } + if hours > 0 { + return fmt.Sprintf("%dh %dm", hours, int(remaining.Minutes())%60) + } + return fmt.Sprintf("%dm", int(remaining.Minutes())) +} + +func renderNotAuthenticated(jsonOutput bool) { + if jsonOutput { + if err := writeLoginStatusJSON(loginStatusJSONOutput{Sessions: []loginStatusSessionJSON{}}); err != nil { + util.HandleError(err, "Unable to encode JSON output") + } + return + } + util.PrintfStdout("%s You are not authenticated.\nRun `infisical login` to log in.\n", redStyle.Sprint("x")) +} + +// verifySession asks the backend whether the credential associated with the +// context is still valid. Local-only signals (missing token, +// already-expired-by-clock) short-circuit the network call. +func verifySession(ctx loginStatusContext) verificationResult { + if ctx.rawToken == "" { + return verificationResult{state: verifyStateSkipped, reason: "no token available"} + } + if isContextExpired(ctx) { + return verificationResult{state: verifyStateSkipped, reason: "locally expired"} + } + switch ctx.kind { + case principalKindServiceToken: + return performVerification(ctx.rawToken, ctx.domain, "/api/v2/service-token", http.MethodGet) + case principalKindMachineIdentity: + return performVerification(ctx.rawToken, ctx.domain, "/api/v1/identities/details", http.MethodGet) + } + return performVerification(ctx.rawToken, ctx.domain, "/api/v1/auth/checkAuth", http.MethodPost) +} + +func performVerification(token, domain, path, method string) verificationResult { + httpClient, err := util.GetRestyClientWithCustomHeaders() + if err != nil { + return verificationResult{state: verifyStateUnknown, reason: err.Error()} + } + httpClient. + SetAuthToken(token). + SetHeader("Accept", "application/json"). + SetTimeout(verifyTimeout) + + url := strings.TrimRight(domain, "/") + path + req := httpClient.R().SetHeader("User-Agent", api.USER_AGENT) + + var ( + statusCode int + callErr error + ) + switch method { + case http.MethodGet: + resp, e := req.Get(url) + callErr = e + if resp != nil { + statusCode = resp.StatusCode() + } + default: + resp, e := req.Post(url) + callErr = e + if resp != nil { + statusCode = resp.StatusCode() + } + } + + if callErr != nil { + log.Debug().Err(callErr).Str("url", url).Msg("login status: backend verification call failed") + return verificationResult{state: verifyStateUnknown, reason: "network error"} + } + switch { + case statusCode >= 200 && statusCode < 300: + return verificationResult{state: verifyStateVerified} + case statusCode == http.StatusUnauthorized, statusCode == http.StatusForbidden: + return verificationResult{state: verifyStateRejected, reason: fmt.Sprintf("HTTP %d", statusCode)} + default: + return verificationResult{state: verifyStateUnknown, reason: fmt.Sprintf("HTTP %d", statusCode)} + } +} + +func init() { + loginStatusCmd.Flags().Bool("json", false, "Output the login status as JSON") + loginStatusCmd.Flags().String("token", "", "Inspect this machine identity access token instead of the active session or environment variables") + loginCmd.AddCommand(loginStatusCmd) +} diff --git a/packages/cmd/login_status_test.go b/packages/cmd/login_status_test.go new file mode 100644 index 00000000..d5e3dc9d --- /dev/null +++ b/packages/cmd/login_status_test.go @@ -0,0 +1,470 @@ +package cmd + +import ( + "encoding/base64" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/Infisical/infisical-merge/packages/util" + jwt "github.com/golang-jwt/jwt/v5" +) + +func TestFormatExpiry(t *testing.T) { + now := time.Now() + + cases := []struct { + name string + expiresAt time.Time + wantPrefix string + wantExact string + }{ + { + name: "already expired", + expiresAt: now.Add(-1 * time.Minute), + wantExact: "expired", + }, + { + name: "exactly at now is expired", + expiresAt: now, + wantExact: "expired", + }, + { + name: "minutes only", + expiresAt: now.Add(15*time.Minute + 30*time.Second), + wantPrefix: "15m", + }, + { + name: "hours and minutes", + expiresAt: now.Add(5*time.Hour + 30*time.Minute), + wantPrefix: "5h ", + }, + { + name: "days and hours", + expiresAt: now.Add(50*time.Hour + 30*time.Second), + wantPrefix: "2d ", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := formatExpiry(tc.expiresAt) + if tc.wantExact != "" { + if got != tc.wantExact { + t.Errorf("formatExpiry(%v) = %q, want %q", tc.expiresAt, got, tc.wantExact) + } + return + } + if !strings.HasPrefix(got, tc.wantPrefix) { + t.Errorf("formatExpiry(%v) = %q, want prefix %q", tc.expiresAt, got, tc.wantPrefix) + } + }) + } +} + +func TestParseLoginJWTClaims(t *testing.T) { + t.Run("happy path with org and sub-org", func(t *testing.T) { + exp := time.Now().Add(time.Hour).Unix() + token := makeUnsignedJWT(t, map[string]any{ + "organizationId": "org-1", + "subOrganizationId": "sub-1", + "exp": exp, + }) + + claims, err := parseLoginJWTClaims(token) + if err != nil { + t.Fatalf("parseLoginJWTClaims: unexpected error: %v", err) + } + if claims.OrganizationID != "org-1" { + t.Errorf("OrganizationID = %q, want %q", claims.OrganizationID, "org-1") + } + if claims.SubOrganizationID != "sub-1" { + t.Errorf("SubOrganizationID = %q, want %q", claims.SubOrganizationID, "sub-1") + } + if claims.ExpiresAt == nil || claims.ExpiresAt.Unix() != exp { + t.Errorf("ExpiresAt = %v, want unix %d", claims.ExpiresAt, exp) + } + }) + + t.Run("machine identity claims parse", func(t *testing.T) { + exp := time.Now().Add(time.Hour).Unix() + token := makeUnsignedJWT(t, map[string]any{ + "identityId": "id-123", + "identityName": "my-ci-bot", + "authMethod": "universal-auth", + "orgId": "org-1", + "exp": exp, + }) + + claims, err := parseLoginJWTClaims(token) + if err != nil { + t.Fatalf("parseLoginJWTClaims: unexpected error: %v", err) + } + if claims.IdentityID != "id-123" { + t.Errorf("IdentityID = %q, want %q", claims.IdentityID, "id-123") + } + if claims.IdentityName != "my-ci-bot" { + t.Errorf("IdentityName = %q, want %q", claims.IdentityName, "my-ci-bot") + } + if claims.AuthMethod != "universal-auth" { + t.Errorf("AuthMethod = %q, want %q", claims.AuthMethod, "universal-auth") + } + if claims.OrgID != "org-1" { + t.Errorf("OrgID = %q, want %q", claims.OrgID, "org-1") + } + }) + + t.Run("token without organization claims still parses", func(t *testing.T) { + token := makeUnsignedJWT(t, map[string]any{ + "exp": time.Now().Add(time.Hour).Unix(), + }) + claims, err := parseLoginJWTClaims(token) + if err != nil { + t.Fatalf("parseLoginJWTClaims: unexpected error: %v", err) + } + if claims.OrganizationID != "" { + t.Errorf("OrganizationID = %q, want empty", claims.OrganizationID) + } + if claims.SubOrganizationID != "" { + t.Errorf("SubOrganizationID = %q, want empty", claims.SubOrganizationID) + } + }) + + t.Run("malformed token returns error", func(t *testing.T) { + if _, err := parseLoginJWTClaims("not-a-jwt"); err == nil { + t.Errorf("parseLoginJWTClaims(%q) = nil error, want error", "not-a-jwt") + } + }) + + t.Run("non-base64 payload returns error", func(t *testing.T) { + if _, err := parseLoginJWTClaims("aaa.!!!.ccc"); err == nil { + t.Errorf("parseLoginJWTClaims with bad payload = nil error, want error") + } + }) +} + +func TestDetectMachineIdentityEnvToken(t *testing.T) { + t.Run("no env var set", func(t *testing.T) { + t.Setenv(util.INFISICAL_TOKEN_NAME, "") + + if _, _, ok := detectMachineIdentityEnvToken(); ok { + t.Errorf("detectMachineIdentityEnvToken() = ok, want !ok when no env var set") + } + }) + + t.Run("returns INFISICAL_TOKEN when set", func(t *testing.T) { + t.Setenv(util.INFISICAL_TOKEN_NAME, "st.abc.def") + + token, source, ok := detectMachineIdentityEnvToken() + if !ok { + t.Fatalf("detectMachineIdentityEnvToken() = !ok, want ok") + } + if token != "st.abc.def" { + t.Errorf("token = %q, want %q", token, "st.abc.def") + } + if !strings.Contains(source, util.INFISICAL_TOKEN_NAME) { + t.Errorf("source = %q, want it to contain %q", source, util.INFISICAL_TOKEN_NAME) + } + }) + + t.Run("whitespace-only env value is ignored", func(t *testing.T) { + t.Setenv(util.INFISICAL_TOKEN_NAME, " ") + + if _, _, ok := detectMachineIdentityEnvToken(); ok { + t.Errorf("detectMachineIdentityEnvToken() = ok for whitespace-only value, want !ok") + } + }) +} + +func TestContextStatus(t *testing.T) { + cases := []struct { + name string + ctx loginStatusContext + want string + }{ + { + name: "user not expired, no verification", + ctx: loginStatusContext{kind: principalKindUser}, + want: statusAuthenticated, + }, + { + name: "locally expired user trumps backend", + ctx: loginStatusContext{ + kind: principalKindUser, + loggedInUser: util.LoggedInUserDetails{LoginExpired: true}, + verification: verificationResult{state: verifyStateVerified}, + }, + want: statusExpired, + }, + { + name: "machine identity locally expired", + ctx: loginStatusContext{ + kind: principalKindMachineIdentity, + claims: loginTokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(-1 * time.Minute)), + }, + }, + }, + want: statusExpired, + }, + { + name: "backend rejected downgrades to rejected", + ctx: loginStatusContext{ + kind: principalKindUser, + verification: verificationResult{state: verifyStateRejected}, + }, + want: statusRejected, + }, + { + name: "unknown verification stays authenticated", + ctx: loginStatusContext{ + kind: principalKindUser, + verification: verificationResult{state: verifyStateUnknown}, + }, + want: statusAuthenticated, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if got := contextStatus(tc.ctx); got != tc.want { + t.Errorf("contextStatus = %q, want %q", got, tc.want) + } + }) + } +} + +func TestVerifySession_NoToken(t *testing.T) { + got := verifySession(loginStatusContext{kind: principalKindUser}) + if got.state != verifyStateSkipped { + t.Errorf("verifySession no-token = %q, want %q", got.state, verifyStateSkipped) + } +} + +func TestVerifySession_LocallyExpiredSkipsCall(t *testing.T) { + var hit int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&hit, 1) + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(server.Close) + + ctx := loginStatusContext{ + kind: principalKindUser, + rawToken: "tok", + domain: server.URL, + loggedInUser: util.LoggedInUserDetails{LoginExpired: true}, + } + got := verifySession(ctx) + if got.state != verifyStateSkipped { + t.Errorf("verifySession locally-expired = %q, want %q", got.state, verifyStateSkipped) + } + if atomic.LoadInt32(&hit) != 0 { + t.Errorf("verifySession hit server %d times for locally-expired ctx, want 0", hit) + } +} + +func TestPerformVerification(t *testing.T) { + cases := []struct { + name string + statusCode int + path string + method string + want string + }{ + {name: "user 200 verified", statusCode: http.StatusOK, path: "/api/v1/auth/checkAuth", method: http.MethodPost, want: verifyStateVerified}, + {name: "user 401 rejected", statusCode: http.StatusUnauthorized, path: "/api/v1/auth/checkAuth", method: http.MethodPost, want: verifyStateRejected}, + {name: "user 403 rejected", statusCode: http.StatusForbidden, path: "/api/v1/auth/checkAuth", method: http.MethodPost, want: verifyStateRejected}, + {name: "user 500 unknown", statusCode: http.StatusInternalServerError, path: "/api/v1/auth/checkAuth", method: http.MethodPost, want: verifyStateUnknown}, + {name: "machine identity 200 verified", statusCode: http.StatusOK, path: "/api/v1/identities/details", method: http.MethodGet, want: verifyStateVerified}, + {name: "machine identity 401 rejected", statusCode: http.StatusUnauthorized, path: "/api/v1/identities/details", method: http.MethodGet, want: verifyStateRejected}, + {name: "machine identity 403 rejected", statusCode: http.StatusForbidden, path: "/api/v1/identities/details", method: http.MethodGet, want: verifyStateRejected}, + {name: "service token 200 verified", statusCode: http.StatusOK, path: "/api/v2/service-token", method: http.MethodGet, want: verifyStateVerified}, + {name: "service token 401 rejected", statusCode: http.StatusUnauthorized, path: "/api/v2/service-token", method: http.MethodGet, want: verifyStateRejected}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var gotPath, gotMethod, gotAuth string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + gotMethod = r.Method + gotAuth = r.Header.Get("Authorization") + w.WriteHeader(tc.statusCode) + })) + t.Cleanup(server.Close) + + got := performVerification("tok-abc", server.URL, tc.path, tc.method) + if got.state != tc.want { + t.Errorf("performVerification state = %q, want %q (detail=%q)", got.state, tc.want, got.reason) + } + if gotPath != tc.path { + t.Errorf("server saw path %q, want %q", gotPath, tc.path) + } + if gotMethod != tc.method { + t.Errorf("server saw method %q, want %q", gotMethod, tc.method) + } + if gotAuth != "Bearer tok-abc" { + t.Errorf("server saw Authorization %q, want %q", gotAuth, "Bearer tok-abc") + } + }) + } +} + +func TestPerformVerification_NetworkError(t *testing.T) { + // Close the server immediately so dialing fails. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + server.Close() + + got := performVerification("tok", server.URL, "/api/v1/auth/checkAuth", http.MethodPost) + if got.state != verifyStateUnknown { + t.Errorf("performVerification network-error state = %q, want %q", got.state, verifyStateUnknown) + } +} + +func TestClassifyToken(t *testing.T) { + exp := time.Now().Add(time.Hour).Unix() + + t.Run("service token prefix", func(t *testing.T) { + kind, _, err := classifyToken("st.abc.def") + if err != nil { + t.Fatalf("classifyToken: unexpected error: %v", err) + } + if kind != principalKindServiceToken { + t.Errorf("kind = %q, want %q", kind, principalKindServiceToken) + } + }) + + t.Run("identity access token routes to machine identity", func(t *testing.T) { + token := makeUnsignedJWT(t, map[string]any{ + "authTokenType": authTokenTypeIdentityAccess, + "identityId": "id-1", + "exp": exp, + }) + kind, claims, err := classifyToken(token) + if err != nil { + t.Fatalf("classifyToken: unexpected error: %v", err) + } + if kind != principalKindMachineIdentity { + t.Errorf("kind = %q, want %q", kind, principalKindMachineIdentity) + } + if claims.IdentityID != "id-1" { + t.Errorf("claims.IdentityID = %q, want %q", claims.IdentityID, "id-1") + } + }) + + t.Run("user access token routes to user even without keyring", func(t *testing.T) { + token := makeUnsignedJWT(t, map[string]any{ + "authTokenType": authTokenTypeAccess, + "userId": "u-1", + "organizationId": "org-1", + "exp": exp, + }) + kind, claims, err := classifyToken(token) + if err != nil { + t.Fatalf("classifyToken: unexpected error: %v", err) + } + if kind != principalKindUser { + t.Errorf("kind = %q, want %q (user JWT misclassified)", kind, principalKindUser) + } + if claims.OrganizationID != "org-1" { + t.Errorf("claims.OrganizationID = %q, want %q", claims.OrganizationID, "org-1") + } + }) + + t.Run("unsupported authTokenType returns error", func(t *testing.T) { + token := makeUnsignedJWT(t, map[string]any{ + "authTokenType": "refreshToken", + "userId": "u-1", + "exp": exp, + }) + _, _, err := classifyToken(token) + if err == nil { + t.Fatalf("classifyToken: expected error for refresh token, got nil") + } + }) + + t.Run("legacy JWT with identityId falls back to machine identity", func(t *testing.T) { + token := makeUnsignedJWT(t, map[string]any{ + "identityId": "id-1", + "exp": exp, + }) + kind, _, err := classifyToken(token) + if err != nil { + t.Fatalf("classifyToken: unexpected error: %v", err) + } + if kind != principalKindMachineIdentity { + t.Errorf("kind = %q, want %q", kind, principalKindMachineIdentity) + } + }) + + t.Run("legacy JWT with userId falls back to user", func(t *testing.T) { + token := makeUnsignedJWT(t, map[string]any{ + "userId": "u-1", + "exp": exp, + }) + kind, _, err := classifyToken(token) + if err != nil { + t.Fatalf("classifyToken: unexpected error: %v", err) + } + if kind != principalKindUser { + t.Errorf("kind = %q, want %q", kind, principalKindUser) + } + }) + + t.Run("malformed JWT returns parse error", func(t *testing.T) { + _, _, err := classifyToken("not-a-jwt") + if err == nil { + t.Fatalf("classifyToken: expected parse error, got nil") + } + }) +} + +func TestBuildContextFromToken(t *testing.T) { + t.Run("user JWT via --token verifies against auth/checkAuth", func(t *testing.T) { + token := makeUnsignedJWT(t, map[string]any{ + "authTokenType": authTokenTypeAccess, + "userId": "u-1", + "exp": time.Now().Add(time.Hour).Unix(), + }) + ctx, err := buildContextFromToken(token, "--token flag", "https://example.test") + if err != nil { + t.Fatalf("buildContextFromToken: unexpected error: %v", err) + } + if ctx.kind != principalKindUser { + t.Fatalf("ctx.kind = %q, want %q", ctx.kind, principalKindUser) + } + if ctx.tokenSource != "--token flag" { + t.Errorf("ctx.tokenSource = %q, want %q", ctx.tokenSource, "--token flag") + } + if ctx.loggedInUser.IsUserLoggedIn { + t.Errorf("ctx.loggedInUser.IsUserLoggedIn = true, want false (no keyring session)") + } + }) + + t.Run("unsupported token type surfaces error", func(t *testing.T) { + token := makeUnsignedJWT(t, map[string]any{ + "authTokenType": "mfaToken", + "exp": time.Now().Add(time.Hour).Unix(), + }) + if _, err := buildContextFromToken(token, "--token flag", "https://example.test"); err == nil { + t.Errorf("buildContextFromToken: expected error for mfaToken, got nil") + } + }) +} + +func makeUnsignedJWT(t *testing.T, claims map[string]any) string { + t.Helper() + headerJSON, _ := json.Marshal(map[string]string{"alg": "none", "typ": "JWT"}) + payloadJSON, err := json.Marshal(claims) + if err != nil { + t.Fatalf("marshal claims: %v", err) + } + enc := base64.RawURLEncoding + return enc.EncodeToString(headerJSON) + "." + enc.EncodeToString(payloadJSON) + "." +} diff --git a/packages/util/credentials.go b/packages/util/credentials.go index 68a17cdf..194f9327 100644 --- a/packages/util/credentials.go +++ b/packages/util/credentials.go @@ -19,6 +19,8 @@ type LoggedInUserDetails struct { UserCredentials models.UserCredentials } +var ErrUserNotLoggedIn = errors.New("we couldn't find your logged in details, try running [infisical login] then try again") + func StoreUserCredsInKeyRing(userCred *models.UserCredentials) error { userCredMarshalled, err := json.Marshal(userCred) if err != nil { @@ -69,7 +71,7 @@ func GetCurrentLoggedInUserDetails(setConfigVariables bool) (LoggedInUserDetails userCreds, err := GetUserCredsFromKeyRing(configFile.LoggedInUserEmail) if err != nil { if strings.Contains(err.Error(), "credentials not found in system keyring") { - return LoggedInUserDetails{}, errors.New("we couldn't find your logged in details, try running [infisical login] then try again") + return LoggedInUserDetails{}, ErrUserNotLoggedIn } else { return LoggedInUserDetails{}, fmt.Errorf("failed to fetch credentials from keyring because [err=%s]", err) }