diff --git a/runs/repository/impl/action.go b/runs/repository/impl/action.go index 53e142e882..79326cb0de 100644 --- a/runs/repository/impl/action.go +++ b/runs/repository/impl/action.go @@ -356,6 +356,38 @@ func (r *actionRepo) ListActions(ctx context.Context, input interfaces.ListResou return actions, nil } +// ListActionPhasesForCounts returns every action in a run with only the columns +// needed to seed the run-state tree's child phase counts (name, parent, phase). +// It deliberately avoids the large bytea columns (action_spec, action_details, ...) +// so the whole run loads in one fast query. WatchActions otherwise has to stream +// every child to count them, which the console's stream deadline truncates on big +// map tasks -- seeding counts up front from this query keeps the aggregate correct +// from the first streamed page. +// +// Ordered created_at ASC so parents are seen before their children (the run-state +// manager requires the parent node to exist when a child is inserted); name is a +// deterministic tiebreaker among equal created_at. +// +// ponytail: loads all rows for the run into memory. Fine to ~100k actions; beyond +// that switch to a recursive SQL CTE or paginate. +func (r *actionRepo) ListActionPhasesForCounts(ctx context.Context, runID *common.RunIdentifier) ([]*models.Action, error) { + expr, err := NewRunActionsFilter(runID).QueryExpression("") + if err != nil { + return nil, fmt.Errorf("failed to build filter expression: %w", err) + } + + query := sqlx.Rebind(sqlx.DOLLAR, + "SELECT name, parent_action_name, phase, created_at FROM actions WHERE "+ + expr.Query+" ORDER BY created_at ASC, name ASC") + + var actions []*models.Action + if err := sqlx.SelectContext(ctx, r.db, &actions, query, expr.Args...); err != nil { + return nil, fmt.Errorf("failed to list action phases: %w", err) + } + + return actions, nil +} + // UpdateActionPhase updates the phase of an action. // endTime should be set when the action reaches a terminal phase. func (r *actionRepo) UpdateActionPhase( diff --git a/runs/repository/impl/action_test.go b/runs/repository/impl/action_test.go index 09ff040eab..58902300e9 100644 --- a/runs/repository/impl/action_test.go +++ b/runs/repository/impl/action_test.go @@ -695,3 +695,75 @@ func TestUpdateActionPhase_AbortedDoesNotInsertEvent(t *testing.T) { require.NoError(t, err) assert.Equal(t, 0, count, "UpdateActionPhase(ABORTED) must not insert a synthetic action_events row") } + +// TestListActionPhasesForCounts verifies the lightweight query used to seed child +// phase counts returns every action in the run (root + children) with name, +// parent, and phase populated, ordered so parents precede their children. +func TestListActionPhasesForCounts(t *testing.T) { + db := setupActionDB(t) + defer func() { db.Exec("DELETE FROM actions") }() + actionRepo, err := NewActionRepo(db, testDbConfig) + require.NoError(t, err) + ctx := context.Background() + + runID := &common.RunIdentifier{Project: "proj1", Domain: "domain1", Name: "run1"} + base := time.Unix(1700000000, 0) + + // Root action (parent_action_name NULL), created first. + _, err = actionRepo.CreateAction(ctx, &models.Action{ + Project: runID.Project, Domain: runID.Domain, RunName: runID.Name, Name: rootActionName, + Phase: int32(common.ActionPhase_ACTION_PHASE_RUNNING), CreatedAt: base, + }, false) + require.NoError(t, err) + + // Children of the root, half QUEUED half SUCCEEDED. + const children = 120 + wantQueued, wantSucceeded := 0, 0 + for i := 0; i < children; i++ { + phase := common.ActionPhase_ACTION_PHASE_QUEUED + if i%2 == 0 { + phase = common.ActionPhase_ACTION_PHASE_SUCCEEDED + wantSucceeded++ + } else { + wantQueued++ + } + _, err := actionRepo.CreateAction(ctx, &models.Action{ + Project: runID.Project, Domain: runID.Domain, RunName: runID.Name, + Name: fmt.Sprintf("c%03d", i), + ParentActionName: sql.NullString{String: rootActionName, Valid: true}, + Phase: int32(phase), + CreatedAt: base.Add(time.Duration(i+1) * time.Second), + }, false) + require.NoError(t, err) + } + + rows, err := actionRepo.ListActionPhasesForCounts(ctx, runID) + require.NoError(t, err) + require.Len(t, rows, children+1, "must return every action in the run") + + // Root sorts first (earliest created_at); fields needed by the tree are populated. + require.Equal(t, rootActionName, rows[0].Name) + require.False(t, rows[0].ParentActionName.Valid, "root has no parent") + + gotQueued, gotSucceeded := 0, 0 + for _, r := range rows { + if r.Name == rootActionName { + continue + } + require.True(t, r.ParentActionName.Valid) + require.Equal(t, rootActionName, r.ParentActionName.String) + switch common.ActionPhase(r.Phase) { + case common.ActionPhase_ACTION_PHASE_QUEUED: + gotQueued++ + case common.ActionPhase_ACTION_PHASE_SUCCEEDED: + gotSucceeded++ + } + } + assert.Equal(t, wantQueued, gotQueued) + assert.Equal(t, wantSucceeded, gotSucceeded) + + // Ordering: created_at ascending (parents-before-children invariant). + for i := 1; i < len(rows); i++ { + require.False(t, rows[i].CreatedAt.Before(rows[i-1].CreatedAt), "rows must be ordered by created_at ASC") + } +} diff --git a/runs/repository/interfaces/action.go b/runs/repository/interfaces/action.go index 4879e42f41..46a5929ba7 100644 --- a/runs/repository/interfaces/action.go +++ b/runs/repository/interfaces/action.go @@ -26,6 +26,10 @@ type ActionRepo interface { GetLatestEventByAttempt(ctx context.Context, actionID *common.ActionIdentifier, attempt uint32) (*models.ActionEvent, error) GetAction(ctx context.Context, actionID *common.ActionIdentifier) (*models.Action, error) ListActions(ctx context.Context, input ListResourceInput) ([]*models.Action, error) + // ListActionPhasesForCounts returns lightweight rows (name, parent, phase) for + // every action in a run, used to seed child phase counts without streaming all + // children. See the impl for details. + ListActionPhasesForCounts(ctx context.Context, runID *common.RunIdentifier) ([]*models.Action, error) UpdateActionPhase(ctx context.Context, actionID *common.ActionIdentifier, phase common.ActionPhase, attempts uint32, cacheStatus core.CatalogCacheStatus, endTime *time.Time) error // AbortAction marks only the targeted action as ABORTED and sets abort_requested_at. // K8s cascades CRD deletion to descendants via OwnerReferences. diff --git a/runs/repository/mocks/mocks.go b/runs/repository/mocks/mocks.go index 3b3b8d3af7..8b67c552f8 100644 --- a/runs/repository/mocks/mocks.go +++ b/runs/repository/mocks/mocks.go @@ -644,6 +644,74 @@ func (_c *ActionRepo_InsertEvents_Call) RunAndReturn(run func(ctx context.Contex return _c } +// ListActionPhasesForCounts provides a mock function for the type ActionRepo +func (_mock *ActionRepo) ListActionPhasesForCounts(ctx context.Context, runID *common.RunIdentifier) ([]*models.Action, error) { + ret := _mock.Called(ctx, runID) + + if len(ret) == 0 { + panic("no return value specified for ListActionPhasesForCounts") + } + + var r0 []*models.Action + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *common.RunIdentifier) ([]*models.Action, error)); ok { + return returnFunc(ctx, runID) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, *common.RunIdentifier) []*models.Action); ok { + r0 = returnFunc(ctx, runID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*models.Action) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, *common.RunIdentifier) error); ok { + r1 = returnFunc(ctx, runID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// ActionRepo_ListActionPhasesForCounts_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListActionPhasesForCounts' +type ActionRepo_ListActionPhasesForCounts_Call struct { + *mock.Call +} + +// ListActionPhasesForCounts is a helper method to define mock.On call +// - ctx context.Context +// - runID *common.RunIdentifier +func (_e *ActionRepo_Expecter) ListActionPhasesForCounts(ctx interface{}, runID interface{}) *ActionRepo_ListActionPhasesForCounts_Call { + return &ActionRepo_ListActionPhasesForCounts_Call{Call: _e.mock.On("ListActionPhasesForCounts", ctx, runID)} +} + +func (_c *ActionRepo_ListActionPhasesForCounts_Call) Run(run func(ctx context.Context, runID *common.RunIdentifier)) *ActionRepo_ListActionPhasesForCounts_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *common.RunIdentifier + if args[1] != nil { + arg1 = args[1].(*common.RunIdentifier) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *ActionRepo_ListActionPhasesForCounts_Call) Return(actions []*models.Action, err error) *ActionRepo_ListActionPhasesForCounts_Call { + _c.Call.Return(actions, err) + return _c +} + +func (_c *ActionRepo_ListActionPhasesForCounts_Call) RunAndReturn(run func(ctx context.Context, runID *common.RunIdentifier) ([]*models.Action, error)) *ActionRepo_ListActionPhasesForCounts_Call { + _c.Call.Return(run) + return _c +} + // ListActions provides a mock function for the type ActionRepo func (_mock *ActionRepo) ListActions(ctx context.Context, input interfaces.ListResourceInput) ([]*models.Action, error) { ret := _mock.Called(ctx, input) diff --git a/runs/service/run_service.go b/runs/service/run_service.go index 21be8e208f..0ed826c0f0 100644 --- a/runs/service/run_service.go +++ b/runs/service/run_service.go @@ -1307,6 +1307,23 @@ func (s *RunService) listAndSendAllActions( rsm *runStateManager, stream *connect.ServerStream[workflow.WatchActionsResponse], ) error { + // Seed child phase counts from a single lightweight query before streaming the + // full snapshot. Otherwise ChildPhaseCounts is built up as every child streams + // in (~25s for a 20k map task), and the console's stream deadline truncates it + // mid-climb -- showing a count far below the real total. Seeding makes the count + // correct from the first streamed page; re-streaming the same rows below is + // count-neutral (same phase => no-op in modifyPhaseCounters), so nothing + // regresses. Mirrors cloud's SQL-aggregated phase counts. + seed, err := s.repo.ActionRepo().ListActionPhasesForCounts(ctx, runID) + if err != nil { + return err + } + // State only -- the per-node updates are re-sent by the streaming loop below + // (with full action data), so we discard them here. + if _, err := rsm.upsertActions(ctx, seed); err != nil { + return err + } + const pageSize = 100 offset := 0 for { diff --git a/runs/service/run_service_test.go b/runs/service/run_service_test.go index f56d0ba5c3..68537e1fa8 100644 --- a/runs/service/run_service_test.go +++ b/runs/service/run_service_test.go @@ -795,6 +795,9 @@ func TestListAndSendAllActionsUsesAscendingSort(t *testing.T) { runID := &common.RunIdentifier{Project: "p", Domain: "d", Name: "run-1"} + // The seed query runs first; return no rows so the streaming loop below is exercised. + actionRepo.On("ListActionPhasesForCounts", mock.Anything, mock.Anything).Return([]*models.Action{}, nil).Once() + var captured interfaces.ListResourceInput actionRepo.On("ListActions", mock.Anything, mock.MatchedBy(func(input interfaces.ListResourceInput) bool { captured = input diff --git a/runs/service/run_state_manager_test.go b/runs/service/run_state_manager_test.go index 194748f8f1..c081f662c7 100644 --- a/runs/service/run_state_manager_test.go +++ b/runs/service/run_state_manager_test.go @@ -101,6 +101,60 @@ func TestRunStateManagerErrorsWhenParentMissing(t *testing.T) { require.Nil(t, rsm.GetActionTreeNodeByName("child")) } +// TestRunStateManagerSeedThenRestreamIsCountNeutral covers the invariant the +// aggregate-count fix relies on: child phase counts are seeded up front from a +// lightweight query (so the count is correct from the first streamed page), and +// then the full snapshot re-streams the same rows. Re-upserting a node with the +// same phase must be count-neutral, so the seeded total neither doubles nor +// regresses as the streaming loop replays it. +func TestRunStateManagerSeedThenRestreamIsCountNeutral(t *testing.T) { + rsm, err := newRunStateManager(nil) + require.NoError(t, err) + ctx := context.Background() + + // root -> mapTask -> 150 QUEUED children, mirroring a large map task. + const children = 150 + seed := []*models.Action{ + testAction("root", nil, common.ActionPhase_ACTION_PHASE_RUNNING, 1), + testAction("mapTask", stringPtr("root"), common.ActionPhase_ACTION_PHASE_RUNNING, 2), + } + for i := 0; i < children; i++ { + seed = append(seed, testAction(fmt.Sprintf("c%03d", i), stringPtr("mapTask"), + common.ActionPhase_ACTION_PHASE_QUEUED, int64(10+i))) + } + + // Seed (listAndSendAllActions discards these updates). + _, err = rsm.upsertActions(ctx, seed) + require.NoError(t, err) + + queued := common.ActionPhase_ACTION_PHASE_QUEUED + require.Equal(t, children, rsm.GetActionTreeNodeByName("mapTask").ChildPhaseCounts[queued]) + // Transitive: root counts the mapTask node (RUNNING) plus all 150 grandchildren. + require.Equal(t, children, rsm.GetActionTreeNodeByName("root").ChildPhaseCounts[queued]) + require.Equal(t, 1, rsm.GetActionTreeNodeByName("root").ChildPhaseCounts[common.ActionPhase_ACTION_PHASE_RUNNING]) + + // Re-stream the same rows in pages (same phases) -> counts must not move. + const pageSize = 100 + for off := 0; off < len(seed); off += pageSize { + end := off + pageSize + if end > len(seed) { + end = len(seed) + } + _, err = rsm.upsertActions(ctx, seed[off:end]) + require.NoError(t, err) + require.Equal(t, children, rsm.GetActionTreeNodeByName("mapTask").ChildPhaseCounts[queued], + "re-streaming with the same phase must be count-neutral") + } + + // A genuine live phase change after the snapshot still adjusts the count. + _, err = rsm.upsertActions(ctx, []*models.Action{ + testAction("c000", stringPtr("mapTask"), common.ActionPhase_ACTION_PHASE_SUCCEEDED, 10), + }) + require.NoError(t, err) + require.Equal(t, children-1, rsm.GetActionTreeNodeByName("mapTask").ChildPhaseCounts[queued]) + require.Equal(t, 1, rsm.GetActionTreeNodeByName("mapTask").ChildPhaseCounts[common.ActionPhase_ACTION_PHASE_SUCCEEDED]) +} + func testAction(name string, parent *string, phase common.ActionPhase, createdAtSec int64) *models.Action { return testActionWithTask(name, parent, phase, createdAtSec, "") }