Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ VM_ORCH_BASE_URL=http://localhost:8082
VM_ORCH_TIMEOUT=5s
VM_CREATE_WINDOW=1m
VM_CREATE_MAX=1
VM_CLEANUP_INTERVAL=30m

# Logging
LOG_DIR=logs
Expand Down
4 changes: 4 additions & 0 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ func main() {
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()

if cfg.VM.Enabled {
vmSvc.StartTTLReaper(ctx, cfg.VM.CleanupInterval)
}

leaderboardBus := realtime.NewScoreboardBus(redisClient, cfg, scoreSvc, logger)
leaderboardBus.Start(ctx)

Expand Down
5 changes: 3 additions & 2 deletions docs/docs/vms.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ Notes:
- VM APIs are backed by `sandboxd-orch` (HTTP API).
- Legacy Stack APIs remain available, but VM APIs are independent.
- For authenticated `POST`, `PUT`, `PATCH`, and `DELETE` requests, send both `csrf_token` cookie and matching `X-CSRF-Token` header.
- VM row deletion from wargame DB is explicit through user/admin delete endpoints.
- VM rows with expired `ttl_expires_at` are deleted automatically by a background cleanup loop (`VM_CLEANUP_INTERVAL`, default `30m`).
- VM rows are not deleted automatically after a correct flag submission.
- VM rows are not deleted automatically just because orchestrator status changed to error or missing.
- VM rows are not deleted automatically just because orchestrator status changed to error; however, if a refresh/read finds that the orchestrator sandbox is missing, the VM row can be removed automatically, in addition to the TTL cleanup above (user/admin delete endpoints can also remove rows).

## VM Response Schema

Expand Down Expand Up @@ -128,6 +128,7 @@ Create behavior:
- Solved users can still create or recreate VMs for the same challenge.
- Wargame parses the YAML, rewrites `id` to generated `vm_id`, and sends the rewritten spec to orchestrator.
- If an existing VM row already exists for `(user_id, challenge_id)`, create returns the existing VM instead of creating a second one.
- Before create/limit checks, wargame removes the caller's expired VM rows from DB using `ttl_expires_at` (rows with `ttl_expires_at = null` are kept).
- User VM cap and VM creation rate limit are enforced server-side.

---
Expand Down
10 changes: 10 additions & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ type VMConfig struct {
OrchestratorTimeout time.Duration
CreateWindow time.Duration
CreateMax int
CleanupInterval time.Duration
}

const defaultJWTSecret = "change-me"
Expand Down Expand Up @@ -297,6 +298,10 @@ func Load() (Config, error) {
if err != nil {
errs = append(errs, err)
}
vmCleanupInterval, err := getDuration("VM_CLEANUP_INTERVAL", 30*time.Minute)
if err != nil {
errs = append(errs, err)
}

cfg := Config{
AppEnv: appEnv,
Expand Down Expand Up @@ -385,6 +390,7 @@ func Load() (Config, error) {
OrchestratorTimeout: vmTimeout,
CreateWindow: vmCreateWindow,
CreateMax: vmCreateMax,
CleanupInterval: vmCleanupInterval,
},
}

Expand Down Expand Up @@ -592,6 +598,9 @@ func validateConfig(cfg Config) error {
if cfg.VM.CreateMax <= 0 {
errs = append(errs, errors.New("VM_CREATE_MAX must be positive"))
}
if cfg.VM.CleanupInterval <= 0 {
errs = append(errs, errors.New("VM_CLEANUP_INTERVAL must be positive"))
}
}

if len(errs) == 0 {
Expand Down Expand Up @@ -734,6 +743,7 @@ func FormatForLog(cfg Config) map[string]any {
"orchestrator_timeout": seconds(cfg.VM.OrchestratorTimeout),
"create_window": seconds(cfg.VM.CreateWindow),
"create_max": cfg.VM.CreateMax,
"cleanup_interval": seconds(cfg.VM.CleanupInterval),
},
}
}
Expand Down
67 changes: 67 additions & 0 deletions internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ func TestLoadConfigDefaults(t *testing.T) {
if cfg.Stack.CreateMax != 1 {
t.Errorf("expected Stack.CreateMax 1, got %d", cfg.Stack.CreateMax)
}

if cfg.VM.CleanupInterval != 30*time.Minute {
t.Errorf("expected VM.CleanupInterval 30m, got %v", cfg.VM.CleanupInterval)
}
}

func TestLoadConfigCustomValues(t *testing.T) {
Expand Down Expand Up @@ -121,6 +125,7 @@ func TestLoadConfigCustomValues(t *testing.T) {
os.Setenv("STACKS_PROVISIONER_TIMEOUT", "9s")
os.Setenv("STACKS_CREATE_WINDOW", "2m")
os.Setenv("STACKS_CREATE_MAX", "2")
os.Setenv("VM_CLEANUP_INTERVAL", "45m")

defer os.Clearenv()

Expand Down Expand Up @@ -223,6 +228,9 @@ func TestLoadConfigCustomValues(t *testing.T) {
if cfg.Stack.MaxPer != 5 {
t.Errorf("expected Stack.MaxPer 5, got %d", cfg.Stack.MaxPer)
}
if cfg.VM.CleanupInterval != 45*time.Minute {
t.Errorf("expected VM.CleanupInterval 45m, got %v", cfg.VM.CleanupInterval)
}
}

func TestLoadConfigInvalidValues(t *testing.T) {
Expand All @@ -248,6 +256,7 @@ func TestLoadConfigInvalidValues(t *testing.T) {
{"invalid s3 media force path", "S3_MEDIA_FORCE_PATH_STYLE", "bad-bool"},
{"invalid stack max scope", "STACKS_MAX_SCOPE", "org"},
{"invalid leaderboard cache ttl", "LEADERBOARD_CACHE_TTL", "bad-duration"},
{"invalid vm cleanup interval", "VM_CLEANUP_INTERVAL", "bad-duration"},
}

for _, tt := range tests {
Expand Down Expand Up @@ -722,6 +731,64 @@ func TestValidateConfigAdditionalValidation(t *testing.T) {
}
}

func TestValidateConfigInvalidVMConfig(t *testing.T) {
cfg := Config{
AppEnv: "local",
HTTPAddr: ":8080",
BcryptCost: bcrypt.DefaultCost,
DB: DBConfig{
Host: "localhost",
Port: 5432,
User: "user",
Name: "db",
MaxOpenConns: 10,
MaxIdleConns: 5,
ConnMaxLifetime: time.Minute,
},
Redis: RedisConfig{
Addr: "localhost:6379",
PoolSize: 10,
},
JWT: JWTConfig{
Secret: "secret",
Issuer: "issuer",
AccessTTL: time.Hour,
RefreshTTL: 24 * time.Hour,
},
Security: SecurityConfig{
SubmissionWindow: time.Minute,
SubmissionMax: 10,
},
Cache: CacheConfig{
TimelineTTL: time.Minute,
LeaderboardTTL: time.Minute,
},
Logging: LoggingConfig{
Dir: "logs",
FilePrefix: "app",
MaxBodyBytes: 1024,
},
VM: VMConfig{
Enabled: true,
MaxPer: 1,
OrchestratorBaseURL: "http://localhost:8082",
OrchestratorTimeout: time.Second,
CreateWindow: time.Minute,
CreateMax: 1,
CleanupInterval: 0,
},
}

err := validateConfig(cfg)
if err == nil {
t.Fatal("expected vm validation error")
}

if !strings.Contains(err.Error(), "VM_CLEANUP_INTERVAL") {
t.Fatalf("expected VM_CLEANUP_INTERVAL error, got %v", err)
}
}

func TestRedact(t *testing.T) {
cfg := Config{
DB: DBConfig{Password: "dbpass"},
Expand Down
19 changes: 19 additions & 0 deletions internal/repo/vm_repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package repo

import (
"context"
"time"

"wargame/internal/models"

Expand Down Expand Up @@ -125,3 +126,21 @@ func (r *VMRepo) Delete(ctx context.Context, vm *models.VM) error {

return nil
}

func (r *VMRepo) DeleteExpired(ctx context.Context, now time.Time) (int64, error) {
res, err := r.db.NewDelete().
Model((*models.VM)(nil)).
Where("ttl_expires_at IS NOT NULL").
Where("ttl_expires_at <= ?", now).
Exec(ctx)
if err != nil {
return 0, wrapError("vmRepo.DeleteExpired", err)
}

affected, err := res.RowsAffected()
if err != nil {
return 0, wrapError("vmRepo.DeleteExpired rows affected", err)
}

return affected, nil
}
66 changes: 66 additions & 0 deletions internal/repo/vm_repo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,69 @@ func TestVMRepoNotFound(t *testing.T) {
t.Fatalf("expected ErrNotFound, got %v", err)
}
}

func TestVMRepoDeleteExpired(t *testing.T) {
env := setupRepoTest(t)
vmRepo := NewVMRepo(env.db)
user := createUserForTestUserScope(t, env, "vmexpired@example.com", "vmexpired", "pass", models.UserRole)
challenge1 := createChallenge(t, env, "VM Expired 1", 100, "flag{vmexpired1}", true)
challenge2 := createChallenge(t, env, "VM Expired 2", 100, "flag{vmexpired2}", true)
challenge3 := createChallenge(t, env, "VM Expired 3", 100, "flag{vmexpired3}", true)

now := time.Now().UTC()
expired := &models.VM{
UserID: user.ID,
ChallengeID: challenge1.ID,
VMID: "vm-expired",
Status: "Running",
TTLExpiresAt: ptrTime(now.Add(-time.Minute)),
CreatedAt: now.Add(-time.Hour),
UpdatedAt: now.Add(-time.Hour),
}

active := &models.VM{
UserID: user.ID,
ChallengeID: challenge2.ID,
VMID: "vm-active",
Status: "Running",
TTLExpiresAt: ptrTime(now.Add(time.Hour)),
CreatedAt: now.Add(-time.Hour),
UpdatedAt: now.Add(-time.Hour),
}

noTTL := &models.VM{
UserID: user.ID,
ChallengeID: challenge3.ID,
VMID: "vm-nottl",
Status: "Running",
CreatedAt: now.Add(-time.Hour),
UpdatedAt: now.Add(-time.Hour),
}

for _, row := range []*models.VM{expired, active, noTTL} {
if err := vmRepo.Create(context.Background(), row); err != nil {
t.Fatalf("create vm(%s): %v", row.VMID, err)
}
}

deleted, err := vmRepo.DeleteExpired(context.Background(), now)
if err != nil {
t.Fatalf("DeleteExpired: %v", err)
}

if deleted != 1 {
t.Fatalf("expected 1 deleted row, got %d", deleted)
}

if _, err := vmRepo.GetByVMID(context.Background(), "vm-expired"); !errors.Is(err, ErrNotFound) {
t.Fatalf("expected expired row deleted, got err=%v", err)
}

if _, err := vmRepo.GetByVMID(context.Background(), "vm-active"); err != nil {
t.Fatalf("expected active row to remain, got %v", err)
}

if _, err := vmRepo.GetByVMID(context.Background(), "vm-nottl"); err != nil {
t.Fatalf("expected null-ttl row to remain, got %v", err)
}
}
66 changes: 66 additions & 0 deletions internal/service/vm_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"log/slog"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -127,6 +128,10 @@ func (s *VMService) GetOrCreateVM(ctx context.Context, userID, challengeID int64
return nil, err
}

if err := s.cleanupExpiredUserVMs(ctx, userID); err != nil {
return nil, err
}

existing, err := s.findExistingVM(ctx, userID, challengeID)
if err != nil {
return nil, err
Expand Down Expand Up @@ -273,6 +278,67 @@ func (s *VMService) ensureUserLimit(ctx context.Context, userID int64) error {
return nil
}

func (s *VMService) cleanupExpiredUserVMs(ctx context.Context, userID int64) error {
rows, err := s.vmRepo.ListByUser(ctx, userID)
if err != nil {
return fmt.Errorf("vm.cleanupExpiredUserVMs list: %w", err)
}

now := time.Now().UTC()
for i := range rows {
if rows[i].TTLExpiresAt == nil {
continue
}

if rows[i].TTLExpiresAt.After(now) {
continue
}

if err := s.vmRepo.Delete(ctx, &rows[i]); err != nil {
return fmt.Errorf("vm.cleanupExpiredUserVMs delete: %w", err)
}
}

return nil
}

func (s *VMService) CleanupExpiredVMs(ctx context.Context) (int64, error) {
deleted, err := s.vmRepo.DeleteExpired(ctx, time.Now().UTC())
if err != nil {
return 0, fmt.Errorf("vm.CleanupExpiredVMs: %w", err)
}

return deleted, nil
}

func (s *VMService) StartTTLReaper(ctx context.Context, interval time.Duration) {
if interval <= 0 {
slog.Warn("vm ttl reaper disabled due to non-positive interval", slog.Duration("interval", interval))
return
}

ticker := time.NewTicker(interval)
go func() {
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
deleted, err := s.CleanupExpiredVMs(ctx)
if err != nil {
slog.Warn("vm ttl cleanup failed", slog.Any("error", err))
continue
}

if deleted > 0 {
slog.Info("vm ttl cleanup completed", slog.Int64("deleted", deleted))
}
}
}
}()
}

func (s *VMService) createVM(ctx context.Context, userID, challengeID int64, spec string) (*models.VM, error) {
vmID := newVMID(userID, challengeID)
sandbox, err := s.client.CreateSandbox(ctx, vmID, spec)
Expand Down
Loading
Loading