From 22e1ae9c9044e7abfa5403c24fec3c010b4dd714 Mon Sep 17 00:00:00 2001 From: Markus Unterwaditzer Date: Fri, 22 May 2026 22:55:47 +0200 Subject: [PATCH 1/2] Initial implementation of PR checkout `forge pr checkout 123` will now check out PR/MR locally: 1. Bunch of API calls (forge-specific) 2. Add fork as remote if necessary 3. Make new branch with upstream from fork I tested this manually on Codeberg and GitHub. An agent also tested it on GitLab and Bitbucket. Since there are no e2e tests in this repo, I mainly constrained myself to testing the CLI with a mock forge. There are no CLI tests that spawn HTTP servers and run CLI, so I refrained from that. Existing fields on the structs were preserved for backwards compat, let me know if you want me to remove them since this library is young, or whether PRBranch should be flattened into its parent structs. --- bitbucket/bitbucket.go | 31 +- bitbucket/bitbucket_test.go | 5 +- bitbucket/prs.go | 67 ++++- bitbucket/prs_test.go | 12 +- gitea/prs.go | 24 +- github/prs.go | 28 +- gitlab/prs.go | 22 ++ internal/cli/pr.go | 170 +++++++++++ internal/cli/pr_checkout_test.go | 500 +++++++++++++++++++++++++++++++ internal/resolve/resolve.go | 27 ++ types.go | 22 +- 11 files changed, 859 insertions(+), 49 deletions(-) create mode 100644 internal/cli/pr_checkout_test.go diff --git a/bitbucket/bitbucket.go b/bitbucket/bitbucket.go index d3f6ff9..0f06097 100644 --- a/bitbucket/bitbucket.go +++ b/bitbucket/bitbucket.go @@ -17,6 +17,23 @@ var bitbucketAPI = "https://api.bitbucket.org/2.0" // setBitbucketAPI overrides the Bitbucket API base URL (for testing). func setBitbucketAPI(url string) { bitbucketAPI = url } +type bbCloneLink struct { + Href string `json:"href"` + Name string `json:"name"` // "https" or "ssh" +} + +func parseCloneURLs(links []bbCloneLink) (cloneURL, sshURL string) { + for _, link := range links { + switch link.Name { + case "https": + cloneURL = link.Href + case "ssh": + sshURL = link.Href + } + } + return +} + type bitbucketForge struct { token string httpClient *http.Client @@ -69,10 +86,7 @@ type bbRepository struct { Avatar struct { Href string `json:"href"` } `json:"avatar"` - Clone []struct { - Href string `json:"href"` - Name string `json:"name"` - } `json:"clone"` + Clone []bbCloneLink `json:"clone"` } `json:"links"` CreatedOn string `json:"created_on"` UpdatedOn string `json:"updated_on"` @@ -157,14 +171,7 @@ func convertBitbucketRepo(bb bbRepository) forge.Repository { LogoURL: bb.Links.Avatar.Href, } - for _, c := range bb.Links.Clone { - switch c.Name { - case "https": - result.CloneURL = c.Href - case "ssh": - result.SSHURL = c.Href - } - } + result.CloneURL, result.SSHURL = parseCloneURLs(bb.Links.Clone) if bb.Owner != nil { result.Owner = bb.Owner.Username diff --git a/bitbucket/bitbucket_test.go b/bitbucket/bitbucket_test.go index a95e6a3..b99f683 100644 --- a/bitbucket/bitbucket_test.go +++ b/bitbucket/bitbucket_test.go @@ -42,10 +42,7 @@ func TestBitbucketGetRepo(t *testing.T) { Avatar struct { Href string `json:"href"` } `json:"avatar"` - Clone []struct { - Href string `json:"href"` - Name string `json:"name"` - } `json:"clone"` + Clone []bbCloneLink `json:"clone"` }{ HTML: struct { Href string `json:"href"` diff --git a/bitbucket/prs.go b/bitbucket/prs.go index 7d8ae61..57b5a4f 100644 --- a/bitbucket/prs.go +++ b/bitbucket/prs.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net/http" + "strings" "time" forge "github.com/git-pkgs/forge" @@ -21,22 +22,29 @@ func (f *bitbucketForge) PullRequests() forge.PullRequestService { return &bitbucketPRService{token: f.token, httpClient: f.httpClient} } +type bbPRBranch struct { + Branch struct { + Name string `json:"name"` + } `json:"branch"` + Commit *struct { + Hash string `json:"hash"` + } `json:"commit"` + Repository *struct { + FullName string `json:"full_name"` + Links struct { + Clone []bbCloneLink `json:"clone"` + } `json:"links"` + } `json:"repository"` +} + type bbPullRequest struct { - ID int `json:"id"` - Title string `json:"title"` - Description string `json:"description"` - State string `json:"state"` // OPEN, MERGED, DECLINED, SUPERSEDED - Source struct { - Branch struct { - Name string `json:"name"` - } `json:"branch"` - } `json:"source"` - Destination struct { - Branch struct { - Name string `json:"name"` - } `json:"branch"` - } `json:"destination"` - Author *struct { + ID int `json:"id"` + Title string `json:"title"` + Description string `json:"description"` + State string `json:"state"` // OPEN, MERGED, DECLINED, SUPERSEDED + Source bbPRBranch `json:"source"` + Destination bbPRBranch `json:"destination"` + Author *struct { Username string `json:"username"` DisplayName string `json:"display_name"` Links struct { @@ -88,6 +96,35 @@ func convertBitbucketPR(bb bbPullRequest) forge.PullRequest { DiffURL: bb.Links.Diff.Href, } + var destFullName string + result.BaseBranch = &forge.PRBranch{Ref: bb.Destination.Branch.Name} + if bb.Destination.Commit != nil { + result.BaseBranch.SHA = bb.Destination.Commit.Hash + } + if bb.Destination.Repository != nil { + destFullName = bb.Destination.Repository.FullName + } + + result.HeadBranch = &forge.PRBranch{Ref: bb.Source.Branch.Name} + if bb.Source.Commit != nil { + result.HeadBranch.SHA = bb.Source.Commit.Hash + } + if bb.Source.Repository != nil && bb.Source.Repository.FullName != destFullName { + cloneURL, sshURL := parseCloneURLs(bb.Source.Repository.Links.Clone) + parts := strings.Split(bb.Source.Repository.FullName, "/") + var owner, name string + if len(parts) >= 2 { + owner = parts[0] + name = parts[1] + } + result.HeadBranch.Fork = &forge.ForkInfo{ + Owner: owner, + Name: name, + CloneURL: cloneURL, + SSHURL: sshURL, + } + } + switch bb.State { case "OPEN": result.State = "open" diff --git a/bitbucket/prs_test.go b/bitbucket/prs_test.go index b40c76b..dcde77e 100644 --- a/bitbucket/prs_test.go +++ b/bitbucket/prs_test.go @@ -17,18 +17,10 @@ func TestBitbucketGetPR(t *testing.T) { Title: "Add feature", Description: "New feature PR", State: "OPEN", - Source: struct { - Branch struct { - Name string `json:"name"` - } `json:"branch"` - }{Branch: struct { + Source: bbPRBranch{Branch: struct { Name string `json:"name"` }{Name: "feature-branch"}}, - Destination: struct { - Branch struct { - Name string `json:"name"` - } `json:"branch"` - }{Branch: struct { + Destination: bbPRBranch{Branch: struct { Name string `json:"name"` }{Name: "main"}}, Author: &struct { diff --git a/gitea/prs.go b/gitea/prs.go index ceeb335..6522cef 100644 --- a/gitea/prs.go +++ b/gitea/prs.go @@ -48,11 +48,29 @@ func convertGiteaPR(pr *gitea.PullRequest) forge.PullRequest { result.State = stateOpen } - if pr.Head != nil { - result.Head = pr.Head.Name - } + var baseRepoID int64 if pr.Base != nil { result.Base = pr.Base.Name + result.BaseBranch = &forge.PRBranch{ + Ref: pr.Base.Ref, + SHA: pr.Base.Sha, + } + baseRepoID = pr.Base.RepoID + } + if pr.Head != nil { + result.Head = pr.Head.Name + result.HeadBranch = &forge.PRBranch{ + Ref: pr.Head.Ref, + SHA: pr.Head.Sha, + } + if pr.Head.RepoID != baseRepoID && pr.Head.Repository != nil { + result.HeadBranch.Fork = &forge.ForkInfo{ + Owner: pr.Head.Repository.Owner.UserName, + Name: pr.Head.Repository.Name, + CloneURL: pr.Head.Repository.CloneURL, + SSHURL: pr.Head.Repository.SSHURL, + } + } } if pr.Poster != nil { diff --git a/github/prs.go b/github/prs.go index 0e1be56..3502105 100644 --- a/github/prs.go +++ b/github/prs.go @@ -81,11 +81,33 @@ func convertGitHubPR(pr *github.PullRequest) forge.PullRequest { } } - if h := pr.GetHead(); h != nil { - result.Head = h.GetRef() - } + var baseFullName string if b := pr.GetBase(); b != nil { result.Base = b.GetRef() + result.BaseBranch = &forge.PRBranch{ + Ref: b.GetRef(), + SHA: b.GetSHA(), + } + if repo := b.GetRepo(); repo != nil { + baseFullName = repo.GetFullName() + } + } + if h := pr.GetHead(); h != nil { + result.Head = h.GetRef() + result.HeadBranch = &forge.PRBranch{ + Ref: h.GetRef(), + SHA: h.GetSHA(), + } + if repo := h.GetRepo(); repo != nil { + if repo.GetFullName() != baseFullName { + result.HeadBranch.Fork = &forge.ForkInfo{ + Owner: repo.GetOwner().GetLogin(), + Name: repo.GetName(), + CloneURL: repo.GetCloneURL(), + SSHURL: repo.GetSSHURL(), + } + } + } } if u := pr.GetMergedBy(); u != nil { diff --git a/gitlab/prs.go b/gitlab/prs.go index bc26a03..265893e 100644 --- a/gitlab/prs.go +++ b/gitlab/prs.go @@ -36,6 +36,12 @@ func convertGitLabMR(mr *gitlab.MergeRequest) forge.PullRequest { HTMLURL: mr.WebURL, } + result.BaseBranch = &forge.PRBranch{Ref: mr.TargetBranch} + result.HeadBranch = &forge.PRBranch{ + Ref: mr.SourceBranch, + SHA: mr.SHA, + } + // Normalize "opened" to "open" if result.State == stateOpened { result.State = stateOpen @@ -184,6 +190,22 @@ func (s *gitLabPRService) Get(ctx context.Context, owner, repo string, number in return nil, err } result := convertGitLabMR(mr) + + if mr.SourceProjectID != mr.TargetProjectID { + sourceProject, _, err := s.client.Projects.GetProject(mr.SourceProjectID, nil) + if err == nil && sourceProject != nil { + if result.HeadBranch == nil { + result.HeadBranch = &forge.PRBranch{Ref: mr.SourceBranch, SHA: mr.SHA} + } + result.HeadBranch.Fork = &forge.ForkInfo{ + Owner: sourceProject.Namespace.Path, + Name: sourceProject.Path, + CloneURL: sourceProject.HTTPURLToRepo, + SSHURL: sourceProject.SSHURLToRepo, + } + } + } + return &result, nil } diff --git a/internal/cli/pr.go b/internal/cli/pr.go index e17b501..80451ee 100644 --- a/internal/cli/pr.go +++ b/internal/cli/pr.go @@ -1,8 +1,10 @@ package cli import ( + "context" "fmt" "os" + "os/exec" "strconv" "strings" @@ -36,6 +38,7 @@ func init() { prCmd.AddCommand(prCommentCmd()) prCmd.AddCommand(prReactionsCmd()) prCmd.AddCommand(prReactCmd()) + prCmd.AddCommand(prCheckoutCmd()) } func prViewCmd() *cobra.Command { @@ -538,3 +541,170 @@ func prCommentCmd() *cobra.Command { cmd.Flags().StringVarP(&flagBody, "body", "b", "", "Comment body") return cmd } + +func prCheckoutCmd() *cobra.Command { + var ( + flagRemoteName string + flagBranch string + flagDetach bool + flagForce bool + ) + + cmd := &cobra.Command{ + Use: "checkout ", + Short: "Check out a pull request locally", + Long: `Check out a pull request's head branch locally. + +If the PR is from a fork, the fork repository is added as a remote +(named after the fork owner by default), and the branch is fetched +and checked out with upstream tracking configured. + +For same-repo PRs, the branch is fetched and checked out.`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + number, err := strconv.Atoi(args[0]) + if err != nil { + return fmt.Errorf("invalid PR number: %s", args[0]) + } + + forge, owner, repoName, _, err := resolve.Repo(flagRepo, flagForgeType) + if err != nil { + return err + } + + ctx := cmd.Context() + + if !flagForce { + status, _ := exec.CommandContext(ctx, "git", "status", "--porcelain").Output() + if len(status) > 0 { + return fmt.Errorf("you have uncommitted changes; commit or stash them, or use --force") + } + } + + pr, err := forge.PullRequests().Get(ctx, owner, repoName, number) + if err != nil { + return fmt.Errorf("getting PR #%d: %w", number, err) + } + + // remoteRef is the branch name on the remote (PR's head branch) + remoteRef := pr.Head + if pr.HeadBranch != nil && pr.HeadBranch.Ref != "" { + remoteRef = pr.HeadBranch.Ref + } + + // localBranch is what we'll name the local branch (defaults to remote ref) + localBranch := remoteRef + if flagBranch != "" { + localBranch = flagBranch + } + + if !flagForce && !flagDetach { + if err := exec.CommandContext(ctx, "git", "rev-parse", "--verify", "--quiet", localBranch).Run(); err == nil { + _, _ = fmt.Fprintf(os.Stderr, "warning: local branch %q already exists and will be reset\n", localBranch) + } + } + + if pr.HeadBranch != nil && pr.HeadBranch.Fork != nil { + return checkoutForkPR(ctx, pr, remoteRef, localBranch, flagRemoteName, flagDetach) + } + + return checkoutSameRepoPR(ctx, remoteRef, localBranch, flagDetach) + }, + } + + cmd.Flags().StringVar(&flagRemoteName, "remote-name", "", "Name for fork remote (default: fork owner)") + cmd.Flags().StringVarP(&flagBranch, "branch", "b", "", "Local branch name (default: same as remote)") + cmd.Flags().BoolVar(&flagDetach, "detach", false, "Checkout in detached HEAD mode") + cmd.Flags().BoolVarP(&flagForce, "force", "f", false, "Force checkout even with uncommitted changes or existing branch") + return cmd +} + +func checkoutForkPR(ctx context.Context, pr *forges.PullRequest, remoteRef, localBranch, flagRemoteName string, detach bool) error { + fork := pr.HeadBranch.Fork + remoteName := flagRemoteName + if remoteName == "" { + remoteName = fork.Owner + } + if remoteName == "" { + remoteName = "fork" + } + + cloneURL := fork.CloneURL + if cloneURL == "" { + cloneURL = fork.SSHURL + } + if cloneURL == "" { + return fmt.Errorf("no clone URL available for fork repository") + } + + remoteName, err := ensureRemote(ctx, remoteName, cloneURL) + if err != nil { + return err + } + + refspec := fmt.Sprintf("+refs/heads/%s:refs/remotes/%s/%s", remoteRef, remoteName, remoteRef) + fetchCmd := exec.CommandContext(ctx, "git", "fetch", "--", remoteName, refspec) + fetchCmd.Stdout = os.Stdout + fetchCmd.Stderr = os.Stderr + if err := fetchCmd.Run(); err != nil { + return fmt.Errorf("fetching %s/%s: %w", remoteName, remoteRef, err) + } + + return gitCheckout(ctx, localBranch, remoteName+"/"+remoteRef, detach) +} + +func checkoutSameRepoPR(ctx context.Context, remoteRef, localBranch string, detach bool) error { + refspec := fmt.Sprintf("+refs/heads/%s:refs/remotes/origin/%s", remoteRef, remoteRef) + fetchCmd := exec.CommandContext(ctx, "git", "fetch", "--", "origin", refspec) + fetchCmd.Stdout = os.Stdout + fetchCmd.Stderr = os.Stderr + if err := fetchCmd.Run(); err != nil { + return fmt.Errorf("fetching origin/%s: %w", remoteRef, err) + } + + return gitCheckout(ctx, localBranch, "origin/"+remoteRef, detach) +} + +func ensureRemote(ctx context.Context, preferredName, cloneURL string) (string, error) { + remotes, err := exec.CommandContext(ctx, "git", "remote", "-v").Output() + if err == nil { + for _, line := range strings.Split(string(remotes), "\n") { + fields := strings.Fields(line) + if len(fields) >= 2 && fields[1] == cloneURL { + return fields[0], nil + } + } + } + + existingURL, err := exec.CommandContext(ctx, "git", "remote", "get-url", preferredName).Output() + if err != nil { + addCmd := exec.CommandContext(ctx, "git", "remote", "add", "--", preferredName, cloneURL) + addCmd.Stdout = os.Stdout + addCmd.Stderr = os.Stderr + if err := addCmd.Run(); err != nil { + return "", fmt.Errorf("adding remote %s: %w", preferredName, err) + } + return preferredName, nil + } + + if strings.TrimSpace(string(existingURL)) == cloneURL { + return preferredName, nil + } + + return "", fmt.Errorf("remote %q already exists with a different URL; use --remote-name to specify a different name", preferredName) +} + +func gitCheckout(ctx context.Context, branchName, ref string, detach bool) error { + var checkoutCmd *exec.Cmd + if detach { + checkoutCmd = exec.CommandContext(ctx, "git", "checkout", "--detach", ref) + } else { + checkoutCmd = exec.CommandContext(ctx, "git", "checkout", "-B", branchName, ref) + } + checkoutCmd.Stdout = os.Stdout + checkoutCmd.Stderr = os.Stderr + if err := checkoutCmd.Run(); err != nil { + return fmt.Errorf("checking out %s: %w", ref, err) + } + return nil +} diff --git a/internal/cli/pr_checkout_test.go b/internal/cli/pr_checkout_test.go new file mode 100644 index 0000000..6a31ac7 --- /dev/null +++ b/internal/cli/pr_checkout_test.go @@ -0,0 +1,500 @@ +package cli + +import ( + "bytes" + "context" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + + forges "github.com/git-pkgs/forge" + "github.com/git-pkgs/forge/internal/resolve" +) + +// mockPRService implements forges.PullRequestService for testing. +type mockPRService struct { + pr *forges.PullRequest + err error +} + +func (m *mockPRService) Get(_ context.Context, _, _ string, _ int) (*forges.PullRequest, error) { + return m.pr, m.err +} + +func (m *mockPRService) List(_ context.Context, _, _ string, _ forges.ListPROpts) ([]forges.PullRequest, error) { + return nil, nil +} + +func (m *mockPRService) Create(_ context.Context, _, _ string, _ forges.CreatePROpts) (*forges.PullRequest, error) { + return nil, nil +} + +func (m *mockPRService) Update(_ context.Context, _, _ string, _ int, _ forges.UpdatePROpts) (*forges.PullRequest, error) { + return nil, nil +} + +func (m *mockPRService) Close(_ context.Context, _, _ string, _ int) error { + return nil +} + +func (m *mockPRService) Reopen(_ context.Context, _, _ string, _ int) error { + return nil +} + +func (m *mockPRService) Merge(_ context.Context, _, _ string, _ int, _ forges.MergePROpts) error { + return nil +} + +func (m *mockPRService) Diff(_ context.Context, _, _ string, _ int) (string, error) { + return "", nil +} + +func (m *mockPRService) CreateComment(_ context.Context, _, _ string, _ int, _ string) (*forges.Comment, error) { + return nil, nil +} + +func (m *mockPRService) ListComments(_ context.Context, _, _ string, _ int) ([]forges.Comment, error) { + return nil, nil +} + +func (m *mockPRService) ListReactions(_ context.Context, _, _ string, _ int, _ int64) ([]forges.Reaction, error) { + return nil, nil +} + +func (m *mockPRService) AddReaction(_ context.Context, _, _ string, _ int, _ int64, _ string) (*forges.Reaction, error) { + return nil, nil +} + +func (m *mockPRService) ListURL(_ string) string { + return "" +} + +// mockForge implements forges.Forge for testing. +type mockForge struct { + prService *mockPRService +} + +func (m *mockForge) Repos() forges.RepoService { return nil } +func (m *mockForge) Issues() forges.IssueService { return nil } +func (m *mockForge) PullRequests() forges.PullRequestService { return m.prService } +func (m *mockForge) Labels() forges.LabelService { return nil } +func (m *mockForge) Milestones() forges.MilestoneService { return nil } +func (m *mockForge) Releases() forges.ReleaseService { return nil } +func (m *mockForge) CI() forges.CIService { return nil } +func (m *mockForge) Branches() forges.BranchService { return nil } +func (m *mockForge) DeployKeys() forges.DeployKeyService { return nil } +func (m *mockForge) Secrets() forges.SecretService { return nil } +func (m *mockForge) Notifications() forges.NotificationService { return nil } +func (m *mockForge) Reviews() forges.ReviewService { return nil } +func (m *mockForge) Files() forges.FileService { return nil } +func (m *mockForge) Collaborators() forges.CollaboratorService { return nil } +func (m *mockForge) CommitStatuses() forges.CommitStatusService { return nil } +func (m *mockForge) GetRateLimit(_ context.Context) (*forges.RateLimit, error) { + return nil, forges.ErrNotSupported +} + +// setupTestRepo creates a temporary git repository with an initial commit +// and an origin remote pointing to a fake URL. +func setupTestRepo(t *testing.T, originURL string) string { + t.Helper() + dir := t.TempDir() + + commands := [][]string{ + {"git", "init"}, + {"git", "config", "user.email", "test@test.com"}, + {"git", "config", "user.name", "Test User"}, + } + + for _, args := range commands { + cmd := exec.Command(args[0], args[1:]...) + cmd.Dir = dir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git command %v failed: %v\n%s", args, err, out) + } + } + + // Create an initial commit so we have a valid HEAD + testFile := filepath.Join(dir, "README.md") + if err := os.WriteFile(testFile, []byte("# Test\n"), 0644); err != nil { + t.Fatalf("writing test file: %v", err) + } + + commands = [][]string{ + {"git", "add", "README.md"}, + {"git", "commit", "-m", "Initial commit"}, + {"git", "remote", "add", "origin", originURL}, + } + + for _, args := range commands { + cmd := exec.Command(args[0], args[1:]...) + cmd.Dir = dir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git command %v failed: %v\n%s", args, err, out) + } + } + + return dir +} + +// setupBareRepo creates a bare git repository that can be used as a remote. +func setupBareRepo(t *testing.T) string { + t.Helper() + dir := t.TempDir() + + cmd := exec.Command("git", "init", "--bare") + cmd.Dir = dir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git init --bare failed: %v\n%s", err, out) + } + + return dir +} + +// pushBranchToRemote creates a branch and pushes it to a remote. +func pushBranchToRemote(t *testing.T, repoDir, remoteName, branchName string) { + t.Helper() + + // Create a file and commit on a new branch + testFile := filepath.Join(repoDir, branchName+".txt") + if err := os.WriteFile(testFile, []byte("content for "+branchName+"\n"), 0644); err != nil { + t.Fatalf("writing test file: %v", err) + } + + commands := [][]string{ + {"git", "checkout", "-b", branchName}, + {"git", "add", "."}, + {"git", "commit", "-m", "Add " + branchName}, + {"git", "push", remoteName, branchName}, + {"git", "checkout", "-"}, + } + + for _, args := range commands { + cmd := exec.Command(args[0], args[1:]...) + cmd.Dir = repoDir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git command %v failed: %v\n%s", args, err, out) + } + } +} + +func TestPRCheckout(t *testing.T) { + tests := []struct { + name string + pr *forges.PullRequest + args []string + setupOrigin bool // whether to create and push to origin + setupFork bool // whether to create a fork remote + wantBranch string + wantRemote string // expected remote name for fork PRs + wantErr string + }{ + { + name: "same-repo PR checks out branch", + pr: &forges.PullRequest{ + Number: 42, + Head: "feature-branch", + HeadBranch: &forges.PRBranch{ + Ref: "feature-branch", + SHA: "abc123", + }, + }, + args: []string{"pr", "checkout", "42"}, + setupOrigin: true, + wantBranch: "feature-branch", + }, + { + name: "fork PR adds remote and checks out", + pr: &forges.PullRequest{ + Number: 42, + Head: "feature", + HeadBranch: &forges.PRBranch{ + Ref: "feature", + SHA: "abc123", + Fork: &forges.ForkInfo{ + Owner: "contributor", + Name: "repo", + CloneURL: "FORK_URL_PLACEHOLDER", // will be replaced + }, + }, + }, + args: []string{"pr", "checkout", "42"}, + setupFork: true, + wantBranch: "feature", + wantRemote: "contributor", + }, + { + name: "fork PR with custom remote name", + pr: &forges.PullRequest{ + Number: 42, + Head: "feature", + HeadBranch: &forges.PRBranch{ + Ref: "feature", + SHA: "abc123", + Fork: &forges.ForkInfo{ + Owner: "contributor", + Name: "repo", + CloneURL: "FORK_URL_PLACEHOLDER", + }, + }, + }, + args: []string{"pr", "checkout", "42", "--remote-name", "upstream"}, + setupFork: true, + wantBranch: "feature", + wantRemote: "upstream", + }, + { + name: "detach mode", + pr: &forges.PullRequest{ + Number: 42, + Head: "feature-branch", + HeadBranch: &forges.PRBranch{ + Ref: "feature-branch", + SHA: "abc123", + }, + }, + args: []string{"pr", "checkout", "42", "--detach"}, + setupOrigin: true, + wantBranch: "", // detached HEAD + }, + { + name: "checkout with custom branch name", + pr: &forges.PullRequest{ + Number: 42, + Head: "feature-branch", + HeadBranch: &forges.PRBranch{ + Ref: "feature-branch", + SHA: "abc123", + }, + }, + args: []string{"pr", "checkout", "42", "-b", "my-local-branch"}, + setupOrigin: true, + wantBranch: "my-local-branch", + }, + { + name: "invalid PR number", + args: []string{"pr", "checkout", "notanumber"}, + wantErr: "invalid PR number", + }, + { + name: "fork PR without clone URL", + pr: &forges.PullRequest{ + Number: 42, + Head: "feature", + HeadBranch: &forges.PRBranch{ + Ref: "feature", + SHA: "abc123", + Fork: &forges.ForkInfo{ + Owner: "contributor", + Name: "repo", + // CloneURL and SSHURL both empty + }, + }, + }, + args: []string{"pr", "checkout", "42"}, + wantErr: "no clone URL available", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Skip tests that need real git operations in short mode + if testing.Short() && (tt.setupOrigin || tt.setupFork) { + t.Skip("skipping git integration test in short mode") + } + + // Reset flags to defaults before each test + // Find the checkout command and reset its flags + checkoutCmd, _, _ := rootCmd.Find([]string{"pr", "checkout"}) + if checkoutCmd != nil { + _ = checkoutCmd.Flags().Set("detach", "false") + _ = checkoutCmd.Flags().Set("force", "false") + _ = checkoutCmd.Flags().Set("branch", "") + _ = checkoutCmd.Flags().Set("remote-name", "") + } + + var workDir string + + // For git integration tests, set up repos + if tt.setupOrigin || tt.setupFork { + originDir := setupBareRepo(t) + workDir = setupTestRepo(t, originDir) + + if tt.setupOrigin { + branchName := tt.pr.HeadBranch.Ref + pushBranchToRemote(t, workDir, "origin", branchName) + } + + if tt.setupFork { + forkDir := setupBareRepo(t) + tt.pr.HeadBranch.Fork.CloneURL = forkDir + + branchName := tt.pr.HeadBranch.Ref + cmd := exec.Command("git", "remote", "add", "tempfork", forkDir) + cmd.Dir = workDir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("adding temp fork remote: %v\n%s", err, out) + } + pushBranchToRemote(t, workDir, "tempfork", branchName) + cmd = exec.Command("git", "remote", "remove", "tempfork") + cmd.Dir = workDir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("removing temp fork remote: %v\n%s", err, out) + } + } + } else if tt.pr != nil { + // For error tests that still need a git context, create a minimal repo + originDir := setupBareRepo(t) + workDir = setupTestRepo(t, originDir) + } + + // Change to work directory for the test + if workDir != "" { + oldWd, err := os.Getwd() + if err != nil { + t.Fatalf("getting working directory: %v", err) + } + if err := os.Chdir(workDir); err != nil { + t.Fatalf("changing to work directory: %v", err) + } + t.Cleanup(func() { + _ = os.Chdir(oldWd) + }) + } + + // Set up mock forge + if tt.pr != nil { + resolve.SetTestForge( + &mockForge{prService: &mockPRService{pr: tt.pr}}, + "testowner", "testrepo", "github.com", + ) + t.Cleanup(resolve.ResetTestForge) + } + + // Execute command + var buf bytes.Buffer + rootCmd.SetOut(&buf) + rootCmd.SetErr(&buf) + rootCmd.SetArgs(tt.args) + + err := rootCmd.Execute() + + // Check error + if tt.wantErr != "" { + if err == nil { + t.Fatalf("expected error containing %q, got nil", tt.wantErr) + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("expected error containing %q, got %q", tt.wantErr, err.Error()) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v\noutput: %s", err, buf.String()) + } + + if workDir == "" { + return // no git state to verify + } + + // Verify branch + if tt.wantBranch != "" { + cmd := exec.Command("git", "branch", "--show-current") + cmd.Dir = workDir + out, err := cmd.Output() + if err != nil { + t.Fatalf("getting current branch: %v", err) + } + gotBranch := strings.TrimSpace(string(out)) + if gotBranch != tt.wantBranch { + t.Errorf("branch: want %q, got %q", tt.wantBranch, gotBranch) + } + } else { + // Detached HEAD - verify no branch + cmd := exec.Command("git", "branch", "--show-current") + cmd.Dir = workDir + out, _ := cmd.Output() + if strings.TrimSpace(string(out)) != "" { + t.Errorf("expected detached HEAD, but on branch %q", strings.TrimSpace(string(out))) + } + } + + // Verify remote for fork PRs + if tt.wantRemote != "" { + cmd := exec.Command("git", "remote", "-v") + cmd.Dir = workDir + out, err := cmd.Output() + if err != nil { + t.Fatalf("listing remotes: %v", err) + } + if !strings.Contains(string(out), tt.wantRemote) { + t.Errorf("expected remote %q in output:\n%s", tt.wantRemote, out) + } + } + }) + } +} + +func TestPRCheckoutUncommittedChanges(t *testing.T) { + if testing.Short() { + t.Skip("skipping git integration test in short mode") + } + + originDir := setupBareRepo(t) + workDir := setupTestRepo(t, originDir) + + // Push a branch to origin + pushBranchToRemote(t, workDir, "origin", "feature-branch") + + // Create uncommitted changes + testFile := filepath.Join(workDir, "uncommitted.txt") + if err := os.WriteFile(testFile, []byte("uncommitted\n"), 0644); err != nil { + t.Fatalf("writing test file: %v", err) + } + + oldWd, _ := os.Getwd() + if err := os.Chdir(workDir); err != nil { + t.Fatalf("changing to work directory: %v", err) + } + t.Cleanup(func() { _ = os.Chdir(oldWd) }) + + pr := &forges.PullRequest{ + Number: 42, + Head: "feature-branch", + HeadBranch: &forges.PRBranch{ + Ref: "feature-branch", + SHA: "abc123", + }, + } + + resolve.SetTestForge( + &mockForge{prService: &mockPRService{pr: pr}}, + "testowner", "testrepo", "github.com", + ) + t.Cleanup(resolve.ResetTestForge) + + // Without --force should fail + var buf bytes.Buffer + rootCmd.SetOut(&buf) + rootCmd.SetErr(&buf) + rootCmd.SetArgs([]string{"pr", "checkout", "42"}) + + err := rootCmd.Execute() + if err == nil { + t.Fatal("expected error for uncommitted changes, got nil") + } + if !strings.Contains(err.Error(), "uncommitted changes") { + t.Errorf("expected error about uncommitted changes, got: %v", err) + } + + // With --force should succeed + buf.Reset() + rootCmd.SetArgs([]string{"pr", "checkout", "42", "--force"}) + + err = rootCmd.Execute() + if err != nil { + t.Fatalf("unexpected error with --force: %v", err) + } +} diff --git a/internal/resolve/resolve.go b/internal/resolve/resolve.go index a5b0f0f..a38b3cf 100644 --- a/internal/resolve/resolve.go +++ b/internal/resolve/resolve.go @@ -20,6 +20,13 @@ var ( remoteName = "origin" hostOverride string forgeTypeOverride string + + // testForge allows tests to inject a mock forge. When set, Repo() returns + // this forge directly without network or git resolution. + testForge forges.Forge + testOwner string + testRepo string + testDomain string ) // SetRemote sets which git remote to read when resolving the current @@ -50,6 +57,23 @@ func SetForgeType(forgeType string) { } } +// SetTestForge configures a mock forge for testing. When set, Repo() returns +// this forge directly without network or git resolution. +func SetTestForge(forge forges.Forge, owner, repo, domain string) { + testForge = forge + testOwner = owner + testRepo = repo + testDomain = domain +} + +// ResetTestForge clears the test forge configuration. +func ResetTestForge() { + testForge = nil + testOwner = "" + testRepo = "" + testDomain = "" +} + var builders = forges.ForgeBuilders{ GitHub: ghforge.NewWithBase, GitLab: glforge.New, @@ -60,6 +84,9 @@ var builders = forges.ForgeBuilders{ // git remote. The -R flag takes precedence; otherwise we read the "origin" // remote URL and parse it. func Repo(flagRepo, flagForgeType string) (forge forges.Forge, owner, repo, domain string, err error) { + if testForge != nil { + return testForge, testOwner, testRepo, testDomain, nil + } if flagRepo != "" { return repoFromFlag(flagRepo, flagForgeType) } diff --git a/types.go b/types.go index b7e0738..a24adfa 100644 --- a/types.go +++ b/types.go @@ -229,6 +229,22 @@ type UpdateIssueOpts struct { Milestone *string } +// ForkInfo holds minimal repository info needed for PR checkout from forks. +type ForkInfo struct { + Owner string `json:"owner"` + Name string `json:"name,omitempty"` + CloneURL string `json:"clone_url,omitempty"` + SSHURL string `json:"ssh_url,omitempty"` +} + +// PRBranch holds branch info including the repository it belongs to. +// For same-repo PRs, Fork is nil. For fork PRs, Fork points to the source repo. +type PRBranch struct { + Ref string `json:"ref"` // branch name + SHA string `json:"sha,omitempty"` // commit SHA + Fork *ForkInfo `json:"fork,omitempty"` // nil if same repo as target +} + // PullRequest holds normalized metadata about a pull request (or merge request). type PullRequest struct { Number int `json:"number"` @@ -241,8 +257,10 @@ type PullRequest struct { Reviewers []User `json:"reviewers,omitempty"` Labels []Label `json:"labels,omitempty"` Milestone *Milestone `json:"milestone,omitempty"` - Head string `json:"head"` // head branch - Base string `json:"base"` // base branch + Head string `json:"head"` // head branch name (for backward compat) + Base string `json:"base"` // base branch name (for backward compat) + HeadBranch *PRBranch `json:"head_branch,omitempty"` // rich head branch info with repo + BaseBranch *PRBranch `json:"base_branch,omitempty"` // rich base branch info Mergeable bool `json:"mergeable"` Merged bool `json:"merged"` MergedBy *User `json:"merged_by,omitempty"` From 3e111583ca39e315a5cb87eff6ff9f0d9dc3119c Mon Sep 17 00:00:00 2001 From: Markus Unterwaditzer Date: Sat, 23 May 2026 18:12:19 +0200 Subject: [PATCH 2/2] Address review feedback --- bitbucket/prs.go | 12 ++-- bitbucket/prs_test.go | 4 +- gitea/prs.go | 10 ++- github/prs.go | 8 +-- github/prs_test.go | 6 +- gitlab/prs.go | 27 ++++---- internal/cli/pr.go | 107 +++++++++++++++---------------- internal/cli/pr_checkout_test.go | 97 +++------------------------- internal/resolve/resolve.go | 6 ++ types.go | 6 +- 10 files changed, 99 insertions(+), 184 deletions(-) diff --git a/bitbucket/prs.go b/bitbucket/prs.go index 57b5a4f..e32d289 100644 --- a/bitbucket/prs.go +++ b/bitbucket/prs.go @@ -89,25 +89,23 @@ func convertBitbucketPR(bb bbPullRequest) forge.PullRequest { Number: bb.ID, Title: bb.Title, Body: bb.Description, - Head: bb.Source.Branch.Name, - Base: bb.Destination.Branch.Name, + Head: forge.PRBranch{Ref: bb.Source.Branch.Name}, + Base: forge.PRBranch{Ref: bb.Destination.Branch.Name}, Comments: bb.CommentCount, HTMLURL: bb.Links.HTML.Href, DiffURL: bb.Links.Diff.Href, } var destFullName string - result.BaseBranch = &forge.PRBranch{Ref: bb.Destination.Branch.Name} if bb.Destination.Commit != nil { - result.BaseBranch.SHA = bb.Destination.Commit.Hash + result.Base.SHA = bb.Destination.Commit.Hash } if bb.Destination.Repository != nil { destFullName = bb.Destination.Repository.FullName } - result.HeadBranch = &forge.PRBranch{Ref: bb.Source.Branch.Name} if bb.Source.Commit != nil { - result.HeadBranch.SHA = bb.Source.Commit.Hash + result.Head.SHA = bb.Source.Commit.Hash } if bb.Source.Repository != nil && bb.Source.Repository.FullName != destFullName { cloneURL, sshURL := parseCloneURLs(bb.Source.Repository.Links.Clone) @@ -117,7 +115,7 @@ func convertBitbucketPR(bb bbPullRequest) forge.PullRequest { owner = parts[0] name = parts[1] } - result.HeadBranch.Fork = &forge.ForkInfo{ + result.Head.Fork = &forge.ForkInfo{ Owner: owner, Name: name, CloneURL: cloneURL, diff --git a/bitbucket/prs_test.go b/bitbucket/prs_test.go index dcde77e..bcdce32 100644 --- a/bitbucket/prs_test.go +++ b/bitbucket/prs_test.go @@ -74,8 +74,8 @@ func TestBitbucketGetPR(t *testing.T) { assertEqual(t, "Title", "Add feature", pr.Title) assertEqual(t, "Body", "New feature PR", pr.Body) assertEqual(t, "State", "open", pr.State) - assertEqual(t, "Head", "feature-branch", pr.Head) - assertEqual(t, "Base", "main", pr.Base) + assertEqual(t, "Head", "feature-branch", pr.Head.Ref) + assertEqual(t, "Base", "main", pr.Base.Ref) assertEqual(t, "Author.Login", "author1", pr.Author.Login) assertEqualInt(t, "Comments", 3, pr.Comments) assertEqualBool(t, "Merged", false, pr.Merged) diff --git a/gitea/prs.go b/gitea/prs.go index 6522cef..a0920ae 100644 --- a/gitea/prs.go +++ b/gitea/prs.go @@ -50,21 +50,19 @@ func convertGiteaPR(pr *gitea.PullRequest) forge.PullRequest { var baseRepoID int64 if pr.Base != nil { - result.Base = pr.Base.Name - result.BaseBranch = &forge.PRBranch{ + result.Base = forge.PRBranch{ Ref: pr.Base.Ref, SHA: pr.Base.Sha, } baseRepoID = pr.Base.RepoID } if pr.Head != nil { - result.Head = pr.Head.Name - result.HeadBranch = &forge.PRBranch{ + result.Head = forge.PRBranch{ Ref: pr.Head.Ref, SHA: pr.Head.Sha, } - if pr.Head.RepoID != baseRepoID && pr.Head.Repository != nil { - result.HeadBranch.Fork = &forge.ForkInfo{ + if pr.Head.RepoID != baseRepoID && pr.Head.Repository != nil && pr.Head.Repository.Owner != nil { + result.Head.Fork = &forge.ForkInfo{ Owner: pr.Head.Repository.Owner.UserName, Name: pr.Head.Repository.Name, CloneURL: pr.Head.Repository.CloneURL, diff --git a/github/prs.go b/github/prs.go index 3502105..b4f97f3 100644 --- a/github/prs.go +++ b/github/prs.go @@ -83,8 +83,7 @@ func convertGitHubPR(pr *github.PullRequest) forge.PullRequest { var baseFullName string if b := pr.GetBase(); b != nil { - result.Base = b.GetRef() - result.BaseBranch = &forge.PRBranch{ + result.Base = forge.PRBranch{ Ref: b.GetRef(), SHA: b.GetSHA(), } @@ -93,14 +92,13 @@ func convertGitHubPR(pr *github.PullRequest) forge.PullRequest { } } if h := pr.GetHead(); h != nil { - result.Head = h.GetRef() - result.HeadBranch = &forge.PRBranch{ + result.Head = forge.PRBranch{ Ref: h.GetRef(), SHA: h.GetSHA(), } if repo := h.GetRepo(); repo != nil { if repo.GetFullName() != baseFullName { - result.HeadBranch.Fork = &forge.ForkInfo{ + result.Head.Fork = &forge.ForkInfo{ Owner: repo.GetOwner().GetLogin(), Name: repo.GetName(), CloneURL: repo.GetCloneURL(), diff --git a/github/prs_test.go b/github/prs_test.go index 6444011..8ff6a49 100644 --- a/github/prs_test.go +++ b/github/prs_test.go @@ -69,8 +69,8 @@ func TestGitHubGetPR(t *testing.T) { assertEqualBool(t, "Draft", false, pr.Draft) assertEqualBool(t, "Merged", false, pr.Merged) assertEqualBool(t, "Mergeable", true, pr.Mergeable) - assertEqual(t, "Head", "feature-branch", pr.Head) - assertEqual(t, "Base", "main", pr.Base) + assertEqual(t, "Head", "feature-branch", pr.Head.Ref) + assertEqual(t, "Base", "main", pr.Base.Ref) assertEqual(t, "Author.Login", "octocat", pr.Author.Login) assertEqualInt(t, "Comments", 2, pr.Comments) assertEqualInt(t, "Additions", 10, pr.Additions) @@ -174,7 +174,7 @@ func TestGitHubListPRs(t *testing.T) { t.Fatalf("expected 2 PRs, got %d", len(prs)) } assertEqual(t, "prs[0].Title", "First PR", prs[0].Title) - assertEqual(t, "prs[0].Head", "feature-1", prs[0].Head) + assertEqual(t, "prs[0].Head", "feature-1", prs[0].Head.Ref) assertEqual(t, "prs[1].Title", "Second PR", prs[1].Title) } diff --git a/gitlab/prs.go b/gitlab/prs.go index 265893e..6251ab9 100644 --- a/gitlab/prs.go +++ b/gitlab/prs.go @@ -2,10 +2,11 @@ package gitlab import ( "context" - forge "github.com/git-pkgs/forge" + "fmt" "net/http" "time" + forge "github.com/git-pkgs/forge" gitlab "gitlab.com/gitlab-org/api/client-go" ) @@ -28,20 +29,14 @@ func convertGitLabMR(mr *gitlab.MergeRequest) forge.PullRequest { Body: mr.Description, State: mr.State, // "opened", "closed", "merged" Draft: mr.Draft, - Head: mr.SourceBranch, - Base: mr.TargetBranch, + Head: forge.PRBranch{Ref: mr.SourceBranch, SHA: mr.SHA}, + Base: forge.PRBranch{Ref: mr.TargetBranch}, Merged: mr.State == "merged", Comments: int(mr.UserNotesCount), // ChangesCount is a string in the GitLab API HTMLURL: mr.WebURL, } - result.BaseBranch = &forge.PRBranch{Ref: mr.TargetBranch} - result.HeadBranch = &forge.PRBranch{ - Ref: mr.SourceBranch, - SHA: mr.SHA, - } - // Normalize "opened" to "open" if result.State == stateOpened { result.State = stateOpen @@ -123,8 +118,8 @@ func convertBasicGitLabMR(mr *gitlab.BasicMergeRequest) forge.PullRequest { Body: mr.Description, State: mr.State, Draft: mr.Draft, - Head: mr.SourceBranch, - Base: mr.TargetBranch, + Head: forge.PRBranch{Ref: mr.SourceBranch}, + Base: forge.PRBranch{Ref: mr.TargetBranch}, Merged: mr.State == "merged", HTMLURL: mr.WebURL, } @@ -193,11 +188,11 @@ func (s *gitLabPRService) Get(ctx context.Context, owner, repo string, number in if mr.SourceProjectID != mr.TargetProjectID { sourceProject, _, err := s.client.Projects.GetProject(mr.SourceProjectID, nil) - if err == nil && sourceProject != nil { - if result.HeadBranch == nil { - result.HeadBranch = &forge.PRBranch{Ref: mr.SourceBranch, SHA: mr.SHA} - } - result.HeadBranch.Fork = &forge.ForkInfo{ + if err != nil { + return nil, fmt.Errorf("getting source project: %w", err) + } + if sourceProject != nil { + result.Head.Fork = &forge.ForkInfo{ Owner: sourceProject.Namespace.Path, Name: sourceProject.Path, CloneURL: sourceProject.HTTPURLToRepo, diff --git a/internal/cli/pr.go b/internal/cli/pr.go index 80451ee..ca22cf1 100644 --- a/internal/cli/pr.go +++ b/internal/cli/pr.go @@ -103,7 +103,7 @@ func printPRDetails(pr *forges.PullRequest) { _, _ = fmt.Fprintf(os.Stdout, "#%d %s\n", pr.Number, output.Sanitize(pr.Title)) _, _ = fmt.Fprintf(os.Stdout, "State: %s\n", pr.State) _, _ = fmt.Fprintf(os.Stdout, "Author: %s\n", output.Sanitize(pr.Author.Login)) - _, _ = fmt.Fprintf(os.Stdout, "Branch: %s -> %s\n", pr.Head, pr.Base) + _, _ = fmt.Fprintf(os.Stdout, "Branch: %s -> %s\n", pr.Head.Ref, pr.Base.Ref) if pr.Draft { _, _ = fmt.Fprintln(os.Stdout, "Draft: yes") @@ -211,7 +211,7 @@ func prListCmd() *cobra.Command { strconv.Itoa(pr.Number), title, output.Sanitize(pr.Author.Login), - pr.Head, + pr.Head.Ref, pr.UpdatedAt.Format("2006-01-02"), } } @@ -574,23 +574,13 @@ For same-repo PRs, the branch is fetched and checked out.`, ctx := cmd.Context() - if !flagForce { - status, _ := exec.CommandContext(ctx, "git", "status", "--porcelain").Output() - if len(status) > 0 { - return fmt.Errorf("you have uncommitted changes; commit or stash them, or use --force") - } - } - pr, err := forge.PullRequests().Get(ctx, owner, repoName, number) if err != nil { return fmt.Errorf("getting PR #%d: %w", number, err) } // remoteRef is the branch name on the remote (PR's head branch) - remoteRef := pr.Head - if pr.HeadBranch != nil && pr.HeadBranch.Ref != "" { - remoteRef = pr.HeadBranch.Ref - } + remoteRef := pr.Head.Ref // localBranch is what we'll name the local branch (defaults to remote ref) localBranch := remoteRef @@ -598,29 +588,23 @@ For same-repo PRs, the branch is fetched and checked out.`, localBranch = flagBranch } - if !flagForce && !flagDetach { - if err := exec.CommandContext(ctx, "git", "rev-parse", "--verify", "--quiet", localBranch).Run(); err == nil { - _, _ = fmt.Fprintf(os.Stderr, "warning: local branch %q already exists and will be reset\n", localBranch) - } - } - - if pr.HeadBranch != nil && pr.HeadBranch.Fork != nil { - return checkoutForkPR(ctx, pr, remoteRef, localBranch, flagRemoteName, flagDetach) + if pr.Head.Fork != nil { + return checkoutForkPR(ctx, pr, remoteRef, localBranch, flagRemoteName, flagDetach, flagForce) } - return checkoutSameRepoPR(ctx, remoteRef, localBranch, flagDetach) + return checkoutSameRepoPR(ctx, remoteRef, localBranch, flagDetach, flagForce) }, } cmd.Flags().StringVar(&flagRemoteName, "remote-name", "", "Name for fork remote (default: fork owner)") cmd.Flags().StringVarP(&flagBranch, "branch", "b", "", "Local branch name (default: same as remote)") cmd.Flags().BoolVar(&flagDetach, "detach", false, "Checkout in detached HEAD mode") - cmd.Flags().BoolVarP(&flagForce, "force", "f", false, "Force checkout even with uncommitted changes or existing branch") + cmd.Flags().BoolVarP(&flagForce, "force", "f", false, "Reset the local branch to the remote state even if it has diverged") return cmd } -func checkoutForkPR(ctx context.Context, pr *forges.PullRequest, remoteRef, localBranch, flagRemoteName string, detach bool) error { - fork := pr.HeadBranch.Fork +func checkoutForkPR(ctx context.Context, pr *forges.PullRequest, remoteRef, localBranch, flagRemoteName string, detach, force bool) error { + fork := pr.Head.Fork remoteName := flagRemoteName if remoteName == "" { remoteName = fork.Owner @@ -642,27 +626,11 @@ func checkoutForkPR(ctx context.Context, pr *forges.PullRequest, remoteRef, loca return err } - refspec := fmt.Sprintf("+refs/heads/%s:refs/remotes/%s/%s", remoteRef, remoteName, remoteRef) - fetchCmd := exec.CommandContext(ctx, "git", "fetch", "--", remoteName, refspec) - fetchCmd.Stdout = os.Stdout - fetchCmd.Stderr = os.Stderr - if err := fetchCmd.Run(); err != nil { - return fmt.Errorf("fetching %s/%s: %w", remoteName, remoteRef, err) - } - - return gitCheckout(ctx, localBranch, remoteName+"/"+remoteRef, detach) + return gitCheckout(ctx, remoteName, remoteRef, localBranch, detach, force) } -func checkoutSameRepoPR(ctx context.Context, remoteRef, localBranch string, detach bool) error { - refspec := fmt.Sprintf("+refs/heads/%s:refs/remotes/origin/%s", remoteRef, remoteRef) - fetchCmd := exec.CommandContext(ctx, "git", "fetch", "--", "origin", refspec) - fetchCmd.Stdout = os.Stdout - fetchCmd.Stderr = os.Stderr - if err := fetchCmd.Run(); err != nil { - return fmt.Errorf("fetching origin/%s: %w", remoteRef, err) - } - - return gitCheckout(ctx, localBranch, "origin/"+remoteRef, detach) +func checkoutSameRepoPR(ctx context.Context, remoteRef, localBranch string, detach, force bool) error { + return gitCheckout(ctx, resolve.RemoteName(), remoteRef, localBranch, detach, force) } func ensureRemote(ctx context.Context, preferredName, cloneURL string) (string, error) { @@ -694,17 +662,48 @@ func ensureRemote(ctx context.Context, preferredName, cloneURL string) (string, return "", fmt.Errorf("remote %q already exists with a different URL; use --remote-name to specify a different name", preferredName) } -func gitCheckout(ctx context.Context, branchName, ref string, detach bool) error { - var checkoutCmd *exec.Cmd +func gitCheckout(ctx context.Context, remote, remoteRef, localBranch string, detach, force bool) error { + refspec := fmt.Sprintf("+refs/heads/%s:refs/remotes/%s/%s", remoteRef, remote, remoteRef) + fetchCmd := exec.CommandContext(ctx, "git", "fetch", "--", remote, refspec) + fetchCmd.Stdout = os.Stdout + fetchCmd.Stderr = os.Stderr + if err := fetchCmd.Run(); err != nil { + return fmt.Errorf("fetching %s/%s: %w", remote, remoteRef, err) + } + + ref := remote + "/" + remoteRef + if detach { - checkoutCmd = exec.CommandContext(ctx, "git", "checkout", "--detach", ref) - } else { - checkoutCmd = exec.CommandContext(ctx, "git", "checkout", "-B", branchName, ref) + cmd := exec.CommandContext(ctx, "git", "checkout", "--detach", ref) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + return cmd.Run() } - checkoutCmd.Stdout = os.Stdout - checkoutCmd.Stderr = os.Stderr - if err := checkoutCmd.Run(); err != nil { - return fmt.Errorf("checking out %s: %w", ref, err) + + // Try creating a new branch + if exec.CommandContext(ctx, "git", "checkout", "-b", localBranch, ref).Run() == nil { + return nil } - return nil + + // Branch exists - switch to it and try to fast-forward + cmd := exec.CommandContext(ctx, "git", "checkout", localBranch) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("checking out %s: %w", localBranch, err) + } + + if exec.CommandContext(ctx, "git", "merge", "--ff-only", ref).Run() == nil { + return nil + } + + if !force { + return fmt.Errorf("local branch %q has diverged from %s; use --force to reset it", localBranch, ref) + } + + _, _ = fmt.Fprintf(os.Stderr, "warning: resetting %q to %s (local commits will be lost)\n", localBranch, ref) + resetCmd := exec.CommandContext(ctx, "git", "reset", "--hard", ref) + resetCmd.Stdout = os.Stdout + resetCmd.Stderr = os.Stderr + return resetCmd.Run() } diff --git a/internal/cli/pr_checkout_test.go b/internal/cli/pr_checkout_test.go index 6a31ac7..58272c0 100644 --- a/internal/cli/pr_checkout_test.go +++ b/internal/cli/pr_checkout_test.go @@ -194,8 +194,7 @@ func TestPRCheckout(t *testing.T) { name: "same-repo PR checks out branch", pr: &forges.PullRequest{ Number: 42, - Head: "feature-branch", - HeadBranch: &forges.PRBranch{ + Head: forges.PRBranch{ Ref: "feature-branch", SHA: "abc123", }, @@ -208,8 +207,7 @@ func TestPRCheckout(t *testing.T) { name: "fork PR adds remote and checks out", pr: &forges.PullRequest{ Number: 42, - Head: "feature", - HeadBranch: &forges.PRBranch{ + Head: forges.PRBranch{ Ref: "feature", SHA: "abc123", Fork: &forges.ForkInfo{ @@ -228,8 +226,7 @@ func TestPRCheckout(t *testing.T) { name: "fork PR with custom remote name", pr: &forges.PullRequest{ Number: 42, - Head: "feature", - HeadBranch: &forges.PRBranch{ + Head: forges.PRBranch{ Ref: "feature", SHA: "abc123", Fork: &forges.ForkInfo{ @@ -248,8 +245,7 @@ func TestPRCheckout(t *testing.T) { name: "detach mode", pr: &forges.PullRequest{ Number: 42, - Head: "feature-branch", - HeadBranch: &forges.PRBranch{ + Head: forges.PRBranch{ Ref: "feature-branch", SHA: "abc123", }, @@ -262,8 +258,7 @@ func TestPRCheckout(t *testing.T) { name: "checkout with custom branch name", pr: &forges.PullRequest{ Number: 42, - Head: "feature-branch", - HeadBranch: &forges.PRBranch{ + Head: forges.PRBranch{ Ref: "feature-branch", SHA: "abc123", }, @@ -281,8 +276,7 @@ func TestPRCheckout(t *testing.T) { name: "fork PR without clone URL", pr: &forges.PullRequest{ Number: 42, - Head: "feature", - HeadBranch: &forges.PRBranch{ + Head: forges.PRBranch{ Ref: "feature", SHA: "abc123", Fork: &forges.ForkInfo{ @@ -322,15 +316,15 @@ func TestPRCheckout(t *testing.T) { workDir = setupTestRepo(t, originDir) if tt.setupOrigin { - branchName := tt.pr.HeadBranch.Ref + branchName := tt.pr.Head.Ref pushBranchToRemote(t, workDir, "origin", branchName) } if tt.setupFork { forkDir := setupBareRepo(t) - tt.pr.HeadBranch.Fork.CloneURL = forkDir + tt.pr.Head.Fork.CloneURL = forkDir - branchName := tt.pr.HeadBranch.Ref + branchName := tt.pr.Head.Ref cmd := exec.Command("git", "remote", "add", "tempfork", forkDir) cmd.Dir = workDir if out, err := cmd.CombinedOutput(); err != nil { @@ -351,16 +345,7 @@ func TestPRCheckout(t *testing.T) { // Change to work directory for the test if workDir != "" { - oldWd, err := os.Getwd() - if err != nil { - t.Fatalf("getting working directory: %v", err) - } - if err := os.Chdir(workDir); err != nil { - t.Fatalf("changing to work directory: %v", err) - } - t.Cleanup(func() { - _ = os.Chdir(oldWd) - }) + t.Chdir(workDir) } // Set up mock forge @@ -436,65 +421,3 @@ func TestPRCheckout(t *testing.T) { }) } } - -func TestPRCheckoutUncommittedChanges(t *testing.T) { - if testing.Short() { - t.Skip("skipping git integration test in short mode") - } - - originDir := setupBareRepo(t) - workDir := setupTestRepo(t, originDir) - - // Push a branch to origin - pushBranchToRemote(t, workDir, "origin", "feature-branch") - - // Create uncommitted changes - testFile := filepath.Join(workDir, "uncommitted.txt") - if err := os.WriteFile(testFile, []byte("uncommitted\n"), 0644); err != nil { - t.Fatalf("writing test file: %v", err) - } - - oldWd, _ := os.Getwd() - if err := os.Chdir(workDir); err != nil { - t.Fatalf("changing to work directory: %v", err) - } - t.Cleanup(func() { _ = os.Chdir(oldWd) }) - - pr := &forges.PullRequest{ - Number: 42, - Head: "feature-branch", - HeadBranch: &forges.PRBranch{ - Ref: "feature-branch", - SHA: "abc123", - }, - } - - resolve.SetTestForge( - &mockForge{prService: &mockPRService{pr: pr}}, - "testowner", "testrepo", "github.com", - ) - t.Cleanup(resolve.ResetTestForge) - - // Without --force should fail - var buf bytes.Buffer - rootCmd.SetOut(&buf) - rootCmd.SetErr(&buf) - rootCmd.SetArgs([]string{"pr", "checkout", "42"}) - - err := rootCmd.Execute() - if err == nil { - t.Fatal("expected error for uncommitted changes, got nil") - } - if !strings.Contains(err.Error(), "uncommitted changes") { - t.Errorf("expected error about uncommitted changes, got: %v", err) - } - - // With --force should succeed - buf.Reset() - rootCmd.SetArgs([]string{"pr", "checkout", "42", "--force"}) - - err = rootCmd.Execute() - if err != nil { - t.Fatalf("unexpected error with --force: %v", err) - } -} diff --git a/internal/resolve/resolve.go b/internal/resolve/resolve.go index a38b3cf..953262e 100644 --- a/internal/resolve/resolve.go +++ b/internal/resolve/resolve.go @@ -39,6 +39,12 @@ func SetRemote(name string) { } } +// RemoteName returns the name of the git remote being used for resolution. +// This is "origin" by default, or whatever was set via SetRemote. +func RemoteName() string { + return remoteName +} + // SetHost forces a specific forge domain, taking precedence over FORGE_HOST, // --forge-type, and git remote detection. The CLI calls this from the --host // persistent flag. An empty string is ignored. diff --git a/types.go b/types.go index a24adfa..bc04c18 100644 --- a/types.go +++ b/types.go @@ -257,10 +257,8 @@ type PullRequest struct { Reviewers []User `json:"reviewers,omitempty"` Labels []Label `json:"labels,omitempty"` Milestone *Milestone `json:"milestone,omitempty"` - Head string `json:"head"` // head branch name (for backward compat) - Base string `json:"base"` // base branch name (for backward compat) - HeadBranch *PRBranch `json:"head_branch,omitempty"` // rich head branch info with repo - BaseBranch *PRBranch `json:"base_branch,omitempty"` // rich base branch info + Head PRBranch `json:"head"` + Base PRBranch `json:"base"` Mergeable bool `json:"mergeable"` Merged bool `json:"merged"` MergedBy *User `json:"merged_by,omitempty"`