Skip to content
Merged
8 changes: 6 additions & 2 deletions backend/internal/cli/dto_drift_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ func (f *fakeSessionService) Spawn(_ context.Context, cfg ports.SpawnConfig) (do
}, nil
}

func (f *fakeSessionService) SpawnOrchestrator(ctx context.Context, projectID domain.ProjectID, _ bool) (domain.Session, error) {
return f.Spawn(ctx, ports.SpawnConfig{ProjectID: projectID, Kind: domain.KindOrchestrator})
}

func (f *fakeSessionService) Get(context.Context, domain.SessionID) (domain.Session, error) {
return domain.Session{}, nil
}
Expand Down Expand Up @@ -124,10 +128,10 @@ func startDriftTestDaemon(t *testing.T, sessions controllers.SessionService, pro
t.Helper()

log := slog.New(slog.NewTextHandler(io.Discard, nil))
router := httpd.NewRouterWithAPI(config.Config{}, log, nil, httpd.APIDeps{
router := httpd.NewRouterWithControl(config.Config{}, log, nil, httpd.APIDeps{
Sessions: sessions,
Projects: projects,
})
}, httpd.ControlDeps{})
srv := httptest.NewServer(router)
t.Cleanup(srv.Close)

Expand Down
66 changes: 66 additions & 0 deletions backend/internal/httpd/apierr/apierr.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Package apierr defines the REST API's error vocabulary: a single structured
// error type every service returns and the controllers render into the locked
// APIError envelope with one errors.As. It is deliberately scoped to the HTTP
// API tree — these services exist to serve the daemon's REST surface — and
// imports nothing, so any layer may depend on it without an import cycle.
package apierr

// Kind is a semantic failure category. It is not an HTTP status or word: the
// envelope layer is the only place a Kind is translated into one.
type Kind int

const (
// KindInternal is an unexpected failure; it maps to 500. As iota's zero
// value it is also the Kind of a zero-value Error, so an Error built without
// a Kind safely defaults to a 500.
KindInternal Kind = iota
// KindInvalid is malformed or rejected input; it maps to 400.
KindInvalid
// KindNotFound is a missing resource; it maps to 404.
KindNotFound
// KindConflict is a state/uniqueness clash; it maps to 409.
KindConflict
)

// Error is the structured error every service returns. Code is a stable machine
// identifier (e.g. "SESSION_NOT_FOUND"); Message is the human-facing text. It
// reaches the controller through fmt.Errorf("...: %w", err) wrapping and is
// matched there with errors.As.
type Error struct {
Kind Kind
Code string
Message string
Details map[string]any
}

func (e *Error) Error() string {
if e == nil {
return ""
}
return e.Message
}

// New builds an Error from its parts.
func New(kind Kind, code, message string, details map[string]any) *Error {
return &Error{Kind: kind, Code: code, Message: message, Details: details}
}

// Invalid is a 400-class error.
func Invalid(code, message string, details map[string]any) *Error {
return New(KindInvalid, code, message, details)
}

// NotFound is a 404-class error.
func NotFound(code, message string) *Error {
return New(KindNotFound, code, message, nil)
}

// Conflict is a 409-class error.
func Conflict(code, message string, details map[string]any) *Error {
return New(KindConflict, code, message, details)
}

// Internal is a 500-class error.
func Internal(code, message string) *Error {
return New(KindInternal, code, message, nil)
}
2 changes: 1 addition & 1 deletion backend/internal/httpd/apispec/parity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (
// spec coverage, and the spec can't describe a route that isn't served.
func TestRouteSpecParity(t *testing.T) {
log := slog.New(slog.NewTextHandler(io.Discard, nil))
router := httpd.NewRouter(config.Config{}, log, nil)
router := httpd.NewRouterWithControl(config.Config{}, log, nil, httpd.APIDeps{}, httpd.ControlDeps{})

mounted := map[string]bool{}
err := chi.Walk(router, func(method, route string, _ http.Handler, _ ...func(http.Handler) http.Handler) error {
Expand Down
33 changes: 4 additions & 29 deletions backend/internal/httpd/controllers/projects.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package controllers

import (
"encoding/json"
"errors"
"net/http"

"github.com/go-chi/chi/v5"
Expand Down Expand Up @@ -38,7 +37,7 @@ func (c *ProjectsController) list(w http.ResponseWriter, r *http.Request) {
}
projects, err := c.Mgr.List(r.Context())
if err != nil {
writeProjectError(w, r, err)
envelope.WriteError(w, r, err)
return
}
if projects == nil {
Expand All @@ -59,7 +58,7 @@ func (c *ProjectsController) add(w http.ResponseWriter, r *http.Request) {
}
p, err := c.Mgr.Add(r.Context(), in)
if err != nil {
writeProjectError(w, r, err)
envelope.WriteError(w, r, err)
return
}
envelope.WriteJSON(w, http.StatusCreated, ProjectResponse{Project: p})
Expand All @@ -72,7 +71,7 @@ func (c *ProjectsController) get(w http.ResponseWriter, r *http.Request) {
}
got, err := c.Mgr.Get(r.Context(), projectID(r))
if err != nil {
writeProjectError(w, r, err)
envelope.WriteError(w, r, err)
return
}
resp, err := newGetProjectResponse(got)
Expand All @@ -90,7 +89,7 @@ func (c *ProjectsController) remove(w http.ResponseWriter, r *http.Request) {
}
result, err := c.Mgr.Remove(r.Context(), projectID(r))
if err != nil {
writeProjectError(w, r, err)
envelope.WriteError(w, r, err)
return
}
envelope.WriteJSON(w, http.StatusOK, result)
Expand All @@ -103,27 +102,3 @@ func projectID(r *http.Request) domain.ProjectID {
func decodeJSON(r *http.Request, out any) error {
return json.NewDecoder(r.Body).Decode(out)
}

// writeProjectError maps a projectsvc.Error to its HTTP status, falling back to
// 500 for an unrecognized kind or a non-projectsvc.Error.
func writeProjectError(w http.ResponseWriter, r *http.Request, err error) {
var pe *projectsvc.Error
if errors.As(err, &pe) {
status := http.StatusInternalServerError
switch pe.Kind {
case "bad_request":
status = http.StatusBadRequest
case "not_found":
status = http.StatusNotFound
case "conflict":
status = http.StatusConflict
case "not_implemented":
status = http.StatusNotImplemented
case "internal":
status = http.StatusInternalServerError
}
envelope.WriteAPIError(w, r, status, pe.Kind, pe.Code, pe.Message, pe.Details)
return
}
envelope.WriteAPIError(w, r, http.StatusInternalServerError, "internal", "INTERNAL_ERROR", "Internal server error", nil)
}
10 changes: 5 additions & 5 deletions backend/internal/httpd/controllers/projects_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ func TestProjectsAPI_GetEmptyResultIs500(t *testing.T) {

log := slog.New(slog.NewTextHandler(io.Discard, nil))

srv := httptest.NewServer(httpd.NewRouterWithAPI(config.Config{}, log, nil, httpd.APIDeps{
srv := httptest.NewServer(httpd.NewRouterWithControl(config.Config{}, log, nil, httpd.APIDeps{

Projects: emptyGetManager{},
}))
}, httpd.ControlDeps{}))

t.Cleanup(srv.Close)

Expand Down Expand Up @@ -89,10 +89,10 @@ func newTestServer(t *testing.T) *httptest.Server {

t.Cleanup(func() { _ = store.Close() })

srv := httptest.NewServer(httpd.NewRouterWithAPI(config.Config{}, log, nil, httpd.APIDeps{
srv := httptest.NewServer(httpd.NewRouterWithControl(config.Config{}, log, nil, httpd.APIDeps{

Projects: projectsvc.New(store),
}))
}, httpd.ControlDeps{}))

t.Cleanup(srv.Close)

Expand All @@ -104,7 +104,7 @@ func TestProjectsRoutes_DefaultToStubsWithoutManager(t *testing.T) {

log := slog.New(slog.NewTextHandler(io.Discard, nil))

srv := httptest.NewServer(httpd.NewRouter(config.Config{}, log, nil))
srv := httptest.NewServer(httpd.NewRouterWithControl(config.Config{}, log, nil, httpd.APIDeps{}, httpd.ControlDeps{}))

t.Cleanup(srv.Close)

Expand Down
2 changes: 1 addition & 1 deletion backend/internal/httpd/controllers/prs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (f *fakePRService) ResolveComments(_ context.Context, _ string, _ []string)
func newPRTestServer(t *testing.T, svc prsvc.ActionManager) *httptest.Server {
t.Helper()
log := slog.New(slog.NewTextHandler(io.Discard, nil))
srv := httptest.NewServer(httpd.NewRouterWithAPI(config.Config{}, log, nil, httpd.APIDeps{PRs: svc}))
srv := httptest.NewServer(httpd.NewRouterWithControl(config.Config{}, log, nil, httpd.APIDeps{PRs: svc}, httpd.ControlDeps{}))
t.Cleanup(srv.Close)
return srv
}
Expand Down
59 changes: 14 additions & 45 deletions backend/internal/httpd/controllers/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"github.com/aoagents/agent-orchestrator/backend/internal/httpd/envelope"
"github.com/aoagents/agent-orchestrator/backend/internal/ports"
sessionsvc "github.com/aoagents/agent-orchestrator/backend/internal/service/session"
sessionmanager "github.com/aoagents/agent-orchestrator/backend/internal/session_manager"
)

const (
Expand All @@ -27,6 +26,7 @@ const (
type SessionService interface {
List(ctx context.Context, filter sessionsvc.ListFilter) ([]domain.Session, error)
Spawn(ctx context.Context, cfg ports.SpawnConfig) (domain.Session, error)
SpawnOrchestrator(ctx context.Context, projectID domain.ProjectID, clean bool) (domain.Session, error)
Get(ctx context.Context, id domain.SessionID) (domain.Session, error)
Restore(ctx context.Context, id domain.SessionID) (domain.Session, error)
Kill(ctx context.Context, id domain.SessionID) (bool, error)
Expand Down Expand Up @@ -68,7 +68,7 @@ func (c *SessionsController) list(w http.ResponseWriter, r *http.Request) {
}
sessions, err := c.Svc.List(r.Context(), filter)
if err != nil {
writeSessionError(w, r, err)
envelope.WriteError(w, r, err)
return
}
envelope.WriteJSON(w, http.StatusOK, ListSessionsResponse{Sessions: sessions})
Expand Down Expand Up @@ -97,7 +97,7 @@ func (c *SessionsController) spawn(w http.ResponseWriter, r *http.Request) {
}
sess, err := c.Svc.Spawn(r.Context(), ports.SpawnConfig{ProjectID: in.ProjectID, IssueID: in.IssueID, Kind: in.Kind, Harness: in.Harness, Branch: in.Branch, Prompt: in.Prompt, AgentRules: in.AgentRules})
if err != nil {
writeSessionError(w, r, err)
envelope.WriteError(w, r, err)
return
}
envelope.WriteJSON(w, http.StatusCreated, SessionResponse{Session: sess})
Expand All @@ -110,7 +110,7 @@ func (c *SessionsController) get(w http.ResponseWriter, r *http.Request) {
}
sess, err := c.Svc.Get(r.Context(), sessionID(r))
if err != nil {
writeSessionError(w, r, err)
envelope.WriteError(w, r, err)
return
}
envelope.WriteJSON(w, http.StatusOK, SessionResponse{Session: sess})
Expand All @@ -132,7 +132,7 @@ func (c *SessionsController) rename(w http.ResponseWriter, r *http.Request) {
return
}
if err := c.Svc.Rename(r.Context(), sessionID(r), displayName); err != nil {
writeSessionError(w, r, err)
envelope.WriteError(w, r, err)
return
}
envelope.WriteJSON(w, http.StatusOK, RenameSessionResponse{OK: true, SessionID: sessionID(r), DisplayName: displayName})
Expand All @@ -145,7 +145,7 @@ func (c *SessionsController) restore(w http.ResponseWriter, r *http.Request) {
}
sess, err := c.Svc.Restore(r.Context(), sessionID(r))
if err != nil {
writeSessionError(w, r, err)
envelope.WriteError(w, r, err)
return
}
envelope.WriteJSON(w, http.StatusOK, RestoreSessionResponse{OK: true, SessionID: sessionID(r), Session: sess})
Expand All @@ -158,7 +158,7 @@ func (c *SessionsController) kill(w http.ResponseWriter, r *http.Request) {
}
freed, err := c.Svc.Kill(r.Context(), sessionID(r))
if err != nil {
writeSessionError(w, r, err)
envelope.WriteError(w, r, err)
return
}
envelope.WriteJSON(w, http.StatusOK, KillSessionResponse{OK: true, SessionID: sessionID(r), Freed: freed})
Expand All @@ -171,7 +171,7 @@ func (c *SessionsController) cleanup(w http.ResponseWriter, r *http.Request) {
}
cleaned, err := c.Svc.Cleanup(r.Context(), domain.ProjectID(r.URL.Query().Get("project")))
if err != nil {
writeSessionError(w, r, err)
envelope.WriteError(w, r, err)
return
}
envelope.WriteJSON(w, http.StatusOK, CleanupSessionsResponse{OK: true, Cleaned: cleaned})
Expand All @@ -197,7 +197,7 @@ func (c *SessionsController) send(w http.ResponseWriter, r *http.Request) {
}
message := stripUnsafeControlChars(in.Message)
if err := c.Svc.Send(r.Context(), sessionID(r), message); err != nil {
writeSessionError(w, r, err)
envelope.WriteError(w, r, err)
return
}
envelope.WriteJSON(w, http.StatusOK, SendSessionMessageResponse{OK: true, SessionID: sessionID(r), Message: message})
Expand All @@ -217,23 +217,9 @@ func (c *SessionsController) spawnOrchestrator(w http.ResponseWriter, r *http.Re
envelope.WriteAPIError(w, r, http.StatusBadRequest, "bad_request", "PROJECT_ID_REQUIRED", "projectId is required", nil)
return
}
if in.Clean {
active := true
orchestrators, err := c.Svc.List(r.Context(), sessionsvc.ListFilter{ProjectID: in.ProjectID, Active: &active, OrchestratorOnly: true})
if err != nil {
writeSessionError(w, r, err)
return
}
for _, existing := range orchestrators {
if _, err := c.Svc.Kill(r.Context(), existing.ID); err != nil {
writeSessionError(w, r, err)
return
}
}
}
sess, err := c.Svc.Spawn(r.Context(), ports.SpawnConfig{ProjectID: in.ProjectID, Kind: domain.KindOrchestrator})
sess, err := c.Svc.SpawnOrchestrator(r.Context(), in.ProjectID, in.Clean)
if err != nil {
writeSessionError(w, r, err)
envelope.WriteError(w, r, err)
return
}
envelope.WriteJSON(w, http.StatusCreated, SpawnOrchestratorResponse{
Expand All @@ -248,7 +234,7 @@ func (c *SessionsController) listOrchestrators(w http.ResponseWriter, r *http.Re
}
sessions, err := c.Svc.List(r.Context(), sessionsvc.ListFilter{OrchestratorOnly: true})
if err != nil {
writeSessionError(w, r, err)
envelope.WriteError(w, r, err)
return
}
envelope.WriteJSON(w, http.StatusOK, ListSessionsResponse{Sessions: sessions})
Expand All @@ -261,11 +247,11 @@ func (c *SessionsController) getOrchestrator(w http.ResponseWriter, r *http.Requ
}
sess, err := c.Svc.Get(r.Context(), orchestratorID(r))
if err != nil {
writeSessionError(w, r, err)
envelope.WriteError(w, r, err)
return
}
if sess.Kind != domain.KindOrchestrator {
writeSessionError(w, r, sessionmanager.ErrNotFound)
envelope.WriteAPIError(w, r, http.StatusNotFound, "not_found", "SESSION_NOT_FOUND", "Unknown session", nil)
return
}
envelope.WriteJSON(w, http.StatusOK, SessionResponse{Session: sess})
Expand Down Expand Up @@ -314,20 +300,3 @@ func stripUnsafeControlChars(message string) string {
return r
}, message)
}

func writeSessionError(w http.ResponseWriter, r *http.Request, err error) {
switch {
case errors.Is(err, sessionmanager.ErrNotFound):
envelope.WriteAPIError(w, r, http.StatusNotFound, "not_found", "SESSION_NOT_FOUND", "Unknown session", nil)
case errors.Is(err, sessionmanager.ErrNotRestorable):
envelope.WriteAPIError(w, r, http.StatusConflict, "conflict", "SESSION_NOT_RESTORABLE", "Session is not restorable", nil)
case errors.Is(err, sessionmanager.ErrTerminated):
envelope.WriteAPIError(w, r, http.StatusConflict, "conflict", "SESSION_TERMINATED", "Session is terminated", nil)
case errors.Is(err, sessionmanager.ErrIncompleteHandle):
envelope.WriteAPIError(w, r, http.StatusConflict, "conflict", "SESSION_INCOMPLETE_HANDLE", "Session is missing runtime or workspace handles", nil)
case errors.Is(err, sessionmanager.ErrProjectNotResolvable):
envelope.WriteAPIError(w, r, http.StatusBadRequest, "bad_request", "PROJECT_NOT_RESOLVABLE", "Project is not registered or has no repo — register it with `ao project add`", nil)
default:
envelope.WriteAPIError(w, r, http.StatusInternalServerError, "internal", "SESSION_OPERATION_FAILED", "Session operation failed", nil)
}
}
Loading
Loading