diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml new file mode 100644 index 00000000..2e3ec0c9 --- /dev/null +++ b/.github/workflows/deploy.yml @@ -0,0 +1,81 @@ +name: Syftbox Deploy + +# This workflow deploys Syftbox to development and staging environments. +# For production releases, use the release.yml workflow instead. + +on: + workflow_dispatch: + inputs: + environment: + description: 'Environment to deploy to' + required: true + default: 'dev' + type: choice + options: + - dev + - stage + +jobs: + build-and-deploy: + # Build and deploy to target environment + runs-on: macos-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v4 + with: + go-version: '1.21' + + - name: Install just + uses: taiki-e/install-action@just + + - name: Install GoReleaser + run: | + brew install --cask goreleaser/tap/goreleaser + goreleaser --version + + - name: Setup toolchain + run: just setup-toolchain + + - name: Setup SSH + run: | + mkdir -p ~/.ssh + + # Use environment-specific SSH private key + case "${{ inputs.environment }}" in + "dev") + echo "${{ secrets.SSH_PRIVATE_KEY_DEV }}" > ~/.ssh/id_rsa + ssh-keyscan -H ${{ secrets.SSH_HOST_DEV }} >> ~/.ssh/known_hosts + ;; + "stage") + echo "${{ secrets.SSH_PRIVATE_KEY_STAGE }}" > ~/.ssh/id_rsa + ssh-keyscan -H ${{ secrets.SSH_HOST_STAGE }} >> ~/.ssh/known_hosts + ;; + *) + echo "Unknown environment: ${{ inputs.environment }}" + exit 1 + ;; + esac + + chmod 700 ~/.ssh + chmod 600 ~/.ssh/id_rsa + + - name: Deploy to ${{ inputs.environment }} + run: | + case "${{ inputs.environment }}" in + "dev") + REMOTE="${{ secrets.SSH_USER_DEV }}@${{ secrets.SSH_HOST_DEV }}" + ;; + "stage") + REMOTE="${{ secrets.SSH_USER_STAGE }}@${{ secrets.SSH_HOST_STAGE }}" + ;; + *) + echo "Unknown environment: ${{ inputs.environment }}" + exit 1 + ;; + esac + + just deploy $REMOTE diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 00000000..5038b436 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,131 @@ +name: Syftbox Release + +# This workflow creates a new release and deploys to production. +# For dev/stage deployments, use the deploy.yml workflow instead. + +on: + workflow_dispatch: + inputs: + version_type: + description: 'Version type for the release' + required: true + type: choice + options: + - patch + - minor + - major + +jobs: + version: + # Handle version bumping and tagging + runs-on: macos-latest + outputs: + version: ${{ steps.bump-version.outputs.version }} + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 # Required for svu to work properly with git history + + - name: Setup Go + uses: actions/setup-go@v4 + with: + go-version: '1.21' + + - name: Install just + uses: taiki-e/install-action@just + + - name: Install svu + run: go install github.com/caarlos0/svu@latest + + - name: Install jq + run: brew install jq + + - name: Setup git config + env: + GH_TOKEN: ${{ github.token }} + run: | + git config user.email "${GITHUB_ACTOR_ID}+${GITHUB_ACTOR}@users.noreply.github.com" + git config user.name "$(gh api /users/${GITHUB_ACTOR} | jq .name -r)" + + - name: Show current version + run: | + echo "Current version information:" + just show-version + + - name: Bump version + id: bump-version + run: | + echo "Releasing version for production deployment..." + just release ${{ inputs.version_type }} + version=$(git describe --tags --abbrev=0) + echo "version=${version}" >> $GITHUB_OUTPUT + + - name: Push version changes + run: | + # Set a new remote URL using HTTPS with the github token + git remote set-url origin https://x-access-token:${{ github.token }}@github.com/${{ github.repository }}.git + + # Push the current branch to the remote repo + git push origin + + # Push the tag to the remote repo + git push origin --tags + + - name: Show new version + run: | + echo "New version information:" + just show-version + + build-and-deploy: + needs: version + # Build and deploy to production + runs-on: macos-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v4 + with: + go-version: '1.21' + + - name: Install just + uses: taiki-e/install-action@just + + - name: Install GoReleaser + run: | + brew install --cask goreleaser/tap/goreleaser + goreleaser --version + + - name: Setup toolchain + run: just setup-toolchain + + - name: Setup SSH + run: | + mkdir -p ~/.ssh + echo "${{ secrets.SSH_PRIVATE_KEY_PROD }}" > ~/.ssh/id_rsa + ssh-keyscan -H ${{ secrets.SSH_HOST_PROD }} >> ~/.ssh/known_hosts + chmod 600 ~/.ssh/id_rsa + chmod 700 ~/.ssh + + - name: Deploy to production + run: | + REMOTE="${{ secrets.SSH_USER_PROD }}@${{ secrets.SSH_HOST_PROD }}" + just deploy $REMOTE + + - name: Create release + uses: ncipollo/release-action@v1 + with: + tag: ${{ needs.version.outputs.version }} + name: ${{ needs.version.outputs.version }} + draft: true + allowUpdates: true + omitBodyDuringUpdate: true + makeLatest: true + generateReleaseNotes: true + artifacts: | + releases/*.tar.gz + releases/*.zip diff --git a/README.md b/README.md index 1262d1fc..c780a2dd 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,27 @@ # SyftBox -## Quickstart +SyftBox is an open-source protocol that enables developers and organizations to build, deploy, and federate privacy-preserving computations seamlessly across a network. Unlock the ability to run computations on distributed datasets without centralizing data—preserving security while gaining valuable insights. + +Read the [documentation](https://www.syftbox.net) for more details. + +> [!WARNING] +> This project is a rewrite of the [original Python version](https://github.com/OpenMined/syft). Consequently, the linked documentation may not fully reflect the current implementation. + +## Quick Start + +Using the GUI, from https://github.com/OpenMined/SyftUI/releases + +On macOS and Linux. +``` +curl -fsSL https://syftbox.net/install.sh | sh +``` + +On Windows using Powershell +``` +powershell -ExecutionPolicy ByPass -c "irm https://syftbox.net/install.ps1 | iex" +``` + +## Contributing ### Install Go Follow the official [Go installation guide](https://golang.org/doc/install) to set up Go on your system. @@ -26,29 +47,4 @@ Verify your setup by running the tests: just test ``` - -SyftBox is an open-source protocol that enables developers and organizations to build, deploy, and federate privacy-preserving computations seamlessly across a network. Unlock the ability to run computations on distributed datasets without centralizing data—preserving security while gaining valuable insights. - -Read the [documentation](https://syftbox-documentation.openmined.org/get-started) for more details. - -> [!WARNING] -> This project is a rewrite of the [original Python version](https://github.com/OpenMined/syft). Consequently, the linked documentation may not fully reflect the current implementation. - -## Installation - -Using the GUI, from https://github.com/OpenMined/SyftUI/releases - - -On macOS and Linux. -``` -curl -fsSL https://syftboxdev.openmined.org/install.sh | sh -``` - -On Windows using Powershell -``` -powershell -ExecutionPolicy ByPass -c "irm https://syftboxdev.openmined.org/install.ps1 | iex" -``` - -## Contributing - -See the [development guide](./DEVELOPMENT.md) to get started +See the [development guide](./DEVELOPMENT.md) for more details diff --git a/cmd/client/main.go b/cmd/client/main.go index 2d20b1a2..452e8169 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -22,9 +22,7 @@ import ( ) var ( - home, _ = os.UserHomeDir() - oldProdURL = "syftbox.openmined.org" - oldStageURL = "syftboxstage.openmined.org" + home, _ = os.UserHomeDir() ) var rootCmd = &cobra.Command{ @@ -166,11 +164,9 @@ func loadConfig(cmd *cobra.Command) (*config.Config, error) { return nil, fmt.Errorf("config read: %w", err) } - // perform migrations // this will error out because a re-auth with server will be required - if strings.Contains(cfg.ServerURL, oldProdURL) || - strings.Contains(cfg.ServerURL, oldStageURL) { - return nil, fmt.Errorf("legacy config detected. please run `syftbox login` to re-authenticate") + if strings.Contains(cfg.ServerURL, "openmined.org") { + return nil, fmt.Errorf("legacy server detected. run `syftbox login` to re-authenticate") } return cfg, nil diff --git a/cmd/client/main_test.go b/cmd/client/main_test.go index aeb4ea37..728e03ed 100644 --- a/cmd/client/main_test.go +++ b/cmd/client/main_test.go @@ -12,7 +12,7 @@ import ( func TestLoadConfigEnv(t *testing.T) { t.Setenv("SYFTBOX_EMAIL", "test@example.com") - t.Setenv("SYFTBOX_SERVER_URL", "https://test.openmined.org") + t.Setenv("SYFTBOX_SERVER_URL", "https://test.syftbox.net") t.Setenv("SYFTBOX_CLIENT_URL", "http://localhost:7938") t.Setenv("SYFTBOX_APPS_ENABLED", "true") t.Setenv("SYFTBOX_REFRESH_TOKEN", "test-refresh-token") @@ -34,7 +34,7 @@ func TestLoadConfigEnv(t *testing.T) { require.NoError(t, err) assert.Equal(t, "test@example.com", cfg.Email) - assert.Equal(t, "https://test.openmined.org", cfg.ServerURL) + assert.Equal(t, "https://test.syftbox.net", cfg.ServerURL) assert.Equal(t, "http://localhost:7938", cfg.ClientURL) assert.Equal(t, true, cfg.AppsEnabled) assert.Equal(t, "test-refresh-token", cfg.RefreshToken) @@ -55,7 +55,7 @@ func TestLoadConfigJSON(t *testing.T) { { "email": "test@example.com", "data_dir": "/tmp/syftbox-test-json", - "server_url": "https://test-json.openmined.org", + "server_url": "https://test-json.syftbox.net", "client_url": "http://localhost:8080", "refresh_token": "test-refresh-token-json", "access_token": "test-access-token-json" @@ -78,7 +78,7 @@ func TestLoadConfigJSON(t *testing.T) { require.Equal(t, dummyConfigFile, cfg.Path) assert.Equal(t, "test@example.com", cfg.Email) assert.Equal(t, "/tmp/syftbox-test-json", cfg.DataDir) - assert.Equal(t, "https://test-json.openmined.org", cfg.ServerURL) + assert.Equal(t, "https://test-json.syftbox.net", cfg.ServerURL) assert.Equal(t, "http://localhost:8080", cfg.ClientURL) assert.Equal(t, "test-refresh-token-json", cfg.RefreshToken) assert.Equal(t, "test-access-token-json", cfg.AccessToken) // can read, but not persist! diff --git a/cmd/server/main.go b/cmd/server/main.go index 9379ae3e..e9a8d687 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -32,7 +32,6 @@ const ( var ( dotenvLoaded bool - prodEnv bool ) var rootCmd = &cobra.Command{ @@ -84,19 +83,31 @@ func init() { } else { dotenvLoaded = true } - - prodEnv = os.Getenv("SYFTBOX_ENV") == "PROD" } func main() { // Setup logger - var handler slog.Handler - if prodEnv { - handler = slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ + logger := slog.New(setupHandler()) + slog.SetDefault(logger) + + // Setup root context with signal handling + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + + // server go brr + if err := rootCmd.ExecuteContext(ctx); err != nil { + os.Exit(1) + } +} + +func setupHandler() slog.Handler { + switch os.Getenv("SYFTBOX_ENV") { + case "PROD", "STAGE": + return slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ Level: slog.LevelDebug, }) - } else { - handler = tint.NewHandler(os.Stdout, &tint.Options{ + default: + return tint.NewHandler(os.Stdout, &tint.Options{ Level: slog.LevelDebug, AddSource: true, TimeFormat: time.DateTime, @@ -108,17 +119,6 @@ func main() { }, }) } - logger := slog.New(handler) - slog.SetDefault(logger) - - // Setup root context with signal handling - ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) - defer stop() - - // server go brr - if err := rootCmd.ExecuteContext(ctx); err != nil { - os.Exit(1) - } } // loadConfig initializes viper, reads config file/env vars, and maps values to config diff --git a/config/server.example.yaml b/config/server.example.yaml index ee3eab4a..bb244fa2 100644 --- a/config/server.example.yaml +++ b/config/server.example.yaml @@ -24,7 +24,7 @@ auth: # whether to enable auth enabled: true # issuer of the JWT tokens (required) - token_issuer: https://example.com + token_issuer: https://test.syftbox.net # secret for the refresh token (required) # recommended to use SYFTBOX_AUTH_REFRESH_TOKEN_SECRET env var refresh_token_secret: refresh_token_secret diff --git a/docker/Dockerfile.client b/docker/Dockerfile.client new file mode 100644 index 00000000..4238d43c --- /dev/null +++ b/docker/Dockerfile.client @@ -0,0 +1,45 @@ +# Build stage +FROM golang:1.24.3-alpine AS builder + +# Install build dependencies +RUN apk add --no-cache git + +# Set working directory +WORKDIR /app + +# Copy go mod and sum files +COPY go.mod go.sum ./ + +# Download dependencies +RUN go mod download + +# Copy the source code +COPY . . + +# Build the client binary +RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo -o syftbox ./cmd/client + +# Final stage +FROM alpine:latest + +# Install ca-certificates for HTTPS +RUN apk --no-cache add ca-certificates bash + +WORKDIR /root/ + +# Copy the binary from builder +COPY --from=builder /app/syftbox . + +# Create base directories +RUN mkdir -p /root/.syftbox + +# Copy entrypoint script +COPY docker/entrypoint-client.sh /entrypoint.sh +RUN chmod +x /entrypoint.sh + +# Expose the daemon API port +EXPOSE 7938 + +# Use entrypoint script +ENTRYPOINT ["/entrypoint.sh"] +CMD ["--help"] \ No newline at end of file diff --git a/docker/Dockerfile.server b/docker/Dockerfile.server new file mode 100644 index 00000000..c45401be --- /dev/null +++ b/docker/Dockerfile.server @@ -0,0 +1,39 @@ +# Build stage +FROM golang:1.24.3-alpine AS builder + +# Install build dependencies +RUN apk add --no-cache git + +# Set working directory +WORKDIR /app + +# Copy go mod and sum files +COPY go.mod go.sum ./ + +# Download dependencies +RUN go mod download + +# Copy the source code +COPY . . + +# Build the server binary +RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo -o server ./cmd/server + +# Final stage +FROM alpine:latest + +# Install ca-certificates for HTTPS and mc (MinIO client) +RUN apk --no-cache add ca-certificates curl && \ + curl -sSL https://dl.min.io/client/mc/release/linux-amd64/mc -o /usr/local/bin/mc && \ + chmod +x /usr/local/bin/mc + +WORKDIR /root/ + +# Copy the binary from builder +COPY --from=builder /app/server . + +# Expose the server port (adjust if needed) +EXPOSE 8080 + +# Run the server +CMD ["./server"] \ No newline at end of file diff --git a/docker/docker-compose-client.yml b/docker/docker-compose-client.yml new file mode 100644 index 00000000..c9ee406e --- /dev/null +++ b/docker/docker-compose-client.yml @@ -0,0 +1,26 @@ +version: '3.8' + +services: + client: + build: + context: .. + dockerfile: docker/Dockerfile.client + image: syftbox-client:latest + container_name: syftbox-client-${CLIENT_EMAIL:-default} + ports: + - "${CLIENT_PORT:-7938}:7938" + environment: + - SYFTBOX_SERVER_URL=${SYFTBOX_SERVER_URL:-http://syftbox-server:8080} + - SYFTBOX_AUTH_ENABLED=0 + volumes: + - ${SYFTBOX_CLIENTS_DIR:-~/.syftbox/clients}:/data/clients + networks: + - syftbox-network + stdin_open: true + tty: true + command: ${CLIENT_EMAIL:---help} + +networks: + syftbox-network: + external: true + name: docker_syftbox-network \ No newline at end of file diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml new file mode 100644 index 00000000..ede3c7fa --- /dev/null +++ b/docker/docker-compose.yml @@ -0,0 +1,69 @@ +services: + minio: + image: minio/minio:RELEASE.2025-04-22T22-12-26Z + container_name: syftbox-minio + ports: + - "9000:9000" + - "9001:9001" + environment: + - MINIO_ROOT_USER=minioadmin + - MINIO_ROOT_PASSWORD=minioadmin + volumes: + - minio-data:/data + - ../minio/init.d:/etc/minio/init.d + command: server /data --console-address ':9001' + networks: + - syftbox-network + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] + interval: 5s + timeout: 5s + retries: 5 + + server: + build: + context: .. + dockerfile: docker/Dockerfile.server + container_name: syftbox-server + ports: + - "8080:8080" + environment: + - SYFTBOX_ENV=DEV + - SYFTBOX_AUTH_ENABLED=0 + - SYFTBOX_EMAIL_ENABLED=0 + - SYFTBOX_BLOB_REGION=us-east-1 + - SYFTBOX_BLOB_BUCKET_NAME=syftbox-local + - SYFTBOX_BLOB_ENDPOINT=http://minio:9000 + - SYFTBOX_BLOB_ACCESS_KEY=ptSLdKiwOi2LYQFZYEZ6 + - SYFTBOX_BLOB_SECRET_KEY=GMDvYrAhWDkB2DyFMn8gU8I8Bg0fT3JGT6iEB7P8 + - SYFTBOX_HTTP_ADDR=0.0.0.0:8080 + networks: + - syftbox-network + depends_on: + minio: + condition: service_healthy + restart: unless-stopped + command: > + sh -c " + # Wait for MinIO to be ready and run setup + until mc alias set local http://minio:9000 minioadmin minioadmin >/dev/null 2>&1; do + echo 'Waiting for MinIO...' + sleep 1 + done + echo 'Running MinIO setup...' + # Update the setup script to use the correct endpoint + sed 's|http://localhost:9000|http://minio:9000|g' /etc/minio/init.d/setup.sh > /tmp/setup.sh + chmod +x /tmp/setup.sh + /tmp/setup.sh || true + echo 'Starting server...' + ./server + " + volumes: + - ../minio/init.d:/etc/minio/init.d:ro + +networks: + syftbox-network: + driver: bridge + +volumes: + minio-data: \ No newline at end of file diff --git a/docker/entrypoint-client.sh b/docker/entrypoint-client.sh new file mode 100644 index 00000000..0dba331a --- /dev/null +++ b/docker/entrypoint-client.sh @@ -0,0 +1,77 @@ +#!/bin/bash +set -e + +# Function to setup client configuration +setup_client() { + local email="$1" + local server_url="${SYFTBOX_SERVER_URL:-http://syftbox-server:8080}" + local client_dir="/data/clients/${email}" + local config_file="${client_dir}/config.json" + local data_dir="${client_dir}/SyftBox" + + # Create directories if they don't exist + mkdir -p "${client_dir}" + mkdir -p "${data_dir}" + + # Set environment variables for this session + export SYFTBOX_CONFIG_PATH="${config_file}" + export SYFTBOX_DATA_DIR="${data_dir}" + export SYFTBOX_EMAIL="${email}" + export SYFTBOX_SERVER_URL="${server_url}" + + # For local dev, create a simple config that bypasses auth + if [[ "${server_url}" == *"syftbox-server"* ]] || [[ "${server_url}" == *"localhost"* ]] || [[ "${server_url}" == *"127.0.0.1"* ]]; then + echo "Setting up local dev config (auth bypass)" + cat > "${config_file}" << EOF +{ + "data_dir": "${data_dir}", + "email": "${email}", + "server_url": "${server_url}", + "client_url": "http://localhost:7938" +} +EOF + fi + + # Create symlinks for easier access in container + ln -sf "${config_file}" /root/.syftbox/config.json + ln -sf "${data_dir}" /root/SyftBox + + echo "Client setup for ${email}:" + echo " Config: ${config_file}" + echo " Data: ${data_dir}" + echo " Server: ${server_url}" +} + +# If email is provided as first argument, setup the client +if [[ "$1" =~ ^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$ ]]; then + email="$1" + shift + setup_client "${email}" + + # If no additional command provided, run daemon + if [ $# -eq 0 ]; then + echo "Starting daemon for ${email}..." + exec ./syftbox daemon + else + # Run the provided command + exec ./syftbox "$@" + fi +elif [ "$1" = "login" ] && [ -n "$2" ]; then + # Handle login command specially + email="$2" + shift 2 + setup_client "${email}" + + # For local dev servers, skip actual login and just setup config + if [[ "${SYFTBOX_SERVER_URL:-http://syftbox-server:8080}" == *"syftbox-server"* ]] || [[ "${SYFTBOX_SERVER_URL:-http://syftbox-server:8080}" == *"localhost"* ]] || [[ "${SYFTBOX_SERVER_URL:-http://syftbox-server:8080}" == *"127.0.0.1"* ]]; then + echo "Local dev server detected - skipping OTP login" + echo "Config created at: ${SYFTBOX_CONFIG_PATH}" + echo "You can now run: just run-docker-client-daemon ${email}" + exit 0 + else + exec ./syftbox login "$@" + fi +else + # Pass through to syftbox + exec ./syftbox "$@" +fi \ No newline at end of file diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 00000000..56b3672a --- /dev/null +++ b/docs/index.md @@ -0,0 +1,23 @@ +# SyftBox Documentation + +- [SyftBox Permissions System](permissions.md) + - [Overview](permissions.md#overview) + - [Design Philosophy: Unix-Inspired Permissions](permissions.md#design-philosophy-unix-inspired-permissions) + - [User Identity](permissions.md#user-identity) + - [Architecture Overview](permissions.md#architecture-overview) + - [Core Components](permissions.md#core-components) + - [ACL Service](permissions.md#1-acl-service-internalserveraclaclgo) + - [Tree Structure](permissions.md#2-tree-structure-internalserveracltreego) + - [Testing and Validation](permissions.md#testing-and-validation) + - [Test Coverage](permissions.md#test-coverage) + - [Example Test Case](permissions.md#example-test-case) + - [Best Practices](permissions.md#best-practices) + - [Permission Design](permissions.md#1-permission-design) + - [Rule Organization](permissions.md#2-rule-organization) + - [Security Considerations](permissions.md#3-security-considerations) + - [Performance Optimization](permissions.md#4-performance-optimization) + - [Integration Points](permissions.md#integration-points) + - [Server Integration](permissions.md#server-integration-internalserverservergo) + - [Client Synchronization](permissions.md#client-synchronization-internalclientsync) + - [Migration and Compatibility](permissions.md#migration-and-compatibility) + - [Conclusion](permissions.md#conclusion) diff --git a/docs/permissions.md b/docs/permissions.md new file mode 100644 index 00000000..c7ddaf4e --- /dev/null +++ b/docs/permissions.md @@ -0,0 +1,592 @@ +# SyftBox Permissions System + +## Overview + +SyftBox implements a simple but powerful file-based Access Control List (ACL) system that mimics Unix permissions while providing more granular control over distributed file systems. The system is designed to secure data sharing across federated networks while maintaining ease of use and flexibility. + +## SyftBox Permissions in 1 Minute ⚡ + +SyftBox uses **email-based permissions** to control who can access your files. Each user has a datasite with a default private root permission: + +### Default Folder Structure +``` +/datasites/ +└── your@email.org/ + ├── syft.pub.yaml # Root permissions (private by default) + ├── public/ + │ └── syft.pub.yaml # Public files + └── shared/ + └── syft.pub.yaml # Shared with specific collaborators +``` + +### Default Root Permissions +Your root `/datasites/your@email.org/syft.pub.yaml` starts private: +```yaml +rules: + - pattern: "**" + access: + read: [] # No one can read + write: [] # No one can write + admin: [] # No one can modify permissions +``` + +### Example Public and Shared Folders +Create `/datasites/your@email.org/public/syft.pub.yaml`: +```yaml +terminal: true # optional `terminal: true`, child directories don't inherit parent permissions +rules: + - pattern: "**" + access: + read: ["*"] # Everyone can read +``` + +Create `/datasites/your@email.org/shared/syft.pub.yaml`: +```yaml +terminal: true # optional When `terminal: true`, child directories don't inherit parent permissions +rules: + - pattern: "**" + access: + read: ["alice@university.edu", "bob@research.org"] + write: ["alice@university.edu"] +``` + +**Key Points:** +- **Email addresses** identify users (`alice@university.edu`) +- **Patterns** match files (`*.txt`, `folder/*`, `**` for everything) +- **Permissions**: `read` (view), `write` (edit), `admin` (change permissions) +- **`"*"`** means public access +- **Root is private** by default - create subfolders for sharing +- **More specific rules** override general ones + + + +## Design Philosophy: Unix-Inspired Permissions + +The SyftBox permission system draws inspiration from Unix file permissions but extends beyond the traditional user/group/other model to support: + +- **Distributed users**: Multiple users identified by email addresses across different nodes +- **Hierarchical rules**: Directory-based permission inheritance +- **Pattern matching**: Glob-based file pattern permissions +- **Granular access**: Separate read, write, and admin permissions +- **Terminal inheritance**: Controlled permission propagation + +## User Identity + +In SyftBox, users are identified by their **email addresses**. This federated approach allows researchers, organizations, and collaborators from different institutions to securely share data while maintaining clear identity management. Examples of user identifiers include: + +- `*` - Special identifier meaning "everyone" (public access) +- `alice@research.org` - A researcher at a research institution +- `bob@university.edu` - A professor at a university +- `team@company.com` - A team or service account at a company +- `*@company.com` - Email's can use glob patterns to allow org level access + +## Architecture Overview + +The permission system consists of several key components working together: + +``` +┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐ +│ ACL Service │────│ Rule Cache │────│ Tree Structure │ +│ │ │ │ │ │ +│ - Rule lookup │ │ - O(1) lookups │ │ - N-ary tree │ +│ - Access check │ │ - Cache hits │ │ - Path traversal│ +│ - File limits │ │ - Invalidation │ │ - Rule storage │ +└─────────────────┘ └──────────────────┘ └─────────────────┘ + │ │ │ + └───────────────────────┼───────────────────────┘ + │ + ┌─────────────────────┐ + │ syft.pub.yaml │ + │ Configuration │ + │ │ + │ - Rules patterns │ + │ - Access grants │ + │ - Terminal flags │ + └─────────────────────┘ +``` + +## Core Components + +### 1. ACL Service (`internal/server/acl/acl.go`) + +The `AclService` is the main orchestrator that manages access control: + +```go +type AclService struct { + tree *Tree // Hierarchical rule storage + cache *RuleCache // Performance optimization +} +``` + +**Key methods:** +- `CanAccess()` - Primary access control check +- `GetRule()` - Rule lookup with caching +- `AddRuleSet()` - Dynamic rule management + +### 2. Tree Structure (`internal/server/acl/tree.go`) + +The permission tree stores rules in an n-ary tree for efficient path-based lookups: + +```go +type Tree struct { + root *Node +} +``` + +**Features:** +- **O(depth)** rule lookups +- **Terminal nodes** to prevent inheritance +- **Dynamic rule addition/removal** +- **Path-based traversal** + +### 3. Rule System (`internal/server/acl/rule.go`) + +Rules define the actual access control logic: + +```go +type Rule struct { + fullPattern string // Complete path + glob pattern + rule *aclspec.Rule // Access specifications + node *Node // Associated tree node +} +``` + +### 4. Access Levels (`internal/server/acl/level.go`) + +The system defines five distinct access levels using bit flags: + +```go +type AccessLevel uint8 + +const ( + AccessRead AccessLevel = 1 << iota // 0001 + AccessCreate // 0010 + AccessWrite // 0100 + AccessReadACL // 1000 + AccessWriteACL // 10000 +) +``` + +## Permission Files: `syft.pub.yaml` + +### File Format + +Permission files follow the YAML format and are always named `syft.pub.yaml`: + +```yaml +terminal: true +rules: + - pattern: "*.md" + access: + read: ["*"] + - pattern: "private/*.txt" + access: + read: ["alice@research.org", "bob@university.edu"] + write: ["alice@research.org"] + - pattern: "**" + access: {} +``` + +### Configuration Elements + +#### Terminal Flag +- **Purpose**: Controls permission inheritance +- **Values**: `true` (stop inheritance) or `false` (allow inheritance) +- **Impact**: When `terminal: true`, child directories don't inherit parent permissions + +#### Rules Array +Each rule contains: +- **`pattern`**: Glob pattern matching files/directories +- **`access`**: Permission specifications + +#### Access Specifications +- **`admin`**: Can modify ACL files (`syft.pub.yaml`) - specified by email addresses +- **`write`**: Can create, modify, and delete files - specified by email addresses +- **`read`**: Can view and download files - specified by email addresses +- **Special value `"*"`**: Grants access to everyone (public access) + +### Example Configurations + +#### Public Documentation with Private Admin +```yaml +terminal: true +rules: + - pattern: "docs/*.md" + access: + read: ["*"] + write: ["maintainer@company.com"] + - pattern: "admin/*" + access: + admin: ["admin@company.com"] + - pattern: "**" + access: + read: ["team@company.com"] +``` + +#### Collaborative Project Structure +```yaml +terminal: false +rules: + - pattern: "src/**/*.go" + access: + read: ["*"] + write: ["dev1@company.com", "dev2@university.edu"] + - pattern: "tests/**" + access: + read: ["*"] + write: ["dev1@company.com", "qa@company.com"] + - pattern: "config/*.yaml" + access: + admin: ["devops@company.com"] + - pattern: "**" + access: {} +``` + +#### Research Data Sharing +```yaml +terminal: true +rules: + - pattern: "public_datasets/*" + access: + read: ["*"] + - pattern: "shared_analysis/*.ipynb" + access: + read: ["alice@research.org", "bob@university.edu", "charlie@institute.org"] + write: ["alice@research.org"] + - pattern: "raw_data/*" + access: + read: ["alice@research.org", "bob@university.edu"] + admin: ["data-owner@research.org"] +``` + +## Access Control Flow + +### Step-by-Step Permission Check + +1. **Owner Check** (`internal/server/acl/acl.go`) + ```go + if user.IsOwner { + return nil // Owners bypass all restrictions + } + ``` + +2. **Rule Lookup** (`internal/server/acl/acl.go`) + - Check cache for O(1) performance + - Traverse tree structure for O(depth) lookup + - Find most specific matching rule + +3. **ACL File Elevation** (`internal/server/acl/acl.go`) + ```go + if isAcl && level == AccessWrite { + level = AccessWriteACL // Elevate to admin requirement + } + ``` + +4. **File Limits Check** (`internal/server/acl/rule.go`) + - Validate file size limits + - Check directory permissions + - Verify symlink restrictions + +5. **Access Verification** (`internal/server/acl/rule.go`) + - Check user permissions against required level + - Handle permission hierarchy (admin > write > read) + +### Permission Hierarchy + +The system enforces a clear permission hierarchy: + +``` +Admin (AccessWriteACL) + ├─ Can modify syft.pub.yaml files + ├─ Full write access + └─ Full read access + │ +Write (AccessWrite) + ├─ Can create/modify/delete files + ├─ Subject to file limits + └─ Full read access + │ +Read (AccessRead) + └─ Can view and download files +``` + +## Tree Structure and Inheritance + +### Hierarchical Rule Storage + +The tree structure mirrors the file system hierarchy: + +``` +/ +├─ users/ +│ ├─ alice/ +│ │ ├─ syft.pub.yaml (terminal: true) +│ │ └─ documents/ +│ └─ bob/ +└─ shared/ + ├─ syft.pub.yaml (terminal: false) + └─ projects/ + └─ project1/ + └─ syft.pub.yaml (terminal: true) +``` + +### Rule Resolution (`internal/server/acl/tree.go`) + +The `GetNearestNodeWithRules()` method traverses up the tree to find applicable rules: + +```go +func (t *Tree) GetNearestNodeWithRules(path string) *Node { + parts := pathParts(path) + var candidate *Node + current := t.root + + for _, part := range parts { + if current.IsTerminal() { + break // Stop at terminal nodes + } + + child, exists := current.GetChild(part) + if !exists { + break + } + + current = child + if child.Rules() != nil { + candidate = current + } + } + + return candidate +} +``` + +## Caching System + +### Performance Optimization + +The rule cache provides O(1) lookups for frequently accessed paths: + +- **Cache hits**: Return immediately without tree traversal +- **Cache misses**: Perform tree lookup and cache result +- **Cache invalidation**: Clear affected entries when rules change + +### Cache Management (`internal/server/acl/acl.go`) + +```go +func (s *AclService) RemoveRuleSet(path string) bool { + s.cache.DeletePrefix(path) // Invalidate cache entries + return s.tree.RemoveRuleSet(path) +} +``` + +## File Limits and Restrictions + +### Supported Limits (`internal/aclspec/limits.go`) + +```go +type Limits struct { + MaxFileSize int64 `yaml:"maxFileSize,omitempty"` + MaxFiles uint32 `yaml:"maxFiles,omitempty"` + AllowDirs bool `yaml:"allowDirs,omitempty"` + AllowSymlinks bool `yaml:"allowSymlinks,omitempty"` +} +``` + +### Enforcement Logic (`internal/server/acl/rule.go`) + +```go +func (r *Rule) CheckLimits(info *File) error { + limits := r.rule.Limits + + if limits.MaxFileSize > 0 && info.Size > limits.MaxFileSize { + return ErrFileSizeExceeded + } + + if !limits.AllowDirs && (info.IsDir || strings.Count(info.Path, pathSep) > 0) { + return ErrDirsNotAllowed + } + + if !limits.AllowSymlinks && info.IsSymlink { + return ErrSymlinksNotAllowed + } + + return nil +} +``` + +## Pattern Matching + +### Glob Pattern Support + +The system supports powerful glob patterns for flexible file matching: + +- **`**`** - Matches all files and directories recursively +- **`*.ext`** - Matches all files with specific extension +- **`dir/*`** - Matches all direct children of directory +- **`dir/**`** - Matches all descendants of directory +- **`specific.txt`** - Matches exact filename + +### Pattern Examples + +```yaml +rules: + - pattern: "**/*.py" # All Python files + access: + read: ["dev1@company.com", "dev2@university.edu"] + + - pattern: "tests/**" # Everything in tests directory + access: + write: ["qa@company.com"] + + - pattern: "config.yaml" # Specific file + access: + admin: ["devops@company.com"] + + - pattern: "public/*" # Direct children only + access: + read: ["*"] +``` + +## Implementation Details + +### User Representation (`internal/server/acl/types.go`) + +```go +type User struct { + ID string // User identifier (email address) + IsOwner bool // Owner bypass flag +} +``` + +### File Information (`internal/server/acl/types.go`) + +```go +type File struct { + Path string // File system path + IsDir bool // Directory flag + IsSymlink bool // Symbolic link flag + Size int64 // File size in bytes +} +``` + +### Rule Set Management (`internal/aclspec/ruleset.go`) + +```go +type RuleSet struct { + Rules []*Rule `yaml:"rules,omitempty"` + Terminal bool `yaml:"terminal,omitempty"` + Path string `yaml:"-"` // Internal use only +} +``` + +## Error Handling + +The system defines specific error types for different access violations: + +```go +var ( + ErrAdminRequired = errors.New("admin access required") + ErrWriteRequired = errors.New("write access required") + ErrReadRequired = errors.New("read access required") + ErrDirsNotAllowed = errors.New("directories not allowed") + ErrSymlinksNotAllowed = errors.New("symlinks not allowed") + ErrFileSizeExceeded = errors.New("file size exceeds limits") +) +``` + +## Testing and Validation + +### Test Coverage + +The system includes comprehensive tests covering: + +- **Rule resolution** (`internal/server/acl/acl_test.go`) +- **Access control enforcement** (`internal/server/acl/acl_test.go`) +- **File limit validation** (`internal/server/acl/acl_test.go`) +- **Cache behavior** (`internal/server/acl/acl_test.go`) +- **Tree operations** (`internal/server/acl/tree_test.go`) + +### Example Test Case + +```go +func TestAclServiceCanAccess(t *testing.T) { + service := NewAclService() + + ruleset := aclspec.NewRuleSet( + "user", + aclspec.SetTerminal, + aclspec.NewRule("public/*.txt", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), + aclspec.NewRule("private/*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + ) + + err := service.AddRuleSet(ruleset) + assert.NoError(t, err) + + // Test access control + regularUser := &User{ID: aclspec.Everyone, IsOwner: false} + publicFile := &File{Path: "user/public/doc.txt", Size: 100} + + err = service.CanAccess(regularUser, publicFile, AccessRead) + assert.NoError(t, err) // Should succeed +} +``` + +## Best Practices + +### 1. Permission Design +- Start with restrictive permissions and grant access as needed +- Use terminal nodes to prevent unintended inheritance +- Group related files using consistent patterns + +### 2. Rule Organization +- Place more specific rules before general ones +- Always include a default `**` catch-all rule +- Use descriptive user group names + +### 3. Security Considerations +- Monitor for overly permissive `*` grants +- Implement file size limits for public areas + +### 4. Performance Optimization +- Keep rule sets small and focused +- Use terminal nodes to limit tree traversal +- Monitor cache hit rates + +## Integration Points + +### Server Integration (`internal/server/server.go`) + +The ACL service integrates with the SyftBox server to protect all file operations: + +- **File uploads**: Check write permissions +- **File downloads**: Verify read access +- **Directory listings**: Filter based on permissions +- **ACL modifications**: Require admin access + +### Client Synchronization (`internal/client/sync/`) + +The sync engine respects ACL permissions during: + +- **Upload operations**: Validate write access before sync +- **Download filtering**: Only sync permitted files +- **Conflict resolution**: Consider permission changes + +## Migration and Compatibility + +The system includes migration support for updating ACL formats: + +- **Backward compatibility**: Support older permission formats +- **Automatic migration**: Upgrade rules during system updates +- **Validation**: Ensure rule consistency after migration + +## Conclusion + +The SyftBox permission system provides a robust, Unix-inspired access control mechanism tailored for distributed file systems. By combining hierarchical rules, pattern matching, and efficient caching, it offers both security and performance for federated data sharing scenarios. + +The system's design emphasizes: +- **Flexibility**: Glob patterns and granular permissions +- **Performance**: O(1) cached lookups and O(depth) tree traversal +- **Security**: Default-deny policies and owner protections +- **Usability**: YAML configuration and inheritance patterns + +This architecture enables secure, scalable data sharing across distributed networks while maintaining the intuitive permission model familiar to Unix users. \ No newline at end of file diff --git a/internal/aclspec/access_test.go b/internal/aclspec/access_test.go new file mode 100644 index 00000000..c38405c5 --- /dev/null +++ b/internal/aclspec/access_test.go @@ -0,0 +1,261 @@ +package aclspec + +import ( + "testing" + + mapset "github.com/deckarep/golang-set/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +func TestNewAccess(t *testing.T) { + // Test creating Access with different user email combinations + // This tests the core constructor and validates that sets are properly initialized + admin := []string{"admin@example.com", "owner@company.org"} + write := []string{"writer1@openmined.org", "writer2@research.edu", "collab@university.ac.uk"} + read := []string{"reader@public.org"} + + access := NewAccess(admin, write, read) + + // Verify all sets are properly initialized and contain expected user emails + assert.True(t, access.Admin.Contains("admin@example.com")) + assert.True(t, access.Admin.Contains("owner@company.org")) + assert.Equal(t, 2, access.Admin.Cardinality()) + + assert.True(t, access.Write.Contains("writer1@openmined.org")) + assert.True(t, access.Write.Contains("writer2@research.edu")) + assert.True(t, access.Write.Contains("collab@university.ac.uk")) + assert.Equal(t, 3, access.Write.Cardinality()) + + assert.True(t, access.Read.Contains("reader@public.org")) + assert.Equal(t, 1, access.Read.Cardinality()) +} + +func TestNewAccessWithEmptyLists(t *testing.T) { + // Test edge case of creating Access with empty lists + // This ensures the constructor handles empty inputs gracefully + access := NewAccess([]string{}, []string{}, []string{}) + + assert.Equal(t, 0, access.Admin.Cardinality()) + assert.Equal(t, 0, access.Write.Cardinality()) + assert.Equal(t, 0, access.Read.Cardinality()) +} + +func TestPrivateAccess(t *testing.T) { + // Test that PrivateAccess creates an Access object with no permissions + // This is a critical security test - private should mean NO access for anyone + access := PrivateAccess() + + assert.Equal(t, 0, access.Admin.Cardinality(), "Private access should have no admin users") + assert.Equal(t, 0, access.Write.Cardinality(), "Private access should have no write users") + assert.Equal(t, 0, access.Read.Cardinality(), "Private access should have no read users") +} + +func TestPublicReadAccess(t *testing.T) { + // Test that PublicReadAccess grants read access to everyone but nothing else + // This validates the most common public sharing scenario + access := PublicReadAccess() + + assert.Equal(t, 0, access.Admin.Cardinality(), "Public read should have no admin users") + assert.Equal(t, 0, access.Write.Cardinality(), "Public read should have no write users") + assert.Equal(t, 1, access.Read.Cardinality(), "Public read should have exactly one read entry") + assert.True(t, access.Read.Contains(Everyone), "Public read should grant read access to everyone") +} + +func TestPublicReadWriteAccess(t *testing.T) { + // Test that PublicReadWriteAccess grants write (and implicitly read) access to everyone + // This tests the dangerous but sometimes necessary full public access + access := PublicReadWriteAccess() + + assert.Equal(t, 0, access.Admin.Cardinality(), "Public read-write should have no admin users") + assert.Equal(t, 1, access.Write.Cardinality(), "Public read-write should have exactly one write entry") + assert.True(t, access.Write.Contains(Everyone), "Public read-write should grant write access to everyone") + assert.Equal(t, 0, access.Read.Cardinality(), "Public read-write should not set read (write implies read)") +} + +func TestSharedReadAccess(t *testing.T) { + // Test creating shared read access for specific user emails + // This validates the common scenario of sharing read access with specific collaborators + users := []string{"alice@research.org", "bob@university.edu", "charlie@company.com"} + access := SharedReadAccess(users...) + + assert.Equal(t, 0, access.Admin.Cardinality(), "Shared read should have no admin users") + assert.Equal(t, 0, access.Write.Cardinality(), "Shared read should have no write users") + assert.Equal(t, 3, access.Read.Cardinality(), "Shared read should have exactly the specified users") + + for _, user := range users { + assert.True(t, access.Read.Contains(user), "Shared read should contain user %s", user) + } +} + +func TestSharedReadWriteAccess(t *testing.T) { + // Test creating shared read-write access for specific user emails + // This validates collaborative scenarios where specific users need write access + users := []string{"maintainer1@openmined.org", "maintainer2@research.edu"} + access := SharedReadWriteAccess(users...) + + assert.Equal(t, 0, access.Admin.Cardinality(), "Shared read-write should have no admin users") + assert.Equal(t, 2, access.Write.Cardinality(), "Shared read-write should have exactly the specified users") + assert.Equal(t, 0, access.Read.Cardinality(), "Shared read-write should not set read (write implies read)") + + for _, user := range users { + assert.True(t, access.Write.Contains(user), "Shared read-write should contain user %s", user) + } +} + +func TestAccessUnmarshalYAML(t *testing.T) { + // Test unmarshaling Access from YAML format + // This is critical for loading syft.pub.yaml files from disk + yamlData := ` +admin: ["admin@example.com", "owner@company.org"] +read: ["reader1@public.org", "reader2@university.edu", "reader3@research.net"] +write: ["writer@openmined.org"] +` + + var access Access + err := yaml.Unmarshal([]byte(yamlData), &access) + require.NoError(t, err, "YAML unmarshaling should succeed") + + // Verify that all user emails were properly loaded into their respective sets + assert.Equal(t, 2, access.Admin.Cardinality()) + assert.True(t, access.Admin.Contains("admin@example.com")) + assert.True(t, access.Admin.Contains("owner@company.org")) + + assert.Equal(t, 3, access.Read.Cardinality()) + assert.True(t, access.Read.Contains("reader1@public.org")) + assert.True(t, access.Read.Contains("reader2@university.edu")) + assert.True(t, access.Read.Contains("reader3@research.net")) + + assert.Equal(t, 1, access.Write.Cardinality()) + assert.True(t, access.Write.Contains("writer@openmined.org")) +} + +func TestAccessUnmarshalYAMLWithMissingSections(t *testing.T) { + // Test unmarshaling when some access sections are missing + // This ensures the system handles partial configurations gracefully + yamlData := ` +read: ["reader@example.com"] +# write and admin sections intentionally missing +` + + var access Access + err := yaml.Unmarshal([]byte(yamlData), &access) + require.NoError(t, err, "YAML unmarshaling should succeed even with missing sections") + + // Missing sections should result in empty sets, not nil + assert.Equal(t, 0, access.Admin.Cardinality(), "Missing admin section should result in empty set") + assert.Equal(t, 0, access.Write.Cardinality(), "Missing write section should result in empty set") + assert.Equal(t, 1, access.Read.Cardinality(), "Read section should be properly loaded") + assert.True(t, access.Read.Contains("reader@example.com")) +} + +func TestAccessUnmarshalYAMLWithEmptyLists(t *testing.T) { + // Test unmarshaling with explicitly empty lists + // This ensures empty lists in YAML are handled correctly + yamlData := ` +admin: [] +read: [] +write: ["writer@company.com"] +` + + var access Access + err := yaml.Unmarshal([]byte(yamlData), &access) + require.NoError(t, err, "YAML unmarshaling should succeed with empty lists") + + assert.Equal(t, 0, access.Admin.Cardinality()) + assert.Equal(t, 0, access.Read.Cardinality()) + assert.Equal(t, 1, access.Write.Cardinality()) + assert.True(t, access.Write.Contains("writer@company.com")) +} + +func TestAccessUnmarshalYAMLInvalidFormat(t *testing.T) { + // Test that invalid YAML format is properly rejected + // This ensures the system fails safely on malformed input + yamlData := ` +admin: "should_be_a_list_not_string" +` + + var access Access + err := yaml.Unmarshal([]byte(yamlData), &access) + assert.Error(t, err, "Invalid YAML format should be rejected") +} + +func TestAccessMarshalYAML(t *testing.T) { + // Test marshaling Access to YAML format + // This is critical for saving syft.pub.yaml files to disk + access := NewAccess( + []string{"admin@example.com", "owner@company.org"}, + []string{"writer@openmined.org"}, + []string{"reader1@public.org", "reader2@university.edu"}, + ) + + data, err := yaml.Marshal(access) + require.NoError(t, err, "YAML marshaling should succeed") + + // Unmarshal back to verify the round-trip works correctly + var unmarshaled Access + err = yaml.Unmarshal(data, &unmarshaled) + require.NoError(t, err, "Round-trip unmarshaling should succeed") + + // Verify all data survived the round-trip + assert.Equal(t, access.Admin.Cardinality(), unmarshaled.Admin.Cardinality()) + assert.Equal(t, access.Write.Cardinality(), unmarshaled.Write.Cardinality()) + assert.Equal(t, access.Read.Cardinality(), unmarshaled.Read.Cardinality()) + + // Check that all user emails are preserved + for _, user := range access.Admin.ToSlice() { + assert.True(t, unmarshaled.Admin.Contains(user)) + } + for _, user := range access.Write.ToSlice() { + assert.True(t, unmarshaled.Write.Contains(user)) + } + for _, user := range access.Read.ToSlice() { + assert.True(t, unmarshaled.Read.Contains(user)) + } +} + +func TestAccessMarshalYAMLWithNilSets(t *testing.T) { + // Test marshaling when some sets are nil + // This tests the defensive programming in MarshalYAML + access := Access{ + Admin: mapset.NewSet("admin@example.com"), + Write: nil, // Intentionally nil + Read: mapset.NewSet("reader@public.org"), + } + + data, err := yaml.Marshal(access) + require.NoError(t, err, "YAML marshaling should handle nil sets gracefully") + + // Verify the marshaled data is valid YAML + var result map[string][]string + err = yaml.Unmarshal(data, &result) + require.NoError(t, err, "Marshaled YAML should be valid") + + // Nil sets should not appear in the output + assert.Contains(t, result, "admin") + assert.Contains(t, result, "read") + assert.NotContains(t, result, "write", "Nil sets should not appear in marshaled YAML") +} + +func TestAccessMarshalYAMLWithEmptySets(t *testing.T) { + // Test marshaling when sets are empty but not nil + // This verifies that empty sets are handled consistently + access := NewAccess( + []string{}, // Empty admin + []string{}, // Empty write + []string{"reader@example.com"}, // Non-empty read + ) + + data, err := yaml.Marshal(access) + require.NoError(t, err, "YAML marshaling should handle empty sets") + + var result map[string][]string + err = yaml.Unmarshal(data, &result) + require.NoError(t, err, "Marshaled YAML should be valid") + + // Empty sets should appear as empty arrays in YAML + assert.Equal(t, []string{}, result["admin"]) + assert.Equal(t, []string{}, result["write"]) + assert.Contains(t, result["read"], "reader@example.com") +} \ No newline at end of file diff --git a/internal/aclspec/aclspec.go b/internal/aclspec/aclspec.go index 3cb1b890..6558fdfa 100644 --- a/internal/aclspec/aclspec.go +++ b/internal/aclspec/aclspec.go @@ -14,30 +14,37 @@ const ( UnsetTerminal = false ) -// IsAclFile checks if the path is an syft.pub.yaml file -func IsAclFile(path string) bool { +// IsACLFile checks if the path is an syft.pub.yaml file +func IsACLFile(path string) bool { return strings.HasSuffix(path, AclFileName) } -// AsAclPath converts any path to exact acl file path -func AsAclPath(path string) string { - if IsAclFile(path) { +// AsACLPath converts any path to exact acl file path +func AsACLPath(path string) string { + if IsACLFile(path) { return path } return filepath.Join(path, AclFileName) } -// WithoutAclPath truncates syft.pub.yaml from the path -func WithoutAclPath(path string) string { +// WithoutACLPath truncates syft.pub.yaml from the path +func WithoutACLPath(path string) string { return strings.TrimSuffix(path, AclFileName) } // Exists checks if the ACL file exists at the given path +// For security reasons, symlinks are not allowed as ACL files func Exists(path string) bool { - aclPath := AsAclPath(path) - stat, err := os.Stat(aclPath) + aclPath := AsACLPath(path) + stat, err := os.Lstat(aclPath) // Use Lstat to not follow symlinks if os.IsNotExist(err) { return false } + + // Reject symlinks for security reasons + if stat.Mode()&os.ModeSymlink != 0 { + return false + } + return stat.Size() > 0 } diff --git a/internal/aclspec/aclspec_test.go b/internal/aclspec/aclspec_test.go new file mode 100644 index 00000000..1199325b --- /dev/null +++ b/internal/aclspec/aclspec_test.go @@ -0,0 +1,357 @@ +package aclspec + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIsACLFile(t *testing.T) { + // Test the core ACL file detection logic + // This is critical for the system to recognize permission files correctly + testCases := []struct { + path string + expected bool + desc string + }{ + { + path: "syft.pub.yaml", + expected: true, + desc: "Direct ACL filename should be detected", + }, + { + path: "/path/to/syft.pub.yaml", + expected: true, + desc: "ACL file with full path should be detected", + }, + { + path: "folder/subfolder/syft.pub.yaml", + expected: true, + desc: "ACL file in nested path should be detected", + }, + { + path: "not_an_acl_file.yaml", + expected: false, + desc: "Regular YAML file should not be detected as ACL", + }, + { + path: "syft.pub.yaml.backup", + expected: false, + desc: "ACL filename with suffix should not be detected", + }, + { + path: "prefix_syft.pub.yaml", + expected: true, + desc: "ACL filename with prefix should be detected (suffix match)", + }, + { + path: "", + expected: false, + desc: "Empty path should not be detected as ACL", + }, + { + path: "syft.pub.yml", + expected: false, + desc: "Wrong extension (.yml vs .yaml) should not be detected", + }, + { + path: "SYFT.PUB.YAML", + expected: false, + desc: "Case-sensitive check - uppercase should not match", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + result := IsACLFile(tc.path) + assert.Equal(t, tc.expected, result, "Path: %s", tc.path) + }) + } +} + +func TestAsACLPath(t *testing.T) { + // Test converting directory paths to ACL file paths + // This ensures the system can correctly locate ACL files for directories + testCases := []struct { + input string + expected string + desc string + }{ + { + input: "/home/user", + expected: "/home/user/syft.pub.yaml", + desc: "Directory path should get ACL filename appended", + }, + { + input: "/home/user/syft.pub.yaml", + expected: "/home/user/syft.pub.yaml", + desc: "ACL file path should remain unchanged", + }, + { + input: "", + expected: "syft.pub.yaml", + desc: "Empty path should result in just the ACL filename", + }, + { + input: "relative/path", + expected: "relative/path/syft.pub.yaml", + desc: "Relative path should get ACL filename appended", + }, + { + input: "/", + expected: "/syft.pub.yaml", + desc: "Root directory should get ACL filename appended", + }, + { + input: "folder/syft.pub.yaml", + expected: "folder/syft.pub.yaml", + desc: "Path already ending with ACL filename should be unchanged", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + result := AsACLPath(tc.input) + assert.Equal(t, tc.expected, result, "Input: %s", tc.input) + }) + } +} + +func TestWithoutACLPath(t *testing.T) { + // Test removing ACL filename from paths + // This is used to get the directory path from an ACL file path + testCases := []struct { + input string + expected string + desc string + }{ + { + input: "/home/user/syft.pub.yaml", + expected: "/home/user/", + desc: "ACL file path should have filename removed", + }, + { + input: "syft.pub.yaml", + expected: "", + desc: "Bare ACL filename should result in empty string", + }, + { + input: "/home/user/other.yaml", + expected: "/home/user/other.yaml", + desc: "Non-ACL file path should remain unchanged", + }, + { + input: "/home/user/", + expected: "/home/user/", + desc: "Directory path without ACL filename should remain unchanged", + }, + { + input: "", + expected: "", + desc: "Empty path should remain empty", + }, + { + input: "folder/subfolder/syft.pub.yaml", + expected: "folder/subfolder/", + desc: "Nested ACL file path should have filename removed", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + result := WithoutACLPath(tc.input) + assert.Equal(t, tc.expected, result, "Input: %s", tc.input) + }) + } +} + +func TestExists(t *testing.T) { + // Test ACL file existence detection + // This validates the system can correctly detect existing ACL files on disk + + // Create a temporary directory for testing + tempDir := t.TempDir() + + // Create a test ACL file with some content + aclFilePath := filepath.Join(tempDir, AclFileName) + err := os.WriteFile(aclFilePath, []byte("terminal: true\nrules: []"), 0644) + require.NoError(t, err, "Should be able to create test ACL file") + + // Test that existing non-empty ACL file is detected + assert.True(t, Exists(tempDir), "Should detect existing ACL file in directory") + assert.True(t, Exists(aclFilePath), "Should detect existing ACL file by direct path") + + // Create an empty ACL file + emptyAclDir := filepath.Join(tempDir, "empty") + err = os.Mkdir(emptyAclDir, 0755) + require.NoError(t, err, "Should be able to create test directory") + + emptyAclFile := filepath.Join(emptyAclDir, AclFileName) + err = os.WriteFile(emptyAclFile, []byte{}, 0644) + require.NoError(t, err, "Should be able to create empty ACL file") + + // Test that empty ACL file is not considered as existing + assert.False(t, Exists(emptyAclDir), "Empty ACL file should not be considered as existing") + + // Test that non-existent path returns false + nonExistentPath := filepath.Join(tempDir, "nonexistent") + assert.False(t, Exists(nonExistentPath), "Non-existent path should return false") + + // Test that directory without ACL file returns false + noAclDir := filepath.Join(tempDir, "no_acl") + err = os.Mkdir(noAclDir, 0755) + require.NoError(t, err, "Should be able to create test directory") + + assert.False(t, Exists(noAclDir), "Directory without ACL file should return false") +} + +func TestExistsWithSymlinks(t *testing.T) { + // Test that symlinks are rejected for security reasons + // ACL files must be regular files, not symlinks + + tempDir := t.TempDir() + + // Create actual ACL file + realAclDir := filepath.Join(tempDir, "real") + err := os.Mkdir(realAclDir, 0755) + require.NoError(t, err) + + realAclFile := filepath.Join(realAclDir, AclFileName) + err = os.WriteFile(realAclFile, []byte("terminal: true"), 0644) + require.NoError(t, err) + + // Verify real ACL file is detected + assert.True(t, Exists(realAclDir), "Real ACL file should be detected") + + // Create symlink to the ACL file + symlinkDir := filepath.Join(tempDir, "symlink") + err = os.Mkdir(symlinkDir, 0755) + require.NoError(t, err) + + symlinkAclFile := filepath.Join(symlinkDir, AclFileName) + err = os.Symlink(realAclFile, symlinkAclFile) + require.NoError(t, err) + + // Test that symlinked ACL file is REJECTED for security reasons + assert.False(t, Exists(symlinkDir), "Symlinked ACL files should be rejected for security") + + // Test that LoadFromFile also rejects symlinks + _, err = LoadFromFile(symlinkDir) + assert.Error(t, err, "LoadFromFile should reject symlinked ACL files") + assert.Contains(t, err.Error(), "symlinks are not allowed", "Error should mention symlink restriction") + + // Create broken symlink + brokenDir := filepath.Join(tempDir, "broken") + err = os.Mkdir(brokenDir, 0755) + require.NoError(t, err) + + brokenSymlink := filepath.Join(brokenDir, AclFileName) + err = os.Symlink("/nonexistent/file", brokenSymlink) + require.NoError(t, err) + + // Test that broken symlink is also rejected + assert.False(t, Exists(brokenDir), "Broken symlinks should be rejected") + + // Test that LoadFromFile also rejects broken symlinks + _, err = LoadFromFile(brokenDir) + assert.Error(t, err, "LoadFromFile should reject broken symlinks") + assert.Contains(t, err.Error(), "symlinks are not allowed", "Error should mention symlink restriction") +} + +func TestSymlinkSecurityComprehensive(t *testing.T) { + // Comprehensive test to ensure symlinks are rejected at all entry points + // This is critical for security - no ACL operation should follow symlinks + + tempDir := t.TempDir() + + // Create a legitimate ACL file + realAclFile := filepath.Join(tempDir, "real.yaml") + aclContent := ` +terminal: true +rules: + - pattern: "**" + access: + read: ["admin@example.com"] +` + err := os.WriteFile(realAclFile, []byte(aclContent), 0644) + require.NoError(t, err) + + // Create a directory with a symlinked ACL file + symlinkDir := filepath.Join(tempDir, "symlinked") + err = os.Mkdir(symlinkDir, 0755) + require.NoError(t, err) + + symlinkAclFile := filepath.Join(symlinkDir, AclFileName) + err = os.Symlink(realAclFile, symlinkAclFile) + require.NoError(t, err) + + // Test all ACL operations reject symlinks + + // 1. Exists should return false for symlinked ACL + assert.False(t, Exists(symlinkDir), "Exists() should reject symlinked ACL files") + + // 2. LoadFromFile should error for symlinked ACL + _, err = LoadFromFile(symlinkDir) + assert.Error(t, err, "LoadFromFile() should reject symlinked ACL files") + assert.Contains(t, err.Error(), "symlinks are not allowed", "Error should mention symlink restriction") + + // 3. AsACLPath and IsACLFile should work normally (they don't check file existence) + aclPath := AsACLPath(symlinkDir) + assert.Equal(t, symlinkAclFile, aclPath, "AsACLPath should work normally") + assert.True(t, IsACLFile(aclPath), "IsACLFile should work normally") + + // 4. Test that WithoutACLPath works normally + dirPath := WithoutACLPath(aclPath) + // Note: WithoutACLPath may add trailing separator, so we use filepath.Clean for comparison + assert.Equal(t, symlinkDir, filepath.Clean(dirPath), "WithoutACLPath should work normally") + + // Verify that regular files still work + regularDir := filepath.Join(tempDir, "regular") + err = os.Mkdir(regularDir, 0755) + require.NoError(t, err) + + regularAclFile := filepath.Join(regularDir, AclFileName) + err = os.WriteFile(regularAclFile, []byte(aclContent), 0644) + require.NoError(t, err) + + // Regular ACL files should work normally + assert.True(t, Exists(regularDir), "Regular ACL files should work") + + _, err = LoadFromFile(regularDir) + assert.NoError(t, err, "LoadFromFile should work with regular ACL files") +} + +func TestConstants(t *testing.T) { + // Test that the constants have expected values + // This ensures the constants are correctly defined and haven't been accidentally changed + assert.Equal(t, "syft.pub.yaml", AclFileName, "ACL filename constant should be correct") + assert.Equal(t, "*", Everyone, "Everyone constant should be correct") + assert.Equal(t, "**", AllFiles, "AllFiles pattern should be correct") + assert.Equal(t, true, SetTerminal, "SetTerminal constant should be true") + assert.Equal(t, false, UnsetTerminal, "UnsetTerminal constant should be false") +} + +func TestPathEdgeCases(t *testing.T) { + // Test edge cases in path manipulation functions + // This ensures robust handling of unusual but valid path scenarios + + // Test with paths containing special characters + specialPath := "/home/user with spaces/syft.pub.yaml" + assert.True(t, IsACLFile(specialPath), "Paths with spaces should be handled correctly") + + // Test with Unicode characters + unicodePath := "/home/用户/syft.pub.yaml" + assert.True(t, IsACLFile(unicodePath), "Unicode paths should be handled correctly") + + // Test with multiple slashes + multiSlashPath := "/home//user///syft.pub.yaml" + expected := "/home//user///syft.pub.yaml" + assert.Equal(t, expected, AsACLPath(multiSlashPath), "Multiple slashes should be preserved") + + // Test with backslashes (Windows-style paths) + windowsPath := "C:\\Users\\user\\syft.pub.yaml" + assert.True(t, IsACLFile(windowsPath), "Windows-style paths should be detected") +} \ No newline at end of file diff --git a/internal/aclspec/limits_test.go b/internal/aclspec/limits_test.go new file mode 100644 index 00000000..e119d51b --- /dev/null +++ b/internal/aclspec/limits_test.go @@ -0,0 +1,123 @@ +package aclspec + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDefaultLimits(t *testing.T) { + // Test that DefaultLimits creates a Limits object with expected default values + // This is critical because these defaults affect security and functionality + limits := DefaultLimits() + + // Verify the limits object is not nil + assert.NotNil(t, limits, "DefaultLimits should return a non-nil Limits object") + + // Test file count limit - default should be 0 (unlimited) + assert.Equal(t, uint32(0), limits.MaxFiles, + "Default MaxFiles should be 0 (unlimited) to not restrict file count by default") + + // Test file size limit - default should be 0 (unlimited) + assert.Equal(t, int64(0), limits.MaxFileSize, + "Default MaxFileSize should be 0 (unlimited) to not restrict file size by default") + + // Test directory permission - default should allow directories + assert.True(t, limits.AllowDirs, + "Default should allow directories since most use cases need directory creation") + + // Test symlink permission - default should NOT allow symlinks for security + assert.False(t, limits.AllowSymlinks, + "Default should not allow symlinks for security reasons (prevent symlink attacks)") +} + +func TestLimitsStruct(t *testing.T) { + // Test creating Limits with custom values + // This validates the struct can hold different configurations correctly + customLimits := &Limits{ + MaxFiles: 100, + MaxFileSize: 1024 * 1024, // 1MB + AllowDirs: false, + AllowSymlinks: true, + } + + assert.Equal(t, uint32(100), customLimits.MaxFiles, "Custom MaxFiles should be preserved") + assert.Equal(t, int64(1024*1024), customLimits.MaxFileSize, "Custom MaxFileSize should be preserved") + assert.False(t, customLimits.AllowDirs, "Custom AllowDirs should be preserved") + assert.True(t, customLimits.AllowSymlinks, "Custom AllowSymlinks should be preserved") +} + +func TestLimitsZeroValues(t *testing.T) { + // Test that zero values in Limits struct behave as expected + // This is important for understanding the semantic meaning of zero values + var limits Limits + + // Zero values should represent the most restrictive settings + assert.Equal(t, uint32(0), limits.MaxFiles, "Zero value MaxFiles should be 0") + assert.Equal(t, int64(0), limits.MaxFileSize, "Zero value MaxFileSize should be 0") + assert.False(t, limits.AllowDirs, "Zero value AllowDirs should be false (more restrictive)") + assert.False(t, limits.AllowSymlinks, "Zero value AllowSymlinks should be false (more secure)") +} + +func TestLimitsIndependence(t *testing.T) { + // Test that multiple calls to DefaultLimits return independent objects + // This ensures modifications to one instance don't affect others + limits1 := DefaultLimits() + limits2 := DefaultLimits() + + // Verify they start with the same values + assert.Equal(t, limits1.MaxFiles, limits2.MaxFiles) + assert.Equal(t, limits1.MaxFileSize, limits2.MaxFileSize) + assert.Equal(t, limits1.AllowDirs, limits2.AllowDirs) + assert.Equal(t, limits1.AllowSymlinks, limits2.AllowSymlinks) + + // Modify one instance + limits1.MaxFiles = 50 + limits1.AllowDirs = false + + // Verify the other instance is unchanged + assert.Equal(t, uint32(0), limits2.MaxFiles, "Modifying one instance should not affect another") + assert.True(t, limits2.AllowDirs, "Modifying one instance should not affect another") +} + +func TestLimitsExtremeValues(t *testing.T) { + // Test Limits with extreme values to ensure robust handling + // This validates the struct can handle boundary conditions + extremeLimits := &Limits{ + MaxFiles: ^uint32(0), // Maximum uint32 value + MaxFileSize: 9223372036854775807, // Maximum int64 value + AllowDirs: true, + AllowSymlinks: true, + } + + // These extreme values should be preserved without overflow or corruption + assert.Equal(t, ^uint32(0), extremeLimits.MaxFiles, "Maximum uint32 value should be preserved") + assert.Equal(t, int64(9223372036854775807), extremeLimits.MaxFileSize, "Maximum int64 value should be preserved") + assert.True(t, extremeLimits.AllowDirs, "Boolean values should be preserved") + assert.True(t, extremeLimits.AllowSymlinks, "Boolean values should be preserved") +} + +func TestLimitsSemantics(t *testing.T) { + // Test the semantic meaning of limit values + // This documents and validates the intended behavior of different settings + + // Test unlimited semantics (0 values) + unlimited := &Limits{ + MaxFiles: 0, + MaxFileSize: 0, + } + + // Zero should mean "no limit" for numeric fields + // This is a common convention in Unix systems + assert.Equal(t, uint32(0), unlimited.MaxFiles, "Zero MaxFiles should mean unlimited") + assert.Equal(t, int64(0), unlimited.MaxFileSize, "Zero MaxFileSize should mean unlimited") + + // Test specific limits + limited := &Limits{ + MaxFiles: 10, + MaxFileSize: 1000, + } + + assert.True(t, limited.MaxFiles > 0, "Non-zero MaxFiles should impose a limit") + assert.True(t, limited.MaxFileSize > 0, "Non-zero MaxFileSize should impose a limit") +} \ No newline at end of file diff --git a/internal/aclspec/migrate.go b/internal/aclspec/migrate.go deleted file mode 100644 index 091121e9..00000000 --- a/internal/aclspec/migrate.go +++ /dev/null @@ -1,36 +0,0 @@ -package aclspec - -import ( - "fmt" - - "gopkg.in/yaml.v3" -) - -type PermissionType string - -const ( - Read PermissionType = "read" - Create PermissionType = "create" - Write PermissionType = "write" - Execute PermissionType = "admin" -) - -type LegacyRule struct { - Path string `yaml:"path"` - User string `yaml:"user"` - Permissions []PermissionType `yaml:"permissions"` -} - -type LegacyPermission struct { - Rules []LegacyRule -} - -// UnmarshalYAML implements the yaml.Unmarshaler interface. -// This allows LegacyPermission to be unmarshalled from a YAML sequence -// directly into its Rules field. -func (lp *LegacyPermission) UnmarshalYAML(node *yaml.Node) error { - if node.Kind != yaml.SequenceNode { - return fmt.Errorf("cannot unmarshal %s into LegacyPermission.Rules, expected a sequence", node.Tag) - } - return node.Decode(&lp.Rules) -} diff --git a/internal/aclspec/rule_test.go b/internal/aclspec/rule_test.go new file mode 100644 index 00000000..ff6be3c1 --- /dev/null +++ b/internal/aclspec/rule_test.go @@ -0,0 +1,247 @@ +package aclspec + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewRule(t *testing.T) { + // Test creating a new Rule with all components + // This validates the core Rule constructor works correctly + pattern := "*.txt" + access := PublicReadAccess() + limits := DefaultLimits() + + rule := NewRule(pattern, access, limits) + + // Verify all components are properly assigned + assert.NotNil(t, rule, "NewRule should return a non-nil Rule") + assert.Equal(t, pattern, rule.Pattern, "Pattern should be preserved") + assert.Equal(t, access, rule.Access, "Access should be preserved") + assert.Equal(t, limits, rule.Limits, "Limits should be preserved") +} + +func TestNewRuleWithDifferentPatterns(t *testing.T) { + // Test creating rules with various pattern types + // This ensures the constructor handles different glob patterns correctly + testCases := []struct { + pattern string + desc string + }{ + { + pattern: "**", + desc: "Universal pattern should be preserved", + }, + { + pattern: "*.go", + desc: "Extension-based pattern should be preserved", + }, + { + pattern: "specific.txt", + desc: "Specific filename pattern should be preserved", + }, + { + pattern: "folder/**/*.py", + desc: "Complex nested pattern should be preserved", + }, + { + pattern: "", + desc: "Empty pattern should be preserved (even if invalid)", + }, + } + + access := PrivateAccess() + limits := DefaultLimits() + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + rule := NewRule(tc.pattern, access, limits) + assert.Equal(t, tc.pattern, rule.Pattern, tc.desc) + }) + } +} + +func TestNewRuleWithDifferentAccess(t *testing.T) { + // Test creating rules with different access configurations + // This validates rules can be created with various permission levels + pattern := "test.txt" + limits := DefaultLimits() + + testCases := []struct { + access *Access + desc string + }{ + { + access: PrivateAccess(), + desc: "Private access should be preserved", + }, + { + access: PublicReadAccess(), + desc: "Public read access should be preserved", + }, + { + access: PublicReadWriteAccess(), + desc: "Public read-write access should be preserved", + }, + { + access: SharedReadAccess("user1", "user2"), + desc: "Shared read access should be preserved", + }, + { + access: SharedReadWriteAccess("maintainer"), + desc: "Shared read-write access should be preserved", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + rule := NewRule(pattern, tc.access, limits) + assert.Equal(t, tc.access, rule.Access, tc.desc) + }) + } +} + +func TestNewRuleWithDifferentLimits(t *testing.T) { + // Test creating rules with different limit configurations + // This validates rules can enforce various restrictions + pattern := "test.txt" + access := PrivateAccess() + + testCases := []struct { + limits *Limits + desc string + }{ + { + limits: DefaultLimits(), + desc: "Default limits should be preserved", + }, + { + limits: &Limits{ + MaxFileSize: 1024, + MaxFiles: 10, + AllowDirs: false, + AllowSymlinks: false, + }, + desc: "Custom restrictive limits should be preserved", + }, + { + limits: &Limits{ + MaxFileSize: 0, // Unlimited + MaxFiles: 0, // Unlimited + AllowDirs: true, + AllowSymlinks: true, + }, + desc: "Custom permissive limits should be preserved", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + rule := NewRule(pattern, access, tc.limits) + assert.Equal(t, tc.limits, rule.Limits, tc.desc) + }) + } +} + +func TestNewRuleWithNilInputs(t *testing.T) { + // Test creating rules with nil inputs + // This documents behavior with invalid inputs (system should handle gracefully) + pattern := "test.txt" + + // Test with nil access (this might be invalid but shouldn't crash) + rule := NewRule(pattern, nil, DefaultLimits()) + assert.Equal(t, pattern, rule.Pattern, "Pattern should be preserved even with nil access") + assert.Nil(t, rule.Access, "Nil access should be preserved") + assert.NotNil(t, rule.Limits, "Limits should be preserved") + + // Test with nil limits (this might be invalid but shouldn't crash) + rule = NewRule(pattern, PrivateAccess(), nil) + assert.Equal(t, pattern, rule.Pattern, "Pattern should be preserved even with nil limits") + assert.NotNil(t, rule.Access, "Access should be preserved") + assert.Nil(t, rule.Limits, "Nil limits should be preserved") +} + +func TestNewDefaultRule(t *testing.T) { + // Test creating a default rule (catch-all pattern) + // This validates the convenience constructor for the most common default rule + access := PrivateAccess() + limits := DefaultLimits() + + rule := NewDefaultRule(access, limits) + + // Verify the rule uses the universal pattern + assert.NotNil(t, rule, "NewDefaultRule should return a non-nil Rule") + assert.Equal(t, AllFiles, rule.Pattern, "Default rule should use the AllFiles pattern (**)") + assert.Equal(t, access, rule.Access, "Access should be preserved") + assert.Equal(t, limits, rule.Limits, "Limits should be preserved") +} + +func TestNewDefaultRuleWithDifferentAccess(t *testing.T) { + // Test default rule creation with various access levels + // This ensures default rules can be created with any permission configuration + limits := DefaultLimits() + + testCases := []struct { + access *Access + desc string + }{ + { + access: PrivateAccess(), + desc: "Default rule with private access", + }, + { + access: PublicReadAccess(), + desc: "Default rule with public read access", + }, + { + access: SharedReadWriteAccess("admin"), + desc: "Default rule with shared access", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + rule := NewDefaultRule(tc.access, limits) + assert.Equal(t, AllFiles, rule.Pattern, "All default rules should use AllFiles pattern") + assert.Equal(t, tc.access, rule.Access, tc.desc) + }) + } +} + +func TestRulePatternConstants(t *testing.T) { + // Test that rules correctly use the expected pattern constants + // This ensures consistency in pattern usage across the system + + // Default rule should use AllFiles constant + defaultRule := NewDefaultRule(PrivateAccess(), DefaultLimits()) + assert.Equal(t, "**", defaultRule.Pattern, "Default rule should use the AllFiles constant value") + assert.Equal(t, AllFiles, defaultRule.Pattern, "Default rule should use AllFiles constant") + + // Custom rule should preserve custom pattern + customRule := NewRule("custom/*.txt", PrivateAccess(), DefaultLimits()) + assert.NotEqual(t, AllFiles, customRule.Pattern, "Custom rule should not use AllFiles pattern") + assert.Equal(t, "custom/*.txt", customRule.Pattern, "Custom rule should preserve its pattern") +} + +func TestRuleIndependence(t *testing.T) { + // Test that multiple rules are independent objects + // This ensures modifying one rule doesn't affect others + access1 := PrivateAccess() + access2 := PublicReadAccess() + limits1 := DefaultLimits() + limits2 := &Limits{MaxFileSize: 1000} + + rule1 := NewRule("*.txt", access1, limits1) + rule2 := NewRule("*.md", access2, limits2) + + // Verify rules are independent + assert.NotEqual(t, rule1.Pattern, rule2.Pattern, "Rules should have different patterns") + assert.NotEqual(t, rule1.Access, rule2.Access, "Rules should have different access") + assert.NotEqual(t, rule1.Limits, rule2.Limits, "Rules should have different limits") + + // Verify shared objects are actually the same references when intended + rule3 := NewRule("*.go", access1, limits1) + assert.Equal(t, rule1.Access, rule3.Access, "Rules sharing the same access object should reference the same object") + assert.Equal(t, rule1.Limits, rule3.Limits, "Rules sharing the same limits object should reference the same object") +} \ No newline at end of file diff --git a/internal/aclspec/ruleset.go b/internal/aclspec/ruleset.go index 68ebcbc9..eedf09b2 100644 --- a/internal/aclspec/ruleset.go +++ b/internal/aclspec/ruleset.go @@ -19,7 +19,7 @@ type RuleSet struct { // NewRuleSet creates a new RuleSet instance with the given path, terminal flag, and initial rules. func NewRuleSet(path string, terminal bool, rules ...*Rule) *RuleSet { return &RuleSet{ - Path: WithoutAclPath(path), + Path: WithoutACLPath(path), Terminal: terminal, Rules: rules, } @@ -30,8 +30,21 @@ func (r *RuleSet) AllRules() []*Rule { } // LoadFromFile loads a RuleSet from the specified file path +// For security reasons, symlinks are not allowed as ACL files func LoadFromFile(path string) (*RuleSet, error) { - aclPath := AsAclPath(path) + aclPath := AsACLPath(path) + + // Check if file is a symlink before opening + stat, err := os.Lstat(aclPath) + if err != nil { + return nil, err + } + + // Reject symlinks for security reasons + if stat.Mode()&os.ModeSymlink != 0 { + return nil, fmt.Errorf("symlinks are not allowed as ACL files: %s", aclPath) + } + fd, err := os.Open(aclPath) if err != nil { return nil, err @@ -42,7 +55,7 @@ func LoadFromFile(path string) (*RuleSet, error) { // LoadFromReader creates a RuleSet by reading and parsing YAML content from the provided reader. // The path parameter is used to set the internal path of the RuleSet. -func LoadFromReader(path string, reader io.ReadCloser) (*RuleSet, error) { +func LoadFromReader(path string, reader io.Reader) (*RuleSet, error) { data, err := io.ReadAll(reader) if err != nil { return nil, err @@ -53,12 +66,12 @@ func LoadFromReader(path string, reader io.ReadCloser) (*RuleSet, error) { return nil, err } - ruleset.Path = WithoutAclPath(path) + ruleset.Path = WithoutACLPath(path) return setDefaults(&ruleset) } func (r *RuleSet) Save() error { - aclPath := AsAclPath(r.Path) + aclPath := AsACLPath(r.Path) file, err := os.Create(aclPath) if err != nil { return fmt.Errorf("failed to create file %s: %w", r.Path, err) diff --git a/internal/aclspec/ruleset_test.go b/internal/aclspec/ruleset_test.go new file mode 100644 index 00000000..aee03c62 --- /dev/null +++ b/internal/aclspec/ruleset_test.go @@ -0,0 +1,368 @@ +package aclspec + +import ( + "io" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +func TestNewRuleSet(t *testing.T) { + // Test creating a new RuleSet with basic configuration + // This validates the core RuleSet constructor + path := "test/path" + terminal := true + rule1 := NewRule("*.txt", PublicReadAccess(), DefaultLimits()) + rule2 := NewRule("*.md", PrivateAccess(), DefaultLimits()) + + ruleset := NewRuleSet(path, terminal, rule1, rule2) + + // Verify all components are properly assigned + assert.NotNil(t, ruleset, "NewRuleSet should return a non-nil RuleSet") + assert.Equal(t, path, ruleset.Path, "Path should be preserved") + assert.Equal(t, terminal, ruleset.Terminal, "Terminal flag should be preserved") + assert.Len(t, ruleset.Rules, 2, "Should have exactly 2 rules") + assert.Contains(t, ruleset.Rules, rule1, "Should contain first rule") + assert.Contains(t, ruleset.Rules, rule2, "Should contain second rule") +} + +func TestNewRuleSetWithAclPath(t *testing.T) { + // Test that NewRuleSet correctly handles ACL file paths + // This ensures the constructor normalizes paths by removing ACL filename + aclPath := "test/path/syft.pub.yaml" + expectedPath := "test/path/" + + ruleset := NewRuleSet(aclPath, false) + + assert.Equal(t, expectedPath, ruleset.Path, "ACL filename should be stripped from path") +} + +func TestNewRuleSetWithNoRules(t *testing.T) { + // Test creating a RuleSet with no initial rules + // This validates the constructor works with empty rule sets + ruleset := NewRuleSet("test/path", false) + + assert.Equal(t, "test/path", ruleset.Path) + assert.False(t, ruleset.Terminal) + assert.Empty(t, ruleset.Rules, "RuleSet with no rules should have empty Rules slice") +} + +func TestAllRules(t *testing.T) { + // Test the AllRules method returns the correct rules + // This is a simple getter but important for the interface + rule1 := NewRule("*.txt", PublicReadAccess(), DefaultLimits()) + rule2 := NewRule("*.md", PrivateAccess(), DefaultLimits()) + ruleset := NewRuleSet("test", false, rule1, rule2) + + rules := ruleset.AllRules() + + assert.Len(t, rules, 2, "AllRules should return all rules") + assert.Contains(t, rules, rule1, "Should contain first rule") + assert.Contains(t, rules, rule2, "Should contain second rule") +} + +func TestLoadFromReader(t *testing.T) { + // Test loading RuleSet from YAML reader + // This validates the core YAML parsing functionality + yamlContent := ` +terminal: true +rules: + - pattern: "*.txt" + access: + read: ["*"] + - pattern: "private/*" + access: + read: ["admin"] + write: ["admin"] +` + + reader := io.NopCloser(strings.NewReader(yamlContent)) + ruleset, err := LoadFromReader("test/path", reader) + + require.NoError(t, err, "LoadFromReader should succeed with valid YAML") + assert.NotNil(t, ruleset, "Should return a non-nil RuleSet") + + // Verify basic properties + assert.Equal(t, "test/path", ruleset.Path, "Path should be set correctly") + assert.True(t, ruleset.Terminal, "Terminal flag should be parsed correctly") + + // Should have the 2 explicit rules plus 1 default rule added by setDefaults + assert.Len(t, ruleset.Rules, 3, "Should have 2 explicit rules + 1 default rule") + + // Verify the first rule + rule1 := ruleset.Rules[0] + assert.Equal(t, "*.txt", rule1.Pattern, "First rule pattern should be correct") + assert.True(t, rule1.Access.Read.Contains("*"), "First rule should grant read access to everyone") + + // Verify the second rule + rule2 := ruleset.Rules[1] + assert.Equal(t, "private/*", rule2.Pattern, "Second rule pattern should be correct") + assert.True(t, rule2.Access.Read.Contains("admin"), "Second rule should grant read access to admin") + assert.True(t, rule2.Access.Write.Contains("admin"), "Second rule should grant write access to admin") +} + +func TestLoadFromReaderWithMinimalYAML(t *testing.T) { + // Test loading with minimal YAML (only terminal flag) + // This validates default rule injection works correctly + yamlContent := `terminal: false` + + reader := io.NopCloser(strings.NewReader(yamlContent)) + ruleset, err := LoadFromReader("test", reader) + + require.NoError(t, err, "LoadFromReader should succeed with minimal YAML") + assert.False(t, ruleset.Terminal, "Terminal flag should be parsed correctly") + + // Should have exactly 1 default rule added by setDefaults + assert.Len(t, ruleset.Rules, 1, "Should have 1 default rule") + assert.Equal(t, "**", ruleset.Rules[0].Pattern, "Default rule should have AllFiles pattern") +} + +func TestLoadFromReaderWithEmptyYAML(t *testing.T) { + // Test loading with completely empty YAML + // This validates the system handles empty files gracefully + yamlContent := `` + + reader := io.NopCloser(strings.NewReader(yamlContent)) + ruleset, err := LoadFromReader("test", reader) + + require.NoError(t, err, "LoadFromReader should succeed with empty YAML") + assert.False(t, ruleset.Terminal, "Terminal should default to false") + + // Should have exactly 1 default rule added by setDefaults + assert.Len(t, ruleset.Rules, 1, "Should have 1 default rule") + assert.Equal(t, "**", ruleset.Rules[0].Pattern, "Default rule should have AllFiles pattern") +} + +func TestLoadFromReaderWithInvalidYAML(t *testing.T) { + // Test that invalid YAML is properly rejected + // This ensures the system fails safely on malformed input + yamlContent := ` +invalid: yaml: content: + - missing + proper: structure +` + + reader := io.NopCloser(strings.NewReader(yamlContent)) + ruleset, err := LoadFromReader("test", reader) + + assert.Error(t, err, "LoadFromReader should fail with invalid YAML") + assert.Nil(t, ruleset, "Should return nil RuleSet on error") +} + +func TestLoadFromFile(t *testing.T) { + // Test loading RuleSet from file on disk + // This validates file I/O integration + tempDir := t.TempDir() + aclFile := filepath.Join(tempDir, AclFileName) + + yamlContent := ` +terminal: true +rules: + - pattern: "*.go" + access: + read: ["developers@company.com"] +` + + err := os.WriteFile(aclFile, []byte(yamlContent), 0644) + require.NoError(t, err, "Should be able to write test file") + + ruleset, err := LoadFromFile(tempDir) + require.NoError(t, err, "LoadFromFile should succeed") + + assert.Equal(t, tempDir, ruleset.Path, "Path should be directory (without ACL filename)") + assert.True(t, ruleset.Terminal, "Terminal flag should be loaded correctly") + assert.Len(t, ruleset.Rules, 2, "Should have 1 explicit rule + 1 default rule") +} + +func TestLoadFromFileNonExistent(t *testing.T) { + // Test loading from non-existent file + // This validates error handling for missing files + nonExistentPath := "/path/that/does/not/exist" + + ruleset, err := LoadFromFile(nonExistentPath) + assert.Error(t, err, "LoadFromFile should fail for non-existent file") + assert.Nil(t, ruleset, "Should return nil RuleSet on error") +} + +func TestRuleSetSave(t *testing.T) { + // Test saving RuleSet to file + // This validates YAML serialization and file I/O + tempDir := t.TempDir() + + // Create a RuleSet with test data + rule1 := NewRule("*.txt", PublicReadAccess(), DefaultLimits()) + rule2 := NewRule("secret/*", PrivateAccess(), &Limits{MaxFileSize: 1024}) + ruleset := NewRuleSet(tempDir, true, rule1, rule2) + + err := ruleset.Save() + require.NoError(t, err, "Save should succeed") + + // Verify the file was created + aclFile := filepath.Join(tempDir, AclFileName) + assert.FileExists(t, aclFile, "ACL file should be created") + + // Load the file back and verify content + content, err := os.ReadFile(aclFile) + require.NoError(t, err, "Should be able to read saved file") + + var loaded RuleSet + err = yaml.Unmarshal(content, &loaded) + require.NoError(t, err, "Saved YAML should be valid") + + assert.True(t, loaded.Terminal, "Terminal flag should be preserved") + assert.Len(t, loaded.Rules, 2, "All rules should be saved") +} + +func TestRuleSetSaveInvalidPath(t *testing.T) { + // Test saving to invalid path + // This validates error handling for file I/O failures + ruleset := NewRuleSet("/invalid/path/that/cannot/be/created", false) + + err := ruleset.Save() + assert.Error(t, err, "Save should fail for invalid path") +} + +func TestSetDefaults(t *testing.T) { + // Test the setDefaults function directly + // This validates the default rule injection logic + + // Test with nil rules + ruleset := &RuleSet{Path: "test", Terminal: false, Rules: nil} + result, err := setDefaults(ruleset) + + require.NoError(t, err, "setDefaults should succeed with nil rules") + assert.Len(t, result.Rules, 1, "Should add exactly one default rule") + assert.Equal(t, "**", result.Rules[0].Pattern, "Default rule should have AllFiles pattern") +} + +func TestSetDefaultsWithExistingDefaultRule(t *testing.T) { + // Test setDefaults when a default rule already exists + // This ensures default rules aren't duplicated + existingDefault := NewRule("**", PrivateAccess(), DefaultLimits()) + customRule := NewRule("*.txt", PublicReadAccess(), DefaultLimits()) + + ruleset := &RuleSet{ + Path: "test", + Terminal: false, + Rules: []*Rule{customRule, existingDefault}, + } + + result, err := setDefaults(ruleset) + + require.NoError(t, err, "setDefaults should succeed with existing default rule") + assert.Len(t, result.Rules, 2, "Should not add additional default rule") + + // Verify the existing default rule is preserved + hasDefault := false + for _, rule := range result.Rules { + if rule.Pattern == "**" { + hasDefault = true + break + } + } + assert.True(t, hasDefault, "Should preserve existing default rule") +} + +func TestSetDefaultsValidation(t *testing.T) { + // Test setDefaults validation of rule requirements + // This ensures invalid rules are properly rejected + + // Test with empty pattern + invalidRule := NewRule("", PrivateAccess(), DefaultLimits()) + ruleset := &RuleSet{ + Path: "test", + Terminal: false, + Rules: []*Rule{invalidRule}, + } + + result, err := setDefaults(ruleset) + assert.Error(t, err, "setDefaults should reject rules with empty patterns") + assert.Nil(t, result, "Should return nil on validation error") + + // Test with nil access + invalidRule2 := NewRule("*.txt", nil, DefaultLimits()) + ruleset2 := &RuleSet{ + Path: "test", + Terminal: false, + Rules: []*Rule{invalidRule2}, + } + + result, err = setDefaults(ruleset2) + assert.Error(t, err, "setDefaults should reject rules with nil access") + assert.Nil(t, result, "Should return nil on validation error") +} + +func TestSetDefaultsLimitsInjection(t *testing.T) { + // Test that setDefaults adds default limits to rules missing them + // This ensures all rules have proper limits configuration + ruleWithoutLimits := NewRule("*.txt", PrivateAccess(), nil) + + ruleset := &RuleSet{ + Path: "test", + Terminal: false, + Rules: []*Rule{ruleWithoutLimits}, + } + + result, err := setDefaults(ruleset) + + require.NoError(t, err, "setDefaults should succeed and inject limits") + assert.NotNil(t, result.Rules[0].Limits, "Rule should have limits after setDefaults") + + // Verify the limits are the default limits + expectedLimits := DefaultLimits() + assert.Equal(t, expectedLimits.MaxFiles, result.Rules[0].Limits.MaxFiles) + assert.Equal(t, expectedLimits.MaxFileSize, result.Rules[0].Limits.MaxFileSize) + assert.Equal(t, expectedLimits.AllowDirs, result.Rules[0].Limits.AllowDirs) + assert.Equal(t, expectedLimits.AllowSymlinks, result.Rules[0].Limits.AllowSymlinks) +} + +func TestRuleSetRoundTrip(t *testing.T) { + // Test complete round-trip: create -> save -> load -> verify + // This validates the entire serialization/deserialization pipeline + tempDir := t.TempDir() + + // Create original RuleSet + original := NewRuleSet(tempDir, true, + NewRule("*.go", SharedReadAccess("dev1@company.com", "dev2@university.edu"), DefaultLimits()), + NewRule("docs/*", PublicReadAccess(), DefaultLimits()), + ) + + // Save to file + err := original.Save() + require.NoError(t, err, "Save should succeed") + + // Load from file + loaded, err := LoadFromFile(tempDir) + require.NoError(t, err, "Load should succeed") + + // Verify key properties are preserved + assert.Equal(t, original.Path, loaded.Path, "Path should be preserved") + assert.Equal(t, original.Terminal, loaded.Terminal, "Terminal flag should be preserved") + + // Note: loaded will have additional default rule, so we check the original rules exist + assert.True(t, len(loaded.Rules) >= len(original.Rules), "Loaded rules should include all original rules") + + // Verify specific rules exist (order might differ due to default rule injection) + hasGoRule := false + hasDocsRule := false + for _, rule := range loaded.Rules { + if rule.Pattern == "*.go" { + hasGoRule = true + assert.True(t, rule.Access.Read.Contains("dev1@company.com"), "Go rule should preserve dev1 access") + assert.True(t, rule.Access.Read.Contains("dev2@university.edu"), "Go rule should preserve dev2 access") + } + if rule.Pattern == "docs/*" { + hasDocsRule = true + assert.True(t, rule.Access.Read.Contains("*"), "Docs rule should preserve public access") + // Note: Limits are not serialized to YAML (yaml:"-" tag), so they get default values after round-trip + assert.NotNil(t, rule.Limits, "Docs rule should have default limits after round-trip") + } + } + assert.True(t, hasGoRule, "Should preserve *.go rule") + assert.True(t, hasDocsRule, "Should preserve docs/* rule") +} diff --git a/internal/client/config/config.go b/internal/client/config/config.go index 53709b58..46ccd313 100644 --- a/internal/client/config/config.go +++ b/internal/client/config/config.go @@ -17,7 +17,7 @@ var ( home, _ = os.UserHomeDir() DefaultConfigPath = filepath.Join(home, ".syftbox", "config.json") DefaultDataDir = filepath.Join(home, "SyftBox") - DefaultServerURL = "https://syftboxdev.openmined.org" + DefaultServerURL = "https://syftbox.net" DefaultClientURL = "http://localhost:7938" DefaultLogFilePath = filepath.Join(home, ".syftbox", "logs", "syftbox.log") DefaultAppsEnabled = true diff --git a/internal/client/middleware/cors.go b/internal/client/middleware/cors.go index fe51011c..ac1ed2be 100644 --- a/internal/client/middleware/cors.go +++ b/internal/client/middleware/cors.go @@ -9,9 +9,14 @@ import ( // control plane cors config var corsConfig = cors.Config{ - AllowAllOrigins: true, - AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD"}, - AllowHeaders: []string{"Origin", "Content-Length", "Content-Type", "Authorization"}, + AllowAllOrigins: true, + AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD"}, + AllowHeaders: []string{ + "Origin", + "Content-Length", + "Content-Type", + "Authorization", + }, AllowCredentials: true, MaxAge: 12 * time.Hour, } diff --git a/internal/client/sync/sync_engine.go b/internal/client/sync/sync_engine.go index 95293e45..dfae471b 100644 --- a/internal/client/sync/sync_engine.go +++ b/internal/client/sync/sync_engine.go @@ -2,6 +2,7 @@ package sync import ( "context" + "encoding/json" "errors" "fmt" "log/slog" @@ -470,6 +471,8 @@ func (se *SyncEngine) handleSocketEvents(ctx context.Context) { go se.handlePriorityError(msg) case syftmsg.MsgFileWrite: go se.handlePriorityDownload(msg) + case syftmsg.MsgHttp: + go se.processHttpMessage(msg) default: slog.Debug("websocket unhandled type", "type", msg.Type) } @@ -497,6 +500,64 @@ func (se *SyncEngine) handleWatcherEvents(ctx context.Context) { } } +func (se *SyncEngine) processHttpMessage(msg *syftmsg.Message) { + slog.Debug("processHttpMessage", "msgType", msg.Type, "msgId", msg.Id) + httpMsg, ok := msg.Data.(*syftmsg.HttpMsg) + if !ok { + slog.Error("processHttpMessage: invalid type assertion for msg.Data", + "msgType", msg.Type, + "msgId", msg.Id, + "dataType", fmt.Sprintf("%T", msg.Data)) + return + } + + slog.Debug("handle", "msgType", msg.Type, "msgId", msg.Id, "httpMsg", httpMsg) + + // Unwrap the into a syftmsg.SyftRPCMessage + syftRPCMsg, err := syftmsg.NewSyftRPCMessage(*httpMsg) + if err != nil { + slog.Error("processHttpMessage: failed to create syftRPCMsg", + "error", err, + "msgType", msg.Type, + "msgId", msg.Id, + "httpMsg", httpMsg) + return + } + + // rpc message file name + fileName := syftRPCMsg.ID.String() + "." + string(httpMsg.Type) + + filePath := filepath.Join( + se.workspace.DatasiteAbsPath(syftRPCMsg.URL.ToLocalPath()), // app_data/{app_name}/rpc/{endpoint} + fileName, + ) + + slog.Debug("Received RPC message", "RPCPath", filePath) + + // Convert the syftRPCMsg to json + jsonRPCMsg, err := json.Marshal(syftRPCMsg) + if err != nil { + slog.Error("handleHttp marshal syftRPCMsg", + "error", err, + "msgId", syftRPCMsg.ID, + "filePath", filePath) + return + } + + // write the RPCMsg to the file + err = os.WriteFile(filePath, jsonRPCMsg, 0644) + if err != nil { + slog.Error("handleHttp write file", + "error", err, + "filePath", filePath, + "msgId", syftRPCMsg.ID, + "fileSize", len(jsonRPCMsg)) + return + } + + slog.Debug("SyftRPC Message", "msg", string(jsonRPCMsg), "filePath", filePath) +} + func (se *SyncEngine) handleSystem(msg *syftmsg.Message) { systemMsg := msg.Data.(syftmsg.System) slog.Info("handle", "msgType", msg.Type, "msgId", msg.Id, "serverVersion", systemMsg.SystemVersion) diff --git a/internal/client/sync/sync_engine_test.go b/internal/client/sync/sync_engine_test.go deleted file mode 100644 index cae32326..00000000 --- a/internal/client/sync/sync_engine_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package sync - -// func TestSyncEngineFullSync(t *testing.T) { -// dummyDatasite, err := datasite.NewLocalDatasite("~/SyftBox", "yash@openmined.org") -// assert.NoError(t, err) - -// sdk, err := syftsdk.New("https://syftboxdev.openmined.org") -// sdk.Login("yash@openmined.org") -// assert.NoError(t, err) - -// ignore := NewSyncIgnore(dummyDatasite.DatasitesDir, dummyDatasite.DatasitesDir) -// watcher := NewFileWatcher(dummyDatasite.DatasitesDir) - -// syncEngine := NewSyncEngine(dummyDatasite, sdk, ignore, watcher) -// err = syncEngine.RunSync(context.Background()) -// assert.NoError(t, err) -// } - -// func TestSyncEngineReconcile(t *testing.T) { -// syncEngine := &SyncEngine{} - -// journal := map[string]*FileMetadata{ -// "/test/file1": { -// Path: "/test/file1", -// ETag: "123", -// Version: "1", -// Size: 100, -// LastModified: time.Now(), -// }, -// "/test/file4": { -// Path: "/test/file4", -// ETag: "sadlajklsd", -// Version: "1", -// Size: 1012310, -// LastModified: time.Now(), -// }, -// } - -// localState := map[string]*FileMetadata{ -// "/test/file1": { -// Path: "/test/file1", -// ETag: "123", -// Version: "1", -// Size: 100, -// LastModified: time.Now(), -// }, -// "/test/file3": { -// Path: "/test/file3", -// ETag: "ashldk", -// Version: "1", -// Size: 10, -// LastModified: time.Now(), -// }, -// } - -// remoteState := map[string]*FileMetadata{ -// "/test/file1": { -// Path: "/test/file1", -// ETag: "defg", -// Version: "2", // new version -// Size: 100, -// LastModified: time.Now(), -// }, -// "/test/file4": { -// Path: "/test/file4", -// ETag: "sadlajklsd", -// Version: "1", -// Size: 1012310, -// LastModified: time.Now(), -// }, -// } - -// // this should download file1, delete file4, and write file3 -// result := syncEngine.reconcile(localState, remoteState, journal) -// assert.Equal(t, 1, len(result.RemoteWrites)) -// assert.Equal(t, 1, len(result.LocalWrites)) -// assert.Equal(t, 1, len(result.RemoteDeletes)) -// assert.Equal(t, "/test/file3", result.RemoteWrites["/test/file3"].Path) -// assert.Equal(t, "/test/file1", result.LocalWrites["/test/file1"].Path) -// assert.Equal(t, "/test/file4", result.RemoteDeletes["/test/file4"].Path) -// } diff --git a/internal/client/sync/sync_ignore.go b/internal/client/sync/sync_ignore.go index 898a815d..b2555068 100644 --- a/internal/client/sync/sync_ignore.go +++ b/internal/client/sync/sync_ignore.go @@ -1,6 +1,14 @@ package sync import ( + "bufio" + "fmt" + "log/slog" + "os" + "path/filepath" + "strings" + + "github.com/openmined/syftbox/internal/utils" gitignore "github.com/sabhiram/go-gitignore" ) @@ -37,11 +45,60 @@ type SyncIgnoreList struct { } func NewSyncIgnoreList(baseDir string) *SyncIgnoreList { - ignore := gitignore.CompileIgnoreLines(defaultIgnoreLines...) - return &SyncIgnoreList{baseDir: baseDir, ignore: ignore} + return &SyncIgnoreList{baseDir: baseDir} +} + +func (s *SyncIgnoreList) Load() { + ignorePath := filepath.Join(s.baseDir, "syftignore") + ignoreLines := defaultIgnoreLines + + // read the syftignore file if it exists + if utils.FileExists(ignorePath) { + customRules, err := readIgnoreFile(ignorePath) + if err != nil { + slog.Warn("failed to read syftignore file", "path", ignorePath, "error", err) + } else if len(customRules) > 0 { + ignoreLines = append(ignoreLines, customRules...) + slog.Info("loaded syftignore file", "path", ignorePath, "rules", len(customRules)) + } + } + + s.ignore = gitignore.CompileIgnoreLines(ignoreLines...) } func (s *SyncIgnoreList) ShouldIgnore(path string) bool { - // todo strip baseDir from relPath return s.ignore.MatchesPath(path) } + +func readIgnoreFile(path string) ([]string, error) { + if path == "" { + return nil, fmt.Errorf("ignore file path is empty") + } + + file, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("failed to open ignore file: %w", err) + } + defer file.Close() + + ignoreLines := []string{} + scanner := bufio.NewScanner(file) + + for scanner.Scan() { + line := scanner.Text() + line = strings.TrimSpace(line) + + // comments, empty lines, and null bytes + if strings.HasPrefix(line, "#") || line == "" || strings.Contains(line, "\x00") { + continue + } + + ignoreLines = append(ignoreLines, line) + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("error reading ignore file: %w", err) + } + + return ignoreLines, nil +} diff --git a/internal/client/sync/sync_manager.go b/internal/client/sync/sync_manager.go index 31e5c64e..61782c2e 100644 --- a/internal/client/sync/sync_manager.go +++ b/internal/client/sync/sync_manager.go @@ -39,6 +39,10 @@ func NewManager(workspace *workspace.Workspace, sdk *syftsdk.SyftSDK) (*SyncMana func (m *SyncManager) Start(ctx context.Context) error { slog.Info("sync manager start") + + // load the ignore list + m.ignore.Load() + if err := m.watcher.Start(ctx); err != nil { return fmt.Errorf("failed to start watcher: %w", err) } diff --git a/internal/server/acl/acl.go b/internal/server/acl/acl.go index 921dbeaa..d631bbaf 100644 --- a/internal/server/acl/acl.go +++ b/internal/server/acl/acl.go @@ -1,47 +1,54 @@ package acl import ( + "fmt" + "log/slog" + "strings" + "github.com/openmined/syftbox/internal/aclspec" ) -// AclService helps to manage and enforce access control rules for file system operations. -type AclService struct { - tree *Tree - cache *RuleCache +// ACLService helps to manage and enforce access control rules for file system operations. +type ACLService struct { + tree *ACLTree + cache *ACLCache } -// NewAclService creates a new ACL service instance -func NewAclService() *AclService { - return &AclService{ - tree: NewTree(), - cache: NewRuleCache(), +// NewACLService creates a new ACL service instance +func NewACLService() *ACLService { + return &ACLService{ + tree: NewACLTree(), + cache: NewACLCache(), } } -func (s *AclService) LoadRuleSets(ruleSets []*aclspec.RuleSet) error { - for _, ruleSet := range ruleSets { - if err := s.tree.AddRuleSet(ruleSet); err != nil { - return err - } +// AddRuleSet adds or updates a new set of rules to the service. +func (s *ACLService) AddRuleSet(ruleSet *aclspec.RuleSet) (ACLVersion, error) { + node, err := s.tree.AddRuleSet(ruleSet) + if err != nil { + return 0, err } - return nil -} -// AddRuleSet adds a new set of rules to the service. -func (s *AclService) AddRuleSet(ruleSet *aclspec.RuleSet) error { - return s.tree.AddRuleSet(ruleSet) + deleted := s.cache.DeletePrefix(ruleSet.Path) + slog.Debug("updated rule set", "path", node.path, "version", node.version, "cache.deleted", deleted) + return node.version, nil } // RemoveRuleSet removes a ruleset at the specified path. // Returns true if a ruleset was removed, false otherwise. -func (s *AclService) RemoveRuleSet(path string) bool { - s.cache.DeletePrefix(path) - return s.tree.RemoveRuleSet(path) +// path must be a dir or dir/syft.pub.yaml +func (s *ACLService) RemoveRuleSet(path string) bool { + path = aclspec.WithoutACLPath(path) + if ok := s.tree.RemoveRuleSet(path); ok { + deleted := s.cache.DeletePrefix(path) + slog.Debug("deleted cached rules", "path", path, "count", deleted) + return true + } + return false } // GetRule finds the most specific rule applicable to the given path. -func (s *AclService) GetRule(path string) (*Rule, error) { - // Normalize path to use forward slashes for glob matching +func (s *ACLService) GetRule(path string) (*ACLRule, error) { path = ACLNormPath(path) // cache hit @@ -51,9 +58,9 @@ func (s *AclService) GetRule(path string) (*Rule, error) { } // cache miss - rule, err := s.tree.GetRule(path) // O(depth) + rule, err := s.tree.GetEffectiveRule(path) // O(depth) if err != nil { - return nil, err + return nil, fmt.Errorf("no effective rules for path '%s': %w", path, err) } // cache the result @@ -63,36 +70,45 @@ func (s *AclService) GetRule(path string) (*Rule, error) { } // CanAccess checks if a user has the specified access permission for a file. -func (s *AclService) CanAccess(user *User, file *File, level AccessLevel) error { - if user.IsOwner { +func (s *ACLService) CanAccess(user *User, file *File, level AccessLevel) error { + // early return if user is the owner + if isOwner(file.Path, user.ID) { return nil } + // get the effective rule for the file rule, err := s.GetRule(file.Path) if err != nil { return err } - isAcl := aclspec.IsAclFile(file.Path) + // Elevate ACL file writes to admin level + if aclspec.IsACLFile(file.Path) && level >= AccessCreate { + level = AccessAdmin + } - // elevate action for ACL files - if isAcl && level == AccessWrite { - level = AccessWriteACL - } else if level == AccessWrite { - // writes need to be checked against the file limits + // Check file limits for write operations + if level >= AccessCreate { if err := rule.CheckLimits(file); err != nil { - return err + return fmt.Errorf("file limits exceeded for user '%s' on path '%s': %w", user.ID, file.Path, err) } } + // finally check the access if err := rule.CheckAccess(user, level); err != nil { - return err + return fmt.Errorf("access denied for user '%s' on path '%s': %w", user.ID, file.Path, err) } return nil } // String returns a string representation of the ACL service's rule tree. -func (s *AclService) String() string { +func (s *ACLService) String() string { return s.tree.String() } + +// checks if the user is the owner of the path +func isOwner(path string, user string) bool { + path = ACLNormPath(path) + return strings.HasPrefix(path, user) +} diff --git a/internal/server/acl/acl_test.go b/internal/server/acl/acl_test.go index d46a1fc9..6bf74227 100644 --- a/internal/server/acl/acl_test.go +++ b/internal/server/acl/acl_test.go @@ -8,7 +8,7 @@ import ( ) func TestAclServiceGetRule(t *testing.T) { - service := NewAclService() + service := NewACLService() // Add a ruleset ruleset := aclspec.NewRuleSet( @@ -18,8 +18,9 @@ func TestAclServiceGetRule(t *testing.T) { aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) - err := service.AddRuleSet(ruleset) + ver, err := service.AddRuleSet(ruleset) assert.NoError(t, err) + assert.Equal(t, ACLVersion(1), ver) // Test cache miss rules assert.NotContains(t, service.cache.index, "user/readme.md") @@ -42,47 +43,49 @@ func TestAclServiceGetRule(t *testing.T) { } func TestAclServiceRemoveRuleSet(t *testing.T) { - service := NewAclService() + service := NewACLService() // Add two rulesets ruleset1 := aclspec.NewRuleSet( - "folder1", + "user1@email.com", aclspec.SetTerminal, aclspec.NewRule("*.txt", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), ) ruleset2 := aclspec.NewRuleSet( - "folder2", + "user2@email.com", aclspec.SetTerminal, aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) - err := service.AddRuleSet(ruleset1) + ver, err := service.AddRuleSet(ruleset1) assert.NoError(t, err) + assert.Equal(t, ACLVersion(1), ver) - err = service.AddRuleSet(ruleset2) + ver, err = service.AddRuleSet(ruleset2) assert.NoError(t, err) + assert.Equal(t, ACLVersion(1), ver) // Verify both rulesets work - rule, err := service.GetRule("folder1/file.txt") + rule, err := service.GetRule("user1@email.com/file.txt") assert.NoError(t, err) assert.NotNil(t, rule) - rule, err = service.GetRule("folder2/file.txt") + rule, err = service.GetRule("user2@email.com/file.txt") assert.NoError(t, err) assert.NotNil(t, rule) // Remove one ruleset - removed := service.RemoveRuleSet("folder1") + removed := service.RemoveRuleSet("user1@email.com") assert.True(t, removed) // Verify removed ruleset no longer works - rule, err = service.GetRule("folder1/file.txt") + rule, err = service.GetRule("user1@email.com/file.txt") assert.Error(t, err) assert.Nil(t, rule) // Verify other ruleset still works - rule, err = service.GetRule("folder2/file.txt") + rule, err = service.GetRule("user2@email.com/file.txt") assert.NoError(t, err) assert.NotNil(t, rule) @@ -92,27 +95,28 @@ func TestAclServiceRemoveRuleSet(t *testing.T) { } func TestAclServiceCanAccess(t *testing.T) { - service := NewAclService() + service := NewACLService() // Add a ruleset with different access levels ruleset := aclspec.NewRuleSet( - "user", + "user1@email.com", aclspec.SetTerminal, aclspec.NewRule("public/*.txt", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), aclspec.NewRule("private/*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), aclspec.NewRule("**", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) - err := service.AddRuleSet(ruleset) + ver, err := service.AddRuleSet(ruleset) assert.NoError(t, err) + assert.Equal(t, ACLVersion(1), ver) // Test cases with different users and files - owner := &User{ID: "user", IsOwner: true} - regularUser := &User{ID: aclspec.Everyone, IsOwner: false} + owner := &User{ID: "user1@email.com"} + regularUser := &User{ID: aclspec.Everyone} - publicFile := &File{Path: "user/public/doc.txt", Size: 100} - privateFile := &File{Path: "user/private/secret.txt", Size: 100} - aclFile := &File{Path: aclspec.AsAclPath("user"), Size: 100} + publicFile := &File{Path: "user1@email.com/public/doc.txt", Size: 100} + privateFile := &File{Path: "user1@email.com/private/secret.txt", Size: 100} + aclFile := &File{Path: aclspec.AsACLPath("user1@email.com"), Size: 100} // Owner should have access to everything err = service.CanAccess(owner, publicFile, AccessRead) @@ -129,119 +133,126 @@ func TestAclServiceCanAccess(t *testing.T) { assert.NoError(t, err) err = service.CanAccess(regularUser, publicFile, AccessWrite) - assert.ErrorIs(t, err, ErrWriteRequired) + assert.ErrorIs(t, err, ErrNoWriteAccess) err = service.CanAccess(regularUser, privateFile, AccessRead) - assert.ErrorIs(t, err, ErrReadRequired) + assert.ErrorIs(t, err, ErrNoReadAccess) // ACL files should have special handling err = service.CanAccess(regularUser, aclFile, AccessWrite) - assert.ErrorIs(t, err, ErrAdminRequired) + assert.ErrorIs(t, err, ErrNoAdminAccess) } func TestAclServiceFileLimits(t *testing.T) { - service := NewAclService() + service := NewACLService() - // Add a ruleset with file size limits - limits := aclspec.Limits{ - MaxFileSize: 100, - AllowDirs: true, - } + owner := "user1@email.com" + someUser := "user2@email.com" ruleset := aclspec.NewRuleSet( - "files", + owner, aclspec.SetTerminal, - aclspec.NewRule("small/*.txt", aclspec.PublicReadWriteAccess(), &limits), + aclspec.NewRule( + "dir/*.txt", + aclspec.PublicReadWriteAccess(), + &aclspec.Limits{MaxFileSize: 100, AllowDirs: true}, + ), ) - err := service.AddRuleSet(ruleset) + ver, err := service.AddRuleSet(ruleset) assert.NoError(t, err) - - regularUser := &User{IsOwner: false} + assert.Equal(t, ACLVersion(1), ver) // File within size limit - smallFile := &File{Path: "files/small/small.txt", Size: 50} - err = service.CanAccess(regularUser, smallFile, AccessWrite) + smallFile := &File{Path: "user1@email.com/dir/small.txt", Size: 50} + err = service.CanAccess(&User{ID: someUser}, smallFile, AccessWrite) assert.NoError(t, err) // File exceeding size limit - largeFile := &File{Path: "files/small/large.txt", Size: 200} - err = service.CanAccess(regularUser, largeFile, AccessWrite) + largeFile := &File{Path: "user1@email.com/dir/large.txt", Size: 200} + err = service.CanAccess(&User{ID: someUser}, largeFile, AccessWrite) assert.ErrorIs(t, err, ErrFileSizeExceeded) // Owner should bypass size limits - owner := &User{IsOwner: true} - err = service.CanAccess(owner, largeFile, AccessWrite) + err = service.CanAccess(&User{ID: owner}, largeFile, AccessWrite) assert.NoError(t, err) } func TestAclServiceLoadRuleSets(t *testing.T) { - service := NewAclService() + service := NewACLService() // Create multiple rulesets ruleset1 := aclspec.NewRuleSet( - "folder1", + "user1@email.com", aclspec.SetTerminal, aclspec.NewRule("*.txt", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), ) ruleset2 := aclspec.NewRuleSet( - "folder2", + "user2@email.com", aclspec.SetTerminal, aclspec.NewRule("*.md", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) // Load multiple rulesets at once - err := service.LoadRuleSets([]*aclspec.RuleSet{ruleset1, ruleset2}) + ver, err := service.AddRuleSet(ruleset1) + assert.NoError(t, err) + assert.Equal(t, ACLVersion(1), ver) + + ver, err = service.AddRuleSet(ruleset2) + assert.NoError(t, err) + assert.Equal(t, ACLVersion(1), ver) assert.NoError(t, err) // Verify both rulesets work - rule, err := service.GetRule("folder1/file.txt") + rule, err := service.GetRule("user1@email.com/file.txt") assert.NoError(t, err) assert.NotNil(t, rule) assert.Equal(t, "*.txt", rule.rule.Pattern) - rule, err = service.GetRule("folder2/file.md") + rule, err = service.GetRule("user2@email.com/file.md") assert.NoError(t, err) assert.NotNil(t, rule) assert.Equal(t, "*.md", rule.rule.Pattern) } func TestAclServiceCacheInvalidation(t *testing.T) { - service := NewAclService() + service := NewACLService() // Add a ruleset rulesetv1 := aclspec.NewRuleSet( - "user", + "user1@email.com", aclspec.UnsetTerminal, aclspec.NewRule("*.md", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), ) - err := service.AddRuleSet(rulesetv1) + ver, err := service.AddRuleSet(rulesetv1) assert.NoError(t, err) + assert.Equal(t, ACLVersion(1), ver) // Access a path to cache the rule - rule, err := service.GetRule("user/readme.md") + rule, err := service.GetRule("user1@email.com/readme.md") assert.NoError(t, err) assert.NotNil(t, rule) assert.Equal(t, "*.md", rule.rule.Pattern) - assert.Contains(t, service.cache.index, "user/readme.md") + assert.Contains(t, service.cache.index, "user1@email.com/readme.md") // Replace the ruleset with different permissions rulesetv2 := aclspec.NewRuleSet( - "user", + "user1@email.com", aclspec.SetTerminal, aclspec.NewRule("*.md", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) // Add new ruleset - err = service.AddRuleSet(rulesetv2) + ver, err = service.AddRuleSet(rulesetv2) assert.NoError(t, err) + assert.Equal(t, ACLVersion(2), ver) // Access the same path, should get the new rule - rule, err = service.GetRule("user/readme.md") + rule, err = service.GetRule("user1@email.com/readme.md") assert.NoError(t, err) assert.NotNil(t, rule) - assert.True(t, rule.node.IsTerminal()) - assert.Equal(t, rule.node.Version(), uint8(2)) + assert.True(t, rule.node.GetTerminal()) + assert.Equal(t, rule.node.GetVersion(), ACLVersion(2)) } diff --git a/internal/server/acl/cache.go b/internal/server/acl/cache.go index ec906bbf..32a88e50 100644 --- a/internal/server/acl/cache.go +++ b/internal/server/acl/cache.go @@ -5,65 +5,62 @@ import ( "sync" ) -type cacheEntry struct { - rule *Rule - version uint8 -} - -type RuleCache struct { - index map[string]*cacheEntry +// ACLCache stores the effective ACL rule for a given path. +type ACLCache struct { + index map[string]*ACLRule // Normalized ACLPath -> ACLRule mu sync.RWMutex } -func NewRuleCache() *RuleCache { - return &RuleCache{ - index: make(map[string]*cacheEntry), +// NewACLCache creates a new ACLCache. +func NewACLCache() *ACLCache { + return &ACLCache{ + index: make(map[string]*ACLRule), } } -func (c *RuleCache) Get(path string) *Rule { +// Get returns the effective ACL rule for the given path. +func (c *ACLCache) Get(path string) *ACLRule { c.mu.RLock() - cached, ok := c.index[path] + cacheRule, ok := c.index[path] c.mu.RUnlock() - if !ok { - return nil - } - // validate the cache entry - valid := cached.rule.node.Version() == cached.version - if !valid { - c.Delete(path) + if !ok { return nil } - return cached.rule + return cacheRule } -func (c *RuleCache) Set(path string, rule *Rule) { +// Set sets the effective ACL rule for the given path. +func (c *ACLCache) Set(path string, rule *ACLRule) { c.mu.Lock() defer c.mu.Unlock() - c.index[path] = &cacheEntry{ - rule: rule, - version: rule.node.Version(), - } + c.index[path] = rule } -func (c *RuleCache) Delete(path string) { +// Delete deletes the effective ACL rule for the given path. +func (c *ACLCache) Delete(path string) { c.mu.Lock() defer c.mu.Unlock() delete(c.index, path) } -func (c *RuleCache) DeletePrefix(path string) { +// DeletePrefix deletes the effective ACL rule for all paths that match the given prefix. +func (c *ACLCache) DeletePrefix(path string) int { c.mu.Lock() defer c.mu.Unlock() + deleted := 0 + // iterate over index keys and delete the entry for k := range c.index { if strings.HasPrefix(k, path) { delete(c.index, k) + deleted++ } } + + return deleted } diff --git a/internal/server/acl/cache_test.go b/internal/server/acl/cache_test.go new file mode 100644 index 00000000..5ff55ba7 --- /dev/null +++ b/internal/server/acl/cache_test.go @@ -0,0 +1,347 @@ +package acl + +import ( + "fmt" + "sync" + "testing" + + "github.com/openmined/syftbox/internal/aclspec" + "github.com/stretchr/testify/assert" +) + +func TestNewACLCache(t *testing.T) { + // Test creating a new cache + // This validates the constructor initializes the cache correctly + cache := NewACLCache() + + assert.NotNil(t, cache, "NewACLCache should return non-nil cache") + assert.NotNil(t, cache.index, "Cache should have initialized index map") + assert.Empty(t, cache.index, "New cache should start empty") +} + +func TestACLCacheBasicOperations(t *testing.T) { + // Test basic cache operations: Set, Get, Delete + // This validates the core cache functionality + cache := NewACLCache() + + // Create a mock rule for testing + mockNode := NewACLNode("test", "testuser", false, 1) + mockRule := &ACLRule{ + fullPattern: "test/*.txt", + rule: aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + node: mockNode, + } + + // Test Get on empty cache + result := cache.Get("test/file.txt") + assert.Nil(t, result, "Get should return nil for non-existent entries") + + // Test Set and Get + cache.Set("test/file.txt", mockRule) + result = cache.Get("test/file.txt") + assert.Equal(t, mockRule, result, "Get should return the cached rule") + + // Test Delete + cache.Delete("test/file.txt") + result = cache.Get("test/file.txt") + assert.Nil(t, result, "Get should return nil after deletion") +} + +func TestACLCacheVersionValidation(t *testing.T) { + // Test that cache validates rule versions to detect stale entries + // This is critical for cache invalidation when rules are updated + cache := NewACLCache() + + // Create a node and rule + mockNode := NewACLNode("test", "testuser", false, 1) + mockRule := &ACLRule{ + fullPattern: "test/*.txt", + rule: aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + node: mockNode, + } + + // Cache the rule with current node version + cache.Set("test/file.txt", mockRule) + + // Verify we can retrieve it + result := cache.Get("test/file.txt") + assert.Equal(t, mockRule, result, "Should retrieve cached rule with valid version") + + // Simulate node version change (like when rules are updated) + mockNode.SetRules([]*aclspec.Rule{ + aclspec.NewRule("*.md", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), + }, false) + + // The cache currently doesn't validate versions, so this test documents + // that the cache may return stale entries after node updates + result = cache.Get("test/file.txt") + assert.Equal(t, mockRule, result, "Cache currently returns stale entries (no version validation)") +} + +func TestACLCacheDeletePrefix(t *testing.T) { + // Test the DeletePrefix operation which removes multiple entries + // This is important for bulk cache invalidation when directory rules change + cache := NewACLCache() + + // Create multiple cache entries with related paths + mockNode := NewACLNode("test", "testuser", false, 1) + mockRule := &ACLRule{ + fullPattern: "test/*.txt", + rule: aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + node: mockNode, + } + + // Add entries with common prefix + cache.Set("project/src/file1.go", mockRule) + cache.Set("project/src/file2.go", mockRule) + cache.Set("project/docs/readme.md", mockRule) + cache.Set("project/tests/test1.go", mockRule) + cache.Set("other/file.txt", mockRule) // Different prefix + + // Verify all entries exist + assert.NotNil(t, cache.Get("project/src/file1.go")) + assert.NotNil(t, cache.Get("project/src/file2.go")) + assert.NotNil(t, cache.Get("project/docs/readme.md")) + assert.NotNil(t, cache.Get("project/tests/test1.go")) + assert.NotNil(t, cache.Get("other/file.txt")) + + // Delete entries with "project/src" prefix + cache.DeletePrefix("project/src") + + // Verify only the prefixed entries were removed + assert.Nil(t, cache.Get("project/src/file1.go"), "Should remove project/src/file1.go") + assert.Nil(t, cache.Get("project/src/file2.go"), "Should remove project/src/file2.go") + assert.NotNil(t, cache.Get("project/docs/readme.md"), "Should keep project/docs/readme.md") + assert.NotNil(t, cache.Get("project/tests/test1.go"), "Should keep project/tests/test1.go") + assert.NotNil(t, cache.Get("other/file.txt"), "Should keep other/file.txt") + + // Delete broader prefix + cache.DeletePrefix("project") + + // Verify all project entries are removed + assert.Nil(t, cache.Get("project/docs/readme.md"), "Should remove project/docs/readme.md") + assert.Nil(t, cache.Get("project/tests/test1.go"), "Should remove project/tests/test1.go") + assert.NotNil(t, cache.Get("other/file.txt"), "Should keep other/file.txt") +} + +func TestACLCacheDeletePrefixEdgeCases(t *testing.T) { + // Test DeletePrefix with edge cases and boundary conditions + // This ensures robust handling of unusual prefix patterns + cache := NewACLCache() + + mockNode := NewACLNode("test", "testuser", false, 1) + mockRule := &ACLRule{ + fullPattern: "test/*.txt", + rule: aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + node: mockNode, + } + + // Add test entries + cache.Set("", mockRule) // Empty path + cache.Set("a", mockRule) // Single character + cache.Set("ab", mockRule) // Two characters + cache.Set("abc", mockRule) // Three characters + cache.Set("abd", mockRule) // Similar prefix + + // Test deleting with empty prefix (should remove all entries that start with empty string, which is all entries) + cache.DeletePrefix("") + assert.Nil(t, cache.Get(""), "Should remove empty path entry") + assert.Nil(t, cache.Get("a"), "Should remove single character entry (empty prefix matches all)") + + // Re-add entries for next test + cache.Set("a", mockRule) + cache.Set("ab", mockRule) + cache.Set("abc", mockRule) + cache.Set("abd", mockRule) + + // Test deleting with single character prefix + cache.DeletePrefix("a") + assert.Nil(t, cache.Get("a"), "Should remove 'a'") + assert.Nil(t, cache.Get("ab"), "Should remove 'ab'") + assert.Nil(t, cache.Get("abc"), "Should remove 'abc'") + assert.Nil(t, cache.Get("abd"), "Should remove 'abd'") + + // Test deleting non-existent prefix (should be safe) + cache.DeletePrefix("nonexistent") + // Should not crash or cause issues +} + +func TestACLCacheConcurrency(t *testing.T) { + // Test that cache operations are thread-safe + // This validates the mutex protection works correctly + cache := NewACLCache() + + mockNode := NewACLNode("test", "testuser", false, 1) + mockRule := &ACLRule{ + fullPattern: "test/*.txt", + rule: aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + node: mockNode, + } + + const numGoroutines = 10 + const numOperations = 100 + + // Test concurrent Set operations + var wg sync.WaitGroup + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < numOperations; j++ { + key := fmt.Sprintf("test/%d/%d.txt", id, j) + cache.Set(key, mockRule) + } + }(i) + } + wg.Wait() + + // Test concurrent Get operations + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < numOperations; j++ { + key := fmt.Sprintf("test/%d/%d.txt", id, j) + result := cache.Get(key) + assert.NotNil(t, result, "Should find cached entry") + } + }(i) + } + wg.Wait() + + // Test concurrent Delete operations + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < numOperations; j++ { + key := fmt.Sprintf("test/%d/%d.txt", id, j) + cache.Delete(key) + } + }(i) + } + wg.Wait() + + // Verify all entries were deleted + for i := 0; i < numGoroutines; i++ { + for j := 0; j < numOperations; j++ { + key := fmt.Sprintf("test/%d/%d.txt", i, j) + result := cache.Get(key) + assert.Nil(t, result, "Entry should be deleted") + } + } +} + +func TestACLCacheMixedConcurrentOperations(t *testing.T) { + // Test mixed concurrent operations (Set, Get, Delete, DeletePrefix) + // This validates thread safety under realistic usage patterns + cache := NewACLCache() + + mockNode := NewACLNode("test", "testuser", false, 1) + mockRule := &ACLRule{ + fullPattern: "test/*.txt", + rule: aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + node: mockNode, + } + + const numWorkers = 5 + const duration = 100 // Number of operations per worker + + var wg sync.WaitGroup + + // Worker 1: Continuous Set operations + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < duration; i++ { + key := fmt.Sprintf("set/worker/%d.txt", i) + cache.Set(key, mockRule) + } + }() + + // Worker 2: Continuous Get operations + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < duration; i++ { + key := fmt.Sprintf("set/worker/%d.txt", i%10) // Access recently set items + cache.Get(key) + } + }() + + // Worker 3: Continuous Delete operations + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < duration; i++ { + key := fmt.Sprintf("set/worker/%d.txt", i) + cache.Delete(key) + } + }() + + // Worker 4: Periodic DeletePrefix operations + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < duration/10; i++ { + prefix := fmt.Sprintf("set/worker") + cache.DeletePrefix(prefix) + } + }() + + // Worker 5: Mixed operations + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < duration; i++ { + key := fmt.Sprintf("mixed/%d.txt", i) + cache.Set(key, mockRule) + cache.Get(key) + if i%5 == 0 { + cache.Delete(key) + } + } + }() + + wg.Wait() + + // Test should complete without deadlocks or race conditions + // The exact final state is unpredictable due to concurrency, + // but the operations should all complete successfully +} + +func TestACLCacheMemoryManagement(t *testing.T) { + // Test that cache doesn't leak memory with repeated operations + // This validates proper cleanup of cache entries + cache := NewACLCache() + + mockNode := NewACLNode("test", "testuser", false, 1) + mockRule := &ACLRule{ + fullPattern: "test/*.txt", + rule: aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + node: mockNode, + } + + // Add many entries + for i := 0; i < 1000; i++ { + key := fmt.Sprintf("test/%d.txt", i) + cache.Set(key, mockRule) + } + + // Verify entries exist + assert.Equal(t, 1000, len(cache.index), "Should have 1000 entries") + + // Clear using DeletePrefix + cache.DeletePrefix("test") + + // Verify all entries are removed + assert.Equal(t, 0, len(cache.index), "Should have 0 entries after DeletePrefix") + + // Add entries again to test reuse + for i := 0; i < 100; i++ { + key := fmt.Sprintf("reuse/%d.txt", i) + cache.Set(key, mockRule) + } + + assert.Equal(t, 100, len(cache.index), "Should be able to reuse cache after clearing") +} \ No newline at end of file diff --git a/internal/server/acl/level.go b/internal/server/acl/level.go index 8f71cfa4..524ccafe 100644 --- a/internal/server/acl/level.go +++ b/internal/server/acl/level.go @@ -8,23 +8,41 @@ const ( AccessRead AccessLevel = 1 << iota AccessCreate AccessWrite - AccessReadACL - AccessWriteACL + AccessAdmin ) func (a AccessLevel) String() string { - switch a { - case AccessRead: - return "Read" - case AccessCreate: - return "Create" - case AccessWrite: - return "Write" - case AccessReadACL: - return "ReadACL" - case AccessWriteACL: - return "WriteACL" - default: + if a == 0 { + return "None" + } + + var parts []string + + if (a & AccessRead) == AccessRead { + parts = append(parts, "Read") + } + if (a & AccessCreate) == AccessCreate { + parts = append(parts, "Create") + } + if (a & AccessWrite) == AccessWrite { + parts = append(parts, "Write") + } + if (a & AccessAdmin) == AccessAdmin { + parts = append(parts, "Admin") + } + + if len(parts) == 0 { return "Unknown" } + + if len(parts) == 1 { + return parts[0] + } + + // For multiple permissions, join with "+" + result := parts[0] + for i := 1; i < len(parts); i++ { + result += "+" + parts[i] + } + return result } diff --git a/internal/server/acl/level_test.go b/internal/server/acl/level_test.go new file mode 100644 index 00000000..2635543c --- /dev/null +++ b/internal/server/acl/level_test.go @@ -0,0 +1,211 @@ +package acl + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAccessLevelString(t *testing.T) { + // Test the String() method for all AccessLevel constants + // This validates that each access level has the correct string representation + // which is important for logging, debugging, and user-facing error messages + testCases := []struct { + level AccessLevel + expected string + desc string + }{ + { + level: AccessRead, + expected: "Read", + desc: "AccessRead should return 'Read'", + }, + { + level: AccessCreate, + expected: "Create", + desc: "AccessCreate should return 'Create'", + }, + { + level: AccessWrite, + expected: "Write", + desc: "AccessWrite should return 'Write'", + }, + { + level: AccessAdmin, + expected: "Admin", + desc: "AccessAdmin should return 'Admin'", + }, + { + level: 0, + expected: "None", + desc: "Zero value should return 'None'", + }, + { + level: AccessLevel(16), + expected: "Unknown", + desc: "Undefined values should return 'Unknown'", + }, + { + level: AccessRead | AccessWrite, + expected: "Read+Write", + desc: "Combined permissions should be joined with '+'", + }, + { + level: AccessRead | AccessCreate | AccessWrite, + expected: "Read+Create+Write", + desc: "Multiple combined permissions should be joined in order", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + result := tc.level.String() + assert.Equal(t, tc.expected, result, tc.desc) + }) + } +} + +func TestAccessLevelStringUnknown(t *testing.T) { + // Test String() method with invalid/unknown access level values + // This ensures the system handles undefined access levels gracefully + unknownLevel := AccessLevel(16) // Invalid access level (higher than any defined bit) + result := unknownLevel.String() + assert.Equal(t, "Unknown", result, "Unknown access levels should return 'Unknown'") +} + +func TestAccessLevelValues(t *testing.T) { + // Test that AccessLevel constants have the expected values + // Using bit flags: 1 << iota creates powers of 2 + + assert.Equal(t, AccessLevel(1), AccessRead, "AccessRead should be 1 (1 << 0)") + assert.Equal(t, AccessLevel(2), AccessCreate, "AccessCreate should be 2 (1 << 1)") + assert.Equal(t, AccessLevel(4), AccessWrite, "AccessWrite should be 4 (1 << 2)") + assert.Equal(t, AccessLevel(8), AccessAdmin, "AccessAdmin should be 8 (1 << 3)") +} + +func TestAccessLevelUniqueness(t *testing.T) { + // Test that all AccessLevel constants are unique + // This prevents accidental duplicate values that could cause permission conflicts + levels := []AccessLevel{ + AccessRead, + AccessCreate, + AccessWrite, + AccessAdmin, + } + + // Check that no two levels have the same value + for i, level1 := range levels { + for j, level2 := range levels { + if i != j { + assert.NotEqual(t, level1, level2, + "AccessLevel constants should be unique: %s and %s have the same value", + level1.String(), level2.String()) + } + } + } +} + +func TestAccessLevelHierarchy(t *testing.T) { + // Test the logical hierarchy of access levels + // This documents the intended permission hierarchy in the system + + // Verify the ordering based on bit values + assert.True(t, AccessRead < AccessCreate, "Read should be lower than Create") + assert.True(t, AccessCreate < AccessWrite, "Create should be lower than Write") + assert.True(t, AccessWrite < AccessAdmin, "Write should be lower than Admin") + assert.True(t, AccessRead < AccessAdmin, "Read should be lower than Admin") +} + +func TestAccessLevelZeroValue(t *testing.T) { + // Test the zero value of AccessLevel + // This ensures the default/uninitialized value behaves correctly + var zeroLevel AccessLevel + + assert.Equal(t, AccessLevel(0), zeroLevel, "Zero value should be 0") + assert.Equal(t, "None", zeroLevel.String(), "Zero value should return 'None'") + + // Zero should not match any defined permission + assert.NotEqual(t, AccessRead, zeroLevel, "Zero should not equal AccessRead") + assert.NotEqual(t, AccessCreate, zeroLevel, "Zero should not equal AccessCreate") + assert.NotEqual(t, AccessWrite, zeroLevel, "Zero should not equal AccessWrite") + assert.NotEqual(t, AccessAdmin, zeroLevel, "Zero should not equal AccessAdmin") +} + +func TestAccessLevelCasting(t *testing.T) { + // Test that AccessLevel can be properly cast from and to uint8 + // This validates the underlying type compatibility + + // Test casting from uint8 + readLevel := AccessLevel(1) + assert.Equal(t, AccessRead, readLevel, "Should be able to cast uint8 to AccessLevel") + + // Test casting to uint8 + readValue := uint8(AccessRead) + assert.Equal(t, uint8(1), readValue, "Should be able to cast AccessLevel to uint8") + + // Test round-trip casting + originalLevel := AccessAdmin + castValue := uint8(originalLevel) + backToLevel := AccessLevel(castValue) + assert.Equal(t, originalLevel, backToLevel, "Round-trip casting should preserve value") +} + +func TestAccessLevelStringConsistency(t *testing.T) { + // Test that String() method is consistent across multiple calls + // This ensures no side effects or state changes in the String() method + level := AccessWrite + + firstCall := level.String() + secondCall := level.String() + thirdCall := level.String() + + assert.Equal(t, firstCall, secondCall, "String() should return consistent results") + assert.Equal(t, secondCall, thirdCall, "String() should return consistent results") + assert.Equal(t, "Write", firstCall, "String() should return correct value") +} + +func TestAccessLevelEdgeCases(t *testing.T) { + // Test edge cases and boundary values + // This ensures robust handling of unusual but possible values + + // Test maximum uint8 value + maxLevel := AccessLevel(255) + assert.Equal(t, "Read+Create+Write+Admin", maxLevel.String(), "Maximum value should show all known bits set") + + // Test values between defined constants + betweenLevels := AccessLevel(16) // Higher than all defined constants + assert.Equal(t, "Unknown", betweenLevels.String(), "Undefined values should be unknown") + + // Test that undefined values are handled correctly + undefinedLevel := AccessLevel(32) + assert.Equal(t, "Unknown", undefinedLevel.String(), "Higher undefined values should be unknown") +} + +func TestAccessLevelBitOperations(t *testing.T) { + // Test bit flag operations + // This validates that permissions can be combined and checked using bitwise operations + + // Test combining permissions + readWrite := AccessRead | AccessWrite + assert.Equal(t, AccessLevel(5), readWrite, "Read | Write should be 5 (1 | 4)") + assert.Equal(t, "Read+Write", readWrite.String(), "Combined permissions should show both") + + // Test checking individual permissions + allPerms := AccessRead | AccessCreate | AccessWrite | AccessAdmin + assert.True(t, (allPerms & AccessRead) == AccessRead, "Should have Read permission") + assert.True(t, (allPerms & AccessCreate) == AccessCreate, "Should have Create permission") + assert.True(t, (allPerms & AccessWrite) == AccessWrite, "Should have Write permission") + assert.True(t, (allPerms & AccessAdmin) == AccessAdmin, "Should have Admin permission") + + // Test absence of permissions + readOnly := AccessRead + assert.True(t, (readOnly & AccessRead) == AccessRead, "Should have Read permission") + assert.False(t, (readOnly & AccessWrite) == AccessWrite, "Should not have Write permission") + assert.False(t, (readOnly & AccessAdmin) == AccessAdmin, "Should not have Admin permission") + + // Test removing permissions + allButAdmin := allPerms &^ AccessAdmin + assert.True(t, (allButAdmin & AccessRead) == AccessRead, "Should still have Read") + assert.True(t, (allButAdmin & AccessWrite) == AccessWrite, "Should still have Write") + assert.False(t, (allButAdmin & AccessAdmin) == AccessAdmin, "Should not have Admin") +} \ No newline at end of file diff --git a/internal/server/acl/node.go b/internal/server/acl/node.go index b21cf5b8..44837218 100644 --- a/internal/server/acl/node.go +++ b/internal/server/acl/node.go @@ -1,6 +1,7 @@ package acl import ( + "slices" "sort" "strings" "sync" @@ -9,133 +10,185 @@ import ( "github.com/openmined/syftbox/internal/aclspec" ) -// Node represents a node in the ACL tree. +// ACLVersion is the version of the node. +// overflow will reset it to 0. +type ACLVersion = uint16 + +// ACLDepth is the depth of the node in the tree. +type ACLDepth = uint8 + +const ( + ACLMaxDepth = 1<<8 - 1 // keep this in sync with the type ACLDepth + ACLMaxVersion = 1<<16 - 1 // keep this in sync with the type ACLVersion +) + +// ACLNode represents a node in the ACL tree. // Each node corresponds to a part of the path and contains rules for that part. -type Node struct { +type ACLNode struct { mu sync.RWMutex - rules []*Rule // rules for this part of the path. sorted by specificity. - path string // path is the full path to this node - children map[string]*Node // key is the part of the path - terminal bool // true if this node is a terminal node - depth uint8 // depth of the node in the tree. 0 is root node - version uint8 // version of the node. incremented on every change + rules []*ACLRule // rules for this part of the path. sorted by specificity. + children map[string]*ACLNode // key is the part of the path + owner string // owner of the node + path string // path is the full path to this Anode + terminal bool // true if this node is a terminal node + depth ACLDepth // depth of the node in the tree. 0 is root node + version ACLVersion // version of the node. incremented on every change } -func NewNode(path string, terminal bool, depth uint8) *Node { - return &Node{ +// NewACLNode creates a new ACLNode. +func NewACLNode(path string, owner string, terminal bool, depth ACLDepth) *ACLNode { + // note rules & children are not initialized here. + // this is to avoid unnecessary allocations, until the node is set with rules. + return &ACLNode{ path: path, + owner: owner, terminal: terminal, depth: depth, + version: 0, } } -func (n *Node) Version() uint8 { +// GetChild returns the child for the node. +func (n *ACLNode) GetChild(key string) (*ACLNode, bool) { n.mu.RLock() defer n.mu.RUnlock() - return n.version -} - -func (n *Node) IsTerminal() bool { - n.mu.RLock() - defer n.mu.RUnlock() - return n.terminal -} - -func (n *Node) Depth() uint8 { - n.mu.RLock() - defer n.mu.RUnlock() - return n.depth -} - -func (n *Node) Rules() []*Rule { - n.mu.RLock() - defer n.mu.RUnlock() - return n.rules + child, exists := n.children[key] + return child, exists } -func (n *Node) SetChild(key string, child *Node) { +// SetChild sets the child for the node. +func (n *ACLNode) SetChild(key string, child *ACLNode) { n.mu.Lock() defer n.mu.Unlock() if n.children == nil { - n.children = make(map[string]*Node) + n.children = make(map[string]*ACLNode) } if child == nil { delete(n.children, key) } else { n.children[key] = child } + n.version++ } -func (n *Node) GetChild(key string) (*Node, bool) { +// GetChildCount returns the number of children for the node. +func (n *ACLNode) GetChildCount() int { n.mu.RLock() defer n.mu.RUnlock() - child, exists := n.children[key] - return child, exists + return len(n.children) } -func (n *Node) DeleteChild(key string) { +// DeleteChild deletes the child for the node. +func (n *ACLNode) DeleteChild(key string) { n.mu.Lock() defer n.mu.Unlock() delete(n.children, key) + n.version++ +} + +// GetRules returns the rules for the node. +func (n *ACLNode) GetRules() []*ACLRule { + n.mu.RLock() + defer n.mu.RUnlock() + return n.rules } // SetRules the rules, terminal flag and depth for the node. // Increments the version counter for repeated operation. -func (n *Node) SetRules(rules []*aclspec.Rule, terminal bool) { +func (n *ACLNode) SetRules(rules []*aclspec.Rule, terminal bool) { n.mu.Lock() defer n.mu.Unlock() if len(rules) > 0 { // pre-sort the rules by specificity - sorted := sortBySpecificity(rules) + sorted := sortRulesBySpecificity(rules) // convert the rules to aclRules - aclRules := make([]*Rule, 0, len(sorted)) + aclRules := make([]*ACLRule, 0, len(sorted)) for _, rule := range sorted { - aclRules = append(aclRules, &Rule{ + aclRules = append(aclRules, &ACLRule{ rule: rule, node: n, fullPattern: ACLJoinPath(n.path, rule.Pattern), }) } n.rules = aclRules + } else { + // Clear rules when empty or nil slice is provided + n.rules = nil } - // set the rules and terminal flag + // set the terminal flag n.terminal = terminal - // increment the version. uint8 overflow will reset it to 0. + // increment the version + n.version++ +} + +// ClearRules clears the rules for the node. +func (n *ACLNode) ClearRules() { + n.mu.Lock() + defer n.mu.Unlock() + n.rules = nil n.version++ } // FindBestRule finds the best matching rule for the given path. -func (n *Node) FindBestRule(path string) (*Rule, error) { +func (n *ACLNode) FindBestRule(path string) (*ACLRule, error) { n.mu.RLock() defer n.mu.RUnlock() if n.rules == nil { - return nil, ErrNoRuleFound + return nil, ErrNoRule } - // find the best matching rule + // find the best matching rule (rules are already sorted by specificity) for _, aclRule := range n.rules { if ok, _ := doublestar.Match(aclRule.fullPattern, path); ok { return aclRule, nil } } - return nil, ErrNoRuleFound + return nil, ErrNoRule +} + +// GetOwner returns the owner of the node. +func (n *ACLNode) GetOwner() string { + n.mu.RLock() + defer n.mu.RUnlock() + return n.owner +} + +// GetVersion returns the version of the node. +func (n *ACLNode) GetVersion() ACLVersion { + n.mu.RLock() + defer n.mu.RUnlock() + return n.version +} + +// GetTerminal returns true if the node is a terminal node. +func (n *ACLNode) GetTerminal() bool { + n.mu.RLock() + defer n.mu.RUnlock() + return n.terminal +} + +// GetDepth returns the depth of the node. +func (n *ACLNode) GetDepth() ACLDepth { + n.mu.RLock() + defer n.mu.RUnlock() + return n.depth } // Equal checks if the node is equal to another node. -func (n *Node) Equal(other *Node) bool { +func (n *ACLNode) Equal(other *ACLNode) bool { n.mu.RLock() defer n.mu.RUnlock() - return n.path == other.path && n.terminal == other.terminal && n.depth == other.depth + return n.path == other.path && n.terminal == other.terminal && n.depth == other.depth && n.version == other.version } -func globSpecificityScore(glob string) int { - // exact +func calculateGlobSpecificity(glob string) int { + // early return for the most specific glob patterns switch glob { case "**": return -100 @@ -147,6 +200,7 @@ func globSpecificityScore(glob string) int { // Use forward slash for glob patterns score := len(glob)*2 + strings.Count(glob, ACLPathSep)*10 + // penalize base score for substr wildcards for i, c := range glob { switch c { case '*': @@ -163,13 +217,14 @@ func globSpecificityScore(glob string) int { return score } -func sortBySpecificity(rules []*aclspec.Rule) []*aclspec.Rule { +func sortRulesBySpecificity(rules []*aclspec.Rule) []*aclspec.Rule { // copy the rules - clone := append([]*aclspec.Rule(nil), rules...) + clone := slices.Clone(rules) - // sort by specificity, descending + // sort by specificity (or priority), descending sort.Slice(clone, func(i, j int) bool { - return globSpecificityScore(clone[i].Pattern) > globSpecificityScore(clone[j].Pattern) + return calculateGlobSpecificity(clone[i].Pattern) > calculateGlobSpecificity(clone[j].Pattern) }) + return clone } diff --git a/internal/server/acl/node_test.go b/internal/server/acl/node_test.go index d2a3d9fb..466ff00c 100644 --- a/internal/server/acl/node_test.go +++ b/internal/server/acl/node_test.go @@ -9,7 +9,7 @@ import ( func TestNodeFindBestRule(t *testing.T) { // Create a node with some rules - node := NewNode("test", false, 1) + node := NewACLNode("some/path", "user1", false, 1) // Create test rules with different patterns rules := []*aclspec.Rule{ @@ -21,19 +21,19 @@ func TestNodeFindBestRule(t *testing.T) { node.SetRules(rules, false) // Test matching with different paths - rule, err := node.FindBestRule("test/file.txt") + rule, err := node.FindBestRule("some/path/file.txt") assert.NoError(t, err) assert.Equal(t, "*.txt", rule.rule.Pattern) - rule, err = node.FindBestRule("test/file.md") + rule, err = node.FindBestRule("some/path/file.md") assert.NoError(t, err) assert.Equal(t, "file.md", rule.rule.Pattern) - rule, err = node.FindBestRule("test/subdir/main.go") + rule, err = node.FindBestRule("some/path/subdir/main.go") assert.NoError(t, err) assert.Equal(t, "**/*.go", rule.rule.Pattern) - rule, err = node.FindBestRule("test/main.go") + rule, err = node.FindBestRule("some/path/main.go") assert.NoError(t, err) assert.Equal(t, "**/*.go", rule.rule.Pattern) @@ -48,10 +48,10 @@ func TestNodeFindBestRule(t *testing.T) { } func TestNodeSetRules(t *testing.T) { - node := NewNode("test", false, 1) + node := NewACLNode("test", "user1", false, 1) // Initial version should be 0 - assert.Equal(t, uint8(0), node.Version()) + assert.Equal(t, ACLVersion(0), node.GetVersion()) // Set rules and check that version increments rules := []*aclspec.Rule{ @@ -61,13 +61,13 @@ func TestNodeSetRules(t *testing.T) { node.SetRules(rules, true) // Version should increment - assert.Equal(t, uint8(1), node.Version()) + assert.Equal(t, ACLVersion(1), node.GetVersion()) // Terminal flag should be set - assert.True(t, node.IsTerminal()) + assert.True(t, node.GetTerminal()) // Rules should be set - assert.Len(t, node.Rules(), 1) + assert.Len(t, node.GetRules(), 1) // Set new rules and check version increments again newRules := []*aclspec.Rule{ @@ -78,21 +78,21 @@ func TestNodeSetRules(t *testing.T) { node.SetRules(newRules, false) // Version should increment - assert.Equal(t, uint8(2), node.Version()) + assert.Equal(t, ACLVersion(2), node.GetVersion()) // Terminal flag should be updated - assert.False(t, node.IsTerminal()) + assert.False(t, node.GetTerminal()) // Rules should be updated - assert.Len(t, node.Rules(), 2) + assert.Len(t, node.GetRules(), 2) } func TestNodeEqual(t *testing.T) { - node1 := NewNode("test", false, 1) - node2 := NewNode("test", false, 1) - node3 := NewNode("different", false, 1) - node4 := NewNode("test", true, 1) - node5 := NewNode("test", false, 2) + node1 := NewACLNode("some/path", "user1", false, 1) + node2 := NewACLNode("some/path", "user1", false, 1) + node3 := NewACLNode("different/path", "user1", false, 1) + node4 := NewACLNode("some/path", "user1", true, 1) + node5 := NewACLNode("some/path", "user1", false, 2) assert.True(t, node1.Equal(node2), "Identical nodes should be equal") assert.False(t, node1.Equal(node3), "Nodes with different paths should not be equal") @@ -111,7 +111,7 @@ func TestRuleSpecificity(t *testing.T) { } // Sort by specificity - sorted := sortBySpecificity(rules) + sorted := sortRulesBySpecificity(rules) // Most specific should come first, least specific last assert.Equal(t, "specific.txt", sorted[0].Pattern) @@ -134,13 +134,15 @@ func TestGlobSpecificityScore(t *testing.T) { } for _, tc := range testCases { - score := globSpecificityScore(tc.pattern) - assert.Equal(t, tc.score, score, "Specificity score for %q should be %d, got %d", tc.pattern, tc.score, score) + t.Run(tc.pattern, func(t *testing.T) { + score := calculateGlobSpecificity(tc.pattern) + assert.Equal(t, tc.score, score, "Specificity score for '%s' should be %d, got %d", tc.pattern, tc.score, score) + }) } } func TestNodeGetChild(t *testing.T) { - node := NewNode("parent", false, 1) + node := NewACLNode("path/to", "user1", false, 1) // Initially, no children child, exists := node.GetChild("child") @@ -148,13 +150,13 @@ func TestNodeGetChild(t *testing.T) { assert.Nil(t, child) // Add a child - childNode := NewNode("parent/child", false, 2) + childNode := NewACLNode("path/to/child", "user1", false, 2) node.SetChild("child", childNode) // Verify child can be retrieved child, exists = node.GetChild("child") assert.True(t, exists) - assert.Equal(t, "parent/child", child.path) + assert.Equal(t, "path/to/child", child.path) // Delete the child node.DeleteChild("child") diff --git a/internal/server/acl/rule.go b/internal/server/acl/rule.go index fbbfb8e2..f1d948c3 100644 --- a/internal/server/acl/rule.go +++ b/internal/server/acl/rule.go @@ -8,27 +8,37 @@ import ( ) var ( - ErrAdminRequired = errors.New("admin access required") - ErrWriteRequired = errors.New("write access required") - ErrReadRequired = errors.New("read access required") + ErrNoAdminAccess = errors.New("no admin access") + ErrNoWriteAccess = errors.New("no write access") + ErrNoReadAccess = errors.New("no read access") ErrDirsNotAllowed = errors.New("directories not allowed") ErrSymlinksNotAllowed = errors.New("symlinks not allowed") ErrFileSizeExceeded = errors.New("file size exceeds limits") ErrInvalidAccessLevel = errors.New("invalid access level") ) -// Rule represents an access control rule for a file or directory in an ACL Node. +// ACLRule represents an access control rule for a file or directory in an ACL Node. // It contains the full pattern of the rule, the rule itself, and the node it applies to -type Rule struct { +type ACLRule struct { fullPattern string // full pattern = full path + glob rule *aclspec.Rule // the rule itself - node *Node // the node this rule applies to + node *ACLNode // the node this rule applies to +} + +// Owner returns the owner of the rule (inherited from the node) +func (r *ACLRule) Owner() string { + return r.node.GetOwner() +} + +// Version returns the version of the rule (inherited from the node)s +func (r *ACLRule) Version() ACLVersion { + return r.node.GetVersion() } // CheckAccess checks if the user has permission to perform the specified action on the node. -func (r *Rule) CheckAccess(user *User, level AccessLevel) error { +func (r *ACLRule) CheckAccess(user *User, level AccessLevel) error { // the rule is owned by the user, so they can do anything - if user.IsOwner { + if r.Owner() == user.ID { return nil } @@ -44,19 +54,19 @@ func (r *Rule) CheckAccess(user *User, level AccessLevel) error { // Use a switch with fallthrough for permission hierarchy switch level { - case AccessWriteACL: + case AccessAdmin: if !isAdmin { - return ErrAdminRequired + return ErrNoAdminAccess } return nil case AccessWrite: if !isWriter { - return ErrWriteRequired + return ErrNoWriteAccess } return nil case AccessRead: if !isReader { - return ErrReadRequired + return ErrNoReadAccess } return nil default: @@ -65,7 +75,7 @@ func (r *Rule) CheckAccess(user *User, level AccessLevel) error { } // CheckLimits checks if the file is within the limits specified by the rule. -func (r *Rule) CheckLimits(info *File) error { +func (r *ACLRule) CheckLimits(info *File) error { limits := r.rule.Limits if limits == nil { diff --git a/internal/server/acl/tree.go b/internal/server/acl/tree.go index a7c09d84..96603704 100644 --- a/internal/server/acl/tree.go +++ b/internal/server/acl/tree.go @@ -11,30 +11,37 @@ import ( var ( ErrInvalidRuleset = errors.New("invalid ruleset") ErrMaxDepthExceeded = errors.New("maximum depth exceeded") - ErrNoRuleFound = errors.New("no rule found") + ErrNoRuleSet = errors.New("no ruleset found") + ErrNoRule = errors.New("no rules available") ) -// Tree stores the ACL rules in a n-ary tree for efficient lookups. -type Tree struct { - root *Node +const ( + rootPath = "/" + noOwner = "" +) + +// ACLTree stores the ACL rules in a n-ary tree for efficient lookups. +type ACLTree struct { + root *ACLNode } -func NewTree() *Tree { - return &Tree{ - root: NewNode(ACLPathSep, false, 0), +// NewACLTree creates a new ACLTree. +func NewACLTree() *ACLTree { + return &ACLTree{ + root: NewACLNode(rootPath, noOwner, false, 0), } } // Add or update a ruleset in the tree. -func (t *Tree) AddRuleSet(ruleset *aclspec.RuleSet) error { +func (t *ACLTree) AddRuleSet(ruleset *aclspec.RuleSet) (*ACLNode, error) { // Validate the ruleset if ruleset == nil { - return fmt.Errorf("%w: ruleset is nil", ErrInvalidRuleset) + return nil, fmt.Errorf("%w: ruleset is nil", ErrInvalidRuleset) } allRules := ruleset.AllRules() if len(allRules) == 0 { - return fmt.Errorf("%w: ruleset is empty", ErrInvalidRuleset) + return nil, fmt.Errorf("%w: ruleset is empty", ErrInvalidRuleset) } // Clean and split the path @@ -42,64 +49,77 @@ func (t *Tree) AddRuleSet(ruleset *aclspec.RuleSet) error { parts := strings.Split(cleanPath, ACLPathSep) pathDepth := strings.Count(cleanPath, ACLPathSep) + // owner is assumed to be the first part of the path. + // but in future we can always bake it as a part of the acl schema + owner := parts[0] + if owner == "" { + return nil, fmt.Errorf("%w: owner is empty", ErrInvalidRuleset) + } + // Check path depth limit (u8) - if pathDepth > 255 { - return ErrMaxDepthExceeded + if pathDepth > ACLMaxDepth { + return nil, ErrMaxDepthExceeded } // Start at the root node current := t.root - currentDepth := current.depth + currentDepth := 0 // Traverse/create the path - for _, part := range parts { - currentDepth++ + for i, part := range parts { + currentDepth = i + 1 // Calculate depth based on current position // Important: We still process terminal nodes to ensure all ACLs are known to the tree // Get or create child node child, exists := current.GetChild(part) if !exists { - // Use forward slashes for paths fullPath := ACLJoinPath(parts[:currentDepth]...) - child = NewNode(fullPath, false, currentDepth) + child = NewACLNode(fullPath, owner, false, ACLDepth(currentDepth)) current.SetChild(part, child) } + current = child } // Set the rules on the final node current.SetRules(allRules, ruleset.Terminal) - return nil + return current, nil } -// Get rule for the given path -func (t *Tree) GetRule(path string) (*Rule, error) { +// GetEffectiveRule returns the most specific rule applicable to the given path. +func (t *ACLTree) GetEffectiveRule(path string) (*ACLRule, error) { + normalizedPath := ACLNormPath(path) - node := t.GetNearestNodeWithRules(path) // O(depth) + node := t.LookupNearestNode(normalizedPath) // O(depth) if node == nil { - return nil, ErrNoRuleFound + return nil, ErrNoRuleSet } - rule, err := node.FindBestRule(path) // O(rules|node) + rule, err := node.FindBestRule(normalizedPath) // O(rules|node) if err != nil { - return nil, err + return nil, err // returns ErrNoRuleFound if no rule is found } return rule, nil } -// GetNearestNodeWithRules returns the nearest node in the tree that has associated rules for the given path. +// LookupNearestNode returns the nearest node in the tree that has associated rules for the given path. // It returns nil if no such node is found. -func (t *Tree) GetNearestNodeWithRules(path string) *Node { - parts := ACLPathSegments(path) +func (t *ACLTree) LookupNearestNode(normalizedPath string) *ACLNode { + parts := ACLPathSegments(normalizedPath) - var candidate *Node + var candidate *ACLNode current := t.root + // candidate only if the root node has rules + if current.GetRules() != nil { + candidate = current + } + for _, part := range parts { // Stop if the current node is terminal. - if current.IsTerminal() { + if current.GetTerminal() { break } @@ -109,7 +129,7 @@ func (t *Tree) GetNearestNodeWithRules(path string) *Node { } current = child - if child.Rules() != nil { + if child.GetRules() != nil { candidate = current } } @@ -118,12 +138,13 @@ func (t *Tree) GetNearestNodeWithRules(path string) *Node { } // GetNode finds the exact node applicable for the given path. -func (t *Tree) GetNode(path string) *Node { - parts := ACLPathSegments(path) +func (t *ACLTree) GetNode(path string) *ACLNode { + normalizedPath := ACLNormPath(path) + parts := ACLPathSegments(normalizedPath) current := t.root for _, part := range parts { - if current.IsTerminal() { + if current.GetTerminal() { break } @@ -138,26 +159,31 @@ func (t *Tree) GetNode(path string) *Node { } // Removes a ruleset at the specified path -func (t *Tree) RemoveRuleSet(path string) bool { - var parent *Node +func (t *ACLTree) RemoveRuleSet(path string) bool { + var parent *ACLNode var lastPart string - parts := ACLPathSegments(path) - current := t.root + normalizedPath := ACLNormPath(path) + parts := ACLPathSegments(normalizedPath) + currentNode := t.root for _, part := range parts { - child, exists := current.GetChild(part) + child, exists := currentNode.GetChild(part) if !exists { return false } - parent = current - current = child + parent = currentNode + currentNode = child lastPart = part } - // Need to lock parent since we're modifying its children - parent.DeleteChild(lastPart) + // clear the rules for the node, but if it has no children, delete the whole node from it's parent + if currentNode.GetChildCount() == 0 { + parent.DeleteChild(lastPart) + } else { + currentNode.ClearRules() + } return true } diff --git a/internal/server/acl/debug.go b/internal/server/acl/tree_debug.go similarity index 93% rename from internal/server/acl/debug.go rename to internal/server/acl/tree_debug.go index b260d226..edf2ec40 100644 --- a/internal/server/acl/debug.go +++ b/internal/server/acl/tree_debug.go @@ -8,17 +8,19 @@ import ( ) // String implements the Stringer interface for PTree -func (t *Tree) String() string { +func (t *ACLTree) String() string { + var sb strings.Builder + if t.root == nil { return "" } - var sb strings.Builder + t.root.buildString(&sb, "", true, true) return sb.String() } // buildString recursively builds the string representation of the tree -func (n *Node) buildString(sb *strings.Builder, prefix string, isLast bool, isRoot bool) { +func (n *ACLNode) buildString(sb *strings.Builder, prefix string, isLast bool, isRoot bool) { n.mu.RLock() defer n.mu.RUnlock() @@ -27,6 +29,7 @@ func (n *Node) buildString(sb *strings.Builder, prefix string, isLast bool, isRo if !isLast { marker = "├── " } + sb.WriteString(prefix) sb.WriteString(marker) } @@ -35,16 +38,19 @@ func (n *Node) buildString(sb *strings.Builder, prefix string, isLast bool, isRo sb.WriteString(filepath.Base(n.path)) sb.WriteString(fmt.Sprintf(" (d:%d, v:%d", n.depth, n.version)) if len(n.rules) > 0 { - sb.WriteString(fmt.Sprintf(", rules:%d", len(n.rules))) // sb.WriteString(fmt.Sprintf(", rules:%d, ptr:%p", len(n.rules), n.rules)) + sb.WriteString(fmt.Sprintf(", rules:%d", len(n.rules))) } + if n.terminal { sb.WriteString(", TERMINAL") } + sb.WriteString(")\n") // Calculate the new prefix for children childPrefix := prefix + if !isRoot { if isLast { childPrefix += " " @@ -77,6 +83,7 @@ func (n *Node) buildString(sb *strings.Builder, prefix string, isLast bool, isRo for k := range n.children { children = append(children, k) } + sort.Strings(children) // Print children diff --git a/internal/server/acl/tree_test.go b/internal/server/acl/tree_test.go index ce4c1b46..4291a8c7 100644 --- a/internal/server/acl/tree_test.go +++ b/internal/server/acl/tree_test.go @@ -7,22 +7,22 @@ import ( "github.com/stretchr/testify/assert" ) -func TestNewTree(t *testing.T) { - tree := NewTree() +func TestNewACLTree(t *testing.T) { + tree := NewACLTree() assert.NotNil(t, tree) assert.NotNil(t, tree.root) // The ACL system uses forward slashes internally on all platforms for glob compatibility. // We explicitly test for "/" rather than pathSep (which is "\" on Windows) because // the ACL system is a platform-independent abstraction layer. - assert.Equal(t, ACLPathSep, tree.root.path) + assert.Equal(t, "/", tree.root.path) assert.Empty(t, tree.root.children) - assert.Empty(t, tree.root.rules) + assert.Nil(t, tree.root.GetRules()) } func TestAddRuleSet(t *testing.T) { - tree := NewTree() + tree := NewACLTree() ruleset := aclspec.NewRuleSet( "test/path", @@ -30,34 +30,34 @@ func TestAddRuleSet(t *testing.T) { aclspec.NewDefaultRule(aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) - err := tree.AddRuleSet(ruleset) + node, err := tree.AddRuleSet(ruleset) // check root node "/" assert.NoError(t, err) - assert.Empty(t, tree.root.rules) - assert.Contains(t, tree.root.children, "test") - assert.Equal(t, tree.root.path, "/") - assert.Equal(t, tree.root.depth, uint8(0)) + assert.NotNil(t, node) + assert.Nil(t, tree.root.GetRules()) + assert.Equal(t, "/", tree.root.path) + assert.Equal(t, ACLDepth(0), tree.root.GetDepth()) // check node "test" child, ok := tree.root.GetChild("test") assert.True(t, ok) assert.NotNil(t, child) - assert.Empty(t, child.rules) - assert.Contains(t, child.children, "path") - assert.Equal(t, child.path, "test") - assert.Equal(t, child.depth, uint8(1)) + assert.Nil(t, child.GetRules()) + assert.Equal(t, "test", child.path) + assert.Equal(t, ACLDepth(1), child.GetDepth()) // check node "path" child, ok = child.GetChild("path") assert.True(t, ok) assert.NotNil(t, child) - assert.Equal(t, child.path, "test/path") - assert.Equal(t, child.depth, uint8(2)) + assert.Equal(t, "test/path", child.path) + assert.Equal(t, ACLDepth(2), child.GetDepth()) + assert.NotNil(t, child.GetRules()) } func TestTreeTraversal(t *testing.T) { - tree := NewTree() + tree := NewACLTree() // Add rulesets with nested paths ruleset1 := aclspec.NewRuleSet( @@ -68,7 +68,7 @@ func TestTreeTraversal(t *testing.T) { ruleset2 := aclspec.NewRuleSet( "parent/child", - aclspec.SetTerminal, + aclspec.UnsetTerminal, // Non-terminal to allow grandchild aclspec.NewRule("*.md", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) @@ -78,36 +78,36 @@ func TestTreeTraversal(t *testing.T) { aclspec.NewRule("*.go", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), ) - err := tree.AddRuleSet(ruleset1) + _, err := tree.AddRuleSet(ruleset1) assert.NoError(t, err) - err = tree.AddRuleSet(ruleset2) + _, err = tree.AddRuleSet(ruleset2) assert.NoError(t, err) - err = tree.AddRuleSet(ruleset3) + _, err = tree.AddRuleSet(ruleset3) assert.NoError(t, err) // Test finding nearest node with rules for different paths - node := tree.GetNearestNodeWithRules("parent/file.txt") + node := tree.LookupNearestNode("parent/file.txt") assert.Equal(t, "parent", node.path) - node = tree.GetNearestNodeWithRules("parent/child/document.md") + node = tree.LookupNearestNode("parent/child/document.md") assert.Equal(t, "parent/child", node.path) - node = tree.GetNearestNodeWithRules("parent/child/grandchild/main.go") - assert.Equal(t, "parent/child", node.path) + node = tree.LookupNearestNode("parent/child/grandchild/main.go") + assert.Equal(t, "parent/child/grandchild", node.path) - // Test inheritance - terminal nodes (like parent/child) block inheritance from higher levels - node = tree.GetNearestNodeWithRules("parent/child/unknown.txt") + // Test inheritance - terminal nodes (like grandchild) block inheritance from higher levels + node = tree.LookupNearestNode("parent/child/unknown.txt") assert.Equal(t, "parent/child", node.path) // Test path that doesn't exist in the tree - node = tree.GetNearestNodeWithRules("unknown/path") + node = tree.LookupNearestNode("unknown/path") assert.Nil(t, node) } func TestRemoveRuleSet(t *testing.T) { - tree := NewTree() + tree := NewACLTree() // Add rulesets ruleset1 := aclspec.NewRuleSet( @@ -122,10 +122,10 @@ func TestRemoveRuleSet(t *testing.T) { aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) - err := tree.AddRuleSet(ruleset1) + _, err := tree.AddRuleSet(ruleset1) assert.NoError(t, err) - err = tree.AddRuleSet(ruleset2) + _, err = tree.AddRuleSet(ruleset2) assert.NoError(t, err) // Verify both rulesets are in the tree @@ -152,8 +152,342 @@ func TestRemoveRuleSet(t *testing.T) { assert.False(t, removed) } +func TestGetNode(t *testing.T) { + // Test the GetNode method which finds exact nodes for given paths + // This validates precise node location without rule inheritance logic + tree := NewACLTree() + + // Add nested rulesets to create a tree structure + ruleset1 := aclspec.NewRuleSet( + "parent", + aclspec.UnsetTerminal, + aclspec.NewRule("*.txt", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), + ) + + ruleset2 := aclspec.NewRuleSet( + "parent/child", + aclspec.SetTerminal, + aclspec.NewRule("*.md", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + ) + + _, err := tree.AddRuleSet(ruleset1) + assert.NoError(t, err) + + _, err = tree.AddRuleSet(ruleset2) + assert.NoError(t, err) + + // Test getting exact nodes that exist + parentNode := tree.GetNode("parent") + assert.NotNil(t, parentNode, "Should find parent node") + assert.Equal(t, "parent", parentNode.path, "Parent node should have correct path") + + childNode := tree.GetNode("parent/child") + assert.NotNil(t, childNode, "Should find child node") + assert.Equal(t, "parent/child", childNode.path, "Child node should have correct path") + + // Test getting node for path that goes beyond existing nodes + deepNode := tree.GetNode("parent/child/grandchild") + // Should return the deepest existing node (child), not create new nodes + assert.Equal(t, "parent/child", deepNode.path, "Should return deepest existing node") + + // Test getting node for path that doesn't exist at all + nonExistentNode := tree.GetNode("nonexistent/path") + // Should return root node since no path matches + assert.Equal(t, "/", nonExistentNode.path, "Should return root for non-existent paths") + + // Test getting root node + rootNode := tree.GetNode("") + assert.Equal(t, "/", rootNode.path, "Empty path should return root node") +} + +func TestGetNodeWithTerminalNodes(t *testing.T) { + // Test GetNode behavior with terminal nodes + // Terminal nodes allow children to be added but stop traversal during lookups + tree := NewACLTree() + + // Add a terminal node with catch-all rule + terminalRuleset := aclspec.NewRuleSet( + "terminal", + aclspec.SetTerminal, + aclspec.NewRule("**", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + ) + + _, err := tree.AddRuleSet(terminalRuleset) + assert.NoError(t, err) + + // Verify the terminal node was added + terminalNode := tree.GetNode("terminal") + assert.NotNil(t, terminalNode, "Terminal node should exist") + assert.True(t, terminalNode.GetTerminal(), "Node should be marked as terminal") + + // Add a child under the terminal node - this should SUCCEED + // The tree allows all nodes to be added for performance (avoids tree rebuilds) + childRuleset := aclspec.NewRuleSet( + "terminal/child", + aclspec.UnsetTerminal, + aclspec.NewRule("*.md", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), + ) + + _, err = tree.AddRuleSet(childRuleset) + assert.NoError(t, err, "Should allow child rulesets under terminal nodes (tree contains all ACLs)") + + // GetNode should stop at terminal node during lookup traversal + // Even though child exists in tree, GetNode stops at terminal boundary + childNode := tree.GetNode("terminal/child") + assert.Equal(t, "terminal", childNode.path, "GetNode should stop at terminal boundary") + + // When looking for paths beyond terminal, it also stops at the terminal node + deepNode := tree.GetNode("terminal/child/deeper") + assert.Equal(t, "terminal", deepNode.path, "Lookup should stop at terminal node") + + // Verify child actually exists in tree structure by accessing parent's children directly + terminalNode = tree.GetNode("terminal") + actualChild, exists := terminalNode.GetChild("child") + assert.True(t, exists, "Child should exist in tree structure") + assert.Equal(t, "terminal/child", actualChild.path, "Child should have correct path") + assert.False(t, actualChild.GetTerminal(), "Child should not be terminal") + + // AND: LookupNearestNode should also stop at terminal nodes + nearestNode := tree.LookupNearestNode("terminal/child/file.txt") + assert.NotNil(t, nearestNode, "Should find the terminal node") + assert.Equal(t, "terminal", nearestNode.path, "Rule lookup should stop at terminal node") + + // This means child rules are ignored for inheritance even though they exist in the tree + // Test with child path - should get rule from terminal node (** pattern matches everything) + rule, err := tree.GetEffectiveRule("terminal/child/test.md") + assert.NoError(t, err, "Should find rule from terminal node") + assert.Equal(t, "terminal", rule.node.path, "Rule should come from terminal node, not child") + assert.Equal(t, "**", rule.rule.Pattern, "Should use terminal node's catch-all rule") +} + +func TestTerminalNodeValidation(t *testing.T) { + // Test that terminal nodes control inheritance but allow children to be added + // This explicitly tests the correct terminal behavior + tree := NewACLTree() + + // Add a terminal node + terminalRuleset := aclspec.NewRuleSet( + "secure", + aclspec.SetTerminal, + aclspec.NewRule("**", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + ) + + _, err := tree.AddRuleSet(terminalRuleset) + assert.NoError(t, err, "Should be able to add terminal node") + + // Verify it's terminal + node := tree.GetNode("secure") + assert.True(t, node.GetTerminal(), "Node should be marked as terminal") + + // Add direct child - should succeed (tree allows all nodes for performance) + childRuleset := aclspec.NewRuleSet( + "secure/child", + aclspec.UnsetTerminal, + aclspec.NewRule("*.txt", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), + ) + + _, err = tree.AddRuleSet(childRuleset) + assert.NoError(t, err, "Should be able to add child under terminal node (exists in tree)") + + // Add deeper nested child - should also succeed + deepChildRuleset := aclspec.NewRuleSet( + "secure/child/grandchild", + aclspec.UnsetTerminal, + aclspec.NewRule("*.md", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), + ) + + _, err = tree.AddRuleSet(deepChildRuleset) + assert.NoError(t, err, "Should be able to add nested child under terminal node") + + // BUT: Terminal nodes should stop traversal for rule lookups + // All paths under secure/ should resolve to the secure terminal node + + // Direct child path should resolve to terminal parent + nearestNode := tree.LookupNearestNode("secure/child/test.txt") + assert.NotNil(t, nearestNode, "Should find a node") + assert.Equal(t, "secure", nearestNode.path, "Should resolve to terminal parent, not child") + + // Deep child path should also resolve to terminal parent + nearestNode = tree.LookupNearestNode("secure/child/grandchild/test.md") + assert.NotNil(t, nearestNode, "Should find a node") + assert.Equal(t, "secure", nearestNode.path, "Should resolve to terminal parent, not grandchild") + + // Rule lookup should use terminal node's rules + rule, err := tree.GetEffectiveRule("secure/child/test.txt") + assert.NoError(t, err, "Should find rule for child path") + assert.Equal(t, "secure", rule.node.path, "Rule should come from terminal parent") + assert.Equal(t, "**", rule.rule.Pattern, "Should use terminal node's catch-all rule") + + // Non-terminal nodes should work normally + nonTerminalRuleset := aclspec.NewRuleSet( + "open", + aclspec.UnsetTerminal, + aclspec.NewRule("**", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), + ) + + _, err = tree.AddRuleSet(nonTerminalRuleset) + assert.NoError(t, err, "Should be able to add non-terminal node") + + // Add child under non-terminal - should succeed and be accessible + openChildRuleset := aclspec.NewRuleSet( + "open/child", + aclspec.UnsetTerminal, + aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + ) + + _, err = tree.AddRuleSet(openChildRuleset) + assert.NoError(t, err, "Should be able to add child under non-terminal node") + + // Child under non-terminal should be accessible for rule lookups + nearestNode = tree.LookupNearestNode("open/child/test.txt") + assert.NotNil(t, nearestNode, "Should find a node") + assert.Equal(t, "open/child", nearestNode.path, "Should find the actual child node, not parent") +} + +func TestConflictingRuleSetsAtSameLevel(t *testing.T) { + // Test what happens when adding multiple rulesets to the same path + // This tests ruleset replacement/overwriting behavior + tree := NewACLTree() + + // Add initial ruleset + initialRuleset := aclspec.NewRuleSet( + "shared", + aclspec.UnsetTerminal, + aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + ) + + _, err := tree.AddRuleSet(initialRuleset) + assert.NoError(t, err, "Should be able to add initial ruleset") + + // Verify initial ruleset + node := tree.GetNode("shared") + assert.NotNil(t, node, "Node should exist") + assert.False(t, node.GetTerminal(), "Node should not be terminal initially") + assert.Len(t, node.GetRules(), 1, "Should have 1 rule initially") + + // Add conflicting ruleset at the SAME path with different rules and terminal flag + conflictingRuleset := aclspec.NewRuleSet( + "shared", // Same path! + aclspec.SetTerminal, // Different terminal flag + aclspec.NewRule("*.md", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), // Different rule + aclspec.NewRule("**", aclspec.SharedReadAccess("admin@example.com"), aclspec.DefaultLimits()), // Additional rule + ) + + _, err = tree.AddRuleSet(conflictingRuleset) + assert.NoError(t, err, "Should be able to add conflicting ruleset (overwrites)") + + // Verify the conflicting ruleset completely replaced the original + node = tree.GetNode("shared") + assert.NotNil(t, node, "Node should still exist") + assert.True(t, node.GetTerminal(), "Node should now be terminal (overwritten)") + assert.Len(t, node.GetRules(), 2, "Should have 2 rules from new ruleset") + + // Verify the conflicting ruleset completely replaced the original + // The original *.txt rule with PrivateAccess is gone + // Now we have *.md rule with PublicReadAccess and ** rule with SharedReadAccess + + // Test that *.txt files now match the ** rule (not the original *.txt rule) + rule, err := node.FindBestRule("shared/test.txt") + assert.NoError(t, err, "Should find rule for *.txt files") + assert.Equal(t, "**", rule.rule.Pattern, "Should match the ** rule, not original *.txt rule") + assert.True(t, rule.rule.Access.Read.Contains("admin@example.com"), "Should have admin access (from ** rule)") + assert.False(t, rule.rule.Access.Read.Contains("*"), "Should NOT have public access (original rule is gone)") + + // Test that *.md files match the more specific *.md rule + rule, err = node.FindBestRule("shared/test.md") + assert.NoError(t, err, "Should find rule for *.md files") + assert.Equal(t, "*.md", rule.rule.Pattern, "Should match the specific *.md rule") + assert.True(t, rule.rule.Access.Read.Contains("*"), "Should have public read access from *.md rule") + + // Test that other files match the ** rule + rule, err = node.FindBestRule("shared/anything.xyz") + assert.NoError(t, err, "Should find rule for other files") + assert.Equal(t, "**", rule.rule.Pattern, "Should match the ** rule") + assert.True(t, rule.rule.Access.Read.Contains("admin@example.com"), "Should have admin access from ** rule") + + // Test that the terminal flag now controls inheritance + childRuleset := aclspec.NewRuleSet( + "shared/child", + aclspec.UnsetTerminal, + aclspec.NewRule("*.go", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + ) + + _, err = tree.AddRuleSet(childRuleset) + assert.NoError(t, err, "Should be able to add child under terminal node (tree allows all nodes)") + + // But rule lookups should stop at terminal node + rule, err = tree.GetEffectiveRule("shared/child/test.go") + assert.NoError(t, err, "Should find rule for child path") + assert.Equal(t, "shared", rule.node.path, "Rule should come from terminal parent, not child") + assert.Equal(t, "**", rule.rule.Pattern, "Should use parent's ** rule, not child's *.go rule") +} + +func TestAddRuleSetErrorCases(t *testing.T) { + // Test AddRuleSet with various error conditions + // This improves coverage of edge cases and error handling + tree := NewACLTree() + + // Test with nil ruleset + _, err := tree.AddRuleSet(nil) + assert.Error(t, err, "Should reject nil ruleset") + assert.Contains(t, err.Error(), "ruleset is nil", "Error should indicate nil ruleset") + + // Test with empty ruleset (no rules) + emptyRuleset := &aclspec.RuleSet{ + Path: "test", + Terminal: false, + Rules: []*aclspec.Rule{}, + } + _, err = tree.AddRuleSet(emptyRuleset) + assert.Error(t, err, "Should reject empty ruleset") + assert.Contains(t, err.Error(), "ruleset is empty", "Error should indicate empty ruleset") + + // Test with extremely deep path (path depth > 255) + deepPath := "" + for i := 0; i < 300; i++ { // Create path with > 255 components + deepPath += "a/" + } + deepRuleset := aclspec.NewRuleSet( + deepPath, + aclspec.UnsetTerminal, + aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + ) + _, err = tree.AddRuleSet(deepRuleset) + assert.Error(t, err, "Should reject paths that are too deep") + assert.Contains(t, err.Error(), "maximum depth exceeded", "Error should indicate depth limit") +} + +func TestAddRuleSetPathNormalization(t *testing.T) { + // Test that AddRuleSet properly normalizes different path formats + // This ensures consistent path handling across different input formats + tree := NewACLTree() + + // Test with path that has leading/trailing separators + rule := aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()) + + // These should all result in the same normalized path + testPaths := []string{ + "test/path", + "/test/path", + "test/path/", + "/test/path/", + "./test/path", + } + + for i, path := range testPaths { + ruleset := aclspec.NewRuleSet(path, false, rule) + _, err := tree.AddRuleSet(ruleset) + assert.NoError(t, err, "Should accept path format: %s", path) + + // All paths should result in the same node being found + node := tree.GetNode("test/path") + assert.NotNil(t, node, "Should find node for normalized path (test %d)", i) + assert.Equal(t, "test/path", node.path, "Path should be normalized consistently (test %d)", i) + } +} + func TestNestedRuleSetRemoval(t *testing.T) { - tree := NewTree() + tree := NewACLTree() // Add nested rulesets ruleset1 := aclspec.NewRuleSet( @@ -168,38 +502,65 @@ func TestNestedRuleSetRemoval(t *testing.T) { aclspec.NewRule("*.md", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) - err := tree.AddRuleSet(ruleset1) + _, err := tree.AddRuleSet(ruleset1) assert.NoError(t, err) - err = tree.AddRuleSet(ruleset2) + _, err = tree.AddRuleSet(ruleset2) assert.NoError(t, err) - // Remove parent - should also remove child + // Remove parent - with new behavior, this only clears rules since parent has children removed := tree.RemoveRuleSet("parent") assert.True(t, removed) - // Verify both are gone - _, ok := tree.root.GetChild("parent") - assert.False(t, ok) + // Verify parent node still exists but rules are cleared + parentNode, ok := tree.root.GetChild("parent") + assert.True(t, ok, "Parent node should still exist after removing ruleset") + assert.NotNil(t, parentNode, "Parent node should not be nil") + assert.Nil(t, parentNode.GetRules(), "Parent node rules should be cleared") - // Add the parent ruleset back - err = tree.AddRuleSet(ruleset1) - assert.NoError(t, err) + // Verify child still exists since parent wasn't deleted + childNode, ok := parentNode.GetChild("child") + assert.True(t, ok, "Child node should still exist") + assert.NotNil(t, childNode, "Child node should not be nil") + assert.NotNil(t, childNode.GetRules(), "Child node rules should still exist") - // Add the child ruleset back - err = tree.AddRuleSet(ruleset2) + // Add the parent ruleset back + _, err = tree.AddRuleSet(ruleset1) assert.NoError(t, err) - // Remove just the child + // Remove just the child - this should delete the child node since it has no children removed = tree.RemoveRuleSet("parent/child") assert.True(t, removed) // Verify parent still exists - parentNode, ok := tree.root.GetChild("parent") - assert.True(t, ok) - assert.NotNil(t, parentNode) + parentNode, ok = tree.root.GetChild("parent") + assert.True(t, ok, "Parent node should exist after removing just child") + assert.NotNil(t, parentNode, "Parent node should not be nil") - // Verify child was removed + // Verify child was deleted since it had no children _, ok = parentNode.GetChild("child") - assert.False(t, ok) + assert.False(t, ok, "Child node should be deleted since it had no children") + + // Test removing a leaf node (no children) - should delete the node + leafRuleset := aclspec.NewRuleSet( + "leaf", + aclspec.SetTerminal, + aclspec.NewRule("*.log", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + ) + + _, err = tree.AddRuleSet(leafRuleset) + assert.NoError(t, err) + + // Verify leaf node exists + leafNode, ok := tree.root.GetChild("leaf") + assert.True(t, ok, "Leaf node should exist") + assert.NotNil(t, leafNode, "Leaf node should not be nil") + + // Remove leaf node - should delete it since it has no children + removed = tree.RemoveRuleSet("leaf") + assert.True(t, removed) + + // Verify leaf node was deleted + _, ok = tree.root.GetChild("leaf") + assert.False(t, ok, "Leaf node should be deleted since it had no children") } diff --git a/internal/server/acl/types.go b/internal/server/acl/types.go index 9bcd50ee..74479afc 100644 --- a/internal/server/acl/types.go +++ b/internal/server/acl/types.go @@ -1,8 +1,7 @@ package acl type User struct { - ID string - IsOwner bool + ID string } type File struct { diff --git a/internal/server/datasite/datasite.go b/internal/server/datasite/datasite.go index 0f3f83af..0b2bd9fc 100644 --- a/internal/server/datasite/datasite.go +++ b/internal/server/datasite/datasite.go @@ -15,10 +15,10 @@ import ( type DatasiteService struct { blob *blob.BlobService - acl *acl.AclService + acl *acl.ACLService } -func NewDatasiteService(blobSvc *blob.BlobService, aclSvc *acl.AclService) *DatasiteService { +func NewDatasiteService(blobSvc *blob.BlobService, aclSvc *acl.ACLService) *DatasiteService { return &DatasiteService{ blob: blobSvc, acl: aclSvc, @@ -50,7 +50,11 @@ func (d *DatasiteService) Start(ctx context.Context) error { // Load the ACL rulesets start = time.Now() - d.acl.LoadRuleSets(ruleSets) + for _, ruleSet := range ruleSets { + if _, err := d.acl.AddRuleSet(ruleSet); err != nil { + slog.Warn("ruleset update error", "path", ruleSet.Path, "error", err) + } + } slog.Debug("acl build", "count", len(ruleSets), "took", time.Since(start)) // Warm up the ACL cache @@ -60,7 +64,7 @@ func (d *DatasiteService) Start(ctx context.Context) error { &acl.User{ID: aclspec.Everyone}, &acl.File{Path: blob.Key}, acl.AccessRead, - ); err != nil && errors.Is(err, acl.ErrNoRuleFound) { + ); err != nil && errors.Is(err, acl.ErrNoRule) { slog.Warn("acl cache warm error", "path", blob.Key, "error", err) } } @@ -82,7 +86,7 @@ func (d *DatasiteService) GetView(user string) []*blob.BlobInfo { // Filter blobs based on ACL for _, blob := range blobs { if err := d.acl.CanAccess( - &acl.User{ID: user, IsOwner: IsOwner(blob.Key, user)}, + &acl.User{ID: user}, &acl.File{Path: blob.Key}, acl.AccessRead, ); err == nil { diff --git a/internal/server/datasite/utils.go b/internal/server/datasite/utils.go index 25c1efe2..33126411 100644 --- a/internal/server/datasite/utils.go +++ b/internal/server/datasite/utils.go @@ -2,13 +2,48 @@ package datasite import ( "path/filepath" + "regexp" "strings" + + "github.com/openmined/syftbox/internal/utils" +) + +var ( + PathSep = string(filepath.Separator) + regexDatasitePath = regexp.MustCompile(`^[^\s@]+@[^\s@]+\.[^\s@]+/`) ) -var pathSep = string(filepath.Separator) +// GetOwner returns the owner of the path +func GetOwner(path string) string { + // clean path + path = CleanPath(path) + + // get owner + parts := strings.Split(path, PathSep) + if len(parts) == 0 { + return "" + } + + return parts[0] +} // IsOwner checks if the user is the owner of the path // The underlying assumption here is that owner is the prefix of the path func IsOwner(path string, user string) bool { - return strings.HasPrefix(strings.TrimLeft(filepath.Clean(path), pathSep), user) + path = CleanPath(path) + return strings.HasPrefix(path, user) +} + +// CleanPath returns a path with leading and trailing slashes removed +func CleanPath(path string) string { + return strings.TrimLeft(filepath.Clean(path), PathSep) +} + +// IsValidPath checks if the path is a valid datasite path +func IsValidPath(path string) bool { + return regexDatasitePath.MatchString(path) +} + +func IsValidDatasite(user string) bool { + return utils.IsValidEmail(user) } diff --git a/internal/server/handlers/acl/acl_handler.go b/internal/server/handlers/acl/acl_handler.go new file mode 100644 index 00000000..fee0262a --- /dev/null +++ b/internal/server/handlers/acl/acl_handler.go @@ -0,0 +1,43 @@ +package acl + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "github.com/openmined/syftbox/internal/server/acl" + "github.com/openmined/syftbox/internal/server/handlers/api" +) + +type ACLHandler struct { + aclSvc *acl.ACLService +} + +func NewACLHandler(svc *acl.ACLService) *ACLHandler { + return &ACLHandler{ + aclSvc: svc, + } +} + +func (h *ACLHandler) CheckAccess(ctx *gin.Context) { + var req ACLCheckRequest + if err := ctx.ShouldBindQuery(&req); err != nil { + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, err) + return + } + + // Check access using the ACL service + if err := h.aclSvc.CanAccess( + &acl.User{ID: req.User}, + &acl.File{Path: req.Path, Size: req.Size}, + acl.AccessLevel(req.Level), + ); err != nil { + api.AbortWithError(ctx, http.StatusForbidden, api.CodeAccessDenied, err) + return + } + + ctx.PureJSON(http.StatusOK, &ACLCheckResponse{ + User: req.User, + Path: req.Path, + Level: req.Level.String(), + }) +} diff --git a/internal/server/handlers/acl/acl_handler_types.go b/internal/server/handlers/acl/acl_handler_types.go new file mode 100644 index 00000000..7a46cc91 --- /dev/null +++ b/internal/server/handlers/acl/acl_handler_types.go @@ -0,0 +1,16 @@ +package acl + +import "github.com/openmined/syftbox/internal/server/acl" + +type ACLCheckRequest struct { + User string `form:"user" binding:"required"` + Path string `form:"path" binding:"required"` + Size int64 `form:"size"` + Level acl.AccessLevel `form:"level" binding:"required"` +} + +type ACLCheckResponse struct { + User string `json:"user"` + Path string `json:"path"` + Level string `json:"level"` +} diff --git a/internal/server/handlers/api/codes.go b/internal/server/handlers/api/codes.go new file mode 100644 index 00000000..9124f43f --- /dev/null +++ b/internal/server/handlers/api/codes.go @@ -0,0 +1,30 @@ +package api + +const ( + // Generic request/server errors + CodeInvalidRequest = "E_INVALID_REQUEST" // bad or invalid request + CodeRateLimited = "E_RATE_LIMITED" // rate limit exceeded + CodeInternalError = "E_INTERNAL_ERROR" // internal server error + CodeAccessDenied = "E_ACCESS_DENIED" // access denied + + // Auth errors + CodeAuthInvalidCredentials = "E_AUTH_INVALID_CREDENTIALS" // authentication credentials (e.g., token) are invalid, expired, or malformed. + CodeAuthTokenGenerationFailed = "E_AUTH_TOKEN_GENERATION_FAILED" // a failure during the generation of new authentication tokens. + CodeAuthOTPVerificationFailed = "E_AUTH_OTP_VERIFICATION_FAILED" // Email One-Time Password (OTP) verification failed. + CodeAuthTokenRefreshFailed = "E_AUTH_TOKEN_REFRESH_FAILED" // a failure during the attempt to refresh an authentication token. + CodeAuthNotificationFailed = "E_AUTH_NOTIFICATION_FAILED" // a failure in sending an authentication-related notification (e.g., OTP email/SMS). + + // Datasite errors + CodeDatasiteNotFound = "E_DATASITE_NOT_FOUND" // the specified datasite resource could not be found. + CodeDatasiteInvalidPath = "E_DATASITE_INVALID_PATH" // the provided path for a datasite resource is invalid or malformed. + + // Blob errors + CodeBlobNotFound = "E_BLOB_NOT_FOUND" // the specified blob could not be found. + CodeBlobListFailed = "E_BLOB_LIST_OPERATION_FAILED" // a failure during the operation to list blobs. + CodeBlobPutFailed = "E_BLOB_PUT_OPERATION_FAILED" // a failure during the operation to upload/put a blob. + CodeBlobGetFailed = "E_BLOB_GET_OPERATION_FAILED" // a failure during the operation to download/get a blob. + CodeBlobDeleteFailed = "E_BLOB_DELETE_OPERATION_FAILED" // a failure during the operation to delete a blob. + + // ACL errors + CodeACLUpdateFailed = "E_ACL_UPDATE_FAILED" // a failure during the operation to update an ACL. +) diff --git a/internal/server/handlers/api/error.go b/internal/server/handlers/api/error.go new file mode 100644 index 00000000..eb63da25 --- /dev/null +++ b/internal/server/handlers/api/error.go @@ -0,0 +1,12 @@ +package api + +import "fmt" + +type SyftAPIError struct { + Code string `json:"code"` + Message string `json:"error"` +} + +func (e *SyftAPIError) Error() string { + return fmt.Sprintf("syft api error: code=%s, message=%s", e.Code, e.Message) +} diff --git a/internal/server/handlers/api/response.go b/internal/server/handlers/api/response.go new file mode 100644 index 00000000..ad703981 --- /dev/null +++ b/internal/server/handlers/api/response.go @@ -0,0 +1,12 @@ +package api + +import "github.com/gin-gonic/gin" + +func AbortWithError(ctx *gin.Context, status int, code string, err error) { + ctx.Abort() + ctx.Error(err) + ctx.PureJSON(status, SyftAPIError{ + Code: code, + Message: err.Error(), + }) +} diff --git a/internal/server/handlers/auth/auth_handler.go b/internal/server/handlers/auth/auth_handler.go index 66805795..fed65f82 100644 --- a/internal/server/handlers/auth/auth_handler.go +++ b/internal/server/handlers/auth/auth_handler.go @@ -1,12 +1,13 @@ package auth import ( - "errors" "fmt" "net/http" "github.com/gin-gonic/gin" "github.com/openmined/syftbox/internal/server/auth" + "github.com/openmined/syftbox/internal/server/handlers/api" + "github.com/openmined/syftbox/internal/utils" ) type AuthHandler struct { @@ -22,24 +23,17 @@ func New(auth *auth.AuthService) *AuthHandler { func (h *AuthHandler) OTPRequest(ctx *gin.Context) { var req OTPRequest if err := ctx.ShouldBindJSON(&req); err != nil { - ctx.Error(fmt.Errorf("failed to bind json: %w", err)) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), - }) + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("failed to bind json: %w", err)) + return + } + + if !utils.IsValidEmail(req.Email) { + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("invalid email")) return } if err := h.auth.SendOTP(ctx, req.Email); err != nil { - ctx.Error(fmt.Errorf("failed to send OTP: %w", err)) - if errors.Is(err, auth.ErrInvalidEmail) { - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), - }) - } else { - ctx.PureJSON(http.StatusInternalServerError, gin.H{ - "error": err.Error(), - }) - } + api.AbortWithError(ctx, http.StatusInternalServerError, api.CodeAuthNotificationFailed, fmt.Errorf("failed to send OTP: %w", err)) return } @@ -49,19 +43,13 @@ func (h *AuthHandler) OTPRequest(ctx *gin.Context) { func (h *AuthHandler) OTPVerify(ctx *gin.Context) { var req OTPVerifyRequest if err := ctx.ShouldBindJSON(&req); err != nil { - ctx.Error(fmt.Errorf("failed to bind json: %w", err)) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), - }) + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("failed to bind json: %w", err)) return } accessToken, refreshToken, err := h.auth.GenerateTokensPair(ctx, req.Email, req.Code) if err != nil { - ctx.Error(fmt.Errorf("failed to generate tokens: %w", err)) - ctx.PureJSON(http.StatusInternalServerError, gin.H{ - "error": err.Error(), - }) + api.AbortWithError(ctx, http.StatusUnauthorized, api.CodeAuthTokenGenerationFailed, err) return } @@ -74,19 +62,13 @@ func (h *AuthHandler) OTPVerify(ctx *gin.Context) { func (h *AuthHandler) Refresh(ctx *gin.Context) { var req RefreshRequest if err := ctx.ShouldBindJSON(&req); err != nil { - ctx.Error(fmt.Errorf("failed to bind json: %w", err)) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), - }) + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("failed to bind json: %w", err)) return } accessToken, refreshToken, err := h.auth.RefreshToken(ctx, req.OldRefreshToken) if err != nil { - ctx.Error(fmt.Errorf("failed to refresh token: %w", err)) - ctx.PureJSON(http.StatusInternalServerError, gin.H{ - "error": err.Error(), - }) + api.AbortWithError(ctx, http.StatusUnauthorized, api.CodeAuthTokenRefreshFailed, err) return } diff --git a/internal/server/handlers/blob/blob_handler.go b/internal/server/handlers/blob/blob_handler.go index edaadec9..03da67c0 100644 --- a/internal/server/handlers/blob/blob_handler.go +++ b/internal/server/handlers/blob/blob_handler.go @@ -1,51 +1,55 @@ package blob import ( + "fmt" "net/http" - "regexp" "github.com/gin-gonic/gin" + "github.com/openmined/syftbox/internal/server/acl" "github.com/openmined/syftbox/internal/server/blob" -) - -var ( - regexDatasiteKey = regexp.MustCompile(`^[^\s@]+@[^\s@]+\.[^\s@]+/`) + "github.com/openmined/syftbox/internal/server/datasite" + "github.com/openmined/syftbox/internal/server/handlers/api" ) type BlobHandler struct { blob *blob.BlobService + acl *acl.ACLService } -func New(blob *blob.BlobService) *BlobHandler { - return &BlobHandler{blob: blob} +func New(blob *blob.BlobService, acl *acl.ACLService) *BlobHandler { + return &BlobHandler{blob: blob, acl: acl} } func (h *BlobHandler) UploadMultipart(ctx *gin.Context) { // todo - ctx.PureJSON(http.StatusNotImplemented, gin.H{ - "error": "not implemented", - }) + api.AbortWithError(ctx, http.StatusNotImplemented, api.CodeInvalidRequest, fmt.Errorf("not implemented")) } func (h *BlobHandler) UploadComplete(ctx *gin.Context) { // todo - ctx.PureJSON(http.StatusNotImplemented, gin.H{ - "error": "not implemented", - }) + api.AbortWithError(ctx, http.StatusNotImplemented, api.CodeInvalidRequest, fmt.Errorf("not implemented")) } func (h *BlobHandler) ListObjects(ctx *gin.Context) { res, err := h.blob.Index().List() if err != nil { - ctx.PureJSON(http.StatusInternalServerError, gin.H{ - "error": err.Error(), - }) + api.AbortWithError(ctx, http.StatusInternalServerError, api.CodeBlobListFailed, err) return } - ctx.PureJSON(http.StatusOK, res) + ctx.PureJSON(http.StatusOK, &gin.H{ + "blobs": res, + }) } -func isValidDatasiteKey(key string) bool { - return regexDatasiteKey.MatchString(key) +func (h *BlobHandler) checkPermissions(key string, user string, access acl.AccessLevel) error { + if datasite.IsOwner(key, user) { + return nil + } + + if err := h.acl.CanAccess(&acl.User{ID: user}, &acl.File{Path: key}, access); err != nil { + return err + } + + return nil } diff --git a/internal/server/handlers/blob/blob_handler_delete.go b/internal/server/handlers/blob/blob_handler_delete.go index 2871e169..38f62102 100644 --- a/internal/server/handlers/blob/blob_handler_delete.go +++ b/internal/server/handlers/blob/blob_handler_delete.go @@ -2,56 +2,58 @@ package blob import ( "fmt" + "log/slog" "net/http" "github.com/gin-gonic/gin" + "github.com/openmined/syftbox/internal/aclspec" + "github.com/openmined/syftbox/internal/server/acl" + "github.com/openmined/syftbox/internal/server/datasite" + "github.com/openmined/syftbox/internal/server/handlers/api" ) func (h *BlobHandler) DeleteObjects(ctx *gin.Context) { var req DeleteRequest - if err := ctx.ShouldBindJSON(&req); err != nil { - ctx.Error(fmt.Errorf("failed to bind json: %w", err)) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), - }) - return - } + user := ctx.GetString("user") - if len(req.Keys) == 0 { - ctx.Error(fmt.Errorf("keys cannot be empty")) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": "keys cannot be empty", - }) + if err := ctx.ShouldBindJSON(&req); err != nil { + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("failed to bind json: %w", err)) return } deleted := make([]string, 0, len(req.Keys)) - errors := make([]*BlobError, 0) + errors := make([]*BlobAPIError, 0) for _, key := range req.Keys { - if !isValidDatasiteKey(key) { - ctx.Error(fmt.Errorf("invalid datasite path: %s", key)) - errors = append(errors, &BlobError{ - Key: key, - Error: "invalid key", - }) + if !datasite.IsValidPath(key) { + errors = append(errors, NewBlobAPIError(api.CodeDatasiteInvalidPath, "invalid key", key)) continue } + + if err := h.checkPermissions(key, user, acl.AccessWrite); err != nil { + errors = append(errors, NewBlobAPIError(api.CodeAccessDenied, err.Error(), key)) + continue + } + _, err := h.blob.Backend().DeleteObject(ctx.Request.Context(), key) if err != nil { ctx.Error(fmt.Errorf("failed to delete object: %w", err)) - errors = append(errors, &BlobError{ - Key: key, - Error: err.Error(), - }) + errors = append(errors, NewBlobAPIError(api.CodeBlobDeleteFailed, err.Error(), key)) continue } + + if aclspec.IsACLFile(key) { + // don't worry the above permissions check will make sure that the user is admin + ok := h.acl.RemoveRuleSet(key) + if !ok { + slog.Warn("remove ruleset returned false", "key", key) + } + } + deleted = append(deleted, key) } code := http.StatusOK - if len(deleted) == 0 && len(errors) > 0 { - code = http.StatusBadRequest - } else if len(deleted) > 0 && len(errors) > 0 { + if len(deleted) >= 0 && len(errors) >= 0 { code = http.StatusMultiStatus } diff --git a/internal/server/handlers/blob/blob_handler_download_presigned.go b/internal/server/handlers/blob/blob_handler_download_presigned.go index 518807cc..b1346d11 100644 --- a/internal/server/handlers/blob/blob_handler_download_presigned.go +++ b/internal/server/handlers/blob/blob_handler_download_presigned.go @@ -5,60 +5,70 @@ import ( "net/http" "github.com/gin-gonic/gin" + "github.com/openmined/syftbox/internal/server/acl" + "github.com/openmined/syftbox/internal/server/datasite" + "github.com/openmined/syftbox/internal/server/handlers/api" ) func (h *BlobHandler) DownloadObjectsPresigned(ctx *gin.Context) { - var req PresignUrlRequest + var req PresignURLRequest + user := ctx.GetString("user") if err := ctx.ShouldBindJSON(&req); err != nil { - ctx.Error(fmt.Errorf("failed to bind json: %w", err)) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), - }) - return - } - - if len(req.Keys) == 0 { - ctx.Error(fmt.Errorf("keys cannot be empty")) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": "keys cannot be empty", - }) + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("failed to bind json: %w", err)) return } - urls := make([]*BlobUrl, 0, len(req.Keys)) - errors := make([]*BlobError, 0) + urls := make([]*BlobURL, 0, len(req.Keys)) + errors := make([]*BlobAPIError, 0) index := h.blob.Index() for _, key := range req.Keys { - if !isValidDatasiteKey(key) { - ctx.Error(fmt.Errorf("invalid datasite path: %s", key)) - errors = append(errors, &BlobError{ - Key: key, - Error: "invalid key", + if !datasite.IsValidPath(key) { + errors = append(errors, &BlobAPIError{ + SyftAPIError: api.SyftAPIError{ + Code: api.CodeDatasiteInvalidPath, + Message: "invalid key", + }, + Key: key, + }) + continue + } + + if err := h.checkPermissions(key, user, acl.AccessRead); err != nil { + errors = append(errors, &BlobAPIError{ + SyftAPIError: api.SyftAPIError{ + Code: api.CodeAccessDenied, + Message: err.Error(), + }, + Key: key, }) continue } _, ok := index.Get(key) if !ok { - ctx.Error(fmt.Errorf("object not found: %s", key)) - errors = append(errors, &BlobError{ - Key: key, - Error: "object not found", + errors = append(errors, &BlobAPIError{ + SyftAPIError: api.SyftAPIError{ + Code: api.CodeBlobNotFound, + Message: "object not found", + }, + Key: key, }) continue } url, err := h.blob.Backend().GetObjectPresigned(ctx, key) if err != nil { - ctx.Error(fmt.Errorf("failed to get object: %w", err)) - errors = append(errors, &BlobError{ - Key: key, - Error: err.Error(), + errors = append(errors, &BlobAPIError{ + SyftAPIError: api.SyftAPIError{ + Code: api.CodeBlobGetFailed, + Message: err.Error(), + }, + Key: key, }) continue } - urls = append(urls, &BlobUrl{ + urls = append(urls, &BlobURL{ Key: key, Url: url, }) @@ -71,7 +81,7 @@ func (h *BlobHandler) DownloadObjectsPresigned(ctx *gin.Context) { code = http.StatusMultiStatus } - ctx.PureJSON(code, &PresignUrlResponse{ + ctx.PureJSON(code, &PresignURLResponse{ URLs: urls, Errors: errors, }) diff --git a/internal/server/handlers/blob/blob_handler_types.go b/internal/server/handlers/blob/blob_handler_types.go index 32108af2..657263f1 100644 --- a/internal/server/handlers/blob/blob_handler_types.go +++ b/internal/server/handlers/blob/blob_handler_types.go @@ -1,13 +1,33 @@ package blob -type BlobUrl struct { +import ( + "fmt" + + "github.com/openmined/syftbox/internal/server/handlers/api" +) + +type BlobURL struct { Key string `json:"key"` Url string `json:"url"` } -type BlobError struct { - Key string `json:"key"` - Error string `json:"error"` +type BlobAPIError struct { + api.SyftAPIError + Key string `json:"key"` +} + +func NewBlobAPIError(code string, message string, key string) *BlobAPIError { + return &BlobAPIError{ + Key: key, + SyftAPIError: api.SyftAPIError{ + Code: code, + Message: message, + }, + } +} + +func (e *BlobAPIError) Error() string { + return fmt.Sprintf("syft api blob error: code=%s, message=%s, key=%s", e.Code, e.Message, e.Key) } type UploadRequest struct { @@ -26,20 +46,20 @@ type UploadResponse struct { LastModified string `json:"lastModified"` } -type PresignUrlRequest struct { - Keys []string `json:"keys" binding:"required"` +type PresignURLRequest struct { + Keys []string `json:"keys" binding:"required,min=1"` } -type PresignUrlResponse struct { - URLs []*BlobUrl `json:"urls"` - Errors []*BlobError `json:"errors"` +type PresignURLResponse struct { + URLs []*BlobURL `json:"urls"` + Errors []*BlobAPIError `json:"errors"` } type DeleteRequest struct { - Keys []string `json:"keys" binding:"required"` + Keys []string `json:"keys" binding:"required,min=1"` } type DeleteResponse struct { - Deleted []string `json:"deleted"` - Errors []*BlobError `json:"errors"` + Deleted []string `json:"deleted"` + Errors []*BlobAPIError `json:"errors"` } diff --git a/internal/server/handlers/blob/blob_handler_upload.go b/internal/server/handlers/blob/blob_handler_upload.go index 8af86525..6f153f0b 100644 --- a/internal/server/handlers/blob/blob_handler_upload.go +++ b/internal/server/handlers/blob/blob_handler_upload.go @@ -6,66 +6,66 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/openmined/syftbox/internal/aclspec" + "github.com/openmined/syftbox/internal/server/acl" "github.com/openmined/syftbox/internal/server/blob" + "github.com/openmined/syftbox/internal/server/datasite" + "github.com/openmined/syftbox/internal/server/handlers/api" ) func (h *BlobHandler) Upload(ctx *gin.Context) { + user := ctx.GetString("user") + + if key := ctx.Query("key"); aclspec.IsACLFile(key) { + h.UploadACL(ctx) + return + } + var req UploadRequest if err := ctx.ShouldBindQuery(&req); err != nil { - ctx.Error(fmt.Errorf("failed to bind query: %w", err)) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), - }) + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("failed to bind query: %w", err)) + return + } + + // todo check if new change using etag + + if !datasite.IsValidPath(req.Key) { + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeDatasiteInvalidPath, fmt.Errorf("invalid key: %s", req.Key)) return } - if !isValidDatasiteKey(req.Key) { - ctx.Error(fmt.Errorf("invalid datasite path: %s", req.Key)) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": "invalid key", - }) + if err := h.checkPermissions(req.Key, user, acl.AccessWrite); err != nil { + api.AbortWithError(ctx, http.StatusForbidden, api.CodeAccessDenied, err) return } // get form file file, err := ctx.FormFile("file") if err != nil { - ctx.Error(fmt.Errorf("failed to get form file: %w", err)) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("invalid file: %s", err), - }) + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("invalid file: %w", err)) return } // check file size if file.Size <= 0 { - ctx.Error(fmt.Errorf("invalid file: size is 0")) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": "invalid file: size is 0", - }) + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("invalid file: size is 0")) return } fd, err := file.Open() if err != nil { - ctx.Error(fmt.Errorf("failed to open file: %w", err)) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("invalid file: %s", err), - }) + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("invalid file file: %w", err)) return } - defer fd.Close() + result, err := h.blob.Backend().PutObject(ctx.Request.Context(), &blob.PutObjectParams{ Key: req.Key, Size: file.Size, Body: fd, }) if err != nil { - ctx.Error(fmt.Errorf("failed to put object: %w", err)) - ctx.PureJSON(http.StatusInternalServerError, gin.H{ - "error": "server error: could not persist file", - }) + api.AbortWithError(ctx, http.StatusInternalServerError, api.CodeBlobPutFailed, fmt.Errorf("failed to put object: %w", err)) return } diff --git a/internal/server/handlers/blob/blob_handler_upload_acl.go b/internal/server/handlers/blob/blob_handler_upload_acl.go new file mode 100644 index 00000000..d1c58f8f --- /dev/null +++ b/internal/server/handlers/blob/blob_handler_upload_acl.go @@ -0,0 +1,101 @@ +package blob + +import ( + "bytes" + "fmt" + "io" + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/openmined/syftbox/internal/aclspec" + "github.com/openmined/syftbox/internal/server/acl" + "github.com/openmined/syftbox/internal/server/blob" + "github.com/openmined/syftbox/internal/server/datasite" + "github.com/openmined/syftbox/internal/server/handlers/api" +) + +func (h *BlobHandler) UploadACL(ctx *gin.Context) { + var req UploadRequest + user := ctx.GetString("user") + + if err := ctx.ShouldBindQuery(&req); err != nil { + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("failed to bind query: %w", err)) + return + } + + if !(datasite.IsValidPath(req.Key) && aclspec.IsACLFile(req.Key)) { + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeDatasiteInvalidPath, fmt.Errorf("invalid ruleset path: %s", req.Key)) + return + } + + // check if user has admin rights + if err := h.checkPermissions(req.Key, user, acl.AccessAdmin); err != nil { + api.AbortWithError(ctx, http.StatusForbidden, api.CodeAccessDenied, err) + return + } + + // get form file + file, err := ctx.FormFile("file") + if err != nil { + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("failed to get form file: %w", err)) + return + } + + // check file size + if file.Size <= 0 { + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("invalid file: size is 0")) + return + } + + // open the file + fd, err := file.Open() + if err != nil { + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("failed to open file: %w", err)) + return + } + defer fd.Close() + + // read the file into memory, because we need to read it twice + fdBytes, err := io.ReadAll(fd) + if err != nil { + api.AbortWithError(ctx, http.StatusInternalServerError, api.CodeInvalidRequest, fmt.Errorf("failed to read file: %w", err)) + return + } + aclBytesReader := bytes.NewReader(fdBytes) + + // load aclspec + ruleset, err := aclspec.LoadFromReader(req.Key, aclBytesReader) + if err != nil { + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("failed to read ruleset: %w", err)) + } + + // upload file to s3 + // because that's always the ground truth + blobBytesReader := bytes.NewReader(fdBytes) + result, err := h.blob.Backend().PutObject(ctx.Request.Context(), &blob.PutObjectParams{ + Key: req.Key, + Size: file.Size, + Body: blobBytesReader, + }) + if err != nil { + api.AbortWithError(ctx, http.StatusInternalServerError, api.CodeBlobPutFailed, fmt.Errorf("failed to put object: %w", err)) + return + } + + // add it to the acl service + if _, err := h.acl.AddRuleSet(ruleset); err != nil { + // if this error happens, there's a pretty serious bug in the acl service + api.AbortWithError(ctx, http.StatusInternalServerError, api.CodeACLUpdateFailed, fmt.Errorf("failed to update ruleset: %w", err)) + return + } + + // response with UploadAccept + ctx.PureJSON(http.StatusOK, &UploadResponse{ + Key: result.Key, + Version: result.Version, + ETag: result.ETag, + Size: result.Size, + LastModified: result.LastModified.Format(time.RFC3339), + }) +} diff --git a/internal/server/handlers/blob/blob_handler_upload_presigned.go b/internal/server/handlers/blob/blob_handler_upload_presigned.go index 7897299b..ac141fe5 100644 --- a/internal/server/handlers/blob/blob_handler_upload_presigned.go +++ b/internal/server/handlers/blob/blob_handler_upload_presigned.go @@ -5,47 +5,57 @@ import ( "net/http" "github.com/gin-gonic/gin" + "github.com/openmined/syftbox/internal/server/acl" + "github.com/openmined/syftbox/internal/server/datasite" + "github.com/openmined/syftbox/internal/server/handlers/api" ) func (h *BlobHandler) UploadPresigned(ctx *gin.Context) { - var req PresignUrlRequest - if err := ctx.ShouldBindJSON(&req); err != nil { - ctx.Error(fmt.Errorf("failed to bind json: %w", err)) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), - }) - return - } + var req PresignURLRequest + user := ctx.GetString("user") - if len(req.Keys) == 0 { - ctx.Error(fmt.Errorf("keys cannot be empty")) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": "keys cannot be empty", - }) + if err := ctx.ShouldBindJSON(&req); err != nil { + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("failed to bind json: %w", err)) return } - urls := make([]*BlobUrl, 0, len(req.Keys)) - errors := make([]*BlobError, 0) + urls := make([]*BlobURL, 0, len(req.Keys)) + errors := make([]*BlobAPIError, 0) for _, key := range req.Keys { - if !isValidDatasiteKey(key) { - ctx.Error(fmt.Errorf("invalid datasite path: %s", key)) - errors = append(errors, &BlobError{ - Key: key, - Error: "invalid key", + if !datasite.IsValidPath(key) { + errors = append(errors, &BlobAPIError{ + SyftAPIError: api.SyftAPIError{ + Code: api.CodeDatasiteInvalidPath, + Message: "invalid key", + }, + Key: key, }) continue } + + if err := h.checkPermissions(key, user, acl.AccessWrite); err != nil { + errors = append(errors, &BlobAPIError{ + SyftAPIError: api.SyftAPIError{ + Code: api.CodeAccessDenied, + Message: err.Error(), + }, + Key: key, + }) + continue + } + url, err := h.blob.Backend().PutObjectPresigned(ctx, key) if err != nil { - ctx.Error(fmt.Errorf("failed to put object: %w", err)) - errors = append(errors, &BlobError{ - Key: key, - Error: err.Error(), + errors = append(errors, &BlobAPIError{ + SyftAPIError: api.SyftAPIError{ + Code: api.CodeBlobPutFailed, + Message: err.Error(), + }, + Key: key, }) continue } - urls = append(urls, &BlobUrl{ + urls = append(urls, &BlobURL{ Key: key, Url: url, }) @@ -53,12 +63,12 @@ func (h *BlobHandler) UploadPresigned(ctx *gin.Context) { code := http.StatusOK if len(urls) == 0 && len(errors) > 0 { - code = http.StatusBadRequest + code = http.StatusMultiStatus } else if len(urls) > 0 && len(errors) > 0 { code = http.StatusMultiStatus } - ctx.PureJSON(code, &PresignUrlResponse{ + ctx.PureJSON(code, &PresignURLResponse{ URLs: urls, Errors: errors, }) diff --git a/internal/server/handlers/datasite/datasite_handler.go b/internal/server/handlers/datasite/datasite_handler.go index 4d124bd4..390ea33b 100644 --- a/internal/server/handlers/datasite/datasite_handler.go +++ b/internal/server/handlers/datasite/datasite_handler.go @@ -1,7 +1,6 @@ package datasite import ( - "fmt" "net/http" "github.com/gin-gonic/gin" @@ -20,15 +19,6 @@ func New(svc *datasite.DatasiteService) *DatasiteHandler { func (h *DatasiteHandler) GetView(ctx *gin.Context) { user := ctx.GetString("user") - - if user == "" { - ctx.Error(fmt.Errorf("`user` is required")) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": "user is required", - }) - return - } - ctx.PureJSON(http.StatusOK, gin.H{ "files": h.svc.GetView(user), }) diff --git a/internal/server/handlers/explorer/explorer_handler.go b/internal/server/handlers/explorer/explorer_handler.go index 5f1a7877..f615685c 100644 --- a/internal/server/handlers/explorer/explorer_handler.go +++ b/internal/server/handlers/explorer/explorer_handler.go @@ -19,6 +19,7 @@ import ( "github.com/openmined/syftbox/internal/aclspec" "github.com/openmined/syftbox/internal/server/acl" "github.com/openmined/syftbox/internal/server/blob" + "github.com/openmined/syftbox/internal/server/handlers/api" ) //go:embed index.html.tmpl @@ -29,13 +30,13 @@ var notFoundOfTmpl string type ExplorerHandler struct { blob *blob.BlobService - acl *acl.AclService + acl *acl.ACLService tplIndex *template.Template tpl404 *template.Template } // New creates a new Explorer instance -func New(svc *blob.BlobService, acl *acl.AclService) *ExplorerHandler { +func New(svc *blob.BlobService, acl *acl.ACLService) *ExplorerHandler { funcMap := template.FuncMap{ "basename": filepath.Base, "humanizeSize": func(size int64) string { @@ -57,7 +58,7 @@ func New(svc *blob.BlobService, acl *acl.AclService) *ExplorerHandler { func (e *ExplorerHandler) Handler(c *gin.Context) { path := strings.TrimPrefix(c.Param("filepath"), "/") contents := e.listContents(path) - if contents.IsDir { + if contents.IsDir || contents.EmptyDir() { e.serveDir(c, path, contents) } else { e.serveFile(c, path) @@ -99,6 +100,17 @@ func (e *ExplorerHandler) listContents(prefix string) *directoryContents { } for _, blob := range blobs { + + // check if public readable + if err := e.acl.CanAccess( + &acl.User{ID: aclspec.Everyone}, + &acl.File{Path: blob.Key}, + acl.AccessRead, + ); err != nil { + // don't reveal if the file is private or not + continue + } + if strings.HasPrefix(blob.Key, prefix) { relPath := strings.TrimPrefix(blob.Key, prefix) if relPath == "" { @@ -149,15 +161,14 @@ func (e *ExplorerHandler) serveDir(c *gin.Context, path string, contents *direct // Generate an HTML response c.Header("Content-Type", "text/html; charset=utf-8") if err := e.tplIndex.Execute(c.Writer, data); err != nil { - c.Error(fmt.Errorf("failed to execute template: %w", err)) - c.String(http.StatusInternalServerError, "internal server error") + api.AbortWithError(c, http.StatusInternalServerError, api.CodeInternalError, fmt.Errorf("failed to execute template: %w", err)) } } // Serve a file from S3 func (e *ExplorerHandler) serveFile(c *gin.Context, key string) { if err := e.acl.CanAccess( - &acl.User{ID: aclspec.Everyone, IsOwner: false}, + &acl.User{ID: aclspec.Everyone}, &acl.File{Path: key}, acl.AccessRead, ); err != nil { @@ -180,8 +191,7 @@ func (e *ExplorerHandler) serveFile(c *gin.Context, key string) { // Stream response body directly _, err = io.Copy(c.Writer, resp.Body) if err != nil { - c.Error(fmt.Errorf("failed to stream file: %w", err)) - c.String(http.StatusInternalServerError, "internal server error") + api.AbortWithError(c, http.StatusInternalServerError, api.CodeInternalError, fmt.Errorf("failed to stream file: %w", err)) return } } @@ -198,8 +208,7 @@ func (e *ExplorerHandler) detectContentType(key string) string { func (e *ExplorerHandler) serve404(c *gin.Context, key string) { c.Header("Content-Type", "text/html; charset=utf-8") if err := e.tpl404.Execute(c.Writer, map[string]any{"Key": key}); err != nil { - c.Error(fmt.Errorf("failed to execute template: %w", err)) - c.String(http.StatusInternalServerError, "internal server error") + api.AbortWithError(c, http.StatusInternalServerError, api.CodeInternalError, fmt.Errorf("failed to execute template: %w", err)) } } diff --git a/internal/server/handlers/explorer/explorer_handler_types.go b/internal/server/handlers/explorer/explorer_handler_types.go index bae4ce55..6909a81b 100644 --- a/internal/server/handlers/explorer/explorer_handler_types.go +++ b/internal/server/handlers/explorer/explorer_handler_types.go @@ -15,3 +15,7 @@ type directoryContents struct { Files []*blob.BlobInfo Folders []string } + +func (d *directoryContents) EmptyDir() bool { + return d.IsDir && len(d.Files) == 0 && len(d.Folders) == 0 +} diff --git a/internal/server/handlers/install/install.ps1 b/internal/server/handlers/install/install.ps1 index aa552407..f59d56f3 100644 --- a/internal/server/handlers/install/install.ps1 +++ b/internal/server/handlers/install/install.ps1 @@ -8,7 +8,7 @@ $AskRunClient = $true $RunClient = $false $AppName = "syftbox" -$ArtifactBaseUrl = if ($env:ARTIFACT_BASE_URL) { $env:ARTIFACT_BASE_URL } else { "https://syftboxdev.openmined.org" } +$ArtifactBaseUrl = if ($env:ARTIFACT_BASE_URL) { $env:ARTIFACT_BASE_URL } else { "https://syftbox.net" } $ArtifactDownloadUrl = "$ArtifactBaseUrl/releases" function Write-Error-Exit($message) { diff --git a/internal/server/handlers/install/install.sh b/internal/server/handlers/install/install.sh index 7339de38..9d4d7a2d 100644 --- a/internal/server/handlers/install/install.sh +++ b/internal/server/handlers/install/install.sh @@ -13,7 +13,7 @@ RUN_CLIENT=0 INSTALL_APPS=${INSTALL_APPS:-""} APP_NAME="syftbox" -ARTIFACT_BASE_URL=${ARTIFACT_BASE_URL:-"https://syftboxdev.openmined.org"} +ARTIFACT_BASE_URL=${ARTIFACT_BASE_URL:-"https://syftbox.net"} ARTIFACT_DOWNLOAD_URL="$ARTIFACT_BASE_URL/releases" SYFTBOX_BINARY_PATH="$HOME/.local/bin/syftbox" diff --git a/internal/server/handlers/send/poll.html b/internal/server/handlers/send/poll.html new file mode 100644 index 00000000..9447c485 --- /dev/null +++ b/internal/server/handlers/send/poll.html @@ -0,0 +1,45 @@ + + + + + + + + Processing Request - SyftBox + + + + +

Processing Your Request

+
+

Your request is being processed. This page will automatically refresh in {{ .RefreshInterval }} seconds.

+

If you are not redirected, click here to check the status.

+
+ + + \ No newline at end of file diff --git a/internal/server/handlers/send/send_handler.go b/internal/server/handlers/send/send_handler.go new file mode 100644 index 00000000..18eb6445 --- /dev/null +++ b/internal/server/handlers/send/send_handler.go @@ -0,0 +1,206 @@ +package send + +import ( + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "strconv" + + "github.com/gin-gonic/gin" +) + +// SendHandler handles HTTP requests for sending messages +type SendHandler struct { + service SendServiceInterface +} + +// New creates a new send handler +func New(msgDispatcher MessageDispatcher, msgStore RPCMsgStore) *SendHandler { + service := NewSendService(msgDispatcher, msgStore, nil) + return &SendHandler{service: service} +} + +// SendMsg handles sending a message +func (h *SendHandler) SendMsg(ctx *gin.Context) { + var req MessageRequest + + // Bind query parameters + if err := ctx.ShouldBindQuery(&req); err != nil { + ctx.PureJSON(http.StatusBadRequest, APIError{ + Error: ErrorInvalidRequest, + Message: err.Error(), + }) + return + } + + // Bind headers + if err := ctx.ShouldBindHeader(&req); err != nil { + ctx.PureJSON(http.StatusBadRequest, APIError{ + Error: ErrorInvalidRequest, + Message: err.Error(), + }) + return + } + + // Bind request method + req.Method = ctx.Request.Method + + // Bind headers + req.BindHeaders(ctx) + + // Read request body with size limit + bodyBytes, err := readRequestBody(ctx, h.service.GetConfig().MaxBodySize) + if err != nil { + ctx.PureJSON(http.StatusBadRequest, APIError{ + Error: ErrorInvalidRequest, + Message: err.Error(), + }) + return + } + + result, err := h.service.SendMessage(ctx.Request.Context(), &req, bodyBytes) + if err != nil { + slog.Error("failed to send message", "error", err) + ctx.PureJSON(http.StatusInternalServerError, APIError{ + Error: ErrorInternal, + Message: err.Error(), + }) + return + } + + if result.Response != nil { + ctx.PureJSON(result.Status, APIResponse{ + RequestID: result.RequestID, + Data: result.Response, + }) + return + } + + // add poll url as location header + ctx.Header("Location", result.PollURL) + + // return poll info + ctx.PureJSON(result.Status, APIResponse{ + RequestID: result.RequestID, + Data: PollInfo{ + PollURL: result.PollURL, + }, + Message: "Request has been accepted. Please check back later.", + }) +} + +// PollForResponse handles polling for a response +func (h *SendHandler) PollForResponse(ctx *gin.Context) { + var req PollObjectRequest + if err := ctx.ShouldBindQuery(&req); err != nil { + slog.Error("failed to bind query parameters", "error", err) + ctx.PureJSON(http.StatusBadRequest, APIError{ + Error: ErrorInvalidRequest, + Message: err.Error(), + RequestID: req.RequestID, + }) + return + } + + if err := ctx.ShouldBindHeader(&req); err != nil { + slog.Error("failed to bind headers", "error", err) + ctx.PureJSON(http.StatusBadRequest, APIError{ + Error: ErrorInvalidRequest, + Message: err.Error(), + RequestID: req.RequestID, + }) + return + } + + result, err := h.service.PollForResponse(ctx.Request.Context(), &req) + contentTypeHTML := ctx.Request.Header.Get("Content-Type") == "text/html" + + if err != nil { + if errors.Is(err, ErrPollTimeout) { + // Add poll URL to the Response header + pollURL := h.service.constructPollURL( + req.RequestID, + req.SyftURL, + req.From, + req.AsRaw, + ) + + // calculate refresh interval in seconds + var refreshInterval int + if req.Timeout > 0 { + refreshInterval = req.Timeout / 1000 + } else { + refreshInterval = h.service.GetConfig().DefaultTimeoutMs / 1000 + } + + // add poll url as location header and retry after header + ctx.Header("Location", pollURL) + ctx.Header("Retry-After", strconv.Itoa(refreshInterval)) + + if contentTypeHTML { + // Return a HTML page with a link to the poll URL + // with auto refresh capability + ctx.HTML(http.StatusAccepted, "poll.html", gin.H{ + "PollURL": pollURL, + "BaseURL": ctx.Request.Host, + "RefreshInterval": refreshInterval, // in seconds + }) + return + } else { + ctx.PureJSON(http.StatusAccepted, APIError{ + Error: ErrorTimeout, + Message: "Polling timeout reached. The request may still be processing.", + RequestID: req.RequestID, + }) + return + } + } + + if errors.Is(err, ErrNoRequest) { + ctx.PureJSON(http.StatusNotFound, APIError{ + Error: ErrorNotFound, + Message: "No request found.", + RequestID: req.RequestID, + }) + return + } + + slog.Error("failed to poll for response", "error", err) + ctx.PureJSON(http.StatusInternalServerError, APIError{ + Error: ErrorInternal, + Message: err.Error(), + RequestID: req.RequestID, + }) + return + } + + ctx.PureJSON(result.Status, APIResponse{ + RequestID: result.RequestID, + Data: result.Response, + }) +} + +// readRequestBody reads and validates the request body +func readRequestBody(ctx *gin.Context, maxSize int64) ([]byte, error) { + body := ctx.Request.Body + defer ctx.Request.Body.Close() + + // Read body bytes + bodyBytes, err := io.ReadAll(body) + if err != nil { + return nil, fmt.Errorf("failed to read request body: %w", err) + } + + // Check if body size exceeds maximum allowed size + if maxSize > 0 && int64(len(bodyBytes)) > maxSize { + return nil, fmt.Errorf( + "request body too large: %d bytes (max: %d bytes)", + len(bodyBytes), + maxSize, + ) + } + + return bodyBytes, nil +} diff --git a/internal/server/handlers/send/send_handler_test.go b/internal/server/handlers/send/send_handler_test.go new file mode 100644 index 00000000..3262f860 --- /dev/null +++ b/internal/server/handlers/send/send_handler_test.go @@ -0,0 +1,542 @@ +package send + +import ( + "context" + "embed" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/openmined/syftbox/internal/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +//go:embed *.html +var templateFS embed.FS + +// MockSendService implements the SendService interface for testing +type MockSendService struct { + mock.Mock +} + +func (m *MockSendService) SendMessage(ctx context.Context, req *MessageRequest, bodyBytes []byte) (*SendResult, error) { + args := m.Called(ctx, req, bodyBytes) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*SendResult), args.Error(1) +} + +func (m *MockSendService) PollForResponse(ctx context.Context, req *PollObjectRequest) (*PollResult, error) { + args := m.Called(ctx, req) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*PollResult), args.Error(1) +} + +func (m *MockSendService) constructPollURL(requestID string, syftURL utils.SyftBoxURL, from string, asRaw bool) string { + args := m.Called(requestID, syftURL, from, asRaw) + return args.String(0) +} + +func (m *MockSendService) GetConfig() *Config { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(*Config) +} + +// Helper function to create a test gin context +func createTestContext( + method string, + url string, + body io.Reader, + query_params map[string]string, + headers map[string]string, +) (*gin.Context, *httptest.ResponseRecorder) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + // Build URL with query parameters + if len(query_params) > 0 { + params := make([]string, 0, len(query_params)) + for k, v := range query_params { + params = append(params, k+"="+v) + } + url = url + "?" + strings.Join(params, "&") + } + + req := httptest.NewRequest(method, url, body) + + // Add headers + for k, v := range headers { + req.Header.Set(k, v) + } + + c.Request = req + return c, w +} + +func TestSendHandler_SendMsg_SuccessWithResponse(t *testing.T) { + // Setup + mockService := &MockSendService{} + handler := &SendHandler{service: mockService} + + // Test data + body := `{"key": "value"}` + query_params := map[string]string{ + "x-syft-url": "syft://test@datasite.com/app_data/testapp/rpc/endpoint", + "x-syft-from": "testuser@example.com", + "timeout": "1000", + } + headers := map[string]string{ + "Content-Type": "application/json", + } + + c, w := createTestContext("POST", "/send/msg/", strings.NewReader(body), query_params, headers) + + // Mock expectations + expectedResult := &SendResult{ + Status: http.StatusOK, + RequestID: "test-request-id", + Response: map[string]interface{}{ + "message": "success", + }, + } + + mockService.On("SendMessage", mock.Anything, mock.MatchedBy(func(req *MessageRequest) bool { + return req.From == "testuser@example.com" && req.Method == "POST" + }), []byte(body)).Return(expectedResult, nil) + mockService.On("GetConfig").Return(&Config{MaxBodySize: 4 * 1024 * 1024}) // 4MB + + // Execute + handler.SendMsg(c) + + // Assertions + assert.Equal(t, http.StatusOK, w.Code) + + var response APIResponse + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "test-request-id", response.RequestID) + assert.NotNil(t, response.Data) + + mockService.AssertExpectations(t) +} + +func TestSendHandler_SendMsg_SuccessWithPolling(t *testing.T) { + // Setup + mockService := &MockSendService{} + handler := &SendHandler{service: mockService} + + // Test data + body := `{"key": "value"}` + query_params := map[string]string{ + "x-syft-url": "syft://test@datasite.com/app_data/testapp/rpc/endpoint", + "x-syft-from": "testuser@example.com", + "timeout": "1000", + } + headers := map[string]string{ + "x-syft-url": "syft://test@datasite.com/app_data/testapp/rpc/endpoint", + "x-syft-from": "testuser@example.com", + "Content-Type": "application/json", + } + + c, w := createTestContext("POST", "/send/msg/", strings.NewReader(body), query_params, headers) + + // Mock expectations + expectedResult := &SendResult{ + Status: http.StatusAccepted, + RequestID: "test-request-id", + PollURL: "/api/v1/send/poll?x-syft-request-id=test-request-id&x-syft-url=syft://test@datasite.com/app_data/testapp/rpc/endpoint&x-syft-from=testuser@example.com&x-syft-raw=false", + } + + mockService.On("SendMessage", mock.Anything, mock.Anything, []byte(body)).Return(expectedResult, nil) + mockService.On("GetConfig").Return(&Config{MaxBodySize: 4 * 1024 * 1024}) // 4MB + + // Execute + handler.SendMsg(c) + + // Assertions + assert.Equal(t, http.StatusAccepted, w.Code) + assert.Equal(t, expectedResult.PollURL, w.Header().Get("Location")) + + var response APIResponse + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "test-request-id", response.RequestID) + assert.Equal(t, "Request has been accepted. Please check back later.", response.Message) + + // Check PollInfo in response + pollInfo, ok := response.Data.(map[string]interface{}) + assert.True(t, ok) + assert.Equal(t, expectedResult.PollURL, pollInfo["poll_url"]) + + mockService.AssertExpectations(t) +} + +func TestSendHandler_SendMsg_InvalidRequestBinding(t *testing.T) { + // Setup + mockService := &MockSendService{} + handler := &SendHandler{service: mockService} + + // Test data - missing required fields + body := `{"key": "value"}` + query_params := map[string]string{ + "x-syft-url": "syft://test@datasite.com/app_data/testapp/rpc/endpoint", + } + headers := map[string]string{ + "Content-Type": "application/json", + } + + c, w := createTestContext("POST", "/send/msg/", strings.NewReader(body), query_params, headers) + + // Execute + handler.SendMsg(c) + + // Assertions + assert.Equal(t, http.StatusBadRequest, w.Code) + + var response APIError + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, ErrorInvalidRequest, response.Error) + assert.Contains(t, response.Message, "required") +} + +func TestSendHandler_SendMsg_BodyTooLarge(t *testing.T) { + // Setup + mockService := &MockSendService{} + handler := &SendHandler{service: mockService} + + // Test data + largeBody := strings.Repeat("a", 5*1024*1024) // 5MB + query_params := map[string]string{ + "x-syft-url": "syft://test@datasite.com/app_data/testapp/rpc/endpoint", + "x-syft-from": "testuser@example.com", + } + headers := map[string]string{ + "Content-Type": "application/json", + } + + c, w := createTestContext("POST", "/send/msg/", strings.NewReader(largeBody), query_params, headers) + + // Mock expectations + mockService.On("GetConfig").Return(&Config{MaxBodySize: 4 * 1024 * 1024}) // 4MB + + // Execute + handler.SendMsg(c) + + // Assertions + assert.Equal(t, http.StatusBadRequest, w.Code) + + var response APIError + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, ErrorInvalidRequest, response.Error) + assert.Contains(t, response.Message, "too large") +} + +func TestSendHandler_SendMsg_ServiceError(t *testing.T) { + // Setup + mockService := &MockSendService{} + handler := &SendHandler{service: mockService} + + // Test data + body := `{"key": "value"}` + query_params := map[string]string{ + "x-syft-url": "syft://test@datasite.com/app_data/testapp/rpc/endpoint", + "x-syft-from": "testuser@example.com", + } + headers := map[string]string{ + "Content-Type": "application/json", + } + + c, w := createTestContext("POST", "/send/msg/", strings.NewReader(body), query_params, headers) + + // Mock expectations + mockService.On("SendMessage", mock.Anything, mock.Anything, []byte(body)).Return(nil, errors.New("service error")) + mockService.On("GetConfig").Return(&Config{MaxBodySize: 4 * 1024 * 1024}) // 4MB + // Execute + handler.SendMsg(c) + + // Assertions + assert.Equal(t, http.StatusInternalServerError, w.Code) + + var response APIError + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, ErrorInternal, response.Error) + assert.Equal(t, "service error", response.Message) + + mockService.AssertExpectations(t) +} + +func TestSendHandler_PollForResponse_Success(t *testing.T) { + // Setup + mockService := &MockSendService{} + handler := &SendHandler{service: mockService} + + // Test data + query_params := map[string]string{ + "x-syft-request-id": "test-request-id", + "x-syft-url": "syft://test@datasite.com/app_data/testapp/rpc/endpoint", + "x-syft-from": "test-user", + } + headers := map[string]string{ + "Content-Type": "application/json", + } + + c, w := createTestContext("GET", "/send/poll/", nil, query_params, headers) + + // Mock expectations + expectedResult := &PollResult{ + Status: http.StatusOK, + RequestID: "test-request-id", + Response: map[string]interface{}{ + "message": "success", + }, + } + + mockService.On("PollForResponse", mock.Anything, mock.MatchedBy(func(req *PollObjectRequest) bool { + return req.RequestID == "test-request-id" && req.From == "test-user" + })).Return(expectedResult, nil) + + // Execute + handler.PollForResponse(c) + + // Assertions + assert.Equal(t, http.StatusOK, w.Code) + + var response APIResponse + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "test-request-id", response.RequestID) + assert.NotNil(t, response.Data) + + mockService.AssertExpectations(t) +} + +func TestSendHandler_PollForResponse_TimeoutWithJSON(t *testing.T) { + // Setup + mockService := &MockSendService{} + handler := &SendHandler{service: mockService} + + // Test data + query_params := map[string]string{ + "x-syft-request-id": "test-request-id", + "x-syft-url": "syft://test@datasite.com/app_data/testapp/rpc/endpoint", + "x-syft-from": "test-user", + "timeout": "1000", + } + headers := map[string]string{ + "Content-Type": "application/json", + } + + c, w := createTestContext("GET", "/send/poll/", nil, query_params, headers) + + // Mock expectations + pollURL := "/api/v1/send/poll?x-syft-request-id=test-request-id&x-syft-url=syft://test@datasite.com/app_data/testapp/rpc/endpoint&x-syft-from=test-user&x-syft-raw=false" + + mockService.On("PollForResponse", mock.Anything, mock.Anything).Return(nil, ErrPollTimeout) + mockService.On("constructPollURL", "test-request-id", mock.Anything, "test-user", false).Return(pollURL) + + // Execute + handler.PollForResponse(c) + + // Assertions + assert.Equal(t, http.StatusAccepted, w.Code) + assert.Equal(t, pollURL, w.Header().Get("Location")) + assert.Equal(t, "1", w.Header().Get("Retry-After")) + + var response APIError + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, ErrorTimeout, response.Error) + assert.Equal(t, "test-request-id", response.RequestID) + + mockService.AssertExpectations(t) +} + +func TestSendHandler_PollForResponse_NotFound(t *testing.T) { + // Setup + mockService := &MockSendService{} + handler := &SendHandler{service: mockService} + + // Test data + query_params := map[string]string{ + "x-syft-request-id": "test-request-id", + "x-syft-url": "syft://test@datasite.com/app_data/testapp/rpc/endpoint", + "x-syft-from": "test-user", + } + headers := map[string]string{ + "x-syft-request-id": "test-request-id", + "x-syft-url": "syft://test@datasite.com/app_data/testapp/rpc/endpoint", + "x-syft-from": "test-user", + } + + c, w := createTestContext("GET", "/send/poll/", nil, query_params, headers) + + // Mock expectations + mockService.On("PollForResponse", mock.Anything, mock.Anything).Return(nil, ErrNoRequest) + + // Execute + handler.PollForResponse(c) + + // Assertions + assert.Equal(t, http.StatusNotFound, w.Code) + + var response APIError + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, ErrorNotFound, response.Error) + assert.Equal(t, "test-request-id", response.RequestID) + + mockService.AssertExpectations(t) +} + +func TestSendHandler_PollForResponse_InvalidRequest(t *testing.T) { + // Setup + mockService := &MockSendService{} + handler := &SendHandler{service: mockService} + + // Test data - missing required fields + query_params := map[string]string{ + "x-syft-from": "test-user", + } + headers := map[string]string{ + "Content-Type": "application/json", + // Missing required headers + } + + c, w := createTestContext("GET", "/send/poll/", nil, query_params, headers) + + // Execute + handler.PollForResponse(c) + + // Assertions + assert.Equal(t, http.StatusBadRequest, w.Code) + + var response APIError + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, ErrorInvalidRequest, response.Error) + assert.Contains(t, response.Message, "required") +} + +func TestSendHandler_PollForResponse_ServiceError(t *testing.T) { + // Setup + mockService := &MockSendService{} + handler := &SendHandler{service: mockService} + + // Test data + query_params := map[string]string{ + "x-syft-request-id": "test-request-id", + "x-syft-url": "syft://test@datasite.com/app_data/testapp/rpc/endpoint", + "x-syft-from": "test-user", + } + headers := map[string]string{ + "Content-Type": "application/json", + } + + c, w := createTestContext("GET", "/send/poll/", nil, query_params, headers) + + // Mock expectations + mockService.On("PollForResponse", mock.Anything, mock.Anything).Return(nil, errors.New("service error")) + + // Execute + handler.PollForResponse(c) + + // Assertions + assert.Equal(t, http.StatusInternalServerError, w.Code) + + var response APIError + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, ErrorInternal, response.Error) + assert.Equal(t, "service error", response.Message) + assert.Equal(t, "test-request-id", response.RequestID) + + mockService.AssertExpectations(t) +} + +func TestSendHandler_New(t *testing.T) { + // Setup + mockDispatcher := &MockMessageDispatcher{} + mockStore := &MockRPCMsgStore{} + + // Execute + handler := New(mockDispatcher, mockStore) + + // Assertions + assert.NotNil(t, handler) + assert.NotNil(t, handler.service) +} + +func TestReadRequestBody_Success(t *testing.T) { + // Setup + body := `{"key": "value"}` + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("POST", "/", strings.NewReader(body)) + + // Execute + result, err := readRequestBody(c, 1024*1024) // 1MB limit + + // Assertions + assert.NoError(t, err) + assert.Equal(t, []byte(body), result) +} + +func TestReadRequestBody_TooLarge(t *testing.T) { + // Setup + largeBody := strings.Repeat("a", 1025) // 1025 bytes + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("POST", "/", strings.NewReader(largeBody)) + + // Execute + result, err := readRequestBody(c, 1024) // 1KB limit + + // Assertions + assert.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "too large") +} + +func TestReadRequestBody_ReadError(t *testing.T) { + // Setup + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + // Create a request with a body that will fail to read + req := httptest.NewRequest("POST", "/", &failingReader{}) + c.Request = req + + // Execute + result, err := readRequestBody(c, 1024) + + // Assertions + assert.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "failed to read request body") +} + +// failingReader is a reader that always fails +type failingReader struct{} + +func (f *failingReader) Read(p []byte) (n int, err error) { + return 0, errors.New("read error") +} diff --git a/internal/server/handlers/send/send_handler_types.go b/internal/server/handlers/send/send_handler_types.go new file mode 100644 index 00000000..d9fb3e74 --- /dev/null +++ b/internal/server/handlers/send/send_handler_types.go @@ -0,0 +1,114 @@ +package send + +import ( + "context" + "io" + + "github.com/gin-gonic/gin" + "github.com/openmined/syftbox/internal/syftmsg" + "github.com/openmined/syftbox/internal/utils" +) + +type PollStatus string + +const ( + PollStatusPending PollStatus = "pending" + PollStatusComplete PollStatus = "complete" +) + +// Error constants +const ( + ErrorTimeout = "timeout" + ErrorInvalidRequest = "invalid_request" + ErrorInternal = "internal_error" + ErrorNotFound = "not_found" + PollURL = "/api/v1/send/poll?x-syft-request-id=%s&x-syft-url=%s&x-syft-from=%s&x-syft-raw=%t" +) + +// APIError represents a standardized error response +type APIError struct { + Error string `json:"error"` + Message string `json:"message"` + RequestID string `json:"request_id,omitempty"` +} + +// APIResponse represents a standardized success response +type APIResponse struct { + RequestID string `json:"request_id"` + Data interface{} `json:"data,omitempty"` + Message string `json:"message,omitempty"` +} + +type PollInfo struct { + PollURL string `json:"poll_url"` +} + +type Headers map[string]string + +// MessageRequest represents the request for sending a message +type MessageRequest struct { + SyftURL utils.SyftBoxURL `form:"x-syft-url" binding:"required"` // Binds to the syft url using UnmarshalParam + From string `form:"x-syft-from" binding:"required"` // The sender of the message + Timeout int `form:"timeout" binding:"gte=0"` // The timeout for the request + AsRaw bool `form:"x-syft-raw" default:"false"` // If true, the request body will be read and sent as is + Method string // Will be set from request method + Headers Headers // Will be set from request headers +} + +func (h *MessageRequest) BindHeaders(ctx *gin.Context) { + + // TODO: Filter out headers that are not allowed + h.Headers = make(Headers) + for k, v := range ctx.Request.Header { + if len(v) > 0 { + h.Headers[k] = v[0] + } + } + // Bind x-syft-from to Headers + h.Headers["x-syft-from"] = h.From +} + +// PollObjectRequest represents the request for polling +type PollObjectRequest struct { + RequestID string `form:"x-syft-request-id" binding:"required"` + From string `form:"x-syft-from" binding:"required"` + SyftURL utils.SyftBoxURL `form:"x-syft-url" binding:"required"` + Timeout int `form:"timeout,omitempty" binding:"gte=0"` + UserAgent string `form:"user-agent,omitempty"` + AsRaw bool `form:"x-syft-raw" default:"false"` // If true, the request body will be read and sent as is +} + +// SendResult represents the result of a send operation +type SendResult struct { + Status int + RequestID string + PollURL string + Response map[string]interface{} +} + +// PollResult represents the result of a poll operation +type PollResult struct { + Status int + RequestID string + Response map[string]interface{} +} + +// Message store interface for storing and retrieving messages +type RPCMsgStore interface { + StoreMsg(ctx context.Context, path string, msg syftmsg.SyftRPCMessage) error + GetMsg(ctx context.Context, path string) (io.ReadCloser, error) + DeleteMsg(ctx context.Context, path string) error +} + +// Message dispatch interface for dispatching messages to users +type MessageDispatcher interface { + Dispatch(datasite string, msg *syftmsg.Message) bool +} + +// SendServiceInterface defines the interface for the send service +type SendServiceInterface interface { + SendMessage(ctx context.Context, req *MessageRequest, bodyBytes []byte) (*SendResult, error) + PollForResponse(ctx context.Context, req *PollObjectRequest) (*PollResult, error) + constructPollURL(requestID string, syftURL utils.SyftBoxURL, from string, asRaw bool) string + GetConfig() *Config +} diff --git a/internal/server/handlers/send/send_service.go b/internal/server/handlers/send/send_service.go new file mode 100644 index 00000000..75cae99e --- /dev/null +++ b/internal/server/handlers/send/send_service.go @@ -0,0 +1,350 @@ +package send + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "path" + "time" + + "github.com/openmined/syftbox/internal/server/blob" + "github.com/openmined/syftbox/internal/server/handlers/ws" + "github.com/openmined/syftbox/internal/syftmsg" + "github.com/openmined/syftbox/internal/utils" +) + +var ( + ErrPollTimeout = errors.New("poll timeout") + ErrNoRequest = errors.New("no request found") +) + +type BlobMsgStore struct { + blob *blob.BlobService +} + +func (m *BlobMsgStore) GetMsg(ctx context.Context, path string) (io.ReadCloser, error) { + object, err := m.blob.Backend().GetObject(ctx, path) + if err != nil { + return nil, err + } + return object.Body, nil +} + +func (m *BlobMsgStore) DeleteMsg(ctx context.Context, path string) error { + _, err := m.blob.Backend().DeleteObject(ctx, path) + return err +} + +func (m *BlobMsgStore) StoreMsg(ctx context.Context, path string, msg syftmsg.SyftRPCMessage) error { + msgBytes, err := msg.MarshalJSON() + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + _, err = m.blob.Backend().PutObject(ctx, &blob.PutObjectParams{ + Key: path, + ETag: msg.ID.String(), + Body: bytes.NewReader(msgBytes), + Size: int64(len(msgBytes)), + }) + return err +} + +type WSMsgDispatcher struct { + hub *ws.WebsocketHub +} + +func (m *WSMsgDispatcher) Dispatch(datasite string, msg *syftmsg.Message) bool { + return m.hub.SendMessageUser(datasite, msg) +} + +func NewWSMsgDispatcher(hub *ws.WebsocketHub) MessageDispatcher { + return &WSMsgDispatcher{hub: hub} +} + +func NewBlobMsgStore(blob *blob.BlobService) RPCMsgStore { + return &BlobMsgStore{blob: blob} +} + +// SendService handles the business logic for message sending and polling +type SendService struct { + dispatcher MessageDispatcher + store RPCMsgStore + cfg *Config +} + +// Config holds the service configuration +type Config struct { + DefaultTimeoutMs int + MaxTimeoutMs int + MaxBodySize int64 + PollIntervalMs int + RequestChkTimeoutMs int +} + +// NewSendService creates a new send service +func NewSendService(dispatch MessageDispatcher, store RPCMsgStore, cfg *Config) *SendService { + if cfg == nil { + cfg = &Config{ + DefaultTimeoutMs: 1000, // 1000 ms + MaxTimeoutMs: 10000, // 10 seconds + MaxBodySize: 4 << 20, // 4MB + PollIntervalMs: 500, // 500 ms + RequestChkTimeoutMs: 200, // 200 ms + } + } + return &SendService{dispatcher: dispatch, store: store, cfg: cfg} +} + +// SendMessage handles sending a message to a user +func (s *SendService) SendMessage(ctx context.Context, req *MessageRequest, bodyBytes []byte) (*SendResult, error) { + + // Create the HTTP message + + msg := syftmsg.NewHttpMsg( + req.From, + req.SyftURL, + req.Method, + bodyBytes, + req.Headers, + syftmsg.HttpMsgTypeRequest, + ) + + httpMsg := msg.Data.(*syftmsg.HttpMsg) + + // TODO: Check if user has permission to send message to this application + + // Dispatch the message to the user via websocket + if ok := s.dispatcher.Dispatch(req.SyftURL.Datasite, msg); !ok { + // If the message is not sent via websocket, handle it as an offline message + return s.handleOfflineMessage(ctx, req, httpMsg) + } + + // If the message is sent via websocket, handle the response + return s.handleOnlineMessage(ctx, req, httpMsg) +} + +// handleOfflineMessage handles sending a message when the user is offline +func (s *SendService) handleOfflineMessage( + ctx context.Context, + req *MessageRequest, + httpMsg *syftmsg.HttpMsg, +) (*SendResult, error) { + blobPath := path.Join( + req.SyftURL.ToLocalPath(), + fmt.Sprintf("%s.%s", httpMsg.Id, httpMsg.Type), + ) + + // Create the RPC message + rpcMsg, err := syftmsg.NewSyftRPCMessage(*httpMsg) + if err != nil { + return nil, fmt.Errorf("failed to create RPCMsg: %w", err) + } + + // Save the RPC message to blob storage + if err := s.store.StoreMsg(ctx, blobPath, *rpcMsg); err != nil { + return nil, fmt.Errorf("failed to save message to blob storage: %w", err) + } + + slog.Info("saved message to blob storage", "blobPath", blobPath) + return &SendResult{ + Status: http.StatusAccepted, + RequestID: httpMsg.Id, + PollURL: s.constructPollURL(httpMsg.Id, req.SyftURL, req.From, req.AsRaw), + }, nil +} + +// handleOnlineMessage handles sending a message when the user is online +func (s *SendService) handleOnlineMessage( + ctx context.Context, + req *MessageRequest, + httpMsg *syftmsg.HttpMsg, +) (*SendResult, error) { + blobPath := path.Join( + req.SyftURL.ToLocalPath(), + fmt.Sprintf("%s.response", httpMsg.Id), + ) + + timeout := req.Timeout + if timeout <= 0 { + timeout = s.cfg.DefaultTimeoutMs + } + + object, err := s.pollForObject(ctx, blobPath, timeout) + if err != nil { + if errors.Is(err, ErrPollTimeout) { + return &SendResult{ + Status: http.StatusAccepted, + RequestID: httpMsg.Id, + PollURL: s.constructPollURL(httpMsg.Id, req.SyftURL, req.From, req.AsRaw), + }, nil + } + return nil, err + } + + // Read the object + bodyBytes, err := io.ReadAll(object) + if err != nil { + return nil, fmt.Errorf("failed to read object: %w", err) + } + + responseBody, err := unmarshalResponse(bodyBytes, req.AsRaw) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + // Clean up in background + go s.cleanReqResponse( + req.SyftURL.Datasite, + req.SyftURL.AppName, + req.SyftURL.Endpoint, + httpMsg.Id, + ) + + return &SendResult{ + Status: http.StatusOK, + RequestID: httpMsg.Id, + Response: responseBody, + }, nil +} + +// PollForResponse handles polling for a response +func (s *SendService) PollForResponse(ctx context.Context, req *PollObjectRequest) (*PollResult, error) { + + // Validate if the corresponding request exists + requestBlobPath := path.Join(req.SyftURL.ToLocalPath(), fmt.Sprintf("%s.request", req.RequestID)) + + _, err := s.pollForObject(ctx, requestBlobPath, s.cfg.RequestChkTimeoutMs) + + if err != nil { + if errors.Is(err, ErrPollTimeout) { + return nil, ErrNoRequest + } + return nil, err + } + + // Check if the corresponding response exists + responseFileName := fmt.Sprintf("%s.response", req.RequestID) + responseBlobPath := path.Join(req.SyftURL.ToLocalPath(), responseFileName) + + timeout := req.Timeout + if timeout <= 0 { + timeout = s.cfg.DefaultTimeoutMs + } + + object, err := s.pollForObject(ctx, responseBlobPath, timeout) + if err != nil { + return nil, err + } + + bodyBytes, err := io.ReadAll(object) + if err != nil { + return nil, fmt.Errorf("failed to read object: %w", err) + } + + responseBody, err := unmarshalResponse(bodyBytes, req.AsRaw) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + // Clean up in background + go s.cleanReqResponse( + req.SyftURL.Datasite, + req.SyftURL.AppName, + req.SyftURL.Endpoint, + req.RequestID, + ) + + return &PollResult{ + Status: http.StatusOK, + RequestID: req.RequestID, + Response: responseBody, + }, nil +} + +// pollForObject polls for an object in blob storage +func (s *SendService) pollForObject(ctx context.Context, blobPath string, timeout int) (io.ReadCloser, error) { + startTime := time.Now() + maxTimeout := time.Duration(timeout) * time.Millisecond + + for { + if time.Since(startTime) > maxTimeout { + return nil, ErrPollTimeout + } + + object, err := s.store.GetMsg(ctx, blobPath) + if err != nil { + slog.Error("Failed to get object from backend", "error", err, "blobPath", blobPath) + continue + } + if object != nil { + return object, nil + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Duration(s.cfg.PollIntervalMs) * time.Millisecond): + continue + } + } +} + +// cleanReqResponse cleans up request and response files +func (s *SendService) cleanReqResponse(sender, appName, appEp, requestID string) { + requestPath := path.Join(sender, "app_data", appName, "rpc", appEp, fmt.Sprintf("%s.request", requestID)) + responsePath := path.Join(sender, "app_data", appName, "rpc", appEp, fmt.Sprintf("%s.response", requestID)) + + if err := s.store.DeleteMsg(context.Background(), requestPath); err != nil { + slog.Error("failed to delete request object", "error", err, "path", requestPath) + } + + if err := s.store.DeleteMsg(context.Background(), responsePath); err != nil { + slog.Error("failed to delete response object", "error", err, "path", responsePath) + } +} + +// constructPollURL constructs the poll URL for a request +func (s *SendService) constructPollURL(requestID string, syftURL utils.SyftBoxURL, from string, asRaw bool) string { + return fmt.Sprintf( + PollURL, + requestID, + syftURL.BaseURL(), + from, + asRaw, + ) +} + +// unmarshalResponse handles the unmarshaling of a response from blob storage +// It expects the response to have a base64 encoded body field that contains JSON +func unmarshalResponse(bodyBytes []byte, asRaw bool) (map[string]interface{}, error) { + // If the request is raw, return the body as bytes + if asRaw { + var bodyJson map[string]interface{} + err := json.Unmarshal(bodyBytes, &bodyJson) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + return map[string]interface{}{"message": bodyJson}, nil + } + + // Otherwise, unmarshal it as a SyftRPCMessage + var rpcMsg syftmsg.SyftRPCMessage + err := json.Unmarshal(bodyBytes, &rpcMsg) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + // decode the body if it is base64 encoded + // return the SyftRPCMessage as a different json representation + return map[string]interface{}{"message": rpcMsg.ToJsonMap()}, nil +} + +// GetConfig returns the service configuration +func (s *SendService) GetConfig() *Config { + return s.cfg +} diff --git a/internal/server/handlers/send/send_service_test.go b/internal/server/handlers/send/send_service_test.go new file mode 100644 index 00000000..b2cf147b --- /dev/null +++ b/internal/server/handlers/send/send_service_test.go @@ -0,0 +1,428 @@ +package send + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "reflect" + "strings" + "sync" + "testing" + "time" + + "github.com/google/uuid" + "github.com/openmined/syftbox/internal/syftmsg" + "github.com/openmined/syftbox/internal/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// MockMessageDispatcher implements MessageDispatcher for testing +type MockMessageDispatcher struct { + mock.Mock +} + +func (m *MockMessageDispatcher) Dispatch(datasite string, msg *syftmsg.Message) bool { + args := m.Called(datasite, msg) + return args.Bool(0) +} + +// MockRPCMsgStore implements RPCMsgStore for testing +type MockRPCMsgStore struct { + mock.Mock +} + +func (m *MockRPCMsgStore) StoreMsg(ctx context.Context, path string, msg syftmsg.SyftRPCMessage) error { + args := m.Called(ctx, path, msg) + return args.Error(0) +} + +func (m *MockRPCMsgStore) GetMsg(ctx context.Context, path string) (io.ReadCloser, error) { + args := m.Called(ctx, path) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(io.ReadCloser), args.Error(1) +} + +func (m *MockRPCMsgStore) DeleteMsg(ctx context.Context, path string) error { + args := m.Called(ctx, path) + return args.Error(0) +} + +func TestNewSendService(t *testing.T) { + // Test with custom config + dispatcher := &MockMessageDispatcher{} + store := &MockRPCMsgStore{} + cfg := &Config{ + DefaultTimeoutMs: 2000, + MaxTimeoutMs: 20000, + MaxBodySize: 8 << 20, + PollIntervalMs: 1000, + RequestChkTimeoutMs: 400, + } + + service := NewSendService(dispatcher, store, cfg) + assert.NotNil(t, service) + assert.Equal(t, cfg, service.cfg) + + // Test with nil config (should use defaults) + service = NewSendService(dispatcher, store, nil) + assert.NotNil(t, service) + assert.NotNil(t, service.cfg) + assert.Equal(t, 1000, service.cfg.DefaultTimeoutMs) + assert.Equal(t, 10000, service.cfg.MaxTimeoutMs) + assert.Equal(t, int64(4<<20), service.cfg.MaxBodySize) + assert.Equal(t, 500, service.cfg.PollIntervalMs) + assert.Equal(t, 200, service.cfg.RequestChkTimeoutMs) +} + +func TestSendService_SendMessage_Online(t *testing.T) { + dispatcher := &MockMessageDispatcher{} + store := &MockRPCMsgStore{} + service := NewSendService(dispatcher, store, nil) + + // Create test data + from := "test-user" + syftURL, err := utils.FromSyftURL("syft://test@datasite.com/app_data/testapp/rpc/endpoint") + assert.NoError(t, err) + method := "POST" + body := []byte(`{"key": "value"}`) + headers := map[string]string{ + "Content-Type": "application/json", + } + + // Create request + req := &MessageRequest{ + From: from, + SyftURL: *syftURL, + Method: method, + Headers: headers, + } + + // Create expected message + msg := syftmsg.NewHttpMsg( + from, + *syftURL, + method, + body, + headers, + syftmsg.HttpMsgTypeRequest, + ) + + httpMsg := msg.Data.(*syftmsg.HttpMsg) + + // Set up mock expectations + dispatcher.On("Dispatch", syftURL.Datasite, mock.MatchedBy(func(msg *syftmsg.Message) bool { + httpMsg, ok := msg.Data.(*syftmsg.HttpMsg) + if !ok { + return false + } + return httpMsg.From == from && + httpMsg.Method == method && + httpMsg.Type == syftmsg.HttpMsgTypeRequest && + bytes.Equal(httpMsg.Body, body) && + reflect.DeepEqual(httpMsg.Headers, headers) + })).Return(true) + + // Create response message + responseMsg := &syftmsg.SyftRPCMessage{ + ID: uuid.MustParse(httpMsg.Id), + Sender: "test-datasite", + URL: *syftURL, + Body: []byte(`{"response": "success"}`), + Headers: headers, + Created: time.Now().UTC(), + Expires: time.Now().UTC().Add(24 * time.Hour), + Method: syftmsg.MethodPOST, + StatusCode: syftmsg.StatusOK, + } + + // Mock GetMsg to return the response + responseBytes, err := json.Marshal(responseMsg) + assert.NoError(t, err) + store.On("GetMsg", mock.Anything, mock.Anything).Return(io.NopCloser(bytes.NewReader(responseBytes)), nil) + + // Set up expectations for DeleteMsg calls in cleanReqResponse + wg := &sync.WaitGroup{} + wg.Add(2) + + store.On("DeleteMsg", mock.Anything, mock.MatchedBy(func(path string) bool { + return strings.HasSuffix(path, ".request") + })). + Run(func(args mock.Arguments) { + wg.Done() + }). + Return(nil) + + store.On("DeleteMsg", mock.Anything, mock.MatchedBy(func(path string) bool { + return strings.HasSuffix(path, ".response") + })). + Run(func(args mock.Arguments) { + wg.Done() + }). + Return(nil) + + // Call SendMessage + result, err := service.SendMessage(context.Background(), req, body) + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, http.StatusOK, result.Status) + assert.NotEmpty(t, result.RequestID) + assert.NotNil(t, result.Response) + + // Wait for cleanup goroutine to complete + // Use a reasonable timeout to prevent test hanging + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // Cleanup completed successfully + case <-time.After(5 * time.Second): + t.Fatal("Cleanup goroutine timed out") + } + + // Verify mock expectations + dispatcher.AssertExpectations(t) + store.AssertExpectations(t) +} + +func TestSendService_SendMessage_Offline(t *testing.T) { + dispatcher := &MockMessageDispatcher{} + store := &MockRPCMsgStore{} + service := NewSendService(dispatcher, store, nil) + + // Create test data + from := "test-user" + syftURL, err := utils.FromSyftURL("syft://test@datasite.com/app_data/testapp/rpc/endpoint") + assert.NoError(t, err) + + method := "POST" + body := []byte(`{"key": "value"}`) + headers := map[string]string{ + "Content-Type": "application/json", + } + + // Create request + req := &MessageRequest{ + From: from, + SyftURL: *syftURL, + Method: method, + Headers: headers, + } + + // Set up mock expectations + dispatcher.On("Dispatch", syftURL.Datasite, mock.MatchedBy(func(msg *syftmsg.Message) bool { + httpMsg, ok := msg.Data.(*syftmsg.HttpMsg) + if !ok { + return false + } + return httpMsg.From == from && + httpMsg.Method == method && + httpMsg.Type == syftmsg.HttpMsgTypeRequest && + bytes.Equal(httpMsg.Body, body) && + reflect.DeepEqual(httpMsg.Headers, headers) + })).Return(false) + store.On("StoreMsg", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + // Call SendMessage + result, err := service.SendMessage(context.Background(), req, body) + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, http.StatusAccepted, result.Status) + assert.NotEmpty(t, result.RequestID) + assert.NotEmpty(t, result.PollURL) + + // Verify mock expectations + dispatcher.AssertExpectations(t) + store.AssertExpectations(t) +} + +func TestSendService_PollForResponse(t *testing.T) { + dispatcher := &MockMessageDispatcher{} + store := &MockRPCMsgStore{} + service := NewSendService(dispatcher, store, nil) + + // Create test data + requestID := uuid.New().String() + from := "test-user" + syftURL, err := utils.FromSyftURL("syft://test@datasite.com/app_data/testapp/rpc/endpoint") + assert.NoError(t, err) + + // Create request + req := &PollObjectRequest{ + RequestID: requestID, + From: from, + SyftURL: *syftURL, + } + + // Create response message + responseMsg := &syftmsg.SyftRPCMessage{ + ID: uuid.New(), + Sender: "test-datasite", + URL: *syftURL, + Body: []byte(`{"response": "success"}`), + Headers: map[string]string{"Content-Type": "application/json"}, + Created: time.Now().UTC(), + Expires: time.Now().UTC().Add(24 * time.Hour), + Method: syftmsg.MethodPOST, + StatusCode: syftmsg.StatusOK, + } + + // Mock GetMsg to return both request and response + requestBytes, err := json.Marshal(responseMsg) + assert.NoError(t, err) + responseBytes, err := json.Marshal(responseMsg) + assert.NoError(t, err) + + store.On("GetMsg", mock.Anything, mock.MatchedBy(func(path string) bool { + return path == syftURL.ToLocalPath()+"/"+requestID+".request" + })).Return(io.NopCloser(bytes.NewReader(requestBytes)), nil) + store.On("GetMsg", mock.Anything, mock.MatchedBy(func(path string) bool { + return path == syftURL.ToLocalPath()+"/"+requestID+".response" + })).Return(io.NopCloser(bytes.NewReader(responseBytes)), nil) + + // Set up expectations for DeleteMsg calls in cleanReqResponse + wg := &sync.WaitGroup{} + wg.Add(2) + + store.On("DeleteMsg", mock.Anything, mock.MatchedBy(func(path string) bool { + return strings.HasSuffix(path, ".request") + })). + Run(func(args mock.Arguments) { + wg.Done() + }). + Return(nil) + + store.On("DeleteMsg", mock.Anything, mock.MatchedBy(func(path string) bool { + return strings.HasSuffix(path, ".response") + })). + Run(func(args mock.Arguments) { + wg.Done() + }). + Return(nil) + + // Call PollForResponse + result, err := service.PollForResponse(context.Background(), req) + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, http.StatusOK, result.Status) + assert.Equal(t, requestID, result.RequestID) + assert.NotNil(t, result.Response) + + // Wait for cleanup goroutine to complete + // Use a reasonable timeout to prevent test hanging + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // Cleanup completed successfully + case <-time.After(5 * time.Second): + t.Fatal("Cleanup goroutine timed out") + } + + // Verify mock expectations + store.AssertExpectations(t) +} + +func TestSendService_PollForResponse_NoRequest(t *testing.T) { + dispatcher := &MockMessageDispatcher{} + store := &MockRPCMsgStore{} + service := NewSendService(dispatcher, store, nil) + + // Create test data + requestID := uuid.New().String() + from := "test-user" + syftURL, err := utils.FromSyftURL("syft://test@datasite.com/app_data/testapp/rpc/endpoint") + assert.NoError(t, err) + + // Create request + req := &PollObjectRequest{ + RequestID: requestID, + From: from, + SyftURL: *syftURL, + } + + // Mock GetMsg to return error for request + store.On("GetMsg", mock.Anything, mock.Anything).Return(nil, errors.New("not found")) + + // Call PollForResponse + result, err := service.PollForResponse(context.Background(), req) + assert.Error(t, err) + assert.Equal(t, ErrNoRequest, err) + assert.Nil(t, result) + + // Verify mock expectations + store.AssertExpectations(t) +} + +func TestSendService_PollForResponse_Timeout(t *testing.T) { + dispatcher := &MockMessageDispatcher{} + store := &MockRPCMsgStore{} + cfg := &Config{ + DefaultTimeoutMs: 100, + MaxTimeoutMs: 1000, + MaxBodySize: 4 << 20, + PollIntervalMs: 50, + RequestChkTimeoutMs: 50, + } + service := NewSendService(dispatcher, store, cfg) + + // Create test data + requestID := uuid.New().String() + from := "test-user" + syftURL, err := utils.FromSyftURL("syft://test@datasite.com/app_data/testapp/rpc/endpoint") + assert.NoError(t, err) + + // Create request + req := &PollObjectRequest{ + RequestID: requestID, + From: from, + SyftURL: *syftURL, + Timeout: 100, + } + + // Create response message for request check + responseMsg := &syftmsg.SyftRPCMessage{ + ID: uuid.New(), + Sender: "test-datasite", + URL: *syftURL, + Body: []byte(`{"response": "success"}`), + Headers: map[string]string{"Content-Type": "application/json"}, + Created: time.Now().UTC(), + Expires: time.Now().UTC().Add(24 * time.Hour), + Method: syftmsg.MethodPOST, + StatusCode: syftmsg.StatusOK, + } + + // Mock GetMsg to return request but timeout for response + requestBytes, err := json.Marshal(responseMsg) + assert.NoError(t, err) + + store.On("GetMsg", mock.Anything, mock.MatchedBy(func(path string) bool { + return path == syftURL.ToLocalPath()+"/"+requestID+".request" + })).Return(io.NopCloser(bytes.NewReader(requestBytes)), nil) + store.On("GetMsg", mock.Anything, mock.MatchedBy(func(path string) bool { + return path == syftURL.ToLocalPath()+"/"+requestID+".response" + })).Return(nil, errors.New("not found")) + + // Call PollForResponse + result, err := service.PollForResponse(context.Background(), req) + assert.Error(t, err) + assert.Equal(t, ErrPollTimeout, err) + assert.Nil(t, result) + + // Verify mock expectations + store.AssertExpectations(t) +} diff --git a/internal/server/handlers/ws/ws_hub.go b/internal/server/handlers/ws/ws_hub.go index 717a1be5..a51a1d44 100644 --- a/internal/server/handlers/ws/ws_hub.go +++ b/internal/server/handlers/ws/ws_hub.go @@ -9,7 +9,9 @@ import ( "github.com/coder/websocket" "github.com/gin-gonic/gin" + "github.com/openmined/syftbox/internal/server/handlers/api" "github.com/openmined/syftbox/internal/syftmsg" + "github.com/openmined/syftbox/internal/version" ) const ( @@ -90,31 +92,27 @@ func (h *WebsocketHub) Shutdown(ctx context.Context) { // WebsocketHandler is the handler for the websocket connection // it upgrades the http connection to a websocket and registers the client with the hub func (h *WebsocketHub) WebsocketHandler(ctx *gin.Context) { - if ctx.GetString("user") == "" { - ctx.Status(http.StatusUnauthorized) - slog.Warn("wshub unauthorized", "ip", ctx.ClientIP(), "headers", ctx.Request.Header, "path", ctx.Request.URL) + user := ctx.GetString("user") + if user == "" { + api.AbortWithError(ctx, http.StatusUnauthorized, api.CodeInvalidRequest, fmt.Errorf("user missing")) return } // Upgrade HTTP connection to WebSocket conn, err := websocket.Accept(ctx.Writer, ctx.Request, nil) if err != nil { - e := fmt.Errorf("websocket accept failed: %w", err) - ctx.Error(e) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": e.Error(), - }) + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("websocket accept failed: %w", err)) return } conn.SetReadLimit(maxMessageSize) client := NewWebsocketClient(conn, &ClientInfo{ - User: ctx.GetString("user"), + User: user, IPAddr: ctx.ClientIP(), Headers: ctx.Request.Header.Clone(), }) - client.MsgTx <- syftmsg.NewSystemMessage("0.5.0", "ok") + client.MsgTx <- syftmsg.NewSystemMessage(version.Version, "ok") h.register <- client } @@ -137,6 +135,7 @@ func (h *WebsocketHub) SendMessageUser(user string, msg *syftmsg.Message) bool { for _, client := range h.clients { if client.Info.User == user { + slog.Debug("sending message to client", "connId", client.ConnID, "user", user) select { case client.MsgTx <- msg: sent = true diff --git a/internal/server/middlewares/jwtauth.go b/internal/server/middlewares/jwtauth.go index ecf491ac..9f006c2c 100644 --- a/internal/server/middlewares/jwtauth.go +++ b/internal/server/middlewares/jwtauth.go @@ -8,6 +8,7 @@ import ( "github.com/gin-gonic/gin" "github.com/openmined/syftbox/internal/server/auth" // Import your auth package + "github.com/openmined/syftbox/internal/server/handlers/api" "github.com/openmined/syftbox/internal/utils" // Import types for error constants ) @@ -19,18 +20,22 @@ const ( // JWTAuth creates a Gin middleware function that validates access tokens. // It requires the AuthService to access token validation logic and configuration. -func JWTAuth(authService *auth.AuthService) gin.HandlerFunc { +func JWTAuth(authService *auth.AuthService, allowGuest bool) gin.HandlerFunc { if !authService.IsEnabled() { slog.Info("auth middleware disabled") return func(ctx *gin.Context) { // expect user to be an email address user := ctx.Query("user") + + // if user is not set, check if the request has a x-syft-from query parameter + if user == "" { + user = ctx.Query("x-syft-from") + } + + // check if the user is a valid email address if !utils.IsValidEmail(user) { - ctx.Error(fmt.Errorf("invalid email")) - ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "error": "invalid email", - }) + api.AbortWithError(ctx, http.StatusUnauthorized, api.CodeInvalidRequest, fmt.Errorf("invalid email")) return } ctx.Set("user", user) @@ -38,44 +43,46 @@ func JWTAuth(authService *auth.AuthService) gin.HandlerFunc { } } - slog.Info("auth middleware enabled") - return func(ctx *gin.Context) { + // Check for guest access first if allowed + slog.Debug("Checking for guest access", "allowGuest", allowGuest) + if allowGuest { + user := ctx.Query("user") + if user == "" { + user = ctx.Query("x-syft-from") + } + slog.Debug("Attempting to access with user", "user", user) + if user == "guest@syft.org" { + ctx.Set("user", user) + ctx.Next() + return + } + } + + // Proceed with normal JWT authentication authHeaderValue := ctx.GetHeader(authHeader) if authHeaderValue == "" { - ctx.Error(fmt.Errorf("authorization header required")) - ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "error": "authorization header required", - }) + api.AbortWithError(ctx, http.StatusUnauthorized, api.CodeAuthInvalidCredentials, fmt.Errorf("authorization header required")) return } // Check if the header starts with "Bearer " if !strings.HasPrefix(authHeaderValue, bearerPrefix) { - ctx.Error(fmt.Errorf("bearer token required")) - ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "error": "bearer token required", - }) + api.AbortWithError(ctx, http.StatusUnauthorized, api.CodeAuthInvalidCredentials, fmt.Errorf("bearer token required")) return } // Extract the token string tokenString := strings.TrimPrefix(authHeaderValue, bearerPrefix) if tokenString == "" { - ctx.Error(fmt.Errorf("token missing")) - ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "error": "token missing", - }) + api.AbortWithError(ctx, http.StatusUnauthorized, api.CodeAuthInvalidCredentials, fmt.Errorf("token missing")) return } // Validate the token using the method added to AuthService claims, err := authService.ValidateAccessToken(ctx, tokenString) if err != nil { - ctx.Error(err) - ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "error": err.Error(), - }) + api.AbortWithError(ctx, http.StatusUnauthorized, api.CodeAuthInvalidCredentials, err) return } diff --git a/internal/server/middlewares/ratelimiter.go b/internal/server/middlewares/ratelimiter.go index fe659a8b..a7b1aff3 100644 --- a/internal/server/middlewares/ratelimiter.go +++ b/internal/server/middlewares/ratelimiter.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/gin-gonic/gin" + "github.com/openmined/syftbox/internal/server/handlers/api" "github.com/ulule/limiter/v3" "github.com/ulule/limiter/v3/drivers/store/memory" @@ -21,13 +22,15 @@ func RateLimiter(formattedRate string) gin.HandlerFunc { return mgin.NewMiddleware( limiter, mgin.WithLimitReachedHandler(func(c *gin.Context) { - c.PureJSON(http.StatusTooManyRequests, gin.H{ - "error": "rate limit exceeded", + c.PureJSON(http.StatusTooManyRequests, api.SyftAPIError{ + Code: api.CodeRateLimited, + Message: "rate limit exceeded", }) }), mgin.WithErrorHandler(func(c *gin.Context, err error) { - c.PureJSON(http.StatusInternalServerError, gin.H{ - "error": err.Error(), + c.PureJSON(http.StatusInternalServerError, api.SyftAPIError{ + Code: api.CodeInternalError, + Message: err.Error(), }) }), ) diff --git a/internal/server/routes.go b/internal/server/routes.go index 9236103f..fe0d4ca9 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -1,22 +1,30 @@ package server import ( + "embed" "fmt" + "html/template" "net/http" "os" "github.com/gin-gonic/gin" + "github.com/openmined/syftbox/internal/server/handlers/acl" + "github.com/openmined/syftbox/internal/server/handlers/api" "github.com/openmined/syftbox/internal/server/handlers/auth" "github.com/openmined/syftbox/internal/server/handlers/blob" "github.com/openmined/syftbox/internal/server/handlers/datasite" "github.com/openmined/syftbox/internal/server/handlers/explorer" "github.com/openmined/syftbox/internal/server/handlers/install" + "github.com/openmined/syftbox/internal/server/handlers/send" "github.com/openmined/syftbox/internal/server/handlers/ws" "github.com/openmined/syftbox/internal/server/middlewares" "github.com/openmined/syftbox/internal/version" ) +//go:embed handlers/send/*.html +var templateFS embed.FS + func SetupRoutes(svc *Services, hub *ws.WebsocketHub, httpsEnabled bool) http.Handler { r := gin.New() @@ -30,12 +38,18 @@ func SetupRoutes(svc *Services, hub *ws.WebsocketHub, httpsEnabled bool) http.Ha r.Use(middlewares.HSTS()) } + // Load HTML templates from embedded filesystem + tmpl := template.Must(template.ParseFS(templateFS, "handlers/send/*.html")) + r.SetHTMLTemplate(tmpl) + // --------------------------- handlers --------------------------- - blobH := blob.New(svc.Blob) + blobH := blob.New(svc.Blob, svc.ACL) dsH := datasite.New(svc.Datasite) explorerH := explorer.New(svc.Blob, svc.ACL) authH := auth.New(svc.Auth) + aclH := acl.NewACLHandler(svc.ACL) + sendH := send.New(send.NewWSMsgDispatcher(hub), send.NewBlobMsgStore(svc.Blob)) // --------------------------- routes --------------------------- @@ -59,12 +73,15 @@ func SetupRoutes(svc *Services, hub *ws.WebsocketHub, httpsEnabled bool) http.Ha } v1 := r.Group("/api/v1") - v1.Use(middlewares.JWTAuth(svc.Auth)) + + // enable auth middleware with no guest access + v1.Use(middlewares.JWTAuth(svc.Auth, false)) // v1.Use(middlewares.RateLimiter("100-S")) // todo { // blob v1.GET("/blob/list", blobH.ListObjects) v1.PUT("/blob/upload", blobH.Upload) + v1.PUT("/blob/upload/acl", blobH.UploadACL) v1.POST("/blob/upload/presigned", blobH.UploadPresigned) v1.POST("/blob/upload/multipart", blobH.UploadMultipart) v1.POST("/blob/upload/complete", blobH.UploadComplete) @@ -74,19 +91,33 @@ func SetupRoutes(svc *Services, hub *ws.WebsocketHub, httpsEnabled bool) http.Ha // datasite v1.GET("/datasite/view", dsH.GetView) + v1.PUT("/acl", blobH.UploadACL) + v1.GET("/acl/check", aclH.CheckAccess) + // websocket events v1.GET("/events", hub.WebsocketHandler) + + } + + // rpc group with guest access + sendG := r.Group("/api/v1/send") + sendG.Use(middlewares.JWTAuth(svc.Auth, true)) + { + sendG.Any("/msg", sendH.SendMsg) + sendG.GET("/poll", sendH.PollForResponse) } r.NoRoute(func(c *gin.Context) { - c.JSON(http.StatusNotFound, gin.H{ - "error": "not found", + c.JSON(http.StatusNotFound, api.SyftAPIError{ + Code: api.CodeInvalidRequest, + Message: "not found", }) }) r.NoMethod(func(c *gin.Context) { - c.JSON(http.StatusMethodNotAllowed, gin.H{ - "error": "method not allowed", + c.JSON(http.StatusMethodNotAllowed, api.SyftAPIError{ + Code: api.CodeInvalidRequest, + Message: "method not allowed", }) }) diff --git a/internal/server/server.go b/internal/server/server.go index 8cccfede..a5f5325c 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -17,7 +17,6 @@ import ( "github.com/openmined/syftbox/internal/db" "github.com/openmined/syftbox/internal/server/acl" "github.com/openmined/syftbox/internal/server/blob" - "github.com/openmined/syftbox/internal/server/datasite" "github.com/openmined/syftbox/internal/server/handlers/ws" "github.com/openmined/syftbox/internal/syftmsg" "golang.org/x/sync/errgroup" @@ -263,7 +262,7 @@ func (s *Server) checkPermission(user string, path string, access acl.AccessLeve return nil } return s.svc.ACL.CanAccess( - &acl.User{ID: user, IsOwner: datasite.IsOwner(path, user)}, + &acl.User{ID: user}, &acl.File{Path: path}, access, ) diff --git a/internal/server/services.go b/internal/server/services.go index 007a6b5d..025c2c33 100644 --- a/internal/server/services.go +++ b/internal/server/services.go @@ -14,7 +14,7 @@ import ( type Services struct { Blob *blob.BlobService - ACL *acl.AclService + ACL *acl.ACLService Datasite *datasite.DatasiteService Auth *auth.AuthService Email *email.EmailService @@ -28,7 +28,7 @@ func NewServices(config *Config, db *sqlx.DB) (*Services, error) { return nil, err } - aclSvc := acl.NewAclService() + aclSvc := acl.NewACLService() datasiteSvc := datasite.NewDatasiteService(blobSvc, aclSvc) diff --git a/internal/syftmsg/http_msg.go b/internal/syftmsg/http_msg.go new file mode 100644 index 00000000..d4b061ee --- /dev/null +++ b/internal/syftmsg/http_msg.go @@ -0,0 +1,46 @@ +package syftmsg + +import ( + "github.com/google/uuid" + "github.com/openmined/syftbox/internal/utils" +) + +type HttpMsgType string + +const ( + HttpMsgTypeRequest HttpMsgType = "request" + HttpMsgTypeResponse HttpMsgType = "response" +) + +type HttpMsg struct { + From string `json:"from"` + SyftURL utils.SyftBoxURL `json:"syft_url"` + Method string `json:"method"` + Headers map[string]string `json:"headers,omitempty"` + Body []byte `json:"body,omitempty"` + Id string `json:"id,omitempty"` + Type HttpMsgType `json:"type,omitempty"` +} + +func NewHttpMsg( + from string, + syftURL utils.SyftBoxURL, + method string, + body []byte, + headers map[string]string, + msgType HttpMsgType, +) *Message { + return &Message{ + Id: generateID(), + Type: MsgHttp, + Data: &HttpMsg{ + From: from, + SyftURL: syftURL, + Method: method, + Body: body, + Headers: headers, + Id: uuid.New().String(), + Type: msgType, + }, + } +} diff --git a/internal/syftmsg/msg.go b/internal/syftmsg/msg.go index 591fb5c5..0184778a 100644 --- a/internal/syftmsg/msg.go +++ b/internal/syftmsg/msg.go @@ -59,6 +59,12 @@ func (m *Message) UnmarshalJSON(data []byte) error { return err } m.Data = fileDelete + case MsgHttp: + var httpMsg HttpMsg + if err := json.Unmarshal(temp.Data, &httpMsg); err != nil { + return err + } + m.Data = &httpMsg default: return fmt.Errorf("unknown message type: %d", m.Type) } diff --git a/internal/syftmsg/msg_type.go b/internal/syftmsg/msg_type.go index 114939da..e5f2aa80 100644 --- a/internal/syftmsg/msg_type.go +++ b/internal/syftmsg/msg_type.go @@ -11,6 +11,7 @@ const ( MsgFileDelete MsgAck MsgNack + MsgHttp ) func (t MessageType) String() string { @@ -27,6 +28,8 @@ func (t MessageType) String() string { return "ACK" case MsgNack: return "NACK" + case MsgHttp: + return "HTTP" default: return fmt.Sprintf("???(%d)", t) } diff --git a/internal/syftmsg/rpc_msg.go b/internal/syftmsg/rpc_msg.go new file mode 100644 index 00000000..5ad219c2 --- /dev/null +++ b/internal/syftmsg/rpc_msg.go @@ -0,0 +1,263 @@ +package syftmsg + +import ( + "encoding/json" + "fmt" + "time" + + "encoding/base64" + + "github.com/google/uuid" + "github.com/openmined/syftbox/internal/utils" +) + +// ValidationError represents a validation error +type ValidationError struct { + Field string + Message string +} + +func (e *ValidationError) Error() string { + return fmt.Sprintf("validation error in field %s: %s", e.Field, e.Message) +} + +// SyftMethod represents the HTTP method in the Syft protocol +type SyftMethod string + +// IsValid checks if the method is valid +func (m SyftMethod) IsValid() bool { + switch m { + case MethodGET, MethodPOST, MethodPUT, MethodDELETE: + return true + default: + return false + } +} + +// Validate validates the method +func (m SyftMethod) Validate() error { + if m == "" { + return nil + } + if !m.IsValid() { + return &ValidationError{ + Field: "method", + Message: fmt.Sprintf("invalid method: %s", m), + } + } + return nil +} + +// SyftStatus represents the status code in the Syft protocol +type SyftStatus int + +// IsValid checks if the status code is valid +func (s SyftStatus) IsValid() bool { + return s >= 100 && s <= 599 +} + +func (s SyftStatus) isDefined() bool { + return s != 0 +} + +// Validate validates the status code +func (s SyftStatus) Validate() error { + if !s.isDefined() { + return nil + } + if !s.IsValid() { + return &ValidationError{ + Field: "status_code", + Message: fmt.Sprintf("invalid status code: %d", s), + } + } + return nil +} + +const ( + // DefaultMessageExpiry is the default time in seconds before a message expires + // 1 day + DefaultMessageExpiry = 24 * 60 * 60 * time.Second + + // HTTP Methods + MethodGET SyftMethod = "GET" + MethodPOST SyftMethod = "POST" + MethodPUT SyftMethod = "PUT" + MethodDELETE SyftMethod = "DELETE" + + // Status codes + StatusOK SyftStatus = 200 +) + +// SyftMessage represents a base message for Syft protocol communication +type SyftRPCMessage struct { + // ID is the unique identifier of the message + ID uuid.UUID `json:"id"` + + // Sender is the sender of the message + Sender string `json:"sender"` + + // URL is the URL of the message + URL utils.SyftBoxURL `json:"url"` + + // Body is the body of the message in bytes + Body []byte `json:"body,omitempty"` + + // Headers contains additional headers for the message + Headers map[string]string `json:"headers"` + + // Created is the timestamp when the message was created + Created time.Time `json:"created"` + + // Expires is the timestamp when the message expires + Expires time.Time `json:"expires"` + + Method SyftMethod `json:"method,omitempty"` + + StatusCode SyftStatus `json:"status_code,omitempty"` +} + +// NewSyftMessage creates a new SyftMessage with default values +func NewSyftRPCMessage(httpMsg HttpMsg) (*SyftRPCMessage, error) { + + // Timezone is UTC by default for SyftRPC messages + now := time.Now().UTC() + + headers := httpMsg.Headers + if headers == nil { + headers = make(map[string]string) + } + + msg := &SyftRPCMessage{ + ID: uuid.MustParse(httpMsg.Id), + Sender: httpMsg.From, + URL: httpMsg.SyftURL, + Body: httpMsg.Body, + Headers: headers, + Created: now, + Expires: now.Add(time.Duration(DefaultMessageExpiry)), + Method: SyftMethod(httpMsg.Method), + } + + if err := msg.Validate(); err != nil { + return nil, err + } + + return msg, nil +} + +// MarshalJSON implements custom JSON marshaling to handle bytes as base64 +func (m *SyftRPCMessage) MarshalJSON() ([]byte, error) { + type Alias SyftRPCMessage + return json.Marshal(&struct { + *Alias + URL string `json:"url"` + Body string `json:"body,omitempty"` + }{ + Alias: (*Alias)(m), + URL: m.URL.String(), + Body: base64.URLEncoding.EncodeToString(m.Body), + }) +} + +// UnmarshalJSON implements custom JSON unmarshaling +func (m *SyftRPCMessage) UnmarshalJSON(data []byte) error { + type Alias struct { + ID uuid.UUID `json:"id"` + Sender string `json:"sender"` + URL string `json:"url"` + Body string `json:"body,omitempty"` + Headers map[string]string `json:"headers"` + Created time.Time `json:"created"` + Expires time.Time `json:"expires"` + Method SyftMethod `json:"method,omitempty"` + StatusCode SyftStatus `json:"status_code,omitempty"` + } + + var aux Alias + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + // Parse URL + url, err := utils.FromSyftURL(aux.URL) + if err != nil { + return fmt.Errorf("failed to parse URL: %w", err) + } + + // Set fields + m.ID = aux.ID + m.Sender = aux.Sender + m.URL = *url + m.Headers = aux.Headers + m.Created = aux.Created + m.Expires = aux.Expires + m.Method = aux.Method + m.StatusCode = aux.StatusCode + + // Handle body + if aux.Body != "" { + if body, err := base64.URLEncoding.DecodeString(aux.Body); err == nil { + m.Body = body + } else { + m.Body = []byte(aux.Body) + } + } + + // Validate the message + if err := m.Validate(); err != nil { + return fmt.Errorf("invalid message: %w", err) + } + + return nil +} + +// JSONString returns a properly formatted JSON string with decoded body +func (m *SyftRPCMessage) ToJsonMap() map[string]interface{} { + var bodyContent interface{} + if err := json.Unmarshal(m.Body, &bodyContent); err != nil { + bodyContent = string(m.Body) + } + + return map[string]interface{}{ + "id": m.ID, + "sender": m.Sender, + "url": m.URL.String(), + "headers": m.Headers, + "created": m.Created, + "expires": m.Expires, + "method": m.Method, + "status_code": m.StatusCode, + "body": bodyContent, + } +} + +// Validate validates the message +func (m *SyftRPCMessage) Validate() error { + if m.ID == uuid.Nil { + return &ValidationError{ + Field: "id", + Message: "id cannot be empty", + } + } + if m.Sender == "" { + return &ValidationError{ + Field: "sender", + Message: "sender cannot be empty", + } + } + if err := m.URL.Validate(); err != nil { + return err + } + + // If Method is defined, validate it + if err := m.Method.Validate(); err != nil { + return err + } + + // If StatusCode is defined, validate it + if err := m.StatusCode.Validate(); err != nil { + return err + } + return nil +} diff --git a/internal/syftsdk/sdk.go b/internal/syftsdk/sdk.go index fca8f9ba..76bdd40e 100644 --- a/internal/syftsdk/sdk.go +++ b/internal/syftsdk/sdk.go @@ -6,6 +6,7 @@ import ( "fmt" "log/slog" "os" + "strconv" "strings" "time" @@ -75,8 +76,8 @@ func (s *SyftSDK) Close() { // Authenticate sets the user authentication for API calls and events func (s *SyftSDK) Authenticate(ctx context.Context) error { - if isDevMode(s.config.BaseURL) { - slog.Warn("sdk is in DEV mode, skipping auth") + if isAuthDisabled() || isDevURL(s.config.BaseURL) { + slog.Warn("sdk auth disabled, skipping auth") return nil } @@ -168,9 +169,17 @@ func (s *SyftSDK) setAccessToken(accessToken string) error { return nil } -func isDevMode(baseURL string) bool { +func isAuthDisabled() bool { + authEnabled := os.Getenv("SYFTBOX_AUTH_ENABLED") + enabled, err := strconv.ParseBool(authEnabled) + if err != nil { + return false + } + return !enabled +} + +func isDevURL(baseURL string) bool { return strings.Contains(baseURL, "localhost") || strings.Contains(baseURL, "127.0.0.1") || - strings.Contains(baseURL, "0.0.0.0") || - os.Getenv("SYFTBOX_DEV_MODE") == "true" + strings.Contains(baseURL, "0.0.0.0") } diff --git a/internal/syftsdk/sdk_config.go b/internal/syftsdk/sdk_config.go index 9119347d..f79b53dc 100644 --- a/internal/syftsdk/sdk_config.go +++ b/internal/syftsdk/sdk_config.go @@ -5,7 +5,7 @@ import ( ) const ( - DefaultBaseURL = "https://syftboxdev.openmined.org" + DefaultBaseURL = "https://syftbox.net" ) // SyftSDKConfig is the configuration for the SyftSDK diff --git a/internal/utils/url.go b/internal/utils/url.go new file mode 100644 index 00000000..94c85958 --- /dev/null +++ b/internal/utils/url.go @@ -0,0 +1,253 @@ +package utils + +import ( + "fmt" + "log/slog" + "net/url" + "path/filepath" + "sort" + "strings" +) + +const ( + // URL scheme and components + syftScheme = "syft://" + appDataPath = "app_data" + rpcPath = "rpc" + pathSeparator = "/" +) + +// SyftBoxURL represents a parsed syft:// URL with its components +type SyftBoxURL struct { + Datasite string `json:"datasite"` + AppName string `json:"app_name"` + Endpoint string `json:"endpoint"` + QueryParams map[string]string `json:"query_params"` +} + +// ValidationError represents a validation error with field context +type ValidationError struct { + Field string + Message string +} + +func (e *ValidationError) Error() string { + return fmt.Sprintf("validation error in field '%s': %s", e.Field, e.Message) +} + +// Syft base URL +func (s *SyftBoxURL) BaseURL() string { + endpoint := strings.Trim(s.Endpoint, pathSeparator) + return fmt.Sprintf("%s%s/%s/%s/%s/%s", + syftScheme, s.Datasite, appDataPath, s.AppName, rpcPath, endpoint) +} + +// String returns the string representation of the SyftBoxURL +func (s *SyftBoxURL) String() string { + baseURL := s.BaseURL() + + // Add query parameters if they exist + if len(s.QueryParams) > 0 { + // Sort query params by key lexicographically to ensure consistent ordering + sortedKeys := make([]string, 0, len(s.QueryParams)) + for key := range s.QueryParams { + sortedKeys = append(sortedKeys, key) + } + sort.Strings(sortedKeys) + + queryParams := make([]string, 0, len(s.QueryParams)) + for _, key := range sortedKeys { + value := s.QueryParams[key] + queryParams = append(queryParams, fmt.Sprintf("%s=%s", key, url.QueryEscape(value))) + } + baseURL += "?" + strings.Join(queryParams, "&") + } + + return baseURL +} + +// ToLocalPath converts the SyftBoxURL to a local file system path +func (s *SyftBoxURL) ToLocalPath() string { + endpoint := strings.Trim(s.Endpoint, pathSeparator) + return filepath.ToSlash(filepath.Join(s.Datasite, appDataPath, s.AppName, rpcPath, endpoint)) +} + +// Validate validates the SyftBoxURL fields +func (s *SyftBoxURL) Validate() error { + if s.Datasite == "" { + return &ValidationError{ + Field: "datasite", + Message: "datasite cannot be empty", + } + } + + // Validate datasite follows email pattern + if !IsValidEmail(s.Datasite) { + return &ValidationError{ + Field: "datasite", + Message: "datasite must be a valid email address", + } + } + + if s.AppName == "" { + return &ValidationError{ + Field: "app_name", + Message: "app_name cannot be empty", + } + } + if s.Endpoint == "" { + return &ValidationError{ + Field: "endpoint", + Message: "endpoint cannot be empty", + } + } + + // Validate endpoint doesn't contain spaces or special characters + if strings.ContainsAny(s.Endpoint, " ?&=") { + return &ValidationError{ + Field: "endpoint", + Message: "endpoint cannot contain spaces or special characters (?&=)", + } + } + + return nil +} + +// UnmarshalParam implements gin.UnmarshalParam for automatic query param binding +func (s *SyftBoxURL) UnmarshalParam(param string) error { + slog.Debug("Unmarshalling syft url", "url", param) + parsed, err := FromSyftURL(param) + if err != nil { + slog.Error("Failed to parse syft url", "error", err, "url", param) + return err + } + *s = *parsed + return nil +} + +// NewSyftBoxURL creates a new SyftBoxURL with validation +func NewSyftBoxURL(datasite, appName, endpoint string) (*SyftBoxURL, error) { + syftURL := &SyftBoxURL{ + Datasite: datasite, + AppName: appName, + Endpoint: endpoint, + } + + if err := syftURL.Validate(); err != nil { + return nil, err + } + + return syftURL, nil +} + +// SetQueryParams sets the query parameters +func (s *SyftBoxURL) SetQueryParams(queryParams map[string]string) { + s.QueryParams = queryParams +} + +// parseQueryParams parses and validates query parameters from a URL +func parseQueryParams(rawQuery string) (map[string]string, error) { + if rawQuery == "" { + return nil, nil + } + + queryParams := make(map[string]string) + values, err := url.ParseQuery(rawQuery) + if err != nil { + return nil, fmt.Errorf("failed to parse query parameters: %w", err) + } + + for key, values := range values { + // Validate key doesn't contain spaces or special characters + if strings.ContainsAny(key, " ?&=") { + return nil, fmt.Errorf("query parameter key '%s' cannot contain spaces or special characters (?&=)", key) + } + // Use the first value if multiple values exist + if len(values) > 0 { + queryParams[key] = values[0] + } + } + + return queryParams, nil +} + +// FromSyftURL parses a syft URL string into a SyftBoxURL struct +func FromSyftURL(rawURL string) (*SyftBoxURL, error) { + // Parse the URL using standard library + parsedURL, err := url.Parse(rawURL) + if err != nil { + return nil, fmt.Errorf("failed to parse URL: %w", err) + } + + // Validate scheme + if parsedURL.Scheme != "syft" { + return nil, fmt.Errorf("invalid scheme: expected 'syft', got '%s'", parsedURL.Scheme) + } + + // datasite is the host of the URL + @ + username + datasite := parsedURL.Host + + if datasite == "" { + return nil, fmt.Errorf("invalid syft url: missing datasite (host)") + } + + if parsedURL.User == nil || parsedURL.User.Username() == "" { + return nil, fmt.Errorf("invalid syft url: invalid datasite name") + } + + username := parsedURL.User.Username() + datasite = username + "@" + datasite + + // Split path into components and remove empty strings + path := strings.Trim(parsedURL.Path, pathSeparator) + parts := make([]string, 0) + for _, part := range strings.Split(path, pathSeparator) { + if part != "" { + parts = append(parts, part) + } + } + + // Validate path structure + if len(parts) < 4 { + return nil, fmt.Errorf("invalid path: expected format 'app_data/app_name/rpc/endpoint'") + } + + // Validate expected structure + if parts[0] != appDataPath { + return nil, fmt.Errorf("invalid path: expected '%s' at position 1", appDataPath) + } + + // Find the index of rpc in the path + rpcIndex := -1 + for i, part := range parts { + if part == rpcPath { + rpcIndex = i + break + } + } + + if rpcIndex == -1 { + return nil, fmt.Errorf("invalid path: expected '%s' in path", rpcPath) + } + + // Extract components + appName := strings.Join(parts[1:rpcIndex], pathSeparator) + endpoint := strings.Join(parts[rpcIndex+1:], pathSeparator) + + // Create SyftBoxURL + syftURL, err := NewSyftBoxURL(datasite, appName, endpoint) + if err != nil { + return nil, fmt.Errorf("failed to create syft url from components: %w", err) + } + + // Validate query params + queryParams, err := parseQueryParams(parsedURL.RawQuery) + if err != nil { + return nil, fmt.Errorf("failed to parse query parameters: %w", err) + } + + // Set query params + syftURL.SetQueryParams(queryParams) + + return syftURL, nil +} diff --git a/internal/utils/url_test.go b/internal/utils/url_test.go new file mode 100644 index 00000000..e5121c5c --- /dev/null +++ b/internal/utils/url_test.go @@ -0,0 +1,438 @@ +package utils + +import ( + "testing" +) + +func TestFromSyftURL(t *testing.T) { + tests := []struct { + name string + url string + want *SyftBoxURL + wantErr bool + }{ + { + name: "valid basic url", + url: "syft://user@example.com/app_data/app1/rpc/endpoint1", + want: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint1", + }, + wantErr: false, + }, + { + name: "valid url with query params", + url: "syft://user@example.com/app_data/app1/rpc/endpoint1?param1=value1¶m2=value2", + want: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint1", + QueryParams: map[string]string{ + "param1": "value1", + "param2": "value2", + }, + }, + wantErr: false, + }, + { + name: "invalid scheme", + url: "http://user@example.com/app_data/app1/rpc/endpoint1", + want: nil, + wantErr: true, + }, + { + name: "missing app_data", + url: "syft://user@example.com/wrong/app1/rpc/endpoint1", + want: nil, + wantErr: true, + }, + { + name: "missing rpc", + url: "syft://user@example.com/app_data/app1/wrong/endpoint1", + want: nil, + wantErr: true, + }, + { + name: "empty url", + url: "", + want: nil, + wantErr: true, + }, + { + name: "malformed url", + url: "syft:///invalid", + want: nil, + wantErr: true, + }, + { + name: "invalid datasite format", + url: "syft://notanemail/app_data/app1/rpc/endpoint1", + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := FromSyftURL(tt.url) + if (err != nil) != tt.wantErr { + t.Errorf("FromSyftURL() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + if got.Datasite != tt.want.Datasite { + t.Errorf("FromSyftURL() Datasite = %v, want %v", got.Datasite, tt.want.Datasite) + } + if got.AppName != tt.want.AppName { + t.Errorf("FromSyftURL() AppName = %v, want %v", got.AppName, tt.want.AppName) + } + if got.Endpoint != tt.want.Endpoint { + t.Errorf("FromSyftURL() Endpoint = %v, want %v", got.Endpoint, tt.want.Endpoint) + } + if len(tt.want.QueryParams) > 0 { + if len(got.QueryParams) != len(tt.want.QueryParams) { + t.Errorf("FromSyftURL() QueryParams length = %v, want %v", len(got.QueryParams), len(tt.want.QueryParams)) + } + for k, v := range tt.want.QueryParams { + if gotVal, exists := got.QueryParams[k]; !exists || gotVal != v { + t.Errorf("FromSyftURL() QueryParams[%s] = %v, want %v", k, gotVal, v) + } + } + } + } + }) + } +} + +func TestFromSyftURL_QueryParamEncoding(t *testing.T) { + tests := []struct { + name string + url string + want *SyftBoxURL + wantErr bool + }{ + { + name: "url with encoded spaces in query param values", + url: "syft://test@example.com/app_data/app1/rpc/endpoint1?param1=value%20with%20spaces¶m2=value2", + want: &SyftBoxURL{ + Datasite: "test@example.com", + AppName: "app1", + Endpoint: "endpoint1", + QueryParams: map[string]string{ + "param1": "value with spaces", + "param2": "value2", + }, + }, + wantErr: false, + }, + { + name: "url with encoded special chars in query param values", + url: "syft://test@example.com/app_data/app1/rpc/endpoint1?param1=value%26with%26chars", + want: &SyftBoxURL{ + Datasite: "test@example.com", + AppName: "app1", + Endpoint: "endpoint1", + QueryParams: map[string]string{ + "param1": "value&with&chars", + }, + }, + wantErr: false, + }, + { + name: "url with spaces in query param keys", + url: "syft://test@example.com/app_data/app1/rpc/endpoint1?param with spaces=value1", + want: nil, + wantErr: true, + }, + { + name: "url with special chars in query param keys", + url: "syft://test@example.com/app_data/app1/rpc/endpoint1?param%26with%26chars=value1", + want: nil, + wantErr: true, + }, + { + name: "url with multiple values for same key", + url: "syft://test@example.com/app_data/app1/rpc/endpoint1?param1=value1¶m1=value2", + want: &SyftBoxURL{ + Datasite: "test@example.com", + AppName: "app1", + Endpoint: "endpoint1", + QueryParams: map[string]string{ + "param1": "value1", + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := FromSyftURL(tt.url) + if (err != nil) != tt.wantErr { + t.Errorf("FromSyftURL() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + if got.Datasite != tt.want.Datasite { + t.Errorf("FromSyftURL() Datasite = %v, want %v", got.Datasite, tt.want.Datasite) + } + if got.AppName != tt.want.AppName { + t.Errorf("FromSyftURL() AppName = %v, want %v", got.AppName, tt.want.AppName) + } + if got.Endpoint != tt.want.Endpoint { + t.Errorf("FromSyftURL() Endpoint = %v, want %v", got.Endpoint, tt.want.Endpoint) + } + if len(tt.want.QueryParams) > 0 { + if len(got.QueryParams) != len(tt.want.QueryParams) { + t.Errorf("FromSyftURL() QueryParams length = %v, want %v", len(got.QueryParams), len(tt.want.QueryParams)) + } + for k, v := range tt.want.QueryParams { + if gotVal, exists := got.QueryParams[k]; !exists || gotVal != v { + t.Errorf("FromSyftURL() QueryParams[%s] = %v, want %v", k, gotVal, v) + } + } + } + } + }) + } +} + +func TestSyftBoxURL_String(t *testing.T) { + tests := []struct { + name string + url *SyftBoxURL + want string + }{ + { + name: "basic url", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint1", + }, + want: "syft://user@example.com/app_data/app1/rpc/endpoint1", + }, + { + name: "url with query params", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint1", + QueryParams: map[string]string{ + "param1": "value1", + "param2": "value2", + }, + }, + want: "syft://user@example.com/app_data/app1/rpc/endpoint1?param1=value1¶m2=value2", + }, + { + name: "url with spaces in query param values", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint1", + QueryParams: map[string]string{ + "param1": "value with spaces", + }, + }, + want: "syft://user@example.com/app_data/app1/rpc/endpoint1?param1=value+with+spaces", + }, + { + name: "url with special chars in query param values", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint1", + QueryParams: map[string]string{ + "param1": "value&with&chars", + }, + }, + want: "syft://user@example.com/app_data/app1/rpc/endpoint1?param1=value%26with%26chars", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.url.String(); got != tt.want { + t.Errorf("SyftBoxURL.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSyftBoxURL_ToLocalPath(t *testing.T) { + tests := []struct { + name string + url *SyftBoxURL + want string + }{ + { + name: "basic path", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint1", + }, + want: "user@example.com/app_data/app1/rpc/endpoint1", + }, + { + name: "path with nested endpoint", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint1/sub/path", + }, + want: "user@example.com/app_data/app1/rpc/endpoint1/sub/path", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.url.ToLocalPath(); got != tt.want { + t.Errorf("SyftBoxURL.ToLocalPath() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSyftBoxURL_Validate(t *testing.T) { + tests := []struct { + name string + url *SyftBoxURL + wantErr bool + }{ + { + name: "valid url", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint1", + }, + wantErr: false, + }, + { + name: "empty datasite", + url: &SyftBoxURL{ + Datasite: "", + AppName: "app1", + Endpoint: "endpoint1", + }, + wantErr: true, + }, + { + name: "empty app name", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "", + Endpoint: "endpoint1", + }, + wantErr: true, + }, + { + name: "empty endpoint", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "", + }, + wantErr: true, + }, + { + name: "endpoint with spaces", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint with spaces", + }, + wantErr: true, + }, + { + name: "endpoint with special chars", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint?with=chars", + }, + wantErr: true, + }, + { + name: "invalid datasite format", + url: &SyftBoxURL{ + Datasite: "notanemail", + AppName: "app1", + Endpoint: "endpoint1", + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.url.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("SyftBoxURL.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantErr && err != nil { + validationErr, ok := err.(*ValidationError) + if !ok { + t.Errorf("expected ValidationError, got %T", err) + } + if validationErr.Field == "" { + t.Error("expected non-empty field in ValidationError") + } + } + }) + } +} + +func TestSyftBoxURL_SetQueryParams(t *testing.T) { + tests := []struct { + name string + url *SyftBoxURL + queryParams map[string]string + want map[string]string + }{ + { + name: "set query params", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint1", + }, + queryParams: map[string]string{ + "param1": "value1", + "param2": "value2", + }, + want: map[string]string{ + "param1": "value1", + "param2": "value2", + }, + }, + { + name: "set empty query params", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint1", + }, + queryParams: map[string]string{}, + want: map[string]string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.url.SetQueryParams(tt.queryParams) + if len(tt.url.QueryParams) != len(tt.want) { + t.Errorf("SyftBoxURL.SetQueryParams() length = %v, want %v", len(tt.url.QueryParams), len(tt.want)) + return + } + for k, v := range tt.want { + if gotVal, exists := tt.url.QueryParams[k]; !exists || gotVal != v { + t.Errorf("SyftBoxURL.SetQueryParams()[%s] = %v, want %v", k, gotVal, v) + } + } + }) + } +} diff --git a/justfile b/justfile index c4554826..0669183c 100644 --- a/justfile +++ b/justfile @@ -1,7 +1,3 @@ -SYFTBOX_VERSION := "0.5.0" -BUILD_COMMIT := `git rev-parse --short HEAD` -BUILD_DATE := `date -u +%Y-%m-%dT%H:%M:%SZ` -BUILD_LD_FLAGS := "-s -w" + " -X github.com/openmined/syftbox/internal/version.Version=" + SYFTBOX_VERSION + " -X github.com/openmined/syftbox/internal/version.Revision=" + BUILD_COMMIT + " -X github.com/openmined/syftbox/internal/version.BuildDate=" + BUILD_DATE CLIENT_BUILD_TAGS := "go_json nomsgpack" SERVER_BUILD_TAGS := "sonic avx nomsgpack" @@ -83,6 +79,87 @@ destroy-minio: ssh-minio: docker exec -it syftbox-minio bash +[group('dev-docker')] +run-docker-server: + #!/bin/bash + set -eou pipefail + echo "Building and running SyftBox server with MinIO in Docker..." + cd docker && COMPOSE_BAKE=true docker-compose up -d --build minio server + echo "Server is running at http://localhost:8080" + echo "MinIO console is available at http://localhost:9001" + echo "Run 'cd docker && docker-compose logs -f server' to view server logs" + +[group('dev-docker')] +run-docker-client email *ARGS: + #!/bin/bash + set -eou pipefail + + # Build the client image + docker build -f docker/Dockerfile.client -t syftbox-client . + + # Create clients directory if it doesn't exist + mkdir -p ~/.syftbox/clients + + if [ -z "{{ email }}" ]; then + echo "Usage: just run-docker-client [command]" + echo "Examples:" + echo " just run-docker-client user@example.com login" + echo " just run-docker-client user@example.com daemon" + echo " just run-docker-client user@example.com app list" + exit 1 + fi + + # Sanitize email for container name (replace @ with -at- and . with -dot-) + container_name="syftbox-client-$(echo '{{ email }}' | sed 's/@/-at-/g' | sed 's/\./-dot-/g')" + + # Run the client with email-specific configuration + docker run --rm -it \ + -v ~/.syftbox/clients:/data/clients \ + --network docker_syftbox-network \ + -e SYFTBOX_SERVER_URL=http://syftbox-server:8080 \ + -e SYFTBOX_AUTH_ENABLED=0 \ + --name "$container_name" \ + syftbox-client {{ email }} {{ ARGS }} + +[group('dev-docker')] +run-docker-client-daemon email: + #!/bin/bash + set -eou pipefail + + # Build and run client in daemon mode using docker-compose + cd docker && CLIENT_EMAIL={{ email }} docker-compose -f docker-compose-client.yml up -d --build + echo "Client daemon for {{ email }} is running at http://localhost:7938" + echo "Logs: cd docker && docker-compose -f docker-compose-client.yml logs -f" + +[group('dev-docker')] +stop-docker-client email: + #!/bin/bash + set -eou pipefail + + cd docker && CLIENT_EMAIL={{ email }} docker-compose -f docker-compose-client.yml down + +[group('dev-docker')] +list-docker-clients: + #!/bin/bash + set -eou pipefail + + echo "Available SyftBox clients:" + if [ -d ~/.syftbox/clients ]; then + ls -la ~/.syftbox/clients/ | grep -E '^d' | grep -v '\.$' | awk '{print " - " $NF}' + else + echo " No clients found" + fi + +[group('dev-docker')] +destroy-docker-server: + #!/bin/bash + set -eou pipefail + echo "Stopping and removing SyftBox Docker containers..." + cd docker && docker-compose down -v + echo "Removing Docker images..." + docker rmi syftbox-server syftbox-client 2>/dev/null || true + echo "Docker environment cleaned up" + [group('dev')] test: env -i \ @@ -98,14 +175,21 @@ test: [doc('Needs a platform specific compiler. Example: CC="aarch64-linux-musl-gcc" just build-client-target goos=linux goarch=arm64')] [group('build')] -build-client-target goos=`go env GOOS` goarch=`go env GOARCH`: +build-client-target goos=`go env GOOS` goarch=`go env GOARCH`: version-utils #!/bin/bash set -eou pipefail + # Calculate build variables locally + SYFTBOX_VERSION=$(svu current 2>/dev/null) + echo "SYFTBOX_VERSION: $SYFTBOX_VERSION" + BUILD_COMMIT=$(git rev-parse --short HEAD) + BUILD_DATE=$(date -u +%Y-%m-%dT%H:%M:%SZ) + BUILD_LD_FLAGS="-s -w -X github.com/openmined/syftbox/internal/version.Version=$SYFTBOX_VERSION -X github.com/openmined/syftbox/internal/version.Revision=$BUILD_COMMIT -X github.com/openmined/syftbox/internal/version.BuildDate=$BUILD_DATE" + export GOOS="{{ goos }}" export GOARCH="{{ goarch }}" export CGO_ENABLED=0 - export GO_LDFLAGS="$([ '{{ goos }}' = 'windows' ] && echo '-H windowsgui '){{ BUILD_LD_FLAGS }}" + export GO_LDFLAGS="$([ '{{ goos }}' = 'windows' ] && echo '-H windowsgui ')$BUILD_LD_FLAGS" if [ "{{ goos }}" = "darwin" ]; then echo "Building for darwin. CGO_ENABLED=1" @@ -130,8 +214,10 @@ build-all: goreleaser release --snapshot --clean [group('deploy')] -deploy-client remote="syftbox-yash": build-all +deploy-client remote: build-all + #!/bin/bash echo "Deploying syftbox client to {{ _cyan }}{{ remote }}{{ _nc }}" + rm -rf releases && mkdir releases cp -r .out/syftbox_client_*.{tar.gz,zip} releases/ ssh {{ remote }} "rm -rfv /home/azureuser/releases.new && mkdir -p /home/azureuser/releases.new" @@ -139,14 +225,16 @@ deploy-client remote="syftbox-yash": build-all ssh {{ remote }} "rm -rfv /home/azureuser/releases/ && mv -fv /home/azureuser/releases.new/ /home/azureuser/releases/" [group('deploy')] -deploy-server remote="syftbox-yash": build-server +deploy-server remote: build-server + #!/bin/bash echo "Deploying syftbox server to {{ _cyan }}{{ remote }}{{ _nc }}" + scp .out/syftbox_server_linux_amd64_v1/syftbox_server {{ remote }}:/home/azureuser/syftbox_server_new ssh {{ remote }} "rm -fv /home/azureuser/syftbox_server && mv -fv /home/azureuser/syftbox_server_new /home/azureuser/syftbox_server" ssh {{ remote }} "sudo systemctl restart syftbox" [group('deploy')] -deploy remote="syftbox-yash": (deploy-client remote) (deploy-server remote) +deploy remote: (deploy-client remote) (deploy-server remote) echo "Deployed syftbox client & server to {{ _cyan }}{{ remote }}{{ _nc }}" [group('utils')] @@ -158,3 +246,168 @@ setup-toolchain: [group('utils')] clean: rm -rf .data .out releases certs cover.out + +[group('version')] +bump type: version-utils + #!/bin/bash + set -eou pipefail + + # Version Management Commands + # + # This project uses semantic versioning with svu (https://github.com/caarlos0/svu) + # for automatic version calculation based on git tags. + # + # Workflow: + # 1. Use `just show-version` to see current version and next versions + # 2. Use `just bump type` to update files only (manual commit/tag) + # 3. Use `just release type` to update files, commit, and tag automatically + # 4. Use `just update-version-files version=X.Y.Z` for custom versions + # + # Examples: + # just show-version # Show current and next versions + # just bump patch # Update files to next patch version + # just bump minor # Update files to next minor version + # just bump major # Update files to next major version + # just release patch # Bump, commit, and tag patch version + # just update-version-files version=1.2.3 # Set specific version + + if [ -z "{{ type }}" ]; then + echo -e "{{ _red }}Error: bump type is required{{ _nc }}" + echo "Usage: just bump " + echo "Examples:" + echo " just bump patch" + echo " just bump minor" + echo " just bump major" + exit 1 + fi + + # Validate bump type + if [[ ! "{{ type }}" =~ ^(patch|minor|major)$ ]]; then + echo -e "{{ _red }}Error: Invalid bump type '{{ type }}'{{ _nc }}" + echo "Valid types: patch, minor, major" + exit 1 + fi + + echo -e "{{ _cyan }}Bumping {{ type }} version...{{ _nc }}" + new_version=$(svu {{ type }} | sed 's/^v//') + echo -e "{{ _green }}New version: $new_version{{ _nc }}" + just update-version-files version="$new_version" + echo -e "{{ _green }}Version bumped to $new_version{{ _nc }}" + echo -e "{{ _yellow }}Don't forget to commit and tag:{{ _nc }}" + echo " git add ." + echo " git commit -m \"chore: bump version to $new_version\"" + echo " git tag v$new_version" + +release type: version-utils + #!/bin/bash + set -eou pipefail + + if [ -z "{{ type }}" ]; then + echo -e "{{ _red }}Error: release type is required{{ _nc }}" + echo "Usage: just release " + echo "Examples:" + echo " just release patch" + echo " just release minor" + echo " just release major" + exit 1 + fi + + # Validate release type + if [[ ! "{{ type }}" =~ ^(patch|minor|major)$ ]]; then + echo -e "{{ _red }}Error: Invalid release type '{{ type }}'{{ _nc }}" + echo "Valid types: patch, minor, major" + exit 1 + fi + + echo -e "{{ _cyan }}Releasing {{ type }} version...{{ _nc }}" + new_version=$(svu {{ type }} | sed 's/^v//') + echo -e "{{ _green }}New version: $new_version{{ _nc }}" + just update-version-files version="$new_version" + just commit-and-tag version="$new_version" + echo -e "{{ _green }}✓ Released {{ type }} version $new_version{{ _nc }}" + +[group('version')] +show-version: version-utils + #!/bin/bash + set -eou pipefail + echo -e "{{ _cyan }}Current version information:{{ _nc }}" + + # Try to get current version, handle errors gracefully + current_version=$(svu current 2>/dev/null || echo "No valid version tags found") + echo " SVU current: $current_version" + + # Try to get next versions, handle errors gracefully + next_patch=$(svu patch 2>/dev/null || echo "Error") + next_minor=$(svu minor 2>/dev/null || echo "Error") + next_major=$(svu major 2>/dev/null || echo "Error") + + echo " SVU next patch: $next_patch" + echo " SVU next minor: $next_minor" + echo " SVU next major: $next_major" + echo " Git tags:" + git tag --sort=-version:refname | head -5 + +[group('version')] +commit-and-tag version: + #!/bin/bash + set -eou pipefail + + # Extract version from parameter (handle both "version=0.5.1" and "0.5.1" formats) + version_value="{{ version }}" + if [[ "$version_value" == version=* ]]; then + version_value="${version_value#version=}" + fi + + if [ -z "$version_value" ]; then + echo -e "{{ _red }}Error: version parameter is required{{ _nc }}" + echo "Usage: just commit-and-tag version=1.2.3" + exit 1 + fi + + echo -e "{{ _cyan }}Committing and tagging version $version_value...{{ _nc }}" + + # Check if there are changes to commit + if git diff --quiet && git diff --cached --quiet; then + echo -e "{{ _yellow }}No changes to commit{{ _nc }}" + else + git add . + git commit -m "chore: bump version to $version_value" + echo -e "{{ _green }}✓ Committed changes{{ _nc }}" + fi + + # Create tag + git tag v$version_value + echo -e "{{ _green }}✓ Tagged v$version_value{{ _nc }}" + + echo -e "{{ _green }}Version $version_value has been committed and tagged!{{ _nc }}" + +[group('version')] +update-version-files version: + #!/bin/bash + set -eou pipefail + + # Extract version from parameter (handle both "version=0.5.1" and "0.5.1" formats) + version_value="{{ version }}" + if [[ "$version_value" == version=* ]]; then + version_value="${version_value#version=}" + fi + + if [ -z "$version_value" ]; then + echo -e "{{ _red }}Error: version parameter is required{{ _nc }}" + echo "Usage: just update-version-files version=1.2.3" + exit 1 + fi + + echo -e "{{ _cyan }}Updating version to $version_value in all files...{{ _nc }}" + + # Update goreleaser.yaml + sed -i "s/-X github.com\/openmined\/syftbox\/internal\/version.Version=.*/-X github.com\/openmined\/syftbox\/internal\/version.Version=$version_value/g" .goreleaser.yaml + echo -e "{{ _green }}✓ Updated .goreleaser.yaml{{ _nc }}" + + # Update version.go + sed -i "s/Version = \".*\"/Version = \"$version_value\"/" internal/version/version.go + echo -e "{{ _green }}✓ Updated internal/version/version.go{{ _nc }}" + +[group('version')] +version-utils: + go install github.com/caarlos0/svu@latest \ No newline at end of file