diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index de23db13..76e94efc 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -4,7 +4,6 @@ on: push: branches: [ main, master, develop ] pull_request: - branches: [ main, master, develop ] workflow_dispatch: concurrency: diff --git a/.github/workflows/pr-checks.yaml b/.github/workflows/pr-checks.yaml index 4069d45f..6557835b 100644 --- a/.github/workflows/pr-checks.yaml +++ b/.github/workflows/pr-checks.yaml @@ -2,7 +2,6 @@ name: PR Checks on: pull_request: - branches: [ main, master, develop ] concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} diff --git a/.github/workflows/syfon-backend-e2e.yaml b/.github/workflows/syfon-backend-e2e.yaml index 62bb839b..cfc4474d 100644 --- a/.github/workflows/syfon-backend-e2e.yaml +++ b/.github/workflows/syfon-backend-e2e.yaml @@ -2,7 +2,6 @@ name: Syfon Backend E2E on: pull_request: - branches: [ main, master, develop ] workflow_dispatch: concurrency: diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 82d0ec2c..38b2fc29 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -4,7 +4,6 @@ on: push: branches: [ main, master, develop ] pull_request: - branches: [ main, master, develop ] workflow_dispatch: concurrency: diff --git a/README.md b/README.md index 9fa44367..48dc57d9 100644 --- a/README.md +++ b/README.md @@ -3,143 +3,115 @@ --- # NOTICE -git-drs is not yet fully compliant with DRS. It currently works against Gen3 DRS server. Full GA4GH DRS support is expected once v1.6 of the specification has been published. +`git-drs` is not a pure GA4GH DRS client. It targets Syfon/Gen3-style DRS workflows and uses extensions where repo-scale behavior requires them. --- [![Tests](https://github.com/calypr/git-drs/actions/workflows/test.yaml/badge.svg)](https://github.com/calypr/git-drs/actions/workflows/test.yaml) -**Git/DRS orchestration with optional Git LFS compatibility** +**Git/DRS orchestration with Git-compatible pointer workflows** -Git DRS manages Git-facing DRS workflows: local metadata, Git hooks, filter behavior, lookup/register/push/pull orchestration, and optional Git LFS compatibility. Provider-specific transfer, signed URL behavior, and direct cloud inspection live in client code outside this repo. +`git-drs` manages: + +- remote Gen3/Syfon configuration +- local DRS metadata +- pointer-aware push/pull orchestration +- bucket-scoped object reference workflows ## Key Features -- **Unified Workflow**: Manage both code and large data files using standard Git commands -- **DRS Integration**: Built-in support for Gen3 DRS servers -- **Multi-Remote Support**: Work with development, staging, and production servers in one repository -- **Automatic Processing**: Files are processed automatically during commits and pushes -- **Flexible Tracking**: Track individual files, patterns, or entire directories +- unified Git/data workflow around DRS-backed pointers +- Gen3/Syfon integration +- multiple remotes in one repository +- explicit file tracking and hydration +- metadata-only reference support for existing bucket objects ## How It Works -Git DRS works alongside Git LFS when you want LFS-compatible pointers and storage, while still supporting DRS-centric workflows: +At a high level: -1. **Initialization**: Set up repository and DRS server configuration -2. **Automatic Commits**: Create DRS objects during pre-commit hooks -3. **Automatic Pushes**: Register files with DRS servers and upload to configured storage -4. **On-Demand Downloads**: Pull specific files or patterns as needed +1. configure a remote for one `organization/project` +2. let `remote add` bootstrap repo-local `git-drs` state if needed +3. track file patterns with `git drs track` +4. add/commit/push normally +5. remove tracked pointers with `git drs rm` when you want repository deletion to reconcile with remote DRS state +5. hydrate pointer files later with `git drs pull` ## Quick Start -### Installation - ```bash -# Install Git LFS first -brew install git-lfs # macOS -git lfs install --skip-smudge - -# Install Git DRS -/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/calypr/git-drs/refs/heads/main/install.sh)" -- $GIT_DRS_VERSION - -# Install global Git filter configuration for git-drs git drs install -``` - -### Basic Usage - -```bash -# Initialize repository (one-time Git repo setup) -git drs init - -# Add DRS remote -git drs remote add gen3 production \ - --cred /path/to/credentials.json \ - --url https://calypr-public.ohsu.edu \ - --organization my-program \ - --project my-project \ - --bucket my-bucket - -# Required prerequisite (usually steward/admin setup): -# create bucket credentials, then map org/project to full storage roots before users run push/pull -git drs bucket add production \ - --bucket my-bucket \ - --region us-east-1 \ - --access-key "$AWS_ACCESS_KEY_ID" \ - --secret-key "$AWS_SECRET_ACCESS_KEY" \ - --s3-endpoint https://s3.amazonaws.com -git drs bucket add-organization production \ - --organization my-program \ - --path s3://my-bucket/my-program -git drs bucket add-project production \ - --organization my-program \ - --project my-project \ - --path s3://my-bucket/my-program/my-project - -# Track files -git lfs track "*.bam" +git drs remote add gen3 production HTAN_INT/BForePC --cred /path/to/credentials.json +git drs track "*.bam" git add .gitattributes - -# Add and commit files -git add my-file.bam -git commit -m "Add data file" -git push - -# Download files -git lfs pull -I "*.bam" +git add sample.bam +git commit -m "Add sample" +git drs push +git drs ls-files +git drs pull -I "*.bam" ``` -## Documentation - -For detailed setup and usage information: +## Current CLI Shape -- **[Getting Started](docs/getting-started.md)** - Repository setup and basic workflows -- **[Commands Reference](docs/commands.md)** - Complete command documentation -- **[Installation Guide](docs/installation.md)** - Platform-specific installation -- **[Troubleshooting](docs/troubleshooting.md)** - Common issues and solutions -- **[E2E Modes + Local Setup](docs/e2e-modes-and-local-setup.md)** - Local vs remote mode, server config, and end-to-end runbooks -- **[Cloud/Object Integration](docs/adding-s3-files.md)** - Adding files from provider URLs or configured bucket object keys -- **[Developer Guide](docs/developer-guide.md)** - Internals and development +The cleaned CLI intentionally removed legacy commands: -## Supported Servers +- removed: + - `git drs fetch` + - `git drs list` + - `git drs upload` + - `git drs download` +- `git drs pull` is hydration-only +- `git drs ls-files` is the local file inventory command +- `git drs remote add gen3` takes scope as `organization/project` -- **Gen3 Data Commons** (e.g., CALYPR) +Example: -## Supported Environments - -- **Local Development** environments -- **HPC Systems** (e.g., ARC) +```bash +git drs remote add gen3 production HTAN_INT/BForePC --cred /path/to/credentials.json +``` -## Commands Overview +## Bucket Mapping Model + +End users should not need to know the bucket name. + +Push and pull depend on server-side bucket mapping for the requested scope. That mapping is normally provisioned once by a steward/admin using the bucket commands. + +## Common Commands + +| Command | Description | +| --- | --- | +| `git drs install` | Install global `git-drs` filter config | +| `git drs init` | Explicitly initialize or repair repository-local `git-drs` state | +| `git drs remote add gen3 [remote] ` | Add or refresh a Gen3/Syfon remote | +| `git drs remote list` | List configured remotes | +| `git drs remote remove ` | Remove a configured DRS remote | +| `git drs remote set ` | Set the default remote | +| `git drs track ` | Track files or globs | +| `git drs untrack ` | Stop tracking files or globs | +| `git drs rm ...` | Remove tracked DRS/LFS files from Git | +| `git drs ls-files` | List tracked files and localization state | +| `git drs pull` | Hydrate pointer files in the current checkout | +| `git drs push` | Register/upload objects, reconcile committed deletes, and push refs | +| `git drs add-url` | Add an existing provider object by URL or scoped key | +| `git drs add-ref` | Add a local reference to an existing DRS object | +| `git drs query` | Query a DRS object by ID | +| `git drs copy-records` | Copy Syfon records between remotes for one scope | -| Command | Description | -| ---------------------- | ------------------------------------- | -| `git drs install` | Install global git-drs filter config | -| `git drs init` | Initialize repository | -| `git drs remote add` | Add a DRS remote server | -| `git drs remote list` | List configured remotes | -| `git drs remote set` | Set default remote | -| `git drs add-url` | Add files via provider URLs or configured bucket object keys | -| `git lfs track` | Track file patterns with LFS | -| `git lfs ls-files` | List tracked files | -| `git lfs pull` | Download tracked files | -| `git drs fetch` | Fetch metadata from DRS server | -| `git drs push` | Push objects to DRS server | +## Documentation -Use `--help` with any command for details. See [Commands Reference](docs/commands.md) for complete documentation. +- [Getting Started](docs/getting-started.md) +- [Commands Reference](docs/commands.md) +- [Troubleshooting](docs/troubleshooting.md) +- [Developer Guide](docs/developer-guide.md) +- [GA4GH DRS Scalability Gaps](docs/ga4gh-drs-scalability-gaps.md) ## Requirements -- Git LFS installed and configured -- Access credentials for your DRS server -- Go 1.24+ (for building from source) +- Git +- access credentials for the target Gen3/Syfon deployment +- Go 1.26.2+ for local builds ## Support -- **Issues**: [GitHub Issues](https://github.com/calypr/git-drs/issues) -- **Releases**: [GitHub Releases](https://github.com/calypr/git-drs/releases) -- **Documentation**: See `docs/` folder for detailed guides - -## License - -This project is part of the CALYPR data commons ecosystem. +- [GitHub Issues](https://github.com/calypr/git-drs/issues) +- [GitHub Releases](https://github.com/calypr/git-drs/releases) diff --git a/attic/issue-add-include-pattern-to-git-drs-pull.md b/attic/issue-add-include-pattern-to-git-drs-pull.md new file mode 100644 index 00000000..4217ab3b --- /dev/null +++ b/attic/issue-add-include-pattern-to-git-drs-pull.md @@ -0,0 +1,51 @@ +# Add `-I "pattern"` include filter support to `git drs pull` + +## Summary +Add include-pattern filtering to `git drs pull`, similar to legacy `git lfs pull -I "pattern"` workflows. + +## Motivation +Current `git drs pull` behavior pulls based on repository resolution without a user-facing path pattern filter. Users migrating from `git lfs pull -I` expect selective hydration of files by glob/path. + +## Proposed UX +Support: + +```bash +git drs pull -I "results/*.txt" +git drs pull -I "*.bam" -I "data/**" +git drs pull --include "path/to/file" +``` + +Optional: +- `--exclude` parity (if desired in same change or follow-up) + +## Proposed behavior +1. Parse one or more include patterns (`-I`, `--include`). +2. Resolve candidate pointers as usual. +3. Filter by repo-relative path match before download. +4. Download only matched objects; skip others with clear logging. +5. If no pattern supplied, preserve current default behavior. + +## Scope +- `cmd/pull/main.go` CLI flags and pull selection pipeline +- pointer/path inventory layer (where path<->OID candidates are produced) +- docs: `docs/commands.md`, `docs/getting-started.md`, `docs/troubleshooting.md` +- tests for include filtering semantics + +## Acceptance criteria +- [ ] `git drs pull -I ""` works for a single pattern. +- [ ] Repeated `-I` flags are supported. +- [ ] Include matching is against repo-relative paths. +- [ ] Default `git drs pull` behavior unchanged when no `-I` is passed. +- [ ] Help text documents pattern syntax and examples. +- [ ] Unit/integration tests cover positive and negative matches. + +## Testing matrix +- Single file exact path include. +- Wildcard include (`*.bam`, `data/**`). +- Multiple `-I` values. +- No matches (should no-op cleanly and return success unless policy says otherwise). +- Mixed matched/unmatched objects in same pull run. + +## Notes +This closes a usability gap for users transitioning from `git lfs` CLI habits to `git drs` commands while keeping pull behavior explicit and predictable. + diff --git a/cmd/addurl/main_test.go b/cmd/addurl/main_test.go index 26a6456a..060dea16 100644 --- a/cmd/addurl/main_test.go +++ b/cmd/addurl/main_test.go @@ -19,7 +19,6 @@ import ( "github.com/calypr/git-drs/internal/config" "github.com/calypr/git-drs/internal/drsobject" "github.com/calypr/git-drs/internal/gitrepo" - "github.com/calypr/git-drs/internal/lfs" "github.com/calypr/git-drs/internal/precommit_cache" sycloud "github.com/calypr/syfon/client/cloud" ) @@ -100,9 +99,9 @@ func TestRunAddURL_WritesPointerAndLFSObject(t *testing.T) { t.Fatalf("service.Run error: %v", err) } - oid, err := lfs.SyntheticOIDFromETag("abcd1234") + oid, err := placeholderOIDForUnknownSHA("abcd1234", "s3://bucket/path/to/file.bin") if err != nil { - t.Fatalf("SyntheticOIDFromETag: %v", err) + t.Fatalf("placeholderOIDForUnknownSHA: %v", err) } pointerPath := filepath.Join(tempDir, "path/to/file.bin") @@ -120,15 +119,8 @@ func TestRunAddURL_WritesPointerAndLFSObject(t *testing.T) { } lfsObject := filepath.Join(lfsRoot, "objects", oid[0:2], oid[2:4], oid) - if _, err := os.Stat(lfsObject); err != nil { - t.Fatalf("expected LFS object at %s: %v", lfsObject, err) - } - sentinel, err := os.ReadFile(lfsObject) - if err != nil { - t.Fatalf("read sentinel: %v", err) - } - if !lfs.IsAddURLSentinelBytes(sentinel) { - t.Fatalf("expected add-url sentinel payload, got: %q", string(sentinel)) + if _, err := os.Stat(lfsObject); !os.IsNotExist(err) { + t.Fatalf("expected no local LFS object payload at %s, got err=%v", lfsObject, err) } drsObject, err := drsobject.ReadObject(common.DRS_OBJS_PATH, oid) @@ -143,6 +135,26 @@ func TestRunAddURL_WritesPointerAndLFSObject(t *testing.T) { } } +func TestPlaceholderOIDForUnknownSHA(t *testing.T) { + oid1, err := placeholderOIDForUnknownSHA("etag-abc", "s3://bucket/key") + if err != nil { + t.Fatalf("placeholderOIDForUnknownSHA: %v", err) + } + oid2, err := placeholderOIDForUnknownSHA(`"etag-abc"`, "s3://bucket/key") + if err != nil { + t.Fatalf("placeholderOIDForUnknownSHA quoted: %v", err) + } + if oid1 != oid2 { + t.Fatalf("expected trimmed etag handling to be stable: %s vs %s", oid1, oid2) + } + if len(oid1) != 64 { + t.Fatalf("expected 64-char oid, got %q", oid1) + } + if _, err := placeholderOIDForUnknownSHA("", "s3://bucket/key"); err == nil { + t.Fatal("expected empty etag error") + } +} + func TestParseAddURLInput_DoesNotRequireAWSFlags(t *testing.T) { cmd := NewCommand() in, err := parseAddURLInput(cmd, []string{"gs://bucket/path/to/file.bin"}) diff --git a/cmd/addurl/service.go b/cmd/addurl/service.go index 79ad6195..e50ca8c1 100644 --- a/cmd/addurl/service.go +++ b/cmd/addurl/service.go @@ -2,9 +2,10 @@ package addurl import ( "context" + "crypto/sha256" "fmt" "log/slog" - "os" + "strings" "github.com/calypr/git-drs/internal/common" "github.com/calypr/git-drs/internal/config" @@ -186,26 +187,29 @@ func writeAddURLDrsObject(builder drsobject.Builder, file addURLDrsFile, objectP return drsObj, nil } -// ensureLFSObject ensures the LFS object identified by objectInfo exists in the -// repository's LFS storage. If SHA256 is provided, it is trusted and returned. -// Otherwise we create a sentinel object and synthetic OID derived from ETag, -// deferring true checksum validation to first real data use. +// ensureLFSObject returns the LFS pointer OID to use for the add-url target. +// If SHA256 is provided, it is trusted and returned. Otherwise we derive a +// deterministic placeholder OID from provider identity without writing any +// local LFS object payload. func (s *AddURLService) ensureLFSObject(ctx context.Context, objectInfo *sycloud.ObjectInfo, input addURLInput, lfsRoot string) (string, error) { _ = ctx + _ = lfsRoot if input.sha256 != "" { return input.sha256, nil } - oid, err := lfs.SyntheticOIDFromETag(objectInfo.ETag) - if err != nil { - return "", err - } - objPath, err := lfs.WriteAddURLSentinelObject(lfsRoot, oid, objectInfo.ETag, input.objectURL) - if err != nil { - return "", err + return placeholderOIDForUnknownSHA(objectInfo.ETag, input.objectURL) +} + +func placeholderOIDForUnknownSHA(etag string, sourceURL string) (string, error) { + e := strings.TrimSpace(strings.Trim(etag, `"`)) + src := strings.TrimSpace(sourceURL) + if e == "" { + return "", fmt.Errorf("etag is required for placeholder oid") } - if _, err := fmt.Fprintf(os.Stderr, "Added add-url sentinel object at %s\n", objPath); err != nil { - return "", fmt.Errorf("stderr write: %w", err) + if src == "" { + return "", fmt.Errorf("source URL is required for placeholder oid") } - return oid, nil + sum := sha256.Sum256([]byte("git-drs-add-url-placeholder:v2\netag=" + e + "\nsource=" + src + "\n")) + return fmt.Sprintf("%x", sum[:]), nil } diff --git a/cmd/copyrecords/main.go b/cmd/copyrecords/main.go new file mode 100644 index 00000000..c494190c --- /dev/null +++ b/cmd/copyrecords/main.go @@ -0,0 +1,350 @@ +package copyrecords + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "strings" + + "github.com/calypr/git-drs/internal/config" + "github.com/calypr/git-drs/internal/drslog" + drsapi "github.com/calypr/syfon/apigen/client/drs" + internalapi "github.com/calypr/syfon/apigen/client/internalapi" + syservices "github.com/calypr/syfon/client/services" + "github.com/spf13/cobra" +) + +var ( + batchSize int +) + +type copyStats struct { + SourceSeen int + Created int + Updated int + Unchanged int + Written int +} + +type indexAPI interface { + List(ctx context.Context, opts syservices.ListRecordsOptions) (internalapi.ListRecordsResponse, error) + BulkDocuments(ctx context.Context, dids []string) ([]internalapi.InternalRecordResponse, error) + CreateBulk(ctx context.Context, req internalapi.BulkCreateRequest) (internalapi.ListRecordsResponse, error) +} + +var Cmd = &cobra.Command{ + Use: "copy-records [source-remote] ", + Short: "Copy Syfon records between remotes for one organization/project scope", + Long: "Read all Syfon records for a source organization/project scope and bulk load them into a target Syfon instance, only merging controlled_access and access_methods for records that already exist on the target.", + Args: cobra.RangeArgs(2, 3), + RunE: func(cmd *cobra.Command, args []string) error { + logger := drslog.GetLogger() + cfg, err := config.LoadConfig() + if err != nil { + return fmt.Errorf("error loading config: %w", err) + } + + sourceRemote := "" + targetRemote := "" + scopeArg := "" + if len(args) == 2 { + targetRemote = args[0] + scopeArg = args[1] + } else { + sourceRemote = args[0] + targetRemote = args[1] + scopeArg = args[2] + } + + srcRemoteName, err := cfg.GetRemoteOrDefault(sourceRemote) + if err != nil { + return fmt.Errorf("error resolving source remote: %w", err) + } + if strings.TrimSpace(targetRemote) == "" { + return fmt.Errorf("target remote is required") + } + dstRemoteName := config.Remote(targetRemote) + if srcRemoteName == dstRemoteName { + return fmt.Errorf("source and target remotes must be different") + } + + srcCfg := cfg.GetRemote(srcRemoteName) + if srcCfg == nil { + return fmt.Errorf("source remote %q not found", srcRemoteName) + } + + org, proj, err := parseScopeArg(scopeArg) + if err != nil { + return err + } + + srcCtx, err := cfg.GetRemoteClient(srcRemoteName, logger) + if err != nil { + return fmt.Errorf("error creating source client: %w", err) + } + dstCtx, err := cfg.GetRemoteClient(dstRemoteName, logger) + if err != nil { + return fmt.Errorf("error creating target client: %w", err) + } + + stats, err := copyProjectRecords(cmd.Context(), logger, srcCtx.Client.Index(), dstCtx.Client.Index(), org, proj, batchSize) + if err != nil { + return err + } + + logger.Info("copy-records complete", + "source_remote", srcRemoteName, + "target_remote", dstRemoteName, + "organization", org, + "project", proj, + "source_seen", stats.SourceSeen, + "created", stats.Created, + "updated", stats.Updated, + "unchanged", stats.Unchanged, + "written", stats.Written, + ) + return nil + }, +} + +func init() { + Cmd.Flags().IntVar(&batchSize, "batch-size", 250, "records per source page and target bulk write") +} + +func parseScopeArg(raw string) (string, string, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return "", "", fmt.Errorf("scope is required and must be in organization/project form") + } + parts := strings.Split(raw, "/") + if len(parts) != 2 { + return "", "", fmt.Errorf("invalid scope %q: expected organization/project", raw) + } + org := strings.TrimSpace(parts[0]) + project := strings.TrimSpace(parts[1]) + if org == "" || project == "" { + return "", "", fmt.Errorf("invalid scope %q: expected organization/project", raw) + } + return org, project, nil +} + +func copyProjectRecords(ctx context.Context, logger *slog.Logger, src indexAPI, dst indexAPI, org, project string, batchSize int) (copyStats, error) { + if batchSize <= 0 { + batchSize = 250 + } + + stats := copyStats{} + page := 1 + for { + listResp, err := src.List(ctx, syservices.ListRecordsOptions{ + Organization: org, + ProjectID: project, + Limit: batchSize, + Page: page, + }) + if err != nil { + return stats, fmt.Errorf("source list failed for %s/%s page %d: %w", org, project, page, err) + } + records := []internalapi.InternalRecord{} + if listResp.Records != nil { + records = *listResp.Records + } + if len(records) == 0 { + break + } + stats.SourceSeen += len(records) + + toWrite, batchStats, err := buildMergedBatch(ctx, dst, records) + if err != nil { + return stats, err + } + stats.Created += batchStats.Created + stats.Updated += batchStats.Updated + stats.Unchanged += batchStats.Unchanged + + if len(toWrite) > 0 { + resp, err := dst.CreateBulk(ctx, internalapi.BulkCreateRequest{Records: toWrite}) + if err != nil { + return stats, fmt.Errorf("target bulk create failed on page %d: %w", page, err) + } + if resp.Records != nil { + stats.Written += len(*resp.Records) + } else { + stats.Written += len(toWrite) + } + } + + if logger != nil { + logger.Info("copy-records batch complete", + "organization", org, + "project", project, + "page", page, + "source_records", len(records), + "created", batchStats.Created, + "updated", batchStats.Updated, + "unchanged", batchStats.Unchanged, + "written", len(toWrite), + ) + } + + if len(records) < batchSize { + break + } + page++ + } + + return stats, nil +} + +func buildMergedBatch(ctx context.Context, dst indexAPI, source []internalapi.InternalRecord) ([]internalapi.InternalRecord, copyStats, error) { + stats := copyStats{} + if len(source) == 0 { + return nil, stats, nil + } + + dids := make([]string, 0, len(source)) + for _, rec := range source { + did := strings.TrimSpace(rec.Did) + if did == "" { + continue + } + dids = append(dids, did) + } + + existing, err := dst.BulkDocuments(ctx, dids) + if err != nil { + return nil, stats, fmt.Errorf("target bulk documents failed: %w", err) + } + existingByDID := make(map[string]internalapi.InternalRecord, len(existing)) + for _, rec := range existing { + existingByDID[strings.TrimSpace(rec.Did)] = recordResponseToRecord(rec) + } + + out := make([]internalapi.InternalRecord, 0, len(source)) + for _, src := range source { + did := strings.TrimSpace(src.Did) + if did == "" { + continue + } + if dstRec, ok := existingByDID[did]; ok { + merged, changed := mergeExistingRecord(dstRec, src) + if changed { + out = append(out, merged) + stats.Updated++ + } else { + stats.Unchanged++ + } + continue + } + out = append(out, src) + stats.Created++ + } + + return out, stats, nil +} + +func mergeExistingRecord(dst, src internalapi.InternalRecord) (internalapi.InternalRecord, bool) { + merged := dst + changed := false + + controlledAccess := mergeStringLists(dst.ControlledAccess, src.ControlledAccess) + if !equalStringPointers(merged.ControlledAccess, controlledAccess) { + merged.ControlledAccess = controlledAccess + changed = true + } + + accessMethods := mergeAccessMethods(dst.AccessMethods, src.AccessMethods) + if !equalAccessMethodPointers(merged.AccessMethods, accessMethods) { + merged.AccessMethods = accessMethods + changed = true + } + + return merged, changed +} + +func recordResponseToRecord(in internalapi.InternalRecordResponse) internalapi.InternalRecord { + return internalapi.InternalRecord{ + Did: in.Did, + AccessMethods: in.AccessMethods, + ControlledAccess: in.ControlledAccess, + CreatedTime: in.CreatedTime, + Description: in.Description, + FileName: in.FileName, + Hashes: in.Hashes, + Organization: in.Organization, + Project: in.Project, + Size: in.Size, + UpdatedTime: in.UpdatedTime, + Version: in.Version, + } +} + +func mergeStringLists(left, right *[]string) *[]string { + seen := map[string]struct{}{} + out := make([]string, 0) + for _, list := range []*[]string{left, right} { + if list == nil { + continue + } + for _, raw := range *list { + val := strings.TrimSpace(raw) + if val == "" { + continue + } + if _, ok := seen[val]; ok { + continue + } + seen[val] = struct{}{} + out = append(out, val) + } + } + if len(out) == 0 { + return nil + } + return &out +} + +func mergeAccessMethods(left, right *[]drsapi.AccessMethod) *[]drsapi.AccessMethod { + seen := map[string]struct{}{} + out := make([]drsapi.AccessMethod, 0) + for _, list := range []*[]drsapi.AccessMethod{left, right} { + if list == nil { + continue + } + for _, method := range *list { + key := canonicalAccessMethod(method) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + out = append(out, method) + } + } + if len(out) == 0 { + return nil + } + return &out +} + +func canonicalAccessMethod(method drsapi.AccessMethod) string { + b, err := json.Marshal(method) + if err != nil { + return fmt.Sprintf("%s|%v", method.Type, method.AccessId) + } + return string(b) +} + +func equalStringPointers(a, b *[]string) bool { + return equalJSON(a, b) +} + +func equalAccessMethodPointers(a, b *[]drsapi.AccessMethod) bool { + return equalJSON(a, b) +} + +func equalJSON(a, b any) bool { + ab, _ := json.Marshal(a) + bb, _ := json.Marshal(b) + return string(ab) == string(bb) +} diff --git a/cmd/copyrecords/main_test.go b/cmd/copyrecords/main_test.go new file mode 100644 index 00000000..6e1e4528 --- /dev/null +++ b/cmd/copyrecords/main_test.go @@ -0,0 +1,139 @@ +package copyrecords + +import ( + "context" + "testing" + + drsapi "github.com/calypr/syfon/apigen/client/drs" + internalapi "github.com/calypr/syfon/apigen/client/internalapi" + syservices "github.com/calypr/syfon/client/services" +) + +type fakeIndexAPI struct { + listResp internalapi.ListRecordsResponse + bulkDocsResp []internalapi.InternalRecordResponse + createBulkReq []internalapi.BulkCreateRequest +} + +func (f *fakeIndexAPI) List(ctx context.Context, opts syservices.ListRecordsOptions) (internalapi.ListRecordsResponse, error) { + return f.listResp, nil +} + +func (f *fakeIndexAPI) BulkDocuments(ctx context.Context, dids []string) ([]internalapi.InternalRecordResponse, error) { + return f.bulkDocsResp, nil +} + +func (f *fakeIndexAPI) CreateBulk(ctx context.Context, req internalapi.BulkCreateRequest) (internalapi.ListRecordsResponse, error) { + f.createBulkReq = append(f.createBulkReq, req) + return internalapi.ListRecordsResponse{Records: &req.Records}, nil +} + +func TestMergeExistingRecord_UnionsControlledAccessAndAccessMethodsOnly(t *testing.T) { + dstName := "target.bin" + srcName := "source.bin" + desc := "keep target description" + leftCA := []string{"/organization/A/project/P1"} + rightCA := []string{"/organization/A/project/P1", "/organization/A/project/P2"} + leftMethods := []drsapi.AccessMethod{{ + Type: drsapi.AccessMethodTypeS3, + AccessUrl: &struct { + Headers *[]string `json:"headers,omitempty"` + Url string `json:"url"` + }{Url: "s3://bucket/one"}, + }} + rightMethods := []drsapi.AccessMethod{ + leftMethods[0], + { + Type: drsapi.AccessMethodTypeHttps, + AccessUrl: &struct { + Headers *[]string `json:"headers,omitempty"` + Url string `json:"url"` + }{Url: "https://example.org/two"}, + }, + } + + merged, changed := mergeExistingRecord( + internalapi.InternalRecord{ + Did: "did-1", + FileName: &dstName, + Description: &desc, + ControlledAccess: &leftCA, + AccessMethods: &leftMethods, + }, + internalapi.InternalRecord{ + Did: "did-1", + FileName: &srcName, + ControlledAccess: &rightCA, + AccessMethods: &rightMethods, + }, + ) + + if !changed { + t.Fatalf("expected merge to report a change") + } + if merged.FileName == nil || *merged.FileName != dstName { + t.Fatalf("expected target metadata to be preserved, got %+v", merged.FileName) + } + if merged.Description == nil || *merged.Description != desc { + t.Fatalf("expected target description to be preserved") + } + if merged.ControlledAccess == nil || len(*merged.ControlledAccess) != 2 { + t.Fatalf("expected merged controlled access union, got %+v", merged.ControlledAccess) + } + if merged.AccessMethods == nil || len(*merged.AccessMethods) != 2 { + t.Fatalf("expected merged access method union, got %+v", merged.AccessMethods) + } +} + +func TestBuildMergedBatch_CreatesNewAndUpdatesExisting(t *testing.T) { + srcCA := []string{"/organization/A/project/P1"} + newCA := []string{"/organization/A/project/P2"} + srcMethods := []drsapi.AccessMethod{{ + Type: drsapi.AccessMethodTypeS3, + AccessUrl: &struct { + Headers *[]string `json:"headers,omitempty"` + Url string `json:"url"` + }{Url: "s3://bucket/a"}, + }} + newMethods := []drsapi.AccessMethod{{ + Type: drsapi.AccessMethodTypeHttps, + AccessUrl: &struct { + Headers *[]string `json:"headers,omitempty"` + Url string `json:"url"` + }{Url: "https://example.org/b"}, + }} + + target := &fakeIndexAPI{ + bulkDocsResp: []internalapi.InternalRecordResponse{ + { + Did: "did-existing", + ControlledAccess: &srcCA, + AccessMethods: &srcMethods, + }, + }, + } + + source := []internalapi.InternalRecord{ + { + Did: "did-existing", + ControlledAccess: &newCA, + AccessMethods: &newMethods, + }, + { + Did: "did-new", + ControlledAccess: &srcCA, + AccessMethods: &srcMethods, + }, + } + + out, stats, err := buildMergedBatch(context.Background(), target, source) + if err != nil { + t.Fatalf("buildMergedBatch error: %v", err) + } + if len(out) != 2 { + t.Fatalf("expected 2 output records, got %d", len(out)) + } + if stats.Created != 1 || stats.Updated != 1 || stats.Unchanged != 0 { + t.Fatalf("unexpected stats: %+v", stats) + } +} diff --git a/cmd/download/main.go b/cmd/download/main.go deleted file mode 100644 index 599cd2e0..00000000 --- a/cmd/download/main.go +++ /dev/null @@ -1,100 +0,0 @@ -package download - -import ( - "context" - "fmt" - "path/filepath" - "strings" - - "github.com/calypr/git-drs/internal/common" - "github.com/calypr/git-drs/internal/config" - "github.com/calypr/git-drs/internal/drslog" - "github.com/calypr/git-drs/internal/drsremote" - drsapi "github.com/calypr/syfon/apigen/client/drs" - sydownload "github.com/calypr/syfon/client/transfer/download" - "github.com/spf13/cobra" -) - -var remote string -var outdir string - -// Cmd line declaration -var Cmd = &cobra.Command{ - Use: "download ", - Short: "Download a file from a DRS server", - Long: "Download a file from a DRS server, without creating an LFS pointer", - Args: cobra.MinimumNArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - - logger := drslog.GetLogger() - - config, err := config.LoadConfig() - if err != nil { - return err - } - - remoteName, err := config.GetRemoteOrDefault(remote) - if err != nil { - logger.Error(fmt.Sprintf("Error getting remote: %v", err)) - return err - } - - client, err := config.GetRemoteClient(remoteName, logger) - if err != nil { - return err - } - for _, src := range args { - obj, err := client.Client.DRS().GetObject(context.Background(), src) - if err != nil { - logger.Error(fmt.Sprintf("Error downloading object %s: %v", src, err)) - } else { - common.PrintDRSObject(obj, false) - dstName := src - if obj.Name != nil && *obj.Name != "" { - dstName = filepath.Base(*obj.Name) - } - dstPath := filepath.Join(outdir, dstName) - logger.Info(fmt.Sprintf("Downloading object %s to path %s", src, dstPath)) - accessURL, err := resolveAccessURL(cmd.Context(), client, obj) - if err != nil { - logger.Error(fmt.Sprintf("Error resolving access URL for object %s: %v", src, err)) - continue - } - if err := drsremote.DownloadResolvedToPath(cmd.Context(), client, obj.Id, dstPath, &obj, accessURL, sydownload.DownloadOptions{ - MultipartThreshold: 5 * 1024 * 1024, - Concurrency: 2, - ChunkSize: 64 * 1024 * 1024, - }); err != nil { - logger.Error(fmt.Sprintf("Error downloading object %s to path %s: %v", src, dstPath, err)) - } else { - logger.Info(fmt.Sprintf("Successfully downloaded object %s to path %s", src, dstPath)) - } - } - } - - return nil - }, -} - -func init() { - Cmd.Flags().StringVarP(&remote, "remote", "r", "", "target remote DRS server (default: default_remote)") - Cmd.Flags().StringVarP(&outdir, "outdir", "o", ".", "output directory for downloaded files") -} - -func resolveAccessURL(ctx context.Context, remote *config.GitContext, obj drsapi.DrsObject) (*drsapi.AccessURL, error) { - if remote == nil || remote.Client == nil { - return nil, fmt.Errorf("DRS client unavailable") - } - if obj.AccessMethods == nil || len(*obj.AccessMethods) == 0 { - return nil, fmt.Errorf("no access methods available for DRS object %s", obj.Id) - } - accessType := strings.TrimSpace(string((*obj.AccessMethods)[0].Type)) - if accessType == "" { - return nil, fmt.Errorf("no access type found in access method for DRS object %s", obj.Id) - } - accessURL, err := remote.Client.DRS().GetAccessURL(ctx, obj.Id, accessType) - if err != nil { - return nil, err - } - return &accessURL, nil -} diff --git a/cmd/fetch/fetch_test.go b/cmd/fetch/fetch_test.go deleted file mode 100644 index 37766718..00000000 --- a/cmd/fetch/fetch_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package fetch - -import ( - "testing" - - "github.com/calypr/git-drs/internal/testutils" - "github.com/stretchr/testify/assert" -) - -func TestFetchCmdArgs(t *testing.T) { - // Test with no arguments (valid) - err := Cmd.Args(Cmd, []string{}) - assert.NoError(t, err) - - // Test with 1 argument (valid) - err = Cmd.Args(Cmd, []string{"origin"}) - assert.NoError(t, err) - - // Test with multiple arguments (invalid) - err = Cmd.Args(Cmd, []string{"origin", "extra"}) - assert.Error(t, err) -} - -func TestFetchRun_Error(t *testing.T) { - _ = testutils.SetupTestGitRepo(t) - // No config, should error - err := Cmd.RunE(Cmd, []string{}) - assert.Error(t, err) -} - -func TestFetchRun_InvalidRemote(t *testing.T) { - tmpDir := testutils.SetupTestGitRepo(t) - testutils.CreateDefaultTestConfig(t, tmpDir) - // Fetch from non-existent remote - err := Cmd.RunE(Cmd, []string{"no-remote"}) - assert.Error(t, err) -} diff --git a/cmd/fetch/main.go b/cmd/fetch/main.go deleted file mode 100644 index 0acf089a..00000000 --- a/cmd/fetch/main.go +++ /dev/null @@ -1,66 +0,0 @@ -package fetch - -import ( - "fmt" - "os/exec" - "strings" - - "github.com/calypr/git-drs/internal/config" - "github.com/calypr/git-drs/internal/drslog" - "github.com/spf13/cobra" -) - -var runCommand = func(name string, args ...string) ([]byte, error) { - cmd := exec.Command(name, args...) - return cmd.CombinedOutput() -} - -// Cmd line declaration -var Cmd = &cobra.Command{ - Use: "fetch [remote-name]", - Short: "Fetch LFS objects from remote via standard git-lfs", - Args: func(cmd *cobra.Command, args []string) error { - if len(args) > 1 { - cmd.SilenceUsage = false - return fmt.Errorf("error: accepts at most 1 argument (remote name), received %d\n\nUsage: %s\n\nSee 'git drs fetch --help' for more details", len(args), cmd.UseLine()) - } - return nil - }, - RunE: func(cmd *cobra.Command, args []string) error { - logger := drslog.GetLogger() - - cfg, err := config.LoadConfig() - if err != nil { - return fmt.Errorf("error loading config: %v", err) - } - - var remote config.Remote - if len(args) > 0 { - remote = config.Remote(args[0]) - } else { - remote, err = cfg.GetDefaultRemote() - if err != nil { - logger.Error(fmt.Sprintf("Error getting remote: %v", err)) - return err - } - } - - drsClient, err := cfg.GetRemoteClient(remote, logger) - if err != nil { - logger.Error(fmt.Sprintf("\nerror creating DRS client: %s", err)) - return err - } - _ = drsClient // Remote validation only. - - out, err := runCommand("git", "lfs", "pull", string(remote)) - if err != nil { - msg := strings.TrimSpace(string(out)) - if msg == "" { - msg = err.Error() - } - return fmt.Errorf("git lfs pull failed for remote %q: %s", remote, msg) - } - - return nil - }, -} diff --git a/cmd/initialize/main.go b/cmd/initialize/main.go index 95fcdf65..dd9d6fca 100644 --- a/cmd/initialize/main.go +++ b/cmd/initialize/main.go @@ -39,57 +39,132 @@ var Cmd = &cobra.Command{ }, RunE: func(cmd *cobra.Command, args []string) error { logg := drslog.GetLogger() - - // check if .git dir exists to ensure you're in a git repository - _, err := gitrepo.GitTopLevel() - if err != nil { - return fmt.Errorf("error: not in a git repository. Please run this command in the root of your git repository") + if err := InitializeRepo(logg); err != nil { + return err } + logg.Debug(fmt.Sprintf("Using %d concurrent transfers", transfers)) + return nil + }, +} - // create config file if it doesn't exist - err = config.CreateEmptyConfig() - if err != nil { - return fmt.Errorf("error: unable to create config file: %v", err) - } +// InitializeRepo applies git-drs repository-local setup to the current git repository. +// It is safe to call repeatedly. +func InitializeRepo(logg *slog.Logger) error { + // check if .git dir exists to ensure you're in a git repository + _, err := gitrepo.GitTopLevel() + if err != nil { + return fmt.Errorf("error: not in a git repository. Please run this command in the root of your git repository") + } - // load the config - _, err = config.LoadConfig() - if err != nil { - logg.Debug(fmt.Sprintf("We should probably fix this: %v", err)) - return fmt.Errorf("error: unable to load config file: %v", err) - } + // create config file if it doesn't exist + err = config.CreateEmptyConfig() + if err != nil { + return fmt.Errorf("error: unable to create config file: %v", err) + } - // create drs directories - drsDir := common.DRS_DIR - drsLfsObjsDir := common.DRS_OBJS_PATH - if err := os.MkdirAll(drsDir, 0755); err != nil { - return fmt.Errorf("error: unable to create drs directory: %v", err) - } - if err := os.MkdirAll(drsLfsObjsDir, 0755); err != nil { - return fmt.Errorf("error: unable to create drs lfs objects directory: %v", err) - } + // load the config + _, err = config.LoadConfig() + if err != nil { + logg.Debug(fmt.Sprintf("We should probably fix this: %v", err)) + return fmt.Errorf("error: unable to load config file: %v", err) + } - err = initGitConfig() - if err != nil { - return fmt.Errorf("error initializing git-drs repository config: %v", err) - } + // create drs directories + drsDir := common.DRS_DIR + drsLfsObjsDir := common.DRS_OBJS_PATH + if err := os.MkdirAll(drsDir, 0755); err != nil { + return fmt.Errorf("error: unable to create drs directory: %v", err) + } + if err := os.MkdirAll(drsLfsObjsDir, 0755); err != nil { + return fmt.Errorf("error: unable to create drs lfs objects directory: %v", err) + } - // install pre-push hook - err = installPrePushHook(logg) - if err != nil { - return fmt.Errorf("error installing pre-push hook: %v", err) - } - // install pre-commit hook - err = installPreCommitHook(logg) - if err != nil { - return fmt.Errorf("error installing pre-commit hook: %v", err) - } + err = initGitConfig() + if err != nil { + return fmt.Errorf("error initializing git-drs repository config: %v", err) + } - // final logs - logg.Debug("Git DRS initialized") - logg.Debug(fmt.Sprintf("Using %d concurrent transfers", transfers)) + // install pre-push hook + err = installPrePushHook(logg) + if err != nil { + return fmt.Errorf("error installing pre-push hook: %v", err) + } + // install pre-commit hook + err = installPreCommitHook(logg) + if err != nil { + return fmt.Errorf("error installing pre-commit hook: %v", err) + } + + logg.Debug("Git DRS initialized") + return nil +} + +// EnsureInitialized applies initialization only when the repository does not +// already appear to have git-drs local setup installed. +func EnsureInitialized(logg *slog.Logger) error { + initialized, err := isInitialized() + if err != nil { + return err + } + if initialized { return nil - }, + } + return InitializeRepo(logg) +} + +func isInitialized() (bool, error) { + if _, err := gitrepo.GitTopLevel(); err != nil { + return false, fmt.Errorf("error: not in a git repository. Please run this command in the root of your git repository") + } + + if _, err := os.Stat(common.DRS_DIR); err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, fmt.Errorf("error checking git-drs directory: %v", err) + } + + if val, err := gitrepo.GetGitConfigString("filter.drs.process"); err != nil || strings.TrimSpace(val) != "git-drs filter" { + return false, err + } + if val, err := gitrepo.GetGitConfigString("filter.drs.clean"); err != nil || strings.TrimSpace(val) != "git-drs clean -- %f" { + return false, err + } + if val, err := gitrepo.GetGitConfigString("filter.drs.smudge"); err != nil || strings.TrimSpace(val) != "git-drs smudge -- %f" { + return false, err + } + if val, err := gitrepo.GetGitConfigString("filter.drs.required"); err != nil || strings.TrimSpace(val) != "true" { + return false, err + } + + preCommitInstalled, err := hookContains("pre-commit", "git drs precommit") + if err != nil { + return false, err + } + if !preCommitInstalled { + return false, nil + } + + prePushInstalled, err := hookContains("pre-push", "git drs pre-push-prepare") + if err != nil { + return false, err + } + return prePushInstalled, nil +} + +func hookContains(name, marker string) (bool, error) { + hooksDir, err := gitrepo.GetGitHooksDir() + if err != nil { + return false, fmt.Errorf("unable to get hooks directory: %w", err) + } + content, err := os.ReadFile(filepath.Join(hooksDir, name)) + if err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, err + } + return strings.Contains(string(content), marker), nil } func initGitConfig() error { @@ -99,7 +174,10 @@ func initGitConfig() error { // Use git-drs as the long-running filter-process handler. // This replaces the default git-lfs smudge/clean per-invocation commands // with a single persistent process that calls the DRS transfer stack directly. - "filter.drs.process": "git-drs filter", + "filter.drs.clean": "git-drs clean -- %f", + "filter.drs.smudge": "git-drs smudge -- %f", + "filter.drs.process": "git-drs filter", + "filter.drs.required": "true", // Canonical git-drs config keys consumed by clients. "drs.upsert": strconv.FormatBool(upsert), "drs.multipart-threshold": strconv.Itoa(multiPartThreshold), diff --git a/cmd/initialize/main_test.go b/cmd/initialize/main_test.go index 1126a2dd..0c2beab3 100644 --- a/cmd/initialize/main_test.go +++ b/cmd/initialize/main_test.go @@ -6,6 +6,7 @@ import ( "strings" "testing" + "github.com/calypr/git-drs/internal/common" "github.com/calypr/git-drs/internal/drslog" "github.com/calypr/git-drs/internal/gitrepo" "github.com/calypr/git-drs/internal/testutils" @@ -105,4 +106,38 @@ func TestInitConfigValues(t *testing.T) { check("lfs.concurrenttransfers", "8") check("lfs.allowincompletepush", "false") + check("filter.drs.clean", "git-drs clean -- %f") + check("filter.drs.smudge", "git-drs smudge -- %f") + check("filter.drs.process", "git-drs filter") + check("filter.drs.required", "true") +} + +func TestEnsureInitialized(t *testing.T) { + testutils.SetupTestGitRepo(t) + logger := drslog.NewNoOpLogger() + + if err := EnsureInitialized(logger); err != nil { + t.Fatalf("EnsureInitialized error: %v", err) + } + if err := EnsureInitialized(logger); err != nil { + t.Fatalf("EnsureInitialized second call error: %v", err) + } + + if _, err := os.Stat(common.DRS_DIR); err != nil { + t.Fatalf("expected %s to exist: %v", common.DRS_DIR, err) + } + filterProcess, err := gitrepo.GetGitConfigString("filter.drs.process") + if err != nil { + t.Fatalf("GetGitConfigString(filter.drs.process): %v", err) + } + if filterProcess != "git-drs filter" { + t.Fatalf("unexpected filter.drs.process: %q", filterProcess) + } + filterClean, err := gitrepo.GetGitConfigString("filter.drs.clean") + if err != nil { + t.Fatalf("GetGitConfigString(filter.drs.clean): %v", err) + } + if filterClean != "git-drs clean -- %f" { + t.Fatalf("unexpected filter.drs.clean: %q", filterClean) + } } diff --git a/cmd/list/main.go b/cmd/list/main.go deleted file mode 100644 index dfcf7a7c..00000000 --- a/cmd/list/main.go +++ /dev/null @@ -1,59 +0,0 @@ -package list - -import ( - "context" - "fmt" - - "github.com/calypr/git-drs/internal/common" - "github.com/calypr/git-drs/internal/config" - "github.com/calypr/git-drs/internal/drslog" - "github.com/spf13/cobra" -) - -var remote string -var pretty = false - -// Cmd line declaration -var Cmd = &cobra.Command{ - Use: "list", - Short: "List DRS objects in a DRS server", - Long: "List DRS objects in a DRS server", - RunE: func(cmd *cobra.Command, args []string) error { - - logger := drslog.GetLogger() - - config, err := config.LoadConfig() - if err != nil { - return err - } - - remoteName, err := config.GetRemoteOrDefault(remote) - if err != nil { - logger.Error(fmt.Sprintf("Error getting remote: %v", err)) - return err - } - - client, err := config.GetRemoteClient(remoteName, logger) - if err != nil { - return err - } - - objs, err := client.Client.DRS().ListObjects(context.Background(), 1000, 1) - if err != nil { - return err - } - - for _, drsObj := range objs.DrsObjects { - if err := common.PrintDRSObject(drsObj, pretty); err != nil { - return err - } - } - - return nil - }, -} - -func init() { - Cmd.Flags().StringVarP(&remote, "remote", "r", "", "target remote DRS server (default: default_remote)") - Cmd.Flags().BoolVarP(&pretty, "pretty", "p", false, "pretty print JSON output") -} diff --git a/cmd/lsfiles/main.go b/cmd/lsfiles/main.go index 96c3dfd1..d403a33e 100644 --- a/cmd/lsfiles/main.go +++ b/cmd/lsfiles/main.go @@ -1,8 +1,11 @@ package lsfiles import ( + "encoding/json" "fmt" "log/slog" + "os" + "os/exec" "sort" "strings" @@ -10,11 +13,18 @@ import ( "github.com/calypr/git-drs/internal/drslog" "github.com/calypr/git-drs/internal/drsremote" "github.com/calypr/git-drs/internal/lfs" + "github.com/calypr/git-drs/internal/pathspec" + drsapi "github.com/calypr/syfon/apigen/client/drs" "github.com/spf13/cobra" ) var gitRemote string var drsRemote string +var includePatterns []string +var showLong bool +var nameOnly bool +var jsonOutput bool +var drsStatus bool var ( loadConfig = config.LoadConfig @@ -22,40 +32,148 @@ var ( newRemoteClient = func(cfg *config.Config, remote config.Remote, logger *slog.Logger) (*config.GitContext, error) { return cfg.GetRemoteClient(remote, logger) } - loadLFSInventory = lfs.GetAllLfsFiles - lookupScopedObjects = drsremote.ObjectsByHashForScope + loadLFSInventory = func(gitRemoteName, gitRemoteLocation string, branches []string, logger *slog.Logger) (map[string]lfs.LfsFileInfo, error) { + if len(branches) == 0 { + return lfs.GetTrackedLfsFiles(logger) + } + return lfs.GetLfsFilesForRefs(branches, logger) + } + listRemoteRefs = defaultListRemoteRefs + listGitRemotes = defaultListGitRemotes + resolveDefaultRemote = defaultResolveDefaultRemote + lookupScopedObjectsBatch = drsremote.ObjectsByHashesForScope ) type fileRow struct { - OID string - Status string - Path string - Detail string + OID string `json:"oid"` + ShortOID string `json:"short_oid"` + Status string `json:"status"` + Path string `json:"path"` + Localized bool `json:"localized"` + Registered bool `json:"registered,omitempty"` + DRSIDs []string `json:"drs_ids,omitempty"` + Detail string `json:"detail,omitempty"` } -func collectRows(cmd *cobra.Command, gitRemoteName, drsRemoteName string) ([]fileRow, error) { - logger := drslog.GetLogger() +func defaultListRemoteRefs(gitRemoteName string) ([]string, error) { + if strings.TrimSpace(gitRemoteName) == "" { + return nil, nil + } - cfg, err := loadConfig() + cmd := exec.Command("git", "for-each-ref", "--format=%(refname)", "refs/remotes/"+gitRemoteName) + out, err := cmd.Output() if err != nil { - return nil, err + return nil, fmt.Errorf("list refs for remote %s: %w", gitRemoteName, err) } - remoteName, err := resolveRemote(cfg, drsRemoteName) - if err != nil { - logger.Error(fmt.Sprintf("Error getting remote: %v", err)) - return nil, err + lines := strings.Split(string(out), "\n") + refs := make([]string, 0, len(lines)) + for _, line := range lines { + ref := strings.TrimSpace(line) + if ref == "" || strings.HasSuffix(ref, "/HEAD") { + continue + } + refs = append(refs, ref) } + sort.Strings(refs) + return refs, nil +} - client, err := newRemoteClient(cfg, remoteName, logger) +func defaultListGitRemotes() ([]string, error) { + cmd := exec.Command("git", "remote") + out, err := cmd.Output() if err != nil { - return nil, err + return nil, fmt.Errorf("list git remotes: %w", err) + } + + lines := strings.Split(string(out), "\n") + remotes := make([]string, 0, len(lines)) + for _, line := range lines { + name := strings.TrimSpace(line) + if name == "" { + continue + } + remotes = append(remotes, name) + } + sort.Strings(remotes) + return remotes, nil +} + +func defaultResolveDefaultRemote() string { + cfg, err := loadConfig() + if err == nil && cfg != nil { + if remote, err := cfg.GetRemoteOrDefault(""); err == nil { + return strings.TrimSpace(string(remote)) + } } - lfsFiles, err := loadLFSInventory(gitRemoteName, drsRemoteName, []string{}, logger) + remotes, err := listGitRemotes() + if err != nil || len(remotes) == 0 { + return "" + } + for _, remote := range remotes { + if remote == config.ORIGIN { + return remote + } + } + if len(remotes) == 1 { + return remotes[0] + } + return "" +} + +func collectRows(cmd *cobra.Command, gitRemoteName, drsRemoteName string, patterns []string, resolveDRS bool) ([]fileRow, error) { + logger := drslog.GetLogger() + + var client *config.GitContext + if resolveDRS { + cfg, err := loadConfig() + if err != nil { + return nil, err + } + + remoteName, err := resolveRemote(cfg, drsRemoteName) + if err != nil { + logger.Error(fmt.Sprintf("Error getting remote: %v", err)) + return nil, err + } + + client, err = newRemoteClient(cfg, remoteName, logger) + if err != nil { + return nil, err + } + } + + var ( + refs []string + err error + ) + if strings.TrimSpace(gitRemoteName) != "" { + refs, err = listRemoteRefs(gitRemoteName) + if err != nil { + return nil, err + } + } + + lfsFiles, err := loadLFSInventory(gitRemoteName, drsRemoteName, refs, logger) if err != nil { return nil, err } + if len(lfsFiles) == 0 && strings.TrimSpace(gitRemoteName) == "" { + fallbackRemote := resolveDefaultRemote() + if fallbackRemote != "" { + refs, err = listRemoteRefs(fallbackRemote) + if err != nil { + return nil, err + } + if len(refs) > 0 { + lfsFiles, err = loadLFSInventory(fallbackRemote, drsRemoteName, refs, logger) + if err != nil { + return nil, err + } + } + } + } keys := make([]string, 0, len(lfsFiles)) for path := range lfsFiles { @@ -64,28 +182,60 @@ func collectRows(cmd *cobra.Command, gitRemoteName, drsRemoteName string) ([]fil sort.Strings(keys) rows := make([]fileRow, 0, len(keys)) + var drsResults map[string][]drsapi.DrsObject + var drsLookupErr error + if resolveDRS { + oids := make([]string, 0, len(keys)) + seenOIDs := make(map[string]struct{}, len(keys)) + for _, path := range keys { + if !pathspec.MatchesAny(path, patterns) { + continue + } + oid := lfsFiles[path].Oid + if oid == "" { + continue + } + if _, exists := seenOIDs[oid]; exists { + continue + } + seenOIDs[oid] = struct{}{} + oids = append(oids, oid) + } + drsResults, drsLookupErr = lookupScopedObjectsBatch(cmd.Context(), client, oids) + } for _, path := range keys { + if !pathspec.MatchesAny(path, patterns) { + continue + } info := lfsFiles[path] row := fileRow{ - OID: info.Oid, - Path: path, + OID: info.Oid, + ShortOID: shortOID(info.Oid), + Path: path, + Localized: isLocalized(path), + } + row.Status = "-" + if row.Localized { + row.Status = "*" } - results, err := lookupScopedObjects(cmd.Context(), client, info.Oid) - switch { - case err != nil: - row.Status = "error" - row.Detail = err.Error() - case len(results) == 0: - row.Status = "missing" - row.Detail = "-" - default: - row.Status = "present" - ids := make([]string, 0, len(results)) - for _, res := range results { - ids = append(ids, "drs://"+res.Id) + if resolveDRS { + switch { + case drsLookupErr != nil: + row.Detail = drsLookupErr.Error() + default: + results := drsResults[info.Oid] + if len(results) == 0 { + row.Registered = false + break + } + row.Registered = true + row.DRSIDs = make([]string, 0, len(results)) + for _, res := range results { + row.DRSIDs = append(row.DRSIDs, "drs://"+res.Id) + } + row.Detail = strings.Join(row.DRSIDs, ",") } - row.Detail = strings.Join(ids, ",") } rows = append(rows, row) @@ -95,23 +245,80 @@ func collectRows(cmd *cobra.Command, gitRemoteName, drsRemoteName string) ([]fil } func printRows(cmd *cobra.Command, rows []fileRow) error { - if _, err := fmt.Fprintf(cmd.OutOrStdout(), "OID\tSTATUS\tPATH\tDETAIL\n"); err != nil { - return err + if jsonOutput { + enc := json.NewEncoder(cmd.OutOrStdout()) + enc.SetIndent("", " ") + return enc.Encode(rows) } for _, row := range rows { - if _, err := fmt.Fprintf(cmd.OutOrStdout(), "%s\t%s\t%s\t%s\n", row.OID, row.Status, row.Path, row.Detail); err != nil { - return err + switch { + case nameOnly: + if _, err := fmt.Fprintln(cmd.OutOrStdout(), row.Path); err != nil { + return err + } + case drsStatus: + oid := row.ShortOID + if showLong { + oid = row.OID + } + detail := row.Detail + if detail == "" { + detail = "-" + } + if _, err := fmt.Fprintf(cmd.OutOrStdout(), "%s %s %s\t%s\n", oid, row.Status, row.Path, detail); err != nil { + return err + } + default: + oid := row.ShortOID + if showLong { + oid = row.OID + } + if _, err := fmt.Fprintf(cmd.OutOrStdout(), "%s %s %s\n", oid, row.Status, row.Path); err != nil { + return err + } } } return nil } +func shortOID(oid string) string { + if len(oid) <= 10 { + return oid + } + return oid[:10] +} + +func isLocalized(path string) bool { + payload, err := os.ReadFile(path) + if err != nil { + return false + } + _, _, ok := lfs.ParseLFSPointer(payload) + return !ok +} + +func validateOutputFlags() error { + if nameOnly && jsonOutput { + return fmt.Errorf("--name-only and --json are mutually exclusive") + } + if showLong && nameOnly { + return fmt.Errorf("--long and --name-only are mutually exclusive") + } + return nil +} + // Cmd line declaration var Cmd = &cobra.Command{ - Use: "ls-files", - Short: "List local LFS-tracked files and their DRS registration status", + Use: "ls-files [pathspec...]", + Short: "List tracked DRS/LFS pointer files in the repository", + Long: "List tracked DRS/Git-LFS pointer files in the repository. By default this behaves like a local file inventory. Use --drs to also resolve DRS registration status.", RunE: func(cmd *cobra.Command, args []string) error { - rows, err := collectRows(cmd, gitRemote, drsRemote) + if err := validateOutputFlags(); err != nil { + return err + } + patterns := append([]string{}, includePatterns...) + patterns = append(patterns, args...) + rows, err := collectRows(cmd, gitRemote, drsRemote, patterns, drsStatus) if err != nil { return err } @@ -122,4 +329,9 @@ var Cmd = &cobra.Command{ func init() { Cmd.Flags().StringVarP(&gitRemote, "git-remote", "r", "", "target remote Git server (default: origin)") Cmd.Flags().StringVarP(&drsRemote, "drs-remote", "d", "", "target remote DRS server (default: origin)") + Cmd.Flags().StringArrayVarP(&includePatterns, "include", "I", nil, "include pathspec/glob pattern(s)") + Cmd.Flags().BoolVarP(&showLong, "long", "l", false, "show full object IDs") + Cmd.Flags().BoolVarP(&nameOnly, "name-only", "n", false, "show only file paths") + Cmd.Flags().BoolVar(&jsonOutput, "json", false, "emit JSON output") + Cmd.Flags().BoolVar(&drsStatus, "drs", false, "include DRS registration lookup details") } diff --git a/cmd/lsfiles/main_test.go b/cmd/lsfiles/main_test.go index 492b0b4f..d4adf1a6 100644 --- a/cmd/lsfiles/main_test.go +++ b/cmd/lsfiles/main_test.go @@ -5,6 +5,8 @@ import ( "context" "errors" "log/slog" + "os" + "path/filepath" "strings" "testing" @@ -14,18 +16,115 @@ import ( "github.com/spf13/cobra" ) -func TestCollectRowsAndPrintRows(t *testing.T) { +func resetFlagsForTest() { + gitRemote = "" + drsRemote = "" + includePatterns = nil + showLong = false + nameOnly = false + jsonOutput = false + drsStatus = false +} + +func TestCollectRowsLocalDefault(t *testing.T) { + resetFlagsForTest() + + oldLoadLFSInventory := loadLFSInventory + oldLookupScopedObjectsBatch := lookupScopedObjectsBatch + oldResolveDefaultRemote := resolveDefaultRemote + t.Cleanup(func() { + loadLFSInventory = oldLoadLFSInventory + lookupScopedObjectsBatch = oldLookupScopedObjectsBatch + resolveDefaultRemote = oldResolveDefaultRemote + }) + + tmpDir := t.TempDir() + oldWD, err := os.Getwd() + if err != nil { + t.Fatalf("getwd: %v", err) + } + if err := os.Chdir(tmpDir); err != nil { + t.Fatalf("chdir tempdir: %v", err) + } + t.Cleanup(func() { + _ = os.Chdir(oldWD) + }) + + localizedPath := filepath.Join("a", "localized.bin") + pointerPath := filepath.Join("b", "pointer.bin") + if err := os.MkdirAll(filepath.Dir(localizedPath), 0o755); err != nil { + t.Fatalf("mkdir localized dir: %v", err) + } + if err := os.MkdirAll(filepath.Dir(pointerPath), 0o755); err != nil { + t.Fatalf("mkdir pointer dir: %v", err) + } + if err := os.WriteFile(localizedPath, []byte("hydrated-bytes"), 0o644); err != nil { + t.Fatalf("write localized file: %v", err) + } + pointerContent := "version https://git-lfs.github.com/spec/v1\noid sha256:" + strings.Repeat("b", 64) + "\nsize 12\n" + if err := os.WriteFile(pointerPath, []byte(pointerContent), 0o644); err != nil { + t.Fatalf("write pointer file: %v", err) + } + + loadLFSInventory = func(gitRemoteName, gitRemoteLocation string, branches []string, logger *slog.Logger) (map[string]lfs.LfsFileInfo, error) { + return map[string]lfs.LfsFileInfo{ + localizedPath: {Name: localizedPath, Oid: strings.Repeat("a", 64)}, + pointerPath: {Name: pointerPath, Oid: strings.Repeat("b", 64)}, + }, nil + } + lookupScopedObjectsBatch = func(ctx context.Context, drsCtx *config.GitContext, checksums []string) (map[string][]drsapi.DrsObject, error) { + t.Fatalf("unexpected remote lookup for checksums %v", checksums) + return nil, nil + } + resolveDefaultRemote = func() string { return "" } + + cmd := &cobra.Command{} + rows, err := collectRows(cmd, "", "", nil, false) + if err != nil { + t.Fatalf("collectRows returned error: %v", err) + } + if len(rows) != 2 { + t.Fatalf("expected 2 rows, got %d", len(rows)) + } + if rows[0].Path != localizedPath || rows[0].Status != "*" || !rows[0].Localized { + t.Fatalf("unexpected localized row: %+v", rows[0]) + } + if rows[1].Path != pointerPath || rows[1].Status != "-" || rows[1].Localized { + t.Fatalf("unexpected pointer row: %+v", rows[1]) + } + + var out bytes.Buffer + cmd.SetOut(&out) + if err := printRows(cmd, rows); err != nil { + t.Fatalf("printRows returned error: %v", err) + } + got := out.String() + if !strings.Contains(got, rows[0].ShortOID+" * "+localizedPath+"\n") { + t.Fatalf("missing localized row: %q", got) + } + if !strings.Contains(got, rows[1].ShortOID+" - "+pointerPath+"\n") { + t.Fatalf("missing pointer row: %q", got) + } +} + +func TestCollectRowsWithDRSLookupAndFilters(t *testing.T) { + resetFlagsForTest() + oldLoadConfig := loadConfig oldResolveRemote := resolveRemote oldNewRemoteClient := newRemoteClient oldLoadLFSInventory := loadLFSInventory - oldLookupScopedObjects := lookupScopedObjects + oldListRemoteRefs := listRemoteRefs + oldLookupScopedObjectsBatch := lookupScopedObjectsBatch + oldResolveDefaultRemote := resolveDefaultRemote t.Cleanup(func() { loadConfig = oldLoadConfig resolveRemote = oldResolveRemote newRemoteClient = oldNewRemoteClient loadLFSInventory = oldLoadLFSInventory - lookupScopedObjects = oldLookupScopedObjects + listRemoteRefs = oldListRemoteRefs + lookupScopedObjectsBatch = oldLookupScopedObjectsBatch + resolveDefaultRemote = oldResolveDefaultRemote }) loadConfig = func() (*config.Config, error) { @@ -37,58 +136,234 @@ func TestCollectRowsAndPrintRows(t *testing.T) { newRemoteClient = func(cfg *config.Config, remote config.Remote, logger *slog.Logger) (*config.GitContext, error) { return &config.GitContext{}, nil } + loadLFSInventory = func(gitRemoteName, gitRemoteLocation string, branches []string, logger *slog.Logger) (map[string]lfs.LfsFileInfo, error) { return map[string]lfs.LfsFileInfo{ - "b/file2.bin": {Name: "b/file2.bin", Oid: strings.Repeat("b", 64)}, - "a/file1.bin": {Name: "a/file1.bin", Oid: strings.Repeat("a", 64)}, - "c/file3.bin": {Name: "c/file3.bin", Oid: strings.Repeat("c", 64)}, + "a/file1.bin": {Name: "a/file1.bin", Oid: strings.Repeat("a", 64)}, + "data/file2.bam": {Name: "data/file2.bam", Oid: strings.Repeat("b", 64)}, + "data/file3.txt": {Name: "data/file3.txt", Oid: strings.Repeat("c", 64)}, }, nil } - lookupScopedObjects = func(ctx context.Context, drsCtx *config.GitContext, checksum string) ([]drsapi.DrsObject, error) { - switch checksum { - case strings.Repeat("a", 64): - return []drsapi.DrsObject{{Id: "did-1"}}, nil - case strings.Repeat("b", 64): + listRemoteRefs = func(remote string) ([]string, error) { + if remote == "" { return nil, nil - default: - return nil, errors.New("lookup failed") } + return []string{"refs/remotes/dev/main"}, nil + } + lookupScopedObjectsBatch = func(ctx context.Context, drsCtx *config.GitContext, checksums []string) (map[string][]drsapi.DrsObject, error) { + got := map[string][]drsapi.DrsObject{} + for _, checksum := range checksums { + switch checksum { + case strings.Repeat("b", 64): + got[checksum] = []drsapi.DrsObject{{Id: "did-1"}} + default: + got[checksum] = nil + } + } + return got, nil } + resolveDefaultRemote = func() string { return "" } cmd := &cobra.Command{} - rows, err := collectRows(cmd, "", "") + rows, err := collectRows(cmd, "dev", "", []string{"data/**"}, true) if err != nil { t.Fatalf("collectRows returned error: %v", err) } - if len(rows) != 3 { - t.Fatalf("expected 3 rows, got %d", len(rows)) + if len(rows) != 2 { + t.Fatalf("expected 2 rows, got %d", len(rows)) } - if rows[0].Path != "a/file1.bin" || rows[0].Status != "present" || rows[0].Detail != "drs://did-1" { - t.Fatalf("unexpected first row: %+v", rows[0]) + if rows[0].Path != "data/file2.bam" || !rows[0].Registered || rows[0].Detail != "drs://did-1" { + t.Fatalf("unexpected registered row: %+v", rows[0]) } - if rows[1].Path != "b/file2.bin" || rows[1].Status != "missing" || rows[1].Detail != "-" { - t.Fatalf("unexpected second row: %+v", rows[1]) - } - if rows[2].Path != "c/file3.bin" || rows[2].Status != "error" || rows[2].Detail != "lookup failed" { - t.Fatalf("unexpected third row: %+v", rows[2]) + if rows[1].Path != "data/file3.txt" || rows[1].Registered || rows[1].Detail != "" { + t.Fatalf("unexpected unregistered row: %+v", rows[1]) } + drsStatus = true + showLong = true var out bytes.Buffer cmd.SetOut(&out) if err := printRows(cmd, rows); err != nil { t.Fatalf("printRows returned error: %v", err) } got := out.String() - if !strings.Contains(got, "OID\tSTATUS\tPATH\tDETAIL\n") { - t.Fatalf("missing header in output: %q", got) + if !strings.Contains(got, rows[0].OID+" - data/file2.bam\tdrs://did-1\n") { + t.Fatalf("missing registered row: %q", got) + } + if !strings.Contains(got, rows[1].OID+" - data/file3.txt\t-\n") { + t.Fatalf("missing unregistered row: %q", got) + } +} + +func TestCollectRowsWithDRSLookupBatchError(t *testing.T) { + resetFlagsForTest() + + oldLoadConfig := loadConfig + oldResolveRemote := resolveRemote + oldNewRemoteClient := newRemoteClient + oldLoadLFSInventory := loadLFSInventory + oldListRemoteRefs := listRemoteRefs + oldLookupScopedObjectsBatch := lookupScopedObjectsBatch + oldResolveDefaultRemote := resolveDefaultRemote + t.Cleanup(func() { + loadConfig = oldLoadConfig + resolveRemote = oldResolveRemote + newRemoteClient = oldNewRemoteClient + loadLFSInventory = oldLoadLFSInventory + listRemoteRefs = oldListRemoteRefs + lookupScopedObjectsBatch = oldLookupScopedObjectsBatch + resolveDefaultRemote = oldResolveDefaultRemote + }) + + loadConfig = func() (*config.Config, error) { return &config.Config{}, nil } + resolveRemote = func(cfg *config.Config, name string) (config.Remote, error) { + return config.Remote("origin"), nil + } + newRemoteClient = func(cfg *config.Config, remote config.Remote, logger *slog.Logger) (*config.GitContext, error) { + return &config.GitContext{}, nil + } + loadLFSInventory = func(gitRemoteName, gitRemoteLocation string, branches []string, logger *slog.Logger) (map[string]lfs.LfsFileInfo, error) { + return map[string]lfs.LfsFileInfo{ + "data/file2.bam": {Name: "data/file2.bam", Oid: strings.Repeat("b", 64)}, + "data/file3.txt": {Name: "data/file3.txt", Oid: strings.Repeat("c", 64)}, + }, nil + } + listRemoteRefs = func(remote string) ([]string, error) { + if remote == "" { + return nil, nil + } + return []string{"refs/remotes/dev/main"}, nil + } + lookupScopedObjectsBatch = func(ctx context.Context, drsCtx *config.GitContext, checksums []string) (map[string][]drsapi.DrsObject, error) { + return nil, errors.New("lookup failed") + } + resolveDefaultRemote = func() string { return "" } + + cmd := &cobra.Command{} + rows, err := collectRows(cmd, "dev", "", []string{"data/**"}, true) + if err != nil { + t.Fatalf("collectRows returned error: %v", err) + } + if len(rows) != 2 { + t.Fatalf("expected 2 rows, got %d", len(rows)) + } + for _, row := range rows { + if row.Detail != "lookup failed" { + t.Fatalf("expected shared batch lookup error, got row=%+v", row) + } + } +} + +func TestCollectRowsUsesRemoteRefsWhenGitRemoteProvided(t *testing.T) { + resetFlagsForTest() + + oldLoadLFSInventory := loadLFSInventory + oldListRemoteRefs := listRemoteRefs + oldResolveDefaultRemote := resolveDefaultRemote + t.Cleanup(func() { + loadLFSInventory = oldLoadLFSInventory + listRemoteRefs = oldListRemoteRefs + resolveDefaultRemote = oldResolveDefaultRemote + }) + + listRemoteRefs = func(remote string) ([]string, error) { + if remote != "dev" { + t.Fatalf("unexpected remote %q", remote) + } + return []string{"refs/remotes/dev/main", "refs/remotes/dev/release"}, nil + } + + loadLFSInventory = func(gitRemoteName, gitRemoteLocation string, branches []string, logger *slog.Logger) (map[string]lfs.LfsFileInfo, error) { + if gitRemoteName != "dev" { + t.Fatalf("unexpected git remote name %q", gitRemoteName) + } + if len(branches) != 2 || branches[0] != "refs/remotes/dev/main" || branches[1] != "refs/remotes/dev/release" { + t.Fatalf("unexpected refs %v", branches) + } + return map[string]lfs.LfsFileInfo{ + "data/file.bam": {Name: "data/file.bam", Oid: strings.Repeat("a", 64)}, + }, nil } - if !strings.Contains(got, rows[0].OID+"\tpresent\ta/file1.bin\tdrs://did-1\n") { - t.Fatalf("missing present row: %q", got) + resolveDefaultRemote = func() string { + t.Fatal("default remote fallback should not be used when --git-remote is set") + return "" } - if !strings.Contains(got, rows[1].OID+"\tmissing\tb/file2.bin\t-\n") { - t.Fatalf("missing missing row: %q", got) + + cmd := &cobra.Command{} + rows, err := collectRows(cmd, "dev", "", nil, false) + if err != nil { + t.Fatalf("collectRows returned error: %v", err) + } + if len(rows) != 1 || rows[0].Path != "data/file.bam" { + t.Fatalf("unexpected rows %+v", rows) + } +} + +func TestCollectRowsFallsBackToDefaultRemoteWhenLocalInventoryEmpty(t *testing.T) { + resetFlagsForTest() + + oldLoadLFSInventory := loadLFSInventory + oldListRemoteRefs := listRemoteRefs + oldResolveDefaultRemote := resolveDefaultRemote + t.Cleanup(func() { + loadLFSInventory = oldLoadLFSInventory + listRemoteRefs = oldListRemoteRefs + resolveDefaultRemote = oldResolveDefaultRemote + }) + + callCount := 0 + loadLFSInventory = func(gitRemoteName, gitRemoteLocation string, branches []string, logger *slog.Logger) (map[string]lfs.LfsFileInfo, error) { + callCount++ + if callCount == 1 { + if gitRemoteName != "" || len(branches) != 0 { + t.Fatalf("first inventory call should be local-only, got remote=%q refs=%v", gitRemoteName, branches) + } + return map[string]lfs.LfsFileInfo{}, nil + } + if gitRemoteName != "dev" { + t.Fatalf("expected fallback remote dev, got %q", gitRemoteName) + } + if len(branches) != 1 || branches[0] != "refs/remotes/dev/main" { + t.Fatalf("unexpected fallback refs: %v", branches) + } + return map[string]lfs.LfsFileInfo{ + "data/file2.bam": {Name: "data/file2.bam", Oid: strings.Repeat("b", 64)}, + }, nil + } + resolveDefaultRemote = func() string { return "dev" } + listRemoteRefs = func(remote string) ([]string, error) { + if remote != "dev" { + t.Fatalf("expected fallback remote query for dev, got %q", remote) + } + return []string{"refs/remotes/dev/main"}, nil } - if !strings.Contains(got, rows[2].OID+"\terror\tc/file3.bin\tlookup failed\n") { - t.Fatalf("missing error row: %q", got) + + cmd := &cobra.Command{} + rows, err := collectRows(cmd, "", "", nil, false) + if err != nil { + t.Fatalf("collectRows returned error: %v", err) + } + if len(rows) != 1 || rows[0].Path != "data/file2.bam" { + t.Fatalf("unexpected rows: %+v", rows) + } + if callCount != 2 { + t.Fatalf("expected 2 inventory calls, got %d", callCount) + } +} + +func TestValidateOutputFlags(t *testing.T) { + resetFlagsForTest() + + nameOnly = true + jsonOutput = true + if err := validateOutputFlags(); err == nil { + t.Fatal("expected name-only/json conflict") + } + + resetFlagsForTest() + nameOnly = true + showLong = true + if err := validateOutputFlags(); err == nil { + t.Fatal("expected long/name-only conflict") } } diff --git a/cmd/ping/main.go b/cmd/ping/main.go new file mode 100644 index 00000000..e02d070f --- /dev/null +++ b/cmd/ping/main.go @@ -0,0 +1,137 @@ +package ping + +import ( + "context" + "fmt" + "log/slog" + "strings" + + "github.com/calypr/git-drs/internal/config" + "github.com/calypr/git-drs/internal/drslog" + "github.com/spf13/cobra" +) + +type statusInfo struct { + Remote config.Remote + IsDefault bool + RemoteType string + Endpoint string + Organization string + Project string + Bucket string + StoragePrefix string + AuthMode string +} + +var pingHealth = func(ctx context.Context, gc *config.GitContext) error { + return gc.Client.Health().Ping(ctx) +} + +var Cmd = &cobra.Command{ + Use: "ping [remote-name]", + Short: "Show effective remote setup and verify the remote responds", + Args: func(cmd *cobra.Command, args []string) error { + if len(args) > 1 { + cmd.SilenceUsage = false + return fmt.Errorf("error: accepts at most 1 argument (remote name), received %d\n\nUsage: %s\n\nSee 'git drs ping --help' for more details", len(args), cmd.UseLine()) + } + return nil + }, + RunE: func(cmd *cobra.Command, args []string) error { + logger := drslog.GetLogger() + status, gc, err := resolveStatus(args, logger) + if err != nil { + return err + } + printStatus(status) + + if err := pingHealth(cmd.Context(), gc); err != nil { + return fmt.Errorf("remote health check failed for %q (%s): %w", status.Remote, status.Endpoint, err) + } + fmt.Println("health: ok") + return nil + }, +} + +func resolveStatus(args []string, logger *slog.Logger) (statusInfo, *config.GitContext, error) { + cfg, err := config.LoadConfig() + if err != nil { + return statusInfo{}, nil, err + } + + var remoteArg string + if len(args) == 1 { + remoteArg = args[0] + } + remoteName, err := cfg.GetRemoteOrDefault(remoteArg) + if err != nil { + return statusInfo{}, nil, err + } + + remoteCfg := cfg.GetRemote(remoteName) + if remoteCfg == nil { + return statusInfo{}, nil, fmt.Errorf("no remote configuration found for %q", remoteName) + } + + gc, err := cfg.GetRemoteClient(remoteName, logger) + if err != nil { + return statusInfo{}, nil, err + } + + status := statusInfo{ + Remote: remoteName, + IsDefault: remoteName == cfg.DefaultRemote, + Endpoint: remoteCfg.GetEndpoint(), + Organization: remoteCfg.GetOrganization(), + Project: remoteCfg.GetProjectId(), + Bucket: gc.BucketName, + StoragePrefix: gc.StoragePrefix, + AuthMode: authMode(gc), + } + switch remoteCfg.(type) { + case *config.Gen3Remote: + status.RemoteType = string(config.Gen3ServerType) + case *config.LocalRemote: + status.RemoteType = string(config.LocalServerType) + default: + status.RemoteType = "unknown" + } + + return status, gc, nil +} + +func printStatus(status statusInfo) { + def := "" + if status.IsDefault { + def = " (default)" + } + fmt.Printf("remote: %s%s\n", status.Remote, def) + fmt.Printf("type: %s\n", status.RemoteType) + fmt.Printf("endpoint: %s\n", status.Endpoint) + fmt.Printf("organization: %s\n", blankIfEmpty(status.Organization)) + fmt.Printf("project: %s\n", blankIfEmpty(status.Project)) + fmt.Printf("bucket: %s\n", blankIfEmpty(status.Bucket)) + fmt.Printf("storage_prefix: %s\n", blankIfEmpty(status.StoragePrefix)) + fmt.Printf("auth: %s\n", status.AuthMode) +} + +func authMode(gc *config.GitContext) string { + if gc == nil || gc.Credential == nil { + return "none" + } + if strings.TrimSpace(gc.Credential.AccessToken) != "" { + return "bearer" + } + if strings.TrimSpace(gc.Credential.KeyID) != "" || strings.TrimSpace(gc.Credential.APIKey) != "" { + return "basic" + } + return "none" +} + +func blankIfEmpty(v string) string { + v = strings.TrimSpace(v) + if v == "" { + return "-" + } + return v +} diff --git a/cmd/ping/main_test.go b/cmd/ping/main_test.go new file mode 100644 index 00000000..8efff3c7 --- /dev/null +++ b/cmd/ping/main_test.go @@ -0,0 +1,132 @@ +package ping + +import ( + "bytes" + "context" + "io" + "os" + "strings" + "testing" + + "github.com/calypr/git-drs/internal/config" + "github.com/calypr/git-drs/internal/drslog" + "github.com/calypr/git-drs/internal/gitrepo" + "github.com/calypr/git-drs/internal/testutils" +) + +func TestPingCmdArgs(t *testing.T) { + if err := Cmd.Args(Cmd, nil); err != nil { + t.Fatalf("unexpected error with no args: %v", err) + } + if err := Cmd.Args(Cmd, []string{"origin"}); err != nil { + t.Fatalf("unexpected error with one arg: %v", err) + } + if err := Cmd.Args(Cmd, []string{"origin", "extra"}); err == nil { + t.Fatal("expected error for extra args") + } +} + +func TestResolveStatusLocalRemote(t *testing.T) { + tmpDir := testutils.SetupTestGitRepo(t) + testutils.CreateTestConfig(t, tmpDir, &config.Config{ + DefaultRemote: config.Remote(config.ORIGIN), + Remotes: map[config.Remote]config.RemoteSelect{ + config.Remote(config.ORIGIN): { + Local: &config.LocalRemote{ + BaseURL: "http://127.0.0.1:8080", + ProjectID: "end_to_end_test", + Bucket: "cbds", + Organization: "calypr", + BasicUsername: "drs-user", + BasicPassword: "drs-pass", + }, + }, + }, + }) + if err := gitrepo.SetBucketMapping("calypr", "end_to_end_test", "cbds", "prefix"); err != nil { + t.Fatalf("SetBucketMapping failed: %v", err) + } + + status, _, err := resolveStatus(nil, drslog.NewNoOpLogger()) + if err != nil { + t.Fatalf("resolveStatus returned error: %v", err) + } + if status.Remote != "origin" || !status.IsDefault { + t.Fatalf("unexpected remote selection: %+v", status) + } + if status.RemoteType != "local" || status.Endpoint != "http://127.0.0.1:8080" { + t.Fatalf("unexpected remote type/endpoint: %+v", status) + } + if status.Organization != "calypr" || status.Project != "end_to_end_test" { + t.Fatalf("unexpected scope: %+v", status) + } + if status.Bucket != "cbds" || status.StoragePrefix != "prefix" { + t.Fatalf("unexpected bucket scope: %+v", status) + } + if status.AuthMode != "none" { + t.Fatalf("expected auth mode none from client credential shape, got %+v", status) + } +} + +func TestPingRunEPrintsStatusAndHealth(t *testing.T) { + tmpDir := testutils.SetupTestGitRepo(t) + testutils.CreateTestConfig(t, tmpDir, &config.Config{ + DefaultRemote: config.Remote(config.ORIGIN), + Remotes: map[config.Remote]config.RemoteSelect{ + config.Remote(config.ORIGIN): { + Local: &config.LocalRemote{ + BaseURL: "http://127.0.0.1:8080", + ProjectID: "end_to_end_test", + Bucket: "cbds", + Organization: "calypr", + }, + }, + }, + }) + if err := gitrepo.SetBucketMapping("calypr", "end_to_end_test", "cbds", "prefix"); err != nil { + t.Fatalf("SetBucketMapping failed: %v", err) + } + + oldHealth := pingHealth + pingHealth = func(ctx context.Context, gc *config.GitContext) error { + if gc == nil || gc.ProjectId != "end_to_end_test" { + t.Fatalf("unexpected git context: %+v", gc) + } + return nil + } + t.Cleanup(func() { pingHealth = oldHealth }) + + oldStdout := os.Stdout + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("pipe: %v", err) + } + os.Stdout = w + t.Cleanup(func() { os.Stdout = oldStdout }) + + runErr := Cmd.RunE(Cmd, nil) + _ = w.Close() + if runErr != nil { + t.Fatalf("Cmd.RunE returned error: %v", runErr) + } + + var buf bytes.Buffer + if _, err := io.Copy(&buf, r); err != nil { + t.Fatalf("read stdout: %v", err) + } + got := buf.String() + for _, want := range []string{ + "remote: origin (default)", + "type: local", + "endpoint: http://127.0.0.1:8080", + "organization: calypr", + "project: end_to_end_test", + "bucket: cbds", + "storage_prefix: prefix", + "health: ok", + } { + if !strings.Contains(got, want) { + t.Fatalf("expected output to contain %q, got %q", want, got) + } + } +} diff --git a/cmd/precommit/main.go b/cmd/precommit/main.go index ac2f3047..53ca5a12 100644 --- a/cmd/precommit/main.go +++ b/cmd/precommit/main.go @@ -26,6 +26,7 @@ import ( "os/exec" "path/filepath" "sort" + "strconv" "strings" "time" @@ -33,8 +34,14 @@ import ( ) const ( - cacheVersionDir = "drs/pre-commit/v1" - lfsSpecLine = "version https://git-lfs.github.com/spec/v1" + cacheVersionDir = "drs/pre-commit/v1" + lfsSpecLine = "version https://git-lfs.github.com/spec/v1" + defaultDirectCommitWarningThreshold = int64(10 * 1024 * 1024) +) + +var ( + directCommitWarningThresholdBytes = defaultDirectCommitWarningThreshold + confirmOversizedDirectGitCommit = promptOversizedDirectGitCommit ) type PathEntry struct { @@ -67,6 +74,11 @@ type Change struct { Status string // raw status, e.g. "A", "M", "D", "R100" } +type OversizedStagedFile struct { + Path string + Size int64 +} + // Cmd line declaration var Cmd = &cobra.Command{ Use: "precommit", @@ -114,6 +126,19 @@ func run(ctx context.Context) error { if len(changes) == 0 { return nil } + oversized, err := collectOversizedPlainGitStagedFiles(ctx, changes, directCommitWarningThresholdBytes) + if err != nil { + return err + } + if len(oversized) > 0 { + allowed, err := confirmOversizedDirectGitCommit(oversized) + if err != nil { + return err + } + if !allowed { + return fmt.Errorf("commit aborted so you can track large files before committing them directly to Git") + } + } now := time.Now().UTC().Format(time.RFC3339) @@ -349,6 +374,92 @@ func stagedLFSOID(ctx context.Context, path string) (string, bool, error) { return "", false, nil } +func stagedBlobSize(ctx context.Context, path string) (int64, error) { + out, err := git(ctx, "cat-file", "-s", ":"+path) + if err != nil { + return 0, err + } + size, err := strconv.ParseInt(strings.TrimSpace(string(out)), 10, 64) + if err != nil { + return 0, fmt.Errorf("parse staged blob size for %s: %w", path, err) + } + return size, nil +} + +func collectOversizedPlainGitStagedFiles(ctx context.Context, changes []Change, thresholdBytes int64) ([]OversizedStagedFile, error) { + if thresholdBytes <= 0 { + return nil, nil + } + var oversized []OversizedStagedFile + seen := make(map[string]struct{}) + for _, ch := range changes { + if ch.Kind != KindAdd && ch.Kind != KindModify && ch.Kind != KindRename { + continue + } + path := ch.NewPath + if path == "" { + continue + } + if _, ok := seen[path]; ok { + continue + } + seen[path] = struct{}{} + + _, isLFS, err := stagedLFSOID(ctx, path) + if err != nil { + continue + } + if isLFS { + continue + } + + size, err := stagedBlobSize(ctx, path) + if err != nil { + return nil, err + } + if size <= thresholdBytes { + continue + } + oversized = append(oversized, OversizedStagedFile{Path: path, Size: size}) + } + sort.Slice(oversized, func(i, j int) bool { return oversized[i].Path < oversized[j].Path }) + return oversized, nil +} + +func promptOversizedDirectGitCommit(files []OversizedStagedFile) (bool, error) { + if len(files) == 0 { + return true, nil + } + + fmt.Fprintf(os.Stderr, "\nWarning: the following staged files are being committed directly to Git and exceed %s:\n\n", humanBytes(directCommitWarningThresholdBytes)) + for _, f := range files { + fmt.Fprintf(os.Stderr, " - %s (%s)\n", f.Path, humanBytes(f.Size)) + } + fmt.Fprintln(os.Stderr, "\nIf these should be managed by git-drs, track them first and re-add them.") + fmt.Fprint(os.Stderr, "Continue committing these files directly to GitHub? [y/N]: ") + + reader := bufio.NewReader(os.Stdin) + line, err := reader.ReadString('\n') + if err != nil && !errors.Is(err, io.EOF) { + return false, err + } + answer := strings.ToLower(strings.TrimSpace(line)) + return answer == "y" || answer == "yes", nil +} + +func humanBytes(n int64) string { + const unit = int64(1024) + if n < unit { + return fmt.Sprintf("%d B", n) + } + div, exp := unit, 0 + for q := n / unit; q >= unit; q /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %ciB", float64(n)/float64(div), "KMGTPE"[exp]) +} + func gitRevParseGitDir(ctx context.Context) (string, error) { out, err := git(ctx, "rev-parse", "--git-dir") if err != nil { diff --git a/cmd/precommit/main_test.go b/cmd/precommit/main_test.go index 8a0fb0c6..5ee5cc9a 100644 --- a/cmd/precommit/main_test.go +++ b/cmd/precommit/main_test.go @@ -114,6 +114,85 @@ func TestHandleUpsertWritesLFSPointerCache(t *testing.T) { } } +func TestCollectOversizedPlainGitStagedFiles(t *testing.T) { + repo := setupGitRepo(t) + oldwd := mustChdir(t, repo) + t.Cleanup(func() { _ = os.Chdir(oldwd) }) + + plainPath := filepath.Join(repo, "data", "large.bin") + if err := os.MkdirAll(filepath.Dir(plainPath), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(plainPath, []byte("plain oversized payload"), 0o644); err != nil { + t.Fatalf("write plain file: %v", err) + } + gitCmd(t, repo, "add", "data/large.bin") + + pointerPath := filepath.Join(repo, "data", "pointer.bin") + lfsPointer := strings.Join([]string{ + "version https://git-lfs.github.com/spec/v1", + "oid sha256:deadbeef", + "size 999", + "", + }, "\n") + if err := os.WriteFile(pointerPath, []byte(lfsPointer), 0o644); err != nil { + t.Fatalf("write pointer file: %v", err) + } + gitCmd(t, repo, "add", "data/pointer.bin") + + changes, err := stagedChanges(context.Background()) + if err != nil { + t.Fatalf("stagedChanges: %v", err) + } + files, err := collectOversizedPlainGitStagedFiles(context.Background(), changes, 1) + if err != nil { + t.Fatalf("collectOversizedPlainGitStagedFiles: %v", err) + } + if len(files) != 1 { + t.Fatalf("expected 1 oversized plain file, got %d: %+v", len(files), files) + } + if files[0].Path != "data/large.bin" { + t.Fatalf("unexpected oversized file path: %+v", files[0]) + } +} + +func TestRunAbortsWhenOversizedPlainGitCommitIsRejected(t *testing.T) { + repo := setupGitRepo(t) + oldwd := mustChdir(t, repo) + t.Cleanup(func() { _ = os.Chdir(oldwd) }) + + path := filepath.Join(repo, "data", "large.bin") + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(path, []byte("plain oversized payload"), 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + gitCmd(t, repo, "add", "data/large.bin") + + oldThreshold := directCommitWarningThresholdBytes + oldPrompt := confirmOversizedDirectGitCommit + t.Cleanup(func() { + directCommitWarningThresholdBytes = oldThreshold + confirmOversizedDirectGitCommit = oldPrompt + }) + directCommitWarningThresholdBytes = 1 + confirmOversizedDirectGitCommit = func(files []OversizedStagedFile) (bool, error) { + if len(files) != 1 || files[0].Path != "data/large.bin" { + t.Fatalf("unexpected prompt files: %+v", files) + } + return false, nil + } + + err := run(context.Background()) + if err == nil { + t.Fatal("expected run to abort when oversized file warning is rejected") + } + if !strings.Contains(err.Error(), "commit aborted") { + t.Fatalf("unexpected error: %v", err) + } +} + func setupGitRepo(t *testing.T) string { t.Helper() dir := t.TempDir() diff --git a/cmd/prepush/io_helpers.go b/cmd/prepush/io_helpers.go new file mode 100644 index 00000000..75dee003 --- /dev/null +++ b/cmd/prepush/io_helpers.go @@ -0,0 +1,27 @@ +package prepush + +import ( + "fmt" + "io" + "os" +) + +func bufferStdin(stdin io.Reader, createTempFile func(dir, pattern string) (*os.File, error)) (*os.File, error) { + tmp, err := createTempFile("", "prepush-stdin-*") + if err != nil { + return nil, fmt.Errorf("error creating temp file for stdin: %w", err) + } + + if _, err := io.Copy(tmp, stdin); err != nil { + _ = tmp.Close() + _ = os.Remove(tmp.Name()) + return nil, fmt.Errorf("error buffering stdin: %w", err) + } + + if _, err := tmp.Seek(0, 0); err != nil { + _ = tmp.Close() + _ = os.Remove(tmp.Name()) + return nil, fmt.Errorf("error seeking temp stdin: %w", err) + } + return tmp, nil +} diff --git a/cmd/prepush/main.go b/cmd/prepush/main.go index 9e2e841e..6ff4a411 100644 --- a/cmd/prepush/main.go +++ b/cmd/prepush/main.go @@ -1,7 +1,6 @@ package prepush import ( - "bufio" "bytes" "context" "encoding/base64" @@ -18,6 +17,7 @@ import ( "github.com/calypr/git-drs/internal/common" "github.com/calypr/git-drs/internal/config" + "github.com/calypr/git-drs/internal/drsdelete" "github.com/calypr/git-drs/internal/drslog" "github.com/calypr/git-drs/internal/drsmap" "github.com/calypr/git-drs/internal/drsobject" @@ -25,7 +25,6 @@ import ( "github.com/calypr/git-drs/internal/lfs" "github.com/calypr/git-drs/internal/precommit_cache" drsapi "github.com/calypr/syfon/apigen/client/drs" - syfoncommon "github.com/calypr/syfon/common" "github.com/spf13/cobra" ) @@ -86,6 +85,10 @@ func (s *PrePushService) Run(args []string, stdin io.Reader) error { myLogger.Debug("Warning. Skipping DRS preparation. Error getting remote configuration.") return nil } + drsClient, err := cfg.GetRemoteClient(remote, myLogger) + if err != nil { + return err + } scope, err := gitrepo.ResolveBucketScope( remoteConfig.GetOrganization(), @@ -117,6 +120,10 @@ func (s *PrePushService) Run(args []string, stdin io.Reader) error { myLogger.Error(fmt.Sprintf("error reading pushed refs: %v", err)) return err } + if _, err := drsdelete.ReconcileCommittedDeletes(ctx, drsClient, drsDeleteRefs(refs), myLogger); err != nil { + myLogger.Error(fmt.Sprintf("delete reconciliation failed: %v", err)) + return err + } branches := branchesFromRefs(refs) cache, cacheReady := openCache(ctx, myLogger) @@ -244,9 +251,6 @@ func toMetadataCandidate(c drsapi.DrsObjectCandidate) metadataCandidate { URL: accURL, }, } - if authzMap := syfoncommon.AuthzMapFromAccessMethodAuthorizations(am.Authorizations); len(authzMap) > 0 { - m.Authorizations = authzMap - } out.AccessMethods = append(out.AccessMethods, m) } } @@ -288,7 +292,7 @@ func submitPendingLFSMeta(ctx context.Context, remote config.Remote, endpoint st if err != nil { return fmt.Errorf("failed to create pending metadata request: %w", err) } - httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Content-Type", "application/vnd.git-lfs+json") httpReq.Header.Set("Accept", "application/vnd.git-lfs+json") if authHeader, ok := resolveRemoteAuthHeader(string(remote)); ok { httpReq.Header.Set("Authorization", authHeader) @@ -350,87 +354,6 @@ func parseRemoteArgs(args []string) (string, string) { return gitRemoteName, gitRemoteLocation } -type pushedRef struct { - LocalRef string - LocalSHA string - RemoteRef string - RemoteSHA string -} - -func bufferStdin(stdin io.Reader, createTempFile func(dir, pattern string) (*os.File, error)) (*os.File, error) { - tmp, err := createTempFile("", "prepush-stdin-*") - if err != nil { - return nil, fmt.Errorf("error creating temp file for stdin: %w", err) - } - - if _, err := io.Copy(tmp, stdin); err != nil { - _ = tmp.Close() - _ = os.Remove(tmp.Name()) - return nil, fmt.Errorf("error buffering stdin: %w", err) - } - - if _, err := tmp.Seek(0, 0); err != nil { - _ = tmp.Close() - _ = os.Remove(tmp.Name()) - return nil, fmt.Errorf("error seeking temp stdin: %w", err) - } - return tmp, nil -} - -// readPushedBranches reads git push lines from the provided temp file, -// extracts unique local branch names for refs under `refs/heads/` and -// returns them sorted. The file is rewound to the start before returning. -func readPushedRefs(f io.ReadSeeker) ([]pushedRef, error) { - // Ensure we read from start - // example: - // refs/heads/main 67890abcdef1234567890abcdef1234567890abcd refs/heads/main 12345abcdef67890abcdef1234567890abcdef12 - if _, err := f.Seek(0, 0); err != nil { - return nil, err - } - scanner := bufio.NewScanner(f) - refs := make([]pushedRef, 0) - for scanner.Scan() { - line := scanner.Text() - fields := strings.Fields(line) - if len(fields) < 4 { - continue - } - refs = append(refs, pushedRef{ - LocalRef: fields[0], - LocalSHA: fields[1], - RemoteRef: fields[2], - RemoteSHA: fields[3], - }) - } - if err := scanner.Err(); err != nil { - return nil, err - } - // Rewind so caller can reuse the file - if _, err := f.Seek(0, 0); err != nil { - return nil, err - } - return refs, nil -} - -func branchesFromRefs(refs []pushedRef) []string { - const prefix = "refs/heads/" - set := make(map[string]struct{}) - for _, ref := range refs { - if strings.HasPrefix(ref.LocalRef, prefix) { - branch := strings.TrimPrefix(ref.LocalRef, prefix) - if branch != "" { - set[branch] = struct{}{} - } - } - } - branches := make([]string, 0, len(set)) - for b := range set { - branches = append(branches, b) - } - sort.Strings(branches) - return branches -} - func openCache(ctx context.Context, logger *slog.Logger) (*precommit_cache.Cache, bool) { cache, err := precommit_cache.Open(ctx) if err != nil { @@ -561,45 +484,3 @@ func gitOutput(ctx context.Context, args ...string) (string, error) { } return string(out), nil } - -// readPushedBranches reads git push lines from the provided temp file, -// extracts unique local branch names for refs under `refs/heads/` and -// returns them sorted. The file is rewound to the start before returning. -func readPushedBranches(f *os.File) ([]string, error) { - // Ensure we read from start - // example: - // refs/heads/main 67890abcdef1234567890abcdef1234567890abcd refs/heads/main 12345abcdef67890abcdef1234567890abcdef12 - if _, err := f.Seek(0, 0); err != nil { - return nil, err - } - scanner := bufio.NewScanner(f) - set := make(map[string]struct{}) - for scanner.Scan() { - line := scanner.Text() - fields := strings.Fields(line) - if len(fields) < 1 { - continue - } - localRef := fields[0] - const prefix = "refs/heads/" - if strings.HasPrefix(localRef, prefix) { - branch := strings.TrimPrefix(localRef, prefix) - if branch != "" { - set[branch] = struct{}{} - } - } - } - if err := scanner.Err(); err != nil { - return nil, err - } - branches := make([]string, 0, len(set)) - for b := range set { - branches = append(branches, b) - } - sort.Strings(branches) - // Rewind so caller can reuse the file - if _, err := f.Seek(0, 0); err != nil { - return nil, err - } - return branches, nil -} diff --git a/cmd/prepush/main_test.go b/cmd/prepush/main_test.go index 5fad785b..99b0a164 100644 --- a/cmd/prepush/main_test.go +++ b/cmd/prepush/main_test.go @@ -100,7 +100,7 @@ func TestLfsFilesFromCache(t *testing.T) { } } -func TestReadPushedBranches(t *testing.T) { +func TestReadPushedRefsAndBranchesFromRefs(t *testing.T) { tests := []struct { name string input string @@ -145,12 +145,11 @@ func TestReadPushedBranches(t *testing.T) { t.Fatalf("write temp: %v", err) } - // readPushedBranches seeks to 0 itself, but we pass the *os.File - // which must be valid. - branches, err := readPushedBranches(tmp) + refs, err := readPushedRefs(tmp) if err != nil { - t.Fatalf("readPushedBranches error: %v", err) + t.Fatalf("readPushedRefs error: %v", err) } + branches := branchesFromRefs(refs) if len(branches) != len(tt.expected) { t.Errorf("expected %d branches, got %d: %v", len(tt.expected), len(branches), branches) @@ -363,8 +362,8 @@ func TestSubmitPendingLFSMetaRequestWiring(t *testing.T) { if gotAuth != "Bearer test-token" { t.Fatalf("expected auth header, got %q", gotAuth) } - if gotContentType != "application/json" { - t.Fatalf("expected content-type application/json, got %q", gotContentType) + if gotContentType != "application/vnd.git-lfs+json" { + t.Fatalf("expected content-type application/vnd.git-lfs+json, got %q", gotContentType) } if gotAccept != "application/vnd.git-lfs+json" { t.Fatalf("expected accept header application/vnd.git-lfs+json, got %q", gotAccept) diff --git a/cmd/prepush/pushed_refs.go b/cmd/prepush/pushed_refs.go new file mode 100644 index 00000000..6db5298e --- /dev/null +++ b/cmd/prepush/pushed_refs.go @@ -0,0 +1,76 @@ +package prepush + +import ( + "bufio" + "io" + "sort" + "strings" + + "github.com/calypr/git-drs/internal/drsdelete" +) + +type pushedRef struct { + LocalRef string + LocalSHA string + RemoteRef string + RemoteSHA string +} + +// readPushedRefs parses git's pre-push stdin format and rewinds the reader +// before returning so callers can reuse the buffered input. +func readPushedRefs(f io.ReadSeeker) ([]pushedRef, error) { + if _, err := f.Seek(0, 0); err != nil { + return nil, err + } + scanner := bufio.NewScanner(f) + refs := make([]pushedRef, 0) + for scanner.Scan() { + fields := strings.Fields(scanner.Text()) + if len(fields) < 4 { + continue + } + refs = append(refs, pushedRef{ + LocalRef: fields[0], + LocalSHA: fields[1], + RemoteRef: fields[2], + RemoteSHA: fields[3], + }) + } + if err := scanner.Err(); err != nil { + return nil, err + } + if _, err := f.Seek(0, 0); err != nil { + return nil, err + } + return refs, nil +} + +func branchesFromRefs(refs []pushedRef) []string { + const prefix = "refs/heads/" + set := make(map[string]struct{}) + for _, ref := range refs { + if strings.HasPrefix(ref.LocalRef, prefix) { + branch := strings.TrimPrefix(ref.LocalRef, prefix) + if branch != "" { + set[branch] = struct{}{} + } + } + } + branches := make([]string, 0, len(set)) + for branch := range set { + branches = append(branches, branch) + } + sort.Strings(branches) + return branches +} + +func drsDeleteRefs(refs []pushedRef) []drsdelete.RefUpdate { + out := make([]drsdelete.RefUpdate, 0, len(refs)) + for _, ref := range refs { + out = append(out, drsdelete.RefUpdate{ + OldSHA: strings.TrimSpace(ref.RemoteSHA), + NewSHA: strings.TrimSpace(ref.LocalSHA), + }) + } + return out +} diff --git a/cmd/pull/main.go b/cmd/pull/main.go index cf352d73..2019515c 100644 --- a/cmd/pull/main.go +++ b/cmd/pull/main.go @@ -3,30 +3,41 @@ package pull import ( "context" "fmt" + "io" + "log/slog" "net/url" "os" - "os/exec" + "path/filepath" + "sort" "strings" - "github.com/bytedance/sonic" "github.com/calypr/git-drs/internal/common" "github.com/calypr/git-drs/internal/config" "github.com/calypr/git-drs/internal/drslog" "github.com/calypr/git-drs/internal/drsremote" "github.com/calypr/git-drs/internal/lfs" + "github.com/calypr/git-drs/internal/pathspec" drsapi "github.com/calypr/syfon/apigen/client/drs" + sycommon "github.com/calypr/syfon/client/common" "github.com/spf13/cobra" ) -var runCommand = func(name string, args ...string) ([]byte, error) { - cmd := exec.Command(name, args...) - return cmd.CombinedOutput() -} +var includePatterns []string +var dryRun bool + +var ( + loadCfg = config.LoadConfig + resolveRemote = func(cfg *config.Config, name string) (config.Remote, error) { return cfg.GetRemoteOrDefault(name) } + newRemoteClient = func(cfg *config.Config, remote config.Remote, logger *slog.Logger) (*config.GitContext, error) { + return cfg.GetRemoteClient(remote, logger) + } + loadWorktreeInventory = lfs.GetWorktreeLfsFiles +) var Cmd = &cobra.Command{ Use: "pull [remote-name]", - Short: "Pull using the standard Git + Git LFS flow", - Long: "Pull using the standard Git + Git LFS flow (git pull, git lfs pull, git lfs checkout).", + Short: "Download DRS pointer file content into the current checkout", + Long: "Hydrate DRS/Git-LFS pointer files in the current checkout. By default this mirrors git lfs pull semantics for the worktree rather than running git pull.", Args: func(cmd *cobra.Command, args []string) error { if len(args) > 1 { cmd.SilenceUsage = false @@ -37,7 +48,7 @@ var Cmd = &cobra.Command{ RunE: func(cmd *cobra.Command, args []string) error { logg := drslog.GetLogger() - cfg, err := config.LoadConfig() + cfg, err := loadCfg() if err != nil { return fmt.Errorf("error loading config: %v", err) } @@ -46,55 +57,54 @@ var Cmd = &cobra.Command{ if len(args) > 0 { remote = config.Remote(args[0]) } else { - remote, err = cfg.GetDefaultRemote() + remote, err = resolveRemote(cfg, "") if err != nil { logg.Error(fmt.Sprintf("Error getting remote: %v", err)) return err } } - drsCtx, err := cfg.GetRemoteClient(remote, logg) + drsCtx, err := newRemoteClient(cfg, remote, logg) if err != nil { logg.Error(fmt.Sprintf("error creating DRS client: %s", err)) return err } - _ = drsCtx // Remote validation only. - if out, err := runCommand("git", "pull", string(remote)); err != nil { - msg := strings.TrimSpace(string(out)) - if msg == "" { - msg = err.Error() - } - return fmt.Errorf("git pull failed for remote %q: %s", remote, msg) + inventory, err := loadWorktreeInventory(logg) + if err != nil { + return fmt.Errorf("failed to discover pointer files in worktree: %w", err) } - - var parsed struct { - Files []lfs.LfsFileInfo `json:"files"` + pointers := collectPointerFiles(inventory, includePatterns) + if len(pointers) == 0 { + logg.Debug("no matching pointer files to hydrate") + return nil } - out, err := runCommand("git", "lfs", "ls-files", "--json") - if err != nil { - msg := commandMessage(out, err) - if !isMissingGitLFS(msg) { - return fmt.Errorf("git lfs ls-files failed: %s", msg) - } - lfsFiles, inventoryErr := lfs.GetAllLfsFiles(string(remote), "", []string{"HEAD"}, logg) - if inventoryErr != nil { - return fmt.Errorf("git lfs ls-files failed: %s; fallback inventory failed: %w", msg, inventoryErr) - } - parsed.Files = make([]lfs.LfsFileInfo, 0, len(lfsFiles)) - for _, f := range lfsFiles { - parsed.Files = append(parsed.Files, f) + + progress := newPullProgressRenderer(os.Stderr) + progress.OnPlan(pointers) + defer progress.Finish() + + if dryRun { + for _, f := range pointers { + if _, err := fmt.Fprintln(cmd.OutOrStdout(), f.Name); err != nil { + return err + } } - } else if err := lfsjsonUnmarshal(out, &parsed); err != nil { - return fmt.Errorf("failed to parse git lfs ls-files output: %w", err) + return nil } ctx := context.Background() - missingOIDs := make([]string, 0, len(parsed.Files)) - seenMissing := make(map[string]struct{}, len(parsed.Files)) - for _, f := range parsed.Files { - if f.Downloaded { + missingOIDs := make([]string, 0, len(pointers)) + seenMissing := make(map[string]struct{}, len(pointers)) + for _, f := range pointers { + cachePath, err := lfs.ObjectPath(common.LFS_OBJS_PATH, f.Oid) + if err != nil { + return fmt.Errorf("failed to resolve LFS object path for %s: %w", f.Oid, err) + } + if _, err := os.Stat(cachePath); err == nil { continue + } else if !os.IsNotExist(err) { + return fmt.Errorf("failed to stat cached object for %s: %w", f.Oid, err) } if _, seen := seenMissing[f.Oid]; seen { continue @@ -131,40 +141,38 @@ var Cmd = &cobra.Command{ logg.Debug(fmt.Sprintf("bulk access prefetch failed; continuing per-object: %v", err)) } } - for _, f := range parsed.Files { - if f.Downloaded { - continue - } + for _, f := range pointers { dstPath, err := lfs.ObjectPath(common.LFS_OBJS_PATH, f.Oid) if err != nil { return fmt.Errorf("failed to resolve LFS object path for %s: %w", f.Oid, err) } + if _, err := os.Stat(dstPath); err == nil { + continue + } else if !os.IsNotExist(err) { + return fmt.Errorf("failed to stat cache path %s: %w", dstPath, err) + } + progress.OnDownloadStart(f) + downloadCtx := progressContextForPointer(ctx, progress, f) if obj, ok := prefetched[f.Oid]; ok { if accessURL, ok := prefetchedAccess[obj.Id]; ok { objCopy := obj - if err := drsremote.DownloadResolvedToCachePath(ctx, drsCtx, f.Oid, dstPath, &objCopy, &accessURL); err != nil { + if err := drsremote.DownloadResolvedToCachePath(downloadCtx, drsCtx, f.Oid, dstPath, &objCopy, &accessURL); err != nil { debugCtx := buildPullDownloadDebugContext(ctx, drsCtx, f.Oid) return fmt.Errorf("failed to download oid %s to %s: %w\npull-debug: %s", f.Oid, dstPath, err, debugCtx) } continue } } - if err := drsremote.DownloadToCachePath(ctx, drsCtx, logg, f.Oid, dstPath); err != nil { + if err := drsremote.DownloadToCachePath(downloadCtx, drsCtx, logg, f.Oid, dstPath); err != nil { debugCtx := buildPullDownloadDebugContext(ctx, drsCtx, f.Oid) return fmt.Errorf("failed to download oid %s to %s: %w\npull-debug: %s", f.Oid, dstPath, err, debugCtx) } } } else { - logg.Debug("no missing LFS objects to download") + logg.Debug("no missing pointer objects to download") } - if out, err := runCommand("git", "lfs", "checkout"); err != nil { - msg := commandMessage(out, err) - if !isMissingGitLFS(msg) { - return fmt.Errorf("git lfs checkout failed: %s", msg) - } - } - if err := checkoutDownloadedFiles(parsed.Files); err != nil { + if err := checkoutDownloadedFiles(pointers, progress); err != nil { return err } @@ -172,19 +180,42 @@ var Cmd = &cobra.Command{ }, } -func commandMessage(out []byte, err error) string { - msg := strings.TrimSpace(string(out)) - if msg == "" && err != nil { - msg = err.Error() +type pointerFile struct { + Name string + Oid string + Size int64 +} + +func collectPointerFiles(inventory map[string]lfs.LfsFileInfo, patterns []string) []pointerFile { + keys := make([]string, 0, len(inventory)) + for path := range inventory { + if !pathspec.MatchesAny(path, patterns) { + continue + } + keys = append(keys, path) + } + sort.Strings(keys) + + files := make([]pointerFile, 0, len(keys)) + for _, path := range keys { + info := inventory[path] + files = append(files, pointerFile{Name: path, Oid: info.Oid, Size: info.Size}) } - return msg + return files } -func isMissingGitLFS(msg string) bool { - return strings.Contains(msg, "git: 'lfs' is not a git command") +func progressContextForPointer(ctx context.Context, progress *pullProgressRenderer, file pointerFile) context.Context { + ctx = sycommon.WithOid(ctx, file.Name) + return sycommon.WithProgress(ctx, func(ev sycommon.ProgressEvent) error { + if ev.Event != "progress" { + return nil + } + progress.OnDownloadProgress(file.Name, ev.BytesSoFar, file.Size) + return nil + }) } -func checkoutDownloadedFiles(files []lfs.LfsFileInfo) error { +func checkoutDownloadedFiles(files []pointerFile, progress *pullProgressRenderer) error { for _, f := range files { if strings.TrimSpace(f.Name) == "" || strings.TrimSpace(f.Oid) == "" { continue @@ -193,21 +224,39 @@ func checkoutDownloadedFiles(files []lfs.LfsFileInfo) error { if err != nil { return fmt.Errorf("failed to resolve cached object for %s: %w", f.Oid, err) } - payload, err := os.ReadFile(srcPath) + src, err := os.Open(srcPath) if err != nil { return fmt.Errorf("failed to read cached object %s: %w", srcPath, err) } - if err := os.WriteFile(f.Name, payload, 0o644); err != nil { + progress.OnCheckoutStart(f) + if dir := filepath.Dir(f.Name); dir != "." { + if err := os.MkdirAll(dir, 0o755); err != nil { + src.Close() + return fmt.Errorf("failed to create directory for %s: %w", f.Name, err) + } + } + dst, err := os.OpenFile(f.Name, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644) + if err != nil { + src.Close() return fmt.Errorf("failed to checkout %s: %w", f.Name, err) } + if _, err := io.Copy(dst, src); err != nil { + dst.Close() + src.Close() + return fmt.Errorf("failed to checkout %s: %w", f.Name, err) + } + if err := dst.Close(); err != nil { + src.Close() + return fmt.Errorf("failed to finalize checkout for %s: %w", f.Name, err) + } + if err := src.Close(); err != nil { + return fmt.Errorf("failed to close cached object %s: %w", srcPath, err) + } + progress.OnCompleted(f) } return nil } -var lfsjsonUnmarshal = func(data []byte, v any) error { - return sonic.ConfigFastest.Unmarshal(data, v) -} - func buildPullDownloadDebugContext(ctx context.Context, drsCtx *config.GitContext, oid string) string { recs, err := drsremote.ObjectsByHashForScope(ctx, drsCtx, oid) if err != nil { @@ -242,3 +291,8 @@ func buildPullDownloadDebugContext(ctx context.Context, drsCtx *config.GitContex } return fmt.Sprintf("oid=%s did=%s size=%d access_methods=%s", oid, strings.TrimSpace(match.Id), match.Size, strings.Join(methods, ", ")) } + +func init() { + Cmd.Flags().StringArrayVarP(&includePatterns, "include", "I", nil, "include pathspec/glob pattern(s)") + Cmd.Flags().BoolVar(&dryRun, "dry-run", false, "list matching pointer files without downloading them") +} diff --git a/cmd/pull/progress.go b/cmd/pull/progress.go new file mode 100644 index 00000000..f1a6d345 --- /dev/null +++ b/cmd/pull/progress.go @@ -0,0 +1,185 @@ +package pull + +import ( + "fmt" + "io" + + "github.com/calypr/git-drs/internal/progressui" +) + +const pullNonTTYProgressInterval = progressui.NonTTYProgressInterval + +type pullProgressPhase string + +const ( + pullProgressPending pullProgressPhase = "pending" + pullProgressDownloading pullProgressPhase = "downloading" + pullProgressCheckingOut pullProgressPhase = "checking_out" + pullProgressCompleted pullProgressPhase = "completed" +) + +type pullFileProgress struct { + path string + total int64 + current int64 + phase pullProgressPhase +} + +type pullProgressRenderer struct { + base *progressui.Renderer + planned bool + files map[string]*pullFileProgress + fileOrder []string +} + +func newPullProgressRenderer(out io.Writer) *pullProgressRenderer { + return &pullProgressRenderer{ + base: progressui.NewRenderer(out), + files: make(map[string]*pullFileProgress), + } +} + +func isPullWriterTTY(w io.Writer) bool { + return progressui.IsWriterTTY(w) +} + +func (r *pullProgressRenderer) render(force bool) { + lines := make([]string, 0, len(r.fileOrder)) + for _, id := range r.fileOrder { + item := r.files[id] + if item == nil { + continue + } + lines = append(lines, r.renderLine(item)) + } + r.base.Render(force, lines) +} + +func (r *pullProgressRenderer) OnPlan(files []pointerFile) { + r.planned = len(files) > 0 + r.files = make(map[string]*pullFileProgress, len(files)) + r.fileOrder = r.fileOrder[:0] + for _, file := range files { + r.files[file.Name] = &pullFileProgress{ + path: file.Name, + total: file.Size, + phase: pullProgressPending, + } + r.fileOrder = append(r.fileOrder, file.Name) + } + if r.planned { + r.render(true) + } +} + +func (r *pullProgressRenderer) OnDownloadStart(file pointerFile) { + if !r.planned { + return + } + item, ok := r.files[file.Name] + if !ok { + return + } + item.path = file.Name + if file.Size > 0 { + item.total = file.Size + } + item.phase = pullProgressDownloading + r.render(false) +} + +func (r *pullProgressRenderer) OnDownloadProgress(id string, bytesSoFar int64, total int64) { + if !r.planned { + return + } + item, ok := r.files[id] + if !ok { + return + } + if total > 0 { + item.total = total + } + if bytesSoFar > item.current { + item.current = bytesSoFar + } + item.phase = pullProgressDownloading + r.render(false) +} + +func (r *pullProgressRenderer) OnCheckoutStart(file pointerFile) { + if !r.planned { + return + } + item, ok := r.files[file.Name] + if !ok { + return + } + item.phase = pullProgressCheckingOut + if item.total == 0 && file.Size > 0 { + item.total = file.Size + } + r.render(false) +} + +func (r *pullProgressRenderer) OnCompleted(file pointerFile) { + if !r.planned { + return + } + item, ok := r.files[file.Name] + if !ok { + return + } + if item.total == 0 && file.Size > 0 { + item.total = file.Size + } + if item.total > 0 { + item.current = item.total + } + item.phase = pullProgressCompleted + r.render(false) +} + +func (r *pullProgressRenderer) Finish() { + if !r.planned { + return + } + lines := make([]string, 0, len(r.fileOrder)) + for _, id := range r.fileOrder { + item := r.files[id] + if item == nil { + continue + } + lines = append(lines, r.renderLine(item)) + } + r.base.Finish(lines) + r.planned = false +} + +func (r *pullProgressRenderer) renderLine(file *pullFileProgress) string { + label := "preparing pull" + if file != nil && file.path != "" { + label = progressui.TrimLabel(file.path, 48) + } + + prefix := "" + if file != nil { + switch file.phase { + case pullProgressDownloading, pullProgressCheckingOut: + if !(file.total > 0 && file.current >= file.total) { + prefix = r.base.Spinner() + " " + } + } + } + + current := int64(0) + total := int64(0) + if file != nil { + current = file.current + total = file.total + } + bar := progressui.RenderProgressBar(current, total, 24) + pct := progressui.RenderPercent(current, total) + bytesLabel := progressui.RenderByteProgress(current, total, current >= total) + + return fmt.Sprintf("%s%s %s %s %s", prefix, label, bar, pct, bytesLabel) +} diff --git a/cmd/pull/progress_test.go b/cmd/pull/progress_test.go new file mode 100644 index 00000000..77682c71 --- /dev/null +++ b/cmd/pull/progress_test.go @@ -0,0 +1,85 @@ +package pull + +import ( + "bytes" + "strings" + "testing" + "time" +) + +func TestPullProgressRendererTTY(t *testing.T) { + var out bytes.Buffer + r := newPullProgressRenderer(&out) + r.base.SetTTY(true) + r.base.SetClock(func() time.Time { return time.Unix(0, 0) }) + + files := []pointerFile{ + {Name: "a.bin", Oid: "oid-1", Size: 100}, + {Name: "b.bin", Oid: "oid-2", Size: 100}, + } + r.OnPlan(files) + r.OnDownloadStart(files[0]) + r.OnDownloadProgress("a.bin", 50, 100) + r.OnCheckoutStart(files[0]) + r.OnCompleted(files[0]) + r.OnCompleted(files[1]) + + got := out.String() + if !strings.Contains(got, "a.bin [============") { + t.Fatalf("expected progress bar output for a.bin, got %q", got) + } + if !strings.Contains(got, "100.0% 100 B/100 B") { + t.Fatalf("expected completed byte summary, got %q", got) + } +} + +func TestPullProgressRendererNonTTYThrottles(t *testing.T) { + var out bytes.Buffer + now := time.Unix(0, 0) + r := newPullProgressRenderer(&out) + r.base.SetTTY(false) + r.base.SetClock(func() time.Time { return now }) + + file := pointerFile{Name: "a.bin", Oid: "oid-1", Size: 100} + r.OnPlan([]pointerFile{file}) + initial := out.String() + if !strings.Contains(initial, "a.bin") { + t.Fatalf("expected initial non-tty progress line, got %q", initial) + } + + r.OnDownloadStart(file) + r.OnDownloadProgress("a.bin", 10, 100) + if got := out.String(); got != initial { + t.Fatalf("expected throttled output before interval, got %q", got) + } + + now = now.Add(pullNonTTYProgressInterval) + r.OnCompleted(file) + got := out.String() + if !strings.Contains(got, "100.0% 100 B/100 B") { + t.Fatalf("expected rendered completion after interval, got %q", got) + } +} + +func TestPullProgressRendererNoSpinnerAtFullDownloadedBytes(t *testing.T) { + var out bytes.Buffer + r := newPullProgressRenderer(&out) + r.base.SetTTY(true) + r.base.SetClock(func() time.Time { return time.Unix(0, 0) }) + + file := pointerFile{Name: "a.bin", Oid: "oid-1", Size: 100} + r.OnPlan([]pointerFile{file}) + r.OnDownloadStart(file) + r.OnDownloadProgress("a.bin", 100, 100) + + got := out.String() + if strings.Contains(got, "/ a.bin [========================] 100.0% 100 B/100 B") || + strings.Contains(got, "| a.bin [========================] 100.0% 100 B/100 B") || + strings.Contains(got, "- a.bin [========================] 100.0% 100 B/100 B") || + strings.Contains(got, "\\ a.bin [========================] 100.0% 100 B/100 B") { + t.Fatalf("expected no spinner prefix on fully downloaded file, got %q", got) + } + if !strings.Contains(got, "a.bin [========================] 100.0% 100 B/100 B") { + t.Fatalf("expected completed byte line without spinner, got %q", got) + } +} diff --git a/cmd/pull/pull_test.go b/cmd/pull/pull_test.go index 41c999ac..7e8e6c74 100644 --- a/cmd/pull/pull_test.go +++ b/cmd/pull/pull_test.go @@ -1,34 +1,81 @@ package pull import ( + "bytes" + "log/slog" "testing" "github.com/calypr/git-drs/internal/config" - "github.com/calypr/git-drs/internal/testutils" - "github.com/stretchr/testify/assert" + "github.com/calypr/git-drs/internal/lfs" ) -func TestPullCmdArgs(t *testing.T) { - err := Cmd.Args(Cmd, []string{}) - assert.NoError(t, err) +func resetPullFlagsForTest() { + includePatterns = nil + dryRun = false +} - err = Cmd.Args(Cmd, []string{"origin"}) - assert.NoError(t, err) +func TestCollectPointerFilesFiltersAndSorts(t *testing.T) { + resetPullFlagsForTest() - err = Cmd.Args(Cmd, []string{"origin", "extra"}) - assert.Error(t, err) -} + inventory := map[string]lfs.LfsFileInfo{ + "data/b.bin": {Name: "data/b.bin", Oid: "bbbb", Size: 2}, + "data/a.bin": {Name: "data/a.bin", Oid: "aaaa", Size: 1}, + "misc/c.bin": {Name: "misc/c.bin", Oid: "cccc", Size: 3}, + } -func TestPullRun_LoadConfigError(t *testing.T) { - _ = testutils.SetupTestGitRepo(t) - err := Cmd.RunE(Cmd, []string{}) - assert.Error(t, err) + files := collectPointerFiles(inventory, []string{"data/**"}) + if len(files) != 2 { + t.Fatalf("expected 2 files, got %d", len(files)) + } + if files[0].Name != "data/a.bin" || files[1].Name != "data/b.bin" { + t.Fatalf("unexpected file order: %+v", files) + } } -func TestPullRun_DefaultRemoteError(t *testing.T) { - tmpDir := testutils.SetupTestGitRepo(t) - testutils.CreateTestConfig(t, tmpDir, &config.Config{}) +func TestPullDryRunListsMatchingPaths(t *testing.T) { + resetPullFlagsForTest() + + oldLoadCfg := loadCfg + oldResolveRemote := resolveRemote + oldNewRemoteClient := newRemoteClient + oldInventory := loadWorktreeInventory + t.Cleanup(func() { + loadCfg = oldLoadCfg + resolveRemote = oldResolveRemote + newRemoteClient = oldNewRemoteClient + loadWorktreeInventory = oldInventory + }) + + loadCfg = func() (*config.Config, error) { return &config.Config{}, nil } + resolveRemote = func(cfg *config.Config, name string) (config.Remote, error) { return config.Remote("origin"), nil } + newRemoteClient = func(cfg *config.Config, remote config.Remote, logger *slog.Logger) (*config.GitContext, error) { + return &config.GitContext{}, nil + } + loadWorktreeInventory = func(_ *slog.Logger) (map[string]lfs.LfsFileInfo, error) { + return map[string]lfs.LfsFileInfo{ + "data/a.bin": {Name: "data/a.bin", Oid: "aaaa", Size: 1}, + "misc/b.bin": {Name: "misc/b.bin", Oid: "bbbb", Size: 2}, + }, nil + } + + includePatterns = []string{"data/**"} + dryRun = true + + var out bytes.Buffer + Cmd.SetOut(&out) + Cmd.SetErr(&out) + Cmd.SetArgs([]string{"--dry-run"}) + t.Cleanup(func() { + Cmd.SetOut(nil) + Cmd.SetErr(nil) + Cmd.SetArgs(nil) + resetPullFlagsForTest() + }) - err := Cmd.RunE(Cmd, []string{}) - assert.Error(t, err) + if err := Cmd.RunE(Cmd, []string{}); err != nil { + t.Fatalf("RunE returned error: %v", err) + } + if got := out.String(); got != "data/a.bin\n" { + t.Fatalf("unexpected dry-run output: %q", got) + } } diff --git a/cmd/push/main.go b/cmd/push/main.go index 4d0445ca..e9877f3b 100644 --- a/cmd/push/main.go +++ b/cmd/push/main.go @@ -3,10 +3,12 @@ package push import ( "context" "fmt" + "os" "os/exec" "strings" "github.com/calypr/git-drs/internal/config" + "github.com/calypr/git-drs/internal/drsdelete" "github.com/calypr/git-drs/internal/drslog" "github.com/calypr/git-drs/internal/lfs" "github.com/calypr/git-drs/internal/pushsync" @@ -14,12 +16,15 @@ import ( ) var pushWithHooks bool +var pushForceUpload bool var runCommand = func(name string, args ...string) ([]byte, error) { cmd := exec.Command(name, args...) return cmd.CombinedOutput() } +var gitOutputFn = gitOutput + var Cmd = &cobra.Command{ Use: "push [remote-name]", Short: "Upload/register DRS objects and push Git refs", @@ -55,15 +60,32 @@ var Cmd = &cobra.Command{ myLogger.Debug(fmt.Sprintf("Error creating DRS client: %s", err)) return err } + drsClient.ForceUpload = pushForceUpload lfsFiles, err := lfs.GetAllLfsFiles(string(remote), "", []string{"HEAD"}, myLogger) if err != nil { return fmt.Errorf("failed to discover LFS files to push: %w", err) } ctx := context.Background() - if err := pushsync.BatchSyncForPush(drsClient, ctx, lfsFiles); err != nil { + deleteRefs, err := currentDeleteRefUpdates(ctx) + if err != nil { + return fmt.Errorf("failed to resolve delete reconciliation base: %w", err) + } + if _, err := drsdelete.ReconcileCommittedDeletes(ctx, drsClient, deleteRefs, myLogger); err != nil { + return fmt.Errorf("failed to reconcile deletes: %w", err) + } + progress := newUploadProgressRenderer(os.Stderr) + if err := pushsync.BatchSyncForPush(drsClient, ctx, lfsFiles, progress); err != nil { + progress.Finish() return fmt.Errorf("failed batch register/upload workflow: %w", err) } + progress.Finish() + switch { + case len(lfsFiles) == 0: + fmt.Fprintln(os.Stdout, "No git-drs tracked files found; pushing Git refs only.") + case !progress.HadUploads(): + fmt.Fprintln(os.Stdout, "No DRS payload uploads needed; all tracked objects are already available remotely.") + } pushArgs := []string{"push"} if !pushWithHooks { @@ -84,4 +106,29 @@ var Cmd = &cobra.Command{ func init() { Cmd.Flags().BoolVar(&pushWithHooks, "with-hooks", false, "Run git push with local hooks enabled (invokes pre-push)") + Cmd.Flags().BoolVar(&pushForceUpload, "force-upload", false, "Upload payload bytes even when a matching downloadable object already exists remotely") +} + +func currentDeleteRefUpdates(ctx context.Context) ([]drsdelete.RefUpdate, error) { + head, err := gitOutputFn(ctx, "rev-parse", "HEAD") + if err != nil { + return nil, err + } + upstream, err := gitOutputFn(ctx, "rev-parse", "--verify", "@{upstream}") + if err != nil { + return nil, nil + } + return []drsdelete.RefUpdate{{ + OldSHA: upstream, + NewSHA: head, + }}, nil +} + +func gitOutput(ctx context.Context, args ...string) (string, error) { + cmd := exec.CommandContext(ctx, "git", args...) + out, err := cmd.CombinedOutput() + if err != nil { + return "", fmt.Errorf("git %s: %s", strings.Join(args, " "), strings.TrimSpace(string(out))) + } + return strings.TrimSpace(string(out)), nil } diff --git a/cmd/push/main_test.go b/cmd/push/main_test.go new file mode 100644 index 00000000..b18b8a87 --- /dev/null +++ b/cmd/push/main_test.go @@ -0,0 +1,58 @@ +package push + +import ( + "context" + "fmt" + "testing" + + "github.com/calypr/git-drs/internal/drsdelete" +) + +func TestCurrentDeleteRefUpdatesUsesUpstreamWhenConfigured(t *testing.T) { + oldFn := gitOutputFn + gitOutputFn = func(ctx context.Context, args ...string) (string, error) { + switch fmt.Sprint(args) { + case "[rev-parse HEAD]": + return "head-sha", nil + case "[rev-parse --verify @{upstream}]": + return "upstream-sha", nil + default: + t.Fatalf("unexpected git args: %v", args) + return "", nil + } + } + t.Cleanup(func() { gitOutputFn = oldFn }) + + got, err := currentDeleteRefUpdates(context.Background()) + if err != nil { + t.Fatalf("currentDeleteRefUpdates returned error: %v", err) + } + want := []drsdelete.RefUpdate{{OldSHA: "upstream-sha", NewSHA: "head-sha"}} + if len(got) != len(want) || got[0] != want[0] { + t.Fatalf("unexpected delete refs: got %+v want %+v", got, want) + } +} + +func TestCurrentDeleteRefUpdatesSkipsWhenUpstreamMissing(t *testing.T) { + oldFn := gitOutputFn + gitOutputFn = func(ctx context.Context, args ...string) (string, error) { + switch fmt.Sprint(args) { + case "[rev-parse HEAD]": + return "head-sha", nil + case "[rev-parse --verify @{upstream}]": + return "", fmt.Errorf("git rev-parse --verify @{upstream}: fatal: no upstream configured") + default: + t.Fatalf("unexpected git args: %v", args) + return "", nil + } + } + t.Cleanup(func() { gitOutputFn = oldFn }) + + got, err := currentDeleteRefUpdates(context.Background()) + if err != nil { + t.Fatalf("currentDeleteRefUpdates returned error: %v", err) + } + if got != nil { + t.Fatalf("expected nil delete refs when upstream is missing, got %+v", got) + } +} diff --git a/cmd/push/progress.go b/cmd/push/progress.go new file mode 100644 index 00000000..a22156ec --- /dev/null +++ b/cmd/push/progress.go @@ -0,0 +1,155 @@ +package push + +import ( + "fmt" + "io" + "sync" + + "github.com/calypr/git-drs/internal/progressui" + "github.com/calypr/git-drs/internal/pushsync" +) + +type uploadFileProgress struct { + path string + total int64 + current int64 + started bool + completed bool +} + +type uploadProgressRenderer struct { + mu sync.Mutex + base *progressui.Renderer + planned bool + plan pushsync.UploadPlanSummary + files map[string]*uploadFileProgress + fileOrder []string +} + +func newUploadProgressRenderer(out io.Writer) *uploadProgressRenderer { + return &uploadProgressRenderer{ + base: progressui.NewRenderer(out), + files: make(map[string]*uploadFileProgress), + } +} + +func (r *uploadProgressRenderer) renderLocked(force bool) { + lines := make([]string, 0, len(r.fileOrder)) + for idx, oid := range r.fileOrder { + file := r.files[oid] + if file == nil { + continue + } + lines = append(lines, r.renderLine(idx, len(r.fileOrder), file)) + } + r.base.Render(force, lines) +} + +func (r *uploadProgressRenderer) OnUploadPlan(plan pushsync.UploadPlanSummary) { + r.mu.Lock() + defer r.mu.Unlock() + + r.plan = plan + r.planned = plan.TotalFiles > 0 + r.files = make(map[string]*uploadFileProgress, len(plan.Files)) + r.fileOrder = r.fileOrder[:0] + for _, file := range plan.Files { + r.files[file.OID] = &uploadFileProgress{ + path: file.Path, + total: file.Bytes, + } + r.fileOrder = append(r.fileOrder, file.OID) + } + if r.planned { + r.renderLocked(true) + } +} + +func (r *uploadProgressRenderer) OnUploadProgress(ev pushsync.UploadProgressEvent) { + r.mu.Lock() + defer r.mu.Unlock() + + if !r.planned { + return + } + file, ok := r.files[ev.OID] + if !ok { + return + } + if ev.Path != "" { + file.path = ev.Path + } + if ev.TotalBytes > 0 { + file.total = ev.TotalBytes + } + if ev.BytesSoFar > file.current { + file.current = ev.BytesSoFar + } + if ev.Phase == pushsync.UploadProgressUploading { + file.started = true + } + if ev.Phase == pushsync.UploadProgressCompleted && !file.completed { + file.started = true + file.completed = true + if file.total > 0 { + file.current = file.total + } + } + r.renderLocked(false) +} + +func (r *uploadProgressRenderer) Finish() { + r.mu.Lock() + defer r.mu.Unlock() + + if !r.planned { + return + } + lines := make([]string, 0, len(r.fileOrder)) + for idx, oid := range r.fileOrder { + file := r.files[oid] + if file == nil { + continue + } + lines = append(lines, r.renderLine(idx, len(r.fileOrder), file)) + } + r.base.Finish(lines) + r.planned = false +} + +func (r *uploadProgressRenderer) HadUploads() bool { + r.mu.Lock() + defer r.mu.Unlock() + return r != nil && r.planned +} + +func (r *uploadProgressRenderer) renderLine(idx int, total int, file *uploadFileProgress) string { + label := "preparing upload" + if file != nil && file.path != "" { + label = progressui.TrimLabel(file.path, 48) + } + prefix := "" + if file != nil { + switch { + case file.started && !file.completed: + prefix = r.base.Spinner() + " " + } + } + + current := int64(0) + totalBytes := int64(0) + completed := false + if file != nil { + current = file.current + totalBytes = file.total + completed = file.completed + } + displayCurrent := progressui.VisibleProgressBytes(current, totalBytes, completed) + bar := progressui.RenderProgressBar(displayCurrent, totalBytes, 24) + pct := progressui.RenderPercentCapped(displayCurrent, totalBytes, completed) + bytesLabel := progressui.RenderByteProgress(displayCurrent, totalBytes, completed) + + _ = idx + _ = total + return fmt.Sprintf("%s%s %s %s %s", prefix, label, bar, pct, bytesLabel) +} diff --git a/cmd/push/progress_test.go b/cmd/push/progress_test.go new file mode 100644 index 00000000..b9c66cb6 --- /dev/null +++ b/cmd/push/progress_test.go @@ -0,0 +1,177 @@ +package push + +import ( + "bytes" + "strings" + "sync" + "testing" + "time" + + "github.com/calypr/git-drs/internal/pushsync" +) + +func TestUploadProgressRendererTTY(t *testing.T) { + var out bytes.Buffer + r := newUploadProgressRenderer(&out) + r.base.SetTTY(true) + + r.OnUploadPlan(pushsync.UploadPlanSummary{ + Files: []pushsync.UploadPlanFile{ + {OID: "oid-1", Path: "a.bin", Bytes: 100}, + {OID: "oid-2", Path: "b.bin", Bytes: 100}, + }, + TotalFiles: 2, + TotalBytes: 200, + }) + r.OnUploadProgress(pushsync.UploadProgressEvent{OID: "oid-1", Path: "a.bin", BytesSoFar: 0, TotalBytes: 100, Phase: pushsync.UploadProgressUploading}) + r.OnUploadProgress(pushsync.UploadProgressEvent{OID: "oid-1", Path: "a.bin", BytesSoFar: 50, BytesSinceLast: 50, TotalBytes: 100, Phase: pushsync.UploadProgressUploading}) + r.OnUploadProgress(pushsync.UploadProgressEvent{OID: "oid-1", Path: "a.bin", BytesSoFar: 100, TotalBytes: 100, Phase: pushsync.UploadProgressCompleted}) + r.OnUploadProgress(pushsync.UploadProgressEvent{OID: "oid-2", Path: "b.bin", BytesSoFar: 0, TotalBytes: 100, Phase: pushsync.UploadProgressUploading}) + r.OnUploadProgress(pushsync.UploadProgressEvent{OID: "oid-2", Path: "b.bin", BytesSoFar: 100, TotalBytes: 100, Phase: pushsync.UploadProgressCompleted}) + r.Finish() + + got := out.String() + if !strings.Contains(got, "a.bin [============ ] 50.0% 50 B/100 B") { + t.Fatalf("expected first file uploading line, got %q", got) + } + if !strings.Contains(got, "b.bin [ ] 0.0% 0 B/100 B") { + t.Fatalf("expected second file pending line, got %q", got) + } + if !strings.Contains(got, "b.bin [========================] 100.0% 100 B/100 B") { + t.Fatalf("expected completed second file line, got %q", got) + } + if strings.Contains(got, "(uploading)") || strings.Contains(got, "(pending)") || strings.Contains(got, "(complete)") { + t.Fatalf("did not expect parenthesized state text, got %q", got) + } + if !strings.HasSuffix(got, "\n") { + t.Fatalf("expected trailing newline, got %q", got) + } +} + +func TestUploadProgressRendererNonTTYThrottles(t *testing.T) { + var out bytes.Buffer + r := newUploadProgressRenderer(&out) + r.base.SetTTY(false) + now := time.Unix(0, 0) + r.base.SetClock(func() time.Time { return now }) + + r.OnUploadPlan(pushsync.UploadPlanSummary{ + Files: []pushsync.UploadPlanFile{{OID: "oid-1", Path: "a.bin", Bytes: 100}}, + TotalFiles: 1, + TotalBytes: 100, + }) + first := out.String() + if first == "" { + t.Fatal("expected initial non-tty progress line") + } + + r.OnUploadProgress(pushsync.UploadProgressEvent{OID: "oid-1", Path: "a.bin", BytesSoFar: 10, BytesSinceLast: 10, TotalBytes: 100, Phase: pushsync.UploadProgressUploading}) + if out.String() != first { + t.Fatalf("expected throttled output to remain unchanged, got %q", out.String()) + } + + now = now.Add(3 * time.Second) + r.OnUploadProgress(pushsync.UploadProgressEvent{OID: "oid-1", Path: "a.bin", BytesSoFar: 100, TotalBytes: 100, Phase: pushsync.UploadProgressCompleted}) + got := out.String() + if strings.Count(got, "\n") < 2 { + t.Fatalf("expected throttled summary updates, got %q", got) + } + if !strings.Contains(got, "a.bin [========================] 100.0% 100 B/100 B") { + t.Fatalf("expected non-tty progress summary, got %q", got) + } + if strings.Contains(got, "1/1") || strings.Contains(got, "[*]") { + t.Fatalf("did not expect positional or completion prefix clutter, got %q", got) + } +} + +func TestUploadProgressRendererDoesNotShowFullCompletionBeforeCompleteEvent(t *testing.T) { + var out bytes.Buffer + r := newUploadProgressRenderer(&out) + r.base.SetTTY(true) + + total := int64(500 * 1024 * 1024) + r.OnUploadPlan(pushsync.UploadPlanSummary{ + Files: []pushsync.UploadPlanFile{{OID: "oid-1", Path: "large.bin", Bytes: total}}, + TotalFiles: 1, + TotalBytes: total, + }) + r.OnUploadProgress(pushsync.UploadProgressEvent{ + OID: "oid-1", + Path: "large.bin", + BytesSoFar: total, + TotalBytes: total, + Phase: pushsync.UploadProgressUploading, + }) + + got := out.String() + if !strings.Contains(got, "99.9%") { + t.Fatalf("expected in-flight upload to stay below 100%%, got %q", got) + } + if !strings.Contains(got, "<500.0 MiB/500.0 MiB") { + t.Fatalf("expected in-flight byte label to avoid full equality, got %q", got) + } + if strings.Contains(got, "100.0%") { + t.Fatalf("did not expect in-flight upload to render as 100%%, got %q", got) + } +} + +func TestUploadProgressRendererHadUploads(t *testing.T) { + var out bytes.Buffer + r := newUploadProgressRenderer(&out) + if r.HadUploads() { + t.Fatal("expected fresh renderer to report no uploads") + } + + r.OnUploadPlan(pushsync.UploadPlanSummary{ + Files: []pushsync.UploadPlanFile{{OID: "oid-1", Path: "a.bin", Bytes: 1}}, + TotalFiles: 1, + TotalBytes: 1, + }) + if !r.HadUploads() { + t.Fatal("expected renderer to report uploads after a non-empty plan") + } + + r.Finish() + if r.HadUploads() { + t.Fatal("expected renderer to reset after finish") + } +} + +func TestUploadProgressRendererConcurrentProgress(t *testing.T) { + var out bytes.Buffer + r := newUploadProgressRenderer(&out) + r.base.SetTTY(false) + + r.OnUploadPlan(pushsync.UploadPlanSummary{ + Files: []pushsync.UploadPlanFile{ + {OID: "oid-1", Path: "a.bin", Bytes: 100}, + {OID: "oid-2", Path: "b.bin", Bytes: 100}, + }, + TotalFiles: 2, + TotalBytes: 200, + }) + + events := []pushsync.UploadProgressEvent{ + {OID: "oid-1", Path: "a.bin", BytesSoFar: 10, BytesSinceLast: 10, TotalBytes: 100, Phase: pushsync.UploadProgressUploading}, + {OID: "oid-2", Path: "b.bin", BytesSoFar: 20, BytesSinceLast: 20, TotalBytes: 100, Phase: pushsync.UploadProgressUploading}, + {OID: "oid-1", Path: "a.bin", BytesSoFar: 100, TotalBytes: 100, Phase: pushsync.UploadProgressCompleted}, + {OID: "oid-2", Path: "b.bin", BytesSoFar: 100, TotalBytes: 100, Phase: pushsync.UploadProgressCompleted}, + } + + var wg sync.WaitGroup + wg.Add(len(events)) + for _, ev := range events { + ev := ev + go func() { + defer wg.Done() + r.OnUploadProgress(ev) + }() + } + wg.Wait() + r.Finish() + + got := out.String() + if !strings.Contains(got, "a.bin") || !strings.Contains(got, "b.bin") { + t.Fatalf("expected both files in concurrent progress output, got %q", got) + } +} diff --git a/cmd/remote/add/add_test.go b/cmd/remote/add/add_test.go index 2923b5d8..e71c3745 100644 --- a/cmd/remote/add/add_test.go +++ b/cmd/remote/add/add_test.go @@ -1,8 +1,13 @@ package add import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" "testing" + bucketapi "github.com/calypr/syfon/apigen/client/bucketapi" "github.com/stretchr/testify/assert" ) @@ -12,5 +17,89 @@ func TestAddCmd(t *testing.T) { } func TestGen3Cmd(t *testing.T) { - assert.Equal(t, "gen3 [remote-name]", Gen3Cmd.Use) + assert.Equal(t, "gen3 [remote-name] ", Gen3Cmd.Use) +} + +func TestParseScopeArg(t *testing.T) { + t.Run("splits org and project on slash", func(t *testing.T) { + org, project, err := parseScopeArg("HTAN_INT/BForePC") + if err != nil { + t.Fatalf("parseScopeArg returned error: %v", err) + } + if org != "HTAN_INT" || project != "BForePC" { + t.Fatalf("unexpected scope parse result: %q/%q", org, project) + } + }) + + t.Run("rejects legacy single token input", func(t *testing.T) { + _, _, err := parseScopeArg("BForePC") + if err == nil { + t.Fatal("expected invalid scope error") + } + }) + + t.Run("rejects empty org or project", func(t *testing.T) { + for _, raw := range []string{"/BForePC", "HTAN_INT/", "HTAN_INT//BForePC"} { + _, _, err := parseScopeArg(raw) + if err == nil { + t.Fatalf("expected invalid scope error for %q", raw) + } + } + }) +} + +func TestResolveBucketScopeFromServer(t *testing.T) { + t.Run("matches project resource", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/data/buckets" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "Bearer test-token" { + t.Fatalf("unexpected auth header: %q", got) + } + _, _ = w.Write([]byte(`{"S3_BUCKETS":{"cbds":{"programs":["/organization/HTAN_INT/project/BForePC"]}}}`)) + })) + defer srv.Close() + + scope, err := resolveBucketScopeFromServer(context.Background(), srv.URL, "test-token", "HTAN_INT", "BForePC") + if err != nil { + t.Fatalf("resolveBucketScopeFromServer returned error: %v", err) + } + if scope.Bucket != "cbds" { + t.Fatalf("unexpected bucket: %+v", scope) + } + }) + + t.Run("falls back to org resource", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(`{"S3_BUCKETS":{"cbds":{"programs":["/organization/HTAN_INT"]}}}`)) + })) + defer srv.Close() + + scope, err := resolveBucketScopeFromServer(context.Background(), srv.URL, "test-token", "HTAN_INT", "BForePC") + if err != nil { + t.Fatalf("resolveBucketScopeFromServer returned error: %v", err) + } + if scope.Bucket != "cbds" { + t.Fatalf("unexpected bucket: %+v", scope) + } + }) + + t.Run("no match", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := bucketapi.BucketsResponse{S3BUCKETS: map[string]bucketapi.BucketMetadata{ + "cbds": {}, + }} + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(resp); err != nil { + t.Fatalf("encode response: %v", err) + } + })) + defer srv.Close() + + _, err := resolveBucketScopeFromServer(context.Background(), srv.URL, "test-token", "HTAN_INT", "BForePC") + if err == nil { + t.Fatal("expected error when no matching bucket is visible") + } + }) } diff --git a/cmd/remote/add/gen3.go b/cmd/remote/add/gen3.go index 9f9ceed7..2028bfa3 100644 --- a/cmd/remote/add/gen3.go +++ b/cmd/remote/add/gen3.go @@ -2,42 +2,46 @@ package add import ( "context" + "encoding/json" "fmt" "log/slog" + "net/http" "strings" "github.com/calypr/data-client/credentials" + "github.com/calypr/git-drs/cmd/initialize" "github.com/calypr/git-drs/internal/common" "github.com/calypr/git-drs/internal/config" "github.com/calypr/git-drs/internal/drslog" "github.com/calypr/git-drs/internal/gitrepo" + bucketapi "github.com/calypr/syfon/apigen/client/bucketapi" conf "github.com/calypr/syfon/client/config" + syfoncommon "github.com/calypr/syfon/common" "github.com/spf13/cobra" ) var Gen3Cmd = &cobra.Command{ - Use: "gen3 [remote-name]", + Use: "gen3 [remote-name] ", Args: func(cmd *cobra.Command, args []string) error { - if len(args) > 1 { + if len(args) < 1 || len(args) > 2 { cmd.SilenceUsage = false - return fmt.Errorf("error: accepts at most 1 argument (remote name), received %d\n\nUsage: %s\n\nSee 'git drs remote add gen3 --help' for more details", len(args), cmd.UseLine()) + return fmt.Errorf("error: expected [remote-name] , received %d arguments\n\nUsage: %s\n\nSee 'git drs remote add gen3 --help' for more details", len(args), cmd.UseLine()) } return nil }, RunE: func(cmd *cobra.Command, args []string) error { logg := drslog.GetLogger() - // make sure at least one of the credentials params is provided - if credFile == "" && fenceToken == "" && len(args) == 0 { - return fmt.Errorf("error: Gen3 requires a credentials file or accessToken to setup project locally. Please provide either a --cred or --token flag. See 'git drs remote add gen3 --help' for more details") - } - remoteName := config.ORIGIN - if len(args) > 0 { + scopeArg := "" + if len(args) == 1 { + scopeArg = args[0] + } else { remoteName = args[0] + scopeArg = args[1] } - err := gen3Init(remoteName, credFile, fenceToken, project, organization, bucket, logg) + err := gen3Init(remoteName, credFile, fenceToken, scopeArg, logg) if err != nil { return fmt.Errorf("error configuring gen3 server: %v", err) } @@ -45,31 +49,16 @@ var Gen3Cmd = &cobra.Command{ }, } -func gen3Init(remoteName, credFile, fenceToken, project, organization, bucket string, logg *slog.Logger) error { +func gen3Init(remoteName, credFile, fenceToken, scopeArg string, logg *slog.Logger) error { if remoteName == "" { return fmt.Errorf("remote name is required") } - if project == "" { - return fmt.Errorf("project is required for Gen3 remote") + if err := initialize.EnsureInitialized(logg); err != nil { + return fmt.Errorf("failed to initialize repository: %w", err) } - - resolvedBucket := strings.TrimSpace(bucket) - resolvedStoragePrefix := "" - if strings.TrimSpace(organization) != "" { - scope, err := gitrepo.ResolveBucketScope(organization, project, resolvedBucket, "") - if err != nil { - return fmt.Errorf("failed resolving bucket mapping for organization=%q project=%q: %w", organization, project, err) - } - resolvedBucket = strings.TrimSpace(scope.Bucket) - resolvedStoragePrefix = strings.TrimSpace(scope.Prefix) - } - if resolvedBucket == "" { - if strings.TrimSpace(organization) == "" { - return fmt.Errorf("bucket is required when organization is empty") - } - if strings.TrimSpace(resolvedBucket) == "" { - return fmt.Errorf("bucket is required (or configure mapping first with `git drs bucket add-project --organization %s --project %s --path :///`)", organization, project) - } + organization, project, err := parseScopeArg(scopeArg) + if err != nil { + return err } var accessToken, apiKey, keyID, apiEndpoint string @@ -99,13 +88,13 @@ func gen3Init(remoteName, credFile, fenceToken, project, organization, bucket st default: existing, err := configure.Load(remoteName) - if err == nil { + if err != nil { + return fmt.Errorf("failed to load %s config: %w", remoteName, err) + } else { accessToken = existing.AccessToken apiKey = existing.APIKey keyID = existing.KeyID apiEndpoint = existing.APIEndpoint - } else { - return fmt.Errorf("must provide either --cred or --token (or have existing profile %s)", remoteName) } } @@ -113,6 +102,33 @@ func gen3Init(remoteName, credFile, fenceToken, project, organization, bucket st return fmt.Errorf("could not determine Gen3 API endpoint") } + cred := &conf.Credential{ + Profile: remoteName, + APIEndpoint: apiEndpoint, + APIKey: apiKey, + KeyID: keyID, + AccessToken: accessToken, // may be stale + UseShepherd: "false", + MinShepherdVersion: "", + } + + if err := credentials.EnsureValidCredential(context.Background(), cred, logg); err != nil { + return fmt.Errorf("failed to verify/refresh Gen3 credential: %w", config.WrapCredentialValidationError(remoteName, err)) + } + + scope, err := gitrepo.ResolveBucketScope(organization, project, "", "") + if err != nil { + scope, err = resolveBucketScopeFromServer(context.Background(), apiEndpoint, strings.TrimSpace(cred.AccessToken), organization, project) + if err != nil { + return fmt.Errorf("failed resolving bucket mapping for organization=%q project=%q: %w", organization, project, err) + } + } + resolvedBucket := strings.TrimSpace(scope.Bucket) + resolvedStoragePrefix := strings.TrimSpace(scope.Prefix) + if resolvedBucket == "" { + return fmt.Errorf("no bucket mapping found for organization=%q project=%q", organization, project) + } + remoteGen3 := config.RemoteSelect{ Gen3: &config.Gen3Remote{ Endpoint: apiEndpoint, @@ -129,21 +145,6 @@ func gen3Init(remoteName, credFile, fenceToken, project, organization, bucket st } logg.Debug(fmt.Sprintf("Remote added/updated: %s → %s (project: %s, bucket: %s, storage_prefix: %s)", remoteName, apiEndpoint, project, resolvedBucket, resolvedStoragePrefix)) - // Step 3: Ensure credential profile is up-to-date (refreshes token if needed) - cred := &conf.Credential{ - Profile: remoteName, - APIEndpoint: apiEndpoint, - APIKey: apiKey, - KeyID: keyID, - AccessToken: accessToken, // may be stale - UseShepherd: "false", // or preserve from existing? - MinShepherdVersion: "", - } - - if err := credentials.EnsureValidCredential(context.Background(), cred, logg); err != nil { - return fmt.Errorf("failed to verify/refresh Gen3 credential: %w", err) - } - if err := configure.Save(cred); err != nil { return fmt.Errorf("failed to configure/update Gen3 profile: %w", err) } @@ -163,3 +164,95 @@ func gen3Init(remoteName, credFile, fenceToken, project, organization, bucket st logg.Debug(fmt.Sprintf("Gen3 profile '%s' configured and token refreshed successfully", remoteName)) return nil } + +func parseScopeArg(raw string) (string, string, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return "", "", fmt.Errorf("organization/project scope is required") + } + + parts := strings.Split(raw, "/") + if len(parts) != 2 { + return "", "", fmt.Errorf("invalid scope %q: expected organization/project", raw) + } + organization := strings.TrimSpace(parts[0]) + project := strings.TrimSpace(parts[1]) + if organization == "" || project == "" { + return "", "", fmt.Errorf("invalid scope %q: expected organization/project", raw) + } + return organization, project, nil +} + +func resolveBucketScopeFromServer(ctx context.Context, endpoint, token, organization, project string) (gitrepo.ResolvedBucketScope, error) { + if strings.TrimSpace(endpoint) == "" { + return gitrepo.ResolvedBucketScope{}, fmt.Errorf("missing API endpoint for server bucket lookup") + } + if strings.TrimSpace(token) == "" { + return gitrepo.ResolvedBucketScope{}, fmt.Errorf("missing access token for server bucket lookup") + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, strings.TrimRight(endpoint, "/")+"/data/buckets", nil) + if err != nil { + return gitrepo.ResolvedBucketScope{}, fmt.Errorf("build bucket list request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(token)) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return gitrepo.ResolvedBucketScope{}, fmt.Errorf("request bucket list: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return gitrepo.ResolvedBucketScope{}, fmt.Errorf("bucket list failed with status %d", resp.StatusCode) + } + + var payload bucketapi.BucketsResponse + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + return gitrepo.ResolvedBucketScope{}, fmt.Errorf("decode bucket list response: %w", err) + } + + projectResource, err := syfoncommon.ResourcePath(organization, project) + if err != nil { + return gitrepo.ResolvedBucketScope{}, err + } + orgResource, err := syfoncommon.ResourcePath(organization, "") + if err != nil { + return gitrepo.ResolvedBucketScope{}, err + } + + if bucket, ok := findBucketByResource(payload, projectResource); ok { + return gitrepo.ResolvedBucketScope{Bucket: bucket}, nil + } + if bucket, ok := findBucketByResource(payload, orgResource); ok { + return gitrepo.ResolvedBucketScope{Bucket: bucket}, nil + } + + return gitrepo.ResolvedBucketScope{}, fmt.Errorf("no visible server bucket matched organization=%q project=%q", organization, project) +} + +func findBucketByResource(payload bucketapi.BucketsResponse, resource string) (string, bool) { + resource = syfoncommon.NormalizeAccessResource(resource) + if resource == "" { + return "", false + } + var match string + for bucket, meta := range payload.S3BUCKETS { + if meta.Programs == nil { + continue + } + for _, candidate := range *meta.Programs { + if syfoncommon.NormalizeAccessResource(candidate) != resource { + continue + } + if match != "" && match != bucket { + return "", false + } + match = bucket + break + } + } + if match == "" { + return "", false + } + return match, true +} diff --git a/cmd/remote/add/init.go b/cmd/remote/add/init.go index 55f848f2..e156d376 100644 --- a/cmd/remote/add/init.go +++ b/cmd/remote/add/init.go @@ -3,14 +3,10 @@ package add import "github.com/spf13/cobra" var ( - apiEndpoint string - bucket string credFile string fenceToken string localPassword string localUsername string - project string - organization string ) // Cmd line declaration @@ -20,17 +16,10 @@ var Cmd = &cobra.Command{ } func init() { - Gen3Cmd.Flags().StringVar(&apiEndpoint, "url", "", "[gen3] Specify the API endpoint of the data commons") - Gen3Cmd.Flags().StringVar(&bucket, "bucket", "", "[gen3] Specify the bucket name") - Gen3Cmd.Flags().StringVar(&credFile, "cred", "", "[gen3] Specify the gen3 credential file that you want to use") - Gen3Cmd.Flags().StringVar(&fenceToken, "token", "", "[gen3] Specify the token to be used as a replacement for a credential file for temporary access") - Gen3Cmd.Flags().StringVar(&project, "project", "", "[gen3] Specify the gen3 project ID in the format -") - Gen3Cmd.Flags().StringVar(&organization, "organization", "", "[gen3] Optional organization/program scope (use with --project as project id)") + Gen3Cmd.Flags().StringVar(&credFile, "cred", "", "[gen3] Import a Gen3 credential file into this profile") + Gen3Cmd.Flags().StringVar(&fenceToken, "token", "", "[gen3] Use a temporary bearer token issued from fence") Cmd.AddCommand(Gen3Cmd) - LocalCmd.Flags().StringVarP(&project, "project", "p", "", "Project ID") - LocalCmd.Flags().StringVar(&bucket, "bucket", "", "Bucket Name") - LocalCmd.Flags().StringVar(&organization, "organization", "", "Organization Name") LocalCmd.Flags().StringVar(&localUsername, "username", "", "Username for local DRS HTTP basic auth") LocalCmd.Flags().StringVar(&localPassword, "password", "", "Password for local DRS HTTP basic auth") Cmd.AddCommand(LocalCmd) diff --git a/cmd/remote/add/local.go b/cmd/remote/add/local.go index c0d61b29..8a572d9d 100644 --- a/cmd/remote/add/local.go +++ b/cmd/remote/add/local.go @@ -1,33 +1,61 @@ package add import ( + "context" + "encoding/json" "fmt" + "net/http" "strings" + "github.com/calypr/git-drs/cmd/initialize" "github.com/calypr/git-drs/internal/config" + "github.com/calypr/git-drs/internal/drslog" "github.com/calypr/git-drs/internal/gitrepo" + bucketapi "github.com/calypr/syfon/apigen/client/bucketapi" + syfoncommon "github.com/calypr/syfon/common" "github.com/spf13/cobra" ) var LocalCmd = &cobra.Command{ - Use: "local ", + Use: "local ", Short: "Add a local DRS server", - Long: "Add a local DRS server by specifying its base URL, e.g., http://localhost:8000. Optional --username/--password configures basic auth for git-lfs and helper flows.", - Args: cobra.ExactArgs(2), + Long: "Add a local DRS server by specifying its base URL and scope. Optional --username/--password configures basic auth for helper flows.", + Args: cobra.ExactArgs(3), RunE: func(cmd *cobra.Command, args []string) error { remoteName := args[0] url := args[1] + scopeArg := args[2] + if err := initialize.EnsureInitialized(drslog.GetLogger()); err != nil { + return fmt.Errorf("failed to initialize repository: %w", err) + } if url == "" { return fmt.Errorf("URL cannot be empty") } + organization, project, err := parseScopeArg(scopeArg) + if err != nil { + return err + } + scope, err := gitrepo.ResolveBucketScope(organization, project, "", "") + if err != nil { + scope, err = resolveBucketScopeFromLocalServer(context.Background(), url, strings.TrimSpace(localUsername), strings.TrimSpace(localPassword), organization, project) + if err != nil { + return fmt.Errorf("failed resolving bucket mapping for organization=%q project=%q: %w", organization, project, err) + } + } + resolvedBucket := strings.TrimSpace(scope.Bucket) + resolvedStoragePrefix := strings.TrimSpace(scope.Prefix) + if resolvedBucket == "" { + return fmt.Errorf("no bucket mapping found for organization=%q project=%q", organization, project) + } remoteSelect := config.RemoteSelect{ Local: &config.LocalRemote{ - BaseURL: url, - ProjectID: project, - Bucket: bucket, - Organization: organization, + BaseURL: url, + ProjectID: project, + Bucket: resolvedBucket, + Organization: organization, + StoragePrefix: resolvedStoragePrefix, }, } @@ -54,3 +82,49 @@ var LocalCmd = &cobra.Command{ return nil }, } + +func resolveBucketScopeFromLocalServer(ctx context.Context, endpoint, username, password, organization, project string) (gitrepo.ResolvedBucketScope, error) { + if strings.TrimSpace(endpoint) == "" { + return gitrepo.ResolvedBucketScope{}, fmt.Errorf("missing API endpoint for server bucket lookup") + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, strings.TrimRight(endpoint, "/")+"/data/buckets", nil) + if err != nil { + return gitrepo.ResolvedBucketScope{}, fmt.Errorf("build bucket list request: %w", err) + } + if username != "" || password != "" { + req.SetBasicAuth(username, password) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return gitrepo.ResolvedBucketScope{}, fmt.Errorf("request bucket list: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return gitrepo.ResolvedBucketScope{}, fmt.Errorf("bucket list failed with status %d", resp.StatusCode) + } + + var payload bucketapi.BucketsResponse + if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil { + return gitrepo.ResolvedBucketScope{}, fmt.Errorf("decode bucket list response: %w", err) + } + + projectResource, err := syfoncommon.ResourcePath(organization, project) + if err != nil { + return gitrepo.ResolvedBucketScope{}, err + } + orgResource, err := syfoncommon.ResourcePath(organization, "") + if err != nil { + return gitrepo.ResolvedBucketScope{}, err + } + + if bucket, ok := findBucketByResource(payload, projectResource); ok { + return gitrepo.ResolvedBucketScope{Bucket: bucket}, nil + } + if bucket, ok := findBucketByResource(payload, orgResource); ok { + return gitrepo.ResolvedBucketScope{Bucket: bucket}, nil + } + + return gitrepo.ResolvedBucketScope{}, fmt.Errorf("no visible server bucket matched organization=%q project=%q", organization, project) +} diff --git a/cmd/remote/add/local_test.go b/cmd/remote/add/local_test.go index 80e908a6..1f5c42a8 100644 --- a/cmd/remote/add/local_test.go +++ b/cmd/remote/add/local_test.go @@ -1,14 +1,101 @@ package add import ( + "context" + "net/http" + "net/http/httptest" + "os" + "path/filepath" "testing" + "github.com/calypr/git-drs/internal/common" + "github.com/calypr/git-drs/internal/gitrepo" + "github.com/calypr/git-drs/internal/testutils" "github.com/stretchr/testify/assert" ) func TestAddLocalRemote(t *testing.T) { assert.NotNil(t, LocalCmd) - assert.Equal(t, "local ", LocalCmd.Use) + assert.Equal(t, "local ", LocalCmd.Use) assert.NotNil(t, LocalCmd.Flag("username")) assert.NotNil(t, LocalCmd.Flag("password")) + assert.Nil(t, LocalCmd.Flag("organization")) + assert.Nil(t, LocalCmd.Flag("project")) + assert.Nil(t, LocalCmd.Flag("bucket")) +} + +func TestResolveBucketScopeFromLocalServer(t *testing.T) { + t.Run("matches project resource", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/data/buckets" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + user, pass, ok := r.BasicAuth() + if !ok || user != "drs-user" || pass != "drs-pass" { + t.Fatalf("unexpected basic auth: ok=%v user=%q pass=%q", ok, user, pass) + } + _, _ = w.Write([]byte(`{"S3_BUCKETS":{"cbds":{"programs":["/organization/calypr/project/end_to_end_test"]}}}`)) + })) + defer srv.Close() + + scope, err := resolveBucketScopeFromLocalServer(context.Background(), srv.URL, "drs-user", "drs-pass", "calypr", "end_to_end_test") + if err != nil { + t.Fatalf("resolveBucketScopeFromLocalServer returned error: %v", err) + } + if scope.Bucket != "cbds" { + t.Fatalf("unexpected bucket: %+v", scope) + } + }) +} + +func TestLocalRemoteAddEnsuresInitialization(t *testing.T) { + testutils.SetupTestGitRepo(t) + localUsername = "" + localPassword = "" + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/data/buckets" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + _, _ = w.Write([]byte(`{"S3_BUCKETS":{"cbds":{"programs":["/organization/calypr/project/end_to_end_test"]}}}`)) + })) + defer srv.Close() + + if err := LocalCmd.RunE(LocalCmd, []string{"origin", srv.URL, "calypr/end_to_end_test"}); err != nil { + t.Fatalf("LocalCmd.RunE returned error: %v", err) + } + + if _, err := os.Stat(common.DRS_DIR); err != nil { + t.Fatalf("expected %s to exist: %v", common.DRS_DIR, err) + } + + filterProcess, err := gitrepo.GetGitConfigString("filter.drs.process") + if err != nil { + t.Fatalf("GetGitConfigString(filter.drs.process): %v", err) + } + if filterProcess != "git-drs filter" { + t.Fatalf("unexpected filter.drs.process: %q", filterProcess) + } + filterClean, err := gitrepo.GetGitConfigString("filter.drs.clean") + if err != nil { + t.Fatalf("GetGitConfigString(filter.drs.clean): %v", err) + } + if filterClean != "git-drs clean -- %f" { + t.Fatalf("unexpected filter.drs.clean: %q", filterClean) + } + filterSmudge, err := gitrepo.GetGitConfigString("filter.drs.smudge") + if err != nil { + t.Fatalf("GetGitConfigString(filter.drs.smudge): %v", err) + } + if filterSmudge != "git-drs smudge -- %f" { + t.Fatalf("unexpected filter.drs.smudge: %q", filterSmudge) + } + + preCommit, err := os.ReadFile(filepath.Join(".git", "hooks", "pre-commit")) + if err != nil { + t.Fatalf("read pre-commit hook: %v", err) + } + if string(preCommit) == "" { + t.Fatalf("expected pre-commit hook to be installed") + } } diff --git a/cmd/remote/list.go b/cmd/remote/list.go index 8723bc97..86a231c5 100644 --- a/cmd/remote/list.go +++ b/cmd/remote/list.go @@ -3,11 +3,21 @@ package remote import ( "fmt" + "github.com/calypr/data-client/credentials" "github.com/calypr/git-drs/internal/config" "github.com/calypr/git-drs/internal/drslog" + syconf "github.com/calypr/syfon/client/config" "github.com/spf13/cobra" ) +var ( + loadConfig = config.LoadConfig + loadProfileCredential = func(profile string) (*syconf.Credential, error) { + return syconf.NewConfigure(drslog.GetLogger()).Load(profile) + } + ensureValidCredential = credentials.EnsureValidCredential +) + var ListCmd = &cobra.Command{ Use: "list", Short: "List DRS repos", @@ -20,7 +30,7 @@ var ListCmd = &cobra.Command{ }, RunE: func(cmd *cobra.Command, args []string) error { logg := drslog.GetLogger() - cfg, err := config.LoadConfig() + cfg, err := loadConfig() if err != nil { logg.Debug(fmt.Sprintf("Error loading config: %s", err)) return err @@ -53,6 +63,16 @@ var ListCmd = &cobra.Command{ } fmt.Printf("%s %-10s %-8s %s\n", marker, name, remoteType, endpoint) + if remoteSelect.Gen3 != nil { + cred, err := loadProfileCredential(string(name)) + if err != nil { + logg.Warn(fmt.Sprintf("remote %s credential check skipped: %v", name, err)) + continue + } + if err := ensureValidCredential(cmd.Context(), cred, logg); err != nil { + logg.Warn(config.WrapCredentialValidationError(string(name), err).Error()) + } + } } return nil }, diff --git a/cmd/remote/remote_test.go b/cmd/remote/remote_test.go index b03b3121..c1d33aab 100644 --- a/cmd/remote/remote_test.go +++ b/cmd/remote/remote_test.go @@ -1,9 +1,14 @@ package remote import ( + "context" + "log/slog" + "os/exec" "testing" + "github.com/calypr/git-drs/internal/config" "github.com/calypr/git-drs/internal/testutils" + syconf "github.com/calypr/syfon/client/config" "github.com/stretchr/testify/assert" ) @@ -21,6 +26,22 @@ func TestRemoteListRun(t *testing.T) { tmpDir := testutils.SetupTestGitRepo(t) testutils.CreateDefaultTestConfig(t, tmpDir) + oldLoadProfileCredential := loadProfileCredential + oldEnsureValidCredential := ensureValidCredential + t.Cleanup(func() { + loadProfileCredential = oldLoadProfileCredential + ensureValidCredential = oldEnsureValidCredential + }) + + loadProfileCredential = func(profile string) (*syconf.Credential, error) { + return &syconf.Credential{Profile: profile, AccessToken: "token", APIEndpoint: "https://example.test"}, nil + } + called := false + ensureValidCredential = func(ctx context.Context, cred *syconf.Credential, _ *slog.Logger) error { + called = true + return nil + } + // Capture stdout output := testutils.CaptureStdout(t, func() { err := ListCmd.RunE(ListCmd, []string{}) @@ -29,6 +50,7 @@ func TestRemoteListRun(t *testing.T) { assert.Contains(t, output, "origin") assert.Contains(t, output, "gen3") + assert.True(t, called, "expected remote list to validate the configured credential") } func TestRemoteSetArgs(t *testing.T) { @@ -44,3 +66,89 @@ func TestRemoteSetArgs(t *testing.T) { err = SetCmd.Args(SetCmd, []string{"origin", "extra"}) assert.Error(t, err) } + +func TestRemoteRemoveArgs(t *testing.T) { + err := RemoveCmd.Args(RemoveCmd, []string{"origin"}) + assert.NoError(t, err) + + err = RemoveCmd.Args(RemoveCmd, []string{}) + assert.Error(t, err) + + err = RemoveCmd.Args(RemoveCmd, []string{"origin", "extra"}) + assert.Error(t, err) +} + +func TestRemoteRemoveRunReassignsDefaultAndCleansKeys(t *testing.T) { + tmpDir := testutils.SetupTestGitRepo(t) + testutils.CreateTestConfig(t, tmpDir, &config.Config{ + DefaultRemote: "origin", + Remotes: map[config.Remote]config.RemoteSelect{ + "origin": { + Gen3: &config.Gen3Remote{ + Endpoint: "https://origin.example", + ProjectID: "origin-proj", + Bucket: "origin-bucket", + }, + }, + "backup": { + Gen3: &config.Gen3Remote{ + Endpoint: "https://backup.example", + ProjectID: "backup-proj", + Bucket: "backup-bucket", + }, + }, + }, + }) + + for _, args := range [][]string{ + {"config", "drs.remote.origin.token", "token"}, + {"config", "drs.remote.origin.username", "alice"}, + {"config", "drs.remote.origin.password", "secret"}, + {"config", "remote.origin.lfsurl", "https://origin.example/info/lfs"}, + } { + cmd := exec.Command("git", args...) + cmd.Dir = tmpDir + err := cmd.Run() + assert.NoError(t, err) + } + + err := RemoveCmd.RunE(RemoveCmd, []string{"origin"}) + assert.NoError(t, err) + + cfg, err := config.LoadConfig() + assert.NoError(t, err) + assert.NotContains(t, cfg.Remotes, config.Remote("origin")) + assert.Equal(t, config.Remote("backup"), cfg.DefaultRemote) + + for _, key := range []string{ + "drs.remote.origin.type", + "drs.remote.origin.endpoint", + "drs.remote.origin.project", + "drs.remote.origin.bucket", + "drs.remote.origin.token", + "drs.remote.origin.username", + "drs.remote.origin.password", + "remote.origin.lfsurl", + } { + val, err := exec.Command("git", "config", "--get", key).CombinedOutput() + assert.Empty(t, string(val)) + assert.Error(t, err) + } +} + +func TestRemoteRemoveRunClearsDefaultWhenLastRemoteRemoved(t *testing.T) { + tmpDir := testutils.SetupTestGitRepo(t) + testutils.CreateDefaultTestConfig(t, tmpDir) + + err := RemoveCmd.RunE(RemoveCmd, []string{"origin"}) + assert.NoError(t, err) + + cfg, err := config.LoadConfig() + assert.NoError(t, err) + assert.Empty(t, cfg.Remotes) + assert.Equal(t, config.Remote(""), cfg.DefaultRemote) + + val, err := exec.Command("git", "config", "--get", "drs.default-remote").CombinedOutput() + assert.Empty(t, string(val)) + assert.Error(t, err) +} diff --git a/cmd/remote/remove.go b/cmd/remote/remove.go new file mode 100644 index 00000000..a5f5dbdc --- /dev/null +++ b/cmd/remote/remove.go @@ -0,0 +1,59 @@ +package remote + +import ( + "fmt" + "sort" + + "github.com/calypr/git-drs/internal/config" + "github.com/calypr/git-drs/internal/drslog" + "github.com/spf13/cobra" +) + +var RemoveCmd = &cobra.Command{ + Use: "remove ", + Aliases: []string{"rm"}, + Short: "Remove a DRS remote", + Long: "Remove a configured DRS remote and repair the default remote if needed.", + Args: func(cmd *cobra.Command, args []string) error { + if len(args) != 1 { + cmd.SilenceUsage = false + return fmt.Errorf("error: requires exactly 1 argument (remote name), received %d\n\nUsage: %s\n\nRun 'git drs remote list' to see available remotes or 'git drs remote remove --help' for more details", len(args), cmd.UseLine()) + } + return nil + }, + RunE: func(cmd *cobra.Command, args []string) error { + remoteName := config.Remote(args[0]) + logger := drslog.GetLogger() + + cfg, err := config.LoadConfig() + if err != nil { + return fmt.Errorf("failed to load config: %w", err) + } + + if _, ok := cfg.Remotes[remoteName]; !ok { + availableRemotes := make([]string, 0, len(cfg.Remotes)) + for name := range cfg.Remotes { + availableRemotes = append(availableRemotes, string(name)) + } + sort.Strings(availableRemotes) + return fmt.Errorf( + "remote '%s' not found.\nAvailable remotes: %v", + remoteName, + availableRemotes, + ) + } + + updated, err := config.RemoveRemote(remoteName) + if err != nil { + return fmt.Errorf("failed to remove remote: %w", err) + } + + if updated.DefaultRemote == "" { + logger.Debug(fmt.Sprintf("Removed remote %s; no default remote remains", remoteName)) + return nil + } + + logger.Debug(fmt.Sprintf("Removed remote %s; default remote is now %s", remoteName, updated.DefaultRemote)) + return nil + }, +} diff --git a/cmd/remote/root.go b/cmd/remote/root.go index 7d865720..45a1963d 100644 --- a/cmd/remote/root.go +++ b/cmd/remote/root.go @@ -14,5 +14,6 @@ var Cmd = &cobra.Command{ func init() { Cmd.AddCommand(add.Cmd) Cmd.AddCommand(ListCmd) + Cmd.AddCommand(RemoveCmd) Cmd.AddCommand(SetCmd) } diff --git a/cmd/rm/main.go b/cmd/rm/main.go new file mode 100644 index 00000000..a58124d1 --- /dev/null +++ b/cmd/rm/main.go @@ -0,0 +1,58 @@ +package rm + +import ( + "context" + "fmt" + "os/exec" + "path/filepath" + "strings" + + "github.com/calypr/git-drs/internal/drslog" + "github.com/calypr/git-drs/internal/lfs" + "github.com/spf13/cobra" +) + +var runCommand = func(name string, args ...string) error { + cmd := exec.Command(name, args...) + return cmd.Run() +} + +var Cmd = &cobra.Command{ + Use: "rm ...", + Short: "Remove tracked git-drs files", + Args: cobra.MinimumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return run(cmd.Context(), args) + }, +} + +func run(ctx context.Context, args []string) error { + tracked, err := lfs.GetTrackedLfsFiles(drslog.GetLogger()) + if err != nil { + return err + } + + type removal struct { + path string + oid string + } + planned := make([]removal, 0, len(args)) + for _, raw := range args { + path := filepath.ToSlash(filepath.Clean(raw)) + info, ok := tracked[path] + if !ok || strings.TrimSpace(info.Oid) == "" { + return fmt.Errorf("%s is not a tracked git-drs/LFS file", raw) + } + planned = append(planned, removal{path: path, oid: "sha256:" + strings.TrimPrefix(strings.TrimSpace(info.Oid), "sha256:")}) + } + + gitArgs := []string{"rm", "--"} + for _, item := range planned { + gitArgs = append(gitArgs, item.path) + } + if err := runCommand("git", gitArgs...); err != nil { + return err + } + + return nil +} diff --git a/cmd/rm/main_test.go b/cmd/rm/main_test.go new file mode 100644 index 00000000..16c51f28 --- /dev/null +++ b/cmd/rm/main_test.go @@ -0,0 +1,54 @@ +package rm + +import ( + "context" + "os" + "os/exec" + "path/filepath" + "testing" +) + +func TestRunRemovesTrackedFile(t *testing.T) { + repo := t.TempDir() + runGitCmd(t, repo, "init") + runGitCmd(t, repo, "config", "user.email", "test@example.com") + runGitCmd(t, repo, "config", "user.name", "Test User") + runGitCmd(t, repo, "config", "filter.drs.clean", "cat") + runGitCmd(t, repo, "config", "filter.drs.smudge", "cat") + runGitCmd(t, repo, "config", "filter.drs.process", "cat") + runGitCmd(t, repo, "config", "filter.drs.required", "false") + + if err := os.WriteFile(filepath.Join(repo, ".gitattributes"), []byte("*.dat filter=drs diff=drs merge=drs -text\n"), 0o644); err != nil { + t.Fatalf("write .gitattributes: %v", err) + } + path := filepath.Join(repo, "data.dat") + if err := os.WriteFile(path, []byte("version https://git-lfs.github.com/spec/v1\noid sha256:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\nsize 12\n"), 0o644); err != nil { + t.Fatalf("write pointer file: %v", err) + } + runGitCmd(t, repo, "add", ".") + runGitCmd(t, repo, "commit", "-m", "add pointer") + + oldWD, _ := os.Getwd() + if err := os.Chdir(repo); err != nil { + t.Fatalf("chdir repo: %v", err) + } + t.Cleanup(func() { _ = os.Chdir(oldWD) }) + + if err := run(context.Background(), []string{"data.dat"}); err != nil { + t.Fatalf("run returned error: %v", err) + } + + if _, err := os.Stat(path); !os.IsNotExist(err) { + t.Fatalf("expected file removed from worktree, stat err=%v", err) + } +} + +func runGitCmd(t *testing.T, dir string, args ...string) { + t.Helper() + cmd := exec.Command("git", args...) + cmd.Dir = dir + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("git %v failed: %v\n%s", args, err, string(out)) + } +} diff --git a/cmd/root.go b/cmd/root.go index ddfc95ac..4365f5c3 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -5,28 +5,24 @@ import ( "github.com/calypr/git-drs/cmd/addurl" "github.com/calypr/git-drs/cmd/bucket" "github.com/calypr/git-drs/cmd/clean" + "github.com/calypr/git-drs/cmd/copyrecords" deleteCmd "github.com/calypr/git-drs/cmd/delete" "github.com/calypr/git-drs/cmd/deleteproject" - - "github.com/calypr/git-drs/cmd/download" - "github.com/calypr/git-drs/cmd/fetch" "github.com/calypr/git-drs/cmd/filter" "github.com/calypr/git-drs/cmd/initialize" "github.com/calypr/git-drs/cmd/install" - - "github.com/calypr/git-drs/cmd/list" "github.com/calypr/git-drs/cmd/lsfiles" + "github.com/calypr/git-drs/cmd/ping" "github.com/calypr/git-drs/cmd/precommit" "github.com/calypr/git-drs/cmd/prepush" "github.com/calypr/git-drs/cmd/pull" "github.com/calypr/git-drs/cmd/push" "github.com/calypr/git-drs/cmd/query" "github.com/calypr/git-drs/cmd/remote" + "github.com/calypr/git-drs/cmd/rm" "github.com/calypr/git-drs/cmd/smudge" "github.com/calypr/git-drs/cmd/track" "github.com/calypr/git-drs/cmd/untrack" - - "github.com/calypr/git-drs/cmd/upload" "github.com/calypr/git-drs/cmd/version" "github.com/spf13/cobra" ) @@ -46,11 +42,13 @@ func init() { RootCmd.AddCommand(initialize.Cmd) RootCmd.AddCommand(version.Cmd) + RootCmd.AddCommand(ping.Cmd) RootCmd.AddCommand(filter.Cmd) RootCmd.AddCommand(clean.Cmd) + RootCmd.AddCommand(copyrecords.Cmd) RootCmd.AddCommand(smudge.Cmd) RootCmd.AddCommand(remote.Cmd) - RootCmd.AddCommand(fetch.Cmd) + RootCmd.AddCommand(rm.Cmd) RootCmd.AddCommand(pull.Cmd) RootCmd.AddCommand(push.Cmd) RootCmd.AddCommand(precommit.Cmd) @@ -63,10 +61,7 @@ func init() { RootCmd.AddCommand(bucket.Cmd) RootCmd.AddCommand(track.Cmd) RootCmd.AddCommand(untrack.Cmd) - RootCmd.AddCommand(list.Cmd) RootCmd.AddCommand(lsfiles.Cmd) - RootCmd.AddCommand(upload.Cmd) - RootCmd.AddCommand(download.Cmd) RootCmd.AddCommand(install.Cmd) RootCmd.CompletionOptions.HiddenDefaultCmd = true diff --git a/cmd/upload/main.go b/cmd/upload/main.go deleted file mode 100644 index 992b8f52..00000000 --- a/cmd/upload/main.go +++ /dev/null @@ -1,99 +0,0 @@ -package upload - -import ( - "fmt" - "os" - "path/filepath" - "strings" - - "github.com/calypr/git-drs/internal/common" - "github.com/calypr/git-drs/internal/config" - "github.com/calypr/git-drs/internal/drslog" - "github.com/calypr/git-drs/internal/drsobject" - "github.com/calypr/git-drs/internal/drsremote" - syupload "github.com/calypr/syfon/client/transfer/upload" - "github.com/spf13/cobra" -) - -var remote string - -// Cmd line declaration -var Cmd = &cobra.Command{ - Use: "upload ", - Short: "Upload a file to a DRS server", - Long: "Upload a file to a DRS server, without creating an LFS pointer", - Args: cobra.MinimumNArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - - logger := drslog.GetLogger() - - config, err := config.LoadConfig() - if err != nil { - return err - } - - remoteName, err := config.GetRemoteOrDefault(remote) - if err != nil { - logger.Error(fmt.Sprintf("Error getting remote: %v", err)) - return err - } - - client, err := config.GetRemoteClient(remoteName, logger) - if err != nil { - return err - } - - remoteConfig := config.GetRemote(remoteName) - organization := "" - project := "" - storagePrefix := "" - bucketName := "" - if remoteConfig != nil { - organization = remoteConfig.GetOrganization() - project = remoteConfig.GetProjectId() - storagePrefix = remoteConfig.GetStoragePrefix() - bucketName = remoteConfig.GetBucketName() - } - - for _, src := range args { - if s, err := os.Stat(src); err != nil { - logger.Error(fmt.Sprintf("Error stating file %s: %v", src, err)) - return err - } else if s.IsDir() { - logger.Error(fmt.Sprintf("Skipping directory %s", src)) - continue - } else { - sha256, err := common.CalculateFileSHA256(src) - if err != nil { - logger.Error(fmt.Sprintf("Error calculating SHA256 for file %s: %v", src, err)) - return err - } - - objs, err := drsremote.ObjectsByHashForScope(cmd.Context(), client, sha256) - if err != nil || len(objs) == 0 { - did := sha256 - name := filepath.Base(src) - drsObj, err := drsobject.BuildWithPrefix(name, sha256, s.Size(), did, bucketName, organization, project, storagePrefix) - if err != nil { - return fmt.Errorf("build DRS object for %s: %w", src, err) - } - registered, err := syupload.RegisterFile(cmd.Context(), client.Client.Data(), client.Client.DRS(), drsObj, src, bucketName) - if err != nil { - return fmt.Errorf("error uploading %s: %v", src, err) - } - if registered != nil { - logger.Info(fmt.Sprintf("Successfully uploaded %s to server with DRS ID %s", src, registered.Id)) - } - } else { - logger.Info(fmt.Sprintf("File %s already exists on server with DRS ID %s, skipping upload", src, strings.TrimSpace(objs[0].Id))) - } - } - } - - return nil - }, -} - -func init() { - Cmd.Flags().StringVarP(&remote, "remote", "r", "", "target remote DRS server (default: default_remote)") -} diff --git a/coverage/combined.html b/coverage/combined.html index f00491d3..bcd8b776 100644 --- a/coverage/combined.html +++ b/coverage/combined.html @@ -61,109 +61,135 @@ - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -526,7 +552,7 @@ "os" "path/filepath" - "github.com/calypr/git-drs/internal/cloud" + sycloud "github.com/calypr/syfon/client/cloud" "github.com/spf13/cobra" ) @@ -577,7 +603,7 @@ // printResolvedInfo writes a human-readable summary of resolved Git/LFS and // cloud object information to the command's stdout for user confirmation. -func printResolvedInfo(cmd *cobra.Command, gitCommonDir, lfsRoot string, objectInfo *cloud.ObjectInfo, pathArg string, isTracked bool, sha256 string) error { +func printResolvedInfo(cmd *cobra.Command, gitCommonDir, lfsRoot string, objectInfo *sycloud.ObjectInfo, pathArg string, isTracked bool, sha256 string) error { if _, err := fmt.Fprintf(cmd.OutOrStdout(), ` Resolved Git LFS Object Info ---------------------------- @@ -651,29 +677,34 @@ // NewCommand constructs the Cobra command for the `add-url` subcommand, // wiring usage, argument validation and the RunE handler. -func NewCommand() *cobra.Command { +func NewCommand() *cobra.Command { cmd := &cobra.Command{ - Use: "add-url <cloud-url> [path]", - Short: "Add a file to the Git DRS repo using a cloud object URL", + Use: "add-url <object-url-or-key> [path]", + Short: "Add a file from a provider URL or configured bucket object key", Args: func(cmd *cobra.Command, args []string) error { if len(args) < 1 || len(args) > 2 { - return errors.New("usage: add-url <cloud-url> [path]") + return errors.New("usage: add-url <object-url-or-key> [path]") } return nil }, RunE: runAddURL, } - addFlags(cmd) + addFlags(cmd) return cmd } // addFlags registers optional expected SHA256 checksum. -func addFlags(cmd *cobra.Command) { +func addFlags(cmd *cobra.Command) { cmd.Flags().String( "sha256", "", "Expected SHA256 checksum (optional)", ) + cmd.Flags().String( + "scheme", + "", + "Storage scheme for object-key mode (for example: s3 or gs)", + ) } // runAddURL is the Cobra RunE wrapper that delegates execution to the service. @@ -688,72 +719,132 @@ "fmt" "net/url" "os" + "path" "strings" - "github.com/calypr/git-drs/internal/cloud" + "github.com/calypr/git-drs/internal/gitrepo" + sycloud "github.com/calypr/syfon/client/cloud" "github.com/spf13/cobra" ) // addURLInput holds the parsed CLI state for the add-url command. type addURLInput struct { - objectURL string - path string - sha256 string - objectParams cloud.ObjectParameters + sourceArg string + objectURL string + path string + sha256 string + scheme string } -// parseAddURLInput parses CLI args and flags into an addURLInput and constructs -// cloud.ObjectParameters for metadata inspection. -func parseAddURLInput(cmd *cobra.Command, args []string) (addURLInput, error) { - objectURL := args[0] +// parseAddURLInput parses CLI args and flags into an addURLInput. +func parseAddURLInput(cmd *cobra.Command, args []string) (addURLInput, error) { + sourceArg := strings.TrimSpace(args[0]) - pathArg, err := resolvePathArg(objectURL, args) + pathArg, err := resolvePathArg(sourceArg, args) if err != nil { return addURLInput{}, err } - sha256Param, err := cmd.Flags().GetString("sha256") + sha256Param, err := cmd.Flags().GetString("sha256") if err != nil { return addURLInput{}, fmt.Errorf("read flag sha256: %w", err) } + scheme, err := cmd.Flags().GetString("scheme") + if err != nil { + return addURLInput{}, fmt.Errorf("read flag scheme: %w", err) + } - return addURLInput{ - objectURL: objectURL, + return addURLInput{ + sourceArg: sourceArg, path: pathArg, sha256: sha256Param, - objectParams: cloud.ObjectParameters{ - ObjectURL: objectURL, - S3Region: firstNonEmpty(os.Getenv("AWS_REGION"), os.Getenv("AWS_DEFAULT_REGION"), os.Getenv("TEST_BUCKET_REGION")), - S3Endpoint: firstNonEmpty(os.Getenv("AWS_ENDPOINT_URL_S3"), os.Getenv("AWS_ENDPOINT_URL"), os.Getenv("TEST_BUCKET_ENDPOINT")), - S3AccessKey: firstNonEmpty(os.Getenv("AWS_ACCESS_KEY_ID"), os.Getenv("TEST_BUCKET_ACCESS_KEY")), - S3SecretKey: firstNonEmpty(os.Getenv("AWS_SECRET_ACCESS_KEY"), os.Getenv("TEST_BUCKET_SECRET_KEY")), - SHA256: sha256Param, - DestinationPath: pathArg, - }, + scheme: strings.ToLower(strings.TrimSpace(scheme)), }, nil } // resolvePathArg returns the explicit destination path argument when provided, -// otherwise derives the worktree path from the given cloud URL path component. -func resolvePathArg(objectURL string, args []string) (string, error) { +// otherwise derives the worktree path from the given cloud URL or object key. +func resolvePathArg(sourceArg string, args []string) (string, error) { if len(args) == 2 { return args[1], nil } - u, err := url.Parse(objectURL) + if looksLikeCloudURL(sourceArg) { + u, err := url.Parse(sourceArg) + if err != nil { + return "", err + } + return strings.TrimPrefix(u.Path, "/"), nil + } + return strings.Trim(strings.TrimSpace(sourceArg), "/"), nil +} + +func buildObjectParameters(objectURL, pathArg, sha256 string) sycloud.ObjectParameters { + return sycloud.ObjectParameters{ + ObjectURL: objectURL, + S3Region: firstNonEmpty(os.Getenv("AWS_REGION"), os.Getenv("AWS_DEFAULT_REGION"), os.Getenv("TEST_BUCKET_REGION")), + S3Endpoint: firstNonEmpty(os.Getenv("AWS_ENDPOINT_URL_S3"), os.Getenv("AWS_ENDPOINT_URL"), os.Getenv("TEST_BUCKET_ENDPOINT")), + S3AccessKey: firstNonEmpty(os.Getenv("AWS_ACCESS_KEY_ID"), os.Getenv("TEST_BUCKET_ACCESS_KEY")), + S3SecretKey: firstNonEmpty(os.Getenv("AWS_SECRET_ACCESS_KEY"), os.Getenv("TEST_BUCKET_SECRET_KEY")), + SHA256: sha256, + DestinationPath: pathArg, + } +} + +func looksLikeCloudURL(raw string) bool { + u, err := url.Parse(strings.TrimSpace(raw)) if err != nil { - return "", err + return false + } + if strings.TrimSpace(u.Scheme) == "" { + return false + } + switch strings.ToLower(strings.TrimSpace(u.Scheme)) { + case "s3", "gs", "gcs", "azblob", "http", "https": + return strings.TrimSpace(u.Host) != "" + default: + return false + } +} + +func resolveObjectURL(input addURLInput, scope gitrepo.ResolvedBucketScope) (string, error) { + if looksLikeCloudURL(input.sourceArg) { + return input.sourceArg, nil } - return strings.TrimPrefix(u.Path, "/"), nil + if input.scheme == "" { + return "", fmt.Errorf("object key mode requires --scheme because local bucket mappings store bucket/prefix but not provider scheme") + } + key := joinObjectKey(scope.Prefix, input.sourceArg) + switch input.scheme { + case "s3": + return fmt.Sprintf("s3://%s/%s", scope.Bucket, key), nil + case "gs", "gcs": + return fmt.Sprintf("gs://%s/%s", scope.Bucket, key), nil + case "azblob", "az": + return "", fmt.Errorf("object key mode for Azure requires a full azblob:// URL because the local mapping does not store account_name") + default: + return "", fmt.Errorf("unsupported --scheme %q (expected s3 or gs, or pass a full object URL)", input.scheme) + } +} + +func joinObjectKey(prefix, key string) string { + parts := make([]string, 0, 2) + if p := strings.Trim(strings.TrimSpace(prefix), "/"); p != "" { + parts = append(parts, p) + } + if k := strings.Trim(strings.TrimSpace(key), "/"); k != "" { + parts = append(parts, k) + } + return path.Join(parts...) } -func firstNonEmpty(values ...string) string { - for _, v := range values { +func firstNonEmpty(values ...string) string { + for _, v := range values { v = strings.TrimSpace(v) - if v != "" { + if v != "" { return v } } - return "" + return "" } @@ -790,16 +881,20 @@ import ( "context" + "crypto/sha256" "fmt" "log/slog" - "os" + "strings" - "github.com/calypr/git-drs/internal/cloud" "github.com/calypr/git-drs/internal/common" "github.com/calypr/git-drs/internal/config" "github.com/calypr/git-drs/internal/drslog" - "github.com/calypr/git-drs/internal/drsmap" + "github.com/calypr/git-drs/internal/drsobject" + "github.com/calypr/git-drs/internal/drstrack" "github.com/calypr/git-drs/internal/lfs" + drsapi "github.com/calypr/syfon/apigen/client/drs" + sycloud "github.com/calypr/syfon/client/cloud" + "github.com/google/uuid" "github.com/spf13/cobra" ) @@ -807,7 +902,7 @@ // behavior (logger factory, object inspection, LFS helpers, config loader, etc.). type AddURLService struct { newLogger func(string, bool) (*slog.Logger, error) - inspectObject func(ctx context.Context, input cloud.ObjectParameters) (*cloud.ObjectInfo, error) + inspectObject func(ctx context.Context, input sycloud.ObjectParameters) (*sycloud.ObjectInfo, error) isLFSTracked func(path string) (bool, error) getGitRoots func(ctx context.Context) (string, string, error) gitLFSTrack func(ctx context.Context, path string) (bool, error) @@ -816,131 +911,186 @@ // NewAddURLService constructs an AddURLService populated with production // implementations of its dependencies. -func NewAddURLService() *AddURLService { +func NewAddURLService() *AddURLService { return &AddURLService{ newLogger: drslog.NewLogger, - inspectObject: cloud.InspectObjectForLFS, + inspectObject: sycloud.InspectObject, isLFSTracked: lfs.IsLFSTracked, getGitRoots: lfs.GetGitRootDirectories, - gitLFSTrack: lfs.GitLFSTrackReadOnly, + gitLFSTrack: drstrack.TrackReadOnly, loadConfig: config.LoadConfig, } } -// Run executes the add-url workflow: parse CLI input, inspect the cloud object, +// Run executes the add-url workflow: parse CLI input, resolve the target bucket +// scope, inspect the provider object through the client-owned cloud package, // ensure the LFS object exists in local storage, write a pointer file, update // the pre-commit cache (best-effort), optionally add a tracking entry, and // record the DRS mapping. -func (s *AddURLService) Run(cmd *cobra.Command, args []string) error { +func (s *AddURLService) Run(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - if ctx == nil { + if ctx == nil { ctx = context.Background() } - logger, err := s.newLogger("", false) + logger, err := s.newLogger("", false) if err != nil { return fmt.Errorf("error creating logger: %v", err) } - input, err := parseAddURLInput(cmd, args) + input, err := parseAddURLInput(cmd, args) if err != nil { return err } - objectInfo, err := s.inspectObject(ctx, input.objectParams) + cfg, err := s.loadConfig() if err != nil { - return err + return fmt.Errorf("error getting config: %v", err) } - isTracked, err := s.isLFSTracked(input.path) + remote, err := cfg.GetDefaultRemote() if err != nil { - return fmt.Errorf("check LFS tracking for %s: %w", input.path, err) + return err } - gitCommonDir, lfsRoot, err := s.getGitRoots(ctx) - if err != nil { - return fmt.Errorf("get git root directories: %w", err) + remoteConfig := cfg.GetRemote(remote) + if remoteConfig == nil { + return fmt.Errorf("error getting remote configuration for %s", remote) } - if err := printResolvedInfo(cmd, gitCommonDir, lfsRoot, objectInfo, input.path, isTracked, input.sha256); err != nil { + org, project, scope, err := resolveTargetScope(remoteConfig) + if err != nil { return err } - oid, err := s.ensureLFSObject(ctx, objectInfo, input, lfsRoot) + input.objectURL, err = resolveObjectURL(input, scope) if err != nil { return err } - if err := writePointerFile(input.path, oid, objectInfo.SizeBytes); err != nil { + objectInfo, err := s.inspectObject(ctx, buildObjectParameters(input.objectURL, input.path, input.sha256)) + if err != nil { return err } - if err := updatePrecommitCache(ctx, logger, input.path, oid, input.objectURL); err != nil { - logger.Warn("pre-commit cache update skipped", "error", err) + isTracked, err := s.isLFSTracked(input.path) + if err != nil { + return fmt.Errorf("check LFS tracking for %s: %w", input.path, err) + } + + gitCommonDir, lfsRoot, err := s.getGitRoots(ctx) + if err != nil { + return fmt.Errorf("get git root directories: %w", err) } - if err := maybeTrackLFS(ctx, s.gitLFSTrack, input.path, isTracked); err != nil { + if err := printResolvedInfo(cmd, gitCommonDir, lfsRoot, objectInfo, input.path, isTracked, input.sha256); err != nil { return err } - cfg, err := s.loadConfig() + oid, err := s.ensureLFSObject(ctx, objectInfo, input, lfsRoot) if err != nil { - return fmt.Errorf("error getting config: %v", err) + return err } - remote, err := cfg.GetDefaultRemote() - if err != nil { + if err := writePointerFile(input.path, oid, objectInfo.SizeBytes); err != nil { return err } - remoteConfig := cfg.GetRemote(remote) - if remoteConfig == nil { - return fmt.Errorf("error getting remote configuration for %s", remote) + if err := updatePrecommitCache(ctx, logger, input.path, oid, input.objectURL); err != nil { + logger.Warn("pre-commit cache update skipped", "error", err) } - org, project, scope, err := resolveTargetScope(remoteConfig) - if err != nil { + if err := maybeTrackLFS(ctx, s.gitLFSTrack, input.path, isTracked); err != nil { return err } - builder := common.NewObjectBuilder(scope.Bucket, project) + builder := drsobject.NewBuilder(scope.Bucket, project) builder.Organization = org builder.StoragePrefix = scope.Prefix - file := lfs.LfsFileInfo{ + file := addURLDrsFile{ Name: input.path, Size: objectInfo.SizeBytes, Oid: oid, } - if _, err := drsmap.WriteDrsFile(builder, file, &input.objectURL); err != nil { - return fmt.Errorf("error WriteDrsFile: %v", err) + if _, err := writeAddURLDrsObject(builder, file, input.objectURL); err != nil { + return fmt.Errorf("write local DRS object: %w", err) } - return nil + return nil +} + +type addURLDrsFile struct { + Name string + Size int64 + Oid string +} + +func writeAddURLDrsObject(builder drsobject.Builder, file addURLDrsFile, objectPath string) (*drsapi.DrsObject, error) { + existing, err := drsobject.ReadObject(common.DRS_OBJS_PATH, file.Oid) + var drsObj *drsapi.DrsObject + if err == nil && existing != nil { + drsObj = existing + name := file.Name + drsObj.Name = &name + drsObj.Size = file.Size + } else { + drsID := uuid.NewSHA1(drsobject.UUIDNamespace, []byte(fmt.Sprintf("%s:%s", builder.Project, drsobject.NormalizeOid(file.Oid)))).String() + drsObj, err = builder.Build(file.Name, file.Oid, file.Size, drsID) + if err != nil { + return nil, fmt.Errorf("error building DRS object for oid %s: %w", file.Oid, err) + } + } + + if objectPath != "" { + if drsObj.AccessMethods != nil && len(*drsObj.AccessMethods) > 0 { + am := &(*drsObj.AccessMethods)[0] + am.AccessUrl = &struct { + Headers *[]string `json:"headers,omitempty"` + Url string `json:"url"` + }{Url: objectPath} + } else { + drsObj.AccessMethods = &[]drsapi.AccessMethod{{ + Type: drsapi.AccessMethodTypeS3, + AccessUrl: &struct { + Headers *[]string `json:"headers,omitempty"` + Url string `json:"url"` + }{Url: objectPath}, + }} + } + } + + if err := drsobject.WriteObject(common.DRS_OBJS_PATH, drsObj, file.Oid); err != nil { + return nil, fmt.Errorf("error writing DRS object for oid %s: %w", file.Oid, err) + } + return drsObj, nil } -// ensureLFSObject ensures the LFS object identified by objectInfo exists in the -// repository's LFS storage. If SHA256 is provided, it is trusted and returned. -// Otherwise we create a sentinel object and synthetic OID derived from ETag, -// deferring true checksum validation to first real data use. -func (s *AddURLService) ensureLFSObject(ctx context.Context, objectInfo *cloud.ObjectInfo, input addURLInput, lfsRoot string) (string, error) { +// ensureLFSObject returns the LFS pointer OID to use for the add-url target. +// If SHA256 is provided, it is trusted and returned. Otherwise we derive a +// deterministic placeholder OID from provider identity without writing any +// local LFS object payload. +func (s *AddURLService) ensureLFSObject(ctx context.Context, objectInfo *sycloud.ObjectInfo, input addURLInput, lfsRoot string) (string, error) { _ = ctx + _ = lfsRoot if input.sha256 != "" { return input.sha256, nil } - oid, err := lfs.SyntheticOIDFromETag(objectInfo.ETag) - if err != nil { - return "", err - } - objPath, err := lfs.WriteAddURLSentinelObject(lfsRoot, oid, objectInfo.ETag, input.objectURL) - if err != nil { - return "", err + return placeholderOIDForUnknownSHA(objectInfo.ETag, input.objectURL) +} + +func placeholderOIDForUnknownSHA(etag string, sourceURL string) (string, error) { + e := strings.TrimSpace(strings.Trim(etag, `"`)) + src := strings.TrimSpace(sourceURL) + if e == "" { + return "", fmt.Errorf("etag is required for placeholder oid") } - if _, err := fmt.Fprintf(os.Stderr, "Added add-url sentinel object at %s\n", objPath); err != nil { - return "", fmt.Errorf("stderr write: %w", err) + if src == "" { + return "", fmt.Errorf("source URL is required for placeholder oid") } - return oid, nil + sum := sha256.Sum256([]byte("git-drs-add-url-placeholder:v2\netag=" + e + "\nsource=" + src + "\n")) + return fmt.Sprintf("%x", sum[:]), nil } @@ -957,11 +1107,11 @@ "strings" "time" - gitauth "github.com/calypr/git-drs/internal/auth" + "github.com/calypr/data-client/credentials" "github.com/calypr/git-drs/internal/common" "github.com/calypr/git-drs/internal/drslog" "github.com/calypr/git-drs/internal/gitrepo" - "github.com/calypr/syfon/client/conf" + conf "github.com/calypr/syfon/client/config" "github.com/spf13/cobra" ) @@ -1195,7 +1345,7 @@ if prof, err := configure.Load(remoteName); err == nil { token = strings.TrimSpace(prof.AccessToken) if token == "" { - if ensureErr := gitauth.EnsureValidCredential(context.Background(), prof, drslog.GetLogger()); ensureErr == nil { + if ensureErr := credentials.EnsureValidCredential(context.Background(), prof, drslog.GetLogger()); ensureErr == nil { _ = configure.Save(prof) token = strings.TrimSpace(prof.AccessToken) } @@ -1326,135 +1476,487 @@ } -