From 686e55e394977ec113f179236c4550992cbff4a0 Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Mon, 16 Jun 2025 17:05:28 +1000 Subject: [PATCH 01/19] Enhance ACL security and fix terminal node behavior Core Changes: - **Terminal node enforcement**: Prevent child rulesets from being added under terminal nodes for proper security boundaries - **Symlink security**: Reject symlinked ACL files in Exists() and LoadFromFile() to prevent security vulnerabilities - **Rule removal fix**: RemoveRuleSet now only clears rules from target node while preserving child nodes, preventing accidental data loss - **Rule clearing**: SetRules properly clears existing rules when passed nil/empty ruleset Technical Details: - tree.go: Add ErrTerminalNodeExists, validate terminal nodes in AddRuleSet, rewrite RemoveRuleSet to preserve children - node.go: Fix SetRules to clear rules when nil/empty input provided - aclspec.go: Replace os.Stat with os.Lstat and add symlink rejection in Exists() - ruleset.go: Add symlink validation in LoadFromFile() with descriptive error messages Tests added to internal/server/acl and internal/aclspec packages to increase coverage and validate new security behaviors. --- docs/index.md | 23 ++ docs/permissions.md | 592 ++++++++++++++++++++++++++++++ internal/aclspec/access_test.go | 261 +++++++++++++ internal/aclspec/aclspec.go | 9 +- internal/aclspec/aclspec_test.go | 357 ++++++++++++++++++ internal/aclspec/limits_test.go | 123 +++++++ internal/aclspec/migrate_test.go | 247 +++++++++++++ internal/aclspec/rule_test.go | 247 +++++++++++++ internal/aclspec/ruleset.go | 13 + internal/aclspec/ruleset_test.go | 368 +++++++++++++++++++ internal/server/acl/cache_test.go | 349 ++++++++++++++++++ internal/server/acl/level_test.go | 192 ++++++++++ internal/server/acl/node.go | 3 + internal/server/acl/tree.go | 48 ++- internal/server/acl/tree_test.go | 320 +++++++++++++++- 15 files changed, 3131 insertions(+), 21 deletions(-) create mode 100644 docs/index.md create mode 100644 docs/permissions.md create mode 100644 internal/aclspec/access_test.go create mode 100644 internal/aclspec/aclspec_test.go create mode 100644 internal/aclspec/limits_test.go create mode 100644 internal/aclspec/migrate_test.go create mode 100644 internal/aclspec/rule_test.go create mode 100644 internal/aclspec/ruleset_test.go create mode 100644 internal/server/acl/cache_test.go create mode 100644 internal/server/acl/level_test.go diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 00000000..56b3672a --- /dev/null +++ b/docs/index.md @@ -0,0 +1,23 @@ +# SyftBox Documentation + +- [SyftBox Permissions System](permissions.md) + - [Overview](permissions.md#overview) + - [Design Philosophy: Unix-Inspired Permissions](permissions.md#design-philosophy-unix-inspired-permissions) + - [User Identity](permissions.md#user-identity) + - [Architecture Overview](permissions.md#architecture-overview) + - [Core Components](permissions.md#core-components) + - [ACL Service](permissions.md#1-acl-service-internalserveraclaclgo) + - [Tree Structure](permissions.md#2-tree-structure-internalserveracltreego) + - [Testing and Validation](permissions.md#testing-and-validation) + - [Test Coverage](permissions.md#test-coverage) + - [Example Test Case](permissions.md#example-test-case) + - [Best Practices](permissions.md#best-practices) + - [Permission Design](permissions.md#1-permission-design) + - [Rule Organization](permissions.md#2-rule-organization) + - [Security Considerations](permissions.md#3-security-considerations) + - [Performance Optimization](permissions.md#4-performance-optimization) + - [Integration Points](permissions.md#integration-points) + - [Server Integration](permissions.md#server-integration-internalserverservergo) + - [Client Synchronization](permissions.md#client-synchronization-internalclientsync) + - [Migration and Compatibility](permissions.md#migration-and-compatibility) + - [Conclusion](permissions.md#conclusion) diff --git a/docs/permissions.md b/docs/permissions.md new file mode 100644 index 00000000..c7ddaf4e --- /dev/null +++ b/docs/permissions.md @@ -0,0 +1,592 @@ +# SyftBox Permissions System + +## Overview + +SyftBox implements a simple but powerful file-based Access Control List (ACL) system that mimics Unix permissions while providing more granular control over distributed file systems. The system is designed to secure data sharing across federated networks while maintaining ease of use and flexibility. + +## SyftBox Permissions in 1 Minute ⚡ + +SyftBox uses **email-based permissions** to control who can access your files. Each user has a datasite with a default private root permission: + +### Default Folder Structure +``` +/datasites/ +└── your@email.org/ + ├── syft.pub.yaml # Root permissions (private by default) + ├── public/ + │ └── syft.pub.yaml # Public files + └── shared/ + └── syft.pub.yaml # Shared with specific collaborators +``` + +### Default Root Permissions +Your root `/datasites/your@email.org/syft.pub.yaml` starts private: +```yaml +rules: + - pattern: "**" + access: + read: [] # No one can read + write: [] # No one can write + admin: [] # No one can modify permissions +``` + +### Example Public and Shared Folders +Create `/datasites/your@email.org/public/syft.pub.yaml`: +```yaml +terminal: true # optional `terminal: true`, child directories don't inherit parent permissions +rules: + - pattern: "**" + access: + read: ["*"] # Everyone can read +``` + +Create `/datasites/your@email.org/shared/syft.pub.yaml`: +```yaml +terminal: true # optional When `terminal: true`, child directories don't inherit parent permissions +rules: + - pattern: "**" + access: + read: ["alice@university.edu", "bob@research.org"] + write: ["alice@university.edu"] +``` + +**Key Points:** +- **Email addresses** identify users (`alice@university.edu`) +- **Patterns** match files (`*.txt`, `folder/*`, `**` for everything) +- **Permissions**: `read` (view), `write` (edit), `admin` (change permissions) +- **`"*"`** means public access +- **Root is private** by default - create subfolders for sharing +- **More specific rules** override general ones + + + +## Design Philosophy: Unix-Inspired Permissions + +The SyftBox permission system draws inspiration from Unix file permissions but extends beyond the traditional user/group/other model to support: + +- **Distributed users**: Multiple users identified by email addresses across different nodes +- **Hierarchical rules**: Directory-based permission inheritance +- **Pattern matching**: Glob-based file pattern permissions +- **Granular access**: Separate read, write, and admin permissions +- **Terminal inheritance**: Controlled permission propagation + +## User Identity + +In SyftBox, users are identified by their **email addresses**. This federated approach allows researchers, organizations, and collaborators from different institutions to securely share data while maintaining clear identity management. Examples of user identifiers include: + +- `*` - Special identifier meaning "everyone" (public access) +- `alice@research.org` - A researcher at a research institution +- `bob@university.edu` - A professor at a university +- `team@company.com` - A team or service account at a company +- `*@company.com` - Email's can use glob patterns to allow org level access + +## Architecture Overview + +The permission system consists of several key components working together: + +``` +┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐ +│ ACL Service │────│ Rule Cache │────│ Tree Structure │ +│ │ │ │ │ │ +│ - Rule lookup │ │ - O(1) lookups │ │ - N-ary tree │ +│ - Access check │ │ - Cache hits │ │ - Path traversal│ +│ - File limits │ │ - Invalidation │ │ - Rule storage │ +└─────────────────┘ └──────────────────┘ └─────────────────┘ + │ │ │ + └───────────────────────┼───────────────────────┘ + │ + ┌─────────────────────┐ + │ syft.pub.yaml │ + │ Configuration │ + │ │ + │ - Rules patterns │ + │ - Access grants │ + │ - Terminal flags │ + └─────────────────────┘ +``` + +## Core Components + +### 1. ACL Service (`internal/server/acl/acl.go`) + +The `AclService` is the main orchestrator that manages access control: + +```go +type AclService struct { + tree *Tree // Hierarchical rule storage + cache *RuleCache // Performance optimization +} +``` + +**Key methods:** +- `CanAccess()` - Primary access control check +- `GetRule()` - Rule lookup with caching +- `AddRuleSet()` - Dynamic rule management + +### 2. Tree Structure (`internal/server/acl/tree.go`) + +The permission tree stores rules in an n-ary tree for efficient path-based lookups: + +```go +type Tree struct { + root *Node +} +``` + +**Features:** +- **O(depth)** rule lookups +- **Terminal nodes** to prevent inheritance +- **Dynamic rule addition/removal** +- **Path-based traversal** + +### 3. Rule System (`internal/server/acl/rule.go`) + +Rules define the actual access control logic: + +```go +type Rule struct { + fullPattern string // Complete path + glob pattern + rule *aclspec.Rule // Access specifications + node *Node // Associated tree node +} +``` + +### 4. Access Levels (`internal/server/acl/level.go`) + +The system defines five distinct access levels using bit flags: + +```go +type AccessLevel uint8 + +const ( + AccessRead AccessLevel = 1 << iota // 0001 + AccessCreate // 0010 + AccessWrite // 0100 + AccessReadACL // 1000 + AccessWriteACL // 10000 +) +``` + +## Permission Files: `syft.pub.yaml` + +### File Format + +Permission files follow the YAML format and are always named `syft.pub.yaml`: + +```yaml +terminal: true +rules: + - pattern: "*.md" + access: + read: ["*"] + - pattern: "private/*.txt" + access: + read: ["alice@research.org", "bob@university.edu"] + write: ["alice@research.org"] + - pattern: "**" + access: {} +``` + +### Configuration Elements + +#### Terminal Flag +- **Purpose**: Controls permission inheritance +- **Values**: `true` (stop inheritance) or `false` (allow inheritance) +- **Impact**: When `terminal: true`, child directories don't inherit parent permissions + +#### Rules Array +Each rule contains: +- **`pattern`**: Glob pattern matching files/directories +- **`access`**: Permission specifications + +#### Access Specifications +- **`admin`**: Can modify ACL files (`syft.pub.yaml`) - specified by email addresses +- **`write`**: Can create, modify, and delete files - specified by email addresses +- **`read`**: Can view and download files - specified by email addresses +- **Special value `"*"`**: Grants access to everyone (public access) + +### Example Configurations + +#### Public Documentation with Private Admin +```yaml +terminal: true +rules: + - pattern: "docs/*.md" + access: + read: ["*"] + write: ["maintainer@company.com"] + - pattern: "admin/*" + access: + admin: ["admin@company.com"] + - pattern: "**" + access: + read: ["team@company.com"] +``` + +#### Collaborative Project Structure +```yaml +terminal: false +rules: + - pattern: "src/**/*.go" + access: + read: ["*"] + write: ["dev1@company.com", "dev2@university.edu"] + - pattern: "tests/**" + access: + read: ["*"] + write: ["dev1@company.com", "qa@company.com"] + - pattern: "config/*.yaml" + access: + admin: ["devops@company.com"] + - pattern: "**" + access: {} +``` + +#### Research Data Sharing +```yaml +terminal: true +rules: + - pattern: "public_datasets/*" + access: + read: ["*"] + - pattern: "shared_analysis/*.ipynb" + access: + read: ["alice@research.org", "bob@university.edu", "charlie@institute.org"] + write: ["alice@research.org"] + - pattern: "raw_data/*" + access: + read: ["alice@research.org", "bob@university.edu"] + admin: ["data-owner@research.org"] +``` + +## Access Control Flow + +### Step-by-Step Permission Check + +1. **Owner Check** (`internal/server/acl/acl.go`) + ```go + if user.IsOwner { + return nil // Owners bypass all restrictions + } + ``` + +2. **Rule Lookup** (`internal/server/acl/acl.go`) + - Check cache for O(1) performance + - Traverse tree structure for O(depth) lookup + - Find most specific matching rule + +3. **ACL File Elevation** (`internal/server/acl/acl.go`) + ```go + if isAcl && level == AccessWrite { + level = AccessWriteACL // Elevate to admin requirement + } + ``` + +4. **File Limits Check** (`internal/server/acl/rule.go`) + - Validate file size limits + - Check directory permissions + - Verify symlink restrictions + +5. **Access Verification** (`internal/server/acl/rule.go`) + - Check user permissions against required level + - Handle permission hierarchy (admin > write > read) + +### Permission Hierarchy + +The system enforces a clear permission hierarchy: + +``` +Admin (AccessWriteACL) + ├─ Can modify syft.pub.yaml files + ├─ Full write access + └─ Full read access + │ +Write (AccessWrite) + ├─ Can create/modify/delete files + ├─ Subject to file limits + └─ Full read access + │ +Read (AccessRead) + └─ Can view and download files +``` + +## Tree Structure and Inheritance + +### Hierarchical Rule Storage + +The tree structure mirrors the file system hierarchy: + +``` +/ +├─ users/ +│ ├─ alice/ +│ │ ├─ syft.pub.yaml (terminal: true) +│ │ └─ documents/ +│ └─ bob/ +└─ shared/ + ├─ syft.pub.yaml (terminal: false) + └─ projects/ + └─ project1/ + └─ syft.pub.yaml (terminal: true) +``` + +### Rule Resolution (`internal/server/acl/tree.go`) + +The `GetNearestNodeWithRules()` method traverses up the tree to find applicable rules: + +```go +func (t *Tree) GetNearestNodeWithRules(path string) *Node { + parts := pathParts(path) + var candidate *Node + current := t.root + + for _, part := range parts { + if current.IsTerminal() { + break // Stop at terminal nodes + } + + child, exists := current.GetChild(part) + if !exists { + break + } + + current = child + if child.Rules() != nil { + candidate = current + } + } + + return candidate +} +``` + +## Caching System + +### Performance Optimization + +The rule cache provides O(1) lookups for frequently accessed paths: + +- **Cache hits**: Return immediately without tree traversal +- **Cache misses**: Perform tree lookup and cache result +- **Cache invalidation**: Clear affected entries when rules change + +### Cache Management (`internal/server/acl/acl.go`) + +```go +func (s *AclService) RemoveRuleSet(path string) bool { + s.cache.DeletePrefix(path) // Invalidate cache entries + return s.tree.RemoveRuleSet(path) +} +``` + +## File Limits and Restrictions + +### Supported Limits (`internal/aclspec/limits.go`) + +```go +type Limits struct { + MaxFileSize int64 `yaml:"maxFileSize,omitempty"` + MaxFiles uint32 `yaml:"maxFiles,omitempty"` + AllowDirs bool `yaml:"allowDirs,omitempty"` + AllowSymlinks bool `yaml:"allowSymlinks,omitempty"` +} +``` + +### Enforcement Logic (`internal/server/acl/rule.go`) + +```go +func (r *Rule) CheckLimits(info *File) error { + limits := r.rule.Limits + + if limits.MaxFileSize > 0 && info.Size > limits.MaxFileSize { + return ErrFileSizeExceeded + } + + if !limits.AllowDirs && (info.IsDir || strings.Count(info.Path, pathSep) > 0) { + return ErrDirsNotAllowed + } + + if !limits.AllowSymlinks && info.IsSymlink { + return ErrSymlinksNotAllowed + } + + return nil +} +``` + +## Pattern Matching + +### Glob Pattern Support + +The system supports powerful glob patterns for flexible file matching: + +- **`**`** - Matches all files and directories recursively +- **`*.ext`** - Matches all files with specific extension +- **`dir/*`** - Matches all direct children of directory +- **`dir/**`** - Matches all descendants of directory +- **`specific.txt`** - Matches exact filename + +### Pattern Examples + +```yaml +rules: + - pattern: "**/*.py" # All Python files + access: + read: ["dev1@company.com", "dev2@university.edu"] + + - pattern: "tests/**" # Everything in tests directory + access: + write: ["qa@company.com"] + + - pattern: "config.yaml" # Specific file + access: + admin: ["devops@company.com"] + + - pattern: "public/*" # Direct children only + access: + read: ["*"] +``` + +## Implementation Details + +### User Representation (`internal/server/acl/types.go`) + +```go +type User struct { + ID string // User identifier (email address) + IsOwner bool // Owner bypass flag +} +``` + +### File Information (`internal/server/acl/types.go`) + +```go +type File struct { + Path string // File system path + IsDir bool // Directory flag + IsSymlink bool // Symbolic link flag + Size int64 // File size in bytes +} +``` + +### Rule Set Management (`internal/aclspec/ruleset.go`) + +```go +type RuleSet struct { + Rules []*Rule `yaml:"rules,omitempty"` + Terminal bool `yaml:"terminal,omitempty"` + Path string `yaml:"-"` // Internal use only +} +``` + +## Error Handling + +The system defines specific error types for different access violations: + +```go +var ( + ErrAdminRequired = errors.New("admin access required") + ErrWriteRequired = errors.New("write access required") + ErrReadRequired = errors.New("read access required") + ErrDirsNotAllowed = errors.New("directories not allowed") + ErrSymlinksNotAllowed = errors.New("symlinks not allowed") + ErrFileSizeExceeded = errors.New("file size exceeds limits") +) +``` + +## Testing and Validation + +### Test Coverage + +The system includes comprehensive tests covering: + +- **Rule resolution** (`internal/server/acl/acl_test.go`) +- **Access control enforcement** (`internal/server/acl/acl_test.go`) +- **File limit validation** (`internal/server/acl/acl_test.go`) +- **Cache behavior** (`internal/server/acl/acl_test.go`) +- **Tree operations** (`internal/server/acl/tree_test.go`) + +### Example Test Case + +```go +func TestAclServiceCanAccess(t *testing.T) { + service := NewAclService() + + ruleset := aclspec.NewRuleSet( + "user", + aclspec.SetTerminal, + aclspec.NewRule("public/*.txt", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), + aclspec.NewRule("private/*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + ) + + err := service.AddRuleSet(ruleset) + assert.NoError(t, err) + + // Test access control + regularUser := &User{ID: aclspec.Everyone, IsOwner: false} + publicFile := &File{Path: "user/public/doc.txt", Size: 100} + + err = service.CanAccess(regularUser, publicFile, AccessRead) + assert.NoError(t, err) // Should succeed +} +``` + +## Best Practices + +### 1. Permission Design +- Start with restrictive permissions and grant access as needed +- Use terminal nodes to prevent unintended inheritance +- Group related files using consistent patterns + +### 2. Rule Organization +- Place more specific rules before general ones +- Always include a default `**` catch-all rule +- Use descriptive user group names + +### 3. Security Considerations +- Monitor for overly permissive `*` grants +- Implement file size limits for public areas + +### 4. Performance Optimization +- Keep rule sets small and focused +- Use terminal nodes to limit tree traversal +- Monitor cache hit rates + +## Integration Points + +### Server Integration (`internal/server/server.go`) + +The ACL service integrates with the SyftBox server to protect all file operations: + +- **File uploads**: Check write permissions +- **File downloads**: Verify read access +- **Directory listings**: Filter based on permissions +- **ACL modifications**: Require admin access + +### Client Synchronization (`internal/client/sync/`) + +The sync engine respects ACL permissions during: + +- **Upload operations**: Validate write access before sync +- **Download filtering**: Only sync permitted files +- **Conflict resolution**: Consider permission changes + +## Migration and Compatibility + +The system includes migration support for updating ACL formats: + +- **Backward compatibility**: Support older permission formats +- **Automatic migration**: Upgrade rules during system updates +- **Validation**: Ensure rule consistency after migration + +## Conclusion + +The SyftBox permission system provides a robust, Unix-inspired access control mechanism tailored for distributed file systems. By combining hierarchical rules, pattern matching, and efficient caching, it offers both security and performance for federated data sharing scenarios. + +The system's design emphasizes: +- **Flexibility**: Glob patterns and granular permissions +- **Performance**: O(1) cached lookups and O(depth) tree traversal +- **Security**: Default-deny policies and owner protections +- **Usability**: YAML configuration and inheritance patterns + +This architecture enables secure, scalable data sharing across distributed networks while maintaining the intuitive permission model familiar to Unix users. \ No newline at end of file diff --git a/internal/aclspec/access_test.go b/internal/aclspec/access_test.go new file mode 100644 index 00000000..c38405c5 --- /dev/null +++ b/internal/aclspec/access_test.go @@ -0,0 +1,261 @@ +package aclspec + +import ( + "testing" + + mapset "github.com/deckarep/golang-set/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +func TestNewAccess(t *testing.T) { + // Test creating Access with different user email combinations + // This tests the core constructor and validates that sets are properly initialized + admin := []string{"admin@example.com", "owner@company.org"} + write := []string{"writer1@openmined.org", "writer2@research.edu", "collab@university.ac.uk"} + read := []string{"reader@public.org"} + + access := NewAccess(admin, write, read) + + // Verify all sets are properly initialized and contain expected user emails + assert.True(t, access.Admin.Contains("admin@example.com")) + assert.True(t, access.Admin.Contains("owner@company.org")) + assert.Equal(t, 2, access.Admin.Cardinality()) + + assert.True(t, access.Write.Contains("writer1@openmined.org")) + assert.True(t, access.Write.Contains("writer2@research.edu")) + assert.True(t, access.Write.Contains("collab@university.ac.uk")) + assert.Equal(t, 3, access.Write.Cardinality()) + + assert.True(t, access.Read.Contains("reader@public.org")) + assert.Equal(t, 1, access.Read.Cardinality()) +} + +func TestNewAccessWithEmptyLists(t *testing.T) { + // Test edge case of creating Access with empty lists + // This ensures the constructor handles empty inputs gracefully + access := NewAccess([]string{}, []string{}, []string{}) + + assert.Equal(t, 0, access.Admin.Cardinality()) + assert.Equal(t, 0, access.Write.Cardinality()) + assert.Equal(t, 0, access.Read.Cardinality()) +} + +func TestPrivateAccess(t *testing.T) { + // Test that PrivateAccess creates an Access object with no permissions + // This is a critical security test - private should mean NO access for anyone + access := PrivateAccess() + + assert.Equal(t, 0, access.Admin.Cardinality(), "Private access should have no admin users") + assert.Equal(t, 0, access.Write.Cardinality(), "Private access should have no write users") + assert.Equal(t, 0, access.Read.Cardinality(), "Private access should have no read users") +} + +func TestPublicReadAccess(t *testing.T) { + // Test that PublicReadAccess grants read access to everyone but nothing else + // This validates the most common public sharing scenario + access := PublicReadAccess() + + assert.Equal(t, 0, access.Admin.Cardinality(), "Public read should have no admin users") + assert.Equal(t, 0, access.Write.Cardinality(), "Public read should have no write users") + assert.Equal(t, 1, access.Read.Cardinality(), "Public read should have exactly one read entry") + assert.True(t, access.Read.Contains(Everyone), "Public read should grant read access to everyone") +} + +func TestPublicReadWriteAccess(t *testing.T) { + // Test that PublicReadWriteAccess grants write (and implicitly read) access to everyone + // This tests the dangerous but sometimes necessary full public access + access := PublicReadWriteAccess() + + assert.Equal(t, 0, access.Admin.Cardinality(), "Public read-write should have no admin users") + assert.Equal(t, 1, access.Write.Cardinality(), "Public read-write should have exactly one write entry") + assert.True(t, access.Write.Contains(Everyone), "Public read-write should grant write access to everyone") + assert.Equal(t, 0, access.Read.Cardinality(), "Public read-write should not set read (write implies read)") +} + +func TestSharedReadAccess(t *testing.T) { + // Test creating shared read access for specific user emails + // This validates the common scenario of sharing read access with specific collaborators + users := []string{"alice@research.org", "bob@university.edu", "charlie@company.com"} + access := SharedReadAccess(users...) + + assert.Equal(t, 0, access.Admin.Cardinality(), "Shared read should have no admin users") + assert.Equal(t, 0, access.Write.Cardinality(), "Shared read should have no write users") + assert.Equal(t, 3, access.Read.Cardinality(), "Shared read should have exactly the specified users") + + for _, user := range users { + assert.True(t, access.Read.Contains(user), "Shared read should contain user %s", user) + } +} + +func TestSharedReadWriteAccess(t *testing.T) { + // Test creating shared read-write access for specific user emails + // This validates collaborative scenarios where specific users need write access + users := []string{"maintainer1@openmined.org", "maintainer2@research.edu"} + access := SharedReadWriteAccess(users...) + + assert.Equal(t, 0, access.Admin.Cardinality(), "Shared read-write should have no admin users") + assert.Equal(t, 2, access.Write.Cardinality(), "Shared read-write should have exactly the specified users") + assert.Equal(t, 0, access.Read.Cardinality(), "Shared read-write should not set read (write implies read)") + + for _, user := range users { + assert.True(t, access.Write.Contains(user), "Shared read-write should contain user %s", user) + } +} + +func TestAccessUnmarshalYAML(t *testing.T) { + // Test unmarshaling Access from YAML format + // This is critical for loading syft.pub.yaml files from disk + yamlData := ` +admin: ["admin@example.com", "owner@company.org"] +read: ["reader1@public.org", "reader2@university.edu", "reader3@research.net"] +write: ["writer@openmined.org"] +` + + var access Access + err := yaml.Unmarshal([]byte(yamlData), &access) + require.NoError(t, err, "YAML unmarshaling should succeed") + + // Verify that all user emails were properly loaded into their respective sets + assert.Equal(t, 2, access.Admin.Cardinality()) + assert.True(t, access.Admin.Contains("admin@example.com")) + assert.True(t, access.Admin.Contains("owner@company.org")) + + assert.Equal(t, 3, access.Read.Cardinality()) + assert.True(t, access.Read.Contains("reader1@public.org")) + assert.True(t, access.Read.Contains("reader2@university.edu")) + assert.True(t, access.Read.Contains("reader3@research.net")) + + assert.Equal(t, 1, access.Write.Cardinality()) + assert.True(t, access.Write.Contains("writer@openmined.org")) +} + +func TestAccessUnmarshalYAMLWithMissingSections(t *testing.T) { + // Test unmarshaling when some access sections are missing + // This ensures the system handles partial configurations gracefully + yamlData := ` +read: ["reader@example.com"] +# write and admin sections intentionally missing +` + + var access Access + err := yaml.Unmarshal([]byte(yamlData), &access) + require.NoError(t, err, "YAML unmarshaling should succeed even with missing sections") + + // Missing sections should result in empty sets, not nil + assert.Equal(t, 0, access.Admin.Cardinality(), "Missing admin section should result in empty set") + assert.Equal(t, 0, access.Write.Cardinality(), "Missing write section should result in empty set") + assert.Equal(t, 1, access.Read.Cardinality(), "Read section should be properly loaded") + assert.True(t, access.Read.Contains("reader@example.com")) +} + +func TestAccessUnmarshalYAMLWithEmptyLists(t *testing.T) { + // Test unmarshaling with explicitly empty lists + // This ensures empty lists in YAML are handled correctly + yamlData := ` +admin: [] +read: [] +write: ["writer@company.com"] +` + + var access Access + err := yaml.Unmarshal([]byte(yamlData), &access) + require.NoError(t, err, "YAML unmarshaling should succeed with empty lists") + + assert.Equal(t, 0, access.Admin.Cardinality()) + assert.Equal(t, 0, access.Read.Cardinality()) + assert.Equal(t, 1, access.Write.Cardinality()) + assert.True(t, access.Write.Contains("writer@company.com")) +} + +func TestAccessUnmarshalYAMLInvalidFormat(t *testing.T) { + // Test that invalid YAML format is properly rejected + // This ensures the system fails safely on malformed input + yamlData := ` +admin: "should_be_a_list_not_string" +` + + var access Access + err := yaml.Unmarshal([]byte(yamlData), &access) + assert.Error(t, err, "Invalid YAML format should be rejected") +} + +func TestAccessMarshalYAML(t *testing.T) { + // Test marshaling Access to YAML format + // This is critical for saving syft.pub.yaml files to disk + access := NewAccess( + []string{"admin@example.com", "owner@company.org"}, + []string{"writer@openmined.org"}, + []string{"reader1@public.org", "reader2@university.edu"}, + ) + + data, err := yaml.Marshal(access) + require.NoError(t, err, "YAML marshaling should succeed") + + // Unmarshal back to verify the round-trip works correctly + var unmarshaled Access + err = yaml.Unmarshal(data, &unmarshaled) + require.NoError(t, err, "Round-trip unmarshaling should succeed") + + // Verify all data survived the round-trip + assert.Equal(t, access.Admin.Cardinality(), unmarshaled.Admin.Cardinality()) + assert.Equal(t, access.Write.Cardinality(), unmarshaled.Write.Cardinality()) + assert.Equal(t, access.Read.Cardinality(), unmarshaled.Read.Cardinality()) + + // Check that all user emails are preserved + for _, user := range access.Admin.ToSlice() { + assert.True(t, unmarshaled.Admin.Contains(user)) + } + for _, user := range access.Write.ToSlice() { + assert.True(t, unmarshaled.Write.Contains(user)) + } + for _, user := range access.Read.ToSlice() { + assert.True(t, unmarshaled.Read.Contains(user)) + } +} + +func TestAccessMarshalYAMLWithNilSets(t *testing.T) { + // Test marshaling when some sets are nil + // This tests the defensive programming in MarshalYAML + access := Access{ + Admin: mapset.NewSet("admin@example.com"), + Write: nil, // Intentionally nil + Read: mapset.NewSet("reader@public.org"), + } + + data, err := yaml.Marshal(access) + require.NoError(t, err, "YAML marshaling should handle nil sets gracefully") + + // Verify the marshaled data is valid YAML + var result map[string][]string + err = yaml.Unmarshal(data, &result) + require.NoError(t, err, "Marshaled YAML should be valid") + + // Nil sets should not appear in the output + assert.Contains(t, result, "admin") + assert.Contains(t, result, "read") + assert.NotContains(t, result, "write", "Nil sets should not appear in marshaled YAML") +} + +func TestAccessMarshalYAMLWithEmptySets(t *testing.T) { + // Test marshaling when sets are empty but not nil + // This verifies that empty sets are handled consistently + access := NewAccess( + []string{}, // Empty admin + []string{}, // Empty write + []string{"reader@example.com"}, // Non-empty read + ) + + data, err := yaml.Marshal(access) + require.NoError(t, err, "YAML marshaling should handle empty sets") + + var result map[string][]string + err = yaml.Unmarshal(data, &result) + require.NoError(t, err, "Marshaled YAML should be valid") + + // Empty sets should appear as empty arrays in YAML + assert.Equal(t, []string{}, result["admin"]) + assert.Equal(t, []string{}, result["write"]) + assert.Contains(t, result["read"], "reader@example.com") +} \ No newline at end of file diff --git a/internal/aclspec/aclspec.go b/internal/aclspec/aclspec.go index 3cb1b890..628362d7 100644 --- a/internal/aclspec/aclspec.go +++ b/internal/aclspec/aclspec.go @@ -33,11 +33,18 @@ func WithoutAclPath(path string) string { } // Exists checks if the ACL file exists at the given path +// For security reasons, symlinks are not allowed as ACL files func Exists(path string) bool { aclPath := AsAclPath(path) - stat, err := os.Stat(aclPath) + stat, err := os.Lstat(aclPath) // Use Lstat to not follow symlinks if os.IsNotExist(err) { return false } + + // Reject symlinks for security reasons + if stat.Mode()&os.ModeSymlink != 0 { + return false + } + return stat.Size() > 0 } diff --git a/internal/aclspec/aclspec_test.go b/internal/aclspec/aclspec_test.go new file mode 100644 index 00000000..a0f63e77 --- /dev/null +++ b/internal/aclspec/aclspec_test.go @@ -0,0 +1,357 @@ +package aclspec + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIsAclFile(t *testing.T) { + // Test the core ACL file detection logic + // This is critical for the system to recognize permission files correctly + testCases := []struct { + path string + expected bool + desc string + }{ + { + path: "syft.pub.yaml", + expected: true, + desc: "Direct ACL filename should be detected", + }, + { + path: "/path/to/syft.pub.yaml", + expected: true, + desc: "ACL file with full path should be detected", + }, + { + path: "folder/subfolder/syft.pub.yaml", + expected: true, + desc: "ACL file in nested path should be detected", + }, + { + path: "not_an_acl_file.yaml", + expected: false, + desc: "Regular YAML file should not be detected as ACL", + }, + { + path: "syft.pub.yaml.backup", + expected: false, + desc: "ACL filename with suffix should not be detected", + }, + { + path: "prefix_syft.pub.yaml", + expected: true, + desc: "ACL filename with prefix should be detected (suffix match)", + }, + { + path: "", + expected: false, + desc: "Empty path should not be detected as ACL", + }, + { + path: "syft.pub.yml", + expected: false, + desc: "Wrong extension (.yml vs .yaml) should not be detected", + }, + { + path: "SYFT.PUB.YAML", + expected: false, + desc: "Case-sensitive check - uppercase should not match", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + result := IsAclFile(tc.path) + assert.Equal(t, tc.expected, result, "Path: %s", tc.path) + }) + } +} + +func TestAsAclPath(t *testing.T) { + // Test converting directory paths to ACL file paths + // This ensures the system can correctly locate ACL files for directories + testCases := []struct { + input string + expected string + desc string + }{ + { + input: "/home/user", + expected: "/home/user/syft.pub.yaml", + desc: "Directory path should get ACL filename appended", + }, + { + input: "/home/user/syft.pub.yaml", + expected: "/home/user/syft.pub.yaml", + desc: "ACL file path should remain unchanged", + }, + { + input: "", + expected: "syft.pub.yaml", + desc: "Empty path should result in just the ACL filename", + }, + { + input: "relative/path", + expected: "relative/path/syft.pub.yaml", + desc: "Relative path should get ACL filename appended", + }, + { + input: "/", + expected: "/syft.pub.yaml", + desc: "Root directory should get ACL filename appended", + }, + { + input: "folder/syft.pub.yaml", + expected: "folder/syft.pub.yaml", + desc: "Path already ending with ACL filename should be unchanged", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + result := AsAclPath(tc.input) + assert.Equal(t, tc.expected, result, "Input: %s", tc.input) + }) + } +} + +func TestWithoutAclPath(t *testing.T) { + // Test removing ACL filename from paths + // This is used to get the directory path from an ACL file path + testCases := []struct { + input string + expected string + desc string + }{ + { + input: "/home/user/syft.pub.yaml", + expected: "/home/user/", + desc: "ACL file path should have filename removed", + }, + { + input: "syft.pub.yaml", + expected: "", + desc: "Bare ACL filename should result in empty string", + }, + { + input: "/home/user/other.yaml", + expected: "/home/user/other.yaml", + desc: "Non-ACL file path should remain unchanged", + }, + { + input: "/home/user/", + expected: "/home/user/", + desc: "Directory path without ACL filename should remain unchanged", + }, + { + input: "", + expected: "", + desc: "Empty path should remain empty", + }, + { + input: "folder/subfolder/syft.pub.yaml", + expected: "folder/subfolder/", + desc: "Nested ACL file path should have filename removed", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + result := WithoutAclPath(tc.input) + assert.Equal(t, tc.expected, result, "Input: %s", tc.input) + }) + } +} + +func TestExists(t *testing.T) { + // Test ACL file existence detection + // This validates the system can correctly detect existing ACL files on disk + + // Create a temporary directory for testing + tempDir := t.TempDir() + + // Create a test ACL file with some content + aclFilePath := filepath.Join(tempDir, AclFileName) + err := os.WriteFile(aclFilePath, []byte("terminal: true\nrules: []"), 0644) + require.NoError(t, err, "Should be able to create test ACL file") + + // Test that existing non-empty ACL file is detected + assert.True(t, Exists(tempDir), "Should detect existing ACL file in directory") + assert.True(t, Exists(aclFilePath), "Should detect existing ACL file by direct path") + + // Create an empty ACL file + emptyAclDir := filepath.Join(tempDir, "empty") + err = os.Mkdir(emptyAclDir, 0755) + require.NoError(t, err, "Should be able to create test directory") + + emptyAclFile := filepath.Join(emptyAclDir, AclFileName) + err = os.WriteFile(emptyAclFile, []byte{}, 0644) + require.NoError(t, err, "Should be able to create empty ACL file") + + // Test that empty ACL file is not considered as existing + assert.False(t, Exists(emptyAclDir), "Empty ACL file should not be considered as existing") + + // Test that non-existent path returns false + nonExistentPath := filepath.Join(tempDir, "nonexistent") + assert.False(t, Exists(nonExistentPath), "Non-existent path should return false") + + // Test that directory without ACL file returns false + noAclDir := filepath.Join(tempDir, "no_acl") + err = os.Mkdir(noAclDir, 0755) + require.NoError(t, err, "Should be able to create test directory") + + assert.False(t, Exists(noAclDir), "Directory without ACL file should return false") +} + +func TestExistsWithSymlinks(t *testing.T) { + // Test that symlinks are rejected for security reasons + // ACL files must be regular files, not symlinks + + tempDir := t.TempDir() + + // Create actual ACL file + realAclDir := filepath.Join(tempDir, "real") + err := os.Mkdir(realAclDir, 0755) + require.NoError(t, err) + + realAclFile := filepath.Join(realAclDir, AclFileName) + err = os.WriteFile(realAclFile, []byte("terminal: true"), 0644) + require.NoError(t, err) + + // Verify real ACL file is detected + assert.True(t, Exists(realAclDir), "Real ACL file should be detected") + + // Create symlink to the ACL file + symlinkDir := filepath.Join(tempDir, "symlink") + err = os.Mkdir(symlinkDir, 0755) + require.NoError(t, err) + + symlinkAclFile := filepath.Join(symlinkDir, AclFileName) + err = os.Symlink(realAclFile, symlinkAclFile) + require.NoError(t, err) + + // Test that symlinked ACL file is REJECTED for security reasons + assert.False(t, Exists(symlinkDir), "Symlinked ACL files should be rejected for security") + + // Test that LoadFromFile also rejects symlinks + _, err = LoadFromFile(symlinkDir) + assert.Error(t, err, "LoadFromFile should reject symlinked ACL files") + assert.Contains(t, err.Error(), "symlinks are not allowed", "Error should mention symlink restriction") + + // Create broken symlink + brokenDir := filepath.Join(tempDir, "broken") + err = os.Mkdir(brokenDir, 0755) + require.NoError(t, err) + + brokenSymlink := filepath.Join(brokenDir, AclFileName) + err = os.Symlink("/nonexistent/file", brokenSymlink) + require.NoError(t, err) + + // Test that broken symlink is also rejected + assert.False(t, Exists(brokenDir), "Broken symlinks should be rejected") + + // Test that LoadFromFile also rejects broken symlinks + _, err = LoadFromFile(brokenDir) + assert.Error(t, err, "LoadFromFile should reject broken symlinks") + assert.Contains(t, err.Error(), "symlinks are not allowed", "Error should mention symlink restriction") +} + +func TestSymlinkSecurityComprehensive(t *testing.T) { + // Comprehensive test to ensure symlinks are rejected at all entry points + // This is critical for security - no ACL operation should follow symlinks + + tempDir := t.TempDir() + + // Create a legitimate ACL file + realAclFile := filepath.Join(tempDir, "real.yaml") + aclContent := ` +terminal: true +rules: + - pattern: "**" + access: + read: ["admin@example.com"] +` + err := os.WriteFile(realAclFile, []byte(aclContent), 0644) + require.NoError(t, err) + + // Create a directory with a symlinked ACL file + symlinkDir := filepath.Join(tempDir, "symlinked") + err = os.Mkdir(symlinkDir, 0755) + require.NoError(t, err) + + symlinkAclFile := filepath.Join(symlinkDir, AclFileName) + err = os.Symlink(realAclFile, symlinkAclFile) + require.NoError(t, err) + + // Test all ACL operations reject symlinks + + // 1. Exists should return false for symlinked ACL + assert.False(t, Exists(symlinkDir), "Exists() should reject symlinked ACL files") + + // 2. LoadFromFile should error for symlinked ACL + _, err = LoadFromFile(symlinkDir) + assert.Error(t, err, "LoadFromFile() should reject symlinked ACL files") + assert.Contains(t, err.Error(), "symlinks are not allowed", "Error should mention symlink restriction") + + // 3. AsAclPath and IsAclFile should work normally (they don't check file existence) + aclPath := AsAclPath(symlinkDir) + assert.Equal(t, symlinkAclFile, aclPath, "AsAclPath should work normally") + assert.True(t, IsAclFile(aclPath), "IsAclFile should work normally") + + // 4. Test that WithoutAclPath works normally + dirPath := WithoutAclPath(aclPath) + // Note: WithoutAclPath may add trailing separator, so we use filepath.Clean for comparison + assert.Equal(t, symlinkDir, filepath.Clean(dirPath), "WithoutAclPath should work normally") + + // Verify that regular files still work + regularDir := filepath.Join(tempDir, "regular") + err = os.Mkdir(regularDir, 0755) + require.NoError(t, err) + + regularAclFile := filepath.Join(regularDir, AclFileName) + err = os.WriteFile(regularAclFile, []byte(aclContent), 0644) + require.NoError(t, err) + + // Regular ACL files should work normally + assert.True(t, Exists(regularDir), "Regular ACL files should work") + + _, err = LoadFromFile(regularDir) + assert.NoError(t, err, "LoadFromFile should work with regular ACL files") +} + +func TestConstants(t *testing.T) { + // Test that the constants have expected values + // This ensures the constants are correctly defined and haven't been accidentally changed + assert.Equal(t, "syft.pub.yaml", AclFileName, "ACL filename constant should be correct") + assert.Equal(t, "*", Everyone, "Everyone constant should be correct") + assert.Equal(t, "**", AllFiles, "AllFiles pattern should be correct") + assert.Equal(t, true, SetTerminal, "SetTerminal constant should be true") + assert.Equal(t, false, UnsetTerminal, "UnsetTerminal constant should be false") +} + +func TestPathEdgeCases(t *testing.T) { + // Test edge cases in path manipulation functions + // This ensures robust handling of unusual but valid path scenarios + + // Test with paths containing special characters + specialPath := "/home/user with spaces/syft.pub.yaml" + assert.True(t, IsAclFile(specialPath), "Paths with spaces should be handled correctly") + + // Test with Unicode characters + unicodePath := "/home/用户/syft.pub.yaml" + assert.True(t, IsAclFile(unicodePath), "Unicode paths should be handled correctly") + + // Test with multiple slashes + multiSlashPath := "/home//user///syft.pub.yaml" + expected := "/home//user///syft.pub.yaml" + assert.Equal(t, expected, AsAclPath(multiSlashPath), "Multiple slashes should be preserved") + + // Test with backslashes (Windows-style paths) + windowsPath := "C:\\Users\\user\\syft.pub.yaml" + assert.True(t, IsAclFile(windowsPath), "Windows-style paths should be detected") +} \ No newline at end of file diff --git a/internal/aclspec/limits_test.go b/internal/aclspec/limits_test.go new file mode 100644 index 00000000..e119d51b --- /dev/null +++ b/internal/aclspec/limits_test.go @@ -0,0 +1,123 @@ +package aclspec + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDefaultLimits(t *testing.T) { + // Test that DefaultLimits creates a Limits object with expected default values + // This is critical because these defaults affect security and functionality + limits := DefaultLimits() + + // Verify the limits object is not nil + assert.NotNil(t, limits, "DefaultLimits should return a non-nil Limits object") + + // Test file count limit - default should be 0 (unlimited) + assert.Equal(t, uint32(0), limits.MaxFiles, + "Default MaxFiles should be 0 (unlimited) to not restrict file count by default") + + // Test file size limit - default should be 0 (unlimited) + assert.Equal(t, int64(0), limits.MaxFileSize, + "Default MaxFileSize should be 0 (unlimited) to not restrict file size by default") + + // Test directory permission - default should allow directories + assert.True(t, limits.AllowDirs, + "Default should allow directories since most use cases need directory creation") + + // Test symlink permission - default should NOT allow symlinks for security + assert.False(t, limits.AllowSymlinks, + "Default should not allow symlinks for security reasons (prevent symlink attacks)") +} + +func TestLimitsStruct(t *testing.T) { + // Test creating Limits with custom values + // This validates the struct can hold different configurations correctly + customLimits := &Limits{ + MaxFiles: 100, + MaxFileSize: 1024 * 1024, // 1MB + AllowDirs: false, + AllowSymlinks: true, + } + + assert.Equal(t, uint32(100), customLimits.MaxFiles, "Custom MaxFiles should be preserved") + assert.Equal(t, int64(1024*1024), customLimits.MaxFileSize, "Custom MaxFileSize should be preserved") + assert.False(t, customLimits.AllowDirs, "Custom AllowDirs should be preserved") + assert.True(t, customLimits.AllowSymlinks, "Custom AllowSymlinks should be preserved") +} + +func TestLimitsZeroValues(t *testing.T) { + // Test that zero values in Limits struct behave as expected + // This is important for understanding the semantic meaning of zero values + var limits Limits + + // Zero values should represent the most restrictive settings + assert.Equal(t, uint32(0), limits.MaxFiles, "Zero value MaxFiles should be 0") + assert.Equal(t, int64(0), limits.MaxFileSize, "Zero value MaxFileSize should be 0") + assert.False(t, limits.AllowDirs, "Zero value AllowDirs should be false (more restrictive)") + assert.False(t, limits.AllowSymlinks, "Zero value AllowSymlinks should be false (more secure)") +} + +func TestLimitsIndependence(t *testing.T) { + // Test that multiple calls to DefaultLimits return independent objects + // This ensures modifications to one instance don't affect others + limits1 := DefaultLimits() + limits2 := DefaultLimits() + + // Verify they start with the same values + assert.Equal(t, limits1.MaxFiles, limits2.MaxFiles) + assert.Equal(t, limits1.MaxFileSize, limits2.MaxFileSize) + assert.Equal(t, limits1.AllowDirs, limits2.AllowDirs) + assert.Equal(t, limits1.AllowSymlinks, limits2.AllowSymlinks) + + // Modify one instance + limits1.MaxFiles = 50 + limits1.AllowDirs = false + + // Verify the other instance is unchanged + assert.Equal(t, uint32(0), limits2.MaxFiles, "Modifying one instance should not affect another") + assert.True(t, limits2.AllowDirs, "Modifying one instance should not affect another") +} + +func TestLimitsExtremeValues(t *testing.T) { + // Test Limits with extreme values to ensure robust handling + // This validates the struct can handle boundary conditions + extremeLimits := &Limits{ + MaxFiles: ^uint32(0), // Maximum uint32 value + MaxFileSize: 9223372036854775807, // Maximum int64 value + AllowDirs: true, + AllowSymlinks: true, + } + + // These extreme values should be preserved without overflow or corruption + assert.Equal(t, ^uint32(0), extremeLimits.MaxFiles, "Maximum uint32 value should be preserved") + assert.Equal(t, int64(9223372036854775807), extremeLimits.MaxFileSize, "Maximum int64 value should be preserved") + assert.True(t, extremeLimits.AllowDirs, "Boolean values should be preserved") + assert.True(t, extremeLimits.AllowSymlinks, "Boolean values should be preserved") +} + +func TestLimitsSemantics(t *testing.T) { + // Test the semantic meaning of limit values + // This documents and validates the intended behavior of different settings + + // Test unlimited semantics (0 values) + unlimited := &Limits{ + MaxFiles: 0, + MaxFileSize: 0, + } + + // Zero should mean "no limit" for numeric fields + // This is a common convention in Unix systems + assert.Equal(t, uint32(0), unlimited.MaxFiles, "Zero MaxFiles should mean unlimited") + assert.Equal(t, int64(0), unlimited.MaxFileSize, "Zero MaxFileSize should mean unlimited") + + // Test specific limits + limited := &Limits{ + MaxFiles: 10, + MaxFileSize: 1000, + } + + assert.True(t, limited.MaxFiles > 0, "Non-zero MaxFiles should impose a limit") + assert.True(t, limited.MaxFileSize > 0, "Non-zero MaxFileSize should impose a limit") +} \ No newline at end of file diff --git a/internal/aclspec/migrate_test.go b/internal/aclspec/migrate_test.go new file mode 100644 index 00000000..7c04d76b --- /dev/null +++ b/internal/aclspec/migrate_test.go @@ -0,0 +1,247 @@ +package aclspec + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +func TestLegacyPermissionUnmarshalYAML(t *testing.T) { + // Test unmarshaling legacy permission format from YAML + // This validates backward compatibility with older permission file formats + yamlContent := ` +- path: "documents/*" + user: "alice@research.org" + permissions: ["read", "write"] +- path: "public/*.txt" + user: "bob@university.edu" + permissions: ["read"] +- path: "admin/*" + user: "admin@company.com" + permissions: ["read", "write", "admin"] +` + + var legacyPerm LegacyPermission + err := yaml.Unmarshal([]byte(yamlContent), &legacyPerm) + require.NoError(t, err, "Should successfully unmarshal legacy permission format") + + // Verify the correct number of rules were parsed + assert.Len(t, legacyPerm.Rules, 3, "Should parse all three legacy rules") + + // Verify first rule (alice with read/write on documents) + rule1 := legacyPerm.Rules[0] + assert.Equal(t, "documents/*", rule1.Path, "First rule path should be correct") + assert.Equal(t, "alice@research.org", rule1.User, "First rule user should be correct") + assert.Len(t, rule1.Permissions, 2, "First rule should have 2 permissions") + assert.Contains(t, rule1.Permissions, Read, "First rule should include read permission") + assert.Contains(t, rule1.Permissions, Write, "First rule should include write permission") + + // Verify second rule (bob with read on public txt files) + rule2 := legacyPerm.Rules[1] + assert.Equal(t, "public/*.txt", rule2.Path, "Second rule path should be correct") + assert.Equal(t, "bob@university.edu", rule2.User, "Second rule user should be correct") + assert.Len(t, rule2.Permissions, 1, "Second rule should have 1 permission") + assert.Contains(t, rule2.Permissions, Read, "Second rule should include read permission") + + // Verify third rule (admin with full permissions) + rule3 := legacyPerm.Rules[2] + assert.Equal(t, "admin/*", rule3.Path, "Third rule path should be correct") + assert.Equal(t, "admin@company.com", rule3.User, "Third rule user should be correct") + assert.Len(t, rule3.Permissions, 3, "Third rule should have 3 permissions") + assert.Contains(t, rule3.Permissions, Read, "Third rule should include read permission") + assert.Contains(t, rule3.Permissions, Write, "Third rule should include write permission") + assert.Contains(t, rule3.Permissions, Execute, "Third rule should include admin permission") +} + +func TestLegacyPermissionUnmarshalEmptyYAML(t *testing.T) { + // Test unmarshaling empty legacy permission list + // This ensures the system handles empty legacy files gracefully + yamlContent := `[]` + + var legacyPerm LegacyPermission + err := yaml.Unmarshal([]byte(yamlContent), &legacyPerm) + require.NoError(t, err, "Should successfully unmarshal empty legacy permission list") + + assert.Empty(t, legacyPerm.Rules, "Empty YAML should result in empty rules list") +} + +func TestLegacyPermissionUnmarshalInvalidYAML(t *testing.T) { + // Test that invalid YAML structure is properly rejected + // This ensures the system fails safely on malformed legacy files + yamlContent := `"this_is_not_a_sequence"` + + var legacyPerm LegacyPermission + err := yaml.Unmarshal([]byte(yamlContent), &legacyPerm) + assert.Error(t, err, "Should reject non-sequence YAML for legacy permissions") + assert.Contains(t, err.Error(), "expected a sequence", "Error should indicate sequence was expected") +} + +func TestLegacyPermissionUnmarshalPartialRule(t *testing.T) { + // Test unmarshaling legacy rules with missing fields + // This validates handling of incomplete legacy data + yamlContent := ` +- path: "test/*" + user: "testuser@example.com" + # permissions field intentionally missing +- path: "incomplete" + # user and permissions fields missing +` + + var legacyPerm LegacyPermission + err := yaml.Unmarshal([]byte(yamlContent), &legacyPerm) + require.NoError(t, err, "Should handle legacy rules with missing fields") + + assert.Len(t, legacyPerm.Rules, 2, "Should parse both rules despite missing fields") + + // First rule should have path and user, but empty permissions + rule1 := legacyPerm.Rules[0] + assert.Equal(t, "test/*", rule1.Path) + assert.Equal(t, "testuser@example.com", rule1.User) + assert.Empty(t, rule1.Permissions, "Missing permissions should result in empty slice") + + // Second rule should have only path + rule2 := legacyPerm.Rules[1] + assert.Equal(t, "incomplete", rule2.Path) + assert.Empty(t, rule2.User, "Missing user should result in empty string") + assert.Empty(t, rule2.Permissions, "Missing permissions should result in empty slice") +} + +func TestPermissionTypeConstants(t *testing.T) { + // Test that permission type constants have expected values + // This ensures the constants match the legacy format requirements + assert.Equal(t, PermissionType("read"), Read, "Read permission constant should be correct") + assert.Equal(t, PermissionType("create"), Create, "Create permission constant should be correct") + assert.Equal(t, PermissionType("write"), Write, "Write permission constant should be correct") + assert.Equal(t, PermissionType("admin"), Execute, "Execute/Admin permission constant should be correct") +} + +func TestLegacyRulePermissionTypes(t *testing.T) { + // Test that all permission types can be properly parsed + // This validates the enum handling for different permission levels + yamlContent := ` +- path: "test/*" + user: "testuser@example.com" + permissions: ["read", "create", "write", "admin"] +` + + var legacyPerm LegacyPermission + err := yaml.Unmarshal([]byte(yamlContent), &legacyPerm) + require.NoError(t, err, "Should parse all permission types") + + rule := legacyPerm.Rules[0] + assert.Len(t, rule.Permissions, 4, "Should have all 4 permission types") + assert.Contains(t, rule.Permissions, Read, "Should contain read permission") + assert.Contains(t, rule.Permissions, Create, "Should contain create permission") + assert.Contains(t, rule.Permissions, Write, "Should contain write permission") + assert.Contains(t, rule.Permissions, Execute, "Should contain admin/execute permission") +} + +func TestLegacyRuleWithDuplicatePermissions(t *testing.T) { + // Test handling of duplicate permissions in legacy format + // This ensures duplicate permissions are handled gracefully + yamlContent := ` +- path: "test/*" + user: "testuser@example.com" + permissions: ["read", "read", "write", "read"] +` + + var legacyPerm LegacyPermission + err := yaml.Unmarshal([]byte(yamlContent), &legacyPerm) + require.NoError(t, err, "Should handle duplicate permissions") + + rule := legacyPerm.Rules[0] + // YAML unmarshaling preserves duplicates in slice + assert.Len(t, rule.Permissions, 4, "Should preserve all permission entries including duplicates") + + // Count actual unique permissions + uniquePerms := make(map[PermissionType]bool) + for _, perm := range rule.Permissions { + uniquePerms[perm] = true + } + assert.Len(t, uniquePerms, 2, "Should have 2 unique permission types (read and write)") +} + +func TestLegacyRuleWithInvalidPermissions(t *testing.T) { + // Test handling of invalid permission types in legacy format + // This documents behavior with unknown permission strings + yamlContent := ` +- path: "test/*" + user: "testuser@example.com" + permissions: ["read", "invalid_permission", "write"] +` + + var legacyPerm LegacyPermission + err := yaml.Unmarshal([]byte(yamlContent), &legacyPerm) + require.NoError(t, err, "Should parse even with invalid permissions") + + rule := legacyPerm.Rules[0] + assert.Len(t, rule.Permissions, 3, "Should include all permission strings including invalid ones") + assert.Contains(t, rule.Permissions, Read, "Should contain valid read permission") + assert.Contains(t, rule.Permissions, Write, "Should contain valid write permission") + assert.Contains(t, rule.Permissions, PermissionType("invalid_permission"), "Should preserve invalid permission as-is") +} + +func TestLegacyPermissionComplexPaths(t *testing.T) { + // Test legacy permission parsing with complex file paths + // This ensures path handling works correctly for various path formats + yamlContent := ` +- path: "**/*.go" + user: "developer@company.com" + permissions: ["read", "write"] +- path: "/absolute/path/*" + user: "admin@openmined.org" + permissions: ["admin"] +- path: "relative/../path" + user: "user@university.edu" + permissions: ["read"] +- path: "" + user: "empty_path_user@example.com" + permissions: ["read"] +` + + var legacyPerm LegacyPermission + err := yaml.Unmarshal([]byte(yamlContent), &legacyPerm) + require.NoError(t, err, "Should handle complex paths") + + assert.Len(t, legacyPerm.Rules, 4, "Should parse all rules with complex paths") + + // Verify each path is preserved exactly as specified + assert.Equal(t, "**/*.go", legacyPerm.Rules[0].Path, "Glob pattern should be preserved") + assert.Equal(t, "/absolute/path/*", legacyPerm.Rules[1].Path, "Absolute path should be preserved") + assert.Equal(t, "relative/../path", legacyPerm.Rules[2].Path, "Relative path with .. should be preserved") + assert.Equal(t, "", legacyPerm.Rules[3].Path, "Empty path should be preserved") +} + +func TestLegacyPermissionSpecialCharacters(t *testing.T) { + // Test legacy permission parsing with special characters in user names and paths + // This validates handling of edge cases in legacy data + yamlContent := ` +- path: "files with spaces/*" + user: "user.with.dots" + permissions: ["read"] +- path: "unicode/测试/*" + user: "用户名" + permissions: ["write"] +- path: "symbols!@#$%/*" + user: "user_with_underscores" + permissions: ["admin"] +` + + var legacyPerm LegacyPermission + err := yaml.Unmarshal([]byte(yamlContent), &legacyPerm) + require.NoError(t, err, "Should handle special characters in paths and users") + + assert.Len(t, legacyPerm.Rules, 3, "Should parse all rules with special characters") + + // Verify special characters are preserved + assert.Equal(t, "files with spaces/*", legacyPerm.Rules[0].Path, "Spaces in paths should be preserved") + assert.Equal(t, "user.with.dots", legacyPerm.Rules[0].User, "Dots in usernames should be preserved") + + assert.Equal(t, "unicode/测试/*", legacyPerm.Rules[1].Path, "Unicode characters should be preserved") + assert.Equal(t, "用户名", legacyPerm.Rules[1].User, "Unicode usernames should be preserved") + + assert.Equal(t, "symbols!@#$%/*", legacyPerm.Rules[2].Path, "Symbol characters should be preserved") + assert.Equal(t, "user_with_underscores", legacyPerm.Rules[2].User, "Underscores should be preserved") +} \ No newline at end of file diff --git a/internal/aclspec/rule_test.go b/internal/aclspec/rule_test.go new file mode 100644 index 00000000..ff6be3c1 --- /dev/null +++ b/internal/aclspec/rule_test.go @@ -0,0 +1,247 @@ +package aclspec + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewRule(t *testing.T) { + // Test creating a new Rule with all components + // This validates the core Rule constructor works correctly + pattern := "*.txt" + access := PublicReadAccess() + limits := DefaultLimits() + + rule := NewRule(pattern, access, limits) + + // Verify all components are properly assigned + assert.NotNil(t, rule, "NewRule should return a non-nil Rule") + assert.Equal(t, pattern, rule.Pattern, "Pattern should be preserved") + assert.Equal(t, access, rule.Access, "Access should be preserved") + assert.Equal(t, limits, rule.Limits, "Limits should be preserved") +} + +func TestNewRuleWithDifferentPatterns(t *testing.T) { + // Test creating rules with various pattern types + // This ensures the constructor handles different glob patterns correctly + testCases := []struct { + pattern string + desc string + }{ + { + pattern: "**", + desc: "Universal pattern should be preserved", + }, + { + pattern: "*.go", + desc: "Extension-based pattern should be preserved", + }, + { + pattern: "specific.txt", + desc: "Specific filename pattern should be preserved", + }, + { + pattern: "folder/**/*.py", + desc: "Complex nested pattern should be preserved", + }, + { + pattern: "", + desc: "Empty pattern should be preserved (even if invalid)", + }, + } + + access := PrivateAccess() + limits := DefaultLimits() + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + rule := NewRule(tc.pattern, access, limits) + assert.Equal(t, tc.pattern, rule.Pattern, tc.desc) + }) + } +} + +func TestNewRuleWithDifferentAccess(t *testing.T) { + // Test creating rules with different access configurations + // This validates rules can be created with various permission levels + pattern := "test.txt" + limits := DefaultLimits() + + testCases := []struct { + access *Access + desc string + }{ + { + access: PrivateAccess(), + desc: "Private access should be preserved", + }, + { + access: PublicReadAccess(), + desc: "Public read access should be preserved", + }, + { + access: PublicReadWriteAccess(), + desc: "Public read-write access should be preserved", + }, + { + access: SharedReadAccess("user1", "user2"), + desc: "Shared read access should be preserved", + }, + { + access: SharedReadWriteAccess("maintainer"), + desc: "Shared read-write access should be preserved", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + rule := NewRule(pattern, tc.access, limits) + assert.Equal(t, tc.access, rule.Access, tc.desc) + }) + } +} + +func TestNewRuleWithDifferentLimits(t *testing.T) { + // Test creating rules with different limit configurations + // This validates rules can enforce various restrictions + pattern := "test.txt" + access := PrivateAccess() + + testCases := []struct { + limits *Limits + desc string + }{ + { + limits: DefaultLimits(), + desc: "Default limits should be preserved", + }, + { + limits: &Limits{ + MaxFileSize: 1024, + MaxFiles: 10, + AllowDirs: false, + AllowSymlinks: false, + }, + desc: "Custom restrictive limits should be preserved", + }, + { + limits: &Limits{ + MaxFileSize: 0, // Unlimited + MaxFiles: 0, // Unlimited + AllowDirs: true, + AllowSymlinks: true, + }, + desc: "Custom permissive limits should be preserved", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + rule := NewRule(pattern, access, tc.limits) + assert.Equal(t, tc.limits, rule.Limits, tc.desc) + }) + } +} + +func TestNewRuleWithNilInputs(t *testing.T) { + // Test creating rules with nil inputs + // This documents behavior with invalid inputs (system should handle gracefully) + pattern := "test.txt" + + // Test with nil access (this might be invalid but shouldn't crash) + rule := NewRule(pattern, nil, DefaultLimits()) + assert.Equal(t, pattern, rule.Pattern, "Pattern should be preserved even with nil access") + assert.Nil(t, rule.Access, "Nil access should be preserved") + assert.NotNil(t, rule.Limits, "Limits should be preserved") + + // Test with nil limits (this might be invalid but shouldn't crash) + rule = NewRule(pattern, PrivateAccess(), nil) + assert.Equal(t, pattern, rule.Pattern, "Pattern should be preserved even with nil limits") + assert.NotNil(t, rule.Access, "Access should be preserved") + assert.Nil(t, rule.Limits, "Nil limits should be preserved") +} + +func TestNewDefaultRule(t *testing.T) { + // Test creating a default rule (catch-all pattern) + // This validates the convenience constructor for the most common default rule + access := PrivateAccess() + limits := DefaultLimits() + + rule := NewDefaultRule(access, limits) + + // Verify the rule uses the universal pattern + assert.NotNil(t, rule, "NewDefaultRule should return a non-nil Rule") + assert.Equal(t, AllFiles, rule.Pattern, "Default rule should use the AllFiles pattern (**)") + assert.Equal(t, access, rule.Access, "Access should be preserved") + assert.Equal(t, limits, rule.Limits, "Limits should be preserved") +} + +func TestNewDefaultRuleWithDifferentAccess(t *testing.T) { + // Test default rule creation with various access levels + // This ensures default rules can be created with any permission configuration + limits := DefaultLimits() + + testCases := []struct { + access *Access + desc string + }{ + { + access: PrivateAccess(), + desc: "Default rule with private access", + }, + { + access: PublicReadAccess(), + desc: "Default rule with public read access", + }, + { + access: SharedReadWriteAccess("admin"), + desc: "Default rule with shared access", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + rule := NewDefaultRule(tc.access, limits) + assert.Equal(t, AllFiles, rule.Pattern, "All default rules should use AllFiles pattern") + assert.Equal(t, tc.access, rule.Access, tc.desc) + }) + } +} + +func TestRulePatternConstants(t *testing.T) { + // Test that rules correctly use the expected pattern constants + // This ensures consistency in pattern usage across the system + + // Default rule should use AllFiles constant + defaultRule := NewDefaultRule(PrivateAccess(), DefaultLimits()) + assert.Equal(t, "**", defaultRule.Pattern, "Default rule should use the AllFiles constant value") + assert.Equal(t, AllFiles, defaultRule.Pattern, "Default rule should use AllFiles constant") + + // Custom rule should preserve custom pattern + customRule := NewRule("custom/*.txt", PrivateAccess(), DefaultLimits()) + assert.NotEqual(t, AllFiles, customRule.Pattern, "Custom rule should not use AllFiles pattern") + assert.Equal(t, "custom/*.txt", customRule.Pattern, "Custom rule should preserve its pattern") +} + +func TestRuleIndependence(t *testing.T) { + // Test that multiple rules are independent objects + // This ensures modifying one rule doesn't affect others + access1 := PrivateAccess() + access2 := PublicReadAccess() + limits1 := DefaultLimits() + limits2 := &Limits{MaxFileSize: 1000} + + rule1 := NewRule("*.txt", access1, limits1) + rule2 := NewRule("*.md", access2, limits2) + + // Verify rules are independent + assert.NotEqual(t, rule1.Pattern, rule2.Pattern, "Rules should have different patterns") + assert.NotEqual(t, rule1.Access, rule2.Access, "Rules should have different access") + assert.NotEqual(t, rule1.Limits, rule2.Limits, "Rules should have different limits") + + // Verify shared objects are actually the same references when intended + rule3 := NewRule("*.go", access1, limits1) + assert.Equal(t, rule1.Access, rule3.Access, "Rules sharing the same access object should reference the same object") + assert.Equal(t, rule1.Limits, rule3.Limits, "Rules sharing the same limits object should reference the same object") +} \ No newline at end of file diff --git a/internal/aclspec/ruleset.go b/internal/aclspec/ruleset.go index 68ebcbc9..57ad6e41 100644 --- a/internal/aclspec/ruleset.go +++ b/internal/aclspec/ruleset.go @@ -30,8 +30,21 @@ func (r *RuleSet) AllRules() []*Rule { } // LoadFromFile loads a RuleSet from the specified file path +// For security reasons, symlinks are not allowed as ACL files func LoadFromFile(path string) (*RuleSet, error) { aclPath := AsAclPath(path) + + // Check if file is a symlink before opening + stat, err := os.Lstat(aclPath) + if err != nil { + return nil, err + } + + // Reject symlinks for security reasons + if stat.Mode()&os.ModeSymlink != 0 { + return nil, fmt.Errorf("symlinks are not allowed as ACL files: %s", aclPath) + } + fd, err := os.Open(aclPath) if err != nil { return nil, err diff --git a/internal/aclspec/ruleset_test.go b/internal/aclspec/ruleset_test.go new file mode 100644 index 00000000..aee03c62 --- /dev/null +++ b/internal/aclspec/ruleset_test.go @@ -0,0 +1,368 @@ +package aclspec + +import ( + "io" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +func TestNewRuleSet(t *testing.T) { + // Test creating a new RuleSet with basic configuration + // This validates the core RuleSet constructor + path := "test/path" + terminal := true + rule1 := NewRule("*.txt", PublicReadAccess(), DefaultLimits()) + rule2 := NewRule("*.md", PrivateAccess(), DefaultLimits()) + + ruleset := NewRuleSet(path, terminal, rule1, rule2) + + // Verify all components are properly assigned + assert.NotNil(t, ruleset, "NewRuleSet should return a non-nil RuleSet") + assert.Equal(t, path, ruleset.Path, "Path should be preserved") + assert.Equal(t, terminal, ruleset.Terminal, "Terminal flag should be preserved") + assert.Len(t, ruleset.Rules, 2, "Should have exactly 2 rules") + assert.Contains(t, ruleset.Rules, rule1, "Should contain first rule") + assert.Contains(t, ruleset.Rules, rule2, "Should contain second rule") +} + +func TestNewRuleSetWithAclPath(t *testing.T) { + // Test that NewRuleSet correctly handles ACL file paths + // This ensures the constructor normalizes paths by removing ACL filename + aclPath := "test/path/syft.pub.yaml" + expectedPath := "test/path/" + + ruleset := NewRuleSet(aclPath, false) + + assert.Equal(t, expectedPath, ruleset.Path, "ACL filename should be stripped from path") +} + +func TestNewRuleSetWithNoRules(t *testing.T) { + // Test creating a RuleSet with no initial rules + // This validates the constructor works with empty rule sets + ruleset := NewRuleSet("test/path", false) + + assert.Equal(t, "test/path", ruleset.Path) + assert.False(t, ruleset.Terminal) + assert.Empty(t, ruleset.Rules, "RuleSet with no rules should have empty Rules slice") +} + +func TestAllRules(t *testing.T) { + // Test the AllRules method returns the correct rules + // This is a simple getter but important for the interface + rule1 := NewRule("*.txt", PublicReadAccess(), DefaultLimits()) + rule2 := NewRule("*.md", PrivateAccess(), DefaultLimits()) + ruleset := NewRuleSet("test", false, rule1, rule2) + + rules := ruleset.AllRules() + + assert.Len(t, rules, 2, "AllRules should return all rules") + assert.Contains(t, rules, rule1, "Should contain first rule") + assert.Contains(t, rules, rule2, "Should contain second rule") +} + +func TestLoadFromReader(t *testing.T) { + // Test loading RuleSet from YAML reader + // This validates the core YAML parsing functionality + yamlContent := ` +terminal: true +rules: + - pattern: "*.txt" + access: + read: ["*"] + - pattern: "private/*" + access: + read: ["admin"] + write: ["admin"] +` + + reader := io.NopCloser(strings.NewReader(yamlContent)) + ruleset, err := LoadFromReader("test/path", reader) + + require.NoError(t, err, "LoadFromReader should succeed with valid YAML") + assert.NotNil(t, ruleset, "Should return a non-nil RuleSet") + + // Verify basic properties + assert.Equal(t, "test/path", ruleset.Path, "Path should be set correctly") + assert.True(t, ruleset.Terminal, "Terminal flag should be parsed correctly") + + // Should have the 2 explicit rules plus 1 default rule added by setDefaults + assert.Len(t, ruleset.Rules, 3, "Should have 2 explicit rules + 1 default rule") + + // Verify the first rule + rule1 := ruleset.Rules[0] + assert.Equal(t, "*.txt", rule1.Pattern, "First rule pattern should be correct") + assert.True(t, rule1.Access.Read.Contains("*"), "First rule should grant read access to everyone") + + // Verify the second rule + rule2 := ruleset.Rules[1] + assert.Equal(t, "private/*", rule2.Pattern, "Second rule pattern should be correct") + assert.True(t, rule2.Access.Read.Contains("admin"), "Second rule should grant read access to admin") + assert.True(t, rule2.Access.Write.Contains("admin"), "Second rule should grant write access to admin") +} + +func TestLoadFromReaderWithMinimalYAML(t *testing.T) { + // Test loading with minimal YAML (only terminal flag) + // This validates default rule injection works correctly + yamlContent := `terminal: false` + + reader := io.NopCloser(strings.NewReader(yamlContent)) + ruleset, err := LoadFromReader("test", reader) + + require.NoError(t, err, "LoadFromReader should succeed with minimal YAML") + assert.False(t, ruleset.Terminal, "Terminal flag should be parsed correctly") + + // Should have exactly 1 default rule added by setDefaults + assert.Len(t, ruleset.Rules, 1, "Should have 1 default rule") + assert.Equal(t, "**", ruleset.Rules[0].Pattern, "Default rule should have AllFiles pattern") +} + +func TestLoadFromReaderWithEmptyYAML(t *testing.T) { + // Test loading with completely empty YAML + // This validates the system handles empty files gracefully + yamlContent := `` + + reader := io.NopCloser(strings.NewReader(yamlContent)) + ruleset, err := LoadFromReader("test", reader) + + require.NoError(t, err, "LoadFromReader should succeed with empty YAML") + assert.False(t, ruleset.Terminal, "Terminal should default to false") + + // Should have exactly 1 default rule added by setDefaults + assert.Len(t, ruleset.Rules, 1, "Should have 1 default rule") + assert.Equal(t, "**", ruleset.Rules[0].Pattern, "Default rule should have AllFiles pattern") +} + +func TestLoadFromReaderWithInvalidYAML(t *testing.T) { + // Test that invalid YAML is properly rejected + // This ensures the system fails safely on malformed input + yamlContent := ` +invalid: yaml: content: + - missing + proper: structure +` + + reader := io.NopCloser(strings.NewReader(yamlContent)) + ruleset, err := LoadFromReader("test", reader) + + assert.Error(t, err, "LoadFromReader should fail with invalid YAML") + assert.Nil(t, ruleset, "Should return nil RuleSet on error") +} + +func TestLoadFromFile(t *testing.T) { + // Test loading RuleSet from file on disk + // This validates file I/O integration + tempDir := t.TempDir() + aclFile := filepath.Join(tempDir, AclFileName) + + yamlContent := ` +terminal: true +rules: + - pattern: "*.go" + access: + read: ["developers@company.com"] +` + + err := os.WriteFile(aclFile, []byte(yamlContent), 0644) + require.NoError(t, err, "Should be able to write test file") + + ruleset, err := LoadFromFile(tempDir) + require.NoError(t, err, "LoadFromFile should succeed") + + assert.Equal(t, tempDir, ruleset.Path, "Path should be directory (without ACL filename)") + assert.True(t, ruleset.Terminal, "Terminal flag should be loaded correctly") + assert.Len(t, ruleset.Rules, 2, "Should have 1 explicit rule + 1 default rule") +} + +func TestLoadFromFileNonExistent(t *testing.T) { + // Test loading from non-existent file + // This validates error handling for missing files + nonExistentPath := "/path/that/does/not/exist" + + ruleset, err := LoadFromFile(nonExistentPath) + assert.Error(t, err, "LoadFromFile should fail for non-existent file") + assert.Nil(t, ruleset, "Should return nil RuleSet on error") +} + +func TestRuleSetSave(t *testing.T) { + // Test saving RuleSet to file + // This validates YAML serialization and file I/O + tempDir := t.TempDir() + + // Create a RuleSet with test data + rule1 := NewRule("*.txt", PublicReadAccess(), DefaultLimits()) + rule2 := NewRule("secret/*", PrivateAccess(), &Limits{MaxFileSize: 1024}) + ruleset := NewRuleSet(tempDir, true, rule1, rule2) + + err := ruleset.Save() + require.NoError(t, err, "Save should succeed") + + // Verify the file was created + aclFile := filepath.Join(tempDir, AclFileName) + assert.FileExists(t, aclFile, "ACL file should be created") + + // Load the file back and verify content + content, err := os.ReadFile(aclFile) + require.NoError(t, err, "Should be able to read saved file") + + var loaded RuleSet + err = yaml.Unmarshal(content, &loaded) + require.NoError(t, err, "Saved YAML should be valid") + + assert.True(t, loaded.Terminal, "Terminal flag should be preserved") + assert.Len(t, loaded.Rules, 2, "All rules should be saved") +} + +func TestRuleSetSaveInvalidPath(t *testing.T) { + // Test saving to invalid path + // This validates error handling for file I/O failures + ruleset := NewRuleSet("/invalid/path/that/cannot/be/created", false) + + err := ruleset.Save() + assert.Error(t, err, "Save should fail for invalid path") +} + +func TestSetDefaults(t *testing.T) { + // Test the setDefaults function directly + // This validates the default rule injection logic + + // Test with nil rules + ruleset := &RuleSet{Path: "test", Terminal: false, Rules: nil} + result, err := setDefaults(ruleset) + + require.NoError(t, err, "setDefaults should succeed with nil rules") + assert.Len(t, result.Rules, 1, "Should add exactly one default rule") + assert.Equal(t, "**", result.Rules[0].Pattern, "Default rule should have AllFiles pattern") +} + +func TestSetDefaultsWithExistingDefaultRule(t *testing.T) { + // Test setDefaults when a default rule already exists + // This ensures default rules aren't duplicated + existingDefault := NewRule("**", PrivateAccess(), DefaultLimits()) + customRule := NewRule("*.txt", PublicReadAccess(), DefaultLimits()) + + ruleset := &RuleSet{ + Path: "test", + Terminal: false, + Rules: []*Rule{customRule, existingDefault}, + } + + result, err := setDefaults(ruleset) + + require.NoError(t, err, "setDefaults should succeed with existing default rule") + assert.Len(t, result.Rules, 2, "Should not add additional default rule") + + // Verify the existing default rule is preserved + hasDefault := false + for _, rule := range result.Rules { + if rule.Pattern == "**" { + hasDefault = true + break + } + } + assert.True(t, hasDefault, "Should preserve existing default rule") +} + +func TestSetDefaultsValidation(t *testing.T) { + // Test setDefaults validation of rule requirements + // This ensures invalid rules are properly rejected + + // Test with empty pattern + invalidRule := NewRule("", PrivateAccess(), DefaultLimits()) + ruleset := &RuleSet{ + Path: "test", + Terminal: false, + Rules: []*Rule{invalidRule}, + } + + result, err := setDefaults(ruleset) + assert.Error(t, err, "setDefaults should reject rules with empty patterns") + assert.Nil(t, result, "Should return nil on validation error") + + // Test with nil access + invalidRule2 := NewRule("*.txt", nil, DefaultLimits()) + ruleset2 := &RuleSet{ + Path: "test", + Terminal: false, + Rules: []*Rule{invalidRule2}, + } + + result, err = setDefaults(ruleset2) + assert.Error(t, err, "setDefaults should reject rules with nil access") + assert.Nil(t, result, "Should return nil on validation error") +} + +func TestSetDefaultsLimitsInjection(t *testing.T) { + // Test that setDefaults adds default limits to rules missing them + // This ensures all rules have proper limits configuration + ruleWithoutLimits := NewRule("*.txt", PrivateAccess(), nil) + + ruleset := &RuleSet{ + Path: "test", + Terminal: false, + Rules: []*Rule{ruleWithoutLimits}, + } + + result, err := setDefaults(ruleset) + + require.NoError(t, err, "setDefaults should succeed and inject limits") + assert.NotNil(t, result.Rules[0].Limits, "Rule should have limits after setDefaults") + + // Verify the limits are the default limits + expectedLimits := DefaultLimits() + assert.Equal(t, expectedLimits.MaxFiles, result.Rules[0].Limits.MaxFiles) + assert.Equal(t, expectedLimits.MaxFileSize, result.Rules[0].Limits.MaxFileSize) + assert.Equal(t, expectedLimits.AllowDirs, result.Rules[0].Limits.AllowDirs) + assert.Equal(t, expectedLimits.AllowSymlinks, result.Rules[0].Limits.AllowSymlinks) +} + +func TestRuleSetRoundTrip(t *testing.T) { + // Test complete round-trip: create -> save -> load -> verify + // This validates the entire serialization/deserialization pipeline + tempDir := t.TempDir() + + // Create original RuleSet + original := NewRuleSet(tempDir, true, + NewRule("*.go", SharedReadAccess("dev1@company.com", "dev2@university.edu"), DefaultLimits()), + NewRule("docs/*", PublicReadAccess(), DefaultLimits()), + ) + + // Save to file + err := original.Save() + require.NoError(t, err, "Save should succeed") + + // Load from file + loaded, err := LoadFromFile(tempDir) + require.NoError(t, err, "Load should succeed") + + // Verify key properties are preserved + assert.Equal(t, original.Path, loaded.Path, "Path should be preserved") + assert.Equal(t, original.Terminal, loaded.Terminal, "Terminal flag should be preserved") + + // Note: loaded will have additional default rule, so we check the original rules exist + assert.True(t, len(loaded.Rules) >= len(original.Rules), "Loaded rules should include all original rules") + + // Verify specific rules exist (order might differ due to default rule injection) + hasGoRule := false + hasDocsRule := false + for _, rule := range loaded.Rules { + if rule.Pattern == "*.go" { + hasGoRule = true + assert.True(t, rule.Access.Read.Contains("dev1@company.com"), "Go rule should preserve dev1 access") + assert.True(t, rule.Access.Read.Contains("dev2@university.edu"), "Go rule should preserve dev2 access") + } + if rule.Pattern == "docs/*" { + hasDocsRule = true + assert.True(t, rule.Access.Read.Contains("*"), "Docs rule should preserve public access") + // Note: Limits are not serialized to YAML (yaml:"-" tag), so they get default values after round-trip + assert.NotNil(t, rule.Limits, "Docs rule should have default limits after round-trip") + } + } + assert.True(t, hasGoRule, "Should preserve *.go rule") + assert.True(t, hasDocsRule, "Should preserve docs/* rule") +} diff --git a/internal/server/acl/cache_test.go b/internal/server/acl/cache_test.go new file mode 100644 index 00000000..dde3541c --- /dev/null +++ b/internal/server/acl/cache_test.go @@ -0,0 +1,349 @@ +package acl + +import ( + "fmt" + "sync" + "testing" + + "github.com/openmined/syftbox/internal/aclspec" + "github.com/stretchr/testify/assert" +) + +func TestNewRuleCache(t *testing.T) { + // Test creating a new cache + // This validates the constructor initializes the cache correctly + cache := NewRuleCache() + + assert.NotNil(t, cache, "NewRuleCache should return non-nil cache") + assert.NotNil(t, cache.index, "Cache should have initialized index map") + assert.Empty(t, cache.index, "New cache should start empty") +} + +func TestRuleCacheBasicOperations(t *testing.T) { + // Test basic cache operations: Set, Get, Delete + // This validates the core cache functionality + cache := NewRuleCache() + + // Create a mock rule for testing + mockNode := NewNode("test", false, 1) + mockRule := &Rule{ + fullPattern: "test/*.txt", + rule: aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + node: mockNode, + } + + // Test Get on empty cache + result := cache.Get("test/file.txt") + assert.Nil(t, result, "Get should return nil for non-existent entries") + + // Test Set and Get + cache.Set("test/file.txt", mockRule) + result = cache.Get("test/file.txt") + assert.Equal(t, mockRule, result, "Get should return the cached rule") + + // Test Delete + cache.Delete("test/file.txt") + result = cache.Get("test/file.txt") + assert.Nil(t, result, "Get should return nil after deletion") +} + +func TestRuleCacheVersionValidation(t *testing.T) { + // Test that cache validates rule versions to detect stale entries + // This is critical for cache invalidation when rules are updated + cache := NewRuleCache() + + // Create a node and rule + mockNode := NewNode("test", false, 1) + mockRule := &Rule{ + fullPattern: "test/*.txt", + rule: aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + node: mockNode, + } + + // Cache the rule with current node version + cache.Set("test/file.txt", mockRule) + + // Verify we can retrieve it + result := cache.Get("test/file.txt") + assert.Equal(t, mockRule, result, "Should retrieve cached rule with valid version") + + // Simulate node version change (like when rules are updated) + mockNode.SetRules([]*aclspec.Rule{ + aclspec.NewRule("*.md", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), + }, false) + + // Now the cached entry should be invalid due to version mismatch + result = cache.Get("test/file.txt") + assert.Nil(t, result, "Should return nil for stale cache entry with wrong version") + + // Verify the stale entry was automatically removed + assert.NotContains(t, cache.index, "test/file.txt", "Stale entry should be removed from cache") +} + +func TestRuleCacheDeletePrefix(t *testing.T) { + // Test the DeletePrefix operation which removes multiple entries + // This is important for bulk cache invalidation when directory rules change + cache := NewRuleCache() + + // Create multiple cache entries with related paths + mockNode := NewNode("test", false, 1) + mockRule := &Rule{ + fullPattern: "test/*.txt", + rule: aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + node: mockNode, + } + + // Add entries with common prefix + cache.Set("project/src/file1.go", mockRule) + cache.Set("project/src/file2.go", mockRule) + cache.Set("project/docs/readme.md", mockRule) + cache.Set("project/tests/test1.go", mockRule) + cache.Set("other/file.txt", mockRule) // Different prefix + + // Verify all entries exist + assert.NotNil(t, cache.Get("project/src/file1.go")) + assert.NotNil(t, cache.Get("project/src/file2.go")) + assert.NotNil(t, cache.Get("project/docs/readme.md")) + assert.NotNil(t, cache.Get("project/tests/test1.go")) + assert.NotNil(t, cache.Get("other/file.txt")) + + // Delete entries with "project/src" prefix + cache.DeletePrefix("project/src") + + // Verify only the prefixed entries were removed + assert.Nil(t, cache.Get("project/src/file1.go"), "Should remove project/src/file1.go") + assert.Nil(t, cache.Get("project/src/file2.go"), "Should remove project/src/file2.go") + assert.NotNil(t, cache.Get("project/docs/readme.md"), "Should keep project/docs/readme.md") + assert.NotNil(t, cache.Get("project/tests/test1.go"), "Should keep project/tests/test1.go") + assert.NotNil(t, cache.Get("other/file.txt"), "Should keep other/file.txt") + + // Delete broader prefix + cache.DeletePrefix("project") + + // Verify all project entries are removed + assert.Nil(t, cache.Get("project/docs/readme.md"), "Should remove project/docs/readme.md") + assert.Nil(t, cache.Get("project/tests/test1.go"), "Should remove project/tests/test1.go") + assert.NotNil(t, cache.Get("other/file.txt"), "Should keep other/file.txt") +} + +func TestRuleCacheDeletePrefixEdgeCases(t *testing.T) { + // Test DeletePrefix with edge cases and boundary conditions + // This ensures robust handling of unusual prefix patterns + cache := NewRuleCache() + + mockNode := NewNode("test", false, 1) + mockRule := &Rule{ + fullPattern: "test/*.txt", + rule: aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + node: mockNode, + } + + // Add test entries + cache.Set("", mockRule) // Empty path + cache.Set("a", mockRule) // Single character + cache.Set("ab", mockRule) // Two characters + cache.Set("abc", mockRule) // Three characters + cache.Set("abd", mockRule) // Similar prefix + + // Test deleting with empty prefix (should remove all entries that start with empty string, which is all entries) + cache.DeletePrefix("") + assert.Nil(t, cache.Get(""), "Should remove empty path entry") + assert.Nil(t, cache.Get("a"), "Should remove single character entry (empty prefix matches all)") + + // Re-add entries for next test + cache.Set("a", mockRule) + cache.Set("ab", mockRule) + cache.Set("abc", mockRule) + cache.Set("abd", mockRule) + + // Test deleting with single character prefix + cache.DeletePrefix("a") + assert.Nil(t, cache.Get("a"), "Should remove 'a'") + assert.Nil(t, cache.Get("ab"), "Should remove 'ab'") + assert.Nil(t, cache.Get("abc"), "Should remove 'abc'") + assert.Nil(t, cache.Get("abd"), "Should remove 'abd'") + + // Test deleting non-existent prefix (should be safe) + cache.DeletePrefix("nonexistent") + // Should not crash or cause issues +} + +func TestRuleCacheConcurrency(t *testing.T) { + // Test that cache operations are thread-safe + // This validates the mutex protection works correctly + cache := NewRuleCache() + + mockNode := NewNode("test", false, 1) + mockRule := &Rule{ + fullPattern: "test/*.txt", + rule: aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + node: mockNode, + } + + const numGoroutines = 10 + const numOperations = 100 + + // Test concurrent Set operations + var wg sync.WaitGroup + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < numOperations; j++ { + key := fmt.Sprintf("test/%d/%d.txt", id, j) + cache.Set(key, mockRule) + } + }(i) + } + wg.Wait() + + // Test concurrent Get operations + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < numOperations; j++ { + key := fmt.Sprintf("test/%d/%d.txt", id, j) + result := cache.Get(key) + assert.NotNil(t, result, "Should find cached entry") + } + }(i) + } + wg.Wait() + + // Test concurrent Delete operations + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < numOperations; j++ { + key := fmt.Sprintf("test/%d/%d.txt", id, j) + cache.Delete(key) + } + }(i) + } + wg.Wait() + + // Verify all entries were deleted + for i := 0; i < numGoroutines; i++ { + for j := 0; j < numOperations; j++ { + key := fmt.Sprintf("test/%d/%d.txt", i, j) + result := cache.Get(key) + assert.Nil(t, result, "Entry should be deleted") + } + } +} + +func TestRuleCacheMixedConcurrentOperations(t *testing.T) { + // Test mixed concurrent operations (Set, Get, Delete, DeletePrefix) + // This validates thread safety under realistic usage patterns + cache := NewRuleCache() + + mockNode := NewNode("test", false, 1) + mockRule := &Rule{ + fullPattern: "test/*.txt", + rule: aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + node: mockNode, + } + + const numWorkers = 5 + const duration = 100 // Number of operations per worker + + var wg sync.WaitGroup + + // Worker 1: Continuous Set operations + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < duration; i++ { + key := fmt.Sprintf("set/worker/%d.txt", i) + cache.Set(key, mockRule) + } + }() + + // Worker 2: Continuous Get operations + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < duration; i++ { + key := fmt.Sprintf("set/worker/%d.txt", i%10) // Access recently set items + cache.Get(key) + } + }() + + // Worker 3: Continuous Delete operations + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < duration; i++ { + key := fmt.Sprintf("set/worker/%d.txt", i) + cache.Delete(key) + } + }() + + // Worker 4: Periodic DeletePrefix operations + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < duration/10; i++ { + prefix := fmt.Sprintf("set/worker") + cache.DeletePrefix(prefix) + } + }() + + // Worker 5: Mixed operations + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < duration; i++ { + key := fmt.Sprintf("mixed/%d.txt", i) + cache.Set(key, mockRule) + cache.Get(key) + if i%5 == 0 { + cache.Delete(key) + } + } + }() + + wg.Wait() + + // Test should complete without deadlocks or race conditions + // The exact final state is unpredictable due to concurrency, + // but the operations should all complete successfully +} + +func TestRuleCacheMemoryManagement(t *testing.T) { + // Test that cache doesn't leak memory with repeated operations + // This validates proper cleanup of cache entries + cache := NewRuleCache() + + mockNode := NewNode("test", false, 1) + mockRule := &Rule{ + fullPattern: "test/*.txt", + rule: aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + node: mockNode, + } + + // Add many entries + for i := 0; i < 1000; i++ { + key := fmt.Sprintf("test/%d.txt", i) + cache.Set(key, mockRule) + } + + // Verify entries exist + assert.Equal(t, 1000, len(cache.index), "Should have 1000 entries") + + // Clear using DeletePrefix + cache.DeletePrefix("test") + + // Verify all entries are removed + assert.Equal(t, 0, len(cache.index), "Should have 0 entries after DeletePrefix") + + // Add entries again to test reuse + for i := 0; i < 100; i++ { + key := fmt.Sprintf("reuse/%d.txt", i) + cache.Set(key, mockRule) + } + + assert.Equal(t, 100, len(cache.index), "Should be able to reuse cache after clearing") +} \ No newline at end of file diff --git a/internal/server/acl/level_test.go b/internal/server/acl/level_test.go new file mode 100644 index 00000000..935d2742 --- /dev/null +++ b/internal/server/acl/level_test.go @@ -0,0 +1,192 @@ +package acl + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAccessLevelString(t *testing.T) { + // Test the String() method for all AccessLevel constants + // This validates that each access level has the correct string representation + // which is important for logging, debugging, and user-facing error messages + testCases := []struct { + level AccessLevel + expected string + desc string + }{ + { + level: AccessRead, + expected: "Read", + desc: "AccessRead should return 'Read'", + }, + { + level: AccessCreate, + expected: "Create", + desc: "AccessCreate should return 'Create'", + }, + { + level: AccessWrite, + expected: "Write", + desc: "AccessWrite should return 'Write'", + }, + { + level: AccessReadACL, + expected: "ReadACL", + desc: "AccessReadACL should return 'ReadACL'", + }, + { + level: AccessWriteACL, + expected: "WriteACL", + desc: "AccessWriteACL should return 'WriteACL'", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + result := tc.level.String() + assert.Equal(t, tc.expected, result, tc.desc) + }) + } +} + +func TestAccessLevelStringUnknown(t *testing.T) { + // Test String() method with invalid/unknown access level values + // This ensures the system handles undefined access levels gracefully + unknownLevel := AccessLevel(255) // Invalid access level + result := unknownLevel.String() + assert.Equal(t, "Unknown", result, "Unknown access levels should return 'Unknown'") +} + +func TestAccessLevelBitFlags(t *testing.T) { + // Test that AccessLevel constants are properly defined as bit flags + // This validates the bit flag implementation which allows for efficient permission checking + + // Verify each level has a unique bit pattern + assert.Equal(t, AccessLevel(1), AccessRead, "AccessRead should be bit 0 (value 1)") + assert.Equal(t, AccessLevel(2), AccessCreate, "AccessCreate should be bit 1 (value 2)") + assert.Equal(t, AccessLevel(4), AccessWrite, "AccessWrite should be bit 2 (value 4)") + assert.Equal(t, AccessLevel(8), AccessReadACL, "AccessReadACL should be bit 3 (value 8)") + assert.Equal(t, AccessLevel(16), AccessWriteACL, "AccessWriteACL should be bit 4 (value 16)") +} + +func TestAccessLevelUniqueness(t *testing.T) { + // Test that all AccessLevel constants are unique + // This prevents accidental duplicate values that could cause permission conflicts + levels := []AccessLevel{ + AccessRead, + AccessCreate, + AccessWrite, + AccessReadACL, + AccessWriteACL, + } + + // Check that no two levels have the same value + for i, level1 := range levels { + for j, level2 := range levels { + if i != j { + assert.NotEqual(t, level1, level2, + "AccessLevel constants should be unique: %s and %s have the same value", + level1.String(), level2.String()) + } + } + } +} + +func TestAccessLevelBitOperations(t *testing.T) { + // Test that bit operations work correctly with AccessLevel flags + // This validates that the bit flag design allows for combining permissions + + // Test combining permissions with OR + combined := AccessRead | AccessWrite + assert.NotEqual(t, AccessRead, combined, "Combined permissions should differ from individual permissions") + assert.NotEqual(t, AccessWrite, combined, "Combined permissions should differ from individual permissions") + + // Test checking individual permissions with AND + assert.Equal(t, AccessRead, combined&AccessRead, "Should be able to check for read permission in combined flags") + assert.Equal(t, AccessWrite, combined&AccessWrite, "Should be able to check for write permission in combined flags") + assert.Equal(t, AccessLevel(0), combined&AccessCreate, "Should not find create permission in read+write combination") +} + +func TestAccessLevelHierarchy(t *testing.T) { + // Test the logical hierarchy of access levels + // This documents the intended permission hierarchy in the system + + // Basic file operations should have lower bit values than ACL operations + assert.True(t, AccessRead < AccessReadACL, "Read should have lower value than ReadACL") + assert.True(t, AccessWrite < AccessWriteACL, "Write should have lower value than WriteACL") + + // Within basic operations, read should be the lowest level + assert.True(t, AccessRead < AccessCreate, "Read should be the most basic permission") + assert.True(t, AccessRead < AccessWrite, "Read should be lower than write") + + // ACL operations should be the highest levels + assert.True(t, AccessReadACL > AccessWrite, "ReadACL should be higher than basic write") + assert.True(t, AccessWriteACL > AccessReadACL, "WriteACL should be the highest permission") +} + +func TestAccessLevelZeroValue(t *testing.T) { + // Test the zero value of AccessLevel + // This ensures the default/uninitialized value behaves correctly + var zeroLevel AccessLevel + + assert.Equal(t, AccessLevel(0), zeroLevel, "Zero value should be 0") + assert.Equal(t, "Unknown", zeroLevel.String(), "Zero value should be treated as unknown") + + // Zero should not match any defined permission + assert.NotEqual(t, AccessRead, zeroLevel, "Zero should not equal AccessRead") + assert.NotEqual(t, AccessCreate, zeroLevel, "Zero should not equal AccessCreate") + assert.NotEqual(t, AccessWrite, zeroLevel, "Zero should not equal AccessWrite") + assert.NotEqual(t, AccessReadACL, zeroLevel, "Zero should not equal AccessReadACL") + assert.NotEqual(t, AccessWriteACL, zeroLevel, "Zero should not equal AccessWriteACL") +} + +func TestAccessLevelCasting(t *testing.T) { + // Test that AccessLevel can be properly cast from and to uint8 + // This validates the underlying type compatibility + + // Test casting from uint8 + readLevel := AccessLevel(1) + assert.Equal(t, AccessRead, readLevel, "Should be able to cast uint8 to AccessLevel") + + // Test casting to uint8 + readValue := uint8(AccessRead) + assert.Equal(t, uint8(1), readValue, "Should be able to cast AccessLevel to uint8") + + // Test round-trip casting + originalLevel := AccessWriteACL + castValue := uint8(originalLevel) + backToLevel := AccessLevel(castValue) + assert.Equal(t, originalLevel, backToLevel, "Round-trip casting should preserve value") +} + +func TestAccessLevelStringConsistency(t *testing.T) { + // Test that String() method is consistent across multiple calls + // This ensures no side effects or state changes in the String() method + level := AccessWrite + + firstCall := level.String() + secondCall := level.String() + thirdCall := level.String() + + assert.Equal(t, firstCall, secondCall, "String() should return consistent results") + assert.Equal(t, secondCall, thirdCall, "String() should return consistent results") + assert.Equal(t, "Write", firstCall, "String() should return correct value") +} + +func TestAccessLevelEdgeCases(t *testing.T) { + // Test edge cases and boundary values + // This ensures robust handling of unusual but possible values + + // Test maximum uint8 value + maxLevel := AccessLevel(255) + assert.Equal(t, "Unknown", maxLevel.String(), "Maximum value should be handled as unknown") + + // Test values between defined constants + betweenLevels := AccessLevel(3) // Between AccessCreate (2) and AccessWrite (4) + assert.Equal(t, "Unknown", betweenLevels.String(), "Undefined intermediate values should be unknown") + + // Test that the bit flag pattern continues to work with undefined values + undefinedLevel := AccessLevel(32) // Next bit after AccessWriteACL (16) + assert.Equal(t, "Unknown", undefinedLevel.String(), "Higher undefined bits should be unknown") +} \ No newline at end of file diff --git a/internal/server/acl/node.go b/internal/server/acl/node.go index fa8e4f05..52469471 100644 --- a/internal/server/acl/node.go +++ b/internal/server/acl/node.go @@ -100,6 +100,9 @@ func (n *Node) SetRules(rules []*aclspec.Rule, terminal bool) { }) } n.rules = aclRules + } else { + // Clear rules when empty or nil slice is provided + n.rules = nil } // set the rules and terminal flag diff --git a/internal/server/acl/tree.go b/internal/server/acl/tree.go index ea86ea71..41764c19 100644 --- a/internal/server/acl/tree.go +++ b/internal/server/acl/tree.go @@ -10,9 +10,10 @@ import ( ) var ( - ErrInvalidRuleset = errors.New("invalid ruleset") - ErrMaxDepthExceeded = errors.New("maximum depth exceeded") - ErrNoRuleFound = errors.New("no rule found") + ErrInvalidRuleset = errors.New("invalid ruleset") + ErrMaxDepthExceeded = errors.New("maximum depth exceeded") + ErrNoRuleFound = errors.New("no rule found") + ErrTerminalNodeExists = errors.New("cannot add child ruleset under terminal node") ) // Tree stores the ACL rules in a n-ary tree for efficient lookups. @@ -56,7 +57,11 @@ func (t *Tree) AddRuleSet(ruleset *aclspec.RuleSet) error { for _, part := range parts { currentDepth++ - // Important: We still process terminal nodes to ensure all ACLs are known to the tree + // Check if current node is terminal - if so, we cannot add children + if current.IsTerminal() { + return fmt.Errorf("%w: path %s", ErrTerminalNodeExists, current.path) + } + // Get or create child node child, exists := current.GetChild(part) if !exists { @@ -139,29 +144,48 @@ func (t *Tree) GetNode(path string) *Node { // Removes a ruleset at the specified path func (t *Tree) RemoveRuleSet(path string) bool { - var parent *Node - var lastPart string - parts := pathParts(path) current := t.root + // Traverse to find the target node for _, part := range parts { child, exists := current.GetChild(part) if !exists { return false } - - parent = current current = child - lastPart = part } - // Need to lock parent since we're modifying its children - parent.DeleteChild(lastPart) + // Check if the node has rules to remove + if current.Rules() == nil { + return false + } + + // Clear the rules from the node but keep the node structure + // This preserves any child nodes that may exist + current.SetRules(nil, false) + + // If the node has no children and no rules, we can remove it from its parent + // But only if it's not the root and doesn't have children + if len(parts) > 0 && !hasChildren(current) { + // Find parent to remove this empty node + parent := t.root + for _, part := range parts[:len(parts)-1] { + parent, _ = parent.GetChild(part) + } + parent.DeleteChild(parts[len(parts)-1]) + } return true } +// Helper function to check if node has children +func hasChildren(node *Node) bool { + node.mu.RLock() + defer node.mu.RUnlock() + return len(node.children) > 0 +} + func pathParts(path string) []string { return strings.Split(stripSep(path), pathSep) } diff --git a/internal/server/acl/tree_test.go b/internal/server/acl/tree_test.go index d962c259..9fd8b86f 100644 --- a/internal/server/acl/tree_test.go +++ b/internal/server/acl/tree_test.go @@ -64,7 +64,7 @@ func TestTreeTraversal(t *testing.T) { ruleset2 := aclspec.NewRuleSet( "parent/child", - aclspec.SetTerminal, + aclspec.UnsetTerminal, // Changed to non-terminal so we can add grandchild aclspec.NewRule("*.md", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) @@ -91,9 +91,9 @@ func TestTreeTraversal(t *testing.T) { assert.Equal(t, "parent/child", node.path) node = tree.GetNearestNodeWithRules("parent/child/grandchild/main.go") - assert.Equal(t, "parent/child", node.path) + assert.Equal(t, "parent/child/grandchild", node.path) - // Test inheritance - terminal nodes (like parent/child) block inheritance from higher levels + // Test inheritance - terminal nodes (like grandchild) block inheritance from higher levels node = tree.GetNearestNodeWithRules("parent/child/unknown.txt") assert.Equal(t, "parent/child", node.path) @@ -148,6 +148,301 @@ func TestRemoveRuleSet(t *testing.T) { assert.False(t, removed) } +func TestGetNode(t *testing.T) { + // Test the GetNode method which finds exact nodes for given paths + // This validates precise node location without rule inheritance logic + tree := NewTree() + + // Add nested rulesets to create a tree structure + ruleset1 := aclspec.NewRuleSet( + "parent", + aclspec.UnsetTerminal, + aclspec.NewRule("*.txt", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), + ) + + ruleset2 := aclspec.NewRuleSet( + "parent/child", + aclspec.SetTerminal, + aclspec.NewRule("*.md", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + ) + + err := tree.AddRuleSet(ruleset1) + assert.NoError(t, err) + + err = tree.AddRuleSet(ruleset2) + assert.NoError(t, err) + + // Test getting exact nodes that exist + parentNode := tree.GetNode("parent") + assert.NotNil(t, parentNode, "Should find parent node") + assert.Equal(t, "parent", parentNode.path, "Parent node should have correct path") + + childNode := tree.GetNode("parent/child") + assert.NotNil(t, childNode, "Should find child node") + assert.Equal(t, "parent/child", childNode.path, "Child node should have correct path") + + // Test getting node for path that goes beyond existing nodes + deepNode := tree.GetNode("parent/child/grandchild") + // Should return the deepest existing node (child), not create new nodes + assert.Equal(t, "parent/child", deepNode.path, "Should return deepest existing node") + + // Test getting node for path that doesn't exist at all + nonExistentNode := tree.GetNode("nonexistent/path") + // Should return root node since no path matches + assert.Equal(t, "/", nonExistentNode.path, "Should return root for non-existent paths") + + // Test getting root node + rootNode := tree.GetNode("") + assert.Equal(t, "/", rootNode.path, "Empty path should return root node") +} + +func TestGetNodeWithTerminalNodes(t *testing.T) { + // Test GetNode behavior with terminal nodes + // This ensures terminal flag affects node traversal correctly + tree := NewTree() + + // Add a terminal node + terminalRuleset := aclspec.NewRuleSet( + "terminal", + aclspec.SetTerminal, + aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + ) + + err := tree.AddRuleSet(terminalRuleset) + assert.NoError(t, err) + + // Verify the terminal node was added + terminalNode := tree.GetNode("terminal") + assert.NotNil(t, terminalNode, "Terminal node should exist") + assert.True(t, terminalNode.IsTerminal(), "Node should be marked as terminal") + + // Test that attempting to add a child under a terminal node should fail + // This is the correct behavior - terminal nodes should prevent child rulesets + childRuleset := aclspec.NewRuleSet( + "terminal/child", + aclspec.UnsetTerminal, + aclspec.NewRule("*.md", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), + ) + + err = tree.AddRuleSet(childRuleset) + // Terminal nodes should prevent child rulesets from being added + assert.Error(t, err, "Should reject child rulesets under terminal nodes") + assert.Contains(t, err.Error(), "terminal node", "Error should mention terminal restriction") + + // Test that GetNode stops at terminal node during traversal + // Since child ruleset was rejected, any path beyond terminal should stop at terminal + node := tree.GetNode("terminal/child/deeper") + assert.Equal(t, "terminal", node.path, "Terminal node should stop further traversal") + + // Test that GetNearestNodeWithRules also respects terminal nodes + nearestNode := tree.GetNearestNodeWithRules("terminal/child/file.txt") + assert.NotNil(t, nearestNode, "Should find the terminal node") + assert.Equal(t, "terminal", nearestNode.path, "Should stop at terminal node for rule lookup") + + // Verify that the child node was not actually created + childNode := tree.GetNode("terminal/child") + assert.Equal(t, "terminal", childNode.path, "Child node should not exist, should return terminal parent") +} + +func TestTerminalNodeValidation(t *testing.T) { + // Test that terminal nodes properly prevent child rulesets from being added + // This explicitly tests the terminal enforcement behavior + tree := NewTree() + + // Add a terminal node + terminalRuleset := aclspec.NewRuleSet( + "secure", + aclspec.SetTerminal, + aclspec.NewRule("**", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + ) + + err := tree.AddRuleSet(terminalRuleset) + assert.NoError(t, err, "Should be able to add terminal node") + + // Verify it's terminal + node := tree.GetNode("secure") + assert.True(t, node.IsTerminal(), "Node should be marked as terminal") + + // Try to add direct child - should fail + childRuleset := aclspec.NewRuleSet( + "secure/child", + aclspec.UnsetTerminal, + aclspec.NewRule("*.txt", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), + ) + + err = tree.AddRuleSet(childRuleset) + assert.Error(t, err, "Should not be able to add child under terminal node") + assert.Contains(t, err.Error(), "terminal node", "Error should mention terminal node") + + // Try to add deeper nested child - should also fail + deepChildRuleset := aclspec.NewRuleSet( + "secure/child/grandchild", + aclspec.UnsetTerminal, + aclspec.NewRule("*.md", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), + ) + + err = tree.AddRuleSet(deepChildRuleset) + assert.Error(t, err, "Should not be able to add nested child under terminal node") + assert.Contains(t, err.Error(), "terminal node", "Error should mention terminal node") + + // Verify that non-terminal nodes still allow children + nonTerminalRuleset := aclspec.NewRuleSet( + "open", + aclspec.UnsetTerminal, + aclspec.NewRule("**", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), + ) + + err = tree.AddRuleSet(nonTerminalRuleset) + assert.NoError(t, err, "Should be able to add non-terminal node") + + // Add child under non-terminal - should succeed + openChildRuleset := aclspec.NewRuleSet( + "open/child", + aclspec.UnsetTerminal, + aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + ) + + err = tree.AddRuleSet(openChildRuleset) + assert.NoError(t, err, "Should be able to add child under non-terminal node") +} + +func TestConflictingRuleSetsAtSameLevel(t *testing.T) { + // Test what happens when adding multiple rulesets to the same path + // This tests ruleset replacement/overwriting behavior + tree := NewTree() + + // Add initial ruleset + initialRuleset := aclspec.NewRuleSet( + "shared", + aclspec.UnsetTerminal, + aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + ) + + err := tree.AddRuleSet(initialRuleset) + assert.NoError(t, err, "Should be able to add initial ruleset") + + // Verify initial ruleset + node := tree.GetNode("shared") + assert.NotNil(t, node, "Node should exist") + assert.False(t, node.IsTerminal(), "Node should not be terminal initially") + assert.Len(t, node.Rules(), 1, "Should have 1 rule initially") + + // Add conflicting ruleset at the SAME path with different rules and terminal flag + conflictingRuleset := aclspec.NewRuleSet( + "shared", // Same path! + aclspec.SetTerminal, // Different terminal flag + aclspec.NewRule("*.md", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), // Different rule + aclspec.NewRule("**", aclspec.SharedReadAccess("admin@example.com"), aclspec.DefaultLimits()), // Additional rule + ) + + err = tree.AddRuleSet(conflictingRuleset) + assert.NoError(t, err, "Should be able to add conflicting ruleset (overwrites)") + + // Verify the conflicting ruleset completely replaced the original + node = tree.GetNode("shared") + assert.NotNil(t, node, "Node should still exist") + assert.True(t, node.IsTerminal(), "Node should now be terminal (overwritten)") + assert.Len(t, node.Rules(), 2, "Should have 2 rules from new ruleset") + + // Verify the conflicting ruleset completely replaced the original + // The original *.txt rule with PrivateAccess is gone + // Now we have *.md rule with PublicReadAccess and ** rule with SharedReadAccess + + // Test that *.txt files now match the ** rule (not the original *.txt rule) + rule, err := node.FindBestRule("shared/test.txt") + assert.NoError(t, err, "Should find rule for *.txt files") + assert.Equal(t, "**", rule.rule.Pattern, "Should match the ** rule, not original *.txt rule") + assert.True(t, rule.rule.Access.Read.Contains("admin@example.com"), "Should have admin access (from ** rule)") + assert.False(t, rule.rule.Access.Read.Contains("*"), "Should NOT have public access (original rule is gone)") + + // Test that *.md files match the more specific *.md rule + rule, err = node.FindBestRule("shared/test.md") + assert.NoError(t, err, "Should find rule for *.md files") + assert.Equal(t, "*.md", rule.rule.Pattern, "Should match the specific *.md rule") + assert.True(t, rule.rule.Access.Read.Contains("*"), "Should have public read access from *.md rule") + + // Test that other files match the ** rule + rule, err = node.FindBestRule("shared/anything.xyz") + assert.NoError(t, err, "Should find rule for other files") + assert.Equal(t, "**", rule.rule.Pattern, "Should match the ** rule") + assert.True(t, rule.rule.Access.Read.Contains("admin@example.com"), "Should have admin access from ** rule") + + // Test that the terminal flag now prevents children + childRuleset := aclspec.NewRuleSet( + "shared/child", + aclspec.UnsetTerminal, + aclspec.NewRule("*.go", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + ) + + err = tree.AddRuleSet(childRuleset) + assert.Error(t, err, "Should not be able to add child under terminal node") + assert.Contains(t, err.Error(), "terminal node", "Error should mention terminal restriction") +} + +func TestAddRuleSetErrorCases(t *testing.T) { + // Test AddRuleSet with various error conditions + // This improves coverage of edge cases and error handling + tree := NewTree() + + // Test with nil ruleset + err := tree.AddRuleSet(nil) + assert.Error(t, err, "Should reject nil ruleset") + assert.Contains(t, err.Error(), "ruleset is nil", "Error should indicate nil ruleset") + + // Test with empty ruleset (no rules) + emptyRuleset := &aclspec.RuleSet{ + Path: "test", + Terminal: false, + Rules: []*aclspec.Rule{}, + } + err = tree.AddRuleSet(emptyRuleset) + assert.Error(t, err, "Should reject empty ruleset") + assert.Contains(t, err.Error(), "ruleset is empty", "Error should indicate empty ruleset") + + // Test with extremely deep path (path depth > 255) + deepPath := "" + for i := 0; i < 300; i++ { // Create path with > 255 components + deepPath += "a/" + } + deepRuleset := aclspec.NewRuleSet( + deepPath, + aclspec.UnsetTerminal, + aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + ) + err = tree.AddRuleSet(deepRuleset) + assert.Error(t, err, "Should reject paths that are too deep") + assert.Contains(t, err.Error(), "maximum depth exceeded", "Error should indicate depth limit") +} + +func TestAddRuleSetPathNormalization(t *testing.T) { + // Test that AddRuleSet properly normalizes different path formats + // This ensures consistent path handling across different input formats + tree := NewTree() + + // Test with path that has leading/trailing separators + rule := aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()) + + // These should all result in the same normalized path + testPaths := []string{ + "test/path", + "/test/path", + "test/path/", + "/test/path/", + "./test/path", + } + + for i, path := range testPaths { + ruleset := aclspec.NewRuleSet(path, false, rule) + err := tree.AddRuleSet(ruleset) + assert.NoError(t, err, "Should accept path format: %s", path) + + // All paths should result in the same node being found + node := tree.GetNode("test/path") + assert.NotNil(t, node, "Should find node for normalized path (test %d)", i) + assert.Equal(t, "test/path", node.path, "Path should be normalized consistently (test %d)", i) + } +} + func TestNestedRuleSetRemoval(t *testing.T) { tree := NewTree() @@ -170,13 +465,22 @@ func TestNestedRuleSetRemoval(t *testing.T) { err = tree.AddRuleSet(ruleset2) assert.NoError(t, err) - // Remove parent - should also remove child + // Remove parent ruleset - should only remove parent, not affect child removed := tree.RemoveRuleSet("parent") assert.True(t, removed) - // Verify both are gone - _, ok := tree.root.GetChild("parent") - assert.False(t, ok) + // Verify parent node structure still exists but rules are gone + parentNode, ok := tree.root.GetChild("parent") + assert.True(t, ok, "Parent node should still exist") + assert.NotNil(t, parentNode, "Parent node should not be nil") + assert.Nil(t, parentNode.Rules(), "Parent node should have no rules after removal") + + // Verify child is still there with its rules intact + childNode, ok := parentNode.GetChild("child") + assert.True(t, ok, "Child node should still exist") + assert.NotNil(t, childNode, "Child node should not be nil") + assert.NotNil(t, childNode.Rules(), "Child node should still have its rules") + assert.True(t, childNode.IsTerminal(), "Child should still be terminal") // Add the parent ruleset back err = tree.AddRuleSet(ruleset1) @@ -191,7 +495,7 @@ func TestNestedRuleSetRemoval(t *testing.T) { assert.True(t, removed) // Verify parent still exists - parentNode, ok := tree.root.GetChild("parent") + parentNode, ok = tree.root.GetChild("parent") assert.True(t, ok) assert.NotNil(t, parentNode) From ce210c5c7a4e4e0247a81dae2c755b803c907081 Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Mon, 16 Jun 2025 17:21:45 +1000 Subject: [PATCH 02/19] Revert terminal node enforcement to original design Reverts changes to terminal node behavior to match the original design: **Core Reversion:** - **Terminal nodes allow children**: Child rulesets can be added under terminal nodes for performance (avoids tree rebuilds) - **Terminal controls lookup/inheritance**: Terminal flag stops traversal during GetNode/GetNearestNodeWithRules, not during AddRuleSet - **Original RemoveRuleSet**: Restores original behavior that removes entire subtrees **Technical Changes:** - tree.go: Remove ErrTerminalNodeExists, restore original AddRuleSet logic, restore original RemoveRuleSet behavior - tree_test.go: Update tests to verify children can be added but lookups stop at terminal boundaries **Design Rationale:** Tree stores all ACL files for performance optimization while terminal flag controls rule inheritance during lookups. Child ACL files exist in tree structure but are ignored when under terminal nodes. --- internal/server/acl/tree.go | 50 +++--------- internal/server/acl/tree_test.go | 132 +++++++++++++++++++------------ 2 files changed, 94 insertions(+), 88 deletions(-) diff --git a/internal/server/acl/tree.go b/internal/server/acl/tree.go index 41764c19..740360fa 100644 --- a/internal/server/acl/tree.go +++ b/internal/server/acl/tree.go @@ -10,10 +10,9 @@ import ( ) var ( - ErrInvalidRuleset = errors.New("invalid ruleset") - ErrMaxDepthExceeded = errors.New("maximum depth exceeded") - ErrNoRuleFound = errors.New("no rule found") - ErrTerminalNodeExists = errors.New("cannot add child ruleset under terminal node") + ErrInvalidRuleset = errors.New("invalid ruleset") + ErrMaxDepthExceeded = errors.New("maximum depth exceeded") + ErrNoRuleFound = errors.New("no rule found") ) // Tree stores the ACL rules in a n-ary tree for efficient lookups. @@ -57,11 +56,7 @@ func (t *Tree) AddRuleSet(ruleset *aclspec.RuleSet) error { for _, part := range parts { currentDepth++ - // Check if current node is terminal - if so, we cannot add children - if current.IsTerminal() { - return fmt.Errorf("%w: path %s", ErrTerminalNodeExists, current.path) - } - + // Important: We still process terminal nodes to ensure all ACLs are known to the tree // Get or create child node child, exists := current.GetChild(part) if !exists { @@ -144,52 +139,33 @@ func (t *Tree) GetNode(path string) *Node { // Removes a ruleset at the specified path func (t *Tree) RemoveRuleSet(path string) bool { + var parent *Node + var lastPart string + parts := pathParts(path) current := t.root - // Traverse to find the target node for _, part := range parts { child, exists := current.GetChild(part) if !exists { return false } - current = child - } - // Check if the node has rules to remove - if current.Rules() == nil { - return false + parent = current + current = child + lastPart = part } - // Clear the rules from the node but keep the node structure - // This preserves any child nodes that may exist - current.SetRules(nil, false) - - // If the node has no children and no rules, we can remove it from its parent - // But only if it's not the root and doesn't have children - if len(parts) > 0 && !hasChildren(current) { - // Find parent to remove this empty node - parent := t.root - for _, part := range parts[:len(parts)-1] { - parent, _ = parent.GetChild(part) - } - parent.DeleteChild(parts[len(parts)-1]) - } + // Need to lock parent since we're modifying its children + parent.DeleteChild(lastPart) return true } -// Helper function to check if node has children -func hasChildren(node *Node) bool { - node.mu.RLock() - defer node.mu.RUnlock() - return len(node.children) > 0 -} - func pathParts(path string) []string { return strings.Split(stripSep(path), pathSep) } func stripSep(path string) string { return strings.TrimLeft(filepath.Clean(path), pathSep) -} +} \ No newline at end of file diff --git a/internal/server/acl/tree_test.go b/internal/server/acl/tree_test.go index 9fd8b86f..c38898ce 100644 --- a/internal/server/acl/tree_test.go +++ b/internal/server/acl/tree_test.go @@ -198,14 +198,14 @@ func TestGetNode(t *testing.T) { func TestGetNodeWithTerminalNodes(t *testing.T) { // Test GetNode behavior with terminal nodes - // This ensures terminal flag affects node traversal correctly + // Terminal nodes allow children to be added but stop traversal during lookups tree := NewTree() - // Add a terminal node + // Add a terminal node with catch-all rule terminalRuleset := aclspec.NewRuleSet( "terminal", aclspec.SetTerminal, - aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + aclspec.NewRule("**", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) err := tree.AddRuleSet(terminalRuleset) @@ -216,8 +216,8 @@ func TestGetNodeWithTerminalNodes(t *testing.T) { assert.NotNil(t, terminalNode, "Terminal node should exist") assert.True(t, terminalNode.IsTerminal(), "Node should be marked as terminal") - // Test that attempting to add a child under a terminal node should fail - // This is the correct behavior - terminal nodes should prevent child rulesets + // Add a child under the terminal node - this should SUCCEED + // The tree allows all nodes to be added for performance (avoids tree rebuilds) childRuleset := aclspec.NewRuleSet( "terminal/child", aclspec.UnsetTerminal, @@ -225,28 +225,40 @@ func TestGetNodeWithTerminalNodes(t *testing.T) { ) err = tree.AddRuleSet(childRuleset) - // Terminal nodes should prevent child rulesets from being added - assert.Error(t, err, "Should reject child rulesets under terminal nodes") - assert.Contains(t, err.Error(), "terminal node", "Error should mention terminal restriction") - - // Test that GetNode stops at terminal node during traversal - // Since child ruleset was rejected, any path beyond terminal should stop at terminal - node := tree.GetNode("terminal/child/deeper") - assert.Equal(t, "terminal", node.path, "Terminal node should stop further traversal") - - // Test that GetNearestNodeWithRules also respects terminal nodes + assert.NoError(t, err, "Should allow child rulesets under terminal nodes (tree contains all ACLs)") + + // GetNode should stop at terminal node during lookup traversal + // Even though child exists in tree, GetNode stops at terminal boundary + childNode := tree.GetNode("terminal/child") + assert.Equal(t, "terminal", childNode.path, "GetNode should stop at terminal boundary") + + // When looking for paths beyond terminal, it also stops at the terminal node + deepNode := tree.GetNode("terminal/child/deeper") + assert.Equal(t, "terminal", deepNode.path, "Lookup should stop at terminal node") + + // Verify child actually exists in tree structure by accessing parent's children directly + terminalNode = tree.GetNode("terminal") + actualChild, exists := terminalNode.GetChild("child") + assert.True(t, exists, "Child should exist in tree structure") + assert.Equal(t, "terminal/child", actualChild.path, "Child should have correct path") + assert.False(t, actualChild.IsTerminal(), "Child should not be terminal") + + // AND: GetNearestNodeWithRules should also stop at terminal nodes nearestNode := tree.GetNearestNodeWithRules("terminal/child/file.txt") assert.NotNil(t, nearestNode, "Should find the terminal node") - assert.Equal(t, "terminal", nearestNode.path, "Should stop at terminal node for rule lookup") - - // Verify that the child node was not actually created - childNode := tree.GetNode("terminal/child") - assert.Equal(t, "terminal", childNode.path, "Child node should not exist, should return terminal parent") + assert.Equal(t, "terminal", nearestNode.path, "Rule lookup should stop at terminal node") + + // This means child rules are ignored for inheritance even though they exist in the tree + // Test with child path - should get rule from terminal node (** pattern matches everything) + rule, err := tree.GetRule("terminal/child/test.md") + assert.NoError(t, err, "Should find rule from terminal node") + assert.Equal(t, "terminal", rule.node.path, "Rule should come from terminal node, not child") + assert.Equal(t, "**", rule.rule.Pattern, "Should use terminal node's catch-all rule") } func TestTerminalNodeValidation(t *testing.T) { - // Test that terminal nodes properly prevent child rulesets from being added - // This explicitly tests the terminal enforcement behavior + // Test that terminal nodes control inheritance but allow children to be added + // This explicitly tests the correct terminal behavior tree := NewTree() // Add a terminal node @@ -263,7 +275,7 @@ func TestTerminalNodeValidation(t *testing.T) { node := tree.GetNode("secure") assert.True(t, node.IsTerminal(), "Node should be marked as terminal") - // Try to add direct child - should fail + // Add direct child - should succeed (tree allows all nodes for performance) childRuleset := aclspec.NewRuleSet( "secure/child", aclspec.UnsetTerminal, @@ -271,10 +283,9 @@ func TestTerminalNodeValidation(t *testing.T) { ) err = tree.AddRuleSet(childRuleset) - assert.Error(t, err, "Should not be able to add child under terminal node") - assert.Contains(t, err.Error(), "terminal node", "Error should mention terminal node") + assert.NoError(t, err, "Should be able to add child under terminal node (exists in tree)") - // Try to add deeper nested child - should also fail + // Add deeper nested child - should also succeed deepChildRuleset := aclspec.NewRuleSet( "secure/child/grandchild", aclspec.UnsetTerminal, @@ -282,10 +293,28 @@ func TestTerminalNodeValidation(t *testing.T) { ) err = tree.AddRuleSet(deepChildRuleset) - assert.Error(t, err, "Should not be able to add nested child under terminal node") - assert.Contains(t, err.Error(), "terminal node", "Error should mention terminal node") + assert.NoError(t, err, "Should be able to add nested child under terminal node") - // Verify that non-terminal nodes still allow children + // BUT: Terminal nodes should stop traversal for rule lookups + // All paths under secure/ should resolve to the secure terminal node + + // Direct child path should resolve to terminal parent + nearestNode := tree.GetNearestNodeWithRules("secure/child/test.txt") + assert.NotNil(t, nearestNode, "Should find a node") + assert.Equal(t, "secure", nearestNode.path, "Should resolve to terminal parent, not child") + + // Deep child path should also resolve to terminal parent + nearestNode = tree.GetNearestNodeWithRules("secure/child/grandchild/test.md") + assert.NotNil(t, nearestNode, "Should find a node") + assert.Equal(t, "secure", nearestNode.path, "Should resolve to terminal parent, not grandchild") + + // Rule lookup should use terminal node's rules + rule, err := tree.GetRule("secure/child/test.txt") + assert.NoError(t, err, "Should find rule for child path") + assert.Equal(t, "secure", rule.node.path, "Rule should come from terminal parent") + assert.Equal(t, "**", rule.rule.Pattern, "Should use terminal node's catch-all rule") + + // Non-terminal nodes should work normally nonTerminalRuleset := aclspec.NewRuleSet( "open", aclspec.UnsetTerminal, @@ -295,7 +324,7 @@ func TestTerminalNodeValidation(t *testing.T) { err = tree.AddRuleSet(nonTerminalRuleset) assert.NoError(t, err, "Should be able to add non-terminal node") - // Add child under non-terminal - should succeed + // Add child under non-terminal - should succeed and be accessible openChildRuleset := aclspec.NewRuleSet( "open/child", aclspec.UnsetTerminal, @@ -304,6 +333,11 @@ func TestTerminalNodeValidation(t *testing.T) { err = tree.AddRuleSet(openChildRuleset) assert.NoError(t, err, "Should be able to add child under non-terminal node") + + // Child under non-terminal should be accessible for rule lookups + nearestNode = tree.GetNearestNodeWithRules("open/child/test.txt") + assert.NotNil(t, nearestNode, "Should find a node") + assert.Equal(t, "open/child", nearestNode.path, "Should find the actual child node, not parent") } func TestConflictingRuleSetsAtSameLevel(t *testing.T) { @@ -367,7 +401,7 @@ func TestConflictingRuleSetsAtSameLevel(t *testing.T) { assert.Equal(t, "**", rule.rule.Pattern, "Should match the ** rule") assert.True(t, rule.rule.Access.Read.Contains("admin@example.com"), "Should have admin access from ** rule") - // Test that the terminal flag now prevents children + // Test that the terminal flag now controls inheritance childRuleset := aclspec.NewRuleSet( "shared/child", aclspec.UnsetTerminal, @@ -375,8 +409,13 @@ func TestConflictingRuleSetsAtSameLevel(t *testing.T) { ) err = tree.AddRuleSet(childRuleset) - assert.Error(t, err, "Should not be able to add child under terminal node") - assert.Contains(t, err.Error(), "terminal node", "Error should mention terminal restriction") + assert.NoError(t, err, "Should be able to add child under terminal node (tree allows all nodes)") + + // But rule lookups should stop at terminal node + rule, err = tree.GetRule("shared/child/test.go") + assert.NoError(t, err, "Should find rule for child path") + assert.Equal(t, "shared", rule.node.path, "Rule should come from terminal parent, not child") + assert.Equal(t, "**", rule.rule.Pattern, "Should use parent's ** rule, not child's *.go rule") } func TestAddRuleSetErrorCases(t *testing.T) { @@ -465,22 +504,13 @@ func TestNestedRuleSetRemoval(t *testing.T) { err = tree.AddRuleSet(ruleset2) assert.NoError(t, err) - // Remove parent ruleset - should only remove parent, not affect child + // Remove parent - with original behavior, this removes the entire subtree removed := tree.RemoveRuleSet("parent") assert.True(t, removed) - // Verify parent node structure still exists but rules are gone - parentNode, ok := tree.root.GetChild("parent") - assert.True(t, ok, "Parent node should still exist") - assert.NotNil(t, parentNode, "Parent node should not be nil") - assert.Nil(t, parentNode.Rules(), "Parent node should have no rules after removal") - - // Verify child is still there with its rules intact - childNode, ok := parentNode.GetChild("child") - assert.True(t, ok, "Child node should still exist") - assert.NotNil(t, childNode, "Child node should not be nil") - assert.NotNil(t, childNode.Rules(), "Child node should still have its rules") - assert.True(t, childNode.IsTerminal(), "Child should still be terminal") + // Verify both parent and child are gone (original behavior) + _, ok := tree.root.GetChild("parent") + assert.False(t, ok, "Parent node should be completely removed") // Add the parent ruleset back err = tree.AddRuleSet(ruleset1) @@ -494,12 +524,12 @@ func TestNestedRuleSetRemoval(t *testing.T) { removed = tree.RemoveRuleSet("parent/child") assert.True(t, removed) - // Verify parent still exists - parentNode, ok = tree.root.GetChild("parent") - assert.True(t, ok) - assert.NotNil(t, parentNode) + // Verify parent still exists but child was removed + parentNode, ok := tree.root.GetChild("parent") + assert.True(t, ok, "Parent node should exist after removing just child") + assert.NotNil(t, parentNode, "Parent node should not be nil") // Verify child was removed _, ok = parentNode.GetChild("child") - assert.False(t, ok) + assert.False(t, ok, "Child node should be removed") } From 959c81bbc2cad1353504a43fa8961a047cba7330 Mon Sep 17 00:00:00 2001 From: Yash Gorana Date: Mon, 16 Jun 2025 20:31:32 +0530 Subject: [PATCH 03/19] chore: drop unused migrate --- internal/aclspec/migrate.go | 36 ----- internal/aclspec/migrate_test.go | 247 ------------------------------- 2 files changed, 283 deletions(-) delete mode 100644 internal/aclspec/migrate.go delete mode 100644 internal/aclspec/migrate_test.go diff --git a/internal/aclspec/migrate.go b/internal/aclspec/migrate.go deleted file mode 100644 index 091121e9..00000000 --- a/internal/aclspec/migrate.go +++ /dev/null @@ -1,36 +0,0 @@ -package aclspec - -import ( - "fmt" - - "gopkg.in/yaml.v3" -) - -type PermissionType string - -const ( - Read PermissionType = "read" - Create PermissionType = "create" - Write PermissionType = "write" - Execute PermissionType = "admin" -) - -type LegacyRule struct { - Path string `yaml:"path"` - User string `yaml:"user"` - Permissions []PermissionType `yaml:"permissions"` -} - -type LegacyPermission struct { - Rules []LegacyRule -} - -// UnmarshalYAML implements the yaml.Unmarshaler interface. -// This allows LegacyPermission to be unmarshalled from a YAML sequence -// directly into its Rules field. -func (lp *LegacyPermission) UnmarshalYAML(node *yaml.Node) error { - if node.Kind != yaml.SequenceNode { - return fmt.Errorf("cannot unmarshal %s into LegacyPermission.Rules, expected a sequence", node.Tag) - } - return node.Decode(&lp.Rules) -} diff --git a/internal/aclspec/migrate_test.go b/internal/aclspec/migrate_test.go deleted file mode 100644 index 7c04d76b..00000000 --- a/internal/aclspec/migrate_test.go +++ /dev/null @@ -1,247 +0,0 @@ -package aclspec - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "gopkg.in/yaml.v3" -) - -func TestLegacyPermissionUnmarshalYAML(t *testing.T) { - // Test unmarshaling legacy permission format from YAML - // This validates backward compatibility with older permission file formats - yamlContent := ` -- path: "documents/*" - user: "alice@research.org" - permissions: ["read", "write"] -- path: "public/*.txt" - user: "bob@university.edu" - permissions: ["read"] -- path: "admin/*" - user: "admin@company.com" - permissions: ["read", "write", "admin"] -` - - var legacyPerm LegacyPermission - err := yaml.Unmarshal([]byte(yamlContent), &legacyPerm) - require.NoError(t, err, "Should successfully unmarshal legacy permission format") - - // Verify the correct number of rules were parsed - assert.Len(t, legacyPerm.Rules, 3, "Should parse all three legacy rules") - - // Verify first rule (alice with read/write on documents) - rule1 := legacyPerm.Rules[0] - assert.Equal(t, "documents/*", rule1.Path, "First rule path should be correct") - assert.Equal(t, "alice@research.org", rule1.User, "First rule user should be correct") - assert.Len(t, rule1.Permissions, 2, "First rule should have 2 permissions") - assert.Contains(t, rule1.Permissions, Read, "First rule should include read permission") - assert.Contains(t, rule1.Permissions, Write, "First rule should include write permission") - - // Verify second rule (bob with read on public txt files) - rule2 := legacyPerm.Rules[1] - assert.Equal(t, "public/*.txt", rule2.Path, "Second rule path should be correct") - assert.Equal(t, "bob@university.edu", rule2.User, "Second rule user should be correct") - assert.Len(t, rule2.Permissions, 1, "Second rule should have 1 permission") - assert.Contains(t, rule2.Permissions, Read, "Second rule should include read permission") - - // Verify third rule (admin with full permissions) - rule3 := legacyPerm.Rules[2] - assert.Equal(t, "admin/*", rule3.Path, "Third rule path should be correct") - assert.Equal(t, "admin@company.com", rule3.User, "Third rule user should be correct") - assert.Len(t, rule3.Permissions, 3, "Third rule should have 3 permissions") - assert.Contains(t, rule3.Permissions, Read, "Third rule should include read permission") - assert.Contains(t, rule3.Permissions, Write, "Third rule should include write permission") - assert.Contains(t, rule3.Permissions, Execute, "Third rule should include admin permission") -} - -func TestLegacyPermissionUnmarshalEmptyYAML(t *testing.T) { - // Test unmarshaling empty legacy permission list - // This ensures the system handles empty legacy files gracefully - yamlContent := `[]` - - var legacyPerm LegacyPermission - err := yaml.Unmarshal([]byte(yamlContent), &legacyPerm) - require.NoError(t, err, "Should successfully unmarshal empty legacy permission list") - - assert.Empty(t, legacyPerm.Rules, "Empty YAML should result in empty rules list") -} - -func TestLegacyPermissionUnmarshalInvalidYAML(t *testing.T) { - // Test that invalid YAML structure is properly rejected - // This ensures the system fails safely on malformed legacy files - yamlContent := `"this_is_not_a_sequence"` - - var legacyPerm LegacyPermission - err := yaml.Unmarshal([]byte(yamlContent), &legacyPerm) - assert.Error(t, err, "Should reject non-sequence YAML for legacy permissions") - assert.Contains(t, err.Error(), "expected a sequence", "Error should indicate sequence was expected") -} - -func TestLegacyPermissionUnmarshalPartialRule(t *testing.T) { - // Test unmarshaling legacy rules with missing fields - // This validates handling of incomplete legacy data - yamlContent := ` -- path: "test/*" - user: "testuser@example.com" - # permissions field intentionally missing -- path: "incomplete" - # user and permissions fields missing -` - - var legacyPerm LegacyPermission - err := yaml.Unmarshal([]byte(yamlContent), &legacyPerm) - require.NoError(t, err, "Should handle legacy rules with missing fields") - - assert.Len(t, legacyPerm.Rules, 2, "Should parse both rules despite missing fields") - - // First rule should have path and user, but empty permissions - rule1 := legacyPerm.Rules[0] - assert.Equal(t, "test/*", rule1.Path) - assert.Equal(t, "testuser@example.com", rule1.User) - assert.Empty(t, rule1.Permissions, "Missing permissions should result in empty slice") - - // Second rule should have only path - rule2 := legacyPerm.Rules[1] - assert.Equal(t, "incomplete", rule2.Path) - assert.Empty(t, rule2.User, "Missing user should result in empty string") - assert.Empty(t, rule2.Permissions, "Missing permissions should result in empty slice") -} - -func TestPermissionTypeConstants(t *testing.T) { - // Test that permission type constants have expected values - // This ensures the constants match the legacy format requirements - assert.Equal(t, PermissionType("read"), Read, "Read permission constant should be correct") - assert.Equal(t, PermissionType("create"), Create, "Create permission constant should be correct") - assert.Equal(t, PermissionType("write"), Write, "Write permission constant should be correct") - assert.Equal(t, PermissionType("admin"), Execute, "Execute/Admin permission constant should be correct") -} - -func TestLegacyRulePermissionTypes(t *testing.T) { - // Test that all permission types can be properly parsed - // This validates the enum handling for different permission levels - yamlContent := ` -- path: "test/*" - user: "testuser@example.com" - permissions: ["read", "create", "write", "admin"] -` - - var legacyPerm LegacyPermission - err := yaml.Unmarshal([]byte(yamlContent), &legacyPerm) - require.NoError(t, err, "Should parse all permission types") - - rule := legacyPerm.Rules[0] - assert.Len(t, rule.Permissions, 4, "Should have all 4 permission types") - assert.Contains(t, rule.Permissions, Read, "Should contain read permission") - assert.Contains(t, rule.Permissions, Create, "Should contain create permission") - assert.Contains(t, rule.Permissions, Write, "Should contain write permission") - assert.Contains(t, rule.Permissions, Execute, "Should contain admin/execute permission") -} - -func TestLegacyRuleWithDuplicatePermissions(t *testing.T) { - // Test handling of duplicate permissions in legacy format - // This ensures duplicate permissions are handled gracefully - yamlContent := ` -- path: "test/*" - user: "testuser@example.com" - permissions: ["read", "read", "write", "read"] -` - - var legacyPerm LegacyPermission - err := yaml.Unmarshal([]byte(yamlContent), &legacyPerm) - require.NoError(t, err, "Should handle duplicate permissions") - - rule := legacyPerm.Rules[0] - // YAML unmarshaling preserves duplicates in slice - assert.Len(t, rule.Permissions, 4, "Should preserve all permission entries including duplicates") - - // Count actual unique permissions - uniquePerms := make(map[PermissionType]bool) - for _, perm := range rule.Permissions { - uniquePerms[perm] = true - } - assert.Len(t, uniquePerms, 2, "Should have 2 unique permission types (read and write)") -} - -func TestLegacyRuleWithInvalidPermissions(t *testing.T) { - // Test handling of invalid permission types in legacy format - // This documents behavior with unknown permission strings - yamlContent := ` -- path: "test/*" - user: "testuser@example.com" - permissions: ["read", "invalid_permission", "write"] -` - - var legacyPerm LegacyPermission - err := yaml.Unmarshal([]byte(yamlContent), &legacyPerm) - require.NoError(t, err, "Should parse even with invalid permissions") - - rule := legacyPerm.Rules[0] - assert.Len(t, rule.Permissions, 3, "Should include all permission strings including invalid ones") - assert.Contains(t, rule.Permissions, Read, "Should contain valid read permission") - assert.Contains(t, rule.Permissions, Write, "Should contain valid write permission") - assert.Contains(t, rule.Permissions, PermissionType("invalid_permission"), "Should preserve invalid permission as-is") -} - -func TestLegacyPermissionComplexPaths(t *testing.T) { - // Test legacy permission parsing with complex file paths - // This ensures path handling works correctly for various path formats - yamlContent := ` -- path: "**/*.go" - user: "developer@company.com" - permissions: ["read", "write"] -- path: "/absolute/path/*" - user: "admin@openmined.org" - permissions: ["admin"] -- path: "relative/../path" - user: "user@university.edu" - permissions: ["read"] -- path: "" - user: "empty_path_user@example.com" - permissions: ["read"] -` - - var legacyPerm LegacyPermission - err := yaml.Unmarshal([]byte(yamlContent), &legacyPerm) - require.NoError(t, err, "Should handle complex paths") - - assert.Len(t, legacyPerm.Rules, 4, "Should parse all rules with complex paths") - - // Verify each path is preserved exactly as specified - assert.Equal(t, "**/*.go", legacyPerm.Rules[0].Path, "Glob pattern should be preserved") - assert.Equal(t, "/absolute/path/*", legacyPerm.Rules[1].Path, "Absolute path should be preserved") - assert.Equal(t, "relative/../path", legacyPerm.Rules[2].Path, "Relative path with .. should be preserved") - assert.Equal(t, "", legacyPerm.Rules[3].Path, "Empty path should be preserved") -} - -func TestLegacyPermissionSpecialCharacters(t *testing.T) { - // Test legacy permission parsing with special characters in user names and paths - // This validates handling of edge cases in legacy data - yamlContent := ` -- path: "files with spaces/*" - user: "user.with.dots" - permissions: ["read"] -- path: "unicode/测试/*" - user: "用户名" - permissions: ["write"] -- path: "symbols!@#$%/*" - user: "user_with_underscores" - permissions: ["admin"] -` - - var legacyPerm LegacyPermission - err := yaml.Unmarshal([]byte(yamlContent), &legacyPerm) - require.NoError(t, err, "Should handle special characters in paths and users") - - assert.Len(t, legacyPerm.Rules, 3, "Should parse all rules with special characters") - - // Verify special characters are preserved - assert.Equal(t, "files with spaces/*", legacyPerm.Rules[0].Path, "Spaces in paths should be preserved") - assert.Equal(t, "user.with.dots", legacyPerm.Rules[0].User, "Dots in usernames should be preserved") - - assert.Equal(t, "unicode/测试/*", legacyPerm.Rules[1].Path, "Unicode characters should be preserved") - assert.Equal(t, "用户名", legacyPerm.Rules[1].User, "Unicode usernames should be preserved") - - assert.Equal(t, "symbols!@#$%/*", legacyPerm.Rules[2].Path, "Symbol characters should be preserved") - assert.Equal(t, "user_with_underscores", legacyPerm.Rules[2].User, "Underscores should be preserved") -} \ No newline at end of file From f49fe23c607ba264e427d2ff325da08e17675a43 Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Tue, 17 Jun 2025 15:36:22 +1000 Subject: [PATCH 04/19] Add Docker development stack with auth bypass for local dev (#23) - Docker server/client with MinIO integration - Auth bypass via SYFTBOX_AUTH_ENABLED=0 for local development - Multi-client support with per-email config persistence - Just commands for easy Docker orchestration - Smart entrypoint handling local vs production servers --- docker/Dockerfile.client | 45 ++++++++++++++++++ docker/Dockerfile.server | 39 +++++++++++++++ docker/docker-compose-client.yml | 26 ++++++++++ docker/docker-compose.yml | 75 +++++++++++++++++++++++++++++ docker/entrypoint-client.sh | 77 ++++++++++++++++++++++++++++++ internal/syftsdk/sdk.go | 13 ++--- justfile | 81 ++++++++++++++++++++++++++++++++ 7 files changed, 348 insertions(+), 8 deletions(-) create mode 100644 docker/Dockerfile.client create mode 100644 docker/Dockerfile.server create mode 100644 docker/docker-compose-client.yml create mode 100644 docker/docker-compose.yml create mode 100644 docker/entrypoint-client.sh diff --git a/docker/Dockerfile.client b/docker/Dockerfile.client new file mode 100644 index 00000000..4238d43c --- /dev/null +++ b/docker/Dockerfile.client @@ -0,0 +1,45 @@ +# Build stage +FROM golang:1.24.3-alpine AS builder + +# Install build dependencies +RUN apk add --no-cache git + +# Set working directory +WORKDIR /app + +# Copy go mod and sum files +COPY go.mod go.sum ./ + +# Download dependencies +RUN go mod download + +# Copy the source code +COPY . . + +# Build the client binary +RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo -o syftbox ./cmd/client + +# Final stage +FROM alpine:latest + +# Install ca-certificates for HTTPS +RUN apk --no-cache add ca-certificates bash + +WORKDIR /root/ + +# Copy the binary from builder +COPY --from=builder /app/syftbox . + +# Create base directories +RUN mkdir -p /root/.syftbox + +# Copy entrypoint script +COPY docker/entrypoint-client.sh /entrypoint.sh +RUN chmod +x /entrypoint.sh + +# Expose the daemon API port +EXPOSE 7938 + +# Use entrypoint script +ENTRYPOINT ["/entrypoint.sh"] +CMD ["--help"] \ No newline at end of file diff --git a/docker/Dockerfile.server b/docker/Dockerfile.server new file mode 100644 index 00000000..c45401be --- /dev/null +++ b/docker/Dockerfile.server @@ -0,0 +1,39 @@ +# Build stage +FROM golang:1.24.3-alpine AS builder + +# Install build dependencies +RUN apk add --no-cache git + +# Set working directory +WORKDIR /app + +# Copy go mod and sum files +COPY go.mod go.sum ./ + +# Download dependencies +RUN go mod download + +# Copy the source code +COPY . . + +# Build the server binary +RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo -o server ./cmd/server + +# Final stage +FROM alpine:latest + +# Install ca-certificates for HTTPS and mc (MinIO client) +RUN apk --no-cache add ca-certificates curl && \ + curl -sSL https://dl.min.io/client/mc/release/linux-amd64/mc -o /usr/local/bin/mc && \ + chmod +x /usr/local/bin/mc + +WORKDIR /root/ + +# Copy the binary from builder +COPY --from=builder /app/server . + +# Expose the server port (adjust if needed) +EXPOSE 8080 + +# Run the server +CMD ["./server"] \ No newline at end of file diff --git a/docker/docker-compose-client.yml b/docker/docker-compose-client.yml new file mode 100644 index 00000000..c9ee406e --- /dev/null +++ b/docker/docker-compose-client.yml @@ -0,0 +1,26 @@ +version: '3.8' + +services: + client: + build: + context: .. + dockerfile: docker/Dockerfile.client + image: syftbox-client:latest + container_name: syftbox-client-${CLIENT_EMAIL:-default} + ports: + - "${CLIENT_PORT:-7938}:7938" + environment: + - SYFTBOX_SERVER_URL=${SYFTBOX_SERVER_URL:-http://syftbox-server:8080} + - SYFTBOX_AUTH_ENABLED=0 + volumes: + - ${SYFTBOX_CLIENTS_DIR:-~/.syftbox/clients}:/data/clients + networks: + - syftbox-network + stdin_open: true + tty: true + command: ${CLIENT_EMAIL:---help} + +networks: + syftbox-network: + external: true + name: docker_syftbox-network \ No newline at end of file diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml new file mode 100644 index 00000000..9978d48a --- /dev/null +++ b/docker/docker-compose.yml @@ -0,0 +1,75 @@ +services: + minio: + image: minio/minio:RELEASE.2025-04-22T22-12-26Z + container_name: syftbox-minio + ports: + - "9000:9000" + - "9001:9001" + environment: + - MINIO_ROOT_USER=minioadmin + - MINIO_ROOT_PASSWORD=minioadmin + volumes: + - minio-data:/data + - ../minio/init.d:/etc/minio/init.d + command: server /data --console-address ':9001' + networks: + - syftbox-network + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] + interval: 5s + timeout: 5s + retries: 5 + + server: + build: + context: .. + dockerfile: docker/Dockerfile.server + container_name: syftbox-server + ports: + - "8080:8080" + environment: + - SYFTBOX_ENV=DEV + - SYFTBOX_AUTH_ENABLED=0 + - SYFTBOX_AUTH_EMAIL_OTP_LENGTH=8 + - SYFTBOX_AUTH_EMAIL_OTP_EXPIRY=5m + - SYFTBOX_AUTH_TOKEN_ISSUER=https://syftboxdev.openmined.org + - SYFTBOX_AUTH_REFRESH_TOKEN_SECRET=123 + - SYFTBOX_AUTH_REFRESH_TOKEN_EXPIRY=0 + - SYFTBOX_AUTH_ACCESS_TOKEN_SECRET=132 + - SYFTBOX_AUTH_ACCESS_TOKEN_EXPIRY=72h + - SYFTBOX_BLOB_REGION=us-east-1 + - SYFTBOX_BLOB_BUCKET_NAME=syftbox-local + - SYFTBOX_BLOB_ENDPOINT=http://minio:9000 + - SYFTBOX_BLOB_ACCESS_KEY=ptSLdKiwOi2LYQFZYEZ6 + - SYFTBOX_BLOB_SECRET_KEY=GMDvYrAhWDkB2DyFMn8gU8I8Bg0fT3JGT6iEB7P8 + - SYFTBOX_HTTP_ADDR=0.0.0.0:8080 + networks: + - syftbox-network + depends_on: + minio: + condition: service_healthy + restart: unless-stopped + command: > + sh -c " + # Wait for MinIO to be ready and run setup + until mc alias set local http://minio:9000 minioadmin minioadmin >/dev/null 2>&1; do + echo 'Waiting for MinIO...' + sleep 1 + done + echo 'Running MinIO setup...' + # Update the setup script to use the correct endpoint + sed 's|http://localhost:9000|http://minio:9000|g' /etc/minio/init.d/setup.sh > /tmp/setup.sh + chmod +x /tmp/setup.sh + /tmp/setup.sh || true + echo 'Starting server...' + ./server + " + volumes: + - ../minio/init.d:/etc/minio/init.d:ro + +networks: + syftbox-network: + driver: bridge + +volumes: + minio-data: \ No newline at end of file diff --git a/docker/entrypoint-client.sh b/docker/entrypoint-client.sh new file mode 100644 index 00000000..0dba331a --- /dev/null +++ b/docker/entrypoint-client.sh @@ -0,0 +1,77 @@ +#!/bin/bash +set -e + +# Function to setup client configuration +setup_client() { + local email="$1" + local server_url="${SYFTBOX_SERVER_URL:-http://syftbox-server:8080}" + local client_dir="/data/clients/${email}" + local config_file="${client_dir}/config.json" + local data_dir="${client_dir}/SyftBox" + + # Create directories if they don't exist + mkdir -p "${client_dir}" + mkdir -p "${data_dir}" + + # Set environment variables for this session + export SYFTBOX_CONFIG_PATH="${config_file}" + export SYFTBOX_DATA_DIR="${data_dir}" + export SYFTBOX_EMAIL="${email}" + export SYFTBOX_SERVER_URL="${server_url}" + + # For local dev, create a simple config that bypasses auth + if [[ "${server_url}" == *"syftbox-server"* ]] || [[ "${server_url}" == *"localhost"* ]] || [[ "${server_url}" == *"127.0.0.1"* ]]; then + echo "Setting up local dev config (auth bypass)" + cat > "${config_file}" << EOF +{ + "data_dir": "${data_dir}", + "email": "${email}", + "server_url": "${server_url}", + "client_url": "http://localhost:7938" +} +EOF + fi + + # Create symlinks for easier access in container + ln -sf "${config_file}" /root/.syftbox/config.json + ln -sf "${data_dir}" /root/SyftBox + + echo "Client setup for ${email}:" + echo " Config: ${config_file}" + echo " Data: ${data_dir}" + echo " Server: ${server_url}" +} + +# If email is provided as first argument, setup the client +if [[ "$1" =~ ^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$ ]]; then + email="$1" + shift + setup_client "${email}" + + # If no additional command provided, run daemon + if [ $# -eq 0 ]; then + echo "Starting daemon for ${email}..." + exec ./syftbox daemon + else + # Run the provided command + exec ./syftbox "$@" + fi +elif [ "$1" = "login" ] && [ -n "$2" ]; then + # Handle login command specially + email="$2" + shift 2 + setup_client "${email}" + + # For local dev servers, skip actual login and just setup config + if [[ "${SYFTBOX_SERVER_URL:-http://syftbox-server:8080}" == *"syftbox-server"* ]] || [[ "${SYFTBOX_SERVER_URL:-http://syftbox-server:8080}" == *"localhost"* ]] || [[ "${SYFTBOX_SERVER_URL:-http://syftbox-server:8080}" == *"127.0.0.1"* ]]; then + echo "Local dev server detected - skipping OTP login" + echo "Config created at: ${SYFTBOX_CONFIG_PATH}" + echo "You can now run: just run-docker-client-daemon ${email}" + exit 0 + else + exec ./syftbox login "$@" + fi +else + # Pass through to syftbox + exec ./syftbox "$@" +fi \ No newline at end of file diff --git a/internal/syftsdk/sdk.go b/internal/syftsdk/sdk.go index fca8f9ba..7c06684e 100644 --- a/internal/syftsdk/sdk.go +++ b/internal/syftsdk/sdk.go @@ -6,7 +6,6 @@ import ( "fmt" "log/slog" "os" - "strings" "time" "github.com/openmined/syftbox/internal/utils" @@ -75,8 +74,8 @@ func (s *SyftSDK) Close() { // Authenticate sets the user authentication for API calls and events func (s *SyftSDK) Authenticate(ctx context.Context) error { - if isDevMode(s.config.BaseURL) { - slog.Warn("sdk is in DEV mode, skipping auth") + if isAuthDisabled() { + slog.Warn("sdk auth disabled, skipping auth") return nil } @@ -168,9 +167,7 @@ func (s *SyftSDK) setAccessToken(accessToken string) error { return nil } -func isDevMode(baseURL string) bool { - return strings.Contains(baseURL, "localhost") || - strings.Contains(baseURL, "127.0.0.1") || - strings.Contains(baseURL, "0.0.0.0") || - os.Getenv("SYFTBOX_DEV_MODE") == "true" +func isAuthDisabled() bool { + authEnabled := os.Getenv("SYFTBOX_AUTH_ENABLED") + return authEnabled == "0" || authEnabled == "false" } diff --git a/justfile b/justfile index c4554826..3c442b0c 100644 --- a/justfile +++ b/justfile @@ -83,6 +83,87 @@ destroy-minio: ssh-minio: docker exec -it syftbox-minio bash +[group('dev-docker')] +run-docker-server: + #!/bin/bash + set -eou pipefail + echo "Building and running SyftBox server with MinIO in Docker..." + cd docker && COMPOSE_BAKE=true docker-compose up -d --build minio server + echo "Server is running at http://localhost:8080" + echo "MinIO console is available at http://localhost:9001" + echo "Run 'cd docker && docker-compose logs -f server' to view server logs" + +[group('dev-docker')] +run-docker-client email *ARGS: + #!/bin/bash + set -eou pipefail + + # Build the client image + docker build -f docker/Dockerfile.client -t syftbox-client . + + # Create clients directory if it doesn't exist + mkdir -p ~/.syftbox/clients + + if [ -z "{{ email }}" ]; then + echo "Usage: just run-docker-client [command]" + echo "Examples:" + echo " just run-docker-client user@example.com login" + echo " just run-docker-client user@example.com daemon" + echo " just run-docker-client user@example.com app list" + exit 1 + fi + + # Sanitize email for container name (replace @ with -at- and . with -dot-) + container_name="syftbox-client-$(echo '{{ email }}' | sed 's/@/-at-/g' | sed 's/\./-dot-/g')" + + # Run the client with email-specific configuration + docker run --rm -it \ + -v ~/.syftbox/clients:/data/clients \ + --network docker_syftbox-network \ + -e SYFTBOX_SERVER_URL=http://syftbox-server:8080 \ + -e SYFTBOX_AUTH_ENABLED=0 \ + --name "$container_name" \ + syftbox-client {{ email }} {{ ARGS }} + +[group('dev-docker')] +run-docker-client-daemon email: + #!/bin/bash + set -eou pipefail + + # Build and run client in daemon mode using docker-compose + cd docker && CLIENT_EMAIL={{ email }} docker-compose -f docker-compose-client.yml up -d --build + echo "Client daemon for {{ email }} is running at http://localhost:7938" + echo "Logs: cd docker && docker-compose -f docker-compose-client.yml logs -f" + +[group('dev-docker')] +stop-docker-client email: + #!/bin/bash + set -eou pipefail + + cd docker && CLIENT_EMAIL={{ email }} docker-compose -f docker-compose-client.yml down + +[group('dev-docker')] +list-docker-clients: + #!/bin/bash + set -eou pipefail + + echo "Available SyftBox clients:" + if [ -d ~/.syftbox/clients ]; then + ls -la ~/.syftbox/clients/ | grep -E '^d' | grep -v '\.$' | awk '{print " - " $NF}' + else + echo " No clients found" + fi + +[group('dev-docker')] +destroy-docker-server: + #!/bin/bash + set -eou pipefail + echo "Stopping and removing SyftBox Docker containers..." + cd docker && docker-compose down -v + echo "Removing Docker images..." + docker rmi syftbox-server syftbox-client 2>/dev/null || true + echo "Docker environment cleaned up" + [group('dev')] test: env -i \ From 73546c6497ba1e1aff2b6c71127ff812dbe9fae1 Mon Sep 17 00:00:00 2001 From: Yash Date: Tue, 17 Jun 2025 17:05:41 +0530 Subject: [PATCH 05/19] feat(server): update perms (#12) * fix(server/explorer): serve empty dirs * refactor(server/acl): rename acl->ACL + update perms + simplified caching + fixes * feat(server/acl): add acl endpoints * feat(server): enforce perms + standardized api error * feat(client): read ignore list from file --- internal/aclspec/aclspec.go | 16 +- internal/aclspec/aclspec_test.go | 36 ++--- internal/aclspec/ruleset.go | 10 +- internal/client/sync/sync_ignore.go | 43 ++++- internal/client/sync/sync_manager.go | 4 + internal/server/acl/acl.go | 95 ++++++----- internal/server/acl/acl_test.go | 123 +++++++------- internal/server/acl/cache.go | 51 +++--- internal/server/acl/cache_test.go | 70 ++++---- internal/server/acl/level.go | 14 +- internal/server/acl/level_test.go | 89 ++++------ internal/server/acl/node.go | 153 +++++++++++------- internal/server/acl/node_test.go | 50 +++--- internal/server/acl/rule.go | 36 +++-- internal/server/acl/tree.go | 70 ++++---- .../server/acl/{debug.go => tree_debug.go} | 15 +- internal/server/acl/tree_test.go | 136 ++++++++-------- internal/server/acl/types.go | 3 +- internal/server/datasite/datasite.go | 13 +- internal/server/datasite/utils.go | 39 ++++- internal/server/handlers/acl/acl_handler.go | 43 +++++ .../server/handlers/acl/acl_handler_types.go | 16 ++ internal/server/handlers/api/codes.go | 30 ++++ internal/server/handlers/api/error.go | 12 ++ internal/server/handlers/api/response.go | 12 ++ internal/server/handlers/auth/auth_handler.go | 44 ++--- internal/server/handlers/blob/blob_handler.go | 42 ++--- .../handlers/blob/blob_handler_delete.go | 54 ++++--- .../blob/blob_handler_download_presigned.go | 70 ++++---- .../handlers/blob/blob_handler_types.go | 44 +++-- .../handlers/blob/blob_handler_upload.go | 52 +++--- .../handlers/blob/blob_handler_upload_acl.go | 101 ++++++++++++ .../blob/blob_handler_upload_presigned.go | 64 ++++---- .../handlers/datasite/datasite_handler.go | 10 -- .../handlers/explorer/explorer_handler.go | 18 +-- .../explorer/explorer_handler_types.go | 4 + internal/server/handlers/ws/ws_hub.go | 18 +-- internal/server/middlewares/jwtauth.go | 26 +-- internal/server/middlewares/ratelimiter.go | 11 +- internal/server/routes.go | 19 ++- internal/server/server.go | 3 +- internal/server/services.go | 4 +- 42 files changed, 1074 insertions(+), 689 deletions(-) rename internal/server/acl/{debug.go => tree_debug.go} (93%) create mode 100644 internal/server/handlers/acl/acl_handler.go create mode 100644 internal/server/handlers/acl/acl_handler_types.go create mode 100644 internal/server/handlers/api/codes.go create mode 100644 internal/server/handlers/api/error.go create mode 100644 internal/server/handlers/api/response.go create mode 100644 internal/server/handlers/blob/blob_handler_upload_acl.go diff --git a/internal/aclspec/aclspec.go b/internal/aclspec/aclspec.go index 628362d7..6558fdfa 100644 --- a/internal/aclspec/aclspec.go +++ b/internal/aclspec/aclspec.go @@ -14,28 +14,28 @@ const ( UnsetTerminal = false ) -// IsAclFile checks if the path is an syft.pub.yaml file -func IsAclFile(path string) bool { +// IsACLFile checks if the path is an syft.pub.yaml file +func IsACLFile(path string) bool { return strings.HasSuffix(path, AclFileName) } -// AsAclPath converts any path to exact acl file path -func AsAclPath(path string) string { - if IsAclFile(path) { +// AsACLPath converts any path to exact acl file path +func AsACLPath(path string) string { + if IsACLFile(path) { return path } return filepath.Join(path, AclFileName) } -// WithoutAclPath truncates syft.pub.yaml from the path -func WithoutAclPath(path string) string { +// WithoutACLPath truncates syft.pub.yaml from the path +func WithoutACLPath(path string) string { return strings.TrimSuffix(path, AclFileName) } // Exists checks if the ACL file exists at the given path // For security reasons, symlinks are not allowed as ACL files func Exists(path string) bool { - aclPath := AsAclPath(path) + aclPath := AsACLPath(path) stat, err := os.Lstat(aclPath) // Use Lstat to not follow symlinks if os.IsNotExist(err) { return false diff --git a/internal/aclspec/aclspec_test.go b/internal/aclspec/aclspec_test.go index a0f63e77..1199325b 100644 --- a/internal/aclspec/aclspec_test.go +++ b/internal/aclspec/aclspec_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestIsAclFile(t *testing.T) { +func TestIsACLFile(t *testing.T) { // Test the core ACL file detection logic // This is critical for the system to recognize permission files correctly testCases := []struct { @@ -66,13 +66,13 @@ func TestIsAclFile(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - result := IsAclFile(tc.path) + result := IsACLFile(tc.path) assert.Equal(t, tc.expected, result, "Path: %s", tc.path) }) } } -func TestAsAclPath(t *testing.T) { +func TestAsACLPath(t *testing.T) { // Test converting directory paths to ACL file paths // This ensures the system can correctly locate ACL files for directories testCases := []struct { @@ -114,13 +114,13 @@ func TestAsAclPath(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - result := AsAclPath(tc.input) + result := AsACLPath(tc.input) assert.Equal(t, tc.expected, result, "Input: %s", tc.input) }) } } -func TestWithoutAclPath(t *testing.T) { +func TestWithoutACLPath(t *testing.T) { // Test removing ACL filename from paths // This is used to get the directory path from an ACL file path testCases := []struct { @@ -162,7 +162,7 @@ func TestWithoutAclPath(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - result := WithoutAclPath(tc.input) + result := WithoutACLPath(tc.input) assert.Equal(t, tc.expected, result, "Input: %s", tc.input) }) } @@ -298,15 +298,15 @@ rules: assert.Error(t, err, "LoadFromFile() should reject symlinked ACL files") assert.Contains(t, err.Error(), "symlinks are not allowed", "Error should mention symlink restriction") - // 3. AsAclPath and IsAclFile should work normally (they don't check file existence) - aclPath := AsAclPath(symlinkDir) - assert.Equal(t, symlinkAclFile, aclPath, "AsAclPath should work normally") - assert.True(t, IsAclFile(aclPath), "IsAclFile should work normally") + // 3. AsACLPath and IsACLFile should work normally (they don't check file existence) + aclPath := AsACLPath(symlinkDir) + assert.Equal(t, symlinkAclFile, aclPath, "AsACLPath should work normally") + assert.True(t, IsACLFile(aclPath), "IsACLFile should work normally") - // 4. Test that WithoutAclPath works normally - dirPath := WithoutAclPath(aclPath) - // Note: WithoutAclPath may add trailing separator, so we use filepath.Clean for comparison - assert.Equal(t, symlinkDir, filepath.Clean(dirPath), "WithoutAclPath should work normally") + // 4. Test that WithoutACLPath works normally + dirPath := WithoutACLPath(aclPath) + // Note: WithoutACLPath may add trailing separator, so we use filepath.Clean for comparison + assert.Equal(t, symlinkDir, filepath.Clean(dirPath), "WithoutACLPath should work normally") // Verify that regular files still work regularDir := filepath.Join(tempDir, "regular") @@ -340,18 +340,18 @@ func TestPathEdgeCases(t *testing.T) { // Test with paths containing special characters specialPath := "/home/user with spaces/syft.pub.yaml" - assert.True(t, IsAclFile(specialPath), "Paths with spaces should be handled correctly") + assert.True(t, IsACLFile(specialPath), "Paths with spaces should be handled correctly") // Test with Unicode characters unicodePath := "/home/用户/syft.pub.yaml" - assert.True(t, IsAclFile(unicodePath), "Unicode paths should be handled correctly") + assert.True(t, IsACLFile(unicodePath), "Unicode paths should be handled correctly") // Test with multiple slashes multiSlashPath := "/home//user///syft.pub.yaml" expected := "/home//user///syft.pub.yaml" - assert.Equal(t, expected, AsAclPath(multiSlashPath), "Multiple slashes should be preserved") + assert.Equal(t, expected, AsACLPath(multiSlashPath), "Multiple slashes should be preserved") // Test with backslashes (Windows-style paths) windowsPath := "C:\\Users\\user\\syft.pub.yaml" - assert.True(t, IsAclFile(windowsPath), "Windows-style paths should be detected") + assert.True(t, IsACLFile(windowsPath), "Windows-style paths should be detected") } \ No newline at end of file diff --git a/internal/aclspec/ruleset.go b/internal/aclspec/ruleset.go index 57ad6e41..eedf09b2 100644 --- a/internal/aclspec/ruleset.go +++ b/internal/aclspec/ruleset.go @@ -19,7 +19,7 @@ type RuleSet struct { // NewRuleSet creates a new RuleSet instance with the given path, terminal flag, and initial rules. func NewRuleSet(path string, terminal bool, rules ...*Rule) *RuleSet { return &RuleSet{ - Path: WithoutAclPath(path), + Path: WithoutACLPath(path), Terminal: terminal, Rules: rules, } @@ -32,7 +32,7 @@ func (r *RuleSet) AllRules() []*Rule { // LoadFromFile loads a RuleSet from the specified file path // For security reasons, symlinks are not allowed as ACL files func LoadFromFile(path string) (*RuleSet, error) { - aclPath := AsAclPath(path) + aclPath := AsACLPath(path) // Check if file is a symlink before opening stat, err := os.Lstat(aclPath) @@ -55,7 +55,7 @@ func LoadFromFile(path string) (*RuleSet, error) { // LoadFromReader creates a RuleSet by reading and parsing YAML content from the provided reader. // The path parameter is used to set the internal path of the RuleSet. -func LoadFromReader(path string, reader io.ReadCloser) (*RuleSet, error) { +func LoadFromReader(path string, reader io.Reader) (*RuleSet, error) { data, err := io.ReadAll(reader) if err != nil { return nil, err @@ -66,12 +66,12 @@ func LoadFromReader(path string, reader io.ReadCloser) (*RuleSet, error) { return nil, err } - ruleset.Path = WithoutAclPath(path) + ruleset.Path = WithoutACLPath(path) return setDefaults(&ruleset) } func (r *RuleSet) Save() error { - aclPath := AsAclPath(r.Path) + aclPath := AsACLPath(r.Path) file, err := os.Create(aclPath) if err != nil { return fmt.Errorf("failed to create file %s: %w", r.Path, err) diff --git a/internal/client/sync/sync_ignore.go b/internal/client/sync/sync_ignore.go index 898a815d..df39c324 100644 --- a/internal/client/sync/sync_ignore.go +++ b/internal/client/sync/sync_ignore.go @@ -1,6 +1,12 @@ package sync import ( + "bufio" + "log/slog" + "os" + "path/filepath" + + "github.com/openmined/syftbox/internal/utils" gitignore "github.com/sabhiram/go-gitignore" ) @@ -37,11 +43,42 @@ type SyncIgnoreList struct { } func NewSyncIgnoreList(baseDir string) *SyncIgnoreList { - ignore := gitignore.CompileIgnoreLines(defaultIgnoreLines...) - return &SyncIgnoreList{baseDir: baseDir, ignore: ignore} + return &SyncIgnoreList{baseDir: baseDir} +} + +func (s *SyncIgnoreList) Load() { + ignorePath := filepath.Join(s.baseDir, "syftignore") + ignoreLines := defaultIgnoreLines + + // read the syftignore file if it exists + if utils.FileExists(ignorePath) { + rules := 0 + file, err := os.Open(ignorePath) + if err != nil { + slog.Warn("Failed to open syftignore file", "path", ignorePath, "error", err) + } + defer file.Close() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + if line != "" { + ignoreLines = append(ignoreLines, line) + rules++ + } + } + + // Check for errors during the scan + if err := scanner.Err(); err != nil { + slog.Warn("Error reading syftignore file", "path", ignorePath, "error", err) + } else { + slog.Info("Loaded syftignore file", "path", ignorePath, "rules", rules) + } + } + + s.ignore = gitignore.CompileIgnoreLines(ignoreLines...) } func (s *SyncIgnoreList) ShouldIgnore(path string) bool { - // todo strip baseDir from relPath return s.ignore.MatchesPath(path) } diff --git a/internal/client/sync/sync_manager.go b/internal/client/sync/sync_manager.go index 31e5c64e..61782c2e 100644 --- a/internal/client/sync/sync_manager.go +++ b/internal/client/sync/sync_manager.go @@ -39,6 +39,10 @@ func NewManager(workspace *workspace.Workspace, sdk *syftsdk.SyftSDK) (*SyncMana func (m *SyncManager) Start(ctx context.Context) error { slog.Info("sync manager start") + + // load the ignore list + m.ignore.Load() + if err := m.watcher.Start(ctx); err != nil { return fmt.Errorf("failed to start watcher: %w", err) } diff --git a/internal/server/acl/acl.go b/internal/server/acl/acl.go index 921dbeaa..29173288 100644 --- a/internal/server/acl/acl.go +++ b/internal/server/acl/acl.go @@ -1,47 +1,68 @@ package acl import ( + "errors" + "fmt" + "github.com/openmined/syftbox/internal/aclspec" ) -// AclService helps to manage and enforce access control rules for file system operations. -type AclService struct { - tree *Tree - cache *RuleCache +// ACLService helps to manage and enforce access control rules for file system operations. +type ACLService struct { + tree *ACLTree + cache *ACLCache } -// NewAclService creates a new ACL service instance -func NewAclService() *AclService { - return &AclService{ - tree: NewTree(), - cache: NewRuleCache(), +// NewACLService creates a new ACL service instance +func NewACLService() *ACLService { + return &ACLService{ + tree: NewACLTree(), + cache: NewACLCache(), + } +} + +// AddRuleSet adds or updates a new set of rules to the service. +func (s *ACLService) AddRuleSet(ruleSet *aclspec.RuleSet) (ACLVersion, error) { + version, err := s.tree.AddRuleSet(ruleSet) + if err != nil { + return 0, err } + + s.cache.DeletePrefix(ruleSet.Path) + return version, nil } -func (s *AclService) LoadRuleSets(ruleSets []*aclspec.RuleSet) error { +// AddRuleSets adds a new set of rules to the service. +func (s *ACLService) AddRuleSets(ruleSets []*aclspec.RuleSet) error { + errs := make([]error, 0) + for _, ruleSet := range ruleSets { - if err := s.tree.AddRuleSet(ruleSet); err != nil { - return err + if _, err := s.tree.AddRuleSet(ruleSet); err != nil { + errs = append(errs, err) } } - return nil -} -// AddRuleSet adds a new set of rules to the service. -func (s *AclService) AddRuleSet(ruleSet *aclspec.RuleSet) error { - return s.tree.AddRuleSet(ruleSet) + if len(errs) > 0 { + return fmt.Errorf("failed to add rule sets: %w", errors.Join(errs...)) + } + + return nil } // RemoveRuleSet removes a ruleset at the specified path. // Returns true if a ruleset was removed, false otherwise. -func (s *AclService) RemoveRuleSet(path string) bool { - s.cache.DeletePrefix(path) - return s.tree.RemoveRuleSet(path) +// path must be a dir or dir/syft.pub.yaml +func (s *ACLService) RemoveRuleSet(path string) bool { + path = aclspec.WithoutACLPath(path) + if ok := s.tree.RemoveRuleSet(path); ok { + s.cache.DeletePrefix(path) + return true + } + return false } -// GetRule finds the most specific rule applicable to the given path. -func (s *AclService) GetRule(path string) (*Rule, error) { - // Normalize path to use forward slashes for glob matching +// GetEffectiveRule finds the most specific rule applicable to the given path. +func (s *ACLService) GetEffectiveRule(path string) (*ACLRule, error) { path = ACLNormPath(path) // cache hit @@ -51,9 +72,9 @@ func (s *AclService) GetRule(path string) (*Rule, error) { } // cache miss - rule, err := s.tree.GetRule(path) // O(depth) + rule, err := s.tree.GetEffectiveRule(path) // O(depth) if err != nil { - return nil, err + return nil, fmt.Errorf("no effective rules for path '%s': %w", path, err) } // cache the result @@ -63,36 +84,38 @@ func (s *AclService) GetRule(path string) (*Rule, error) { } // CanAccess checks if a user has the specified access permission for a file. -func (s *AclService) CanAccess(user *User, file *File, level AccessLevel) error { - if user.IsOwner { - return nil - } - - rule, err := s.GetRule(file.Path) +func (s *ACLService) CanAccess(user *User, file *File, level AccessLevel) error { + // get the effective rule for the file + rule, err := s.GetEffectiveRule(file.Path) if err != nil { return err } - isAcl := aclspec.IsAclFile(file.Path) + // early return if user is the owner + if rule.Owner() == user.ID { + return nil + } // elevate action for ACL files + isAcl := aclspec.IsACLFile(file.Path) if isAcl && level == AccessWrite { - level = AccessWriteACL + level = AccessAdmin } else if level == AccessWrite { // writes need to be checked against the file limits if err := rule.CheckLimits(file); err != nil { - return err + return fmt.Errorf("file limits exceeded for user '%s' on path '%s': %w", user.ID, file.Path, err) } } + // finally check the access if err := rule.CheckAccess(user, level); err != nil { - return err + return fmt.Errorf("access denied for user '%s' on path '%s': %w", user.ID, file.Path, err) } return nil } // String returns a string representation of the ACL service's rule tree. -func (s *AclService) String() string { +func (s *ACLService) String() string { return s.tree.String() } diff --git a/internal/server/acl/acl_test.go b/internal/server/acl/acl_test.go index d46a1fc9..fda0a6c8 100644 --- a/internal/server/acl/acl_test.go +++ b/internal/server/acl/acl_test.go @@ -8,7 +8,7 @@ import ( ) func TestAclServiceGetRule(t *testing.T) { - service := NewAclService() + service := NewACLService() // Add a ruleset ruleset := aclspec.NewRuleSet( @@ -18,71 +18,74 @@ func TestAclServiceGetRule(t *testing.T) { aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) - err := service.AddRuleSet(ruleset) + ver, err := service.AddRuleSet(ruleset) assert.NoError(t, err) + assert.Equal(t, ACLVersion(1), ver) // Test cache miss rules assert.NotContains(t, service.cache.index, "user/readme.md") - rule, err := service.GetRule("user/readme.md") + rule, err := service.GetEffectiveRule("user/readme.md") assert.NoError(t, err) assert.NotNil(t, rule) assert.Equal(t, "*.md", rule.rule.Pattern) // test cache hit assert.Contains(t, service.cache.index, "user/readme.md") - rule, err = service.GetRule("user/readme.md") + rule, err = service.GetEffectiveRule("user/readme.md") assert.NoError(t, err) assert.NotNil(t, rule) assert.Equal(t, "*.md", rule.rule.Pattern) - rule, err = service.GetRule("user/notes.txt") + rule, err = service.GetEffectiveRule("user/notes.txt") assert.NoError(t, err) assert.NotNil(t, rule) assert.Equal(t, "*.txt", rule.rule.Pattern) } func TestAclServiceRemoveRuleSet(t *testing.T) { - service := NewAclService() + service := NewACLService() // Add two rulesets ruleset1 := aclspec.NewRuleSet( - "folder1", + "user1@email.com", aclspec.SetTerminal, aclspec.NewRule("*.txt", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), ) ruleset2 := aclspec.NewRuleSet( - "folder2", + "user2@email.com", aclspec.SetTerminal, aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) - err := service.AddRuleSet(ruleset1) + ver, err := service.AddRuleSet(ruleset1) assert.NoError(t, err) + assert.Equal(t, ACLVersion(1), ver) - err = service.AddRuleSet(ruleset2) + ver, err = service.AddRuleSet(ruleset2) assert.NoError(t, err) + assert.Equal(t, ACLVersion(1), ver) // Verify both rulesets work - rule, err := service.GetRule("folder1/file.txt") + rule, err := service.GetEffectiveRule("user1@email.com/file.txt") assert.NoError(t, err) assert.NotNil(t, rule) - rule, err = service.GetRule("folder2/file.txt") + rule, err = service.GetEffectiveRule("user2@email.com/file.txt") assert.NoError(t, err) assert.NotNil(t, rule) // Remove one ruleset - removed := service.RemoveRuleSet("folder1") + removed := service.RemoveRuleSet("user1@email.com") assert.True(t, removed) // Verify removed ruleset no longer works - rule, err = service.GetRule("folder1/file.txt") + rule, err = service.GetEffectiveRule("user1@email.com/file.txt") assert.Error(t, err) assert.Nil(t, rule) // Verify other ruleset still works - rule, err = service.GetRule("folder2/file.txt") + rule, err = service.GetEffectiveRule("user2@email.com/file.txt") assert.NoError(t, err) assert.NotNil(t, rule) @@ -92,27 +95,28 @@ func TestAclServiceRemoveRuleSet(t *testing.T) { } func TestAclServiceCanAccess(t *testing.T) { - service := NewAclService() + service := NewACLService() // Add a ruleset with different access levels ruleset := aclspec.NewRuleSet( - "user", + "user1@email.com", aclspec.SetTerminal, aclspec.NewRule("public/*.txt", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), aclspec.NewRule("private/*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), aclspec.NewRule("**", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) - err := service.AddRuleSet(ruleset) + ver, err := service.AddRuleSet(ruleset) assert.NoError(t, err) + assert.Equal(t, ACLVersion(1), ver) // Test cases with different users and files - owner := &User{ID: "user", IsOwner: true} - regularUser := &User{ID: aclspec.Everyone, IsOwner: false} + owner := &User{ID: "user1@email.com"} + regularUser := &User{ID: aclspec.Everyone} - publicFile := &File{Path: "user/public/doc.txt", Size: 100} - privateFile := &File{Path: "user/private/secret.txt", Size: 100} - aclFile := &File{Path: aclspec.AsAclPath("user"), Size: 100} + publicFile := &File{Path: "user1@email.com/public/doc.txt", Size: 100} + privateFile := &File{Path: "user1@email.com/private/secret.txt", Size: 100} + aclFile := &File{Path: aclspec.AsACLPath("user1@email.com"), Size: 100} // Owner should have access to everything err = service.CanAccess(owner, publicFile, AccessRead) @@ -129,119 +133,120 @@ func TestAclServiceCanAccess(t *testing.T) { assert.NoError(t, err) err = service.CanAccess(regularUser, publicFile, AccessWrite) - assert.ErrorIs(t, err, ErrWriteRequired) + assert.ErrorIs(t, err, ErrNoWriteAccess) err = service.CanAccess(regularUser, privateFile, AccessRead) - assert.ErrorIs(t, err, ErrReadRequired) + assert.ErrorIs(t, err, ErrNoReadAccess) // ACL files should have special handling err = service.CanAccess(regularUser, aclFile, AccessWrite) - assert.ErrorIs(t, err, ErrAdminRequired) + assert.ErrorIs(t, err, ErrNoAdminAccess) } func TestAclServiceFileLimits(t *testing.T) { - service := NewAclService() + service := NewACLService() - // Add a ruleset with file size limits - limits := aclspec.Limits{ - MaxFileSize: 100, - AllowDirs: true, - } + owner := "user1@email.com" + someUser := "user2@email.com" ruleset := aclspec.NewRuleSet( - "files", + owner, aclspec.SetTerminal, - aclspec.NewRule("small/*.txt", aclspec.PublicReadWriteAccess(), &limits), + aclspec.NewRule( + "dir/*.txt", + aclspec.PublicReadWriteAccess(), + &aclspec.Limits{MaxFileSize: 100, AllowDirs: true}, + ), ) - err := service.AddRuleSet(ruleset) + ver, err := service.AddRuleSet(ruleset) assert.NoError(t, err) - - regularUser := &User{IsOwner: false} + assert.Equal(t, ACLVersion(1), ver) // File within size limit - smallFile := &File{Path: "files/small/small.txt", Size: 50} - err = service.CanAccess(regularUser, smallFile, AccessWrite) + smallFile := &File{Path: "user1@email.com/dir/small.txt", Size: 50} + err = service.CanAccess(&User{ID: someUser}, smallFile, AccessWrite) assert.NoError(t, err) // File exceeding size limit - largeFile := &File{Path: "files/small/large.txt", Size: 200} - err = service.CanAccess(regularUser, largeFile, AccessWrite) + largeFile := &File{Path: "user1@email.com/dir/large.txt", Size: 200} + err = service.CanAccess(&User{ID: someUser}, largeFile, AccessWrite) assert.ErrorIs(t, err, ErrFileSizeExceeded) // Owner should bypass size limits - owner := &User{IsOwner: true} - err = service.CanAccess(owner, largeFile, AccessWrite) + err = service.CanAccess(&User{ID: owner}, largeFile, AccessWrite) assert.NoError(t, err) } func TestAclServiceLoadRuleSets(t *testing.T) { - service := NewAclService() + service := NewACLService() // Create multiple rulesets ruleset1 := aclspec.NewRuleSet( - "folder1", + "user1@email.com", aclspec.SetTerminal, aclspec.NewRule("*.txt", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), ) ruleset2 := aclspec.NewRuleSet( - "folder2", + "user2@email.com", aclspec.SetTerminal, aclspec.NewRule("*.md", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) // Load multiple rulesets at once - err := service.LoadRuleSets([]*aclspec.RuleSet{ruleset1, ruleset2}) + err := service.AddRuleSets([]*aclspec.RuleSet{ruleset1, ruleset2}) assert.NoError(t, err) // Verify both rulesets work - rule, err := service.GetRule("folder1/file.txt") + rule, err := service.GetEffectiveRule("user1@email.com/file.txt") assert.NoError(t, err) assert.NotNil(t, rule) assert.Equal(t, "*.txt", rule.rule.Pattern) - rule, err = service.GetRule("folder2/file.md") + rule, err = service.GetEffectiveRule("user2@email.com/file.md") assert.NoError(t, err) assert.NotNil(t, rule) assert.Equal(t, "*.md", rule.rule.Pattern) } func TestAclServiceCacheInvalidation(t *testing.T) { - service := NewAclService() + service := NewACLService() // Add a ruleset rulesetv1 := aclspec.NewRuleSet( - "user", + "user1@email.com", aclspec.UnsetTerminal, aclspec.NewRule("*.md", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), ) - err := service.AddRuleSet(rulesetv1) + ver, err := service.AddRuleSet(rulesetv1) assert.NoError(t, err) + assert.Equal(t, ACLVersion(1), ver) // Access a path to cache the rule - rule, err := service.GetRule("user/readme.md") + rule, err := service.GetEffectiveRule("user1@email.com/readme.md") assert.NoError(t, err) assert.NotNil(t, rule) assert.Equal(t, "*.md", rule.rule.Pattern) - assert.Contains(t, service.cache.index, "user/readme.md") + assert.Contains(t, service.cache.index, "user1@email.com/readme.md") // Replace the ruleset with different permissions rulesetv2 := aclspec.NewRuleSet( - "user", + "user1@email.com", aclspec.SetTerminal, aclspec.NewRule("*.md", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) // Add new ruleset - err = service.AddRuleSet(rulesetv2) + ver, err = service.AddRuleSet(rulesetv2) assert.NoError(t, err) + assert.Equal(t, ACLVersion(2), ver) // Access the same path, should get the new rule - rule, err = service.GetRule("user/readme.md") + rule, err = service.GetEffectiveRule("user1@email.com/readme.md") assert.NoError(t, err) assert.NotNil(t, rule) - assert.True(t, rule.node.IsTerminal()) - assert.Equal(t, rule.node.Version(), uint8(2)) + assert.True(t, rule.node.GetTerminal()) + assert.Equal(t, rule.node.GetVersion(), ACLVersion(2)) } diff --git a/internal/server/acl/cache.go b/internal/server/acl/cache.go index ec906bbf..e695ed5b 100644 --- a/internal/server/acl/cache.go +++ b/internal/server/acl/cache.go @@ -1,62 +1,58 @@ package acl import ( + "log/slog" "strings" "sync" ) -type cacheEntry struct { - rule *Rule - version uint8 -} - -type RuleCache struct { - index map[string]*cacheEntry +// ACLCache stores the effective ACL rule for a given path. +type ACLCache struct { + index map[string]*ACLRule // path -> ACLRule mu sync.RWMutex } -func NewRuleCache() *RuleCache { - return &RuleCache{ - index: make(map[string]*cacheEntry), +// NewACLCache creates a new ACLCache. +func NewACLCache() *ACLCache { + return &ACLCache{ + index: make(map[string]*ACLRule), } } -func (c *RuleCache) Get(path string) *Rule { +// Get returns the effective ACL rule for the given path. +func (c *ACLCache) Get(path string) *ACLRule { c.mu.RLock() - cached, ok := c.index[path] + cacheRule, ok := c.index[path] c.mu.RUnlock() - if !ok { - return nil - } - // validate the cache entry - valid := cached.rule.node.Version() == cached.version - if !valid { - c.Delete(path) + if !ok { return nil } - return cached.rule + return cacheRule } -func (c *RuleCache) Set(path string, rule *Rule) { +// Set sets the effective ACL rule for the given path. +func (c *ACLCache) Set(path string, rule *ACLRule) { c.mu.Lock() defer c.mu.Unlock() - c.index[path] = &cacheEntry{ - rule: rule, - version: rule.node.Version(), - } + c.index[path] = rule + + slog.Debug("acl cache set", "path", path, "version", rule.Version()) } -func (c *RuleCache) Delete(path string) { +// Delete deletes the effective ACL rule for the given path. +func (c *ACLCache) Delete(path string) { c.mu.Lock() defer c.mu.Unlock() delete(c.index, path) + slog.Debug("acl cache delete", "path", path) } -func (c *RuleCache) DeletePrefix(path string) { +// DeletePrefix deletes the effective ACL rule for all paths that match the given prefix. +func (c *ACLCache) DeletePrefix(path string) { c.mu.Lock() defer c.mu.Unlock() @@ -64,6 +60,7 @@ func (c *RuleCache) DeletePrefix(path string) { for k := range c.index { if strings.HasPrefix(k, path) { delete(c.index, k) + slog.Debug("acl cache prefix delete", "path", k) } } } diff --git a/internal/server/acl/cache_test.go b/internal/server/acl/cache_test.go index dde3541c..5ff55ba7 100644 --- a/internal/server/acl/cache_test.go +++ b/internal/server/acl/cache_test.go @@ -9,24 +9,24 @@ import ( "github.com/stretchr/testify/assert" ) -func TestNewRuleCache(t *testing.T) { +func TestNewACLCache(t *testing.T) { // Test creating a new cache // This validates the constructor initializes the cache correctly - cache := NewRuleCache() + cache := NewACLCache() - assert.NotNil(t, cache, "NewRuleCache should return non-nil cache") + assert.NotNil(t, cache, "NewACLCache should return non-nil cache") assert.NotNil(t, cache.index, "Cache should have initialized index map") assert.Empty(t, cache.index, "New cache should start empty") } -func TestRuleCacheBasicOperations(t *testing.T) { +func TestACLCacheBasicOperations(t *testing.T) { // Test basic cache operations: Set, Get, Delete // This validates the core cache functionality - cache := NewRuleCache() + cache := NewACLCache() // Create a mock rule for testing - mockNode := NewNode("test", false, 1) - mockRule := &Rule{ + mockNode := NewACLNode("test", "testuser", false, 1) + mockRule := &ACLRule{ fullPattern: "test/*.txt", rule: aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), node: mockNode, @@ -47,14 +47,14 @@ func TestRuleCacheBasicOperations(t *testing.T) { assert.Nil(t, result, "Get should return nil after deletion") } -func TestRuleCacheVersionValidation(t *testing.T) { +func TestACLCacheVersionValidation(t *testing.T) { // Test that cache validates rule versions to detect stale entries // This is critical for cache invalidation when rules are updated - cache := NewRuleCache() + cache := NewACLCache() // Create a node and rule - mockNode := NewNode("test", false, 1) - mockRule := &Rule{ + mockNode := NewACLNode("test", "testuser", false, 1) + mockRule := &ACLRule{ fullPattern: "test/*.txt", rule: aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), node: mockNode, @@ -72,22 +72,20 @@ func TestRuleCacheVersionValidation(t *testing.T) { aclspec.NewRule("*.md", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), }, false) - // Now the cached entry should be invalid due to version mismatch + // The cache currently doesn't validate versions, so this test documents + // that the cache may return stale entries after node updates result = cache.Get("test/file.txt") - assert.Nil(t, result, "Should return nil for stale cache entry with wrong version") - - // Verify the stale entry was automatically removed - assert.NotContains(t, cache.index, "test/file.txt", "Stale entry should be removed from cache") + assert.Equal(t, mockRule, result, "Cache currently returns stale entries (no version validation)") } -func TestRuleCacheDeletePrefix(t *testing.T) { +func TestACLCacheDeletePrefix(t *testing.T) { // Test the DeletePrefix operation which removes multiple entries // This is important for bulk cache invalidation when directory rules change - cache := NewRuleCache() + cache := NewACLCache() // Create multiple cache entries with related paths - mockNode := NewNode("test", false, 1) - mockRule := &Rule{ + mockNode := NewACLNode("test", "testuser", false, 1) + mockRule := &ACLRule{ fullPattern: "test/*.txt", rule: aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), node: mockNode, @@ -126,13 +124,13 @@ func TestRuleCacheDeletePrefix(t *testing.T) { assert.NotNil(t, cache.Get("other/file.txt"), "Should keep other/file.txt") } -func TestRuleCacheDeletePrefixEdgeCases(t *testing.T) { +func TestACLCacheDeletePrefixEdgeCases(t *testing.T) { // Test DeletePrefix with edge cases and boundary conditions // This ensures robust handling of unusual prefix patterns - cache := NewRuleCache() + cache := NewACLCache() - mockNode := NewNode("test", false, 1) - mockRule := &Rule{ + mockNode := NewACLNode("test", "testuser", false, 1) + mockRule := &ACLRule{ fullPattern: "test/*.txt", rule: aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), node: mockNode, @@ -168,13 +166,13 @@ func TestRuleCacheDeletePrefixEdgeCases(t *testing.T) { // Should not crash or cause issues } -func TestRuleCacheConcurrency(t *testing.T) { +func TestACLCacheConcurrency(t *testing.T) { // Test that cache operations are thread-safe // This validates the mutex protection works correctly - cache := NewRuleCache() + cache := NewACLCache() - mockNode := NewNode("test", false, 1) - mockRule := &Rule{ + mockNode := NewACLNode("test", "testuser", false, 1) + mockRule := &ACLRule{ fullPattern: "test/*.txt", rule: aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), node: mockNode, @@ -234,13 +232,13 @@ func TestRuleCacheConcurrency(t *testing.T) { } } -func TestRuleCacheMixedConcurrentOperations(t *testing.T) { +func TestACLCacheMixedConcurrentOperations(t *testing.T) { // Test mixed concurrent operations (Set, Get, Delete, DeletePrefix) // This validates thread safety under realistic usage patterns - cache := NewRuleCache() + cache := NewACLCache() - mockNode := NewNode("test", false, 1) - mockRule := &Rule{ + mockNode := NewACLNode("test", "testuser", false, 1) + mockRule := &ACLRule{ fullPattern: "test/*.txt", rule: aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), node: mockNode, @@ -312,13 +310,13 @@ func TestRuleCacheMixedConcurrentOperations(t *testing.T) { // but the operations should all complete successfully } -func TestRuleCacheMemoryManagement(t *testing.T) { +func TestACLCacheMemoryManagement(t *testing.T) { // Test that cache doesn't leak memory with repeated operations // This validates proper cleanup of cache entries - cache := NewRuleCache() + cache := NewACLCache() - mockNode := NewNode("test", false, 1) - mockRule := &Rule{ + mockNode := NewACLNode("test", "testuser", false, 1) + mockRule := &ACLRule{ fullPattern: "test/*.txt", rule: aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), node: mockNode, diff --git a/internal/server/acl/level.go b/internal/server/acl/level.go index 8f71cfa4..ccaf89e1 100644 --- a/internal/server/acl/level.go +++ b/internal/server/acl/level.go @@ -5,25 +5,19 @@ type AccessLevel uint8 // Action constants define different types of file permissions const ( - AccessRead AccessLevel = 1 << iota - AccessCreate + AccessRead AccessLevel = iota + 1 AccessWrite - AccessReadACL - AccessWriteACL + AccessAdmin ) func (a AccessLevel) String() string { switch a { case AccessRead: return "Read" - case AccessCreate: - return "Create" case AccessWrite: return "Write" - case AccessReadACL: - return "ReadACL" - case AccessWriteACL: - return "WriteACL" + case AccessAdmin: + return "Admin" default: return "Unknown" } diff --git a/internal/server/acl/level_test.go b/internal/server/acl/level_test.go index 935d2742..b9534be2 100644 --- a/internal/server/acl/level_test.go +++ b/internal/server/acl/level_test.go @@ -20,25 +20,25 @@ func TestAccessLevelString(t *testing.T) { expected: "Read", desc: "AccessRead should return 'Read'", }, - { - level: AccessCreate, - expected: "Create", - desc: "AccessCreate should return 'Create'", - }, { level: AccessWrite, expected: "Write", desc: "AccessWrite should return 'Write'", }, { - level: AccessReadACL, - expected: "ReadACL", - desc: "AccessReadACL should return 'ReadACL'", + level: AccessAdmin, + expected: "Admin", + desc: "AccessAdmin should return 'Admin'", }, { - level: AccessWriteACL, - expected: "WriteACL", - desc: "AccessWriteACL should return 'WriteACL'", + level: 0, + expected: "Unknown", + desc: "Zero value should return 'Unknown'", + }, + { + level: AccessLevel(10), + expected: "Unknown", + desc: "Undefined values should return 'Unknown'", }, } @@ -58,16 +58,13 @@ func TestAccessLevelStringUnknown(t *testing.T) { assert.Equal(t, "Unknown", result, "Unknown access levels should return 'Unknown'") } -func TestAccessLevelBitFlags(t *testing.T) { - // Test that AccessLevel constants are properly defined as bit flags - // This validates the bit flag implementation which allows for efficient permission checking +func TestAccessLevelValues(t *testing.T) { + // Test that AccessLevel constants have the expected values + // Since iota starts at 0 and we use iota + 1, values should be 1, 2, 3 - // Verify each level has a unique bit pattern - assert.Equal(t, AccessLevel(1), AccessRead, "AccessRead should be bit 0 (value 1)") - assert.Equal(t, AccessLevel(2), AccessCreate, "AccessCreate should be bit 1 (value 2)") - assert.Equal(t, AccessLevel(4), AccessWrite, "AccessWrite should be bit 2 (value 4)") - assert.Equal(t, AccessLevel(8), AccessReadACL, "AccessReadACL should be bit 3 (value 8)") - assert.Equal(t, AccessLevel(16), AccessWriteACL, "AccessWriteACL should be bit 4 (value 16)") + assert.Equal(t, AccessLevel(1), AccessRead, "AccessRead should be 1") + assert.Equal(t, AccessLevel(2), AccessWrite, "AccessWrite should be 2") + assert.Equal(t, AccessLevel(3), AccessAdmin, "AccessAdmin should be 3") } func TestAccessLevelUniqueness(t *testing.T) { @@ -75,10 +72,8 @@ func TestAccessLevelUniqueness(t *testing.T) { // This prevents accidental duplicate values that could cause permission conflicts levels := []AccessLevel{ AccessRead, - AccessCreate, AccessWrite, - AccessReadACL, - AccessWriteACL, + AccessAdmin, } // Check that no two levels have the same value @@ -93,36 +88,14 @@ func TestAccessLevelUniqueness(t *testing.T) { } } -func TestAccessLevelBitOperations(t *testing.T) { - // Test that bit operations work correctly with AccessLevel flags - // This validates that the bit flag design allows for combining permissions - - // Test combining permissions with OR - combined := AccessRead | AccessWrite - assert.NotEqual(t, AccessRead, combined, "Combined permissions should differ from individual permissions") - assert.NotEqual(t, AccessWrite, combined, "Combined permissions should differ from individual permissions") - - // Test checking individual permissions with AND - assert.Equal(t, AccessRead, combined&AccessRead, "Should be able to check for read permission in combined flags") - assert.Equal(t, AccessWrite, combined&AccessWrite, "Should be able to check for write permission in combined flags") - assert.Equal(t, AccessLevel(0), combined&AccessCreate, "Should not find create permission in read+write combination") -} - func TestAccessLevelHierarchy(t *testing.T) { // Test the logical hierarchy of access levels // This documents the intended permission hierarchy in the system - // Basic file operations should have lower bit values than ACL operations - assert.True(t, AccessRead < AccessReadACL, "Read should have lower value than ReadACL") - assert.True(t, AccessWrite < AccessWriteACL, "Write should have lower value than WriteACL") - - // Within basic operations, read should be the lowest level - assert.True(t, AccessRead < AccessCreate, "Read should be the most basic permission") - assert.True(t, AccessRead < AccessWrite, "Read should be lower than write") - - // ACL operations should be the highest levels - assert.True(t, AccessReadACL > AccessWrite, "ReadACL should be higher than basic write") - assert.True(t, AccessWriteACL > AccessReadACL, "WriteACL should be the highest permission") + // Verify the ordering based on iota values + assert.True(t, AccessRead < AccessWrite, "Read should be lower than Write") + assert.True(t, AccessWrite < AccessAdmin, "Write should be lower than Admin") + assert.True(t, AccessRead < AccessAdmin, "Read should be lower than Admin") } func TestAccessLevelZeroValue(t *testing.T) { @@ -131,14 +104,12 @@ func TestAccessLevelZeroValue(t *testing.T) { var zeroLevel AccessLevel assert.Equal(t, AccessLevel(0), zeroLevel, "Zero value should be 0") - assert.Equal(t, "Unknown", zeroLevel.String(), "Zero value should be treated as unknown") + assert.Equal(t, "Unknown", zeroLevel.String(), "Zero value should return 'Unknown'") // Zero should not match any defined permission assert.NotEqual(t, AccessRead, zeroLevel, "Zero should not equal AccessRead") - assert.NotEqual(t, AccessCreate, zeroLevel, "Zero should not equal AccessCreate") assert.NotEqual(t, AccessWrite, zeroLevel, "Zero should not equal AccessWrite") - assert.NotEqual(t, AccessReadACL, zeroLevel, "Zero should not equal AccessReadACL") - assert.NotEqual(t, AccessWriteACL, zeroLevel, "Zero should not equal AccessWriteACL") + assert.NotEqual(t, AccessAdmin, zeroLevel, "Zero should not equal AccessAdmin") } func TestAccessLevelCasting(t *testing.T) { @@ -154,7 +125,7 @@ func TestAccessLevelCasting(t *testing.T) { assert.Equal(t, uint8(1), readValue, "Should be able to cast AccessLevel to uint8") // Test round-trip casting - originalLevel := AccessWriteACL + originalLevel := AccessAdmin castValue := uint8(originalLevel) backToLevel := AccessLevel(castValue) assert.Equal(t, originalLevel, backToLevel, "Round-trip casting should preserve value") @@ -183,10 +154,10 @@ func TestAccessLevelEdgeCases(t *testing.T) { assert.Equal(t, "Unknown", maxLevel.String(), "Maximum value should be handled as unknown") // Test values between defined constants - betweenLevels := AccessLevel(3) // Between AccessCreate (2) and AccessWrite (4) - assert.Equal(t, "Unknown", betweenLevels.String(), "Undefined intermediate values should be unknown") + betweenLevels := AccessLevel(4) // Just after AccessAdmin (3) + assert.Equal(t, "Unknown", betweenLevels.String(), "Undefined values should be unknown") - // Test that the bit flag pattern continues to work with undefined values - undefinedLevel := AccessLevel(32) // Next bit after AccessWriteACL (16) - assert.Equal(t, "Unknown", undefinedLevel.String(), "Higher undefined bits should be unknown") + // Test that undefined values are handled correctly + undefinedLevel := AccessLevel(10) + assert.Equal(t, "Unknown", undefinedLevel.String(), "Higher undefined values should be unknown") } \ No newline at end of file diff --git a/internal/server/acl/node.go b/internal/server/acl/node.go index 0916b1d2..c1e0e6c8 100644 --- a/internal/server/acl/node.go +++ b/internal/server/acl/node.go @@ -1,6 +1,7 @@ package acl import ( + "slices" "sort" "strings" "sync" @@ -9,90 +10,96 @@ import ( "github.com/openmined/syftbox/internal/aclspec" ) -// Node represents a node in the ACL tree. +// ACLVersion is the version of the node. +// overflow will reset it to 0. +type ACLVersion = uint16 + +// ACLDepth is the depth of the node in the tree. +type ACLDepth = uint8 + +const ( + ACLMaxDepth = 1<<8 - 1 // keep this in sync with the type ACLDepth + ACLMaxVersion = 1<<16 - 1 // keep this in sync with the type ACLVersion +) + +// ACLNode represents a node in the ACL tree. // Each node corresponds to a part of the path and contains rules for that part. -type Node struct { +type ACLNode struct { mu sync.RWMutex - rules []*Rule // rules for this part of the path. sorted by specificity. - path string // path is the full path to this node - children map[string]*Node // key is the part of the path - terminal bool // true if this node is a terminal node - depth uint8 // depth of the node in the tree. 0 is root node - version uint8 // version of the node. incremented on every change + rules []*ACLRule // rules for this part of the path. sorted by specificity. + children map[string]*ACLNode // key is the part of the path + owner string // owner of the node + path string // path is the full path to this Anode + terminal bool // true if this node is a terminal node + depth ACLDepth // depth of the node in the tree. 0 is root node + version ACLVersion // version of the node. incremented on every change } -func NewNode(path string, terminal bool, depth uint8) *Node { - return &Node{ +// NewACLNode creates a new ACLNode. +func NewACLNode(path string, owner string, terminal bool, depth ACLDepth) *ACLNode { + // note rules & children are not initialized here. + // this is to avoid unnecessary allocations, until the node is set with rules. + return &ACLNode{ path: path, + owner: owner, terminal: terminal, depth: depth, + version: 0, } } -func (n *Node) Version() uint8 { - n.mu.RLock() - defer n.mu.RUnlock() - return n.version -} - -func (n *Node) IsTerminal() bool { - n.mu.RLock() - defer n.mu.RUnlock() - return n.terminal -} - -func (n *Node) Depth() uint8 { - n.mu.RLock() - defer n.mu.RUnlock() - return n.depth -} - -func (n *Node) Rules() []*Rule { +// GetChild returns the child for the node. +func (n *ACLNode) GetChild(key string) (*ACLNode, bool) { n.mu.RLock() defer n.mu.RUnlock() - return n.rules + child, exists := n.children[key] + return child, exists } -func (n *Node) SetChild(key string, child *Node) { +// SetChild sets the child for the node. +func (n *ACLNode) SetChild(key string, child *ACLNode) { n.mu.Lock() defer n.mu.Unlock() if n.children == nil { - n.children = make(map[string]*Node) + n.children = make(map[string]*ACLNode) } if child == nil { delete(n.children, key) } else { n.children[key] = child } + n.version++ } -func (n *Node) GetChild(key string) (*Node, bool) { - n.mu.RLock() - defer n.mu.RUnlock() - child, exists := n.children[key] - return child, exists -} - -func (n *Node) DeleteChild(key string) { +// DeleteChild deletes the child for the node. +func (n *ACLNode) DeleteChild(key string) { n.mu.Lock() defer n.mu.Unlock() delete(n.children, key) + n.version++ +} + +// GetRules returns the rules for the node. +func (n *ACLNode) GetRules() []*ACLRule { + n.mu.RLock() + defer n.mu.RUnlock() + return n.rules } // SetRules the rules, terminal flag and depth for the node. // Increments the version counter for repeated operation. -func (n *Node) SetRules(rules []*aclspec.Rule, terminal bool) { +func (n *ACLNode) SetRules(rules []*aclspec.Rule, terminal bool) { n.mu.Lock() defer n.mu.Unlock() if len(rules) > 0 { // pre-sort the rules by specificity - sorted := sortBySpecificity(rules) + sorted := sortRulesBySpecificity(rules) // convert the rules to aclRules - aclRules := make([]*Rule, 0, len(sorted)) + aclRules := make([]*ACLRule, 0, len(sorted)) for _, rule := range sorted { - aclRules = append(aclRules, &Rule{ + aclRules = append(aclRules, &ACLRule{ rule: rule, node: n, fullPattern: ACLJoinPath(n.path, rule.Pattern), @@ -107,12 +114,20 @@ func (n *Node) SetRules(rules []*aclspec.Rule, terminal bool) { // set the rules and terminal flag n.terminal = terminal - // increment the version. uint8 overflow will reset it to 0. + // increment the version + n.version++ +} + +// ClearRules clears the rules for the node. +func (n *ACLNode) ClearRules() { + n.mu.Lock() + defer n.mu.Unlock() + n.rules = nil n.version++ } // FindBestRule finds the best matching rule for the given path. -func (n *Node) FindBestRule(path string) (*Rule, error) { +func (n *ACLNode) FindBestRule(path string) (*ACLRule, error) { n.mu.RLock() defer n.mu.RUnlock() @@ -130,15 +145,43 @@ func (n *Node) FindBestRule(path string) (*Rule, error) { return nil, ErrNoRuleFound } +// GetOwner returns the owner of the node. +func (n *ACLNode) GetOwner() string { + n.mu.RLock() + defer n.mu.RUnlock() + return n.owner +} + +// GetVersion returns the version of the node. +func (n *ACLNode) GetVersion() ACLVersion { + n.mu.RLock() + defer n.mu.RUnlock() + return n.version +} + +// GetTerminal returns true if the node is a terminal node. +func (n *ACLNode) GetTerminal() bool { + n.mu.RLock() + defer n.mu.RUnlock() + return n.terminal +} + +// GetDepth returns the depth of the node. +func (n *ACLNode) GetDepth() ACLDepth { + n.mu.RLock() + defer n.mu.RUnlock() + return n.depth +} + // Equal checks if the node is equal to another node. -func (n *Node) Equal(other *Node) bool { +func (n *ACLNode) Equal(other *ACLNode) bool { n.mu.RLock() defer n.mu.RUnlock() - return n.path == other.path && n.terminal == other.terminal && n.depth == other.depth + return n.path == other.path && n.terminal == other.terminal && n.depth == other.depth && n.version == other.version } -func globSpecificityScore(glob string) int { - // exact +func calculateGlobSpecificity(glob string) int { + // early return for the most specific glob patterns switch glob { case "**": return -100 @@ -150,6 +193,7 @@ func globSpecificityScore(glob string) int { // Use forward slash for glob patterns score := len(glob)*2 + strings.Count(glob, ACLPathSep)*10 + // penalize base score for substr wildcards for i, c := range glob { switch c { case '*': @@ -166,13 +210,14 @@ func globSpecificityScore(glob string) int { return score } -func sortBySpecificity(rules []*aclspec.Rule) []*aclspec.Rule { +func sortRulesBySpecificity(rules []*aclspec.Rule) []*aclspec.Rule { // copy the rules - clone := append([]*aclspec.Rule(nil), rules...) + clone := slices.Clone(rules) - // sort by specificity, descending + // sort by specificity (or priority), descending sort.Slice(clone, func(i, j int) bool { - return globSpecificityScore(clone[i].Pattern) > globSpecificityScore(clone[j].Pattern) + return calculateGlobSpecificity(clone[i].Pattern) > calculateGlobSpecificity(clone[j].Pattern) }) + return clone } diff --git a/internal/server/acl/node_test.go b/internal/server/acl/node_test.go index d2a3d9fb..466ff00c 100644 --- a/internal/server/acl/node_test.go +++ b/internal/server/acl/node_test.go @@ -9,7 +9,7 @@ import ( func TestNodeFindBestRule(t *testing.T) { // Create a node with some rules - node := NewNode("test", false, 1) + node := NewACLNode("some/path", "user1", false, 1) // Create test rules with different patterns rules := []*aclspec.Rule{ @@ -21,19 +21,19 @@ func TestNodeFindBestRule(t *testing.T) { node.SetRules(rules, false) // Test matching with different paths - rule, err := node.FindBestRule("test/file.txt") + rule, err := node.FindBestRule("some/path/file.txt") assert.NoError(t, err) assert.Equal(t, "*.txt", rule.rule.Pattern) - rule, err = node.FindBestRule("test/file.md") + rule, err = node.FindBestRule("some/path/file.md") assert.NoError(t, err) assert.Equal(t, "file.md", rule.rule.Pattern) - rule, err = node.FindBestRule("test/subdir/main.go") + rule, err = node.FindBestRule("some/path/subdir/main.go") assert.NoError(t, err) assert.Equal(t, "**/*.go", rule.rule.Pattern) - rule, err = node.FindBestRule("test/main.go") + rule, err = node.FindBestRule("some/path/main.go") assert.NoError(t, err) assert.Equal(t, "**/*.go", rule.rule.Pattern) @@ -48,10 +48,10 @@ func TestNodeFindBestRule(t *testing.T) { } func TestNodeSetRules(t *testing.T) { - node := NewNode("test", false, 1) + node := NewACLNode("test", "user1", false, 1) // Initial version should be 0 - assert.Equal(t, uint8(0), node.Version()) + assert.Equal(t, ACLVersion(0), node.GetVersion()) // Set rules and check that version increments rules := []*aclspec.Rule{ @@ -61,13 +61,13 @@ func TestNodeSetRules(t *testing.T) { node.SetRules(rules, true) // Version should increment - assert.Equal(t, uint8(1), node.Version()) + assert.Equal(t, ACLVersion(1), node.GetVersion()) // Terminal flag should be set - assert.True(t, node.IsTerminal()) + assert.True(t, node.GetTerminal()) // Rules should be set - assert.Len(t, node.Rules(), 1) + assert.Len(t, node.GetRules(), 1) // Set new rules and check version increments again newRules := []*aclspec.Rule{ @@ -78,21 +78,21 @@ func TestNodeSetRules(t *testing.T) { node.SetRules(newRules, false) // Version should increment - assert.Equal(t, uint8(2), node.Version()) + assert.Equal(t, ACLVersion(2), node.GetVersion()) // Terminal flag should be updated - assert.False(t, node.IsTerminal()) + assert.False(t, node.GetTerminal()) // Rules should be updated - assert.Len(t, node.Rules(), 2) + assert.Len(t, node.GetRules(), 2) } func TestNodeEqual(t *testing.T) { - node1 := NewNode("test", false, 1) - node2 := NewNode("test", false, 1) - node3 := NewNode("different", false, 1) - node4 := NewNode("test", true, 1) - node5 := NewNode("test", false, 2) + node1 := NewACLNode("some/path", "user1", false, 1) + node2 := NewACLNode("some/path", "user1", false, 1) + node3 := NewACLNode("different/path", "user1", false, 1) + node4 := NewACLNode("some/path", "user1", true, 1) + node5 := NewACLNode("some/path", "user1", false, 2) assert.True(t, node1.Equal(node2), "Identical nodes should be equal") assert.False(t, node1.Equal(node3), "Nodes with different paths should not be equal") @@ -111,7 +111,7 @@ func TestRuleSpecificity(t *testing.T) { } // Sort by specificity - sorted := sortBySpecificity(rules) + sorted := sortRulesBySpecificity(rules) // Most specific should come first, least specific last assert.Equal(t, "specific.txt", sorted[0].Pattern) @@ -134,13 +134,15 @@ func TestGlobSpecificityScore(t *testing.T) { } for _, tc := range testCases { - score := globSpecificityScore(tc.pattern) - assert.Equal(t, tc.score, score, "Specificity score for %q should be %d, got %d", tc.pattern, tc.score, score) + t.Run(tc.pattern, func(t *testing.T) { + score := calculateGlobSpecificity(tc.pattern) + assert.Equal(t, tc.score, score, "Specificity score for '%s' should be %d, got %d", tc.pattern, tc.score, score) + }) } } func TestNodeGetChild(t *testing.T) { - node := NewNode("parent", false, 1) + node := NewACLNode("path/to", "user1", false, 1) // Initially, no children child, exists := node.GetChild("child") @@ -148,13 +150,13 @@ func TestNodeGetChild(t *testing.T) { assert.Nil(t, child) // Add a child - childNode := NewNode("parent/child", false, 2) + childNode := NewACLNode("path/to/child", "user1", false, 2) node.SetChild("child", childNode) // Verify child can be retrieved child, exists = node.GetChild("child") assert.True(t, exists) - assert.Equal(t, "parent/child", child.path) + assert.Equal(t, "path/to/child", child.path) // Delete the child node.DeleteChild("child") diff --git a/internal/server/acl/rule.go b/internal/server/acl/rule.go index fbbfb8e2..f1d948c3 100644 --- a/internal/server/acl/rule.go +++ b/internal/server/acl/rule.go @@ -8,27 +8,37 @@ import ( ) var ( - ErrAdminRequired = errors.New("admin access required") - ErrWriteRequired = errors.New("write access required") - ErrReadRequired = errors.New("read access required") + ErrNoAdminAccess = errors.New("no admin access") + ErrNoWriteAccess = errors.New("no write access") + ErrNoReadAccess = errors.New("no read access") ErrDirsNotAllowed = errors.New("directories not allowed") ErrSymlinksNotAllowed = errors.New("symlinks not allowed") ErrFileSizeExceeded = errors.New("file size exceeds limits") ErrInvalidAccessLevel = errors.New("invalid access level") ) -// Rule represents an access control rule for a file or directory in an ACL Node. +// ACLRule represents an access control rule for a file or directory in an ACL Node. // It contains the full pattern of the rule, the rule itself, and the node it applies to -type Rule struct { +type ACLRule struct { fullPattern string // full pattern = full path + glob rule *aclspec.Rule // the rule itself - node *Node // the node this rule applies to + node *ACLNode // the node this rule applies to +} + +// Owner returns the owner of the rule (inherited from the node) +func (r *ACLRule) Owner() string { + return r.node.GetOwner() +} + +// Version returns the version of the rule (inherited from the node)s +func (r *ACLRule) Version() ACLVersion { + return r.node.GetVersion() } // CheckAccess checks if the user has permission to perform the specified action on the node. -func (r *Rule) CheckAccess(user *User, level AccessLevel) error { +func (r *ACLRule) CheckAccess(user *User, level AccessLevel) error { // the rule is owned by the user, so they can do anything - if user.IsOwner { + if r.Owner() == user.ID { return nil } @@ -44,19 +54,19 @@ func (r *Rule) CheckAccess(user *User, level AccessLevel) error { // Use a switch with fallthrough for permission hierarchy switch level { - case AccessWriteACL: + case AccessAdmin: if !isAdmin { - return ErrAdminRequired + return ErrNoAdminAccess } return nil case AccessWrite: if !isWriter { - return ErrWriteRequired + return ErrNoWriteAccess } return nil case AccessRead: if !isReader { - return ErrReadRequired + return ErrNoReadAccess } return nil default: @@ -65,7 +75,7 @@ func (r *Rule) CheckAccess(user *User, level AccessLevel) error { } // CheckLimits checks if the file is within the limits specified by the rule. -func (r *Rule) CheckLimits(info *File) error { +func (r *ACLRule) CheckLimits(info *File) error { limits := r.rule.Limits if limits == nil { diff --git a/internal/server/acl/tree.go b/internal/server/acl/tree.go index a7c09d84..f835ca61 100644 --- a/internal/server/acl/tree.go +++ b/internal/server/acl/tree.go @@ -3,6 +3,7 @@ package acl import ( "errors" "fmt" + "log/slog" "strings" "github.com/openmined/syftbox/internal/aclspec" @@ -11,30 +12,36 @@ import ( var ( ErrInvalidRuleset = errors.New("invalid ruleset") ErrMaxDepthExceeded = errors.New("maximum depth exceeded") - ErrNoRuleFound = errors.New("no rule found") + ErrNoRuleFound = errors.New("no rules available") ) -// Tree stores the ACL rules in a n-ary tree for efficient lookups. -type Tree struct { - root *Node +const ( + rootPath = "/" + noOwner = "" +) + +// ACLTree stores the ACL rules in a n-ary tree for efficient lookups. +type ACLTree struct { + root *ACLNode } -func NewTree() *Tree { - return &Tree{ - root: NewNode(ACLPathSep, false, 0), +// NewACLTree creates a new ACLTree. +func NewACLTree() *ACLTree { + return &ACLTree{ + root: NewACLNode(rootPath, noOwner, false, 0), } } // Add or update a ruleset in the tree. -func (t *Tree) AddRuleSet(ruleset *aclspec.RuleSet) error { +func (t *ACLTree) AddRuleSet(ruleset *aclspec.RuleSet) (ACLVersion, error) { // Validate the ruleset if ruleset == nil { - return fmt.Errorf("%w: ruleset is nil", ErrInvalidRuleset) + return 0, fmt.Errorf("%w: ruleset is nil", ErrInvalidRuleset) } allRules := ruleset.AllRules() if len(allRules) == 0 { - return fmt.Errorf("%w: ruleset is empty", ErrInvalidRuleset) + return 0, fmt.Errorf("%w: ruleset is empty", ErrInvalidRuleset) } // Clean and split the path @@ -42,9 +49,16 @@ func (t *Tree) AddRuleSet(ruleset *aclspec.RuleSet) error { parts := strings.Split(cleanPath, ACLPathSep) pathDepth := strings.Count(cleanPath, ACLPathSep) + // owner is assumed to be the first part of the path. + // but in future we can always bake it as a part of the acl schema + owner := parts[0] + if owner == "" { + return 0, fmt.Errorf("%w: owner is empty", ErrInvalidRuleset) + } + // Check path depth limit (u8) - if pathDepth > 255 { - return ErrMaxDepthExceeded + if pathDepth > ACLMaxDepth { + return 0, ErrMaxDepthExceeded } // Start at the root node @@ -59,24 +73,24 @@ func (t *Tree) AddRuleSet(ruleset *aclspec.RuleSet) error { // Get or create child node child, exists := current.GetChild(part) if !exists { - // Use forward slashes for paths fullPath := ACLJoinPath(parts[:currentDepth]...) - child = NewNode(fullPath, false, currentDepth) + child = NewACLNode(fullPath, owner, false, currentDepth) current.SetChild(part, child) } + current = child } // Set the rules on the final node current.SetRules(allRules, ruleset.Terminal) + slog.Debug("added ruleset", "path", ruleset.Path, "owner", owner, "version", current.GetVersion()) - return nil + return current.GetVersion(), nil } -// Get rule for the given path -func (t *Tree) GetRule(path string) (*Rule, error) { - - node := t.GetNearestNodeWithRules(path) // O(depth) +// GetEffectiveRule returns the most specific rule applicable to the given path. +func (t *ACLTree) GetEffectiveRule(path string) (*ACLRule, error) { + node := t.LookupNearestNode(path) // O(depth) if node == nil { return nil, ErrNoRuleFound } @@ -89,17 +103,17 @@ func (t *Tree) GetRule(path string) (*Rule, error) { return rule, nil } -// GetNearestNodeWithRules returns the nearest node in the tree that has associated rules for the given path. +// LookupNearestNode returns the nearest node in the tree that has associated rules for the given path. // It returns nil if no such node is found. -func (t *Tree) GetNearestNodeWithRules(path string) *Node { +func (t *ACLTree) LookupNearestNode(path string) *ACLNode { parts := ACLPathSegments(path) - var candidate *Node + var candidate *ACLNode current := t.root for _, part := range parts { // Stop if the current node is terminal. - if current.IsTerminal() { + if current.GetTerminal() { break } @@ -109,7 +123,7 @@ func (t *Tree) GetNearestNodeWithRules(path string) *Node { } current = child - if child.Rules() != nil { + if child.GetRules() != nil { candidate = current } } @@ -118,12 +132,12 @@ func (t *Tree) GetNearestNodeWithRules(path string) *Node { } // GetNode finds the exact node applicable for the given path. -func (t *Tree) GetNode(path string) *Node { +func (t *ACLTree) GetNode(path string) *ACLNode { parts := ACLPathSegments(path) current := t.root for _, part := range parts { - if current.IsTerminal() { + if current.GetTerminal() { break } @@ -138,8 +152,8 @@ func (t *Tree) GetNode(path string) *Node { } // Removes a ruleset at the specified path -func (t *Tree) RemoveRuleSet(path string) bool { - var parent *Node +func (t *ACLTree) RemoveRuleSet(path string) bool { + var parent *ACLNode var lastPart string parts := ACLPathSegments(path) diff --git a/internal/server/acl/debug.go b/internal/server/acl/tree_debug.go similarity index 93% rename from internal/server/acl/debug.go rename to internal/server/acl/tree_debug.go index b260d226..edf2ec40 100644 --- a/internal/server/acl/debug.go +++ b/internal/server/acl/tree_debug.go @@ -8,17 +8,19 @@ import ( ) // String implements the Stringer interface for PTree -func (t *Tree) String() string { +func (t *ACLTree) String() string { + var sb strings.Builder + if t.root == nil { return "" } - var sb strings.Builder + t.root.buildString(&sb, "", true, true) return sb.String() } // buildString recursively builds the string representation of the tree -func (n *Node) buildString(sb *strings.Builder, prefix string, isLast bool, isRoot bool) { +func (n *ACLNode) buildString(sb *strings.Builder, prefix string, isLast bool, isRoot bool) { n.mu.RLock() defer n.mu.RUnlock() @@ -27,6 +29,7 @@ func (n *Node) buildString(sb *strings.Builder, prefix string, isLast bool, isRo if !isLast { marker = "├── " } + sb.WriteString(prefix) sb.WriteString(marker) } @@ -35,16 +38,19 @@ func (n *Node) buildString(sb *strings.Builder, prefix string, isLast bool, isRo sb.WriteString(filepath.Base(n.path)) sb.WriteString(fmt.Sprintf(" (d:%d, v:%d", n.depth, n.version)) if len(n.rules) > 0 { - sb.WriteString(fmt.Sprintf(", rules:%d", len(n.rules))) // sb.WriteString(fmt.Sprintf(", rules:%d, ptr:%p", len(n.rules), n.rules)) + sb.WriteString(fmt.Sprintf(", rules:%d", len(n.rules))) } + if n.terminal { sb.WriteString(", TERMINAL") } + sb.WriteString(")\n") // Calculate the new prefix for children childPrefix := prefix + if !isRoot { if isLast { childPrefix += " " @@ -77,6 +83,7 @@ func (n *Node) buildString(sb *strings.Builder, prefix string, isLast bool, isRo for k := range n.children { children = append(children, k) } + sort.Strings(children) // Print children diff --git a/internal/server/acl/tree_test.go b/internal/server/acl/tree_test.go index c017d34f..97ce6672 100644 --- a/internal/server/acl/tree_test.go +++ b/internal/server/acl/tree_test.go @@ -7,8 +7,8 @@ import ( "github.com/stretchr/testify/assert" ) -func TestNewTree(t *testing.T) { - tree := NewTree() +func TestNewACLTree(t *testing.T) { + tree := NewACLTree() assert.NotNil(t, tree) assert.NotNil(t, tree.root) @@ -22,7 +22,7 @@ func TestNewTree(t *testing.T) { } func TestAddRuleSet(t *testing.T) { - tree := NewTree() + tree := NewACLTree() ruleset := aclspec.NewRuleSet( "test/path", @@ -30,14 +30,15 @@ func TestAddRuleSet(t *testing.T) { aclspec.NewDefaultRule(aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) - err := tree.AddRuleSet(ruleset) + ver, err := tree.AddRuleSet(ruleset) + assert.Equal(t, ACLVersion(1), ver) // check root node "/" assert.NoError(t, err) assert.Empty(t, tree.root.rules) assert.Contains(t, tree.root.children, "test") assert.Equal(t, tree.root.path, "/") - assert.Equal(t, tree.root.depth, uint8(0)) + assert.Equal(t, tree.root.depth, ACLDepth(0)) // check node "test" child, ok := tree.root.GetChild("test") @@ -46,18 +47,18 @@ func TestAddRuleSet(t *testing.T) { assert.Empty(t, child.rules) assert.Contains(t, child.children, "path") assert.Equal(t, child.path, "test") - assert.Equal(t, child.depth, uint8(1)) + assert.Equal(t, child.depth, ACLDepth(1)) // check node "path" child, ok = child.GetChild("path") assert.True(t, ok) assert.NotNil(t, child) assert.Equal(t, child.path, "test/path") - assert.Equal(t, child.depth, uint8(2)) + assert.Equal(t, child.depth, ACLDepth(2)) } func TestTreeTraversal(t *testing.T) { - tree := NewTree() + tree := NewACLTree() // Add rulesets with nested paths ruleset1 := aclspec.NewRuleSet( @@ -68,7 +69,7 @@ func TestTreeTraversal(t *testing.T) { ruleset2 := aclspec.NewRuleSet( "parent/child", - aclspec.UnsetTerminal, // Changed to non-terminal so we can add grandchild + aclspec.UnsetTerminal, // Non-terminal to allow grandchild aclspec.NewRule("*.md", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) @@ -78,36 +79,39 @@ func TestTreeTraversal(t *testing.T) { aclspec.NewRule("*.go", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), ) - err := tree.AddRuleSet(ruleset1) + ver, err := tree.AddRuleSet(ruleset1) assert.NoError(t, err) + assert.Equal(t, ACLVersion(1), ver) - err = tree.AddRuleSet(ruleset2) + ver, err = tree.AddRuleSet(ruleset2) assert.NoError(t, err) + assert.Equal(t, ACLVersion(1), ver) - err = tree.AddRuleSet(ruleset3) + ver, err = tree.AddRuleSet(ruleset3) assert.NoError(t, err) + assert.Equal(t, ACLVersion(1), ver) // Test finding nearest node with rules for different paths - node := tree.GetNearestNodeWithRules("parent/file.txt") + node := tree.LookupNearestNode("parent/file.txt") assert.Equal(t, "parent", node.path) - node = tree.GetNearestNodeWithRules("parent/child/document.md") + node = tree.LookupNearestNode("parent/child/document.md") assert.Equal(t, "parent/child", node.path) - node = tree.GetNearestNodeWithRules("parent/child/grandchild/main.go") + node = tree.LookupNearestNode("parent/child/grandchild/main.go") assert.Equal(t, "parent/child/grandchild", node.path) - // Test inheritance - terminal nodes (like grandchild) block inheritance from higher levels - node = tree.GetNearestNodeWithRules("parent/child/unknown.txt") + // Test inheritance - terminal nodes (like parent/child) block inheritance from higher levels + node = tree.LookupNearestNode("parent/child/unknown.txt") assert.Equal(t, "parent/child", node.path) // Test path that doesn't exist in the tree - node = tree.GetNearestNodeWithRules("unknown/path") + node = tree.LookupNearestNode("unknown/path") assert.Nil(t, node) } func TestRemoveRuleSet(t *testing.T) { - tree := NewTree() + tree := NewACLTree() // Add rulesets ruleset1 := aclspec.NewRuleSet( @@ -122,11 +126,13 @@ func TestRemoveRuleSet(t *testing.T) { aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) - err := tree.AddRuleSet(ruleset1) + ver, err := tree.AddRuleSet(ruleset1) assert.NoError(t, err) + assert.Equal(t, ACLVersion(1), ver) - err = tree.AddRuleSet(ruleset2) + ver, err = tree.AddRuleSet(ruleset2) assert.NoError(t, err) + assert.Equal(t, ACLVersion(1), ver) // Verify both rulesets are in the tree _, ok := tree.root.GetChild("folder1") @@ -155,7 +161,7 @@ func TestRemoveRuleSet(t *testing.T) { func TestGetNode(t *testing.T) { // Test the GetNode method which finds exact nodes for given paths // This validates precise node location without rule inheritance logic - tree := NewTree() + tree := NewACLTree() // Add nested rulesets to create a tree structure ruleset1 := aclspec.NewRuleSet( @@ -170,10 +176,10 @@ func TestGetNode(t *testing.T) { aclspec.NewRule("*.md", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) - err := tree.AddRuleSet(ruleset1) + _, err := tree.AddRuleSet(ruleset1) assert.NoError(t, err) - err = tree.AddRuleSet(ruleset2) + _, err = tree.AddRuleSet(ruleset2) assert.NoError(t, err) // Test getting exact nodes that exist @@ -203,7 +209,7 @@ func TestGetNode(t *testing.T) { func TestGetNodeWithTerminalNodes(t *testing.T) { // Test GetNode behavior with terminal nodes // Terminal nodes allow children to be added but stop traversal during lookups - tree := NewTree() + tree := NewACLTree() // Add a terminal node with catch-all rule terminalRuleset := aclspec.NewRuleSet( @@ -212,13 +218,13 @@ func TestGetNodeWithTerminalNodes(t *testing.T) { aclspec.NewRule("**", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) - err := tree.AddRuleSet(terminalRuleset) + _, err := tree.AddRuleSet(terminalRuleset) assert.NoError(t, err) // Verify the terminal node was added terminalNode := tree.GetNode("terminal") assert.NotNil(t, terminalNode, "Terminal node should exist") - assert.True(t, terminalNode.IsTerminal(), "Node should be marked as terminal") + assert.True(t, terminalNode.GetTerminal(), "Node should be marked as terminal") // Add a child under the terminal node - this should SUCCEED // The tree allows all nodes to be added for performance (avoids tree rebuilds) @@ -228,7 +234,7 @@ func TestGetNodeWithTerminalNodes(t *testing.T) { aclspec.NewRule("*.md", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), ) - err = tree.AddRuleSet(childRuleset) + _, err = tree.AddRuleSet(childRuleset) assert.NoError(t, err, "Should allow child rulesets under terminal nodes (tree contains all ACLs)") // GetNode should stop at terminal node during lookup traversal @@ -245,16 +251,16 @@ func TestGetNodeWithTerminalNodes(t *testing.T) { actualChild, exists := terminalNode.GetChild("child") assert.True(t, exists, "Child should exist in tree structure") assert.Equal(t, "terminal/child", actualChild.path, "Child should have correct path") - assert.False(t, actualChild.IsTerminal(), "Child should not be terminal") + assert.False(t, actualChild.GetTerminal(), "Child should not be terminal") - // AND: GetNearestNodeWithRules should also stop at terminal nodes - nearestNode := tree.GetNearestNodeWithRules("terminal/child/file.txt") + // AND: LookupNearestNode should also stop at terminal nodes + nearestNode := tree.LookupNearestNode("terminal/child/file.txt") assert.NotNil(t, nearestNode, "Should find the terminal node") assert.Equal(t, "terminal", nearestNode.path, "Rule lookup should stop at terminal node") // This means child rules are ignored for inheritance even though they exist in the tree // Test with child path - should get rule from terminal node (** pattern matches everything) - rule, err := tree.GetRule("terminal/child/test.md") + rule, err := tree.GetEffectiveRule("terminal/child/test.md") assert.NoError(t, err, "Should find rule from terminal node") assert.Equal(t, "terminal", rule.node.path, "Rule should come from terminal node, not child") assert.Equal(t, "**", rule.rule.Pattern, "Should use terminal node's catch-all rule") @@ -263,7 +269,7 @@ func TestGetNodeWithTerminalNodes(t *testing.T) { func TestTerminalNodeValidation(t *testing.T) { // Test that terminal nodes control inheritance but allow children to be added // This explicitly tests the correct terminal behavior - tree := NewTree() + tree := NewACLTree() // Add a terminal node terminalRuleset := aclspec.NewRuleSet( @@ -272,12 +278,12 @@ func TestTerminalNodeValidation(t *testing.T) { aclspec.NewRule("**", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) - err := tree.AddRuleSet(terminalRuleset) + _, err := tree.AddRuleSet(terminalRuleset) assert.NoError(t, err, "Should be able to add terminal node") // Verify it's terminal node := tree.GetNode("secure") - assert.True(t, node.IsTerminal(), "Node should be marked as terminal") + assert.True(t, node.GetTerminal(), "Node should be marked as terminal") // Add direct child - should succeed (tree allows all nodes for performance) childRuleset := aclspec.NewRuleSet( @@ -286,7 +292,7 @@ func TestTerminalNodeValidation(t *testing.T) { aclspec.NewRule("*.txt", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), ) - err = tree.AddRuleSet(childRuleset) + _, err = tree.AddRuleSet(childRuleset) assert.NoError(t, err, "Should be able to add child under terminal node (exists in tree)") // Add deeper nested child - should also succeed @@ -296,24 +302,24 @@ func TestTerminalNodeValidation(t *testing.T) { aclspec.NewRule("*.md", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), ) - err = tree.AddRuleSet(deepChildRuleset) + _, err = tree.AddRuleSet(deepChildRuleset) assert.NoError(t, err, "Should be able to add nested child under terminal node") // BUT: Terminal nodes should stop traversal for rule lookups // All paths under secure/ should resolve to the secure terminal node // Direct child path should resolve to terminal parent - nearestNode := tree.GetNearestNodeWithRules("secure/child/test.txt") + nearestNode := tree.LookupNearestNode("secure/child/test.txt") assert.NotNil(t, nearestNode, "Should find a node") assert.Equal(t, "secure", nearestNode.path, "Should resolve to terminal parent, not child") // Deep child path should also resolve to terminal parent - nearestNode = tree.GetNearestNodeWithRules("secure/child/grandchild/test.md") + nearestNode = tree.LookupNearestNode("secure/child/grandchild/test.md") assert.NotNil(t, nearestNode, "Should find a node") assert.Equal(t, "secure", nearestNode.path, "Should resolve to terminal parent, not grandchild") // Rule lookup should use terminal node's rules - rule, err := tree.GetRule("secure/child/test.txt") + rule, err := tree.GetEffectiveRule("secure/child/test.txt") assert.NoError(t, err, "Should find rule for child path") assert.Equal(t, "secure", rule.node.path, "Rule should come from terminal parent") assert.Equal(t, "**", rule.rule.Pattern, "Should use terminal node's catch-all rule") @@ -325,7 +331,7 @@ func TestTerminalNodeValidation(t *testing.T) { aclspec.NewRule("**", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), ) - err = tree.AddRuleSet(nonTerminalRuleset) + _, err = tree.AddRuleSet(nonTerminalRuleset) assert.NoError(t, err, "Should be able to add non-terminal node") // Add child under non-terminal - should succeed and be accessible @@ -335,11 +341,11 @@ func TestTerminalNodeValidation(t *testing.T) { aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) - err = tree.AddRuleSet(openChildRuleset) + _, err = tree.AddRuleSet(openChildRuleset) assert.NoError(t, err, "Should be able to add child under non-terminal node") // Child under non-terminal should be accessible for rule lookups - nearestNode = tree.GetNearestNodeWithRules("open/child/test.txt") + nearestNode = tree.LookupNearestNode("open/child/test.txt") assert.NotNil(t, nearestNode, "Should find a node") assert.Equal(t, "open/child", nearestNode.path, "Should find the actual child node, not parent") } @@ -347,7 +353,7 @@ func TestTerminalNodeValidation(t *testing.T) { func TestConflictingRuleSetsAtSameLevel(t *testing.T) { // Test what happens when adding multiple rulesets to the same path // This tests ruleset replacement/overwriting behavior - tree := NewTree() + tree := NewACLTree() // Add initial ruleset initialRuleset := aclspec.NewRuleSet( @@ -356,14 +362,14 @@ func TestConflictingRuleSetsAtSameLevel(t *testing.T) { aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) - err := tree.AddRuleSet(initialRuleset) + _, err := tree.AddRuleSet(initialRuleset) assert.NoError(t, err, "Should be able to add initial ruleset") // Verify initial ruleset node := tree.GetNode("shared") assert.NotNil(t, node, "Node should exist") - assert.False(t, node.IsTerminal(), "Node should not be terminal initially") - assert.Len(t, node.Rules(), 1, "Should have 1 rule initially") + assert.False(t, node.GetTerminal(), "Node should not be terminal initially") + assert.Len(t, node.GetRules(), 1, "Should have 1 rule initially") // Add conflicting ruleset at the SAME path with different rules and terminal flag conflictingRuleset := aclspec.NewRuleSet( @@ -373,14 +379,14 @@ func TestConflictingRuleSetsAtSameLevel(t *testing.T) { aclspec.NewRule("**", aclspec.SharedReadAccess("admin@example.com"), aclspec.DefaultLimits()), // Additional rule ) - err = tree.AddRuleSet(conflictingRuleset) + _, err = tree.AddRuleSet(conflictingRuleset) assert.NoError(t, err, "Should be able to add conflicting ruleset (overwrites)") // Verify the conflicting ruleset completely replaced the original node = tree.GetNode("shared") assert.NotNil(t, node, "Node should still exist") - assert.True(t, node.IsTerminal(), "Node should now be terminal (overwritten)") - assert.Len(t, node.Rules(), 2, "Should have 2 rules from new ruleset") + assert.True(t, node.GetTerminal(), "Node should now be terminal (overwritten)") + assert.Len(t, node.GetRules(), 2, "Should have 2 rules from new ruleset") // Verify the conflicting ruleset completely replaced the original // The original *.txt rule with PrivateAccess is gone @@ -412,11 +418,11 @@ func TestConflictingRuleSetsAtSameLevel(t *testing.T) { aclspec.NewRule("*.go", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) - err = tree.AddRuleSet(childRuleset) + _, err = tree.AddRuleSet(childRuleset) assert.NoError(t, err, "Should be able to add child under terminal node (tree allows all nodes)") // But rule lookups should stop at terminal node - rule, err = tree.GetRule("shared/child/test.go") + rule, err = tree.GetEffectiveRule("shared/child/test.go") assert.NoError(t, err, "Should find rule for child path") assert.Equal(t, "shared", rule.node.path, "Rule should come from terminal parent, not child") assert.Equal(t, "**", rule.rule.Pattern, "Should use parent's ** rule, not child's *.go rule") @@ -425,10 +431,10 @@ func TestConflictingRuleSetsAtSameLevel(t *testing.T) { func TestAddRuleSetErrorCases(t *testing.T) { // Test AddRuleSet with various error conditions // This improves coverage of edge cases and error handling - tree := NewTree() + tree := NewACLTree() // Test with nil ruleset - err := tree.AddRuleSet(nil) + _, err := tree.AddRuleSet(nil) assert.Error(t, err, "Should reject nil ruleset") assert.Contains(t, err.Error(), "ruleset is nil", "Error should indicate nil ruleset") @@ -438,7 +444,7 @@ func TestAddRuleSetErrorCases(t *testing.T) { Terminal: false, Rules: []*aclspec.Rule{}, } - err = tree.AddRuleSet(emptyRuleset) + _, err = tree.AddRuleSet(emptyRuleset) assert.Error(t, err, "Should reject empty ruleset") assert.Contains(t, err.Error(), "ruleset is empty", "Error should indicate empty ruleset") @@ -452,7 +458,7 @@ func TestAddRuleSetErrorCases(t *testing.T) { aclspec.UnsetTerminal, aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) - err = tree.AddRuleSet(deepRuleset) + _, err = tree.AddRuleSet(deepRuleset) assert.Error(t, err, "Should reject paths that are too deep") assert.Contains(t, err.Error(), "maximum depth exceeded", "Error should indicate depth limit") } @@ -460,7 +466,7 @@ func TestAddRuleSetErrorCases(t *testing.T) { func TestAddRuleSetPathNormalization(t *testing.T) { // Test that AddRuleSet properly normalizes different path formats // This ensures consistent path handling across different input formats - tree := NewTree() + tree := NewACLTree() // Test with path that has leading/trailing separators rule := aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()) @@ -476,7 +482,7 @@ func TestAddRuleSetPathNormalization(t *testing.T) { for i, path := range testPaths { ruleset := aclspec.NewRuleSet(path, false, rule) - err := tree.AddRuleSet(ruleset) + _, err := tree.AddRuleSet(ruleset) assert.NoError(t, err, "Should accept path format: %s", path) // All paths should result in the same node being found @@ -487,7 +493,7 @@ func TestAddRuleSetPathNormalization(t *testing.T) { } func TestNestedRuleSetRemoval(t *testing.T) { - tree := NewTree() + tree := NewACLTree() // Add nested rulesets ruleset1 := aclspec.NewRuleSet( @@ -502,11 +508,13 @@ func TestNestedRuleSetRemoval(t *testing.T) { aclspec.NewRule("*.md", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) - err := tree.AddRuleSet(ruleset1) + ver, err := tree.AddRuleSet(ruleset1) assert.NoError(t, err) + assert.Equal(t, ACLVersion(1), ver) - err = tree.AddRuleSet(ruleset2) + ver, err = tree.AddRuleSet(ruleset2) assert.NoError(t, err) + assert.Equal(t, ACLVersion(1), ver) // Remove parent - with original behavior, this removes the entire subtree removed := tree.RemoveRuleSet("parent") @@ -517,12 +525,14 @@ func TestNestedRuleSetRemoval(t *testing.T) { assert.False(t, ok, "Parent node should be completely removed") // Add the parent ruleset back - err = tree.AddRuleSet(ruleset1) + ver, err = tree.AddRuleSet(ruleset1) assert.NoError(t, err) + assert.Equal(t, ACLVersion(1), ver) // Add the child ruleset back - err = tree.AddRuleSet(ruleset2) + ver, err = tree.AddRuleSet(ruleset2) assert.NoError(t, err) + assert.Equal(t, ACLVersion(1), ver) // Remove just the child removed = tree.RemoveRuleSet("parent/child") diff --git a/internal/server/acl/types.go b/internal/server/acl/types.go index 9bcd50ee..74479afc 100644 --- a/internal/server/acl/types.go +++ b/internal/server/acl/types.go @@ -1,8 +1,7 @@ package acl type User struct { - ID string - IsOwner bool + ID string } type File struct { diff --git a/internal/server/datasite/datasite.go b/internal/server/datasite/datasite.go index 0f3f83af..d57956e9 100644 --- a/internal/server/datasite/datasite.go +++ b/internal/server/datasite/datasite.go @@ -15,10 +15,10 @@ import ( type DatasiteService struct { blob *blob.BlobService - acl *acl.AclService + acl *acl.ACLService } -func NewDatasiteService(blobSvc *blob.BlobService, aclSvc *acl.AclService) *DatasiteService { +func NewDatasiteService(blobSvc *blob.BlobService, aclSvc *acl.ACLService) *DatasiteService { return &DatasiteService{ blob: blobSvc, acl: aclSvc, @@ -50,7 +50,7 @@ func (d *DatasiteService) Start(ctx context.Context) error { // Load the ACL rulesets start = time.Now() - d.acl.LoadRuleSets(ruleSets) + d.acl.AddRuleSets(ruleSets) slog.Debug("acl build", "count", len(ruleSets), "took", time.Since(start)) // Warm up the ACL cache @@ -81,8 +81,13 @@ func (d *DatasiteService) GetView(user string) []*blob.BlobInfo { // Filter blobs based on ACL for _, blob := range blobs { + if IsOwner(blob.Key, user) { + view = append(view, blob) + continue + } + if err := d.acl.CanAccess( - &acl.User{ID: user, IsOwner: IsOwner(blob.Key, user)}, + &acl.User{ID: user}, &acl.File{Path: blob.Key}, acl.AccessRead, ); err == nil { diff --git a/internal/server/datasite/utils.go b/internal/server/datasite/utils.go index 25c1efe2..33126411 100644 --- a/internal/server/datasite/utils.go +++ b/internal/server/datasite/utils.go @@ -2,13 +2,48 @@ package datasite import ( "path/filepath" + "regexp" "strings" + + "github.com/openmined/syftbox/internal/utils" +) + +var ( + PathSep = string(filepath.Separator) + regexDatasitePath = regexp.MustCompile(`^[^\s@]+@[^\s@]+\.[^\s@]+/`) ) -var pathSep = string(filepath.Separator) +// GetOwner returns the owner of the path +func GetOwner(path string) string { + // clean path + path = CleanPath(path) + + // get owner + parts := strings.Split(path, PathSep) + if len(parts) == 0 { + return "" + } + + return parts[0] +} // IsOwner checks if the user is the owner of the path // The underlying assumption here is that owner is the prefix of the path func IsOwner(path string, user string) bool { - return strings.HasPrefix(strings.TrimLeft(filepath.Clean(path), pathSep), user) + path = CleanPath(path) + return strings.HasPrefix(path, user) +} + +// CleanPath returns a path with leading and trailing slashes removed +func CleanPath(path string) string { + return strings.TrimLeft(filepath.Clean(path), PathSep) +} + +// IsValidPath checks if the path is a valid datasite path +func IsValidPath(path string) bool { + return regexDatasitePath.MatchString(path) +} + +func IsValidDatasite(user string) bool { + return utils.IsValidEmail(user) } diff --git a/internal/server/handlers/acl/acl_handler.go b/internal/server/handlers/acl/acl_handler.go new file mode 100644 index 00000000..fee0262a --- /dev/null +++ b/internal/server/handlers/acl/acl_handler.go @@ -0,0 +1,43 @@ +package acl + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "github.com/openmined/syftbox/internal/server/acl" + "github.com/openmined/syftbox/internal/server/handlers/api" +) + +type ACLHandler struct { + aclSvc *acl.ACLService +} + +func NewACLHandler(svc *acl.ACLService) *ACLHandler { + return &ACLHandler{ + aclSvc: svc, + } +} + +func (h *ACLHandler) CheckAccess(ctx *gin.Context) { + var req ACLCheckRequest + if err := ctx.ShouldBindQuery(&req); err != nil { + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, err) + return + } + + // Check access using the ACL service + if err := h.aclSvc.CanAccess( + &acl.User{ID: req.User}, + &acl.File{Path: req.Path, Size: req.Size}, + acl.AccessLevel(req.Level), + ); err != nil { + api.AbortWithError(ctx, http.StatusForbidden, api.CodeAccessDenied, err) + return + } + + ctx.PureJSON(http.StatusOK, &ACLCheckResponse{ + User: req.User, + Path: req.Path, + Level: req.Level.String(), + }) +} diff --git a/internal/server/handlers/acl/acl_handler_types.go b/internal/server/handlers/acl/acl_handler_types.go new file mode 100644 index 00000000..7a46cc91 --- /dev/null +++ b/internal/server/handlers/acl/acl_handler_types.go @@ -0,0 +1,16 @@ +package acl + +import "github.com/openmined/syftbox/internal/server/acl" + +type ACLCheckRequest struct { + User string `form:"user" binding:"required"` + Path string `form:"path" binding:"required"` + Size int64 `form:"size"` + Level acl.AccessLevel `form:"level" binding:"required"` +} + +type ACLCheckResponse struct { + User string `json:"user"` + Path string `json:"path"` + Level string `json:"level"` +} diff --git a/internal/server/handlers/api/codes.go b/internal/server/handlers/api/codes.go new file mode 100644 index 00000000..9124f43f --- /dev/null +++ b/internal/server/handlers/api/codes.go @@ -0,0 +1,30 @@ +package api + +const ( + // Generic request/server errors + CodeInvalidRequest = "E_INVALID_REQUEST" // bad or invalid request + CodeRateLimited = "E_RATE_LIMITED" // rate limit exceeded + CodeInternalError = "E_INTERNAL_ERROR" // internal server error + CodeAccessDenied = "E_ACCESS_DENIED" // access denied + + // Auth errors + CodeAuthInvalidCredentials = "E_AUTH_INVALID_CREDENTIALS" // authentication credentials (e.g., token) are invalid, expired, or malformed. + CodeAuthTokenGenerationFailed = "E_AUTH_TOKEN_GENERATION_FAILED" // a failure during the generation of new authentication tokens. + CodeAuthOTPVerificationFailed = "E_AUTH_OTP_VERIFICATION_FAILED" // Email One-Time Password (OTP) verification failed. + CodeAuthTokenRefreshFailed = "E_AUTH_TOKEN_REFRESH_FAILED" // a failure during the attempt to refresh an authentication token. + CodeAuthNotificationFailed = "E_AUTH_NOTIFICATION_FAILED" // a failure in sending an authentication-related notification (e.g., OTP email/SMS). + + // Datasite errors + CodeDatasiteNotFound = "E_DATASITE_NOT_FOUND" // the specified datasite resource could not be found. + CodeDatasiteInvalidPath = "E_DATASITE_INVALID_PATH" // the provided path for a datasite resource is invalid or malformed. + + // Blob errors + CodeBlobNotFound = "E_BLOB_NOT_FOUND" // the specified blob could not be found. + CodeBlobListFailed = "E_BLOB_LIST_OPERATION_FAILED" // a failure during the operation to list blobs. + CodeBlobPutFailed = "E_BLOB_PUT_OPERATION_FAILED" // a failure during the operation to upload/put a blob. + CodeBlobGetFailed = "E_BLOB_GET_OPERATION_FAILED" // a failure during the operation to download/get a blob. + CodeBlobDeleteFailed = "E_BLOB_DELETE_OPERATION_FAILED" // a failure during the operation to delete a blob. + + // ACL errors + CodeACLUpdateFailed = "E_ACL_UPDATE_FAILED" // a failure during the operation to update an ACL. +) diff --git a/internal/server/handlers/api/error.go b/internal/server/handlers/api/error.go new file mode 100644 index 00000000..eb63da25 --- /dev/null +++ b/internal/server/handlers/api/error.go @@ -0,0 +1,12 @@ +package api + +import "fmt" + +type SyftAPIError struct { + Code string `json:"code"` + Message string `json:"error"` +} + +func (e *SyftAPIError) Error() string { + return fmt.Sprintf("syft api error: code=%s, message=%s", e.Code, e.Message) +} diff --git a/internal/server/handlers/api/response.go b/internal/server/handlers/api/response.go new file mode 100644 index 00000000..ad703981 --- /dev/null +++ b/internal/server/handlers/api/response.go @@ -0,0 +1,12 @@ +package api + +import "github.com/gin-gonic/gin" + +func AbortWithError(ctx *gin.Context, status int, code string, err error) { + ctx.Abort() + ctx.Error(err) + ctx.PureJSON(status, SyftAPIError{ + Code: code, + Message: err.Error(), + }) +} diff --git a/internal/server/handlers/auth/auth_handler.go b/internal/server/handlers/auth/auth_handler.go index 66805795..fed65f82 100644 --- a/internal/server/handlers/auth/auth_handler.go +++ b/internal/server/handlers/auth/auth_handler.go @@ -1,12 +1,13 @@ package auth import ( - "errors" "fmt" "net/http" "github.com/gin-gonic/gin" "github.com/openmined/syftbox/internal/server/auth" + "github.com/openmined/syftbox/internal/server/handlers/api" + "github.com/openmined/syftbox/internal/utils" ) type AuthHandler struct { @@ -22,24 +23,17 @@ func New(auth *auth.AuthService) *AuthHandler { func (h *AuthHandler) OTPRequest(ctx *gin.Context) { var req OTPRequest if err := ctx.ShouldBindJSON(&req); err != nil { - ctx.Error(fmt.Errorf("failed to bind json: %w", err)) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), - }) + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("failed to bind json: %w", err)) + return + } + + if !utils.IsValidEmail(req.Email) { + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("invalid email")) return } if err := h.auth.SendOTP(ctx, req.Email); err != nil { - ctx.Error(fmt.Errorf("failed to send OTP: %w", err)) - if errors.Is(err, auth.ErrInvalidEmail) { - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), - }) - } else { - ctx.PureJSON(http.StatusInternalServerError, gin.H{ - "error": err.Error(), - }) - } + api.AbortWithError(ctx, http.StatusInternalServerError, api.CodeAuthNotificationFailed, fmt.Errorf("failed to send OTP: %w", err)) return } @@ -49,19 +43,13 @@ func (h *AuthHandler) OTPRequest(ctx *gin.Context) { func (h *AuthHandler) OTPVerify(ctx *gin.Context) { var req OTPVerifyRequest if err := ctx.ShouldBindJSON(&req); err != nil { - ctx.Error(fmt.Errorf("failed to bind json: %w", err)) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), - }) + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("failed to bind json: %w", err)) return } accessToken, refreshToken, err := h.auth.GenerateTokensPair(ctx, req.Email, req.Code) if err != nil { - ctx.Error(fmt.Errorf("failed to generate tokens: %w", err)) - ctx.PureJSON(http.StatusInternalServerError, gin.H{ - "error": err.Error(), - }) + api.AbortWithError(ctx, http.StatusUnauthorized, api.CodeAuthTokenGenerationFailed, err) return } @@ -74,19 +62,13 @@ func (h *AuthHandler) OTPVerify(ctx *gin.Context) { func (h *AuthHandler) Refresh(ctx *gin.Context) { var req RefreshRequest if err := ctx.ShouldBindJSON(&req); err != nil { - ctx.Error(fmt.Errorf("failed to bind json: %w", err)) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), - }) + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("failed to bind json: %w", err)) return } accessToken, refreshToken, err := h.auth.RefreshToken(ctx, req.OldRefreshToken) if err != nil { - ctx.Error(fmt.Errorf("failed to refresh token: %w", err)) - ctx.PureJSON(http.StatusInternalServerError, gin.H{ - "error": err.Error(), - }) + api.AbortWithError(ctx, http.StatusUnauthorized, api.CodeAuthTokenRefreshFailed, err) return } diff --git a/internal/server/handlers/blob/blob_handler.go b/internal/server/handlers/blob/blob_handler.go index edaadec9..03da67c0 100644 --- a/internal/server/handlers/blob/blob_handler.go +++ b/internal/server/handlers/blob/blob_handler.go @@ -1,51 +1,55 @@ package blob import ( + "fmt" "net/http" - "regexp" "github.com/gin-gonic/gin" + "github.com/openmined/syftbox/internal/server/acl" "github.com/openmined/syftbox/internal/server/blob" -) - -var ( - regexDatasiteKey = regexp.MustCompile(`^[^\s@]+@[^\s@]+\.[^\s@]+/`) + "github.com/openmined/syftbox/internal/server/datasite" + "github.com/openmined/syftbox/internal/server/handlers/api" ) type BlobHandler struct { blob *blob.BlobService + acl *acl.ACLService } -func New(blob *blob.BlobService) *BlobHandler { - return &BlobHandler{blob: blob} +func New(blob *blob.BlobService, acl *acl.ACLService) *BlobHandler { + return &BlobHandler{blob: blob, acl: acl} } func (h *BlobHandler) UploadMultipart(ctx *gin.Context) { // todo - ctx.PureJSON(http.StatusNotImplemented, gin.H{ - "error": "not implemented", - }) + api.AbortWithError(ctx, http.StatusNotImplemented, api.CodeInvalidRequest, fmt.Errorf("not implemented")) } func (h *BlobHandler) UploadComplete(ctx *gin.Context) { // todo - ctx.PureJSON(http.StatusNotImplemented, gin.H{ - "error": "not implemented", - }) + api.AbortWithError(ctx, http.StatusNotImplemented, api.CodeInvalidRequest, fmt.Errorf("not implemented")) } func (h *BlobHandler) ListObjects(ctx *gin.Context) { res, err := h.blob.Index().List() if err != nil { - ctx.PureJSON(http.StatusInternalServerError, gin.H{ - "error": err.Error(), - }) + api.AbortWithError(ctx, http.StatusInternalServerError, api.CodeBlobListFailed, err) return } - ctx.PureJSON(http.StatusOK, res) + ctx.PureJSON(http.StatusOK, &gin.H{ + "blobs": res, + }) } -func isValidDatasiteKey(key string) bool { - return regexDatasiteKey.MatchString(key) +func (h *BlobHandler) checkPermissions(key string, user string, access acl.AccessLevel) error { + if datasite.IsOwner(key, user) { + return nil + } + + if err := h.acl.CanAccess(&acl.User{ID: user}, &acl.File{Path: key}, access); err != nil { + return err + } + + return nil } diff --git a/internal/server/handlers/blob/blob_handler_delete.go b/internal/server/handlers/blob/blob_handler_delete.go index 2871e169..38f62102 100644 --- a/internal/server/handlers/blob/blob_handler_delete.go +++ b/internal/server/handlers/blob/blob_handler_delete.go @@ -2,56 +2,58 @@ package blob import ( "fmt" + "log/slog" "net/http" "github.com/gin-gonic/gin" + "github.com/openmined/syftbox/internal/aclspec" + "github.com/openmined/syftbox/internal/server/acl" + "github.com/openmined/syftbox/internal/server/datasite" + "github.com/openmined/syftbox/internal/server/handlers/api" ) func (h *BlobHandler) DeleteObjects(ctx *gin.Context) { var req DeleteRequest - if err := ctx.ShouldBindJSON(&req); err != nil { - ctx.Error(fmt.Errorf("failed to bind json: %w", err)) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), - }) - return - } + user := ctx.GetString("user") - if len(req.Keys) == 0 { - ctx.Error(fmt.Errorf("keys cannot be empty")) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": "keys cannot be empty", - }) + if err := ctx.ShouldBindJSON(&req); err != nil { + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("failed to bind json: %w", err)) return } deleted := make([]string, 0, len(req.Keys)) - errors := make([]*BlobError, 0) + errors := make([]*BlobAPIError, 0) for _, key := range req.Keys { - if !isValidDatasiteKey(key) { - ctx.Error(fmt.Errorf("invalid datasite path: %s", key)) - errors = append(errors, &BlobError{ - Key: key, - Error: "invalid key", - }) + if !datasite.IsValidPath(key) { + errors = append(errors, NewBlobAPIError(api.CodeDatasiteInvalidPath, "invalid key", key)) continue } + + if err := h.checkPermissions(key, user, acl.AccessWrite); err != nil { + errors = append(errors, NewBlobAPIError(api.CodeAccessDenied, err.Error(), key)) + continue + } + _, err := h.blob.Backend().DeleteObject(ctx.Request.Context(), key) if err != nil { ctx.Error(fmt.Errorf("failed to delete object: %w", err)) - errors = append(errors, &BlobError{ - Key: key, - Error: err.Error(), - }) + errors = append(errors, NewBlobAPIError(api.CodeBlobDeleteFailed, err.Error(), key)) continue } + + if aclspec.IsACLFile(key) { + // don't worry the above permissions check will make sure that the user is admin + ok := h.acl.RemoveRuleSet(key) + if !ok { + slog.Warn("remove ruleset returned false", "key", key) + } + } + deleted = append(deleted, key) } code := http.StatusOK - if len(deleted) == 0 && len(errors) > 0 { - code = http.StatusBadRequest - } else if len(deleted) > 0 && len(errors) > 0 { + if len(deleted) >= 0 && len(errors) >= 0 { code = http.StatusMultiStatus } diff --git a/internal/server/handlers/blob/blob_handler_download_presigned.go b/internal/server/handlers/blob/blob_handler_download_presigned.go index 518807cc..b1346d11 100644 --- a/internal/server/handlers/blob/blob_handler_download_presigned.go +++ b/internal/server/handlers/blob/blob_handler_download_presigned.go @@ -5,60 +5,70 @@ import ( "net/http" "github.com/gin-gonic/gin" + "github.com/openmined/syftbox/internal/server/acl" + "github.com/openmined/syftbox/internal/server/datasite" + "github.com/openmined/syftbox/internal/server/handlers/api" ) func (h *BlobHandler) DownloadObjectsPresigned(ctx *gin.Context) { - var req PresignUrlRequest + var req PresignURLRequest + user := ctx.GetString("user") if err := ctx.ShouldBindJSON(&req); err != nil { - ctx.Error(fmt.Errorf("failed to bind json: %w", err)) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), - }) - return - } - - if len(req.Keys) == 0 { - ctx.Error(fmt.Errorf("keys cannot be empty")) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": "keys cannot be empty", - }) + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("failed to bind json: %w", err)) return } - urls := make([]*BlobUrl, 0, len(req.Keys)) - errors := make([]*BlobError, 0) + urls := make([]*BlobURL, 0, len(req.Keys)) + errors := make([]*BlobAPIError, 0) index := h.blob.Index() for _, key := range req.Keys { - if !isValidDatasiteKey(key) { - ctx.Error(fmt.Errorf("invalid datasite path: %s", key)) - errors = append(errors, &BlobError{ - Key: key, - Error: "invalid key", + if !datasite.IsValidPath(key) { + errors = append(errors, &BlobAPIError{ + SyftAPIError: api.SyftAPIError{ + Code: api.CodeDatasiteInvalidPath, + Message: "invalid key", + }, + Key: key, + }) + continue + } + + if err := h.checkPermissions(key, user, acl.AccessRead); err != nil { + errors = append(errors, &BlobAPIError{ + SyftAPIError: api.SyftAPIError{ + Code: api.CodeAccessDenied, + Message: err.Error(), + }, + Key: key, }) continue } _, ok := index.Get(key) if !ok { - ctx.Error(fmt.Errorf("object not found: %s", key)) - errors = append(errors, &BlobError{ - Key: key, - Error: "object not found", + errors = append(errors, &BlobAPIError{ + SyftAPIError: api.SyftAPIError{ + Code: api.CodeBlobNotFound, + Message: "object not found", + }, + Key: key, }) continue } url, err := h.blob.Backend().GetObjectPresigned(ctx, key) if err != nil { - ctx.Error(fmt.Errorf("failed to get object: %w", err)) - errors = append(errors, &BlobError{ - Key: key, - Error: err.Error(), + errors = append(errors, &BlobAPIError{ + SyftAPIError: api.SyftAPIError{ + Code: api.CodeBlobGetFailed, + Message: err.Error(), + }, + Key: key, }) continue } - urls = append(urls, &BlobUrl{ + urls = append(urls, &BlobURL{ Key: key, Url: url, }) @@ -71,7 +81,7 @@ func (h *BlobHandler) DownloadObjectsPresigned(ctx *gin.Context) { code = http.StatusMultiStatus } - ctx.PureJSON(code, &PresignUrlResponse{ + ctx.PureJSON(code, &PresignURLResponse{ URLs: urls, Errors: errors, }) diff --git a/internal/server/handlers/blob/blob_handler_types.go b/internal/server/handlers/blob/blob_handler_types.go index 32108af2..657263f1 100644 --- a/internal/server/handlers/blob/blob_handler_types.go +++ b/internal/server/handlers/blob/blob_handler_types.go @@ -1,13 +1,33 @@ package blob -type BlobUrl struct { +import ( + "fmt" + + "github.com/openmined/syftbox/internal/server/handlers/api" +) + +type BlobURL struct { Key string `json:"key"` Url string `json:"url"` } -type BlobError struct { - Key string `json:"key"` - Error string `json:"error"` +type BlobAPIError struct { + api.SyftAPIError + Key string `json:"key"` +} + +func NewBlobAPIError(code string, message string, key string) *BlobAPIError { + return &BlobAPIError{ + Key: key, + SyftAPIError: api.SyftAPIError{ + Code: code, + Message: message, + }, + } +} + +func (e *BlobAPIError) Error() string { + return fmt.Sprintf("syft api blob error: code=%s, message=%s, key=%s", e.Code, e.Message, e.Key) } type UploadRequest struct { @@ -26,20 +46,20 @@ type UploadResponse struct { LastModified string `json:"lastModified"` } -type PresignUrlRequest struct { - Keys []string `json:"keys" binding:"required"` +type PresignURLRequest struct { + Keys []string `json:"keys" binding:"required,min=1"` } -type PresignUrlResponse struct { - URLs []*BlobUrl `json:"urls"` - Errors []*BlobError `json:"errors"` +type PresignURLResponse struct { + URLs []*BlobURL `json:"urls"` + Errors []*BlobAPIError `json:"errors"` } type DeleteRequest struct { - Keys []string `json:"keys" binding:"required"` + Keys []string `json:"keys" binding:"required,min=1"` } type DeleteResponse struct { - Deleted []string `json:"deleted"` - Errors []*BlobError `json:"errors"` + Deleted []string `json:"deleted"` + Errors []*BlobAPIError `json:"errors"` } diff --git a/internal/server/handlers/blob/blob_handler_upload.go b/internal/server/handlers/blob/blob_handler_upload.go index 8af86525..6f153f0b 100644 --- a/internal/server/handlers/blob/blob_handler_upload.go +++ b/internal/server/handlers/blob/blob_handler_upload.go @@ -6,66 +6,66 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/openmined/syftbox/internal/aclspec" + "github.com/openmined/syftbox/internal/server/acl" "github.com/openmined/syftbox/internal/server/blob" + "github.com/openmined/syftbox/internal/server/datasite" + "github.com/openmined/syftbox/internal/server/handlers/api" ) func (h *BlobHandler) Upload(ctx *gin.Context) { + user := ctx.GetString("user") + + if key := ctx.Query("key"); aclspec.IsACLFile(key) { + h.UploadACL(ctx) + return + } + var req UploadRequest if err := ctx.ShouldBindQuery(&req); err != nil { - ctx.Error(fmt.Errorf("failed to bind query: %w", err)) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), - }) + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("failed to bind query: %w", err)) + return + } + + // todo check if new change using etag + + if !datasite.IsValidPath(req.Key) { + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeDatasiteInvalidPath, fmt.Errorf("invalid key: %s", req.Key)) return } - if !isValidDatasiteKey(req.Key) { - ctx.Error(fmt.Errorf("invalid datasite path: %s", req.Key)) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": "invalid key", - }) + if err := h.checkPermissions(req.Key, user, acl.AccessWrite); err != nil { + api.AbortWithError(ctx, http.StatusForbidden, api.CodeAccessDenied, err) return } // get form file file, err := ctx.FormFile("file") if err != nil { - ctx.Error(fmt.Errorf("failed to get form file: %w", err)) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("invalid file: %s", err), - }) + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("invalid file: %w", err)) return } // check file size if file.Size <= 0 { - ctx.Error(fmt.Errorf("invalid file: size is 0")) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": "invalid file: size is 0", - }) + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("invalid file: size is 0")) return } fd, err := file.Open() if err != nil { - ctx.Error(fmt.Errorf("failed to open file: %w", err)) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": fmt.Sprintf("invalid file: %s", err), - }) + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("invalid file file: %w", err)) return } - defer fd.Close() + result, err := h.blob.Backend().PutObject(ctx.Request.Context(), &blob.PutObjectParams{ Key: req.Key, Size: file.Size, Body: fd, }) if err != nil { - ctx.Error(fmt.Errorf("failed to put object: %w", err)) - ctx.PureJSON(http.StatusInternalServerError, gin.H{ - "error": "server error: could not persist file", - }) + api.AbortWithError(ctx, http.StatusInternalServerError, api.CodeBlobPutFailed, fmt.Errorf("failed to put object: %w", err)) return } diff --git a/internal/server/handlers/blob/blob_handler_upload_acl.go b/internal/server/handlers/blob/blob_handler_upload_acl.go new file mode 100644 index 00000000..d1c58f8f --- /dev/null +++ b/internal/server/handlers/blob/blob_handler_upload_acl.go @@ -0,0 +1,101 @@ +package blob + +import ( + "bytes" + "fmt" + "io" + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/openmined/syftbox/internal/aclspec" + "github.com/openmined/syftbox/internal/server/acl" + "github.com/openmined/syftbox/internal/server/blob" + "github.com/openmined/syftbox/internal/server/datasite" + "github.com/openmined/syftbox/internal/server/handlers/api" +) + +func (h *BlobHandler) UploadACL(ctx *gin.Context) { + var req UploadRequest + user := ctx.GetString("user") + + if err := ctx.ShouldBindQuery(&req); err != nil { + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("failed to bind query: %w", err)) + return + } + + if !(datasite.IsValidPath(req.Key) && aclspec.IsACLFile(req.Key)) { + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeDatasiteInvalidPath, fmt.Errorf("invalid ruleset path: %s", req.Key)) + return + } + + // check if user has admin rights + if err := h.checkPermissions(req.Key, user, acl.AccessAdmin); err != nil { + api.AbortWithError(ctx, http.StatusForbidden, api.CodeAccessDenied, err) + return + } + + // get form file + file, err := ctx.FormFile("file") + if err != nil { + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("failed to get form file: %w", err)) + return + } + + // check file size + if file.Size <= 0 { + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("invalid file: size is 0")) + return + } + + // open the file + fd, err := file.Open() + if err != nil { + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("failed to open file: %w", err)) + return + } + defer fd.Close() + + // read the file into memory, because we need to read it twice + fdBytes, err := io.ReadAll(fd) + if err != nil { + api.AbortWithError(ctx, http.StatusInternalServerError, api.CodeInvalidRequest, fmt.Errorf("failed to read file: %w", err)) + return + } + aclBytesReader := bytes.NewReader(fdBytes) + + // load aclspec + ruleset, err := aclspec.LoadFromReader(req.Key, aclBytesReader) + if err != nil { + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("failed to read ruleset: %w", err)) + } + + // upload file to s3 + // because that's always the ground truth + blobBytesReader := bytes.NewReader(fdBytes) + result, err := h.blob.Backend().PutObject(ctx.Request.Context(), &blob.PutObjectParams{ + Key: req.Key, + Size: file.Size, + Body: blobBytesReader, + }) + if err != nil { + api.AbortWithError(ctx, http.StatusInternalServerError, api.CodeBlobPutFailed, fmt.Errorf("failed to put object: %w", err)) + return + } + + // add it to the acl service + if _, err := h.acl.AddRuleSet(ruleset); err != nil { + // if this error happens, there's a pretty serious bug in the acl service + api.AbortWithError(ctx, http.StatusInternalServerError, api.CodeACLUpdateFailed, fmt.Errorf("failed to update ruleset: %w", err)) + return + } + + // response with UploadAccept + ctx.PureJSON(http.StatusOK, &UploadResponse{ + Key: result.Key, + Version: result.Version, + ETag: result.ETag, + Size: result.Size, + LastModified: result.LastModified.Format(time.RFC3339), + }) +} diff --git a/internal/server/handlers/blob/blob_handler_upload_presigned.go b/internal/server/handlers/blob/blob_handler_upload_presigned.go index 7897299b..ac141fe5 100644 --- a/internal/server/handlers/blob/blob_handler_upload_presigned.go +++ b/internal/server/handlers/blob/blob_handler_upload_presigned.go @@ -5,47 +5,57 @@ import ( "net/http" "github.com/gin-gonic/gin" + "github.com/openmined/syftbox/internal/server/acl" + "github.com/openmined/syftbox/internal/server/datasite" + "github.com/openmined/syftbox/internal/server/handlers/api" ) func (h *BlobHandler) UploadPresigned(ctx *gin.Context) { - var req PresignUrlRequest - if err := ctx.ShouldBindJSON(&req); err != nil { - ctx.Error(fmt.Errorf("failed to bind json: %w", err)) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": err.Error(), - }) - return - } + var req PresignURLRequest + user := ctx.GetString("user") - if len(req.Keys) == 0 { - ctx.Error(fmt.Errorf("keys cannot be empty")) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": "keys cannot be empty", - }) + if err := ctx.ShouldBindJSON(&req); err != nil { + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("failed to bind json: %w", err)) return } - urls := make([]*BlobUrl, 0, len(req.Keys)) - errors := make([]*BlobError, 0) + urls := make([]*BlobURL, 0, len(req.Keys)) + errors := make([]*BlobAPIError, 0) for _, key := range req.Keys { - if !isValidDatasiteKey(key) { - ctx.Error(fmt.Errorf("invalid datasite path: %s", key)) - errors = append(errors, &BlobError{ - Key: key, - Error: "invalid key", + if !datasite.IsValidPath(key) { + errors = append(errors, &BlobAPIError{ + SyftAPIError: api.SyftAPIError{ + Code: api.CodeDatasiteInvalidPath, + Message: "invalid key", + }, + Key: key, }) continue } + + if err := h.checkPermissions(key, user, acl.AccessWrite); err != nil { + errors = append(errors, &BlobAPIError{ + SyftAPIError: api.SyftAPIError{ + Code: api.CodeAccessDenied, + Message: err.Error(), + }, + Key: key, + }) + continue + } + url, err := h.blob.Backend().PutObjectPresigned(ctx, key) if err != nil { - ctx.Error(fmt.Errorf("failed to put object: %w", err)) - errors = append(errors, &BlobError{ - Key: key, - Error: err.Error(), + errors = append(errors, &BlobAPIError{ + SyftAPIError: api.SyftAPIError{ + Code: api.CodeBlobPutFailed, + Message: err.Error(), + }, + Key: key, }) continue } - urls = append(urls, &BlobUrl{ + urls = append(urls, &BlobURL{ Key: key, Url: url, }) @@ -53,12 +63,12 @@ func (h *BlobHandler) UploadPresigned(ctx *gin.Context) { code := http.StatusOK if len(urls) == 0 && len(errors) > 0 { - code = http.StatusBadRequest + code = http.StatusMultiStatus } else if len(urls) > 0 && len(errors) > 0 { code = http.StatusMultiStatus } - ctx.PureJSON(code, &PresignUrlResponse{ + ctx.PureJSON(code, &PresignURLResponse{ URLs: urls, Errors: errors, }) diff --git a/internal/server/handlers/datasite/datasite_handler.go b/internal/server/handlers/datasite/datasite_handler.go index 4d124bd4..390ea33b 100644 --- a/internal/server/handlers/datasite/datasite_handler.go +++ b/internal/server/handlers/datasite/datasite_handler.go @@ -1,7 +1,6 @@ package datasite import ( - "fmt" "net/http" "github.com/gin-gonic/gin" @@ -20,15 +19,6 @@ func New(svc *datasite.DatasiteService) *DatasiteHandler { func (h *DatasiteHandler) GetView(ctx *gin.Context) { user := ctx.GetString("user") - - if user == "" { - ctx.Error(fmt.Errorf("`user` is required")) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": "user is required", - }) - return - } - ctx.PureJSON(http.StatusOK, gin.H{ "files": h.svc.GetView(user), }) diff --git a/internal/server/handlers/explorer/explorer_handler.go b/internal/server/handlers/explorer/explorer_handler.go index 5f1a7877..346d4355 100644 --- a/internal/server/handlers/explorer/explorer_handler.go +++ b/internal/server/handlers/explorer/explorer_handler.go @@ -19,6 +19,7 @@ import ( "github.com/openmined/syftbox/internal/aclspec" "github.com/openmined/syftbox/internal/server/acl" "github.com/openmined/syftbox/internal/server/blob" + "github.com/openmined/syftbox/internal/server/handlers/api" ) //go:embed index.html.tmpl @@ -29,13 +30,13 @@ var notFoundOfTmpl string type ExplorerHandler struct { blob *blob.BlobService - acl *acl.AclService + acl *acl.ACLService tplIndex *template.Template tpl404 *template.Template } // New creates a new Explorer instance -func New(svc *blob.BlobService, acl *acl.AclService) *ExplorerHandler { +func New(svc *blob.BlobService, acl *acl.ACLService) *ExplorerHandler { funcMap := template.FuncMap{ "basename": filepath.Base, "humanizeSize": func(size int64) string { @@ -57,7 +58,7 @@ func New(svc *blob.BlobService, acl *acl.AclService) *ExplorerHandler { func (e *ExplorerHandler) Handler(c *gin.Context) { path := strings.TrimPrefix(c.Param("filepath"), "/") contents := e.listContents(path) - if contents.IsDir { + if contents.IsDir || contents.IsEmpty() { e.serveDir(c, path, contents) } else { e.serveFile(c, path) @@ -149,15 +150,14 @@ func (e *ExplorerHandler) serveDir(c *gin.Context, path string, contents *direct // Generate an HTML response c.Header("Content-Type", "text/html; charset=utf-8") if err := e.tplIndex.Execute(c.Writer, data); err != nil { - c.Error(fmt.Errorf("failed to execute template: %w", err)) - c.String(http.StatusInternalServerError, "internal server error") + api.AbortWithError(c, http.StatusInternalServerError, api.CodeInternalError, fmt.Errorf("failed to execute template: %w", err)) } } // Serve a file from S3 func (e *ExplorerHandler) serveFile(c *gin.Context, key string) { if err := e.acl.CanAccess( - &acl.User{ID: aclspec.Everyone, IsOwner: false}, + &acl.User{ID: aclspec.Everyone}, &acl.File{Path: key}, acl.AccessRead, ); err != nil { @@ -180,8 +180,7 @@ func (e *ExplorerHandler) serveFile(c *gin.Context, key string) { // Stream response body directly _, err = io.Copy(c.Writer, resp.Body) if err != nil { - c.Error(fmt.Errorf("failed to stream file: %w", err)) - c.String(http.StatusInternalServerError, "internal server error") + api.AbortWithError(c, http.StatusInternalServerError, api.CodeInternalError, fmt.Errorf("failed to stream file: %w", err)) return } } @@ -198,8 +197,7 @@ func (e *ExplorerHandler) detectContentType(key string) string { func (e *ExplorerHandler) serve404(c *gin.Context, key string) { c.Header("Content-Type", "text/html; charset=utf-8") if err := e.tpl404.Execute(c.Writer, map[string]any{"Key": key}); err != nil { - c.Error(fmt.Errorf("failed to execute template: %w", err)) - c.String(http.StatusInternalServerError, "internal server error") + api.AbortWithError(c, http.StatusInternalServerError, api.CodeInternalError, fmt.Errorf("failed to execute template: %w", err)) } } diff --git a/internal/server/handlers/explorer/explorer_handler_types.go b/internal/server/handlers/explorer/explorer_handler_types.go index bae4ce55..3995a18a 100644 --- a/internal/server/handlers/explorer/explorer_handler_types.go +++ b/internal/server/handlers/explorer/explorer_handler_types.go @@ -15,3 +15,7 @@ type directoryContents struct { Files []*blob.BlobInfo Folders []string } + +func (d *directoryContents) IsEmpty() bool { + return len(d.Files) == 0 && len(d.Folders) == 0 +} diff --git a/internal/server/handlers/ws/ws_hub.go b/internal/server/handlers/ws/ws_hub.go index 717a1be5..0f3ccf2a 100644 --- a/internal/server/handlers/ws/ws_hub.go +++ b/internal/server/handlers/ws/ws_hub.go @@ -9,7 +9,9 @@ import ( "github.com/coder/websocket" "github.com/gin-gonic/gin" + "github.com/openmined/syftbox/internal/server/handlers/api" "github.com/openmined/syftbox/internal/syftmsg" + "github.com/openmined/syftbox/internal/version" ) const ( @@ -90,31 +92,27 @@ func (h *WebsocketHub) Shutdown(ctx context.Context) { // WebsocketHandler is the handler for the websocket connection // it upgrades the http connection to a websocket and registers the client with the hub func (h *WebsocketHub) WebsocketHandler(ctx *gin.Context) { - if ctx.GetString("user") == "" { - ctx.Status(http.StatusUnauthorized) - slog.Warn("wshub unauthorized", "ip", ctx.ClientIP(), "headers", ctx.Request.Header, "path", ctx.Request.URL) + user := ctx.GetString("user") + if user == "" { + api.AbortWithError(ctx, http.StatusUnauthorized, api.CodeInvalidRequest, fmt.Errorf("user missing")) return } // Upgrade HTTP connection to WebSocket conn, err := websocket.Accept(ctx.Writer, ctx.Request, nil) if err != nil { - e := fmt.Errorf("websocket accept failed: %w", err) - ctx.Error(e) - ctx.PureJSON(http.StatusBadRequest, gin.H{ - "error": e.Error(), - }) + api.AbortWithError(ctx, http.StatusBadRequest, api.CodeInvalidRequest, fmt.Errorf("websocket accept failed: %w", err)) return } conn.SetReadLimit(maxMessageSize) client := NewWebsocketClient(conn, &ClientInfo{ - User: ctx.GetString("user"), + User: user, IPAddr: ctx.ClientIP(), Headers: ctx.Request.Header.Clone(), }) - client.MsgTx <- syftmsg.NewSystemMessage("0.5.0", "ok") + client.MsgTx <- syftmsg.NewSystemMessage(version.Version, "ok") h.register <- client } diff --git a/internal/server/middlewares/jwtauth.go b/internal/server/middlewares/jwtauth.go index ecf491ac..2e27dbda 100644 --- a/internal/server/middlewares/jwtauth.go +++ b/internal/server/middlewares/jwtauth.go @@ -8,6 +8,7 @@ import ( "github.com/gin-gonic/gin" "github.com/openmined/syftbox/internal/server/auth" // Import your auth package + "github.com/openmined/syftbox/internal/server/handlers/api" "github.com/openmined/syftbox/internal/utils" // Import types for error constants ) @@ -27,10 +28,7 @@ func JWTAuth(authService *auth.AuthService) gin.HandlerFunc { // expect user to be an email address user := ctx.Query("user") if !utils.IsValidEmail(user) { - ctx.Error(fmt.Errorf("invalid email")) - ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "error": "invalid email", - }) + api.AbortWithError(ctx, http.StatusUnauthorized, api.CodeInvalidRequest, fmt.Errorf("invalid email")) return } ctx.Set("user", user) @@ -43,39 +41,27 @@ func JWTAuth(authService *auth.AuthService) gin.HandlerFunc { return func(ctx *gin.Context) { authHeaderValue := ctx.GetHeader(authHeader) if authHeaderValue == "" { - ctx.Error(fmt.Errorf("authorization header required")) - ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "error": "authorization header required", - }) + api.AbortWithError(ctx, http.StatusUnauthorized, api.CodeAuthInvalidCredentials, fmt.Errorf("authorization header required")) return } // Check if the header starts with "Bearer " if !strings.HasPrefix(authHeaderValue, bearerPrefix) { - ctx.Error(fmt.Errorf("bearer token required")) - ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "error": "bearer token required", - }) + api.AbortWithError(ctx, http.StatusUnauthorized, api.CodeAuthInvalidCredentials, fmt.Errorf("bearer token required")) return } // Extract the token string tokenString := strings.TrimPrefix(authHeaderValue, bearerPrefix) if tokenString == "" { - ctx.Error(fmt.Errorf("token missing")) - ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "error": "token missing", - }) + api.AbortWithError(ctx, http.StatusUnauthorized, api.CodeAuthInvalidCredentials, fmt.Errorf("token missing")) return } // Validate the token using the method added to AuthService claims, err := authService.ValidateAccessToken(ctx, tokenString) if err != nil { - ctx.Error(err) - ctx.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ - "error": err.Error(), - }) + api.AbortWithError(ctx, http.StatusUnauthorized, api.CodeAuthInvalidCredentials, err) return } diff --git a/internal/server/middlewares/ratelimiter.go b/internal/server/middlewares/ratelimiter.go index fe659a8b..a7b1aff3 100644 --- a/internal/server/middlewares/ratelimiter.go +++ b/internal/server/middlewares/ratelimiter.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/gin-gonic/gin" + "github.com/openmined/syftbox/internal/server/handlers/api" "github.com/ulule/limiter/v3" "github.com/ulule/limiter/v3/drivers/store/memory" @@ -21,13 +22,15 @@ func RateLimiter(formattedRate string) gin.HandlerFunc { return mgin.NewMiddleware( limiter, mgin.WithLimitReachedHandler(func(c *gin.Context) { - c.PureJSON(http.StatusTooManyRequests, gin.H{ - "error": "rate limit exceeded", + c.PureJSON(http.StatusTooManyRequests, api.SyftAPIError{ + Code: api.CodeRateLimited, + Message: "rate limit exceeded", }) }), mgin.WithErrorHandler(func(c *gin.Context, err error) { - c.PureJSON(http.StatusInternalServerError, gin.H{ - "error": err.Error(), + c.PureJSON(http.StatusInternalServerError, api.SyftAPIError{ + Code: api.CodeInternalError, + Message: err.Error(), }) }), ) diff --git a/internal/server/routes.go b/internal/server/routes.go index 9236103f..88e565e9 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -7,6 +7,8 @@ import ( "github.com/gin-gonic/gin" + "github.com/openmined/syftbox/internal/server/handlers/acl" + "github.com/openmined/syftbox/internal/server/handlers/api" "github.com/openmined/syftbox/internal/server/handlers/auth" "github.com/openmined/syftbox/internal/server/handlers/blob" "github.com/openmined/syftbox/internal/server/handlers/datasite" @@ -32,10 +34,11 @@ func SetupRoutes(svc *Services, hub *ws.WebsocketHub, httpsEnabled bool) http.Ha // --------------------------- handlers --------------------------- - blobH := blob.New(svc.Blob) + blobH := blob.New(svc.Blob, svc.ACL) dsH := datasite.New(svc.Datasite) explorerH := explorer.New(svc.Blob, svc.ACL) authH := auth.New(svc.Auth) + aclH := acl.NewACLHandler(svc.ACL) // --------------------------- routes --------------------------- @@ -65,6 +68,7 @@ func SetupRoutes(svc *Services, hub *ws.WebsocketHub, httpsEnabled bool) http.Ha // blob v1.GET("/blob/list", blobH.ListObjects) v1.PUT("/blob/upload", blobH.Upload) + v1.PUT("/blob/upload/acl", blobH.UploadACL) v1.POST("/blob/upload/presigned", blobH.UploadPresigned) v1.POST("/blob/upload/multipart", blobH.UploadMultipart) v1.POST("/blob/upload/complete", blobH.UploadComplete) @@ -74,19 +78,24 @@ func SetupRoutes(svc *Services, hub *ws.WebsocketHub, httpsEnabled bool) http.Ha // datasite v1.GET("/datasite/view", dsH.GetView) + v1.PUT("/acl", blobH.UploadACL) + v1.GET("/acl/check", aclH.CheckAccess) + // websocket events v1.GET("/events", hub.WebsocketHandler) } r.NoRoute(func(c *gin.Context) { - c.JSON(http.StatusNotFound, gin.H{ - "error": "not found", + c.JSON(http.StatusNotFound, api.SyftAPIError{ + Code: api.CodeInvalidRequest, + Message: "not found", }) }) r.NoMethod(func(c *gin.Context) { - c.JSON(http.StatusMethodNotAllowed, gin.H{ - "error": "method not allowed", + c.JSON(http.StatusMethodNotAllowed, api.SyftAPIError{ + Code: api.CodeInvalidRequest, + Message: "method not allowed", }) }) diff --git a/internal/server/server.go b/internal/server/server.go index 8cccfede..a5f5325c 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -17,7 +17,6 @@ import ( "github.com/openmined/syftbox/internal/db" "github.com/openmined/syftbox/internal/server/acl" "github.com/openmined/syftbox/internal/server/blob" - "github.com/openmined/syftbox/internal/server/datasite" "github.com/openmined/syftbox/internal/server/handlers/ws" "github.com/openmined/syftbox/internal/syftmsg" "golang.org/x/sync/errgroup" @@ -263,7 +262,7 @@ func (s *Server) checkPermission(user string, path string, access acl.AccessLeve return nil } return s.svc.ACL.CanAccess( - &acl.User{ID: user, IsOwner: datasite.IsOwner(path, user)}, + &acl.User{ID: user}, &acl.File{Path: path}, access, ) diff --git a/internal/server/services.go b/internal/server/services.go index 007a6b5d..025c2c33 100644 --- a/internal/server/services.go +++ b/internal/server/services.go @@ -14,7 +14,7 @@ import ( type Services struct { Blob *blob.BlobService - ACL *acl.AclService + ACL *acl.ACLService Datasite *datasite.DatasiteService Auth *auth.AuthService Email *email.EmailService @@ -28,7 +28,7 @@ func NewServices(config *Config, db *sqlx.DB) (*Services, error) { return nil, err } - aclSvc := acl.NewAclService() + aclSvc := acl.NewACLService() datasiteSvc := datasite.NewDatasiteService(blobSvc, aclSvc) From 9a3028daf0403d22e8749c7ef5a761ba2a7665f2 Mon Sep 17 00:00:00 2001 From: Yash Gorana Date: Tue, 17 Jun 2025 17:44:49 +0530 Subject: [PATCH 06/19] fix(server/explorer): fix file not being server --- internal/server/handlers/explorer/explorer_handler.go | 2 +- internal/server/handlers/explorer/explorer_handler_types.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/server/handlers/explorer/explorer_handler.go b/internal/server/handlers/explorer/explorer_handler.go index 346d4355..55d49fd8 100644 --- a/internal/server/handlers/explorer/explorer_handler.go +++ b/internal/server/handlers/explorer/explorer_handler.go @@ -58,7 +58,7 @@ func New(svc *blob.BlobService, acl *acl.ACLService) *ExplorerHandler { func (e *ExplorerHandler) Handler(c *gin.Context) { path := strings.TrimPrefix(c.Param("filepath"), "/") contents := e.listContents(path) - if contents.IsDir || contents.IsEmpty() { + if contents.IsDir || contents.EmptyDir() { e.serveDir(c, path, contents) } else { e.serveFile(c, path) diff --git a/internal/server/handlers/explorer/explorer_handler_types.go b/internal/server/handlers/explorer/explorer_handler_types.go index 3995a18a..6909a81b 100644 --- a/internal/server/handlers/explorer/explorer_handler_types.go +++ b/internal/server/handlers/explorer/explorer_handler_types.go @@ -16,6 +16,6 @@ type directoryContents struct { Folders []string } -func (d *directoryContents) IsEmpty() bool { - return len(d.Files) == 0 && len(d.Folders) == 0 +func (d *directoryContents) EmptyDir() bool { + return d.IsDir && len(d.Files) == 0 && len(d.Folders) == 0 } From 4b8931bfae866b1160ef8052acb6797ac41cfe68 Mon Sep 17 00:00:00 2001 From: Yash Gorana Date: Tue, 17 Jun 2025 18:33:26 +0530 Subject: [PATCH 07/19] fix(client/sdk): disable auth for local urls --- internal/syftsdk/sdk.go | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/internal/syftsdk/sdk.go b/internal/syftsdk/sdk.go index 7c06684e..76bdd40e 100644 --- a/internal/syftsdk/sdk.go +++ b/internal/syftsdk/sdk.go @@ -6,6 +6,8 @@ import ( "fmt" "log/slog" "os" + "strconv" + "strings" "time" "github.com/openmined/syftbox/internal/utils" @@ -74,7 +76,7 @@ func (s *SyftSDK) Close() { // Authenticate sets the user authentication for API calls and events func (s *SyftSDK) Authenticate(ctx context.Context) error { - if isAuthDisabled() { + if isAuthDisabled() || isDevURL(s.config.BaseURL) { slog.Warn("sdk auth disabled, skipping auth") return nil } @@ -169,5 +171,15 @@ func (s *SyftSDK) setAccessToken(accessToken string) error { func isAuthDisabled() bool { authEnabled := os.Getenv("SYFTBOX_AUTH_ENABLED") - return authEnabled == "0" || authEnabled == "false" + enabled, err := strconv.ParseBool(authEnabled) + if err != nil { + return false + } + return !enabled +} + +func isDevURL(baseURL string) bool { + return strings.Contains(baseURL, "localhost") || + strings.Contains(baseURL, "127.0.0.1") || + strings.Contains(baseURL, "0.0.0.0") } From 787df8e75e3de3645ad730c25c0497aa34550b8f Mon Sep 17 00:00:00 2001 From: Yash Gorana Date: Tue, 17 Jun 2025 18:34:21 +0530 Subject: [PATCH 08/19] chore: update README --- README.md | 50 +++++++++++++++++++++++--------------------------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 1262d1fc..f95bc4f8 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,27 @@ # SyftBox -## Quickstart +SyftBox is an open-source protocol that enables developers and organizations to build, deploy, and federate privacy-preserving computations seamlessly across a network. Unlock the ability to run computations on distributed datasets without centralizing data—preserving security while gaining valuable insights. + +Read the [documentation](https://syftbox-documentation.openmined.org/get-started) for more details. + +> [!WARNING] +> This project is a rewrite of the [original Python version](https://github.com/OpenMined/syft). Consequently, the linked documentation may not fully reflect the current implementation. + +## Quick Start + +Using the GUI, from https://github.com/OpenMined/SyftUI/releases + +On macOS and Linux. +``` +curl -fsSL https://syftboxdev.openmined.org/install.sh | sh +``` + +On Windows using Powershell +``` +powershell -ExecutionPolicy ByPass -c "irm https://syftboxdev.openmined.org/install.ps1 | iex" +``` + +## Contributing ### Install Go Follow the official [Go installation guide](https://golang.org/doc/install) to set up Go on your system. @@ -26,29 +47,4 @@ Verify your setup by running the tests: just test ``` - -SyftBox is an open-source protocol that enables developers and organizations to build, deploy, and federate privacy-preserving computations seamlessly across a network. Unlock the ability to run computations on distributed datasets without centralizing data—preserving security while gaining valuable insights. - -Read the [documentation](https://syftbox-documentation.openmined.org/get-started) for more details. - -> [!WARNING] -> This project is a rewrite of the [original Python version](https://github.com/OpenMined/syft). Consequently, the linked documentation may not fully reflect the current implementation. - -## Installation - -Using the GUI, from https://github.com/OpenMined/SyftUI/releases - - -On macOS and Linux. -``` -curl -fsSL https://syftboxdev.openmined.org/install.sh | sh -``` - -On Windows using Powershell -``` -powershell -ExecutionPolicy ByPass -c "irm https://syftboxdev.openmined.org/install.ps1 | iex" -``` - -## Contributing - -See the [development guide](./DEVELOPMENT.md) to get started +See the [development guide](./DEVELOPMENT.md) for more details From 7daf4bce63c01d47f2c8400569baba98787ad437 Mon Sep 17 00:00:00 2001 From: Yash Date: Wed, 18 Jun 2025 15:19:29 +0530 Subject: [PATCH 09/19] fix(server/acl): ifx root perm terminal bug + cherry pick tests from #22 (#25) --- internal/server/acl/acl.go | 42 ++++------ internal/server/acl/acl_test.go | 30 +++++--- internal/server/acl/cache.go | 7 +- internal/server/acl/level.go | 42 +++++++--- internal/server/acl/level_test.go | 76 +++++++++++++++---- internal/server/acl/node.go | 8 +- internal/server/acl/tree.go | 55 ++++++++------ internal/server/acl/tree_test.go | 56 ++++++-------- internal/server/datasite/datasite.go | 13 ++-- .../handlers/explorer/explorer_handler.go | 11 +++ 10 files changed, 202 insertions(+), 138 deletions(-) diff --git a/internal/server/acl/acl.go b/internal/server/acl/acl.go index 29173288..e56dab20 100644 --- a/internal/server/acl/acl.go +++ b/internal/server/acl/acl.go @@ -1,8 +1,8 @@ package acl import ( - "errors" "fmt" + "log/slog" "github.com/openmined/syftbox/internal/aclspec" ) @@ -23,30 +23,14 @@ func NewACLService() *ACLService { // AddRuleSet adds or updates a new set of rules to the service. func (s *ACLService) AddRuleSet(ruleSet *aclspec.RuleSet) (ACLVersion, error) { - version, err := s.tree.AddRuleSet(ruleSet) + node, err := s.tree.AddRuleSet(ruleSet) if err != nil { return 0, err } s.cache.DeletePrefix(ruleSet.Path) - return version, nil -} - -// AddRuleSets adds a new set of rules to the service. -func (s *ACLService) AddRuleSets(ruleSets []*aclspec.RuleSet) error { - errs := make([]error, 0) - - for _, ruleSet := range ruleSets { - if _, err := s.tree.AddRuleSet(ruleSet); err != nil { - errs = append(errs, err) - } - } - - if len(errs) > 0 { - return fmt.Errorf("failed to add rule sets: %w", errors.Join(errs...)) - } - - return nil + slog.Debug("updated rule set", "path", node.path, "version", node.version) + return node.version, nil } // RemoveRuleSet removes a ruleset at the specified path. @@ -61,8 +45,8 @@ func (s *ACLService) RemoveRuleSet(path string) bool { return false } -// GetEffectiveRule finds the most specific rule applicable to the given path. -func (s *ACLService) GetEffectiveRule(path string) (*ACLRule, error) { +// GetRule finds the most specific rule applicable to the given path. +func (s *ACLService) GetRule(path string) (*ACLRule, error) { path = ACLNormPath(path) // cache hit @@ -86,7 +70,7 @@ func (s *ACLService) GetEffectiveRule(path string) (*ACLRule, error) { // CanAccess checks if a user has the specified access permission for a file. func (s *ACLService) CanAccess(user *User, file *File, level AccessLevel) error { // get the effective rule for the file - rule, err := s.GetEffectiveRule(file.Path) + rule, err := s.GetRule(file.Path) if err != nil { return err } @@ -95,13 +79,13 @@ func (s *ACLService) CanAccess(user *User, file *File, level AccessLevel) error if rule.Owner() == user.ID { return nil } - - // elevate action for ACL files - isAcl := aclspec.IsACLFile(file.Path) - if isAcl && level == AccessWrite { + // Elevate ACL file writes to admin level + if aclspec.IsACLFile(file.Path) && level >= AccessCreate { level = AccessAdmin - } else if level == AccessWrite { - // writes need to be checked against the file limits + } + + // Check file limits for write operations + if level >= AccessCreate { if err := rule.CheckLimits(file); err != nil { return fmt.Errorf("file limits exceeded for user '%s' on path '%s': %w", user.ID, file.Path, err) } diff --git a/internal/server/acl/acl_test.go b/internal/server/acl/acl_test.go index fda0a6c8..6bf74227 100644 --- a/internal/server/acl/acl_test.go +++ b/internal/server/acl/acl_test.go @@ -24,19 +24,19 @@ func TestAclServiceGetRule(t *testing.T) { // Test cache miss rules assert.NotContains(t, service.cache.index, "user/readme.md") - rule, err := service.GetEffectiveRule("user/readme.md") + rule, err := service.GetRule("user/readme.md") assert.NoError(t, err) assert.NotNil(t, rule) assert.Equal(t, "*.md", rule.rule.Pattern) // test cache hit assert.Contains(t, service.cache.index, "user/readme.md") - rule, err = service.GetEffectiveRule("user/readme.md") + rule, err = service.GetRule("user/readme.md") assert.NoError(t, err) assert.NotNil(t, rule) assert.Equal(t, "*.md", rule.rule.Pattern) - rule, err = service.GetEffectiveRule("user/notes.txt") + rule, err = service.GetRule("user/notes.txt") assert.NoError(t, err) assert.NotNil(t, rule) assert.Equal(t, "*.txt", rule.rule.Pattern) @@ -67,11 +67,11 @@ func TestAclServiceRemoveRuleSet(t *testing.T) { assert.Equal(t, ACLVersion(1), ver) // Verify both rulesets work - rule, err := service.GetEffectiveRule("user1@email.com/file.txt") + rule, err := service.GetRule("user1@email.com/file.txt") assert.NoError(t, err) assert.NotNil(t, rule) - rule, err = service.GetEffectiveRule("user2@email.com/file.txt") + rule, err = service.GetRule("user2@email.com/file.txt") assert.NoError(t, err) assert.NotNil(t, rule) @@ -80,12 +80,12 @@ func TestAclServiceRemoveRuleSet(t *testing.T) { assert.True(t, removed) // Verify removed ruleset no longer works - rule, err = service.GetEffectiveRule("user1@email.com/file.txt") + rule, err = service.GetRule("user1@email.com/file.txt") assert.Error(t, err) assert.Nil(t, rule) // Verify other ruleset still works - rule, err = service.GetEffectiveRule("user2@email.com/file.txt") + rule, err = service.GetRule("user2@email.com/file.txt") assert.NoError(t, err) assert.NotNil(t, rule) @@ -195,16 +195,22 @@ func TestAclServiceLoadRuleSets(t *testing.T) { ) // Load multiple rulesets at once - err := service.AddRuleSets([]*aclspec.RuleSet{ruleset1, ruleset2}) + ver, err := service.AddRuleSet(ruleset1) + assert.NoError(t, err) + assert.Equal(t, ACLVersion(1), ver) + + ver, err = service.AddRuleSet(ruleset2) + assert.NoError(t, err) + assert.Equal(t, ACLVersion(1), ver) assert.NoError(t, err) // Verify both rulesets work - rule, err := service.GetEffectiveRule("user1@email.com/file.txt") + rule, err := service.GetRule("user1@email.com/file.txt") assert.NoError(t, err) assert.NotNil(t, rule) assert.Equal(t, "*.txt", rule.rule.Pattern) - rule, err = service.GetEffectiveRule("user2@email.com/file.md") + rule, err = service.GetRule("user2@email.com/file.md") assert.NoError(t, err) assert.NotNil(t, rule) assert.Equal(t, "*.md", rule.rule.Pattern) @@ -225,7 +231,7 @@ func TestAclServiceCacheInvalidation(t *testing.T) { assert.Equal(t, ACLVersion(1), ver) // Access a path to cache the rule - rule, err := service.GetEffectiveRule("user1@email.com/readme.md") + rule, err := service.GetRule("user1@email.com/readme.md") assert.NoError(t, err) assert.NotNil(t, rule) assert.Equal(t, "*.md", rule.rule.Pattern) @@ -244,7 +250,7 @@ func TestAclServiceCacheInvalidation(t *testing.T) { assert.Equal(t, ACLVersion(2), ver) // Access the same path, should get the new rule - rule, err = service.GetEffectiveRule("user1@email.com/readme.md") + rule, err = service.GetRule("user1@email.com/readme.md") assert.NoError(t, err) assert.NotNil(t, rule) assert.True(t, rule.node.GetTerminal()) diff --git a/internal/server/acl/cache.go b/internal/server/acl/cache.go index e695ed5b..8786cd68 100644 --- a/internal/server/acl/cache.go +++ b/internal/server/acl/cache.go @@ -1,14 +1,13 @@ package acl import ( - "log/slog" "strings" "sync" ) // ACLCache stores the effective ACL rule for a given path. type ACLCache struct { - index map[string]*ACLRule // path -> ACLRule + index map[string]*ACLRule // Normalized ACLPath -> ACLRule mu sync.RWMutex } @@ -38,8 +37,6 @@ func (c *ACLCache) Set(path string, rule *ACLRule) { defer c.mu.Unlock() c.index[path] = rule - - slog.Debug("acl cache set", "path", path, "version", rule.Version()) } // Delete deletes the effective ACL rule for the given path. @@ -48,7 +45,6 @@ func (c *ACLCache) Delete(path string) { defer c.mu.Unlock() delete(c.index, path) - slog.Debug("acl cache delete", "path", path) } // DeletePrefix deletes the effective ACL rule for all paths that match the given prefix. @@ -60,7 +56,6 @@ func (c *ACLCache) DeletePrefix(path string) { for k := range c.index { if strings.HasPrefix(k, path) { delete(c.index, k) - slog.Debug("acl cache prefix delete", "path", k) } } } diff --git a/internal/server/acl/level.go b/internal/server/acl/level.go index ccaf89e1..524ccafe 100644 --- a/internal/server/acl/level.go +++ b/internal/server/acl/level.go @@ -5,20 +5,44 @@ type AccessLevel uint8 // Action constants define different types of file permissions const ( - AccessRead AccessLevel = iota + 1 + AccessRead AccessLevel = 1 << iota + AccessCreate AccessWrite AccessAdmin ) func (a AccessLevel) String() string { - switch a { - case AccessRead: - return "Read" - case AccessWrite: - return "Write" - case AccessAdmin: - return "Admin" - default: + if a == 0 { + return "None" + } + + var parts []string + + if (a & AccessRead) == AccessRead { + parts = append(parts, "Read") + } + if (a & AccessCreate) == AccessCreate { + parts = append(parts, "Create") + } + if (a & AccessWrite) == AccessWrite { + parts = append(parts, "Write") + } + if (a & AccessAdmin) == AccessAdmin { + parts = append(parts, "Admin") + } + + if len(parts) == 0 { return "Unknown" } + + if len(parts) == 1 { + return parts[0] + } + + // For multiple permissions, join with "+" + result := parts[0] + for i := 1; i < len(parts); i++ { + result += "+" + parts[i] + } + return result } diff --git a/internal/server/acl/level_test.go b/internal/server/acl/level_test.go index b9534be2..2635543c 100644 --- a/internal/server/acl/level_test.go +++ b/internal/server/acl/level_test.go @@ -20,6 +20,11 @@ func TestAccessLevelString(t *testing.T) { expected: "Read", desc: "AccessRead should return 'Read'", }, + { + level: AccessCreate, + expected: "Create", + desc: "AccessCreate should return 'Create'", + }, { level: AccessWrite, expected: "Write", @@ -32,14 +37,24 @@ func TestAccessLevelString(t *testing.T) { }, { level: 0, - expected: "Unknown", - desc: "Zero value should return 'Unknown'", + expected: "None", + desc: "Zero value should return 'None'", }, { - level: AccessLevel(10), + level: AccessLevel(16), expected: "Unknown", desc: "Undefined values should return 'Unknown'", }, + { + level: AccessRead | AccessWrite, + expected: "Read+Write", + desc: "Combined permissions should be joined with '+'", + }, + { + level: AccessRead | AccessCreate | AccessWrite, + expected: "Read+Create+Write", + desc: "Multiple combined permissions should be joined in order", + }, } for _, tc := range testCases { @@ -53,18 +68,19 @@ func TestAccessLevelString(t *testing.T) { func TestAccessLevelStringUnknown(t *testing.T) { // Test String() method with invalid/unknown access level values // This ensures the system handles undefined access levels gracefully - unknownLevel := AccessLevel(255) // Invalid access level + unknownLevel := AccessLevel(16) // Invalid access level (higher than any defined bit) result := unknownLevel.String() assert.Equal(t, "Unknown", result, "Unknown access levels should return 'Unknown'") } func TestAccessLevelValues(t *testing.T) { // Test that AccessLevel constants have the expected values - // Since iota starts at 0 and we use iota + 1, values should be 1, 2, 3 + // Using bit flags: 1 << iota creates powers of 2 - assert.Equal(t, AccessLevel(1), AccessRead, "AccessRead should be 1") - assert.Equal(t, AccessLevel(2), AccessWrite, "AccessWrite should be 2") - assert.Equal(t, AccessLevel(3), AccessAdmin, "AccessAdmin should be 3") + assert.Equal(t, AccessLevel(1), AccessRead, "AccessRead should be 1 (1 << 0)") + assert.Equal(t, AccessLevel(2), AccessCreate, "AccessCreate should be 2 (1 << 1)") + assert.Equal(t, AccessLevel(4), AccessWrite, "AccessWrite should be 4 (1 << 2)") + assert.Equal(t, AccessLevel(8), AccessAdmin, "AccessAdmin should be 8 (1 << 3)") } func TestAccessLevelUniqueness(t *testing.T) { @@ -72,6 +88,7 @@ func TestAccessLevelUniqueness(t *testing.T) { // This prevents accidental duplicate values that could cause permission conflicts levels := []AccessLevel{ AccessRead, + AccessCreate, AccessWrite, AccessAdmin, } @@ -92,8 +109,9 @@ func TestAccessLevelHierarchy(t *testing.T) { // Test the logical hierarchy of access levels // This documents the intended permission hierarchy in the system - // Verify the ordering based on iota values - assert.True(t, AccessRead < AccessWrite, "Read should be lower than Write") + // Verify the ordering based on bit values + assert.True(t, AccessRead < AccessCreate, "Read should be lower than Create") + assert.True(t, AccessCreate < AccessWrite, "Create should be lower than Write") assert.True(t, AccessWrite < AccessAdmin, "Write should be lower than Admin") assert.True(t, AccessRead < AccessAdmin, "Read should be lower than Admin") } @@ -104,10 +122,11 @@ func TestAccessLevelZeroValue(t *testing.T) { var zeroLevel AccessLevel assert.Equal(t, AccessLevel(0), zeroLevel, "Zero value should be 0") - assert.Equal(t, "Unknown", zeroLevel.String(), "Zero value should return 'Unknown'") + assert.Equal(t, "None", zeroLevel.String(), "Zero value should return 'None'") // Zero should not match any defined permission assert.NotEqual(t, AccessRead, zeroLevel, "Zero should not equal AccessRead") + assert.NotEqual(t, AccessCreate, zeroLevel, "Zero should not equal AccessCreate") assert.NotEqual(t, AccessWrite, zeroLevel, "Zero should not equal AccessWrite") assert.NotEqual(t, AccessAdmin, zeroLevel, "Zero should not equal AccessAdmin") } @@ -151,13 +170,42 @@ func TestAccessLevelEdgeCases(t *testing.T) { // Test maximum uint8 value maxLevel := AccessLevel(255) - assert.Equal(t, "Unknown", maxLevel.String(), "Maximum value should be handled as unknown") + assert.Equal(t, "Read+Create+Write+Admin", maxLevel.String(), "Maximum value should show all known bits set") // Test values between defined constants - betweenLevels := AccessLevel(4) // Just after AccessAdmin (3) + betweenLevels := AccessLevel(16) // Higher than all defined constants assert.Equal(t, "Unknown", betweenLevels.String(), "Undefined values should be unknown") // Test that undefined values are handled correctly - undefinedLevel := AccessLevel(10) + undefinedLevel := AccessLevel(32) assert.Equal(t, "Unknown", undefinedLevel.String(), "Higher undefined values should be unknown") +} + +func TestAccessLevelBitOperations(t *testing.T) { + // Test bit flag operations + // This validates that permissions can be combined and checked using bitwise operations + + // Test combining permissions + readWrite := AccessRead | AccessWrite + assert.Equal(t, AccessLevel(5), readWrite, "Read | Write should be 5 (1 | 4)") + assert.Equal(t, "Read+Write", readWrite.String(), "Combined permissions should show both") + + // Test checking individual permissions + allPerms := AccessRead | AccessCreate | AccessWrite | AccessAdmin + assert.True(t, (allPerms & AccessRead) == AccessRead, "Should have Read permission") + assert.True(t, (allPerms & AccessCreate) == AccessCreate, "Should have Create permission") + assert.True(t, (allPerms & AccessWrite) == AccessWrite, "Should have Write permission") + assert.True(t, (allPerms & AccessAdmin) == AccessAdmin, "Should have Admin permission") + + // Test absence of permissions + readOnly := AccessRead + assert.True(t, (readOnly & AccessRead) == AccessRead, "Should have Read permission") + assert.False(t, (readOnly & AccessWrite) == AccessWrite, "Should not have Write permission") + assert.False(t, (readOnly & AccessAdmin) == AccessAdmin, "Should not have Admin permission") + + // Test removing permissions + allButAdmin := allPerms &^ AccessAdmin + assert.True(t, (allButAdmin & AccessRead) == AccessRead, "Should still have Read") + assert.True(t, (allButAdmin & AccessWrite) == AccessWrite, "Should still have Write") + assert.False(t, (allButAdmin & AccessAdmin) == AccessAdmin, "Should not have Admin") } \ No newline at end of file diff --git a/internal/server/acl/node.go b/internal/server/acl/node.go index c1e0e6c8..4135cc48 100644 --- a/internal/server/acl/node.go +++ b/internal/server/acl/node.go @@ -111,7 +111,7 @@ func (n *ACLNode) SetRules(rules []*aclspec.Rule, terminal bool) { n.rules = nil } - // set the rules and terminal flag + // set the terminal flag n.terminal = terminal // increment the version @@ -132,17 +132,17 @@ func (n *ACLNode) FindBestRule(path string) (*ACLRule, error) { defer n.mu.RUnlock() if n.rules == nil { - return nil, ErrNoRuleFound + return nil, ErrNoRule } - // find the best matching rule + // find the best matching rule (rules are already sorted by specificity) for _, aclRule := range n.rules { if ok, _ := doublestar.Match(aclRule.fullPattern, path); ok { return aclRule, nil } } - return nil, ErrNoRuleFound + return nil, ErrNoRule } // GetOwner returns the owner of the node. diff --git a/internal/server/acl/tree.go b/internal/server/acl/tree.go index f835ca61..a53d302a 100644 --- a/internal/server/acl/tree.go +++ b/internal/server/acl/tree.go @@ -3,7 +3,6 @@ package acl import ( "errors" "fmt" - "log/slog" "strings" "github.com/openmined/syftbox/internal/aclspec" @@ -12,7 +11,8 @@ import ( var ( ErrInvalidRuleset = errors.New("invalid ruleset") ErrMaxDepthExceeded = errors.New("maximum depth exceeded") - ErrNoRuleFound = errors.New("no rules available") + ErrNoRuleSet = errors.New("no ruleset found") + ErrNoRule = errors.New("no rules available") ) const ( @@ -33,15 +33,15 @@ func NewACLTree() *ACLTree { } // Add or update a ruleset in the tree. -func (t *ACLTree) AddRuleSet(ruleset *aclspec.RuleSet) (ACLVersion, error) { +func (t *ACLTree) AddRuleSet(ruleset *aclspec.RuleSet) (*ACLNode, error) { // Validate the ruleset if ruleset == nil { - return 0, fmt.Errorf("%w: ruleset is nil", ErrInvalidRuleset) + return nil, fmt.Errorf("%w: ruleset is nil", ErrInvalidRuleset) } allRules := ruleset.AllRules() if len(allRules) == 0 { - return 0, fmt.Errorf("%w: ruleset is empty", ErrInvalidRuleset) + return nil, fmt.Errorf("%w: ruleset is empty", ErrInvalidRuleset) } // Clean and split the path @@ -53,28 +53,28 @@ func (t *ACLTree) AddRuleSet(ruleset *aclspec.RuleSet) (ACLVersion, error) { // but in future we can always bake it as a part of the acl schema owner := parts[0] if owner == "" { - return 0, fmt.Errorf("%w: owner is empty", ErrInvalidRuleset) + return nil, fmt.Errorf("%w: owner is empty", ErrInvalidRuleset) } // Check path depth limit (u8) if pathDepth > ACLMaxDepth { - return 0, ErrMaxDepthExceeded + return nil, ErrMaxDepthExceeded } // Start at the root node current := t.root - currentDepth := current.depth + currentDepth := 0 // Traverse/create the path - for _, part := range parts { - currentDepth++ + for i, part := range parts { + currentDepth = i + 1 // Calculate depth based on current position // Important: We still process terminal nodes to ensure all ACLs are known to the tree // Get or create child node child, exists := current.GetChild(part) if !exists { fullPath := ACLJoinPath(parts[:currentDepth]...) - child = NewACLNode(fullPath, owner, false, currentDepth) + child = NewACLNode(fullPath, owner, false, ACLDepth(currentDepth)) current.SetChild(part, child) } @@ -83,21 +83,22 @@ func (t *ACLTree) AddRuleSet(ruleset *aclspec.RuleSet) (ACLVersion, error) { // Set the rules on the final node current.SetRules(allRules, ruleset.Terminal) - slog.Debug("added ruleset", "path", ruleset.Path, "owner", owner, "version", current.GetVersion()) - return current.GetVersion(), nil + return current, nil } // GetEffectiveRule returns the most specific rule applicable to the given path. func (t *ACLTree) GetEffectiveRule(path string) (*ACLRule, error) { - node := t.LookupNearestNode(path) // O(depth) + normalizedPath := ACLNormPath(path) + + node := t.LookupNearestNode(normalizedPath) // O(depth) if node == nil { - return nil, ErrNoRuleFound + return nil, ErrNoRuleSet } - rule, err := node.FindBestRule(path) // O(rules|node) + rule, err := node.FindBestRule(normalizedPath) // O(rules|node) if err != nil { - return nil, err + return nil, err // returns ErrNoRuleFound if no rule is found } return rule, nil @@ -105,12 +106,17 @@ func (t *ACLTree) GetEffectiveRule(path string) (*ACLRule, error) { // LookupNearestNode returns the nearest node in the tree that has associated rules for the given path. // It returns nil if no such node is found. -func (t *ACLTree) LookupNearestNode(path string) *ACLNode { - parts := ACLPathSegments(path) +func (t *ACLTree) LookupNearestNode(normalizedPath string) *ACLNode { + parts := ACLPathSegments(normalizedPath) var candidate *ACLNode current := t.root + // candidate only if the root node has rules + if current.GetRules() != nil { + candidate = current + } + for _, part := range parts { // Stop if the current node is terminal. if current.GetTerminal() { @@ -133,7 +139,8 @@ func (t *ACLTree) LookupNearestNode(path string) *ACLNode { // GetNode finds the exact node applicable for the given path. func (t *ACLTree) GetNode(path string) *ACLNode { - parts := ACLPathSegments(path) + normalizedPath := ACLNormPath(path) + parts := ACLPathSegments(normalizedPath) current := t.root for _, part := range parts { @@ -157,16 +164,16 @@ func (t *ACLTree) RemoveRuleSet(path string) bool { var lastPart string parts := ACLPathSegments(path) - current := t.root + currentNode := t.root for _, part := range parts { - child, exists := current.GetChild(part) + child, exists := currentNode.GetChild(part) if !exists { return false } - parent = current - current = child + parent = currentNode + currentNode = child lastPart = part } diff --git a/internal/server/acl/tree_test.go b/internal/server/acl/tree_test.go index 97ce6672..6d7d4b9e 100644 --- a/internal/server/acl/tree_test.go +++ b/internal/server/acl/tree_test.go @@ -15,10 +15,10 @@ func TestNewACLTree(t *testing.T) { // The ACL system uses forward slashes internally on all platforms for glob compatibility. // We explicitly test for "/" rather than pathSep (which is "\" on Windows) because // the ACL system is a platform-independent abstraction layer. - assert.Equal(t, ACLPathSep, tree.root.path) + assert.Equal(t, "/", tree.root.path) assert.Empty(t, tree.root.children) - assert.Empty(t, tree.root.rules) + assert.Nil(t, tree.root.GetRules()) } func TestAddRuleSet(t *testing.T) { @@ -30,31 +30,30 @@ func TestAddRuleSet(t *testing.T) { aclspec.NewDefaultRule(aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) - ver, err := tree.AddRuleSet(ruleset) - assert.Equal(t, ACLVersion(1), ver) + node, err := tree.AddRuleSet(ruleset) // check root node "/" assert.NoError(t, err) - assert.Empty(t, tree.root.rules) - assert.Contains(t, tree.root.children, "test") - assert.Equal(t, tree.root.path, "/") - assert.Equal(t, tree.root.depth, ACLDepth(0)) + assert.NotNil(t, node) + assert.Nil(t, tree.root.GetRules()) + assert.Equal(t, "/", tree.root.path) + assert.Equal(t, ACLDepth(0), tree.root.GetDepth()) // check node "test" child, ok := tree.root.GetChild("test") assert.True(t, ok) assert.NotNil(t, child) - assert.Empty(t, child.rules) - assert.Contains(t, child.children, "path") - assert.Equal(t, child.path, "test") - assert.Equal(t, child.depth, ACLDepth(1)) + assert.Nil(t, child.GetRules()) + assert.Equal(t, "test", child.path) + assert.Equal(t, ACLDepth(1), child.GetDepth()) // check node "path" child, ok = child.GetChild("path") assert.True(t, ok) assert.NotNil(t, child) - assert.Equal(t, child.path, "test/path") - assert.Equal(t, child.depth, ACLDepth(2)) + assert.Equal(t, "test/path", child.path) + assert.Equal(t, ACLDepth(2), child.GetDepth()) + assert.NotNil(t, child.GetRules()) } func TestTreeTraversal(t *testing.T) { @@ -79,17 +78,14 @@ func TestTreeTraversal(t *testing.T) { aclspec.NewRule("*.go", aclspec.PublicReadAccess(), aclspec.DefaultLimits()), ) - ver, err := tree.AddRuleSet(ruleset1) + _, err := tree.AddRuleSet(ruleset1) assert.NoError(t, err) - assert.Equal(t, ACLVersion(1), ver) - ver, err = tree.AddRuleSet(ruleset2) + _, err = tree.AddRuleSet(ruleset2) assert.NoError(t, err) - assert.Equal(t, ACLVersion(1), ver) - ver, err = tree.AddRuleSet(ruleset3) + _, err = tree.AddRuleSet(ruleset3) assert.NoError(t, err) - assert.Equal(t, ACLVersion(1), ver) // Test finding nearest node with rules for different paths node := tree.LookupNearestNode("parent/file.txt") @@ -101,7 +97,7 @@ func TestTreeTraversal(t *testing.T) { node = tree.LookupNearestNode("parent/child/grandchild/main.go") assert.Equal(t, "parent/child/grandchild", node.path) - // Test inheritance - terminal nodes (like parent/child) block inheritance from higher levels + // Test inheritance - terminal nodes (like grandchild) block inheritance from higher levels node = tree.LookupNearestNode("parent/child/unknown.txt") assert.Equal(t, "parent/child", node.path) @@ -126,13 +122,11 @@ func TestRemoveRuleSet(t *testing.T) { aclspec.NewRule("*.txt", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) - ver, err := tree.AddRuleSet(ruleset1) + _, err := tree.AddRuleSet(ruleset1) assert.NoError(t, err) - assert.Equal(t, ACLVersion(1), ver) - ver, err = tree.AddRuleSet(ruleset2) + _, err = tree.AddRuleSet(ruleset2) assert.NoError(t, err) - assert.Equal(t, ACLVersion(1), ver) // Verify both rulesets are in the tree _, ok := tree.root.GetChild("folder1") @@ -508,13 +502,11 @@ func TestNestedRuleSetRemoval(t *testing.T) { aclspec.NewRule("*.md", aclspec.PrivateAccess(), aclspec.DefaultLimits()), ) - ver, err := tree.AddRuleSet(ruleset1) + _, err := tree.AddRuleSet(ruleset1) assert.NoError(t, err) - assert.Equal(t, ACLVersion(1), ver) - ver, err = tree.AddRuleSet(ruleset2) + _, err = tree.AddRuleSet(ruleset2) assert.NoError(t, err) - assert.Equal(t, ACLVersion(1), ver) // Remove parent - with original behavior, this removes the entire subtree removed := tree.RemoveRuleSet("parent") @@ -525,14 +517,12 @@ func TestNestedRuleSetRemoval(t *testing.T) { assert.False(t, ok, "Parent node should be completely removed") // Add the parent ruleset back - ver, err = tree.AddRuleSet(ruleset1) + _, err = tree.AddRuleSet(ruleset1) assert.NoError(t, err) - assert.Equal(t, ACLVersion(1), ver) // Add the child ruleset back - ver, err = tree.AddRuleSet(ruleset2) + _, err = tree.AddRuleSet(ruleset2) assert.NoError(t, err) - assert.Equal(t, ACLVersion(1), ver) // Remove just the child removed = tree.RemoveRuleSet("parent/child") diff --git a/internal/server/datasite/datasite.go b/internal/server/datasite/datasite.go index d57956e9..0b2bd9fc 100644 --- a/internal/server/datasite/datasite.go +++ b/internal/server/datasite/datasite.go @@ -50,7 +50,11 @@ func (d *DatasiteService) Start(ctx context.Context) error { // Load the ACL rulesets start = time.Now() - d.acl.AddRuleSets(ruleSets) + for _, ruleSet := range ruleSets { + if _, err := d.acl.AddRuleSet(ruleSet); err != nil { + slog.Warn("ruleset update error", "path", ruleSet.Path, "error", err) + } + } slog.Debug("acl build", "count", len(ruleSets), "took", time.Since(start)) // Warm up the ACL cache @@ -60,7 +64,7 @@ func (d *DatasiteService) Start(ctx context.Context) error { &acl.User{ID: aclspec.Everyone}, &acl.File{Path: blob.Key}, acl.AccessRead, - ); err != nil && errors.Is(err, acl.ErrNoRuleFound) { + ); err != nil && errors.Is(err, acl.ErrNoRule) { slog.Warn("acl cache warm error", "path", blob.Key, "error", err) } } @@ -81,11 +85,6 @@ func (d *DatasiteService) GetView(user string) []*blob.BlobInfo { // Filter blobs based on ACL for _, blob := range blobs { - if IsOwner(blob.Key, user) { - view = append(view, blob) - continue - } - if err := d.acl.CanAccess( &acl.User{ID: user}, &acl.File{Path: blob.Key}, diff --git a/internal/server/handlers/explorer/explorer_handler.go b/internal/server/handlers/explorer/explorer_handler.go index 55d49fd8..f615685c 100644 --- a/internal/server/handlers/explorer/explorer_handler.go +++ b/internal/server/handlers/explorer/explorer_handler.go @@ -100,6 +100,17 @@ func (e *ExplorerHandler) listContents(prefix string) *directoryContents { } for _, blob := range blobs { + + // check if public readable + if err := e.acl.CanAccess( + &acl.User{ID: aclspec.Everyone}, + &acl.File{Path: blob.Key}, + acl.AccessRead, + ); err != nil { + // don't reveal if the file is private or not + continue + } + if strings.HasPrefix(blob.Key, prefix) { relPath := strings.TrimPrefix(blob.Key, prefix) if relPath == "" { From 2d4ead88d0fbbfc9271b2f05c359c523de6ec3bc Mon Sep 17 00:00:00 2001 From: Yash Gorana Date: Wed, 18 Jun 2025 16:16:45 +0530 Subject: [PATCH 10/19] fix(server): use json log on stage & prod --- cmd/server/main.go | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/cmd/server/main.go b/cmd/server/main.go index 9379ae3e..e9a8d687 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -32,7 +32,6 @@ const ( var ( dotenvLoaded bool - prodEnv bool ) var rootCmd = &cobra.Command{ @@ -84,19 +83,31 @@ func init() { } else { dotenvLoaded = true } - - prodEnv = os.Getenv("SYFTBOX_ENV") == "PROD" } func main() { // Setup logger - var handler slog.Handler - if prodEnv { - handler = slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ + logger := slog.New(setupHandler()) + slog.SetDefault(logger) + + // Setup root context with signal handling + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + + // server go brr + if err := rootCmd.ExecuteContext(ctx); err != nil { + os.Exit(1) + } +} + +func setupHandler() slog.Handler { + switch os.Getenv("SYFTBOX_ENV") { + case "PROD", "STAGE": + return slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ Level: slog.LevelDebug, }) - } else { - handler = tint.NewHandler(os.Stdout, &tint.Options{ + default: + return tint.NewHandler(os.Stdout, &tint.Options{ Level: slog.LevelDebug, AddSource: true, TimeFormat: time.DateTime, @@ -108,17 +119,6 @@ func main() { }, }) } - logger := slog.New(handler) - slog.SetDefault(logger) - - // Setup root context with signal handling - ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) - defer stop() - - // server go brr - if err := rootCmd.ExecuteContext(ctx); err != nil { - os.Exit(1) - } } // loadConfig initializes viper, reads config file/env vars, and maps values to config From 4fba2b0632acce381322ed9259c2a60563db3a5e Mon Sep 17 00:00:00 2001 From: Yash Date: Wed, 18 Jun 2025 19:35:34 +0530 Subject: [PATCH 11/19] feat: migrate to new syftbox.net (#13) * feat: go to prod * chore: retire syftbox running openmined.org domain * fix: update error message --- README.md | 6 +- cmd/client/main.go | 10 +-- cmd/client/main_test.go | 8 +- config/server.example.yaml | 2 +- docker/docker-compose.yml | 8 +- internal/client/config/config.go | 2 +- internal/client/sync/sync_engine_test.go | 81 -------------------- internal/server/handlers/install/install.ps1 | 2 +- internal/server/handlers/install/install.sh | 2 +- internal/syftsdk/sdk_config.go | 2 +- 10 files changed, 16 insertions(+), 107 deletions(-) delete mode 100644 internal/client/sync/sync_engine_test.go diff --git a/README.md b/README.md index f95bc4f8..c780a2dd 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ SyftBox is an open-source protocol that enables developers and organizations to build, deploy, and federate privacy-preserving computations seamlessly across a network. Unlock the ability to run computations on distributed datasets without centralizing data—preserving security while gaining valuable insights. -Read the [documentation](https://syftbox-documentation.openmined.org/get-started) for more details. +Read the [documentation](https://www.syftbox.net) for more details. > [!WARNING] > This project is a rewrite of the [original Python version](https://github.com/OpenMined/syft). Consequently, the linked documentation may not fully reflect the current implementation. @@ -13,12 +13,12 @@ Using the GUI, from https://github.com/OpenMined/SyftUI/releases On macOS and Linux. ``` -curl -fsSL https://syftboxdev.openmined.org/install.sh | sh +curl -fsSL https://syftbox.net/install.sh | sh ``` On Windows using Powershell ``` -powershell -ExecutionPolicy ByPass -c "irm https://syftboxdev.openmined.org/install.ps1 | iex" +powershell -ExecutionPolicy ByPass -c "irm https://syftbox.net/install.ps1 | iex" ``` ## Contributing diff --git a/cmd/client/main.go b/cmd/client/main.go index 2d20b1a2..452e8169 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -22,9 +22,7 @@ import ( ) var ( - home, _ = os.UserHomeDir() - oldProdURL = "syftbox.openmined.org" - oldStageURL = "syftboxstage.openmined.org" + home, _ = os.UserHomeDir() ) var rootCmd = &cobra.Command{ @@ -166,11 +164,9 @@ func loadConfig(cmd *cobra.Command) (*config.Config, error) { return nil, fmt.Errorf("config read: %w", err) } - // perform migrations // this will error out because a re-auth with server will be required - if strings.Contains(cfg.ServerURL, oldProdURL) || - strings.Contains(cfg.ServerURL, oldStageURL) { - return nil, fmt.Errorf("legacy config detected. please run `syftbox login` to re-authenticate") + if strings.Contains(cfg.ServerURL, "openmined.org") { + return nil, fmt.Errorf("legacy server detected. run `syftbox login` to re-authenticate") } return cfg, nil diff --git a/cmd/client/main_test.go b/cmd/client/main_test.go index aeb4ea37..728e03ed 100644 --- a/cmd/client/main_test.go +++ b/cmd/client/main_test.go @@ -12,7 +12,7 @@ import ( func TestLoadConfigEnv(t *testing.T) { t.Setenv("SYFTBOX_EMAIL", "test@example.com") - t.Setenv("SYFTBOX_SERVER_URL", "https://test.openmined.org") + t.Setenv("SYFTBOX_SERVER_URL", "https://test.syftbox.net") t.Setenv("SYFTBOX_CLIENT_URL", "http://localhost:7938") t.Setenv("SYFTBOX_APPS_ENABLED", "true") t.Setenv("SYFTBOX_REFRESH_TOKEN", "test-refresh-token") @@ -34,7 +34,7 @@ func TestLoadConfigEnv(t *testing.T) { require.NoError(t, err) assert.Equal(t, "test@example.com", cfg.Email) - assert.Equal(t, "https://test.openmined.org", cfg.ServerURL) + assert.Equal(t, "https://test.syftbox.net", cfg.ServerURL) assert.Equal(t, "http://localhost:7938", cfg.ClientURL) assert.Equal(t, true, cfg.AppsEnabled) assert.Equal(t, "test-refresh-token", cfg.RefreshToken) @@ -55,7 +55,7 @@ func TestLoadConfigJSON(t *testing.T) { { "email": "test@example.com", "data_dir": "/tmp/syftbox-test-json", - "server_url": "https://test-json.openmined.org", + "server_url": "https://test-json.syftbox.net", "client_url": "http://localhost:8080", "refresh_token": "test-refresh-token-json", "access_token": "test-access-token-json" @@ -78,7 +78,7 @@ func TestLoadConfigJSON(t *testing.T) { require.Equal(t, dummyConfigFile, cfg.Path) assert.Equal(t, "test@example.com", cfg.Email) assert.Equal(t, "/tmp/syftbox-test-json", cfg.DataDir) - assert.Equal(t, "https://test-json.openmined.org", cfg.ServerURL) + assert.Equal(t, "https://test-json.syftbox.net", cfg.ServerURL) assert.Equal(t, "http://localhost:8080", cfg.ClientURL) assert.Equal(t, "test-refresh-token-json", cfg.RefreshToken) assert.Equal(t, "test-access-token-json", cfg.AccessToken) // can read, but not persist! diff --git a/config/server.example.yaml b/config/server.example.yaml index ee3eab4a..bb244fa2 100644 --- a/config/server.example.yaml +++ b/config/server.example.yaml @@ -24,7 +24,7 @@ auth: # whether to enable auth enabled: true # issuer of the JWT tokens (required) - token_issuer: https://example.com + token_issuer: https://test.syftbox.net # secret for the refresh token (required) # recommended to use SYFTBOX_AUTH_REFRESH_TOKEN_SECRET env var refresh_token_secret: refresh_token_secret diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 9978d48a..ede3c7fa 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -30,13 +30,7 @@ services: environment: - SYFTBOX_ENV=DEV - SYFTBOX_AUTH_ENABLED=0 - - SYFTBOX_AUTH_EMAIL_OTP_LENGTH=8 - - SYFTBOX_AUTH_EMAIL_OTP_EXPIRY=5m - - SYFTBOX_AUTH_TOKEN_ISSUER=https://syftboxdev.openmined.org - - SYFTBOX_AUTH_REFRESH_TOKEN_SECRET=123 - - SYFTBOX_AUTH_REFRESH_TOKEN_EXPIRY=0 - - SYFTBOX_AUTH_ACCESS_TOKEN_SECRET=132 - - SYFTBOX_AUTH_ACCESS_TOKEN_EXPIRY=72h + - SYFTBOX_EMAIL_ENABLED=0 - SYFTBOX_BLOB_REGION=us-east-1 - SYFTBOX_BLOB_BUCKET_NAME=syftbox-local - SYFTBOX_BLOB_ENDPOINT=http://minio:9000 diff --git a/internal/client/config/config.go b/internal/client/config/config.go index 53709b58..46ccd313 100644 --- a/internal/client/config/config.go +++ b/internal/client/config/config.go @@ -17,7 +17,7 @@ var ( home, _ = os.UserHomeDir() DefaultConfigPath = filepath.Join(home, ".syftbox", "config.json") DefaultDataDir = filepath.Join(home, "SyftBox") - DefaultServerURL = "https://syftboxdev.openmined.org" + DefaultServerURL = "https://syftbox.net" DefaultClientURL = "http://localhost:7938" DefaultLogFilePath = filepath.Join(home, ".syftbox", "logs", "syftbox.log") DefaultAppsEnabled = true diff --git a/internal/client/sync/sync_engine_test.go b/internal/client/sync/sync_engine_test.go deleted file mode 100644 index cae32326..00000000 --- a/internal/client/sync/sync_engine_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package sync - -// func TestSyncEngineFullSync(t *testing.T) { -// dummyDatasite, err := datasite.NewLocalDatasite("~/SyftBox", "yash@openmined.org") -// assert.NoError(t, err) - -// sdk, err := syftsdk.New("https://syftboxdev.openmined.org") -// sdk.Login("yash@openmined.org") -// assert.NoError(t, err) - -// ignore := NewSyncIgnore(dummyDatasite.DatasitesDir, dummyDatasite.DatasitesDir) -// watcher := NewFileWatcher(dummyDatasite.DatasitesDir) - -// syncEngine := NewSyncEngine(dummyDatasite, sdk, ignore, watcher) -// err = syncEngine.RunSync(context.Background()) -// assert.NoError(t, err) -// } - -// func TestSyncEngineReconcile(t *testing.T) { -// syncEngine := &SyncEngine{} - -// journal := map[string]*FileMetadata{ -// "/test/file1": { -// Path: "/test/file1", -// ETag: "123", -// Version: "1", -// Size: 100, -// LastModified: time.Now(), -// }, -// "/test/file4": { -// Path: "/test/file4", -// ETag: "sadlajklsd", -// Version: "1", -// Size: 1012310, -// LastModified: time.Now(), -// }, -// } - -// localState := map[string]*FileMetadata{ -// "/test/file1": { -// Path: "/test/file1", -// ETag: "123", -// Version: "1", -// Size: 100, -// LastModified: time.Now(), -// }, -// "/test/file3": { -// Path: "/test/file3", -// ETag: "ashldk", -// Version: "1", -// Size: 10, -// LastModified: time.Now(), -// }, -// } - -// remoteState := map[string]*FileMetadata{ -// "/test/file1": { -// Path: "/test/file1", -// ETag: "defg", -// Version: "2", // new version -// Size: 100, -// LastModified: time.Now(), -// }, -// "/test/file4": { -// Path: "/test/file4", -// ETag: "sadlajklsd", -// Version: "1", -// Size: 1012310, -// LastModified: time.Now(), -// }, -// } - -// // this should download file1, delete file4, and write file3 -// result := syncEngine.reconcile(localState, remoteState, journal) -// assert.Equal(t, 1, len(result.RemoteWrites)) -// assert.Equal(t, 1, len(result.LocalWrites)) -// assert.Equal(t, 1, len(result.RemoteDeletes)) -// assert.Equal(t, "/test/file3", result.RemoteWrites["/test/file3"].Path) -// assert.Equal(t, "/test/file1", result.LocalWrites["/test/file1"].Path) -// assert.Equal(t, "/test/file4", result.RemoteDeletes["/test/file4"].Path) -// } diff --git a/internal/server/handlers/install/install.ps1 b/internal/server/handlers/install/install.ps1 index aa552407..f59d56f3 100644 --- a/internal/server/handlers/install/install.ps1 +++ b/internal/server/handlers/install/install.ps1 @@ -8,7 +8,7 @@ $AskRunClient = $true $RunClient = $false $AppName = "syftbox" -$ArtifactBaseUrl = if ($env:ARTIFACT_BASE_URL) { $env:ARTIFACT_BASE_URL } else { "https://syftboxdev.openmined.org" } +$ArtifactBaseUrl = if ($env:ARTIFACT_BASE_URL) { $env:ARTIFACT_BASE_URL } else { "https://syftbox.net" } $ArtifactDownloadUrl = "$ArtifactBaseUrl/releases" function Write-Error-Exit($message) { diff --git a/internal/server/handlers/install/install.sh b/internal/server/handlers/install/install.sh index 7339de38..9d4d7a2d 100644 --- a/internal/server/handlers/install/install.sh +++ b/internal/server/handlers/install/install.sh @@ -13,7 +13,7 @@ RUN_CLIENT=0 INSTALL_APPS=${INSTALL_APPS:-""} APP_NAME="syftbox" -ARTIFACT_BASE_URL=${ARTIFACT_BASE_URL:-"https://syftboxdev.openmined.org"} +ARTIFACT_BASE_URL=${ARTIFACT_BASE_URL:-"https://syftbox.net"} ARTIFACT_DOWNLOAD_URL="$ARTIFACT_BASE_URL/releases" SYFTBOX_BINARY_PATH="$HOME/.local/bin/syftbox" diff --git a/internal/syftsdk/sdk_config.go b/internal/syftsdk/sdk_config.go index 9119347d..f79b53dc 100644 --- a/internal/syftsdk/sdk_config.go +++ b/internal/syftsdk/sdk_config.go @@ -5,7 +5,7 @@ import ( ) const ( - DefaultBaseURL = "https://syftboxdev.openmined.org" + DefaultBaseURL = "https://syftbox.net" ) // SyftSDKConfig is the configuration for the SyftSDK From 02454dd9ddf25be7ace23c75d5ef65d98d30e4a5 Mon Sep 17 00:00:00 2001 From: Yash Gorana Date: Wed, 18 Jun 2025 19:58:14 +0530 Subject: [PATCH 12/19] fix(client/sync): improved ignore file read --- internal/client/sync/sync_ignore.go | 60 +++++++++++++++++++---------- 1 file changed, 40 insertions(+), 20 deletions(-) diff --git a/internal/client/sync/sync_ignore.go b/internal/client/sync/sync_ignore.go index df39c324..b2555068 100644 --- a/internal/client/sync/sync_ignore.go +++ b/internal/client/sync/sync_ignore.go @@ -2,9 +2,11 @@ package sync import ( "bufio" + "fmt" "log/slog" "os" "path/filepath" + "strings" "github.com/openmined/syftbox/internal/utils" gitignore "github.com/sabhiram/go-gitignore" @@ -52,27 +54,12 @@ func (s *SyncIgnoreList) Load() { // read the syftignore file if it exists if utils.FileExists(ignorePath) { - rules := 0 - file, err := os.Open(ignorePath) + customRules, err := readIgnoreFile(ignorePath) if err != nil { - slog.Warn("Failed to open syftignore file", "path", ignorePath, "error", err) - } - defer file.Close() - - scanner := bufio.NewScanner(file) - for scanner.Scan() { - line := scanner.Text() - if line != "" { - ignoreLines = append(ignoreLines, line) - rules++ - } - } - - // Check for errors during the scan - if err := scanner.Err(); err != nil { - slog.Warn("Error reading syftignore file", "path", ignorePath, "error", err) - } else { - slog.Info("Loaded syftignore file", "path", ignorePath, "rules", rules) + slog.Warn("failed to read syftignore file", "path", ignorePath, "error", err) + } else if len(customRules) > 0 { + ignoreLines = append(ignoreLines, customRules...) + slog.Info("loaded syftignore file", "path", ignorePath, "rules", len(customRules)) } } @@ -82,3 +69,36 @@ func (s *SyncIgnoreList) Load() { func (s *SyncIgnoreList) ShouldIgnore(path string) bool { return s.ignore.MatchesPath(path) } + +func readIgnoreFile(path string) ([]string, error) { + if path == "" { + return nil, fmt.Errorf("ignore file path is empty") + } + + file, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("failed to open ignore file: %w", err) + } + defer file.Close() + + ignoreLines := []string{} + scanner := bufio.NewScanner(file) + + for scanner.Scan() { + line := scanner.Text() + line = strings.TrimSpace(line) + + // comments, empty lines, and null bytes + if strings.HasPrefix(line, "#") || line == "" || strings.Contains(line, "\x00") { + continue + } + + ignoreLines = append(ignoreLines, line) + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("error reading ignore file: %w", err) + } + + return ignoreLines, nil +} From 23f5f6ea245fed58fd0134426af7e6144a9e3896 Mon Sep 17 00:00:00 2001 From: Yash Date: Wed, 18 Jun 2025 22:12:27 +0530 Subject: [PATCH 13/19] fix(server/acl): don't remove whole subtree on delete (#26) --- internal/server/acl/acl.go | 23 +++++++++----- internal/server/acl/cache.go | 7 ++++- internal/server/acl/node.go | 7 +++++ internal/server/acl/tree.go | 11 +++++-- internal/server/acl/tree_test.go | 53 ++++++++++++++++++++++++-------- 5 files changed, 77 insertions(+), 24 deletions(-) diff --git a/internal/server/acl/acl.go b/internal/server/acl/acl.go index e56dab20..d631bbaf 100644 --- a/internal/server/acl/acl.go +++ b/internal/server/acl/acl.go @@ -3,6 +3,7 @@ package acl import ( "fmt" "log/slog" + "strings" "github.com/openmined/syftbox/internal/aclspec" ) @@ -28,8 +29,8 @@ func (s *ACLService) AddRuleSet(ruleSet *aclspec.RuleSet) (ACLVersion, error) { return 0, err } - s.cache.DeletePrefix(ruleSet.Path) - slog.Debug("updated rule set", "path", node.path, "version", node.version) + deleted := s.cache.DeletePrefix(ruleSet.Path) + slog.Debug("updated rule set", "path", node.path, "version", node.version, "cache.deleted", deleted) return node.version, nil } @@ -39,7 +40,8 @@ func (s *ACLService) AddRuleSet(ruleSet *aclspec.RuleSet) (ACLVersion, error) { func (s *ACLService) RemoveRuleSet(path string) bool { path = aclspec.WithoutACLPath(path) if ok := s.tree.RemoveRuleSet(path); ok { - s.cache.DeletePrefix(path) + deleted := s.cache.DeletePrefix(path) + slog.Debug("deleted cached rules", "path", path, "count", deleted) return true } return false @@ -69,16 +71,17 @@ func (s *ACLService) GetRule(path string) (*ACLRule, error) { // CanAccess checks if a user has the specified access permission for a file. func (s *ACLService) CanAccess(user *User, file *File, level AccessLevel) error { + // early return if user is the owner + if isOwner(file.Path, user.ID) { + return nil + } + // get the effective rule for the file rule, err := s.GetRule(file.Path) if err != nil { return err } - // early return if user is the owner - if rule.Owner() == user.ID { - return nil - } // Elevate ACL file writes to admin level if aclspec.IsACLFile(file.Path) && level >= AccessCreate { level = AccessAdmin @@ -103,3 +106,9 @@ func (s *ACLService) CanAccess(user *User, file *File, level AccessLevel) error func (s *ACLService) String() string { return s.tree.String() } + +// checks if the user is the owner of the path +func isOwner(path string, user string) bool { + path = ACLNormPath(path) + return strings.HasPrefix(path, user) +} diff --git a/internal/server/acl/cache.go b/internal/server/acl/cache.go index 8786cd68..32a88e50 100644 --- a/internal/server/acl/cache.go +++ b/internal/server/acl/cache.go @@ -48,14 +48,19 @@ func (c *ACLCache) Delete(path string) { } // DeletePrefix deletes the effective ACL rule for all paths that match the given prefix. -func (c *ACLCache) DeletePrefix(path string) { +func (c *ACLCache) DeletePrefix(path string) int { c.mu.Lock() defer c.mu.Unlock() + deleted := 0 + // iterate over index keys and delete the entry for k := range c.index { if strings.HasPrefix(k, path) { delete(c.index, k) + deleted++ } } + + return deleted } diff --git a/internal/server/acl/node.go b/internal/server/acl/node.go index 4135cc48..44837218 100644 --- a/internal/server/acl/node.go +++ b/internal/server/acl/node.go @@ -71,6 +71,13 @@ func (n *ACLNode) SetChild(key string, child *ACLNode) { n.version++ } +// GetChildCount returns the number of children for the node. +func (n *ACLNode) GetChildCount() int { + n.mu.RLock() + defer n.mu.RUnlock() + return len(n.children) +} + // DeleteChild deletes the child for the node. func (n *ACLNode) DeleteChild(key string) { n.mu.Lock() diff --git a/internal/server/acl/tree.go b/internal/server/acl/tree.go index a53d302a..96603704 100644 --- a/internal/server/acl/tree.go +++ b/internal/server/acl/tree.go @@ -163,7 +163,8 @@ func (t *ACLTree) RemoveRuleSet(path string) bool { var parent *ACLNode var lastPart string - parts := ACLPathSegments(path) + normalizedPath := ACLNormPath(path) + parts := ACLPathSegments(normalizedPath) currentNode := t.root for _, part := range parts { @@ -177,8 +178,12 @@ func (t *ACLTree) RemoveRuleSet(path string) bool { lastPart = part } - // Need to lock parent since we're modifying its children - parent.DeleteChild(lastPart) + // clear the rules for the node, but if it has no children, delete the whole node from it's parent + if currentNode.GetChildCount() == 0 { + parent.DeleteChild(lastPart) + } else { + currentNode.ClearRules() + } return true } diff --git a/internal/server/acl/tree_test.go b/internal/server/acl/tree_test.go index 6d7d4b9e..4291a8c7 100644 --- a/internal/server/acl/tree_test.go +++ b/internal/server/acl/tree_test.go @@ -508,32 +508,59 @@ func TestNestedRuleSetRemoval(t *testing.T) { _, err = tree.AddRuleSet(ruleset2) assert.NoError(t, err) - // Remove parent - with original behavior, this removes the entire subtree + // Remove parent - with new behavior, this only clears rules since parent has children removed := tree.RemoveRuleSet("parent") assert.True(t, removed) - // Verify both parent and child are gone (original behavior) - _, ok := tree.root.GetChild("parent") - assert.False(t, ok, "Parent node should be completely removed") + // Verify parent node still exists but rules are cleared + parentNode, ok := tree.root.GetChild("parent") + assert.True(t, ok, "Parent node should still exist after removing ruleset") + assert.NotNil(t, parentNode, "Parent node should not be nil") + assert.Nil(t, parentNode.GetRules(), "Parent node rules should be cleared") + + // Verify child still exists since parent wasn't deleted + childNode, ok := parentNode.GetChild("child") + assert.True(t, ok, "Child node should still exist") + assert.NotNil(t, childNode, "Child node should not be nil") + assert.NotNil(t, childNode.GetRules(), "Child node rules should still exist") // Add the parent ruleset back _, err = tree.AddRuleSet(ruleset1) assert.NoError(t, err) - // Add the child ruleset back - _, err = tree.AddRuleSet(ruleset2) - assert.NoError(t, err) - - // Remove just the child + // Remove just the child - this should delete the child node since it has no children removed = tree.RemoveRuleSet("parent/child") assert.True(t, removed) - // Verify parent still exists but child was removed - parentNode, ok := tree.root.GetChild("parent") + // Verify parent still exists + parentNode, ok = tree.root.GetChild("parent") assert.True(t, ok, "Parent node should exist after removing just child") assert.NotNil(t, parentNode, "Parent node should not be nil") - // Verify child was removed + // Verify child was deleted since it had no children _, ok = parentNode.GetChild("child") - assert.False(t, ok, "Child node should be removed") + assert.False(t, ok, "Child node should be deleted since it had no children") + + // Test removing a leaf node (no children) - should delete the node + leafRuleset := aclspec.NewRuleSet( + "leaf", + aclspec.SetTerminal, + aclspec.NewRule("*.log", aclspec.PrivateAccess(), aclspec.DefaultLimits()), + ) + + _, err = tree.AddRuleSet(leafRuleset) + assert.NoError(t, err) + + // Verify leaf node exists + leafNode, ok := tree.root.GetChild("leaf") + assert.True(t, ok, "Leaf node should exist") + assert.NotNil(t, leafNode, "Leaf node should not be nil") + + // Remove leaf node - should delete it since it has no children + removed = tree.RemoveRuleSet("leaf") + assert.True(t, removed) + + // Verify leaf node was deleted + _, ok = tree.root.GetChild("leaf") + assert.False(t, ok, "Leaf node should be deleted since it had no children") } From bef99ba214d7f8f559b4297aab794f6ceddfc78a Mon Sep 17 00:00:00 2001 From: Shubham Gupta <11032835+shubham3121@users.noreply.github.com> Date: Thu, 19 Jun 2025 19:41:59 +0530 Subject: [PATCH 14/19] HTTP-based RPC Message Handling System (#10) * feat(server): add HTTP message handling and send endpoint - Implement HTTP message processing with /send/msg endpoint and HttpMsg types * feat(send): add polling endpoint for HTTP responses - Add /send/poll endpoint with configurable timeout and structured responses using PollObjectRequest and SendAcknowledgment * refactor(sync): clean up and rename HTTP message functions - Remove unused code and rename handleHttp to processHttpMessage with .http.request/.http.response extensions * refactor(cors): centralize CORS middleware * fix(server): resolve CORS and send timeout issues * refactor: improve RPC message handling - Remove base64 decoding, add FromSyftURL parser and enhanced JSON marshaling * feat: add automatic request/response cleanup * feat: standardize API responses with 202 status for timeouts * feat: comprehensive URL handling system (#14) - Implement SyftBoxURL with parsing, validation, header management, and flexible routing * feat(routes): embed HTML templates * fix: normalize file path separators * refactor: remove user agent check from polling API * refactor(send): add interfaces and comprehensive tests (#27) * test(middlewares): add guest access to JWTAuth for RPC routes (#28) --------- Co-authored-by: Yash Gorana --- internal/client/middleware/cors.go | 11 +- internal/client/sync/sync_engine.go | 61 ++ internal/server/handlers/send/poll.html | 45 ++ internal/server/handlers/send/send_handler.go | 206 +++++++ .../server/handlers/send/send_handler_test.go | 542 ++++++++++++++++++ .../handlers/send/send_handler_types.go | 114 ++++ internal/server/handlers/send/send_service.go | 350 +++++++++++ .../server/handlers/send/send_service_test.go | 428 ++++++++++++++ internal/server/handlers/ws/ws_hub.go | 1 + internal/server/middlewares/jwtauth.go | 27 +- internal/server/routes.go | 24 +- internal/syftmsg/http_msg.go | 46 ++ internal/syftmsg/msg.go | 6 + internal/syftmsg/msg_type.go | 3 + internal/syftmsg/rpc_msg.go | 263 +++++++++ internal/utils/url.go | 244 ++++++++ internal/utils/url_test.go | 438 ++++++++++++++ 17 files changed, 2802 insertions(+), 7 deletions(-) create mode 100644 internal/server/handlers/send/poll.html create mode 100644 internal/server/handlers/send/send_handler.go create mode 100644 internal/server/handlers/send/send_handler_test.go create mode 100644 internal/server/handlers/send/send_handler_types.go create mode 100644 internal/server/handlers/send/send_service.go create mode 100644 internal/server/handlers/send/send_service_test.go create mode 100644 internal/syftmsg/http_msg.go create mode 100644 internal/syftmsg/rpc_msg.go create mode 100644 internal/utils/url.go create mode 100644 internal/utils/url_test.go diff --git a/internal/client/middleware/cors.go b/internal/client/middleware/cors.go index fe51011c..ac1ed2be 100644 --- a/internal/client/middleware/cors.go +++ b/internal/client/middleware/cors.go @@ -9,9 +9,14 @@ import ( // control plane cors config var corsConfig = cors.Config{ - AllowAllOrigins: true, - AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD"}, - AllowHeaders: []string{"Origin", "Content-Length", "Content-Type", "Authorization"}, + AllowAllOrigins: true, + AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD"}, + AllowHeaders: []string{ + "Origin", + "Content-Length", + "Content-Type", + "Authorization", + }, AllowCredentials: true, MaxAge: 12 * time.Hour, } diff --git a/internal/client/sync/sync_engine.go b/internal/client/sync/sync_engine.go index 95293e45..dfae471b 100644 --- a/internal/client/sync/sync_engine.go +++ b/internal/client/sync/sync_engine.go @@ -2,6 +2,7 @@ package sync import ( "context" + "encoding/json" "errors" "fmt" "log/slog" @@ -470,6 +471,8 @@ func (se *SyncEngine) handleSocketEvents(ctx context.Context) { go se.handlePriorityError(msg) case syftmsg.MsgFileWrite: go se.handlePriorityDownload(msg) + case syftmsg.MsgHttp: + go se.processHttpMessage(msg) default: slog.Debug("websocket unhandled type", "type", msg.Type) } @@ -497,6 +500,64 @@ func (se *SyncEngine) handleWatcherEvents(ctx context.Context) { } } +func (se *SyncEngine) processHttpMessage(msg *syftmsg.Message) { + slog.Debug("processHttpMessage", "msgType", msg.Type, "msgId", msg.Id) + httpMsg, ok := msg.Data.(*syftmsg.HttpMsg) + if !ok { + slog.Error("processHttpMessage: invalid type assertion for msg.Data", + "msgType", msg.Type, + "msgId", msg.Id, + "dataType", fmt.Sprintf("%T", msg.Data)) + return + } + + slog.Debug("handle", "msgType", msg.Type, "msgId", msg.Id, "httpMsg", httpMsg) + + // Unwrap the into a syftmsg.SyftRPCMessage + syftRPCMsg, err := syftmsg.NewSyftRPCMessage(*httpMsg) + if err != nil { + slog.Error("processHttpMessage: failed to create syftRPCMsg", + "error", err, + "msgType", msg.Type, + "msgId", msg.Id, + "httpMsg", httpMsg) + return + } + + // rpc message file name + fileName := syftRPCMsg.ID.String() + "." + string(httpMsg.Type) + + filePath := filepath.Join( + se.workspace.DatasiteAbsPath(syftRPCMsg.URL.ToLocalPath()), // app_data/{app_name}/rpc/{endpoint} + fileName, + ) + + slog.Debug("Received RPC message", "RPCPath", filePath) + + // Convert the syftRPCMsg to json + jsonRPCMsg, err := json.Marshal(syftRPCMsg) + if err != nil { + slog.Error("handleHttp marshal syftRPCMsg", + "error", err, + "msgId", syftRPCMsg.ID, + "filePath", filePath) + return + } + + // write the RPCMsg to the file + err = os.WriteFile(filePath, jsonRPCMsg, 0644) + if err != nil { + slog.Error("handleHttp write file", + "error", err, + "filePath", filePath, + "msgId", syftRPCMsg.ID, + "fileSize", len(jsonRPCMsg)) + return + } + + slog.Debug("SyftRPC Message", "msg", string(jsonRPCMsg), "filePath", filePath) +} + func (se *SyncEngine) handleSystem(msg *syftmsg.Message) { systemMsg := msg.Data.(syftmsg.System) slog.Info("handle", "msgType", msg.Type, "msgId", msg.Id, "serverVersion", systemMsg.SystemVersion) diff --git a/internal/server/handlers/send/poll.html b/internal/server/handlers/send/poll.html new file mode 100644 index 00000000..9447c485 --- /dev/null +++ b/internal/server/handlers/send/poll.html @@ -0,0 +1,45 @@ + + + + + + + + Processing Request - SyftBox + + + + +

Processing Your Request

+
+

Your request is being processed. This page will automatically refresh in {{ .RefreshInterval }} seconds.

+

If you are not redirected, click here to check the status.

+
+ + + \ No newline at end of file diff --git a/internal/server/handlers/send/send_handler.go b/internal/server/handlers/send/send_handler.go new file mode 100644 index 00000000..18eb6445 --- /dev/null +++ b/internal/server/handlers/send/send_handler.go @@ -0,0 +1,206 @@ +package send + +import ( + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "strconv" + + "github.com/gin-gonic/gin" +) + +// SendHandler handles HTTP requests for sending messages +type SendHandler struct { + service SendServiceInterface +} + +// New creates a new send handler +func New(msgDispatcher MessageDispatcher, msgStore RPCMsgStore) *SendHandler { + service := NewSendService(msgDispatcher, msgStore, nil) + return &SendHandler{service: service} +} + +// SendMsg handles sending a message +func (h *SendHandler) SendMsg(ctx *gin.Context) { + var req MessageRequest + + // Bind query parameters + if err := ctx.ShouldBindQuery(&req); err != nil { + ctx.PureJSON(http.StatusBadRequest, APIError{ + Error: ErrorInvalidRequest, + Message: err.Error(), + }) + return + } + + // Bind headers + if err := ctx.ShouldBindHeader(&req); err != nil { + ctx.PureJSON(http.StatusBadRequest, APIError{ + Error: ErrorInvalidRequest, + Message: err.Error(), + }) + return + } + + // Bind request method + req.Method = ctx.Request.Method + + // Bind headers + req.BindHeaders(ctx) + + // Read request body with size limit + bodyBytes, err := readRequestBody(ctx, h.service.GetConfig().MaxBodySize) + if err != nil { + ctx.PureJSON(http.StatusBadRequest, APIError{ + Error: ErrorInvalidRequest, + Message: err.Error(), + }) + return + } + + result, err := h.service.SendMessage(ctx.Request.Context(), &req, bodyBytes) + if err != nil { + slog.Error("failed to send message", "error", err) + ctx.PureJSON(http.StatusInternalServerError, APIError{ + Error: ErrorInternal, + Message: err.Error(), + }) + return + } + + if result.Response != nil { + ctx.PureJSON(result.Status, APIResponse{ + RequestID: result.RequestID, + Data: result.Response, + }) + return + } + + // add poll url as location header + ctx.Header("Location", result.PollURL) + + // return poll info + ctx.PureJSON(result.Status, APIResponse{ + RequestID: result.RequestID, + Data: PollInfo{ + PollURL: result.PollURL, + }, + Message: "Request has been accepted. Please check back later.", + }) +} + +// PollForResponse handles polling for a response +func (h *SendHandler) PollForResponse(ctx *gin.Context) { + var req PollObjectRequest + if err := ctx.ShouldBindQuery(&req); err != nil { + slog.Error("failed to bind query parameters", "error", err) + ctx.PureJSON(http.StatusBadRequest, APIError{ + Error: ErrorInvalidRequest, + Message: err.Error(), + RequestID: req.RequestID, + }) + return + } + + if err := ctx.ShouldBindHeader(&req); err != nil { + slog.Error("failed to bind headers", "error", err) + ctx.PureJSON(http.StatusBadRequest, APIError{ + Error: ErrorInvalidRequest, + Message: err.Error(), + RequestID: req.RequestID, + }) + return + } + + result, err := h.service.PollForResponse(ctx.Request.Context(), &req) + contentTypeHTML := ctx.Request.Header.Get("Content-Type") == "text/html" + + if err != nil { + if errors.Is(err, ErrPollTimeout) { + // Add poll URL to the Response header + pollURL := h.service.constructPollURL( + req.RequestID, + req.SyftURL, + req.From, + req.AsRaw, + ) + + // calculate refresh interval in seconds + var refreshInterval int + if req.Timeout > 0 { + refreshInterval = req.Timeout / 1000 + } else { + refreshInterval = h.service.GetConfig().DefaultTimeoutMs / 1000 + } + + // add poll url as location header and retry after header + ctx.Header("Location", pollURL) + ctx.Header("Retry-After", strconv.Itoa(refreshInterval)) + + if contentTypeHTML { + // Return a HTML page with a link to the poll URL + // with auto refresh capability + ctx.HTML(http.StatusAccepted, "poll.html", gin.H{ + "PollURL": pollURL, + "BaseURL": ctx.Request.Host, + "RefreshInterval": refreshInterval, // in seconds + }) + return + } else { + ctx.PureJSON(http.StatusAccepted, APIError{ + Error: ErrorTimeout, + Message: "Polling timeout reached. The request may still be processing.", + RequestID: req.RequestID, + }) + return + } + } + + if errors.Is(err, ErrNoRequest) { + ctx.PureJSON(http.StatusNotFound, APIError{ + Error: ErrorNotFound, + Message: "No request found.", + RequestID: req.RequestID, + }) + return + } + + slog.Error("failed to poll for response", "error", err) + ctx.PureJSON(http.StatusInternalServerError, APIError{ + Error: ErrorInternal, + Message: err.Error(), + RequestID: req.RequestID, + }) + return + } + + ctx.PureJSON(result.Status, APIResponse{ + RequestID: result.RequestID, + Data: result.Response, + }) +} + +// readRequestBody reads and validates the request body +func readRequestBody(ctx *gin.Context, maxSize int64) ([]byte, error) { + body := ctx.Request.Body + defer ctx.Request.Body.Close() + + // Read body bytes + bodyBytes, err := io.ReadAll(body) + if err != nil { + return nil, fmt.Errorf("failed to read request body: %w", err) + } + + // Check if body size exceeds maximum allowed size + if maxSize > 0 && int64(len(bodyBytes)) > maxSize { + return nil, fmt.Errorf( + "request body too large: %d bytes (max: %d bytes)", + len(bodyBytes), + maxSize, + ) + } + + return bodyBytes, nil +} diff --git a/internal/server/handlers/send/send_handler_test.go b/internal/server/handlers/send/send_handler_test.go new file mode 100644 index 00000000..3262f860 --- /dev/null +++ b/internal/server/handlers/send/send_handler_test.go @@ -0,0 +1,542 @@ +package send + +import ( + "context" + "embed" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/openmined/syftbox/internal/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +//go:embed *.html +var templateFS embed.FS + +// MockSendService implements the SendService interface for testing +type MockSendService struct { + mock.Mock +} + +func (m *MockSendService) SendMessage(ctx context.Context, req *MessageRequest, bodyBytes []byte) (*SendResult, error) { + args := m.Called(ctx, req, bodyBytes) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*SendResult), args.Error(1) +} + +func (m *MockSendService) PollForResponse(ctx context.Context, req *PollObjectRequest) (*PollResult, error) { + args := m.Called(ctx, req) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*PollResult), args.Error(1) +} + +func (m *MockSendService) constructPollURL(requestID string, syftURL utils.SyftBoxURL, from string, asRaw bool) string { + args := m.Called(requestID, syftURL, from, asRaw) + return args.String(0) +} + +func (m *MockSendService) GetConfig() *Config { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(*Config) +} + +// Helper function to create a test gin context +func createTestContext( + method string, + url string, + body io.Reader, + query_params map[string]string, + headers map[string]string, +) (*gin.Context, *httptest.ResponseRecorder) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + // Build URL with query parameters + if len(query_params) > 0 { + params := make([]string, 0, len(query_params)) + for k, v := range query_params { + params = append(params, k+"="+v) + } + url = url + "?" + strings.Join(params, "&") + } + + req := httptest.NewRequest(method, url, body) + + // Add headers + for k, v := range headers { + req.Header.Set(k, v) + } + + c.Request = req + return c, w +} + +func TestSendHandler_SendMsg_SuccessWithResponse(t *testing.T) { + // Setup + mockService := &MockSendService{} + handler := &SendHandler{service: mockService} + + // Test data + body := `{"key": "value"}` + query_params := map[string]string{ + "x-syft-url": "syft://test@datasite.com/app_data/testapp/rpc/endpoint", + "x-syft-from": "testuser@example.com", + "timeout": "1000", + } + headers := map[string]string{ + "Content-Type": "application/json", + } + + c, w := createTestContext("POST", "/send/msg/", strings.NewReader(body), query_params, headers) + + // Mock expectations + expectedResult := &SendResult{ + Status: http.StatusOK, + RequestID: "test-request-id", + Response: map[string]interface{}{ + "message": "success", + }, + } + + mockService.On("SendMessage", mock.Anything, mock.MatchedBy(func(req *MessageRequest) bool { + return req.From == "testuser@example.com" && req.Method == "POST" + }), []byte(body)).Return(expectedResult, nil) + mockService.On("GetConfig").Return(&Config{MaxBodySize: 4 * 1024 * 1024}) // 4MB + + // Execute + handler.SendMsg(c) + + // Assertions + assert.Equal(t, http.StatusOK, w.Code) + + var response APIResponse + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "test-request-id", response.RequestID) + assert.NotNil(t, response.Data) + + mockService.AssertExpectations(t) +} + +func TestSendHandler_SendMsg_SuccessWithPolling(t *testing.T) { + // Setup + mockService := &MockSendService{} + handler := &SendHandler{service: mockService} + + // Test data + body := `{"key": "value"}` + query_params := map[string]string{ + "x-syft-url": "syft://test@datasite.com/app_data/testapp/rpc/endpoint", + "x-syft-from": "testuser@example.com", + "timeout": "1000", + } + headers := map[string]string{ + "x-syft-url": "syft://test@datasite.com/app_data/testapp/rpc/endpoint", + "x-syft-from": "testuser@example.com", + "Content-Type": "application/json", + } + + c, w := createTestContext("POST", "/send/msg/", strings.NewReader(body), query_params, headers) + + // Mock expectations + expectedResult := &SendResult{ + Status: http.StatusAccepted, + RequestID: "test-request-id", + PollURL: "/api/v1/send/poll?x-syft-request-id=test-request-id&x-syft-url=syft://test@datasite.com/app_data/testapp/rpc/endpoint&x-syft-from=testuser@example.com&x-syft-raw=false", + } + + mockService.On("SendMessage", mock.Anything, mock.Anything, []byte(body)).Return(expectedResult, nil) + mockService.On("GetConfig").Return(&Config{MaxBodySize: 4 * 1024 * 1024}) // 4MB + + // Execute + handler.SendMsg(c) + + // Assertions + assert.Equal(t, http.StatusAccepted, w.Code) + assert.Equal(t, expectedResult.PollURL, w.Header().Get("Location")) + + var response APIResponse + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "test-request-id", response.RequestID) + assert.Equal(t, "Request has been accepted. Please check back later.", response.Message) + + // Check PollInfo in response + pollInfo, ok := response.Data.(map[string]interface{}) + assert.True(t, ok) + assert.Equal(t, expectedResult.PollURL, pollInfo["poll_url"]) + + mockService.AssertExpectations(t) +} + +func TestSendHandler_SendMsg_InvalidRequestBinding(t *testing.T) { + // Setup + mockService := &MockSendService{} + handler := &SendHandler{service: mockService} + + // Test data - missing required fields + body := `{"key": "value"}` + query_params := map[string]string{ + "x-syft-url": "syft://test@datasite.com/app_data/testapp/rpc/endpoint", + } + headers := map[string]string{ + "Content-Type": "application/json", + } + + c, w := createTestContext("POST", "/send/msg/", strings.NewReader(body), query_params, headers) + + // Execute + handler.SendMsg(c) + + // Assertions + assert.Equal(t, http.StatusBadRequest, w.Code) + + var response APIError + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, ErrorInvalidRequest, response.Error) + assert.Contains(t, response.Message, "required") +} + +func TestSendHandler_SendMsg_BodyTooLarge(t *testing.T) { + // Setup + mockService := &MockSendService{} + handler := &SendHandler{service: mockService} + + // Test data + largeBody := strings.Repeat("a", 5*1024*1024) // 5MB + query_params := map[string]string{ + "x-syft-url": "syft://test@datasite.com/app_data/testapp/rpc/endpoint", + "x-syft-from": "testuser@example.com", + } + headers := map[string]string{ + "Content-Type": "application/json", + } + + c, w := createTestContext("POST", "/send/msg/", strings.NewReader(largeBody), query_params, headers) + + // Mock expectations + mockService.On("GetConfig").Return(&Config{MaxBodySize: 4 * 1024 * 1024}) // 4MB + + // Execute + handler.SendMsg(c) + + // Assertions + assert.Equal(t, http.StatusBadRequest, w.Code) + + var response APIError + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, ErrorInvalidRequest, response.Error) + assert.Contains(t, response.Message, "too large") +} + +func TestSendHandler_SendMsg_ServiceError(t *testing.T) { + // Setup + mockService := &MockSendService{} + handler := &SendHandler{service: mockService} + + // Test data + body := `{"key": "value"}` + query_params := map[string]string{ + "x-syft-url": "syft://test@datasite.com/app_data/testapp/rpc/endpoint", + "x-syft-from": "testuser@example.com", + } + headers := map[string]string{ + "Content-Type": "application/json", + } + + c, w := createTestContext("POST", "/send/msg/", strings.NewReader(body), query_params, headers) + + // Mock expectations + mockService.On("SendMessage", mock.Anything, mock.Anything, []byte(body)).Return(nil, errors.New("service error")) + mockService.On("GetConfig").Return(&Config{MaxBodySize: 4 * 1024 * 1024}) // 4MB + // Execute + handler.SendMsg(c) + + // Assertions + assert.Equal(t, http.StatusInternalServerError, w.Code) + + var response APIError + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, ErrorInternal, response.Error) + assert.Equal(t, "service error", response.Message) + + mockService.AssertExpectations(t) +} + +func TestSendHandler_PollForResponse_Success(t *testing.T) { + // Setup + mockService := &MockSendService{} + handler := &SendHandler{service: mockService} + + // Test data + query_params := map[string]string{ + "x-syft-request-id": "test-request-id", + "x-syft-url": "syft://test@datasite.com/app_data/testapp/rpc/endpoint", + "x-syft-from": "test-user", + } + headers := map[string]string{ + "Content-Type": "application/json", + } + + c, w := createTestContext("GET", "/send/poll/", nil, query_params, headers) + + // Mock expectations + expectedResult := &PollResult{ + Status: http.StatusOK, + RequestID: "test-request-id", + Response: map[string]interface{}{ + "message": "success", + }, + } + + mockService.On("PollForResponse", mock.Anything, mock.MatchedBy(func(req *PollObjectRequest) bool { + return req.RequestID == "test-request-id" && req.From == "test-user" + })).Return(expectedResult, nil) + + // Execute + handler.PollForResponse(c) + + // Assertions + assert.Equal(t, http.StatusOK, w.Code) + + var response APIResponse + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "test-request-id", response.RequestID) + assert.NotNil(t, response.Data) + + mockService.AssertExpectations(t) +} + +func TestSendHandler_PollForResponse_TimeoutWithJSON(t *testing.T) { + // Setup + mockService := &MockSendService{} + handler := &SendHandler{service: mockService} + + // Test data + query_params := map[string]string{ + "x-syft-request-id": "test-request-id", + "x-syft-url": "syft://test@datasite.com/app_data/testapp/rpc/endpoint", + "x-syft-from": "test-user", + "timeout": "1000", + } + headers := map[string]string{ + "Content-Type": "application/json", + } + + c, w := createTestContext("GET", "/send/poll/", nil, query_params, headers) + + // Mock expectations + pollURL := "/api/v1/send/poll?x-syft-request-id=test-request-id&x-syft-url=syft://test@datasite.com/app_data/testapp/rpc/endpoint&x-syft-from=test-user&x-syft-raw=false" + + mockService.On("PollForResponse", mock.Anything, mock.Anything).Return(nil, ErrPollTimeout) + mockService.On("constructPollURL", "test-request-id", mock.Anything, "test-user", false).Return(pollURL) + + // Execute + handler.PollForResponse(c) + + // Assertions + assert.Equal(t, http.StatusAccepted, w.Code) + assert.Equal(t, pollURL, w.Header().Get("Location")) + assert.Equal(t, "1", w.Header().Get("Retry-After")) + + var response APIError + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, ErrorTimeout, response.Error) + assert.Equal(t, "test-request-id", response.RequestID) + + mockService.AssertExpectations(t) +} + +func TestSendHandler_PollForResponse_NotFound(t *testing.T) { + // Setup + mockService := &MockSendService{} + handler := &SendHandler{service: mockService} + + // Test data + query_params := map[string]string{ + "x-syft-request-id": "test-request-id", + "x-syft-url": "syft://test@datasite.com/app_data/testapp/rpc/endpoint", + "x-syft-from": "test-user", + } + headers := map[string]string{ + "x-syft-request-id": "test-request-id", + "x-syft-url": "syft://test@datasite.com/app_data/testapp/rpc/endpoint", + "x-syft-from": "test-user", + } + + c, w := createTestContext("GET", "/send/poll/", nil, query_params, headers) + + // Mock expectations + mockService.On("PollForResponse", mock.Anything, mock.Anything).Return(nil, ErrNoRequest) + + // Execute + handler.PollForResponse(c) + + // Assertions + assert.Equal(t, http.StatusNotFound, w.Code) + + var response APIError + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, ErrorNotFound, response.Error) + assert.Equal(t, "test-request-id", response.RequestID) + + mockService.AssertExpectations(t) +} + +func TestSendHandler_PollForResponse_InvalidRequest(t *testing.T) { + // Setup + mockService := &MockSendService{} + handler := &SendHandler{service: mockService} + + // Test data - missing required fields + query_params := map[string]string{ + "x-syft-from": "test-user", + } + headers := map[string]string{ + "Content-Type": "application/json", + // Missing required headers + } + + c, w := createTestContext("GET", "/send/poll/", nil, query_params, headers) + + // Execute + handler.PollForResponse(c) + + // Assertions + assert.Equal(t, http.StatusBadRequest, w.Code) + + var response APIError + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, ErrorInvalidRequest, response.Error) + assert.Contains(t, response.Message, "required") +} + +func TestSendHandler_PollForResponse_ServiceError(t *testing.T) { + // Setup + mockService := &MockSendService{} + handler := &SendHandler{service: mockService} + + // Test data + query_params := map[string]string{ + "x-syft-request-id": "test-request-id", + "x-syft-url": "syft://test@datasite.com/app_data/testapp/rpc/endpoint", + "x-syft-from": "test-user", + } + headers := map[string]string{ + "Content-Type": "application/json", + } + + c, w := createTestContext("GET", "/send/poll/", nil, query_params, headers) + + // Mock expectations + mockService.On("PollForResponse", mock.Anything, mock.Anything).Return(nil, errors.New("service error")) + + // Execute + handler.PollForResponse(c) + + // Assertions + assert.Equal(t, http.StatusInternalServerError, w.Code) + + var response APIError + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, ErrorInternal, response.Error) + assert.Equal(t, "service error", response.Message) + assert.Equal(t, "test-request-id", response.RequestID) + + mockService.AssertExpectations(t) +} + +func TestSendHandler_New(t *testing.T) { + // Setup + mockDispatcher := &MockMessageDispatcher{} + mockStore := &MockRPCMsgStore{} + + // Execute + handler := New(mockDispatcher, mockStore) + + // Assertions + assert.NotNil(t, handler) + assert.NotNil(t, handler.service) +} + +func TestReadRequestBody_Success(t *testing.T) { + // Setup + body := `{"key": "value"}` + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("POST", "/", strings.NewReader(body)) + + // Execute + result, err := readRequestBody(c, 1024*1024) // 1MB limit + + // Assertions + assert.NoError(t, err) + assert.Equal(t, []byte(body), result) +} + +func TestReadRequestBody_TooLarge(t *testing.T) { + // Setup + largeBody := strings.Repeat("a", 1025) // 1025 bytes + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("POST", "/", strings.NewReader(largeBody)) + + // Execute + result, err := readRequestBody(c, 1024) // 1KB limit + + // Assertions + assert.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "too large") +} + +func TestReadRequestBody_ReadError(t *testing.T) { + // Setup + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + // Create a request with a body that will fail to read + req := httptest.NewRequest("POST", "/", &failingReader{}) + c.Request = req + + // Execute + result, err := readRequestBody(c, 1024) + + // Assertions + assert.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "failed to read request body") +} + +// failingReader is a reader that always fails +type failingReader struct{} + +func (f *failingReader) Read(p []byte) (n int, err error) { + return 0, errors.New("read error") +} diff --git a/internal/server/handlers/send/send_handler_types.go b/internal/server/handlers/send/send_handler_types.go new file mode 100644 index 00000000..d9fb3e74 --- /dev/null +++ b/internal/server/handlers/send/send_handler_types.go @@ -0,0 +1,114 @@ +package send + +import ( + "context" + "io" + + "github.com/gin-gonic/gin" + "github.com/openmined/syftbox/internal/syftmsg" + "github.com/openmined/syftbox/internal/utils" +) + +type PollStatus string + +const ( + PollStatusPending PollStatus = "pending" + PollStatusComplete PollStatus = "complete" +) + +// Error constants +const ( + ErrorTimeout = "timeout" + ErrorInvalidRequest = "invalid_request" + ErrorInternal = "internal_error" + ErrorNotFound = "not_found" + PollURL = "/api/v1/send/poll?x-syft-request-id=%s&x-syft-url=%s&x-syft-from=%s&x-syft-raw=%t" +) + +// APIError represents a standardized error response +type APIError struct { + Error string `json:"error"` + Message string `json:"message"` + RequestID string `json:"request_id,omitempty"` +} + +// APIResponse represents a standardized success response +type APIResponse struct { + RequestID string `json:"request_id"` + Data interface{} `json:"data,omitempty"` + Message string `json:"message,omitempty"` +} + +type PollInfo struct { + PollURL string `json:"poll_url"` +} + +type Headers map[string]string + +// MessageRequest represents the request for sending a message +type MessageRequest struct { + SyftURL utils.SyftBoxURL `form:"x-syft-url" binding:"required"` // Binds to the syft url using UnmarshalParam + From string `form:"x-syft-from" binding:"required"` // The sender of the message + Timeout int `form:"timeout" binding:"gte=0"` // The timeout for the request + AsRaw bool `form:"x-syft-raw" default:"false"` // If true, the request body will be read and sent as is + Method string // Will be set from request method + Headers Headers // Will be set from request headers +} + +func (h *MessageRequest) BindHeaders(ctx *gin.Context) { + + // TODO: Filter out headers that are not allowed + h.Headers = make(Headers) + for k, v := range ctx.Request.Header { + if len(v) > 0 { + h.Headers[k] = v[0] + } + } + // Bind x-syft-from to Headers + h.Headers["x-syft-from"] = h.From +} + +// PollObjectRequest represents the request for polling +type PollObjectRequest struct { + RequestID string `form:"x-syft-request-id" binding:"required"` + From string `form:"x-syft-from" binding:"required"` + SyftURL utils.SyftBoxURL `form:"x-syft-url" binding:"required"` + Timeout int `form:"timeout,omitempty" binding:"gte=0"` + UserAgent string `form:"user-agent,omitempty"` + AsRaw bool `form:"x-syft-raw" default:"false"` // If true, the request body will be read and sent as is +} + +// SendResult represents the result of a send operation +type SendResult struct { + Status int + RequestID string + PollURL string + Response map[string]interface{} +} + +// PollResult represents the result of a poll operation +type PollResult struct { + Status int + RequestID string + Response map[string]interface{} +} + +// Message store interface for storing and retrieving messages +type RPCMsgStore interface { + StoreMsg(ctx context.Context, path string, msg syftmsg.SyftRPCMessage) error + GetMsg(ctx context.Context, path string) (io.ReadCloser, error) + DeleteMsg(ctx context.Context, path string) error +} + +// Message dispatch interface for dispatching messages to users +type MessageDispatcher interface { + Dispatch(datasite string, msg *syftmsg.Message) bool +} + +// SendServiceInterface defines the interface for the send service +type SendServiceInterface interface { + SendMessage(ctx context.Context, req *MessageRequest, bodyBytes []byte) (*SendResult, error) + PollForResponse(ctx context.Context, req *PollObjectRequest) (*PollResult, error) + constructPollURL(requestID string, syftURL utils.SyftBoxURL, from string, asRaw bool) string + GetConfig() *Config +} diff --git a/internal/server/handlers/send/send_service.go b/internal/server/handlers/send/send_service.go new file mode 100644 index 00000000..910eea09 --- /dev/null +++ b/internal/server/handlers/send/send_service.go @@ -0,0 +1,350 @@ +package send + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "path" + "time" + + "github.com/openmined/syftbox/internal/server/blob" + "github.com/openmined/syftbox/internal/server/handlers/ws" + "github.com/openmined/syftbox/internal/syftmsg" + "github.com/openmined/syftbox/internal/utils" +) + +var ( + ErrPollTimeout = errors.New("poll timeout") + ErrNoRequest = errors.New("no request found") +) + +type BlobMsgStore struct { + blob *blob.BlobService +} + +func (m *BlobMsgStore) GetMsg(ctx context.Context, path string) (io.ReadCloser, error) { + object, err := m.blob.Backend().GetObject(ctx, path) + if err != nil { + return nil, err + } + return object.Body, nil +} + +func (m *BlobMsgStore) DeleteMsg(ctx context.Context, path string) error { + _, err := m.blob.Backend().DeleteObject(ctx, path) + return err +} + +func (m *BlobMsgStore) StoreMsg(ctx context.Context, path string, msg syftmsg.SyftRPCMessage) error { + msgBytes, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + _, err = m.blob.Backend().PutObject(ctx, &blob.PutObjectParams{ + Key: path, + ETag: msg.ID.String(), + Body: bytes.NewReader(msgBytes), + Size: int64(len(msgBytes)), + }) + return err +} + +type WSMsgDispatcher struct { + hub *ws.WebsocketHub +} + +func (m *WSMsgDispatcher) Dispatch(datasite string, msg *syftmsg.Message) bool { + return m.hub.SendMessageUser(datasite, msg) +} + +func NewWSMsgDispatcher(hub *ws.WebsocketHub) MessageDispatcher { + return &WSMsgDispatcher{hub: hub} +} + +func NewBlobMsgStore(blob *blob.BlobService) RPCMsgStore { + return &BlobMsgStore{blob: blob} +} + +// SendService handles the business logic for message sending and polling +type SendService struct { + dispatcher MessageDispatcher + store RPCMsgStore + cfg *Config +} + +// Config holds the service configuration +type Config struct { + DefaultTimeoutMs int + MaxTimeoutMs int + MaxBodySize int64 + PollIntervalMs int + RequestChkTimeoutMs int +} + +// NewSendService creates a new send service +func NewSendService(dispatch MessageDispatcher, store RPCMsgStore, cfg *Config) *SendService { + if cfg == nil { + cfg = &Config{ + DefaultTimeoutMs: 1000, // 1000 ms + MaxTimeoutMs: 10000, // 10 seconds + MaxBodySize: 4 << 20, // 4MB + PollIntervalMs: 500, // 500 ms + RequestChkTimeoutMs: 200, // 200 ms + } + } + return &SendService{dispatcher: dispatch, store: store, cfg: cfg} +} + +// SendMessage handles sending a message to a user +func (s *SendService) SendMessage(ctx context.Context, req *MessageRequest, bodyBytes []byte) (*SendResult, error) { + + // Create the HTTP message + + msg := syftmsg.NewHttpMsg( + req.From, + req.SyftURL, + req.Method, + bodyBytes, + req.Headers, + syftmsg.HttpMsgTypeRequest, + ) + + httpMsg := msg.Data.(*syftmsg.HttpMsg) + + // TODO: Check if user has permission to send message to this application + + // Dispatch the message to the user via websocket + if ok := s.dispatcher.Dispatch(req.SyftURL.Datasite, msg); !ok { + // If the message is not sent via websocket, handle it as an offline message + return s.handleOfflineMessage(ctx, req, httpMsg) + } + + // If the message is sent via websocket, handle the response + return s.handleOnlineMessage(ctx, req, httpMsg) +} + +// handleOfflineMessage handles sending a message when the user is offline +func (s *SendService) handleOfflineMessage( + ctx context.Context, + req *MessageRequest, + httpMsg *syftmsg.HttpMsg, +) (*SendResult, error) { + blobPath := path.Join( + req.SyftURL.ToLocalPath(), + fmt.Sprintf("%s.%s", httpMsg.Id, httpMsg.Type), + ) + + // Create the RPC message + rpcMsg, err := syftmsg.NewSyftRPCMessage(*httpMsg) + if err != nil { + return nil, fmt.Errorf("failed to create RPCMsg: %w", err) + } + + // Save the RPC message to blob storage + if err := s.store.StoreMsg(ctx, blobPath, *rpcMsg); err != nil { + return nil, fmt.Errorf("failed to save message to blob storage: %w", err) + } + + slog.Info("saved message to blob storage", "blobPath", blobPath) + return &SendResult{ + Status: http.StatusAccepted, + RequestID: httpMsg.Id, + PollURL: s.constructPollURL(httpMsg.Id, req.SyftURL, req.From, req.AsRaw), + }, nil +} + +// handleOnlineMessage handles sending a message when the user is online +func (s *SendService) handleOnlineMessage( + ctx context.Context, + req *MessageRequest, + httpMsg *syftmsg.HttpMsg, +) (*SendResult, error) { + blobPath := path.Join( + req.SyftURL.ToLocalPath(), + fmt.Sprintf("%s.response", httpMsg.Id), + ) + + timeout := req.Timeout + if timeout <= 0 { + timeout = s.cfg.DefaultTimeoutMs + } + + object, err := s.pollForObject(ctx, blobPath, timeout) + if err != nil { + if errors.Is(err, ErrPollTimeout) { + return &SendResult{ + Status: http.StatusAccepted, + RequestID: httpMsg.Id, + PollURL: s.constructPollURL(httpMsg.Id, req.SyftURL, req.From, req.AsRaw), + }, nil + } + return nil, err + } + + // Read the object + bodyBytes, err := io.ReadAll(object) + if err != nil { + return nil, fmt.Errorf("failed to read object: %w", err) + } + + responseBody, err := unmarshalResponse(bodyBytes, req.AsRaw) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + // Clean up in background + go s.cleanReqResponse( + req.SyftURL.Datasite, + req.SyftURL.AppName, + req.SyftURL.Endpoint, + httpMsg.Id, + ) + + return &SendResult{ + Status: http.StatusOK, + RequestID: httpMsg.Id, + Response: responseBody, + }, nil +} + +// PollForResponse handles polling for a response +func (s *SendService) PollForResponse(ctx context.Context, req *PollObjectRequest) (*PollResult, error) { + + // Validate if the corresponding request exists + requestBlobPath := path.Join(req.SyftURL.ToLocalPath(), fmt.Sprintf("%s.request", req.RequestID)) + + _, err := s.pollForObject(ctx, requestBlobPath, s.cfg.RequestChkTimeoutMs) + + if err != nil { + if errors.Is(err, ErrPollTimeout) { + return nil, ErrNoRequest + } + return nil, err + } + + // Check if the corresponding response exists + responseFileName := fmt.Sprintf("%s.response", req.RequestID) + responseBlobPath := path.Join(req.SyftURL.ToLocalPath(), responseFileName) + + timeout := req.Timeout + if timeout <= 0 { + timeout = s.cfg.DefaultTimeoutMs + } + + object, err := s.pollForObject(ctx, responseBlobPath, timeout) + if err != nil { + return nil, err + } + + bodyBytes, err := io.ReadAll(object) + if err != nil { + return nil, fmt.Errorf("failed to read object: %w", err) + } + + responseBody, err := unmarshalResponse(bodyBytes, req.AsRaw) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + // Clean up in background + go s.cleanReqResponse( + req.SyftURL.Datasite, + req.SyftURL.AppName, + req.SyftURL.Endpoint, + req.RequestID, + ) + + return &PollResult{ + Status: http.StatusOK, + RequestID: req.RequestID, + Response: responseBody, + }, nil +} + +// pollForObject polls for an object in blob storage +func (s *SendService) pollForObject(ctx context.Context, blobPath string, timeout int) (io.ReadCloser, error) { + startTime := time.Now() + maxTimeout := time.Duration(timeout) * time.Millisecond + + for { + if time.Since(startTime) > maxTimeout { + return nil, ErrPollTimeout + } + + object, err := s.store.GetMsg(ctx, blobPath) + if err != nil { + slog.Error("Failed to get object from backend", "error", err, "blobPath", blobPath) + continue + } + if object != nil { + return object, nil + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Duration(s.cfg.PollIntervalMs) * time.Millisecond): + continue + } + } +} + +// cleanReqResponse cleans up request and response files +func (s *SendService) cleanReqResponse(sender, appName, appEp, requestID string) { + requestPath := path.Join(sender, "app_data", appName, "rpc", appEp, fmt.Sprintf("%s.request", requestID)) + responsePath := path.Join(sender, "app_data", appName, "rpc", appEp, fmt.Sprintf("%s.response", requestID)) + + if err := s.store.DeleteMsg(context.Background(), requestPath); err != nil { + slog.Error("failed to delete request object", "error", err, "path", requestPath) + } + + if err := s.store.DeleteMsg(context.Background(), responsePath); err != nil { + slog.Error("failed to delete response object", "error", err, "path", responsePath) + } +} + +// constructPollURL constructs the poll URL for a request +func (s *SendService) constructPollURL(requestID string, syftURL utils.SyftBoxURL, from string, asRaw bool) string { + return fmt.Sprintf( + PollURL, + requestID, + syftURL.BaseURL(), + from, + asRaw, + ) +} + +// unmarshalResponse handles the unmarshaling of a response from blob storage +// It expects the response to have a base64 encoded body field that contains JSON +func unmarshalResponse(bodyBytes []byte, asRaw bool) (map[string]interface{}, error) { + // If the request is raw, return the body as bytes + if asRaw { + var bodyJson map[string]interface{} + err := json.Unmarshal(bodyBytes, &bodyJson) + if err != nil { + return nil, fmt.Errorf("failed to marshal response: %w", err) + } + return map[string]interface{}{"message": bodyJson}, nil + } + + // Otherwise, unmarshal it as a SyftRPCMessage + var rpcMsg syftmsg.SyftRPCMessage + err := json.Unmarshal(bodyBytes, &rpcMsg) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + // decode the body if it is base64 encoded + // return the SyftRPCMessage as a different json representation + return map[string]interface{}{"message": rpcMsg.ToJsonMap()}, nil +} + +// GetConfig returns the service configuration +func (s *SendService) GetConfig() *Config { + return s.cfg +} diff --git a/internal/server/handlers/send/send_service_test.go b/internal/server/handlers/send/send_service_test.go new file mode 100644 index 00000000..b2cf147b --- /dev/null +++ b/internal/server/handlers/send/send_service_test.go @@ -0,0 +1,428 @@ +package send + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "reflect" + "strings" + "sync" + "testing" + "time" + + "github.com/google/uuid" + "github.com/openmined/syftbox/internal/syftmsg" + "github.com/openmined/syftbox/internal/utils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// MockMessageDispatcher implements MessageDispatcher for testing +type MockMessageDispatcher struct { + mock.Mock +} + +func (m *MockMessageDispatcher) Dispatch(datasite string, msg *syftmsg.Message) bool { + args := m.Called(datasite, msg) + return args.Bool(0) +} + +// MockRPCMsgStore implements RPCMsgStore for testing +type MockRPCMsgStore struct { + mock.Mock +} + +func (m *MockRPCMsgStore) StoreMsg(ctx context.Context, path string, msg syftmsg.SyftRPCMessage) error { + args := m.Called(ctx, path, msg) + return args.Error(0) +} + +func (m *MockRPCMsgStore) GetMsg(ctx context.Context, path string) (io.ReadCloser, error) { + args := m.Called(ctx, path) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(io.ReadCloser), args.Error(1) +} + +func (m *MockRPCMsgStore) DeleteMsg(ctx context.Context, path string) error { + args := m.Called(ctx, path) + return args.Error(0) +} + +func TestNewSendService(t *testing.T) { + // Test with custom config + dispatcher := &MockMessageDispatcher{} + store := &MockRPCMsgStore{} + cfg := &Config{ + DefaultTimeoutMs: 2000, + MaxTimeoutMs: 20000, + MaxBodySize: 8 << 20, + PollIntervalMs: 1000, + RequestChkTimeoutMs: 400, + } + + service := NewSendService(dispatcher, store, cfg) + assert.NotNil(t, service) + assert.Equal(t, cfg, service.cfg) + + // Test with nil config (should use defaults) + service = NewSendService(dispatcher, store, nil) + assert.NotNil(t, service) + assert.NotNil(t, service.cfg) + assert.Equal(t, 1000, service.cfg.DefaultTimeoutMs) + assert.Equal(t, 10000, service.cfg.MaxTimeoutMs) + assert.Equal(t, int64(4<<20), service.cfg.MaxBodySize) + assert.Equal(t, 500, service.cfg.PollIntervalMs) + assert.Equal(t, 200, service.cfg.RequestChkTimeoutMs) +} + +func TestSendService_SendMessage_Online(t *testing.T) { + dispatcher := &MockMessageDispatcher{} + store := &MockRPCMsgStore{} + service := NewSendService(dispatcher, store, nil) + + // Create test data + from := "test-user" + syftURL, err := utils.FromSyftURL("syft://test@datasite.com/app_data/testapp/rpc/endpoint") + assert.NoError(t, err) + method := "POST" + body := []byte(`{"key": "value"}`) + headers := map[string]string{ + "Content-Type": "application/json", + } + + // Create request + req := &MessageRequest{ + From: from, + SyftURL: *syftURL, + Method: method, + Headers: headers, + } + + // Create expected message + msg := syftmsg.NewHttpMsg( + from, + *syftURL, + method, + body, + headers, + syftmsg.HttpMsgTypeRequest, + ) + + httpMsg := msg.Data.(*syftmsg.HttpMsg) + + // Set up mock expectations + dispatcher.On("Dispatch", syftURL.Datasite, mock.MatchedBy(func(msg *syftmsg.Message) bool { + httpMsg, ok := msg.Data.(*syftmsg.HttpMsg) + if !ok { + return false + } + return httpMsg.From == from && + httpMsg.Method == method && + httpMsg.Type == syftmsg.HttpMsgTypeRequest && + bytes.Equal(httpMsg.Body, body) && + reflect.DeepEqual(httpMsg.Headers, headers) + })).Return(true) + + // Create response message + responseMsg := &syftmsg.SyftRPCMessage{ + ID: uuid.MustParse(httpMsg.Id), + Sender: "test-datasite", + URL: *syftURL, + Body: []byte(`{"response": "success"}`), + Headers: headers, + Created: time.Now().UTC(), + Expires: time.Now().UTC().Add(24 * time.Hour), + Method: syftmsg.MethodPOST, + StatusCode: syftmsg.StatusOK, + } + + // Mock GetMsg to return the response + responseBytes, err := json.Marshal(responseMsg) + assert.NoError(t, err) + store.On("GetMsg", mock.Anything, mock.Anything).Return(io.NopCloser(bytes.NewReader(responseBytes)), nil) + + // Set up expectations for DeleteMsg calls in cleanReqResponse + wg := &sync.WaitGroup{} + wg.Add(2) + + store.On("DeleteMsg", mock.Anything, mock.MatchedBy(func(path string) bool { + return strings.HasSuffix(path, ".request") + })). + Run(func(args mock.Arguments) { + wg.Done() + }). + Return(nil) + + store.On("DeleteMsg", mock.Anything, mock.MatchedBy(func(path string) bool { + return strings.HasSuffix(path, ".response") + })). + Run(func(args mock.Arguments) { + wg.Done() + }). + Return(nil) + + // Call SendMessage + result, err := service.SendMessage(context.Background(), req, body) + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, http.StatusOK, result.Status) + assert.NotEmpty(t, result.RequestID) + assert.NotNil(t, result.Response) + + // Wait for cleanup goroutine to complete + // Use a reasonable timeout to prevent test hanging + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // Cleanup completed successfully + case <-time.After(5 * time.Second): + t.Fatal("Cleanup goroutine timed out") + } + + // Verify mock expectations + dispatcher.AssertExpectations(t) + store.AssertExpectations(t) +} + +func TestSendService_SendMessage_Offline(t *testing.T) { + dispatcher := &MockMessageDispatcher{} + store := &MockRPCMsgStore{} + service := NewSendService(dispatcher, store, nil) + + // Create test data + from := "test-user" + syftURL, err := utils.FromSyftURL("syft://test@datasite.com/app_data/testapp/rpc/endpoint") + assert.NoError(t, err) + + method := "POST" + body := []byte(`{"key": "value"}`) + headers := map[string]string{ + "Content-Type": "application/json", + } + + // Create request + req := &MessageRequest{ + From: from, + SyftURL: *syftURL, + Method: method, + Headers: headers, + } + + // Set up mock expectations + dispatcher.On("Dispatch", syftURL.Datasite, mock.MatchedBy(func(msg *syftmsg.Message) bool { + httpMsg, ok := msg.Data.(*syftmsg.HttpMsg) + if !ok { + return false + } + return httpMsg.From == from && + httpMsg.Method == method && + httpMsg.Type == syftmsg.HttpMsgTypeRequest && + bytes.Equal(httpMsg.Body, body) && + reflect.DeepEqual(httpMsg.Headers, headers) + })).Return(false) + store.On("StoreMsg", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + // Call SendMessage + result, err := service.SendMessage(context.Background(), req, body) + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, http.StatusAccepted, result.Status) + assert.NotEmpty(t, result.RequestID) + assert.NotEmpty(t, result.PollURL) + + // Verify mock expectations + dispatcher.AssertExpectations(t) + store.AssertExpectations(t) +} + +func TestSendService_PollForResponse(t *testing.T) { + dispatcher := &MockMessageDispatcher{} + store := &MockRPCMsgStore{} + service := NewSendService(dispatcher, store, nil) + + // Create test data + requestID := uuid.New().String() + from := "test-user" + syftURL, err := utils.FromSyftURL("syft://test@datasite.com/app_data/testapp/rpc/endpoint") + assert.NoError(t, err) + + // Create request + req := &PollObjectRequest{ + RequestID: requestID, + From: from, + SyftURL: *syftURL, + } + + // Create response message + responseMsg := &syftmsg.SyftRPCMessage{ + ID: uuid.New(), + Sender: "test-datasite", + URL: *syftURL, + Body: []byte(`{"response": "success"}`), + Headers: map[string]string{"Content-Type": "application/json"}, + Created: time.Now().UTC(), + Expires: time.Now().UTC().Add(24 * time.Hour), + Method: syftmsg.MethodPOST, + StatusCode: syftmsg.StatusOK, + } + + // Mock GetMsg to return both request and response + requestBytes, err := json.Marshal(responseMsg) + assert.NoError(t, err) + responseBytes, err := json.Marshal(responseMsg) + assert.NoError(t, err) + + store.On("GetMsg", mock.Anything, mock.MatchedBy(func(path string) bool { + return path == syftURL.ToLocalPath()+"/"+requestID+".request" + })).Return(io.NopCloser(bytes.NewReader(requestBytes)), nil) + store.On("GetMsg", mock.Anything, mock.MatchedBy(func(path string) bool { + return path == syftURL.ToLocalPath()+"/"+requestID+".response" + })).Return(io.NopCloser(bytes.NewReader(responseBytes)), nil) + + // Set up expectations for DeleteMsg calls in cleanReqResponse + wg := &sync.WaitGroup{} + wg.Add(2) + + store.On("DeleteMsg", mock.Anything, mock.MatchedBy(func(path string) bool { + return strings.HasSuffix(path, ".request") + })). + Run(func(args mock.Arguments) { + wg.Done() + }). + Return(nil) + + store.On("DeleteMsg", mock.Anything, mock.MatchedBy(func(path string) bool { + return strings.HasSuffix(path, ".response") + })). + Run(func(args mock.Arguments) { + wg.Done() + }). + Return(nil) + + // Call PollForResponse + result, err := service.PollForResponse(context.Background(), req) + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, http.StatusOK, result.Status) + assert.Equal(t, requestID, result.RequestID) + assert.NotNil(t, result.Response) + + // Wait for cleanup goroutine to complete + // Use a reasonable timeout to prevent test hanging + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // Cleanup completed successfully + case <-time.After(5 * time.Second): + t.Fatal("Cleanup goroutine timed out") + } + + // Verify mock expectations + store.AssertExpectations(t) +} + +func TestSendService_PollForResponse_NoRequest(t *testing.T) { + dispatcher := &MockMessageDispatcher{} + store := &MockRPCMsgStore{} + service := NewSendService(dispatcher, store, nil) + + // Create test data + requestID := uuid.New().String() + from := "test-user" + syftURL, err := utils.FromSyftURL("syft://test@datasite.com/app_data/testapp/rpc/endpoint") + assert.NoError(t, err) + + // Create request + req := &PollObjectRequest{ + RequestID: requestID, + From: from, + SyftURL: *syftURL, + } + + // Mock GetMsg to return error for request + store.On("GetMsg", mock.Anything, mock.Anything).Return(nil, errors.New("not found")) + + // Call PollForResponse + result, err := service.PollForResponse(context.Background(), req) + assert.Error(t, err) + assert.Equal(t, ErrNoRequest, err) + assert.Nil(t, result) + + // Verify mock expectations + store.AssertExpectations(t) +} + +func TestSendService_PollForResponse_Timeout(t *testing.T) { + dispatcher := &MockMessageDispatcher{} + store := &MockRPCMsgStore{} + cfg := &Config{ + DefaultTimeoutMs: 100, + MaxTimeoutMs: 1000, + MaxBodySize: 4 << 20, + PollIntervalMs: 50, + RequestChkTimeoutMs: 50, + } + service := NewSendService(dispatcher, store, cfg) + + // Create test data + requestID := uuid.New().String() + from := "test-user" + syftURL, err := utils.FromSyftURL("syft://test@datasite.com/app_data/testapp/rpc/endpoint") + assert.NoError(t, err) + + // Create request + req := &PollObjectRequest{ + RequestID: requestID, + From: from, + SyftURL: *syftURL, + Timeout: 100, + } + + // Create response message for request check + responseMsg := &syftmsg.SyftRPCMessage{ + ID: uuid.New(), + Sender: "test-datasite", + URL: *syftURL, + Body: []byte(`{"response": "success"}`), + Headers: map[string]string{"Content-Type": "application/json"}, + Created: time.Now().UTC(), + Expires: time.Now().UTC().Add(24 * time.Hour), + Method: syftmsg.MethodPOST, + StatusCode: syftmsg.StatusOK, + } + + // Mock GetMsg to return request but timeout for response + requestBytes, err := json.Marshal(responseMsg) + assert.NoError(t, err) + + store.On("GetMsg", mock.Anything, mock.MatchedBy(func(path string) bool { + return path == syftURL.ToLocalPath()+"/"+requestID+".request" + })).Return(io.NopCloser(bytes.NewReader(requestBytes)), nil) + store.On("GetMsg", mock.Anything, mock.MatchedBy(func(path string) bool { + return path == syftURL.ToLocalPath()+"/"+requestID+".response" + })).Return(nil, errors.New("not found")) + + // Call PollForResponse + result, err := service.PollForResponse(context.Background(), req) + assert.Error(t, err) + assert.Equal(t, ErrPollTimeout, err) + assert.Nil(t, result) + + // Verify mock expectations + store.AssertExpectations(t) +} diff --git a/internal/server/handlers/ws/ws_hub.go b/internal/server/handlers/ws/ws_hub.go index 0f3ccf2a..a51a1d44 100644 --- a/internal/server/handlers/ws/ws_hub.go +++ b/internal/server/handlers/ws/ws_hub.go @@ -135,6 +135,7 @@ func (h *WebsocketHub) SendMessageUser(user string, msg *syftmsg.Message) bool { for _, client := range h.clients { if client.Info.User == user { + slog.Debug("sending message to client", "connId", client.ConnID, "user", user) select { case client.MsgTx <- msg: sent = true diff --git a/internal/server/middlewares/jwtauth.go b/internal/server/middlewares/jwtauth.go index 2e27dbda..9f006c2c 100644 --- a/internal/server/middlewares/jwtauth.go +++ b/internal/server/middlewares/jwtauth.go @@ -20,13 +20,20 @@ const ( // JWTAuth creates a Gin middleware function that validates access tokens. // It requires the AuthService to access token validation logic and configuration. -func JWTAuth(authService *auth.AuthService) gin.HandlerFunc { +func JWTAuth(authService *auth.AuthService, allowGuest bool) gin.HandlerFunc { if !authService.IsEnabled() { slog.Info("auth middleware disabled") return func(ctx *gin.Context) { // expect user to be an email address user := ctx.Query("user") + + // if user is not set, check if the request has a x-syft-from query parameter + if user == "" { + user = ctx.Query("x-syft-from") + } + + // check if the user is a valid email address if !utils.IsValidEmail(user) { api.AbortWithError(ctx, http.StatusUnauthorized, api.CodeInvalidRequest, fmt.Errorf("invalid email")) return @@ -36,9 +43,23 @@ func JWTAuth(authService *auth.AuthService) gin.HandlerFunc { } } - slog.Info("auth middleware enabled") - return func(ctx *gin.Context) { + // Check for guest access first if allowed + slog.Debug("Checking for guest access", "allowGuest", allowGuest) + if allowGuest { + user := ctx.Query("user") + if user == "" { + user = ctx.Query("x-syft-from") + } + slog.Debug("Attempting to access with user", "user", user) + if user == "guest@syft.org" { + ctx.Set("user", user) + ctx.Next() + return + } + } + + // Proceed with normal JWT authentication authHeaderValue := ctx.GetHeader(authHeader) if authHeaderValue == "" { api.AbortWithError(ctx, http.StatusUnauthorized, api.CodeAuthInvalidCredentials, fmt.Errorf("authorization header required")) diff --git a/internal/server/routes.go b/internal/server/routes.go index 88e565e9..fe0d4ca9 100644 --- a/internal/server/routes.go +++ b/internal/server/routes.go @@ -1,7 +1,9 @@ package server import ( + "embed" "fmt" + "html/template" "net/http" "os" @@ -14,11 +16,15 @@ import ( "github.com/openmined/syftbox/internal/server/handlers/datasite" "github.com/openmined/syftbox/internal/server/handlers/explorer" "github.com/openmined/syftbox/internal/server/handlers/install" + "github.com/openmined/syftbox/internal/server/handlers/send" "github.com/openmined/syftbox/internal/server/handlers/ws" "github.com/openmined/syftbox/internal/server/middlewares" "github.com/openmined/syftbox/internal/version" ) +//go:embed handlers/send/*.html +var templateFS embed.FS + func SetupRoutes(svc *Services, hub *ws.WebsocketHub, httpsEnabled bool) http.Handler { r := gin.New() @@ -32,6 +38,10 @@ func SetupRoutes(svc *Services, hub *ws.WebsocketHub, httpsEnabled bool) http.Ha r.Use(middlewares.HSTS()) } + // Load HTML templates from embedded filesystem + tmpl := template.Must(template.ParseFS(templateFS, "handlers/send/*.html")) + r.SetHTMLTemplate(tmpl) + // --------------------------- handlers --------------------------- blobH := blob.New(svc.Blob, svc.ACL) @@ -39,6 +49,7 @@ func SetupRoutes(svc *Services, hub *ws.WebsocketHub, httpsEnabled bool) http.Ha explorerH := explorer.New(svc.Blob, svc.ACL) authH := auth.New(svc.Auth) aclH := acl.NewACLHandler(svc.ACL) + sendH := send.New(send.NewWSMsgDispatcher(hub), send.NewBlobMsgStore(svc.Blob)) // --------------------------- routes --------------------------- @@ -62,7 +73,9 @@ func SetupRoutes(svc *Services, hub *ws.WebsocketHub, httpsEnabled bool) http.Ha } v1 := r.Group("/api/v1") - v1.Use(middlewares.JWTAuth(svc.Auth)) + + // enable auth middleware with no guest access + v1.Use(middlewares.JWTAuth(svc.Auth, false)) // v1.Use(middlewares.RateLimiter("100-S")) // todo { // blob @@ -83,6 +96,15 @@ func SetupRoutes(svc *Services, hub *ws.WebsocketHub, httpsEnabled bool) http.Ha // websocket events v1.GET("/events", hub.WebsocketHandler) + + } + + // rpc group with guest access + sendG := r.Group("/api/v1/send") + sendG.Use(middlewares.JWTAuth(svc.Auth, true)) + { + sendG.Any("/msg", sendH.SendMsg) + sendG.GET("/poll", sendH.PollForResponse) } r.NoRoute(func(c *gin.Context) { diff --git a/internal/syftmsg/http_msg.go b/internal/syftmsg/http_msg.go new file mode 100644 index 00000000..d4b061ee --- /dev/null +++ b/internal/syftmsg/http_msg.go @@ -0,0 +1,46 @@ +package syftmsg + +import ( + "github.com/google/uuid" + "github.com/openmined/syftbox/internal/utils" +) + +type HttpMsgType string + +const ( + HttpMsgTypeRequest HttpMsgType = "request" + HttpMsgTypeResponse HttpMsgType = "response" +) + +type HttpMsg struct { + From string `json:"from"` + SyftURL utils.SyftBoxURL `json:"syft_url"` + Method string `json:"method"` + Headers map[string]string `json:"headers,omitempty"` + Body []byte `json:"body,omitempty"` + Id string `json:"id,omitempty"` + Type HttpMsgType `json:"type,omitempty"` +} + +func NewHttpMsg( + from string, + syftURL utils.SyftBoxURL, + method string, + body []byte, + headers map[string]string, + msgType HttpMsgType, +) *Message { + return &Message{ + Id: generateID(), + Type: MsgHttp, + Data: &HttpMsg{ + From: from, + SyftURL: syftURL, + Method: method, + Body: body, + Headers: headers, + Id: uuid.New().String(), + Type: msgType, + }, + } +} diff --git a/internal/syftmsg/msg.go b/internal/syftmsg/msg.go index 591fb5c5..0184778a 100644 --- a/internal/syftmsg/msg.go +++ b/internal/syftmsg/msg.go @@ -59,6 +59,12 @@ func (m *Message) UnmarshalJSON(data []byte) error { return err } m.Data = fileDelete + case MsgHttp: + var httpMsg HttpMsg + if err := json.Unmarshal(temp.Data, &httpMsg); err != nil { + return err + } + m.Data = &httpMsg default: return fmt.Errorf("unknown message type: %d", m.Type) } diff --git a/internal/syftmsg/msg_type.go b/internal/syftmsg/msg_type.go index 114939da..e5f2aa80 100644 --- a/internal/syftmsg/msg_type.go +++ b/internal/syftmsg/msg_type.go @@ -11,6 +11,7 @@ const ( MsgFileDelete MsgAck MsgNack + MsgHttp ) func (t MessageType) String() string { @@ -27,6 +28,8 @@ func (t MessageType) String() string { return "ACK" case MsgNack: return "NACK" + case MsgHttp: + return "HTTP" default: return fmt.Sprintf("???(%d)", t) } diff --git a/internal/syftmsg/rpc_msg.go b/internal/syftmsg/rpc_msg.go new file mode 100644 index 00000000..5ad219c2 --- /dev/null +++ b/internal/syftmsg/rpc_msg.go @@ -0,0 +1,263 @@ +package syftmsg + +import ( + "encoding/json" + "fmt" + "time" + + "encoding/base64" + + "github.com/google/uuid" + "github.com/openmined/syftbox/internal/utils" +) + +// ValidationError represents a validation error +type ValidationError struct { + Field string + Message string +} + +func (e *ValidationError) Error() string { + return fmt.Sprintf("validation error in field %s: %s", e.Field, e.Message) +} + +// SyftMethod represents the HTTP method in the Syft protocol +type SyftMethod string + +// IsValid checks if the method is valid +func (m SyftMethod) IsValid() bool { + switch m { + case MethodGET, MethodPOST, MethodPUT, MethodDELETE: + return true + default: + return false + } +} + +// Validate validates the method +func (m SyftMethod) Validate() error { + if m == "" { + return nil + } + if !m.IsValid() { + return &ValidationError{ + Field: "method", + Message: fmt.Sprintf("invalid method: %s", m), + } + } + return nil +} + +// SyftStatus represents the status code in the Syft protocol +type SyftStatus int + +// IsValid checks if the status code is valid +func (s SyftStatus) IsValid() bool { + return s >= 100 && s <= 599 +} + +func (s SyftStatus) isDefined() bool { + return s != 0 +} + +// Validate validates the status code +func (s SyftStatus) Validate() error { + if !s.isDefined() { + return nil + } + if !s.IsValid() { + return &ValidationError{ + Field: "status_code", + Message: fmt.Sprintf("invalid status code: %d", s), + } + } + return nil +} + +const ( + // DefaultMessageExpiry is the default time in seconds before a message expires + // 1 day + DefaultMessageExpiry = 24 * 60 * 60 * time.Second + + // HTTP Methods + MethodGET SyftMethod = "GET" + MethodPOST SyftMethod = "POST" + MethodPUT SyftMethod = "PUT" + MethodDELETE SyftMethod = "DELETE" + + // Status codes + StatusOK SyftStatus = 200 +) + +// SyftMessage represents a base message for Syft protocol communication +type SyftRPCMessage struct { + // ID is the unique identifier of the message + ID uuid.UUID `json:"id"` + + // Sender is the sender of the message + Sender string `json:"sender"` + + // URL is the URL of the message + URL utils.SyftBoxURL `json:"url"` + + // Body is the body of the message in bytes + Body []byte `json:"body,omitempty"` + + // Headers contains additional headers for the message + Headers map[string]string `json:"headers"` + + // Created is the timestamp when the message was created + Created time.Time `json:"created"` + + // Expires is the timestamp when the message expires + Expires time.Time `json:"expires"` + + Method SyftMethod `json:"method,omitempty"` + + StatusCode SyftStatus `json:"status_code,omitempty"` +} + +// NewSyftMessage creates a new SyftMessage with default values +func NewSyftRPCMessage(httpMsg HttpMsg) (*SyftRPCMessage, error) { + + // Timezone is UTC by default for SyftRPC messages + now := time.Now().UTC() + + headers := httpMsg.Headers + if headers == nil { + headers = make(map[string]string) + } + + msg := &SyftRPCMessage{ + ID: uuid.MustParse(httpMsg.Id), + Sender: httpMsg.From, + URL: httpMsg.SyftURL, + Body: httpMsg.Body, + Headers: headers, + Created: now, + Expires: now.Add(time.Duration(DefaultMessageExpiry)), + Method: SyftMethod(httpMsg.Method), + } + + if err := msg.Validate(); err != nil { + return nil, err + } + + return msg, nil +} + +// MarshalJSON implements custom JSON marshaling to handle bytes as base64 +func (m *SyftRPCMessage) MarshalJSON() ([]byte, error) { + type Alias SyftRPCMessage + return json.Marshal(&struct { + *Alias + URL string `json:"url"` + Body string `json:"body,omitempty"` + }{ + Alias: (*Alias)(m), + URL: m.URL.String(), + Body: base64.URLEncoding.EncodeToString(m.Body), + }) +} + +// UnmarshalJSON implements custom JSON unmarshaling +func (m *SyftRPCMessage) UnmarshalJSON(data []byte) error { + type Alias struct { + ID uuid.UUID `json:"id"` + Sender string `json:"sender"` + URL string `json:"url"` + Body string `json:"body,omitempty"` + Headers map[string]string `json:"headers"` + Created time.Time `json:"created"` + Expires time.Time `json:"expires"` + Method SyftMethod `json:"method,omitempty"` + StatusCode SyftStatus `json:"status_code,omitempty"` + } + + var aux Alias + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + // Parse URL + url, err := utils.FromSyftURL(aux.URL) + if err != nil { + return fmt.Errorf("failed to parse URL: %w", err) + } + + // Set fields + m.ID = aux.ID + m.Sender = aux.Sender + m.URL = *url + m.Headers = aux.Headers + m.Created = aux.Created + m.Expires = aux.Expires + m.Method = aux.Method + m.StatusCode = aux.StatusCode + + // Handle body + if aux.Body != "" { + if body, err := base64.URLEncoding.DecodeString(aux.Body); err == nil { + m.Body = body + } else { + m.Body = []byte(aux.Body) + } + } + + // Validate the message + if err := m.Validate(); err != nil { + return fmt.Errorf("invalid message: %w", err) + } + + return nil +} + +// JSONString returns a properly formatted JSON string with decoded body +func (m *SyftRPCMessage) ToJsonMap() map[string]interface{} { + var bodyContent interface{} + if err := json.Unmarshal(m.Body, &bodyContent); err != nil { + bodyContent = string(m.Body) + } + + return map[string]interface{}{ + "id": m.ID, + "sender": m.Sender, + "url": m.URL.String(), + "headers": m.Headers, + "created": m.Created, + "expires": m.Expires, + "method": m.Method, + "status_code": m.StatusCode, + "body": bodyContent, + } +} + +// Validate validates the message +func (m *SyftRPCMessage) Validate() error { + if m.ID == uuid.Nil { + return &ValidationError{ + Field: "id", + Message: "id cannot be empty", + } + } + if m.Sender == "" { + return &ValidationError{ + Field: "sender", + Message: "sender cannot be empty", + } + } + if err := m.URL.Validate(); err != nil { + return err + } + + // If Method is defined, validate it + if err := m.Method.Validate(); err != nil { + return err + } + + // If StatusCode is defined, validate it + if err := m.StatusCode.Validate(); err != nil { + return err + } + return nil +} diff --git a/internal/utils/url.go b/internal/utils/url.go new file mode 100644 index 00000000..9ce85d70 --- /dev/null +++ b/internal/utils/url.go @@ -0,0 +1,244 @@ +package utils + +import ( + "fmt" + "log/slog" + "net/url" + "path/filepath" + "strings" +) + +const ( + // URL scheme and components + syftScheme = "syft://" + appDataPath = "app_data" + rpcPath = "rpc" + pathSeparator = "/" +) + +// SyftBoxURL represents a parsed syft:// URL with its components +type SyftBoxURL struct { + Datasite string `json:"datasite"` + AppName string `json:"app_name"` + Endpoint string `json:"endpoint"` + QueryParams map[string]string `json:"query_params"` +} + +// ValidationError represents a validation error with field context +type ValidationError struct { + Field string + Message string +} + +func (e *ValidationError) Error() string { + return fmt.Sprintf("validation error in field '%s': %s", e.Field, e.Message) +} + +// Syft base URL +func (s *SyftBoxURL) BaseURL() string { + endpoint := strings.Trim(s.Endpoint, pathSeparator) + return fmt.Sprintf("%s%s/%s/%s/%s/%s", + syftScheme, s.Datasite, appDataPath, s.AppName, rpcPath, endpoint) +} + +// String returns the string representation of the SyftBoxURL +func (s *SyftBoxURL) String() string { + baseURL := s.BaseURL() + + // Add query parameters if they exist + if len(s.QueryParams) > 0 { + queryParams := make([]string, 0, len(s.QueryParams)) + for key, value := range s.QueryParams { + queryParams = append(queryParams, fmt.Sprintf("%s=%s", key, url.QueryEscape(value))) + } + baseURL += "?" + strings.Join(queryParams, "&") + } + + return baseURL +} + +// ToLocalPath converts the SyftBoxURL to a local file system path +func (s *SyftBoxURL) ToLocalPath() string { + endpoint := strings.Trim(s.Endpoint, pathSeparator) + return filepath.ToSlash(filepath.Join(s.Datasite, appDataPath, s.AppName, rpcPath, endpoint)) +} + +// Validate validates the SyftBoxURL fields +func (s *SyftBoxURL) Validate() error { + if s.Datasite == "" { + return &ValidationError{ + Field: "datasite", + Message: "datasite cannot be empty", + } + } + + // Validate datasite follows email pattern + if !IsValidEmail(s.Datasite) { + return &ValidationError{ + Field: "datasite", + Message: "datasite must be a valid email address", + } + } + + if s.AppName == "" { + return &ValidationError{ + Field: "app_name", + Message: "app_name cannot be empty", + } + } + if s.Endpoint == "" { + return &ValidationError{ + Field: "endpoint", + Message: "endpoint cannot be empty", + } + } + + // Validate endpoint doesn't contain spaces or special characters + if strings.ContainsAny(s.Endpoint, " ?&=") { + return &ValidationError{ + Field: "endpoint", + Message: "endpoint cannot contain spaces or special characters (?&=)", + } + } + + return nil +} + +// UnmarshalParam implements gin.UnmarshalParam for automatic query param binding +func (s *SyftBoxURL) UnmarshalParam(param string) error { + slog.Debug("Unmarshalling syft url", "url", param) + parsed, err := FromSyftURL(param) + if err != nil { + slog.Error("Failed to parse syft url", "error", err, "url", param) + return err + } + *s = *parsed + return nil +} + +// NewSyftBoxURL creates a new SyftBoxURL with validation +func NewSyftBoxURL(datasite, appName, endpoint string) (*SyftBoxURL, error) { + syftURL := &SyftBoxURL{ + Datasite: datasite, + AppName: appName, + Endpoint: endpoint, + } + + if err := syftURL.Validate(); err != nil { + return nil, err + } + + return syftURL, nil +} + +// SetQueryParams sets the query parameters +func (s *SyftBoxURL) SetQueryParams(queryParams map[string]string) { + s.QueryParams = queryParams +} + +// parseQueryParams parses and validates query parameters from a URL +func parseQueryParams(rawQuery string) (map[string]string, error) { + if rawQuery == "" { + return nil, nil + } + + queryParams := make(map[string]string) + values, err := url.ParseQuery(rawQuery) + if err != nil { + return nil, fmt.Errorf("failed to parse query parameters: %w", err) + } + + for key, values := range values { + // Validate key doesn't contain spaces or special characters + if strings.ContainsAny(key, " ?&=") { + return nil, fmt.Errorf("query parameter key '%s' cannot contain spaces or special characters (?&=)", key) + } + // Use the first value if multiple values exist + if len(values) > 0 { + queryParams[key] = values[0] + } + } + + return queryParams, nil +} + +// FromSyftURL parses a syft URL string into a SyftBoxURL struct +func FromSyftURL(rawURL string) (*SyftBoxURL, error) { + // Parse the URL using standard library + parsedURL, err := url.Parse(rawURL) + if err != nil { + return nil, fmt.Errorf("failed to parse URL: %w", err) + } + + // Validate scheme + if parsedURL.Scheme != "syft" { + return nil, fmt.Errorf("invalid scheme: expected 'syft', got '%s'", parsedURL.Scheme) + } + + // datasite is the host of the URL + @ + username + datasite := parsedURL.Host + + if datasite == "" { + return nil, fmt.Errorf("invalid syft url: missing datasite (host)") + } + + if parsedURL.User == nil || parsedURL.User.Username() == "" { + return nil, fmt.Errorf("invalid syft url: invalid datasite name") + } + + username := parsedURL.User.Username() + datasite = username + "@" + datasite + + // Split path into components and remove empty strings + path := strings.Trim(parsedURL.Path, pathSeparator) + parts := make([]string, 0) + for _, part := range strings.Split(path, pathSeparator) { + if part != "" { + parts = append(parts, part) + } + } + + // Validate path structure + if len(parts) < 4 { + return nil, fmt.Errorf("invalid path: expected format 'app_data/app_name/rpc/endpoint'") + } + + // Validate expected structure + if parts[0] != appDataPath { + return nil, fmt.Errorf("invalid path: expected '%s' at position 1", appDataPath) + } + + // Find the index of rpc in the path + rpcIndex := -1 + for i, part := range parts { + if part == rpcPath { + rpcIndex = i + break + } + } + + if rpcIndex == -1 { + return nil, fmt.Errorf("invalid path: expected '%s' in path", rpcPath) + } + + // Extract components + appName := strings.Join(parts[1:rpcIndex], pathSeparator) + endpoint := strings.Join(parts[rpcIndex+1:], pathSeparator) + + // Create SyftBoxURL + syftURL, err := NewSyftBoxURL(datasite, appName, endpoint) + if err != nil { + return nil, fmt.Errorf("failed to create syft url from components: %w", err) + } + + // Validate query params + queryParams, err := parseQueryParams(parsedURL.RawQuery) + if err != nil { + return nil, fmt.Errorf("failed to parse query parameters: %w", err) + } + + // Set query params + syftURL.SetQueryParams(queryParams) + + return syftURL, nil +} diff --git a/internal/utils/url_test.go b/internal/utils/url_test.go new file mode 100644 index 00000000..e5121c5c --- /dev/null +++ b/internal/utils/url_test.go @@ -0,0 +1,438 @@ +package utils + +import ( + "testing" +) + +func TestFromSyftURL(t *testing.T) { + tests := []struct { + name string + url string + want *SyftBoxURL + wantErr bool + }{ + { + name: "valid basic url", + url: "syft://user@example.com/app_data/app1/rpc/endpoint1", + want: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint1", + }, + wantErr: false, + }, + { + name: "valid url with query params", + url: "syft://user@example.com/app_data/app1/rpc/endpoint1?param1=value1¶m2=value2", + want: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint1", + QueryParams: map[string]string{ + "param1": "value1", + "param2": "value2", + }, + }, + wantErr: false, + }, + { + name: "invalid scheme", + url: "http://user@example.com/app_data/app1/rpc/endpoint1", + want: nil, + wantErr: true, + }, + { + name: "missing app_data", + url: "syft://user@example.com/wrong/app1/rpc/endpoint1", + want: nil, + wantErr: true, + }, + { + name: "missing rpc", + url: "syft://user@example.com/app_data/app1/wrong/endpoint1", + want: nil, + wantErr: true, + }, + { + name: "empty url", + url: "", + want: nil, + wantErr: true, + }, + { + name: "malformed url", + url: "syft:///invalid", + want: nil, + wantErr: true, + }, + { + name: "invalid datasite format", + url: "syft://notanemail/app_data/app1/rpc/endpoint1", + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := FromSyftURL(tt.url) + if (err != nil) != tt.wantErr { + t.Errorf("FromSyftURL() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + if got.Datasite != tt.want.Datasite { + t.Errorf("FromSyftURL() Datasite = %v, want %v", got.Datasite, tt.want.Datasite) + } + if got.AppName != tt.want.AppName { + t.Errorf("FromSyftURL() AppName = %v, want %v", got.AppName, tt.want.AppName) + } + if got.Endpoint != tt.want.Endpoint { + t.Errorf("FromSyftURL() Endpoint = %v, want %v", got.Endpoint, tt.want.Endpoint) + } + if len(tt.want.QueryParams) > 0 { + if len(got.QueryParams) != len(tt.want.QueryParams) { + t.Errorf("FromSyftURL() QueryParams length = %v, want %v", len(got.QueryParams), len(tt.want.QueryParams)) + } + for k, v := range tt.want.QueryParams { + if gotVal, exists := got.QueryParams[k]; !exists || gotVal != v { + t.Errorf("FromSyftURL() QueryParams[%s] = %v, want %v", k, gotVal, v) + } + } + } + } + }) + } +} + +func TestFromSyftURL_QueryParamEncoding(t *testing.T) { + tests := []struct { + name string + url string + want *SyftBoxURL + wantErr bool + }{ + { + name: "url with encoded spaces in query param values", + url: "syft://test@example.com/app_data/app1/rpc/endpoint1?param1=value%20with%20spaces¶m2=value2", + want: &SyftBoxURL{ + Datasite: "test@example.com", + AppName: "app1", + Endpoint: "endpoint1", + QueryParams: map[string]string{ + "param1": "value with spaces", + "param2": "value2", + }, + }, + wantErr: false, + }, + { + name: "url with encoded special chars in query param values", + url: "syft://test@example.com/app_data/app1/rpc/endpoint1?param1=value%26with%26chars", + want: &SyftBoxURL{ + Datasite: "test@example.com", + AppName: "app1", + Endpoint: "endpoint1", + QueryParams: map[string]string{ + "param1": "value&with&chars", + }, + }, + wantErr: false, + }, + { + name: "url with spaces in query param keys", + url: "syft://test@example.com/app_data/app1/rpc/endpoint1?param with spaces=value1", + want: nil, + wantErr: true, + }, + { + name: "url with special chars in query param keys", + url: "syft://test@example.com/app_data/app1/rpc/endpoint1?param%26with%26chars=value1", + want: nil, + wantErr: true, + }, + { + name: "url with multiple values for same key", + url: "syft://test@example.com/app_data/app1/rpc/endpoint1?param1=value1¶m1=value2", + want: &SyftBoxURL{ + Datasite: "test@example.com", + AppName: "app1", + Endpoint: "endpoint1", + QueryParams: map[string]string{ + "param1": "value1", + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := FromSyftURL(tt.url) + if (err != nil) != tt.wantErr { + t.Errorf("FromSyftURL() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + if got.Datasite != tt.want.Datasite { + t.Errorf("FromSyftURL() Datasite = %v, want %v", got.Datasite, tt.want.Datasite) + } + if got.AppName != tt.want.AppName { + t.Errorf("FromSyftURL() AppName = %v, want %v", got.AppName, tt.want.AppName) + } + if got.Endpoint != tt.want.Endpoint { + t.Errorf("FromSyftURL() Endpoint = %v, want %v", got.Endpoint, tt.want.Endpoint) + } + if len(tt.want.QueryParams) > 0 { + if len(got.QueryParams) != len(tt.want.QueryParams) { + t.Errorf("FromSyftURL() QueryParams length = %v, want %v", len(got.QueryParams), len(tt.want.QueryParams)) + } + for k, v := range tt.want.QueryParams { + if gotVal, exists := got.QueryParams[k]; !exists || gotVal != v { + t.Errorf("FromSyftURL() QueryParams[%s] = %v, want %v", k, gotVal, v) + } + } + } + } + }) + } +} + +func TestSyftBoxURL_String(t *testing.T) { + tests := []struct { + name string + url *SyftBoxURL + want string + }{ + { + name: "basic url", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint1", + }, + want: "syft://user@example.com/app_data/app1/rpc/endpoint1", + }, + { + name: "url with query params", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint1", + QueryParams: map[string]string{ + "param1": "value1", + "param2": "value2", + }, + }, + want: "syft://user@example.com/app_data/app1/rpc/endpoint1?param1=value1¶m2=value2", + }, + { + name: "url with spaces in query param values", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint1", + QueryParams: map[string]string{ + "param1": "value with spaces", + }, + }, + want: "syft://user@example.com/app_data/app1/rpc/endpoint1?param1=value+with+spaces", + }, + { + name: "url with special chars in query param values", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint1", + QueryParams: map[string]string{ + "param1": "value&with&chars", + }, + }, + want: "syft://user@example.com/app_data/app1/rpc/endpoint1?param1=value%26with%26chars", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.url.String(); got != tt.want { + t.Errorf("SyftBoxURL.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSyftBoxURL_ToLocalPath(t *testing.T) { + tests := []struct { + name string + url *SyftBoxURL + want string + }{ + { + name: "basic path", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint1", + }, + want: "user@example.com/app_data/app1/rpc/endpoint1", + }, + { + name: "path with nested endpoint", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint1/sub/path", + }, + want: "user@example.com/app_data/app1/rpc/endpoint1/sub/path", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.url.ToLocalPath(); got != tt.want { + t.Errorf("SyftBoxURL.ToLocalPath() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSyftBoxURL_Validate(t *testing.T) { + tests := []struct { + name string + url *SyftBoxURL + wantErr bool + }{ + { + name: "valid url", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint1", + }, + wantErr: false, + }, + { + name: "empty datasite", + url: &SyftBoxURL{ + Datasite: "", + AppName: "app1", + Endpoint: "endpoint1", + }, + wantErr: true, + }, + { + name: "empty app name", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "", + Endpoint: "endpoint1", + }, + wantErr: true, + }, + { + name: "empty endpoint", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "", + }, + wantErr: true, + }, + { + name: "endpoint with spaces", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint with spaces", + }, + wantErr: true, + }, + { + name: "endpoint with special chars", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint?with=chars", + }, + wantErr: true, + }, + { + name: "invalid datasite format", + url: &SyftBoxURL{ + Datasite: "notanemail", + AppName: "app1", + Endpoint: "endpoint1", + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.url.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("SyftBoxURL.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantErr && err != nil { + validationErr, ok := err.(*ValidationError) + if !ok { + t.Errorf("expected ValidationError, got %T", err) + } + if validationErr.Field == "" { + t.Error("expected non-empty field in ValidationError") + } + } + }) + } +} + +func TestSyftBoxURL_SetQueryParams(t *testing.T) { + tests := []struct { + name string + url *SyftBoxURL + queryParams map[string]string + want map[string]string + }{ + { + name: "set query params", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint1", + }, + queryParams: map[string]string{ + "param1": "value1", + "param2": "value2", + }, + want: map[string]string{ + "param1": "value1", + "param2": "value2", + }, + }, + { + name: "set empty query params", + url: &SyftBoxURL{ + Datasite: "user@example.com", + AppName: "app1", + Endpoint: "endpoint1", + }, + queryParams: map[string]string{}, + want: map[string]string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.url.SetQueryParams(tt.queryParams) + if len(tt.url.QueryParams) != len(tt.want) { + t.Errorf("SyftBoxURL.SetQueryParams() length = %v, want %v", len(tt.url.QueryParams), len(tt.want)) + return + } + for k, v := range tt.want { + if gotVal, exists := tt.url.QueryParams[k]; !exists || gotVal != v { + t.Errorf("SyftBoxURL.SetQueryParams()[%s] = %v, want %v", k, gotVal, v) + } + } + }) + } +} From e93e15a9c5787a50020fb7bdaa5decbdf19c69ff Mon Sep 17 00:00:00 2001 From: Shubham Gupta <11032835+shubham3121@users.noreply.github.com> Date: Fri, 20 Jun 2025 12:57:17 +0530 Subject: [PATCH 15/19] Fix syft url not rendering in offline case (#29) * fix(server/handlers/send): replace json.Marshal with custom MarshalJSON for message storage * sort query params in syfturl to main consistent ordering --- internal/server/handlers/send/send_service.go | 2 +- internal/utils/url.go | 11 ++++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/internal/server/handlers/send/send_service.go b/internal/server/handlers/send/send_service.go index 910eea09..75cae99e 100644 --- a/internal/server/handlers/send/send_service.go +++ b/internal/server/handlers/send/send_service.go @@ -41,7 +41,7 @@ func (m *BlobMsgStore) DeleteMsg(ctx context.Context, path string) error { } func (m *BlobMsgStore) StoreMsg(ctx context.Context, path string, msg syftmsg.SyftRPCMessage) error { - msgBytes, err := json.Marshal(msg) + msgBytes, err := msg.MarshalJSON() if err != nil { return fmt.Errorf("failed to marshal message: %w", err) } diff --git a/internal/utils/url.go b/internal/utils/url.go index 9ce85d70..94c85958 100644 --- a/internal/utils/url.go +++ b/internal/utils/url.go @@ -5,6 +5,7 @@ import ( "log/slog" "net/url" "path/filepath" + "sort" "strings" ) @@ -47,8 +48,16 @@ func (s *SyftBoxURL) String() string { // Add query parameters if they exist if len(s.QueryParams) > 0 { + // Sort query params by key lexicographically to ensure consistent ordering + sortedKeys := make([]string, 0, len(s.QueryParams)) + for key := range s.QueryParams { + sortedKeys = append(sortedKeys, key) + } + sort.Strings(sortedKeys) + queryParams := make([]string, 0, len(s.QueryParams)) - for key, value := range s.QueryParams { + for _, key := range sortedKeys { + value := s.QueryParams[key] queryParams = append(queryParams, fmt.Sprintf("%s=%s", key, url.QueryEscape(value))) } baseURL += "?" + strings.Join(queryParams, "&") From 7ba3038ba29b8c43f6f1182bb4f7af26168293fb Mon Sep 17 00:00:00 2001 From: Shubham Gupta <11032835+shubham3121@users.noreply.github.com> Date: Mon, 23 Jun 2025 20:45:43 +0530 Subject: [PATCH 16/19] Add CD flow (#30) * cd pipeline to deploy syftbox to dev, stage, prod * add just commands to bump and release package * fix bug in show versions in justfile * integrate version release in deployment action * auto release version on deployment to prod * update deploy workflow to set remote URL with token and push tags * refactor deployment workflows to separate production and non-production processes --- .github/workflows/deploy.yml | 81 +++++++++++++++++ .github/workflows/release.yml | 131 +++++++++++++++++++++++++++ justfile | 166 +++++++++++++++++++++++++++++++++- 3 files changed, 377 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/deploy.yml create mode 100644 .github/workflows/release.yml diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml new file mode 100644 index 00000000..537206e5 --- /dev/null +++ b/.github/workflows/deploy.yml @@ -0,0 +1,81 @@ +name: Syftbox Deploy + +# This workflow deploys Syftbox to development and staging environments. +# For production releases, use the release.yml workflow instead. + +on: + workflow_dispatch: + inputs: + environment: + description: 'Environment to deploy to' + required: true + default: 'dev' + type: choice + options: + - dev + - stage + +jobs: + build-and-deploy: + # Build and deploy to target environment + runs-on: macos-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v4 + with: + go-version: '1.21' + + - name: Install just + uses: taiki-e/install-action@just + + - name: Install goreleaser + uses: goreleaser/goreleaser-action@v4 + with: + version: latest + args: --version + + - name: Setup toolchain + run: just setup-toolchain + + - name: Setup SSH + run: | + mkdir -p ~/.ssh + + # Use environment-specific SSH private key + case "${{ inputs.environment }}" in + "dev") + echo "${{ secrets.SSH_PRIVATE_KEY_DEV }}" > ~/.ssh/id_rsa + ssh-keyscan -H ${{ secrets.SSH_HOST_DEV }} >> ~/.ssh/known_hosts + ;; + "stage") + echo "${{ secrets.SSH_PRIVATE_KEY_STAGE }}" > ~/.ssh/id_rsa + ssh-keyscan -H ${{ secrets.SSH_HOST_STAGE }} >> ~/.ssh/known_hosts + ;; + *) + echo "Unknown environment: ${{ inputs.environment }}" + exit 1 + ;; + esac + + chmod 600 ~/.ssh/id_rsa + + - name: Deploy to ${{ inputs.environment }} + run: | + case "${{ inputs.environment }}" in + "dev") + REMOTE="${{ secrets.SSH_USER_DEV }}@${{ secrets.SSH_HOST_DEV }}" + ;; + "stage") + REMOTE="${{ secrets.SSH_USER_STAGE }}@${{ secrets.SSH_HOST_STAGE }}" + ;; + *) + echo "Unknown environment: ${{ inputs.environment }}" + exit 1 + ;; + esac + + just deploy remote="$REMOTE" diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 00000000..3870a38b --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,131 @@ +name: Syftbox Release + +# This workflow creates a new release and deploys to production. +# For dev/stage deployments, use the deploy.yml workflow instead. + +on: + workflow_dispatch: + inputs: + version_type: + description: 'Version type for the release' + required: true + type: choice + options: + - patch + - minor + - major + +jobs: + version: + # Handle version bumping and tagging + runs-on: macos-latest + outputs: + version: ${{ steps.bump-version.outputs.version }} + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 # Required for svu to work properly with git history + + - name: Setup Go + uses: actions/setup-go@v4 + with: + go-version: '1.21' + + - name: Install just + uses: taiki-e/install-action@just + + - name: Install svu + run: go install github.com/caarlos0/svu@latest + + - name: Install jq + run: brew install jq + + - name: Setup git config + env: + GH_TOKEN: ${{ github.token }} + run: | + git config user.email "${GITHUB_ACTOR_ID}+${GITHUB_ACTOR}@users.noreply.github.com" + git config user.name "$(gh api /users/${GITHUB_ACTOR} | jq .name -r)" + + - name: Show current version + run: | + echo "Current version information:" + just show-version + + - name: Bump version + id: bump-version + run: | + echo "Releasing version for production deployment..." + just release ${{ inputs.version_type }} + version=$(git describe --tags --abbrev=0) + echo "version=${version}" >> $GITHUB_OUTPUT + + - name: Push version changes + run: | + # Set a new remote URL using HTTPS with the github token + git remote set-url origin https://x-access-token:${{ github.token }}@github.com/${{ github.repository }}.git + + # Push the current branch to the remote repo + git push origin + + # Push the tag to the remote repo + git push origin --tags + + - name: Show new version + run: | + echo "New version information:" + just show-version + + build-and-deploy: + needs: version + # Build and deploy to production + runs-on: macos-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v4 + with: + go-version: '1.21' + + - name: Install just + uses: taiki-e/install-action@just + + - name: Install goreleaser + uses: goreleaser/goreleaser-action@v4 + with: + version: latest + args: --version + + - name: Setup toolchain + run: just setup-toolchain + + - name: Setup SSH + run: | + mkdir -p ~/.ssh + echo "${{ secrets.SSH_PRIVATE_KEY_PROD }}" > ~/.ssh/id_rsa + ssh-keyscan -H ${{ secrets.SSH_HOST_PROD }} >> ~/.ssh/known_hosts + chmod 600 ~/.ssh/id_rsa + + - name: Deploy to production + run: | + REMOTE="${{ secrets.SSH_USER_PROD }}@${{ secrets.SSH_HOST_PROD }}" + just deploy remote="$REMOTE" + + - name: Create release + uses: ncipollo/release-action@v1 + with: + tag: ${{ needs.version.outputs.version }} + name: ${{ needs.version.outputs.version }} + draft: true + allowUpdates: true + omitBodyDuringUpdate: true + makeLatest: true + generateReleaseNotes: true + artifacts: | + releases/*.tar.gz + releases/*.zip diff --git a/justfile b/justfile index 3c442b0c..63f1f526 100644 --- a/justfile +++ b/justfile @@ -1,4 +1,4 @@ -SYFTBOX_VERSION := "0.5.0" +SYFTBOX_VERSION := `svu current` BUILD_COMMIT := `git rev-parse --short HEAD` BUILD_DATE := `date -u +%Y-%m-%dT%H:%M:%SZ` BUILD_LD_FLAGS := "-s -w" + " -X github.com/openmined/syftbox/internal/version.Version=" + SYFTBOX_VERSION + " -X github.com/openmined/syftbox/internal/version.Revision=" + BUILD_COMMIT + " -X github.com/openmined/syftbox/internal/version.BuildDate=" + BUILD_DATE @@ -212,6 +212,7 @@ build-all: [group('deploy')] deploy-client remote="syftbox-yash": build-all + #!/bin/bash echo "Deploying syftbox client to {{ _cyan }}{{ remote }}{{ _nc }}" rm -rf releases && mkdir releases cp -r .out/syftbox_client_*.{tar.gz,zip} releases/ @@ -221,6 +222,7 @@ deploy-client remote="syftbox-yash": build-all [group('deploy')] deploy-server remote="syftbox-yash": build-server + #!/bin/bash echo "Deploying syftbox server to {{ _cyan }}{{ remote }}{{ _nc }}" scp .out/syftbox_server_linux_amd64_v1/syftbox_server {{ remote }}:/home/azureuser/syftbox_server_new ssh {{ remote }} "rm -fv /home/azureuser/syftbox_server && mv -fv /home/azureuser/syftbox_server_new /home/azureuser/syftbox_server" @@ -235,7 +237,169 @@ setup-toolchain: go install github.com/swaggo/swag/v2/cmd/swag@latest go install github.com/bokwoon95/wgo@latest go install filippo.io/mkcert@latest + go install github.com/caarlos0/svu@latest [group('utils')] clean: rm -rf .data .out releases certs cover.out + +[group('version')] +bump type: + #!/bin/bash + set -eou pipefail + + # Version Management Commands + # + # This project uses semantic versioning with svu (https://github.com/caarlos0/svu) + # for automatic version calculation based on git tags. + # + # Workflow: + # 1. Use `just show-version` to see current version and next versions + # 2. Use `just bump type` to update files only (manual commit/tag) + # 3. Use `just release type` to update files, commit, and tag automatically + # 4. Use `just update-version-files version=X.Y.Z` for custom versions + # + # Examples: + # just show-version # Show current and next versions + # just bump patch # Update files to next patch version + # just bump minor # Update files to next minor version + # just bump major # Update files to next major version + # just release patch # Bump, commit, and tag patch version + # just update-version-files version=1.2.3 # Set specific version + + if [ -z "{{ type }}" ]; then + echo -e "{{ _red }}Error: bump type is required{{ _nc }}" + echo "Usage: just bump " + echo "Examples:" + echo " just bump patch" + echo " just bump minor" + echo " just bump major" + exit 1 + fi + + # Validate bump type + if [[ ! "{{ type }}" =~ ^(patch|minor|major)$ ]]; then + echo -e "{{ _red }}Error: Invalid bump type '{{ type }}'{{ _nc }}" + echo "Valid types: patch, minor, major" + exit 1 + fi + + echo -e "{{ _cyan }}Bumping {{ type }} version...{{ _nc }}" + new_version=$(svu {{ type }} | sed 's/^v//') + echo -e "{{ _green }}New version: $new_version{{ _nc }}" + just update-version-files version="$new_version" + echo -e "{{ _green }}Version bumped to $new_version{{ _nc }}" + echo -e "{{ _yellow }}Don't forget to commit and tag:{{ _nc }}" + echo " git add ." + echo " git commit -m \"chore: bump version to $new_version\"" + echo " git tag v$new_version" + +release type: + #!/bin/bash + set -eou pipefail + + if [ -z "{{ type }}" ]; then + echo -e "{{ _red }}Error: release type is required{{ _nc }}" + echo "Usage: just release " + echo "Examples:" + echo " just release patch" + echo " just release minor" + echo " just release major" + exit 1 + fi + + # Validate release type + if [[ ! "{{ type }}" =~ ^(patch|minor|major)$ ]]; then + echo -e "{{ _red }}Error: Invalid release type '{{ type }}'{{ _nc }}" + echo "Valid types: patch, minor, major" + exit 1 + fi + + echo -e "{{ _cyan }}Releasing {{ type }} version...{{ _nc }}" + new_version=$(svu {{ type }} | sed 's/^v//') + echo -e "{{ _green }}New version: $new_version{{ _nc }}" + just update-version-files version="$new_version" + just commit-and-tag version="$new_version" + echo -e "{{ _green }}✓ Released {{ type }} version $new_version{{ _nc }}" + +[group('version')] +show-version: + #!/bin/bash + set -eou pipefail + echo -e "{{ _cyan }}Current version information:{{ _nc }}" + + # Try to get current version, handle errors gracefully + current_version=$(svu current 2>/dev/null || echo "No valid version tags found") + echo " SVU current: $current_version" + + # Try to get next versions, handle errors gracefully + next_patch=$(svu patch 2>/dev/null || echo "Error") + next_minor=$(svu minor 2>/dev/null || echo "Error") + next_major=$(svu major 2>/dev/null || echo "Error") + + echo " SVU next patch: $next_patch" + echo " SVU next minor: $next_minor" + echo " SVU next major: $next_major" + echo " Git tags:" + git tag --sort=-version:refname | head -5 + +[group('version')] +commit-and-tag version: + #!/bin/bash + set -eou pipefail + + # Extract version from parameter (handle both "version=0.5.1" and "0.5.1" formats) + version_value="{{ version }}" + if [[ "$version_value" == version=* ]]; then + version_value="${version_value#version=}" + fi + + if [ -z "$version_value" ]; then + echo -e "{{ _red }}Error: version parameter is required{{ _nc }}" + echo "Usage: just commit-and-tag version=1.2.3" + exit 1 + fi + + echo -e "{{ _cyan }}Committing and tagging version $version_value...{{ _nc }}" + + # Check if there are changes to commit + if git diff --quiet && git diff --cached --quiet; then + echo -e "{{ _yellow }}No changes to commit{{ _nc }}" + else + git add . + git commit -m "chore: bump version to $version_value" + echo -e "{{ _green }}✓ Committed changes{{ _nc }}" + fi + + # Create tag + git tag v$version_value + echo -e "{{ _green }}✓ Tagged v$version_value{{ _nc }}" + + echo -e "{{ _green }}Version $version_value has been committed and tagged!{{ _nc }}" + +[group('version')] +update-version-files version: + #!/bin/bash + set -eou pipefail + + # Extract version from parameter (handle both "version=0.5.1" and "0.5.1" formats) + version_value="{{ version }}" + if [[ "$version_value" == version=* ]]; then + version_value="${version_value#version=}" + fi + + if [ -z "$version_value" ]; then + echo -e "{{ _red }}Error: version parameter is required{{ _nc }}" + echo "Usage: just update-version-files version=1.2.3" + exit 1 + fi + + echo -e "{{ _cyan }}Updating version to $version_value in all files...{{ _nc }}" + + # Update goreleaser.yaml + sed -i "s/-X github.com\/openmined\/syftbox\/internal\/version.Version=.*/-X github.com\/openmined\/syftbox\/internal\/version.Version=$version_value/g" .goreleaser.yaml + echo -e "{{ _green }}✓ Updated .goreleaser.yaml{{ _nc }}" + + # Update version.go + sed -i "s/Version = \".*\"/Version = \"$version_value\"/" internal/version/version.go + echo -e "{{ _green }}✓ Updated internal/version/version.go{{ _nc }}" From 5f846e2e15e684e21b6f02509a27c84db8293c84 Mon Sep 17 00:00:00 2001 From: Shubham Gupta <11032835+shubham3121@users.noreply.github.com> Date: Mon, 23 Jun 2025 21:35:47 +0530 Subject: [PATCH 17/19] refactor(build): streamline version variable calculations in justfile (#32) --- justfile | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/justfile b/justfile index 63f1f526..06f83599 100644 --- a/justfile +++ b/justfile @@ -1,7 +1,3 @@ -SYFTBOX_VERSION := `svu current` -BUILD_COMMIT := `git rev-parse --short HEAD` -BUILD_DATE := `date -u +%Y-%m-%dT%H:%M:%SZ` -BUILD_LD_FLAGS := "-s -w" + " -X github.com/openmined/syftbox/internal/version.Version=" + SYFTBOX_VERSION + " -X github.com/openmined/syftbox/internal/version.Revision=" + BUILD_COMMIT + " -X github.com/openmined/syftbox/internal/version.BuildDate=" + BUILD_DATE CLIENT_BUILD_TAGS := "go_json nomsgpack" SERVER_BUILD_TAGS := "sonic avx nomsgpack" @@ -179,14 +175,21 @@ test: [doc('Needs a platform specific compiler. Example: CC="aarch64-linux-musl-gcc" just build-client-target goos=linux goarch=arm64')] [group('build')] -build-client-target goos=`go env GOOS` goarch=`go env GOARCH`: +build-client-target goos=`go env GOOS` goarch=`go env GOARCH`: version-utils #!/bin/bash set -eou pipefail + # Calculate build variables locally + SYFTBOX_VERSION=$(svu current 2>/dev/null) + echo "SYFTBOX_VERSION: $SYFTBOX_VERSION" + BUILD_COMMIT=$(git rev-parse --short HEAD) + BUILD_DATE=$(date -u +%Y-%m-%dT%H:%M:%SZ) + BUILD_LD_FLAGS="-s -w -X github.com/openmined/syftbox/internal/version.Version=$SYFTBOX_VERSION -X github.com/openmined/syftbox/internal/version.Revision=$BUILD_COMMIT -X github.com/openmined/syftbox/internal/version.BuildDate=$BUILD_DATE" + export GOOS="{{ goos }}" export GOARCH="{{ goarch }}" export CGO_ENABLED=0 - export GO_LDFLAGS="$([ '{{ goos }}' = 'windows' ] && echo '-H windowsgui '){{ BUILD_LD_FLAGS }}" + export GO_LDFLAGS="$([ '{{ goos }}' = 'windows' ] && echo '-H windowsgui ')$BUILD_LD_FLAGS" if [ "{{ goos }}" = "darwin" ]; then echo "Building for darwin. CGO_ENABLED=1" @@ -237,14 +240,13 @@ setup-toolchain: go install github.com/swaggo/swag/v2/cmd/swag@latest go install github.com/bokwoon95/wgo@latest go install filippo.io/mkcert@latest - go install github.com/caarlos0/svu@latest [group('utils')] clean: rm -rf .data .out releases certs cover.out [group('version')] -bump type: +bump type: version-utils #!/bin/bash set -eou pipefail @@ -294,7 +296,7 @@ bump type: echo " git commit -m \"chore: bump version to $new_version\"" echo " git tag v$new_version" -release type: +release type: version-utils #!/bin/bash set -eou pipefail @@ -323,7 +325,7 @@ release type: echo -e "{{ _green }}✓ Released {{ type }} version $new_version{{ _nc }}" [group('version')] -show-version: +show-version: version-utils #!/bin/bash set -eou pipefail echo -e "{{ _cyan }}Current version information:{{ _nc }}" @@ -403,3 +405,7 @@ update-version-files version: # Update version.go sed -i "s/Version = \".*\"/Version = \"$version_value\"/" internal/version/version.go echo -e "{{ _green }}✓ Updated internal/version/version.go{{ _nc }}" + +[group('version')] +version-utils: + go install github.com/caarlos0/svu@latest \ No newline at end of file From 34966ffa5815b63ea5c5746a26fb032f50ae8124 Mon Sep 17 00:00:00 2001 From: Shubham Gupta <11032835+shubham3121@users.noreply.github.com> Date: Mon, 23 Jun 2025 22:25:26 +0530 Subject: [PATCH 18/19] fix: go releaser installation in workflow (#33) --- .github/workflows/deploy.yml | 13 ++++++------- .github/workflows/release.yml | 9 ++++----- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 537206e5..5acca247 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -32,12 +32,11 @@ jobs: - name: Install just uses: taiki-e/install-action@just - - name: Install goreleaser - uses: goreleaser/goreleaser-action@v4 - with: - version: latest - args: --version - + - name: Install GoReleaser + run: | + brew install goreleaser/tap/goreleaser + goreleaser --version + - name: Setup toolchain run: just setup-toolchain @@ -62,7 +61,7 @@ jobs: esac chmod 600 ~/.ssh/id_rsa - + - name: Deploy to ${{ inputs.environment }} run: | case "${{ inputs.environment }}" in diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 3870a38b..77fe7514 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -95,11 +95,10 @@ jobs: - name: Install just uses: taiki-e/install-action@just - - name: Install goreleaser - uses: goreleaser/goreleaser-action@v4 - with: - version: latest - args: --version + - name: Install GoReleaser + run: | + brew install goreleaser/tap/goreleaser + goreleaser --version - name: Setup toolchain run: just setup-toolchain From 46191fc3b18656a6cc4d59a635549599050835f5 Mon Sep 17 00:00:00 2001 From: Shubham Gupta <11032835+shubham3121@users.noreply.github.com> Date: Tue, 24 Jun 2025 00:08:20 +0530 Subject: [PATCH 19/19] Fix/goreleaser installation (#34) * fix: go releaser installation in workflow * use cask arg in homebrew * pass remote username without remote= var --- .github/workflows/deploy.yml | 5 +++-- .github/workflows/release.yml | 5 +++-- justfile | 8 +++++--- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 5acca247..2e3ec0c9 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -34,7 +34,7 @@ jobs: - name: Install GoReleaser run: | - brew install goreleaser/tap/goreleaser + brew install --cask goreleaser/tap/goreleaser goreleaser --version - name: Setup toolchain @@ -60,6 +60,7 @@ jobs: ;; esac + chmod 700 ~/.ssh chmod 600 ~/.ssh/id_rsa - name: Deploy to ${{ inputs.environment }} @@ -77,4 +78,4 @@ jobs: ;; esac - just deploy remote="$REMOTE" + just deploy $REMOTE diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 77fe7514..5038b436 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -97,7 +97,7 @@ jobs: - name: Install GoReleaser run: | - brew install goreleaser/tap/goreleaser + brew install --cask goreleaser/tap/goreleaser goreleaser --version - name: Setup toolchain @@ -109,11 +109,12 @@ jobs: echo "${{ secrets.SSH_PRIVATE_KEY_PROD }}" > ~/.ssh/id_rsa ssh-keyscan -H ${{ secrets.SSH_HOST_PROD }} >> ~/.ssh/known_hosts chmod 600 ~/.ssh/id_rsa + chmod 700 ~/.ssh - name: Deploy to production run: | REMOTE="${{ secrets.SSH_USER_PROD }}@${{ secrets.SSH_HOST_PROD }}" - just deploy remote="$REMOTE" + just deploy $REMOTE - name: Create release uses: ncipollo/release-action@v1 diff --git a/justfile b/justfile index 06f83599..0669183c 100644 --- a/justfile +++ b/justfile @@ -214,9 +214,10 @@ build-all: goreleaser release --snapshot --clean [group('deploy')] -deploy-client remote="syftbox-yash": build-all +deploy-client remote: build-all #!/bin/bash echo "Deploying syftbox client to {{ _cyan }}{{ remote }}{{ _nc }}" + rm -rf releases && mkdir releases cp -r .out/syftbox_client_*.{tar.gz,zip} releases/ ssh {{ remote }} "rm -rfv /home/azureuser/releases.new && mkdir -p /home/azureuser/releases.new" @@ -224,15 +225,16 @@ deploy-client remote="syftbox-yash": build-all ssh {{ remote }} "rm -rfv /home/azureuser/releases/ && mv -fv /home/azureuser/releases.new/ /home/azureuser/releases/" [group('deploy')] -deploy-server remote="syftbox-yash": build-server +deploy-server remote: build-server #!/bin/bash echo "Deploying syftbox server to {{ _cyan }}{{ remote }}{{ _nc }}" + scp .out/syftbox_server_linux_amd64_v1/syftbox_server {{ remote }}:/home/azureuser/syftbox_server_new ssh {{ remote }} "rm -fv /home/azureuser/syftbox_server && mv -fv /home/azureuser/syftbox_server_new /home/azureuser/syftbox_server" ssh {{ remote }} "sudo systemctl restart syftbox" [group('deploy')] -deploy remote="syftbox-yash": (deploy-client remote) (deploy-server remote) +deploy remote: (deploy-client remote) (deploy-server remote) echo "Deployed syftbox client & server to {{ _cyan }}{{ remote }}{{ _nc }}" [group('utils')]