diff --git a/cmd/cloud/regions/list.go b/cmd/cloud/regions/list.go index 5eb65d68..cbe6eb86 100644 --- a/cmd/cloud/regions/list.go +++ b/cmd/cloud/regions/list.go @@ -52,7 +52,14 @@ func (c *ListController) Run(cmd *cobra.Command, args []string) (fctl.Renderable return nil, err } - organizationID, apiClient, err := fctl.NewMembershipClientForOrganizationFromFlags(cmd, relyingParty, fctl.NewPTermDialog(), profileName, *profile) + organizationID, apiClient, err := fctl.NewMembershipClientForOrganizationFromFlagsWithScopes( + cmd, + relyingParty, + fctl.NewPTermDialog(), + profileName, + *profile, + []string{"organization:ListRegions"}, + ) if err != nil { return nil, err } diff --git a/cmd/stack/create.go b/cmd/stack/create.go index 73bf7bd3..e04418d7 100644 --- a/cmd/stack/create.go +++ b/cmd/stack/create.go @@ -70,7 +70,19 @@ func (c *CreateController) Run(cmd *cobra.Command, args []string) (fctl.Renderab return nil, err } - organizationID, apiClient, err := fctl.NewMembershipClientForOrganizationFromFlags(cmd, relyingParty, fctl.NewPTermDialog(), profileName, *profile) + region := fctl.GetString(cmd, regionFlag) + requiredScopes := []string{ + "organization:CreateStack", + "organization:ReadRegion", + } + if region == "" { + requiredScopes = append(requiredScopes, "organization:ListRegions") + } + if !fctl.GetBool(cmd, nowaitFlag) { + requiredScopes = append(requiredScopes, "organization:ReadStack") + } + + organizationID, apiClient, err := fctl.NewMembershipClientForOrganizationFromFlagsWithScopes(cmd, relyingParty, fctl.NewPTermDialog(), profileName, *profile, requiredScopes) if err != nil { return nil, err } @@ -85,7 +97,6 @@ func (c *CreateController) Run(cmd *cobra.Command, args []string) (fctl.Renderab } } - region := fctl.GetString(cmd, regionFlag) if region == "" { listRegionsRequest := operations.ListRegionsRequest{ OrganizationID: organizationID, @@ -152,7 +163,7 @@ func (c *CreateController) Run(cmd *cobra.Command, args []string) (fctl.Renderab specifiedVersion := fctl.GetString(cmd, versionFlag) if specifiedVersion == "" { var options []string - for _, version := range availableVersionsResponse.GetRegionVersionsResponse.GetData() { + for _, version := range sortRegionVersionsByLatest(availableVersionsResponse.GetRegionVersionsResponse.GetData()) { options = append(options, version.GetName()) } @@ -219,7 +230,18 @@ func (c *CreateController) Run(cmd *cobra.Command, args []string) (fctl.Renderab fctl.BasicTextCyan.WithWriter(cmd.OutOrStdout()).Println("Your portal will be reachable on: " + portal) - // todo: need a long running client with auto refresh + if !profile.RootTokens.ID.Claims.HasStackAccess(organizationID, stackData.GetID()) { + fctl.NewPTermDialog().Info("Refreshing profile accesses...") + tokens, err := fctl.RefreshRootTokens(cmd.Context(), relyingParty, *profile.RootTokens) + if err != nil { + return nil, fmt.Errorf("refreshing profile accesses: %w", err) + } + profile.UpdateRootToken(tokens) + if err := fctl.WriteProfile(cmd, profileName, *profile); err != nil { + return nil, fmt.Errorf("writing refreshed profile: %w", err) + } + } + stackClient, err := fctl.NewStackClient( cmd, relyingParty, diff --git a/cmd/stack/upgrade.go b/cmd/stack/upgrade.go index ca6f5d70..1d56931b 100644 --- a/cmd/stack/upgrade.go +++ b/cmd/stack/upgrade.go @@ -6,7 +6,6 @@ import ( "github.com/pterm/pterm" "github.com/spf13/cobra" - "golang.org/x/mod/semver" "github.com/formancehq/fctl/internal/membershipclient/v3" "github.com/formancehq/fctl/internal/membershipclient/v3/models/components" @@ -181,9 +180,9 @@ func retrieveUpgradableVersion(ctx context.Context, organization string, stack c if versionName == *currentVersion { continue } - if !semver.IsValid(versionName) || semver.Compare(versionName, *currentVersion) >= 1 { + if isVersionNewerThanCurrent(versionName, *currentVersion) { upgradeOptions = append(upgradeOptions, versionName) } } - return upgradeOptions, nil + return sortVersionNamesByLatest(upgradeOptions), nil } diff --git a/cmd/stack/version_sort.go b/cmd/stack/version_sort.go new file mode 100644 index 00000000..7d793243 --- /dev/null +++ b/cmd/stack/version_sort.go @@ -0,0 +1,91 @@ +package stack + +import ( + "sort" + "strings" + + "golang.org/x/mod/semver" + + "github.com/formancehq/fctl/internal/membershipclient/v3/models/components" +) + +func sortRegionVersionsByLatest(versions []components.Version) []components.Version { + sorted := append([]components.Version(nil), versions...) + sort.SliceStable(sorted, func(i, j int) bool { + return compareVersionNamesByLatest(sorted[i].GetName(), sorted[j].GetName()) + }) + return sorted +} + +func sortVersionNamesByLatest(versions []string) []string { + sorted := append([]string(nil), versions...) + sort.SliceStable(sorted, func(i, j int) bool { + return compareVersionNamesByLatest(sorted[i], sorted[j]) + }) + return sorted +} + +func compareVersionNamesByLatest(a, b string) bool { + normalizedA, validA := normalizeSemver(a) + normalizedB, validB := normalizeSemver(b) + + if validA && validB { + return semver.Compare(normalizedA, normalizedB) > 0 + } + if validA != validB { + return validA + } + return a > b +} + +func isVersionNewerThanCurrent(candidate, current string) bool { + normalizedCandidate, validCandidate := normalizeSemver(candidate) + normalizedCurrent, validCurrent := normalizeSemver(current) + if validCandidate && validCurrent { + return semver.Compare(normalizedCandidate, normalizedCurrent) > 0 + } + return true +} + +func normalizeSemver(version string) (string, bool) { + version = strings.TrimSpace(version) + if version == "" { + return "", false + } + if !strings.HasPrefix(version, "v") { + version = "v" + version + } + if semver.IsValid(version) { + return version, true + } + + suffixStart := len(version) + for _, separator := range []string{"-", "+"} { + if index := strings.Index(version, separator); index >= 0 && index < suffixStart { + suffixStart = index + } + } + + core := strings.TrimPrefix(version[:suffixStart], "v") + suffix := version[suffixStart:] + parts := strings.Split(core, ".") + if len(parts) > 3 { + return "", false + } + for len(parts) < 3 { + parts = append(parts, "0") + } + for _, part := range parts { + if part == "" { + return "", false + } + for _, r := range part { + if r < '0' || r > '9' { + return "", false + } + } + } + + normalized := "v" + strings.Join(parts, ".") + suffix + return normalized, semver.IsValid(normalized) +} diff --git a/cmd/stack/version_sort_test.go b/cmd/stack/version_sort_test.go new file mode 100644 index 00000000..a9cfb322 --- /dev/null +++ b/cmd/stack/version_sort_test.go @@ -0,0 +1,56 @@ +package stack + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/formancehq/fctl/internal/membershipclient/v3/models/components" +) + +func TestSortRegionVersionsByLatest(t *testing.T) { + versions := []components.Version{ + {Name: "v1.2.0"}, + {Name: "1.10.0"}, + {Name: "v2.0.0-rc.1"}, + {Name: "v2.0.0"}, + {Name: "v1.9.0"}, + } + + sorted := sortRegionVersionsByLatest(versions) + + require.Equal(t, []string{ + "v2.0.0", + "v2.0.0-rc.1", + "1.10.0", + "v1.9.0", + "v1.2.0", + }, []string{ + sorted[0].GetName(), + sorted[1].GetName(), + sorted[2].GetName(), + sorted[3].GetName(), + sorted[4].GetName(), + }) + require.Equal(t, "v1.2.0", versions[0].GetName()) +} + +func TestSortVersionNamesByLatest(t *testing.T) { + require.Equal(t, + []string{"v1.11.0", "1.10.0", "v1.2.0"}, + sortVersionNamesByLatest([]string{"v1.2.0", "v1.11.0", "1.10.0"}), + ) +} + +func TestSortVersionNamesByLatestWithShortVersions(t *testing.T) { + require.Equal(t, + []string{"v3.2-rc", "v3.1", "v3.0", "v2.2", "v2.1"}, + sortVersionNamesByLatest([]string{"v3.0", "v2.2", "v3.1", "v3.2-rc", "v2.1"}), + ) +} + +func TestIsVersionNewerThanCurrent(t *testing.T) { + require.True(t, isVersionNewerThanCurrent("1.10.0", "v1.9.0")) + require.False(t, isVersionNewerThanCurrent("v1.9.0", "1.10.0")) + require.True(t, isVersionNewerThanCurrent("v3.2-rc", "v3.1")) +} diff --git a/pkg/authentication.go b/pkg/authentication.go index a9932582..fe9fd04e 100644 --- a/pkg/authentication.go +++ b/pkg/authentication.go @@ -248,6 +248,44 @@ func Refresh(ctx context.Context, relyingParty client.RelyingParty, token Access }, nil } +func RefreshRootTokens(ctx context.Context, relyingParty client.RelyingParty, tokens Tokens) (*Tokens, error) { + newToken, err := client.RefreshTokens[*IDTokenClaims](ctx, relyingParty, tokens.Access.Refresh, "", "") + if err != nil { + return nil, newErrInvalidAuthentication(err) + } + + accessTokenClaims := AccessTokenClaims{} + if _, err := oidc.ParseToken(newToken.AccessToken, &accessTokenClaims); err != nil { + return nil, newErrInvalidAuthentication(err) + } + + refreshToken := newToken.RefreshToken + if refreshToken == "" { + refreshToken = tokens.Access.Refresh + } + + idToken := tokens.ID.Token + idTokenClaims := tokens.ID.Claims + if newToken.IDToken != "" && newToken.IDTokenClaims != nil { + idToken = newToken.IDToken + idTokenClaims = *newToken.IDTokenClaims + } + + return &Tokens{ + Access: AccessToken{ + TokenWithClaims: TokenWithClaims[AccessTokenClaims]{ + Token: newToken.AccessToken, + Claims: accessTokenClaims, + }, + Refresh: refreshToken, + }, + ID: IDToken{ + Token: idToken, + Claims: idTokenClaims, + }, + }, nil +} + func FetchStackToken(ctx context.Context, httpClient *http.Client, stackURI, token string) (*oauth2.Token, error) { form := url.Values{ "grant_type": []string{"urn:ietf:params:oauth:grant-type:jwt-bearer"}, @@ -323,6 +361,26 @@ type AccessToken struct { Refresh string `json:"refreshToken"` } +func (t AccessToken) MissingScopes(scopes ...string) []string { + tokenScopes := make(map[string]struct{}, len(t.Claims.Scopes)) + for _, scope := range t.Claims.Scopes { + tokenScopes[scope] = struct{}{} + } + + missingScopes := make([]string, 0) + for _, scope := range scopes { + if _, ok := tokenScopes[scope]; !ok { + missingScopes = append(missingScopes, scope) + } + } + + return missingScopes +} + +func (t AccessToken) HasScopes(scopes ...string) bool { + return len(t.MissingScopes(scopes...)) == 0 +} + func (t AccessToken) ToOAuth2() *oauth2.Token { return &oauth2.Token{ AccessToken: t.Token, diff --git a/pkg/authentication_test.go b/pkg/authentication_test.go new file mode 100644 index 00000000..5e324179 --- /dev/null +++ b/pkg/authentication_test.go @@ -0,0 +1,40 @@ +package fctl + +import ( + "reflect" + "testing" + + "github.com/formancehq/go-libs/v4/oidc" +) + +func TestAccessTokenMissingScopes(t *testing.T) { + token := AccessToken{ + TokenWithClaims: TokenWithClaims[AccessTokenClaims]{ + Claims: AccessTokenClaims{ + Scopes: oidc.SpaceDelimitedArray{ + "organization:CreateStack", + "organization:ReadRegion", + }, + }, + }, + } + + missingScopes := token.MissingScopes( + "organization:CreateStack", + "organization:ListRegions", + "organization:ReadRegion", + ) + + expected := []string{"organization:ListRegions"} + if !reflect.DeepEqual(missingScopes, expected) { + t.Fatalf("expected missing scopes %v, got %v", expected, missingScopes) + } + + if token.HasScopes("organization:CreateStack", "organization:ListRegions") { + t.Fatal("expected HasScopes to return false for missing scope") + } + + if !token.HasScopes("organization:CreateStack", "organization:ReadRegion") { + t.Fatal("expected HasScopes to return true when all scopes are present") + } +} diff --git a/pkg/clients.go b/pkg/clients.go index a7f1a798..0530e27e 100644 --- a/pkg/clients.go +++ b/pkg/clients.go @@ -113,6 +113,26 @@ func EnsureOrganizationAccess( profileName string, profile Profile, organizationID string, +) (*AccessToken, error) { + return EnsureOrganizationAccessWithScopes( + cmd, + relyingParty, + dialog, + profileName, + profile, + organizationID, + nil, + ) +} + +func EnsureOrganizationAccessWithScopes( + cmd *cobra.Command, + relyingParty client.RelyingParty, + dialog Dialog, + profileName string, + profile Profile, + organizationID string, + requiredScopes []string, ) (*AccessToken, error) { if !profile.RootTokens.ID.Claims.HasOrganizationAccess(organizationID) { return nil, fmt.Errorf("no access to organization %s found in your authentication profile, "+ @@ -170,6 +190,19 @@ func EnsureOrganizationAccess( } } + if organizationToken != nil && len(requiredScopes) > 0 && !organizationToken.HasScopes(requiredScopes...) { + dialog.Info("Organization token is missing required scopes, requesting new authentication...") + tokens, err := authenticate() + if err != nil { + return nil, fmt.Errorf("failed to authenticate for organization: %w", err) + } + + organizationToken = &tokens.Access + if missingScopes := organizationToken.MissingScopes(requiredScopes...); len(missingScopes) > 0 { + return nil, fmt.Errorf("authenticated organization token is missing required scopes: %s", strings.Join(missingScopes, ", ")) + } + } + if organizationToken != originalOrganizationToken { if err := WriteOrganizationToken(cmd, profileName, *organizationToken); err != nil { return nil, err @@ -391,14 +424,35 @@ func NewMembershipClientForOrganization( profile Profile, organizationID string, ) (*membershipclient.SDK, error) { + return NewMembershipClientForOrganizationWithScopes( + cmd, + relyingParty, + dialog, + profileName, + profile, + organizationID, + nil, + ) +} - organizationToken, err := EnsureOrganizationAccess( +func NewMembershipClientForOrganizationWithScopes( + cmd *cobra.Command, + relyingParty client.RelyingParty, + dialog Dialog, + profileName string, + profile Profile, + organizationID string, + requiredScopes []string, +) (*membershipclient.SDK, error) { + + organizationToken, err := EnsureOrganizationAccessWithScopes( cmd, relyingParty, dialog, profileName, profile, organizationID, + requiredScopes, ) if err != nil { return nil, err @@ -411,12 +465,13 @@ func NewMembershipClientForOrganization( ), nil } -func NewMembershipClientForOrganizationFromFlags( +func NewMembershipClientForOrganizationFromFlagsWithScopes( cmd *cobra.Command, relyingParty client.RelyingParty, dialog Dialog, profileName string, profile Profile, + requiredScopes []string, ) (string, *membershipclient.SDK, error) { organizationID, err := ResolveOrganizationID(cmd, profile) @@ -424,11 +479,21 @@ func NewMembershipClientForOrganizationFromFlags( return "", nil, err } - client, err := NewMembershipClientForOrganization(cmd, relyingParty, dialog, profileName, profile, organizationID) + client, err := NewMembershipClientForOrganizationWithScopes(cmd, relyingParty, dialog, profileName, profile, organizationID, requiredScopes) return organizationID, client, err } +func NewMembershipClientForOrganizationFromFlags( + cmd *cobra.Command, + relyingParty client.RelyingParty, + dialog Dialog, + profileName string, + profile Profile, +) (string, *membershipclient.SDK, error) { + return NewMembershipClientForOrganizationFromFlagsWithScopes(cmd, relyingParty, dialog, profileName, profile, nil) +} + func NewStackClient( cmd *cobra.Command, relyingParty client.RelyingParty,