diff --git a/.env.example b/.env.example index a2a503d..2adccc5 100644 --- a/.env.example +++ b/.env.example @@ -58,6 +58,7 @@ VM_ORCH_BASE_URL=http://localhost:8082 VM_ORCH_TIMEOUT=5s VM_CREATE_WINDOW=1m VM_CREATE_MAX=1 +VM_CLEANUP_INTERVAL=30m # Logging LOG_DIR=logs diff --git a/cmd/server/main.go b/cmd/server/main.go index 84716ca..4258977 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -127,6 +127,10 @@ func main() { ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() + if cfg.VM.Enabled { + vmSvc.StartTTLReaper(ctx, cfg.VM.CleanupInterval) + } + leaderboardBus := realtime.NewScoreboardBus(redisClient, cfg, scoreSvc, logger) leaderboardBus.Start(ctx) diff --git a/docs/docs/vms.md b/docs/docs/vms.md index 02c0ac3..03813d9 100644 --- a/docs/docs/vms.md +++ b/docs/docs/vms.md @@ -8,9 +8,9 @@ Notes: - VM APIs are backed by `sandboxd-orch` (HTTP API). - Legacy Stack APIs remain available, but VM APIs are independent. - For authenticated `POST`, `PUT`, `PATCH`, and `DELETE` requests, send both `csrf_token` cookie and matching `X-CSRF-Token` header. -- VM row deletion from wargame DB is explicit through user/admin delete endpoints. +- VM rows with expired `ttl_expires_at` are deleted automatically by a background cleanup loop (`VM_CLEANUP_INTERVAL`, default `30m`). - VM rows are not deleted automatically after a correct flag submission. -- VM rows are not deleted automatically just because orchestrator status changed to error or missing. +- VM rows are not deleted automatically just because orchestrator status changed to error; however, if a refresh/read finds that the orchestrator sandbox is missing, the VM row can be removed automatically, in addition to the TTL cleanup above (user/admin delete endpoints can also remove rows). ## VM Response Schema @@ -128,6 +128,7 @@ Create behavior: - Solved users can still create or recreate VMs for the same challenge. - Wargame parses the YAML, rewrites `id` to generated `vm_id`, and sends the rewritten spec to orchestrator. - If an existing VM row already exists for `(user_id, challenge_id)`, create returns the existing VM instead of creating a second one. +- Before create/limit checks, wargame removes the caller's expired VM rows from DB using `ttl_expires_at` (rows with `ttl_expires_at = null` are kept). - User VM cap and VM creation rate limit are enforced server-side. --- diff --git a/internal/config/config.go b/internal/config/config.go index b870d17..9fcf80c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -111,6 +111,7 @@ type VMConfig struct { OrchestratorTimeout time.Duration CreateWindow time.Duration CreateMax int + CleanupInterval time.Duration } const defaultJWTSecret = "change-me" @@ -297,6 +298,10 @@ func Load() (Config, error) { if err != nil { errs = append(errs, err) } + vmCleanupInterval, err := getDuration("VM_CLEANUP_INTERVAL", 30*time.Minute) + if err != nil { + errs = append(errs, err) + } cfg := Config{ AppEnv: appEnv, @@ -385,6 +390,7 @@ func Load() (Config, error) { OrchestratorTimeout: vmTimeout, CreateWindow: vmCreateWindow, CreateMax: vmCreateMax, + CleanupInterval: vmCleanupInterval, }, } @@ -592,6 +598,9 @@ func validateConfig(cfg Config) error { if cfg.VM.CreateMax <= 0 { errs = append(errs, errors.New("VM_CREATE_MAX must be positive")) } + if cfg.VM.CleanupInterval <= 0 { + errs = append(errs, errors.New("VM_CLEANUP_INTERVAL must be positive")) + } } if len(errs) == 0 { @@ -734,6 +743,7 @@ func FormatForLog(cfg Config) map[string]any { "orchestrator_timeout": seconds(cfg.VM.OrchestratorTimeout), "create_window": seconds(cfg.VM.CreateWindow), "create_max": cfg.VM.CreateMax, + "cleanup_interval": seconds(cfg.VM.CleanupInterval), }, } } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 4fddd96..0bbefaf 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -74,6 +74,10 @@ func TestLoadConfigDefaults(t *testing.T) { if cfg.Stack.CreateMax != 1 { t.Errorf("expected Stack.CreateMax 1, got %d", cfg.Stack.CreateMax) } + + if cfg.VM.CleanupInterval != 30*time.Minute { + t.Errorf("expected VM.CleanupInterval 30m, got %v", cfg.VM.CleanupInterval) + } } func TestLoadConfigCustomValues(t *testing.T) { @@ -121,6 +125,7 @@ func TestLoadConfigCustomValues(t *testing.T) { os.Setenv("STACKS_PROVISIONER_TIMEOUT", "9s") os.Setenv("STACKS_CREATE_WINDOW", "2m") os.Setenv("STACKS_CREATE_MAX", "2") + os.Setenv("VM_CLEANUP_INTERVAL", "45m") defer os.Clearenv() @@ -223,6 +228,9 @@ func TestLoadConfigCustomValues(t *testing.T) { if cfg.Stack.MaxPer != 5 { t.Errorf("expected Stack.MaxPer 5, got %d", cfg.Stack.MaxPer) } + if cfg.VM.CleanupInterval != 45*time.Minute { + t.Errorf("expected VM.CleanupInterval 45m, got %v", cfg.VM.CleanupInterval) + } } func TestLoadConfigInvalidValues(t *testing.T) { @@ -248,6 +256,7 @@ func TestLoadConfigInvalidValues(t *testing.T) { {"invalid s3 media force path", "S3_MEDIA_FORCE_PATH_STYLE", "bad-bool"}, {"invalid stack max scope", "STACKS_MAX_SCOPE", "org"}, {"invalid leaderboard cache ttl", "LEADERBOARD_CACHE_TTL", "bad-duration"}, + {"invalid vm cleanup interval", "VM_CLEANUP_INTERVAL", "bad-duration"}, } for _, tt := range tests { @@ -722,6 +731,64 @@ func TestValidateConfigAdditionalValidation(t *testing.T) { } } +func TestValidateConfigInvalidVMConfig(t *testing.T) { + cfg := Config{ + AppEnv: "local", + HTTPAddr: ":8080", + BcryptCost: bcrypt.DefaultCost, + DB: DBConfig{ + Host: "localhost", + Port: 5432, + User: "user", + Name: "db", + MaxOpenConns: 10, + MaxIdleConns: 5, + ConnMaxLifetime: time.Minute, + }, + Redis: RedisConfig{ + Addr: "localhost:6379", + PoolSize: 10, + }, + JWT: JWTConfig{ + Secret: "secret", + Issuer: "issuer", + AccessTTL: time.Hour, + RefreshTTL: 24 * time.Hour, + }, + Security: SecurityConfig{ + SubmissionWindow: time.Minute, + SubmissionMax: 10, + }, + Cache: CacheConfig{ + TimelineTTL: time.Minute, + LeaderboardTTL: time.Minute, + }, + Logging: LoggingConfig{ + Dir: "logs", + FilePrefix: "app", + MaxBodyBytes: 1024, + }, + VM: VMConfig{ + Enabled: true, + MaxPer: 1, + OrchestratorBaseURL: "http://localhost:8082", + OrchestratorTimeout: time.Second, + CreateWindow: time.Minute, + CreateMax: 1, + CleanupInterval: 0, + }, + } + + err := validateConfig(cfg) + if err == nil { + t.Fatal("expected vm validation error") + } + + if !strings.Contains(err.Error(), "VM_CLEANUP_INTERVAL") { + t.Fatalf("expected VM_CLEANUP_INTERVAL error, got %v", err) + } +} + func TestRedact(t *testing.T) { cfg := Config{ DB: DBConfig{Password: "dbpass"}, diff --git a/internal/repo/vm_repo.go b/internal/repo/vm_repo.go index 9ac2c44..693e21d 100644 --- a/internal/repo/vm_repo.go +++ b/internal/repo/vm_repo.go @@ -2,6 +2,7 @@ package repo import ( "context" + "time" "wargame/internal/models" @@ -125,3 +126,21 @@ func (r *VMRepo) Delete(ctx context.Context, vm *models.VM) error { return nil } + +func (r *VMRepo) DeleteExpired(ctx context.Context, now time.Time) (int64, error) { + res, err := r.db.NewDelete(). + Model((*models.VM)(nil)). + Where("ttl_expires_at IS NOT NULL"). + Where("ttl_expires_at <= ?", now). + Exec(ctx) + if err != nil { + return 0, wrapError("vmRepo.DeleteExpired", err) + } + + affected, err := res.RowsAffected() + if err != nil { + return 0, wrapError("vmRepo.DeleteExpired rows affected", err) + } + + return affected, nil +} diff --git a/internal/repo/vm_repo_test.go b/internal/repo/vm_repo_test.go index 10b8053..2a1b803 100644 --- a/internal/repo/vm_repo_test.go +++ b/internal/repo/vm_repo_test.go @@ -142,3 +142,69 @@ func TestVMRepoNotFound(t *testing.T) { t.Fatalf("expected ErrNotFound, got %v", err) } } + +func TestVMRepoDeleteExpired(t *testing.T) { + env := setupRepoTest(t) + vmRepo := NewVMRepo(env.db) + user := createUserForTestUserScope(t, env, "vmexpired@example.com", "vmexpired", "pass", models.UserRole) + challenge1 := createChallenge(t, env, "VM Expired 1", 100, "flag{vmexpired1}", true) + challenge2 := createChallenge(t, env, "VM Expired 2", 100, "flag{vmexpired2}", true) + challenge3 := createChallenge(t, env, "VM Expired 3", 100, "flag{vmexpired3}", true) + + now := time.Now().UTC() + expired := &models.VM{ + UserID: user.ID, + ChallengeID: challenge1.ID, + VMID: "vm-expired", + Status: "Running", + TTLExpiresAt: ptrTime(now.Add(-time.Minute)), + CreatedAt: now.Add(-time.Hour), + UpdatedAt: now.Add(-time.Hour), + } + + active := &models.VM{ + UserID: user.ID, + ChallengeID: challenge2.ID, + VMID: "vm-active", + Status: "Running", + TTLExpiresAt: ptrTime(now.Add(time.Hour)), + CreatedAt: now.Add(-time.Hour), + UpdatedAt: now.Add(-time.Hour), + } + + noTTL := &models.VM{ + UserID: user.ID, + ChallengeID: challenge3.ID, + VMID: "vm-nottl", + Status: "Running", + CreatedAt: now.Add(-time.Hour), + UpdatedAt: now.Add(-time.Hour), + } + + for _, row := range []*models.VM{expired, active, noTTL} { + if err := vmRepo.Create(context.Background(), row); err != nil { + t.Fatalf("create vm(%s): %v", row.VMID, err) + } + } + + deleted, err := vmRepo.DeleteExpired(context.Background(), now) + if err != nil { + t.Fatalf("DeleteExpired: %v", err) + } + + if deleted != 1 { + t.Fatalf("expected 1 deleted row, got %d", deleted) + } + + if _, err := vmRepo.GetByVMID(context.Background(), "vm-expired"); !errors.Is(err, ErrNotFound) { + t.Fatalf("expected expired row deleted, got err=%v", err) + } + + if _, err := vmRepo.GetByVMID(context.Background(), "vm-active"); err != nil { + t.Fatalf("expected active row to remain, got %v", err) + } + + if _, err := vmRepo.GetByVMID(context.Background(), "vm-nottl"); err != nil { + t.Fatalf("expected null-ttl row to remain, got %v", err) + } +} diff --git a/internal/service/vm_service.go b/internal/service/vm_service.go index f772df7..0eff429 100644 --- a/internal/service/vm_service.go +++ b/internal/service/vm_service.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "log/slog" "strconv" "strings" "time" @@ -127,6 +128,10 @@ func (s *VMService) GetOrCreateVM(ctx context.Context, userID, challengeID int64 return nil, err } + if err := s.cleanupExpiredUserVMs(ctx, userID); err != nil { + return nil, err + } + existing, err := s.findExistingVM(ctx, userID, challengeID) if err != nil { return nil, err @@ -273,6 +278,67 @@ func (s *VMService) ensureUserLimit(ctx context.Context, userID int64) error { return nil } +func (s *VMService) cleanupExpiredUserVMs(ctx context.Context, userID int64) error { + rows, err := s.vmRepo.ListByUser(ctx, userID) + if err != nil { + return fmt.Errorf("vm.cleanupExpiredUserVMs list: %w", err) + } + + now := time.Now().UTC() + for i := range rows { + if rows[i].TTLExpiresAt == nil { + continue + } + + if rows[i].TTLExpiresAt.After(now) { + continue + } + + if err := s.vmRepo.Delete(ctx, &rows[i]); err != nil { + return fmt.Errorf("vm.cleanupExpiredUserVMs delete: %w", err) + } + } + + return nil +} + +func (s *VMService) CleanupExpiredVMs(ctx context.Context) (int64, error) { + deleted, err := s.vmRepo.DeleteExpired(ctx, time.Now().UTC()) + if err != nil { + return 0, fmt.Errorf("vm.CleanupExpiredVMs: %w", err) + } + + return deleted, nil +} + +func (s *VMService) StartTTLReaper(ctx context.Context, interval time.Duration) { + if interval <= 0 { + slog.Warn("vm ttl reaper disabled due to non-positive interval", slog.Duration("interval", interval)) + return + } + + ticker := time.NewTicker(interval) + go func() { + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + deleted, err := s.CleanupExpiredVMs(ctx) + if err != nil { + slog.Warn("vm ttl cleanup failed", slog.Any("error", err)) + continue + } + + if deleted > 0 { + slog.Info("vm ttl cleanup completed", slog.Int64("deleted", deleted)) + } + } + } + }() +} + func (s *VMService) createVM(ctx context.Context, userID, challengeID int64, spec string) (*models.VM, error) { vmID := newVMID(userID, challengeID) sandbox, err := s.client.CreateSandbox(ctx, vmID, spec) diff --git a/internal/service/vm_service_test.go b/internal/service/vm_service_test.go index 921e2cf..069b51b 100644 --- a/internal/service/vm_service_test.go +++ b/internal/service/vm_service_test.go @@ -267,3 +267,330 @@ func TestVMServiceAdminListDoesNotRefreshRows(t *testing.T) { t.Fatalf("expected vm row to remain in db without refresh, got %v", err) } } + +func TestVMServiceGetOrCreateVMCleansUpExpiredRowsBeforeLimit(t *testing.T) { + env := setupServiceTest(t) + user := createUser(t, env, "vm-cleanup-limit@example.com", "vm-cleanup-limit", "pass", models.UserRole) + challenge1 := createVMChallenge(t, env, "vm-cleanup-limit-1") + challenge2 := createVMChallenge(t, env, "vm-cleanup-limit-2") + svc, vmRepo := newVMServiceForTest(env, &vm.MockClient{ + CreateSandboxFn: func(ctx context.Context, id string, specYAML string) (*vm.Sandbox, error) { + exp := time.Now().UTC().Add(time.Hour) + return &vm.Sandbox{ID: id, Status: vm.SandboxStatus{Phase: "Pending", ExpireAt: &exp}}, nil + }, + GetSandboxFn: func(ctx context.Context, id string) (*vm.Sandbox, error) { + return &vm.Sandbox{ID: id, Status: vm.SandboxStatus{Phase: "Running"}}, nil + }, + }, config.VMConfig{Enabled: true, MaxPer: 1, CreateWindow: time.Minute, CreateMax: 5, CleanupInterval: 30 * time.Minute}) + + now := time.Now().UTC() + if err := vmRepo.Create(context.Background(), &models.VM{ + UserID: user.ID, + ChallengeID: challenge1.ID, + VMID: "vm-expired-limit", + Status: "Running", + TTLExpiresAt: ptrTime(now.Add(-time.Minute)), + CreatedAt: now.Add(-time.Hour), + UpdatedAt: now.Add(-time.Hour), + }); err != nil { + t.Fatalf("create expired vm: %v", err) + } + + created, err := svc.GetOrCreateVM(context.Background(), user.ID, challenge2.ID) + if err != nil { + t.Fatalf("GetOrCreateVM: %v", err) + } + if created == nil || created.VMID == "" { + t.Fatalf("expected created vm, got %+v", created) + } + + if _, err := vmRepo.GetByVMID(context.Background(), "vm-expired-limit"); !errors.Is(err, repo.ErrNotFound) { + t.Fatalf("expected expired vm to be removed before limit check, got %v", err) + } +} + +func TestVMServiceGetOrCreateVMKeepsRowsWithNullTTL(t *testing.T) { + env := setupServiceTest(t) + user := createUser(t, env, "vm-null-ttl@example.com", "vm-null-ttl", "pass", models.UserRole) + challenge1 := createVMChallenge(t, env, "vm-null-ttl-1") + challenge2 := createVMChallenge(t, env, "vm-null-ttl-2") + svc, vmRepo := newVMServiceForTest(env, &vm.MockClient{ + CreateSandboxFn: func(ctx context.Context, id string, specYAML string) (*vm.Sandbox, error) { + exp := time.Now().UTC().Add(time.Hour) + return &vm.Sandbox{ID: id, Status: vm.SandboxStatus{Phase: "Pending", ExpireAt: &exp}}, nil + }, + }, config.VMConfig{Enabled: true, MaxPer: 1, CreateWindow: time.Minute, CreateMax: 5, CleanupInterval: 30 * time.Minute}) + + now := time.Now().UTC() + if err := vmRepo.Create(context.Background(), &models.VM{ + UserID: user.ID, + ChallengeID: challenge1.ID, + VMID: "vm-null-ttl-limit", + Status: "Running", + CreatedAt: now.Add(-time.Hour), + UpdatedAt: now.Add(-time.Hour), + }); err != nil { + t.Fatalf("create null ttl vm: %v", err) + } + + if _, err := svc.GetOrCreateVM(context.Background(), user.ID, challenge2.ID); !errors.Is(err, ErrVMLimitReached) { + t.Fatalf("expected ErrVMLimitReached because null ttl row remains, got %v", err) + } +} + +func TestVMServiceCleanupExpiredVMs(t *testing.T) { + env := setupServiceTest(t) + user := createUser(t, env, "vm-cleanup-all@example.com", "vm-cleanup-all", "pass", models.UserRole) + challenge1 := createVMChallenge(t, env, "vm-cleanup-all-1") + challenge2 := createVMChallenge(t, env, "vm-cleanup-all-2") + challenge3 := createVMChallenge(t, env, "vm-cleanup-all-3") + svc, vmRepo := newVMServiceForTest(env, &vm.MockClient{}, config.VMConfig{Enabled: true, MaxPer: 3, CreateWindow: time.Minute, CreateMax: 5, CleanupInterval: 30 * time.Minute}) + + now := time.Now().UTC() + for _, row := range []*models.VM{ + {UserID: user.ID, ChallengeID: challenge1.ID, VMID: "vm-clean-expired", Status: "Running", TTLExpiresAt: ptrTime(now.Add(-time.Minute)), CreatedAt: now, UpdatedAt: now}, + {UserID: user.ID, ChallengeID: challenge2.ID, VMID: "vm-clean-active", Status: "Running", TTLExpiresAt: ptrTime(now.Add(time.Minute)), CreatedAt: now, UpdatedAt: now}, + {UserID: user.ID, ChallengeID: challenge3.ID, VMID: "vm-clean-null", Status: "Running", CreatedAt: now, UpdatedAt: now}, + } { + if err := vmRepo.Create(context.Background(), row); err != nil { + t.Fatalf("create vm(%s): %v", row.VMID, err) + } + } + + deleted, err := svc.CleanupExpiredVMs(context.Background()) + if err != nil { + t.Fatalf("CleanupExpiredVMs: %v", err) + } + if deleted != 1 { + t.Fatalf("expected deleted=1, got %d", deleted) + } + + if _, err := vmRepo.GetByVMID(context.Background(), "vm-clean-expired"); !errors.Is(err, repo.ErrNotFound) { + t.Fatalf("expected expired vm deleted, got %v", err) + } + if _, err := vmRepo.GetByVMID(context.Background(), "vm-clean-active"); err != nil { + t.Fatalf("expected active vm remain, got %v", err) + } + if _, err := vmRepo.GetByVMID(context.Background(), "vm-clean-null"); err != nil { + t.Fatalf("expected null-ttl vm remain, got %v", err) + } +} + +func ptrTime(value time.Time) *time.Time { + return &value +} + +func TestVMServiceUserVMSummary(t *testing.T) { + env := setupServiceTest(t) + user := createUser(t, env, "vm-summary@example.com", "vm-summary", "pass", models.UserRole) + challenge := createVMChallenge(t, env, "vm-summary") + _, vmRepo := newVMServiceForTest(env, &vm.MockClient{}, config.VMConfig{Enabled: true, MaxPer: 2, CreateWindow: time.Minute, CreateMax: 1, CleanupInterval: time.Minute}) + + now := time.Now().UTC() + if err := vmRepo.Create(context.Background(), &models.VM{ + UserID: user.ID, + ChallengeID: challenge.ID, + VMID: "vm-summary-1", + Status: "Running", + CreatedAt: now, + UpdatedAt: now, + }); err != nil { + t.Fatalf("create vm: %v", err) + } + + svcEnabled, _ := newVMServiceForTest(env, &vm.MockClient{}, config.VMConfig{Enabled: true, MaxPer: 2, CreateWindow: time.Minute, CreateMax: 1, CleanupInterval: time.Minute}) + count, limit, err := svcEnabled.UserVMSummary(context.Background(), user.ID) + if err != nil { + t.Fatalf("UserVMSummary(enabled): %v", err) + } + + if count != 1 || limit != 2 { + t.Fatalf("expected 1/2, got %d/%d", count, limit) + } + + svcDisabled, _ := newVMServiceForTest(env, &vm.MockClient{}, config.VMConfig{Enabled: false, MaxPer: 2, CreateWindow: time.Minute, CreateMax: 1, CleanupInterval: time.Minute}) + count, limit, err = svcDisabled.UserVMSummary(context.Background(), user.ID) + if err != nil { + t.Fatalf("UserVMSummary(disabled): %v", err) + } + + if count != 0 || limit != 0 { + t.Fatalf("expected 0/0 when disabled, got %d/%d", count, limit) + } +} + +func TestVMServiceListAndGetByVMID(t *testing.T) { + env := setupServiceTest(t) + user := createUser(t, env, "vm-list-get@example.com", "vm-list-get", "pass", models.UserRole) + challenge := createVMChallenge(t, env, "vm-list-get") + svc, vmRepo := newVMServiceForTest(env, &vm.MockClient{ + GetSandboxFn: func(ctx context.Context, id string) (*vm.Sandbox, error) { + return &vm.Sandbox{ID: id, Status: vm.SandboxStatus{Phase: "Running", ExternalIP: "127.0.0.1"}}, nil + }, + }, config.VMConfig{Enabled: true, MaxPer: 3, CreateWindow: time.Minute, CreateMax: 1, CleanupInterval: time.Minute}) + + now := time.Now().UTC() + if err := vmRepo.Create(context.Background(), &models.VM{ + UserID: user.ID, + ChallengeID: challenge.ID, + VMID: "vm-list-get-1", + Status: "Pending", + CreatedAt: now, + UpdatedAt: now, + }); err != nil { + t.Fatalf("create vm: %v", err) + } + + rows, err := svc.ListUserVMs(context.Background(), user.ID) + if err != nil || len(rows) != 1 { + t.Fatalf("ListUserVMs err=%v len=%d", err, len(rows)) + } + + row, err := svc.GetVMByVMID(context.Background(), "vm-list-get-1") + if err != nil { + t.Fatalf("GetVMByVMID: %v", err) + } + + if row.Status != "Running" { + t.Fatalf("expected refreshed status Running, got %s", row.Status) + } +} + +func TestVMServiceDeletePaths(t *testing.T) { + env := setupServiceTest(t) + user := createUser(t, env, "vm-delete@example.com", "vm-delete", "pass", models.UserRole) + challenge := createVMChallenge(t, env, "vm-delete") + svc, vmRepo := newVMServiceForTest(env, &vm.MockClient{ + DeleteSandboxFn: func(ctx context.Context, id string) error { return vm.ErrNotFound }, + }, config.VMConfig{Enabled: true, MaxPer: 3, CreateWindow: time.Minute, CreateMax: 1, CleanupInterval: time.Minute}) + + now := time.Now().UTC() + if err := vmRepo.Create(context.Background(), &models.VM{ + UserID: user.ID, + ChallengeID: challenge.ID, + VMID: "vm-delete-1", + Status: "Running", + CreatedAt: now, + UpdatedAt: now, + }); err != nil { + t.Fatalf("create vm: %v", err) + } + + if err := svc.DeleteVM(context.Background(), user.ID, challenge.ID); err != nil { + t.Fatalf("DeleteVM: %v", err) + } + + if _, err := vmRepo.GetByVMID(context.Background(), "vm-delete-1"); !errors.Is(err, repo.ErrNotFound) { + t.Fatalf("expected deleted row, got %v", err) + } + + if err := vmRepo.Create(context.Background(), &models.VM{ + UserID: user.ID, + ChallengeID: challenge.ID, + VMID: "vm-delete-2", + Status: "Running", + CreatedAt: now, + UpdatedAt: now, + }); err != nil { + t.Fatalf("create vm2: %v", err) + } + + if err := svc.DeleteVMByVMID(context.Background(), "vm-delete-2"); err != nil { + t.Fatalf("DeleteVMByVMID: %v", err) + } +} + +func TestVMServiceStartTTLReaper(t *testing.T) { + env := setupServiceTest(t) + user := createUser(t, env, "vm-reaper@example.com", "vm-reaper", "pass", models.UserRole) + challenge := createVMChallenge(t, env, "vm-reaper") + svc, vmRepo := newVMServiceForTest(env, &vm.MockClient{}, config.VMConfig{Enabled: true, MaxPer: 3, CreateWindow: time.Minute, CreateMax: 1, CleanupInterval: time.Minute}) + + now := time.Now().UTC() + if err := vmRepo.Create(context.Background(), &models.VM{ + UserID: user.ID, + ChallengeID: challenge.ID, + VMID: "vm-reaper-expired", + Status: "Running", + TTLExpiresAt: ptrTime(now.Add(-time.Minute)), + CreatedAt: now, + UpdatedAt: now, + }); err != nil { + t.Fatalf("create vm: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + svc.StartTTLReaper(ctx, 10*time.Millisecond) + + deadline := time.Now().Add(500 * time.Millisecond) + for time.Now().Before(deadline) { + if _, err := vmRepo.GetByVMID(context.Background(), "vm-reaper-expired"); errors.Is(err, repo.ErrNotFound) { + return + } + + time.Sleep(20 * time.Millisecond) + } + + t.Fatal("expected StartTTLReaper to delete expired vm row") +} + +func TestVMServiceStartTTLReaperNonPositiveInterval(t *testing.T) { + env := setupServiceTest(t) + user := createUser(t, env, "vm-reaper-nonpos@example.com", "vm-reaper-nonpos", "pass", models.UserRole) + challenge := createVMChallenge(t, env, "vm-reaper-nonpos") + svc, vmRepo := newVMServiceForTest(env, &vm.MockClient{}, config.VMConfig{Enabled: true, MaxPer: 3, CreateWindow: time.Minute, CreateMax: 1, CleanupInterval: time.Minute}) + + now := time.Now().UTC() + if err := vmRepo.Create(context.Background(), &models.VM{ + UserID: user.ID, + ChallengeID: challenge.ID, + VMID: "vm-reaper-still-there", + Status: "Running", + TTLExpiresAt: ptrTime(now.Add(-time.Minute)), + CreatedAt: now, + UpdatedAt: now, + }); err != nil { + t.Fatalf("create vm: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + svc.StartTTLReaper(ctx, 0) + time.Sleep(50 * time.Millisecond) + + if _, err := vmRepo.GetByVMID(context.Background(), "vm-reaper-still-there"); err != nil { + t.Fatalf("expected vm row to remain with non-positive interval, got %v", err) + } +} + +func TestVMServiceHelpers(t *testing.T) { + if got := mapVMOrchestratorError(vm.ErrNotFound); !errors.Is(got, ErrVMNotFound) { + t.Fatalf("expected ErrVMNotFound, got %v", got) + } + + if got := mapVMOrchestratorError(vm.ErrInvalid); !errors.Is(got, ErrVMInvalidSpec) { + t.Fatalf("expected ErrVMInvalidSpec, got %v", got) + } + + if got := mapVMOrchestratorError(vm.ErrUnavailable); !errors.Is(got, ErrVMOrchestratorDown) { + t.Fatalf("expected ErrVMOrchestratorDown, got %v", got) + } + + ports := toVMPortMappings([]vm.PortMapping{{HostPort: 10000, ContainerPort: 31337, Protocol: "tcp"}}) + if len(ports) != 1 || ports[0].HostPort != 10000 { + t.Fatalf("unexpected mapped ports: %+v", ports) + } + + if toVMPortMappings(nil) != nil { + t.Fatal("expected nil for empty port mappings") + } + + if !isUniqueVMConflict(errors.New("duplicate key value violates unique constraint")) { + t.Fatal("expected unique conflict=true") + } + + if isUniqueVMConflict(errors.New("other error")) { + t.Fatal("expected unique conflict=false") + } +}