diff --git a/runs/service/run_service.go b/runs/service/run_service.go index 57fa5c8f0d..c926623f5e 100644 --- a/runs/service/run_service.go +++ b/runs/service/run_service.go @@ -1082,17 +1082,13 @@ func (s *RunService) WatchRunDetails( ) error { logger.Infof(ctx, "Received WatchRunDetails request") - // For now, just send initial state and close - // TODO: Implement actual streaming with polling or database triggers run, err := s.repo.ActionRepo().GetRun(ctx, req.Msg.RunId) if err != nil { return connect.NewError(connect.CodeNotFound, err) } resp := &workflow.WatchRunDetailsResponse{ - Details: &workflow.RunDetails{ - // Would populate from run model - }, + Details: s.runModelToDetails(run, req.Msg.RunId), } if err := stream.Send(resp); err != nil { @@ -1101,7 +1097,6 @@ func (s *RunService) WatchRunDetails( logger.Infof(ctx, "Sent initial run details for: %s", run.Name) - // Keep connection open and send updates (simplified) updates := make(chan *models.Run, 50) errs := make(chan error, 1) @@ -1115,9 +1110,7 @@ func (s *RunService) WatchRunDetails( return connect.NewError(connect.CodeInternal, err) case run := <-updates: resp := &workflow.WatchRunDetailsResponse{ - Details: &workflow.RunDetails{ - // Would populate from run - }, + Details: s.runModelToDetails(run, req.Msg.RunId), } if err := stream.Send(resp); err != nil { return err @@ -1462,8 +1455,36 @@ func (s *RunService) getClusterEventsInfo( return info, nil } +// runModelToDetails converts a DB Run model to a RunDetails proto. +func (s *RunService) runModelToDetails(run *models.Run, runID *common.RunIdentifier) *workflow.RunDetails { + if run == nil && runID == nil { + return nil + } + var runSpec *task.RunSpec + if run != nil && len(run.ActionSpec) > 0 { + var actionSpec workflow.ActionSpec + if err := json.Unmarshal(run.ActionSpec, &actionSpec); err == nil { + runSpec = actionSpec.RunSpec + } + } + + id := &common.ActionIdentifier{ + Run: runID, + } + if run != nil { + id.Name = run.Name + } + return &workflow.RunDetails{ + RunSpec: runSpec, + Action: s.actionModelToDetails(run, id), + } +} + // actionModelToDetails converts a DB Action model to an ActionDetails proto. func (s *RunService) actionModelToDetails(action *models.Action, actionID *common.ActionIdentifier) *workflow.ActionDetails { + if action == nil && actionID == nil { + return nil + } status := &workflow.ActionStatus{ Phase: common.ActionPhase(action.Phase), StartTime: timestamppb.New(action.CreatedAt), @@ -1472,6 +1493,8 @@ func (s *RunService) actionModelToDetails(action *models.Action, actionID *commo } if action.EndedAt.Valid { status.EndTime = timestamppb.New(action.EndedAt.Time) + durationMs := uint64(status.EndTime.AsTime().Sub(status.StartTime.AsTime()).Milliseconds()) + status.DurationMs = &durationMs } if action.DurationMs.Valid { durationMs := uint64(action.DurationMs.Int64) diff --git a/runs/service/run_service_test.go b/runs/service/run_service_test.go index 3c40bd5917..516a7cd08d 100644 --- a/runs/service/run_service_test.go +++ b/runs/service/run_service_test.go @@ -1735,3 +1735,402 @@ func TestGetActionLogContext(t *testing.T) { assert.Equal(t, connect.CodeInternal, connect.CodeOf(err)) }) } + +func TestActionModelToDetails(t *testing.T) { + svc := &RunService{} + now := time.Now().UTC().Truncate(time.Millisecond) + end := now.Add(5 * time.Second) + + tests := []struct { + name string + action *models.Action + actionID *common.ActionIdentifier + verify func(t *testing.T, result *workflow.ActionDetails) + }{ + { + name: "BothNilReturnsNil", + action: nil, + actionID: nil, + verify: func(t *testing.T, result *workflow.ActionDetails) { + assert.Nil(t, result) + }, + }, + { + name: "BasicStatusFields", + action: &models.Action{ + Phase: int32(common.ActionPhase_ACTION_PHASE_RUNNING), + CreatedAt: now, + Attempts: 2, + CacheStatus: core.CatalogCacheStatus_CACHE_HIT, + }, + actionID: testActionID, + verify: func(t *testing.T, result *workflow.ActionDetails) { + require.NotNil(t, result) + assert.Equal(t, testActionID, result.Id) + assert.Equal(t, common.ActionPhase_ACTION_PHASE_RUNNING, result.Status.Phase) + assert.Equal(t, now, result.Status.StartTime.AsTime()) + assert.Equal(t, uint32(2), result.Status.Attempts) + assert.Equal(t, core.CatalogCacheStatus_CACHE_HIT, result.Status.CacheStatus) + assert.Nil(t, result.Status.EndTime) + assert.Nil(t, result.Status.DurationMs) + }, + }, + { + name: "EndedAtSetsDurationFromTimestamps", + action: &models.Action{ + Phase: int32(common.ActionPhase_ACTION_PHASE_SUCCEEDED), + CreatedAt: now, + EndedAt: sql.NullTime{Time: end, Valid: true}, + }, + actionID: testActionID, + verify: func(t *testing.T, result *workflow.ActionDetails) { + require.NotNil(t, result.Status.EndTime) + assert.Equal(t, end, result.Status.EndTime.AsTime()) + require.NotNil(t, result.Status.DurationMs) + assert.Equal(t, uint64(5000), *result.Status.DurationMs) + }, + }, + { + name: "MetadataOptionalFields", + action: &models.Action{ + Phase: int32(common.ActionPhase_ACTION_PHASE_RUNNING), + ActionType: int32(workflow.ActionType_ACTION_TYPE_TASK), + FunctionName: "my_func", + ParentActionName: sql.NullString{String: "parent-action", Valid: true}, + ActionGroup: sql.NullString{String: "group-1", Valid: true}, + EnvironmentName: sql.NullString{String: "prod", Valid: true}, + }, + actionID: testActionID, + verify: func(t *testing.T, result *workflow.ActionDetails) { + require.NotNil(t, result.Metadata) + assert.Equal(t, workflow.ActionType_ACTION_TYPE_TASK, result.Metadata.ActionType) + assert.Equal(t, "my_func", result.Metadata.FuntionName) + assert.Equal(t, "parent-action", result.Metadata.Parent) + assert.Equal(t, "group-1", result.Metadata.Group) + assert.Equal(t, "prod", result.Metadata.EnvironmentName) + }, + }, + { + name: "TaskMetadataWithFullID", + action: &models.Action{ + Phase: int32(common.ActionPhase_ACTION_PHASE_RUNNING), + ActionType: int32(workflow.ActionType_ACTION_TYPE_TASK), + TaskType: "python", + TaskProject: sql.NullString{String: "proj", Valid: true}, + TaskDomain: sql.NullString{String: "dev", Valid: true}, + TaskName: sql.NullString{String: "my_task", Valid: true}, + TaskVersion: sql.NullString{String: "v1", Valid: true}, + TaskShortName: sql.NullString{String: "short", Valid: true}, + }, + actionID: testActionID, + verify: func(t *testing.T, result *workflow.ActionDetails) { + taskMeta := result.Metadata.GetTask() + require.NotNil(t, taskMeta) + assert.Equal(t, "python", taskMeta.TaskType) + assert.Equal(t, "short", taskMeta.ShortName) + require.NotNil(t, taskMeta.Id) + assert.Equal(t, "proj", taskMeta.Id.Project) + assert.Equal(t, "dev", taskMeta.Id.Domain) + assert.Equal(t, "my_task", taskMeta.Id.Name) + assert.Equal(t, "v1", taskMeta.Id.Version) + }, + }, + { + name: "TraceMetadata", + action: &models.Action{ + Phase: int32(common.ActionPhase_ACTION_PHASE_RUNNING), + ActionType: int32(workflow.ActionType_ACTION_TYPE_TRACE), + FunctionName: "trace_func", + }, + actionID: testActionID, + verify: func(t *testing.T, result *workflow.ActionDetails) { + traceMeta := result.Metadata.GetTrace() + require.NotNil(t, traceMeta) + assert.Equal(t, "trace_func", traceMeta.Name) + assert.Nil(t, result.Metadata.GetTask()) + }, + }, + { + name: "NilActionID", + action: &models.Action{ + Phase: int32(common.ActionPhase_ACTION_PHASE_QUEUED), + }, + actionID: nil, + verify: func(t *testing.T, result *workflow.ActionDetails) { + require.NotNil(t, result) + assert.Nil(t, result.Id) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tc.verify(t, svc.actionModelToDetails(tc.action, tc.actionID)) + }) + } +} + +func TestRunModelToDetails(t *testing.T) { + svc := &RunService{} + + testRunID := &common.RunIdentifier{ + Org: "test-org", + Project: "test-project", + Domain: "test-domain", + Name: "test-run", + } + + validActionSpecBytes, err := json.Marshal(&workflow.ActionSpec{ + RunSpec: &task.RunSpec{ + Labels: &task.Labels{ + Values: map[string]string{"env": "prod"}, + }, + }, + }) + require.NoError(t, err) + + tests := []struct { + name string + run *models.Run + runID *common.RunIdentifier + verify func(t *testing.T, result *workflow.RunDetails) + }{ + { + name: "BothNilReturnsNil", + run: nil, + runID: nil, + verify: func(t *testing.T, result *workflow.RunDetails) { + assert.Nil(t, result) + }, + }, + { + name: "EmptyActionSpecNilRunSpec", + run: &models.Run{ + Name: "test-run", + Phase: int32(common.ActionPhase_ACTION_PHASE_RUNNING), + }, + runID: testRunID, + verify: func(t *testing.T, result *workflow.RunDetails) { + require.NotNil(t, result) + assert.Nil(t, result.RunSpec) + }, + }, + { + name: "ValidActionSpecExtractsRunSpec", + run: &models.Run{ + Name: "test-run", + Phase: int32(common.ActionPhase_ACTION_PHASE_RUNNING), + ActionSpec: validActionSpecBytes, + }, + runID: testRunID, + verify: func(t *testing.T, result *workflow.RunDetails) { + require.NotNil(t, result) + require.NotNil(t, result.RunSpec) + require.NotNil(t, result.RunSpec.Labels) + assert.Equal(t, "prod", result.RunSpec.Labels.Values["env"]) + }, + }, + { + name: "MalformedActionSpecNilRunSpec", + run: &models.Run{ + Name: "test-run", + Phase: int32(common.ActionPhase_ACTION_PHASE_RUNNING), + ActionSpec: []byte("not-valid-json{{"), + }, + runID: testRunID, + verify: func(t *testing.T, result *workflow.RunDetails) { + require.NotNil(t, result) + assert.Nil(t, result.RunSpec) + }, + }, + { + name: "RunNamePropagatedToActionID", + run: &models.Run{ + Name: "my-run-name", + Phase: int32(common.ActionPhase_ACTION_PHASE_QUEUED), + }, + runID: testRunID, + verify: func(t *testing.T, result *workflow.RunDetails) { + require.NotNil(t, result) + require.NotNil(t, result.Action) + require.NotNil(t, result.Action.Id) + assert.Equal(t, "my-run-name", result.Action.Id.Name) + }, + }, + { + name: "RunIDPropagatedToActionID", + run: &models.Run{ + Name: testRunID.Name, + Phase: int32(common.ActionPhase_ACTION_PHASE_RUNNING), + }, + runID: testRunID, + verify: func(t *testing.T, result *workflow.RunDetails) { + require.NotNil(t, result) + require.NotNil(t, result.Action) + require.NotNil(t, result.Action.Id) + assert.Equal(t, testRunID.Org, result.Action.Id.Run.Org) + assert.Equal(t, testRunID.Project, result.Action.Id.Run.Project) + assert.Equal(t, testRunID.Domain, result.Action.Id.Run.Domain) + assert.Equal(t, testRunID.Name, result.Action.Id.Run.Name) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tc.verify(t, svc.runModelToDetails(tc.run, tc.runID)) + }) + } +} + +func newWatchRunDetailsTestClient(t *testing.T, actionRepo *repoMocks.ActionRepo) workflowconnect.RunServiceClient { + t.Helper() + taskRepo := &repoMocks.TaskRepo{} + repo := &repoMocks.Repository{} + repo.On("ActionRepo").Return(actionRepo) + repo.On("TaskRepo").Maybe().Return(taskRepo) + + actionsClient := actionsconnectmocks.NewActionsServiceClient(t) + svc := &RunService{repo: repo, actionsClient: actionsClient} + path, handler := workflowconnect.NewRunServiceHandler(svc) + + mux := http.NewServeMux() + mux.Handle(path, handler) + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + return workflowconnect.NewRunServiceClient(http.DefaultClient, server.URL) +} + +func TestWatchRunDetails(t *testing.T) { + runID := &common.RunIdentifier{ + Org: "test-org", + Project: "test-project", + Domain: "test-domain", + Name: "rtest-watch-1", + } + + runModel := &models.Run{ + Project: runID.Project, + Domain: runID.Domain, + Name: runID.Name, + Phase: int32(common.ActionPhase_ACTION_PHASE_RUNNING), + } + + tests := []struct { + name string + setupMocks func(actionRepo *repoMocks.ActionRepo) + test func(t *testing.T, client workflowconnect.RunServiceClient) + }{ + { + name: "GetRun_Error_Returns_NotFound", + setupMocks: func(actionRepo *repoMocks.ActionRepo) { + actionRepo.EXPECT().GetRun(mock.Anything, matchRunID(runID)).Return(nil, fmt.Errorf("run not found")) + }, + test: func(t *testing.T, client workflowconnect.RunServiceClient) { + ctx := context.Background() + stream, err := client.WatchRunDetails(ctx, connect.NewRequest(&workflow.WatchRunDetailsRequest{ + RunId: runID, + })) + require.NoError(t, err) + assert.False(t, stream.Receive()) + require.Error(t, stream.Err()) + var connectErr *connect.Error + require.True(t, errors.As(stream.Err(), &connectErr)) + assert.Equal(t, connect.CodeNotFound, connectErr.Code()) + }, + }, + { + name: "Initial_State_Sent_On_Success", + setupMocks: func(actionRepo *repoMocks.ActionRepo) { + actionRepo.EXPECT().GetRun(mock.Anything, matchRunID(runID)).Return(runModel, nil) + actionRepo.EXPECT().WatchRunUpdates(mock.Anything, matchRunID(runID), mock.Anything, mock.Anything). + RunAndReturn(func(ctx context.Context, _ *common.RunIdentifier, _ chan<- *models.Run, _ chan<- error) { + <-ctx.Done() + }).Call.Maybe() + }, + test: func(t *testing.T, client workflowconnect.RunServiceClient) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + stream, err := client.WatchRunDetails(ctx, connect.NewRequest(&workflow.WatchRunDetailsRequest{ + RunId: runID, + })) + require.NoError(t, err) + + assert.True(t, stream.Receive()) + require.NotNil(t, stream.Msg().Details) + + cancel() + assert.False(t, stream.Receive()) + }, + }, + { + name: "Watch_Error_Returns_Internal_Error", + setupMocks: func(actionRepo *repoMocks.ActionRepo) { + actionRepo.EXPECT().GetRun(mock.Anything, matchRunID(runID)).Return(runModel, nil) + actionRepo.EXPECT().WatchRunUpdates(mock.Anything, matchRunID(runID), mock.Anything, mock.Anything). + RunAndReturn(func(_ context.Context, _ *common.RunIdentifier, _ chan<- *models.Run, errs chan<- error) { + errs <- fmt.Errorf("db watch failed") + }) + }, + test: func(t *testing.T, client workflowconnect.RunServiceClient) { + ctx := context.Background() + stream, err := client.WatchRunDetails(ctx, connect.NewRequest(&workflow.WatchRunDetailsRequest{ + RunId: runID, + })) + require.NoError(t, err) + + assert.True(t, stream.Receive()) + + assert.False(t, stream.Receive()) + require.Error(t, stream.Err()) + var connectErr *connect.Error + require.True(t, errors.As(stream.Err(), &connectErr)) + assert.Equal(t, connect.CodeInternal, connectErr.Code()) + }, + }, + { + name: "Run_Update_Sends_New_Response", + setupMocks: func(actionRepo *repoMocks.ActionRepo) { + actionRepo.EXPECT().GetRun(mock.Anything, matchRunID(runID)).Return(runModel, nil) + actionRepo.EXPECT().WatchRunUpdates(mock.Anything, matchRunID(runID), mock.Anything, mock.Anything). + RunAndReturn(func(ctx context.Context, _ *common.RunIdentifier, updates chan<- *models.Run, _ chan<- error) { + updates <- &models.Run{ + Project: runID.Project, + Domain: runID.Domain, + Name: runID.Name, + Phase: int32(common.ActionPhase_ACTION_PHASE_SUCCEEDED), + } + <-ctx.Done() + }) + }, + test: func(t *testing.T, client workflowconnect.RunServiceClient) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + stream, err := client.WatchRunDetails(ctx, connect.NewRequest(&workflow.WatchRunDetailsRequest{ + RunId: runID, + })) + require.NoError(t, err) + + assert.True(t, stream.Receive()) + assert.True(t, stream.Receive()) + assert.NotNil(t, stream.Msg().Details) + + cancel() + assert.False(t, stream.Receive()) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + actionRepo := &repoMocks.ActionRepo{} + tc.setupMocks(actionRepo) + client := newWatchRunDetailsTestClient(t, actionRepo) + tc.test(t, client) + actionRepo.AssertExpectations(t) + }) + } +}