From b1ef45bb95a1fd25b4d7407947574922d4acb190 Mon Sep 17 00:00:00 2001 From: Carina-TzuHsuan Date: Thu, 23 Apr 2026 16:37:00 -0500 Subject: [PATCH] runs: delegate RunLogsService.TailLogs to DataProxyService (#7252) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove the duplicate K8sLogStreamer from the runs service and replace it with a thin forwarding shim that calls DataProxyService.TailLogs. The /RunLogsService/TailLogs endpoint is kept registered for backward compatibility — clients see no change. Co-Authored-By: Claude Sonnet 4.6 --- runs/config/config.go | 10 +- runs/service/k8s_log_streamer.go | 97 --------- runs/service/k8s_log_streamer_test.go | 144 ------------- runs/service/run_logs_service.go | 93 +++------ runs/service/run_logs_service_test.go | 280 +++++++------------------- runs/setup.go | 17 +- 6 files changed, 114 insertions(+), 527 deletions(-) delete mode 100644 runs/service/k8s_log_streamer.go delete mode 100644 runs/service/k8s_log_streamer_test.go diff --git a/runs/config/config.go b/runs/config/config.go index 9c2ab1a73b..2f66bdda77 100644 --- a/runs/config/config.go +++ b/runs/config/config.go @@ -16,9 +16,10 @@ var defaultConfig = &Config{ Port: 8090, Host: "0.0.0.0", }, - WatchBufferSize: 100, - ActionsServiceURL: "http://localhost:8090", - StoragePrefix: "file:///tmp/flyte/data", + WatchBufferSize: 100, + ActionsServiceURL: "http://localhost:8090", + DataProxyServiceURL: "http://localhost:8088", + StoragePrefix: "file:///tmp/flyte/data", SeedProjects: []string{"flytesnacks"}, Domains: []DomainConfig{ {ID: "development", Name: "Development"}, @@ -50,6 +51,9 @@ type Config struct { // Actions service URL for enqueuing actions ActionsServiceURL string `json:"actionsServiceUrl" pflag:",URL of the actions service"` + // DataProxyServiceURL is the URL of the DataProxy service, used by RunLogsService to delegate log streaming. + DataProxyServiceURL string `json:"dataProxyServiceUrl" pflag:",URL of the DataProxy service"` + // StoragePrefix is the base URI for storing run data (inputs, outputs) // e.g. "s3://my-bucket" or "gs://my-bucket" or "file:///tmp/flyte/data" StoragePrefix string `json:"storagePrefix" pflag:",Base URI prefix for storing run inputs and outputs"` diff --git a/runs/service/k8s_log_streamer.go b/runs/service/k8s_log_streamer.go deleted file mode 100644 index 6ddf40d86d..0000000000 --- a/runs/service/k8s_log_streamer.go +++ /dev/null @@ -1,97 +0,0 @@ -package service - -import ( - "context" - "fmt" - - "connectrpc.com/connect" - corev1 "k8s.io/api/core/v1" - k8serrors "k8s.io/apimachinery/pkg/api/errors" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/client-go/kubernetes" - "k8s.io/client-go/rest" - - "github.com/flyteorg/flyte/v2/flytestdlib/k8s/podlogs" - "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" - "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/logs/dataplane" - "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow" -) - -const defaultInitialLines = int64(1000) - -// K8sLogStreamer streams logs directly from Kubernetes pods. -type K8sLogStreamer struct { - clientset kubernetes.Interface -} - -// NewK8sLogStreamer creates a K8sLogStreamer from a Kubernetes REST config. -// It clears the timeout so that long-lived log streams are not interrupted. -func NewK8sLogStreamer(k8sConfig *rest.Config) (*K8sLogStreamer, error) { - cfg := rest.CopyConfig(k8sConfig) - cfg.Timeout = 0 - clientset, err := kubernetes.NewForConfig(cfg) - if err != nil { - return nil, fmt.Errorf("failed to create kubernetes clientset: %w", err) - } - return &K8sLogStreamer{clientset: clientset}, nil -} - -// TailLogs streams log lines for the given LogContext from a Kubernetes pod. -func (s *K8sLogStreamer) TailLogs(ctx context.Context, logContext *core.LogContext, stream *connect.ServerStream[workflow.TailLogsResponse]) error { - pod, container, err := getPrimaryPodAndContainer(logContext) - if err != nil { - return connect.NewError(connect.CodeNotFound, err) - } - - tailLines := defaultInitialLines - opts := &corev1.PodLogOptions{ - Container: container.GetContainerName(), - Follow: true, - Timestamps: true, - TailLines: &tailLines, - } - - // Set SinceTime from container start time if available. - // When SinceTime is set, it takes precedence and we clear TailLines - // to stream all logs from that point forward. - if startTime := container.GetProcess().GetContainerStartTime(); startTime != nil { - t := metav1.NewTime(startTime.AsTime()) - opts.SinceTime = &t - opts.TailLines = nil - } - - // Only follow logs when the pod is actively running. For pending or - // terminated pods, disable follow so existing logs are returned immediately. - podObj, err := s.clientset.CoreV1().Pods(pod.GetNamespace()).Get(ctx, pod.GetPodName(), metav1.GetOptions{}) - if err != nil { - if k8serrors.IsNotFound(err) { - return connect.NewError(connect.CodeNotFound, fmt.Errorf("pod %s not found in namespace %s", pod.GetPodName(), pod.GetNamespace())) - } - return connect.NewError(connect.CodeInternal, fmt.Errorf("failed to get pod: %w", err)) - } - opts.Follow = podObj.Status.Phase == corev1.PodRunning - - // Create a context without the incoming gRPC deadline so long-lived follow - // streams are not killed by a short client/proxy timeout. Cancellation is - // still propagated so the stream closes when the client disconnects. - streamCtx, streamCancel := context.WithCancel(context.Background()) - defer streamCancel() - stop := context.AfterFunc(ctx, streamCancel) - defer stop() - - logStream, err := s.clientset.CoreV1().Pods(pod.GetNamespace()).GetLogs(pod.GetPodName(), opts).Stream(streamCtx) - if err != nil { - return connect.NewError(connect.CodeInternal, fmt.Errorf("failed to stream pod logs: %w", err)) - } - defer logStream.Close() - - err = podlogs.Stream(ctx, logStream, podlogs.DefaultBatchSize, func(lines []*dataplane.LogLine) error { - return stream.Send(&workflow.TailLogsResponse{ - Logs: []*workflow.TailLogsResponse_Logs{{Lines: lines}}, - }) - }) - if err != nil { - return connect.NewError(connect.CodeInternal, fmt.Errorf("error reading log stream: %w", err)) - } - return nil -} diff --git a/runs/service/k8s_log_streamer_test.go b/runs/service/k8s_log_streamer_test.go deleted file mode 100644 index 9dfb5148b2..0000000000 --- a/runs/service/k8s_log_streamer_test.go +++ /dev/null @@ -1,144 +0,0 @@ -package service - -import ( - "context" - "testing" - - "connectrpc.com/connect" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/client-go/kubernetes/fake" - - "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" -) - -func TestGetPrimaryPodAndContainer_HappyPath(t *testing.T) { - logCtx := &core.LogContext{ - PrimaryPodName: "my-pod", - Pods: []*core.PodLogContext{ - { - PodName: "my-pod", - Namespace: "default", - PrimaryContainerName: "main", - Containers: []*core.ContainerContext{ - {ContainerName: "main"}, - {ContainerName: "sidecar"}, - }, - }, - }, - } - - pod, container, err := getPrimaryPodAndContainer(logCtx) - assert.NoError(t, err) - assert.Equal(t, "my-pod", pod.GetPodName()) - assert.Equal(t, "default", pod.GetNamespace()) - assert.Equal(t, "main", container.GetContainerName()) -} - -func TestGetPrimaryPodAndContainer_EmptyPodName(t *testing.T) { - logCtx := &core.LogContext{ - PrimaryPodName: "", - } - - _, _, err := getPrimaryPodAndContainer(logCtx) - assert.Error(t, err) - assert.Contains(t, err.Error(), "primary pod name is empty") -} - -func TestGetPrimaryPodAndContainer_PodNotFound(t *testing.T) { - logCtx := &core.LogContext{ - PrimaryPodName: "missing-pod", - Pods: []*core.PodLogContext{ - {PodName: "other-pod"}, - }, - } - - _, _, err := getPrimaryPodAndContainer(logCtx) - assert.Error(t, err) - assert.Contains(t, err.Error(), "not found in log context") -} - -func TestGetPrimaryPodAndContainer_ContainerNotFound(t *testing.T) { - logCtx := &core.LogContext{ - PrimaryPodName: "my-pod", - Pods: []*core.PodLogContext{ - { - PodName: "my-pod", - PrimaryContainerName: "missing-container", - Containers: []*core.ContainerContext{ - {ContainerName: "other"}, - }, - }, - }, - } - - _, _, err := getPrimaryPodAndContainer(logCtx) - assert.Error(t, err) - assert.Contains(t, err.Error(), "primary container") -} - -func newTestLogContext(podName, namespace, containerName string) *core.LogContext { - return &core.LogContext{ - PrimaryPodName: podName, - Pods: []*core.PodLogContext{ - { - PodName: podName, - Namespace: namespace, - PrimaryContainerName: containerName, - Containers: []*core.ContainerContext{ - {ContainerName: containerName}, - }, - }, - }, - } -} - -func TestTailLogs_PodNotFound(t *testing.T) { - clientset := fake.NewSimpleClientset() // no pods - - streamer := &K8sLogStreamer{clientset: clientset} - logCtx := newTestLogContext("missing-pod", "default", "main") - - err := streamer.TailLogs(context.Background(), logCtx, nil) - require.Error(t, err) - assert.Equal(t, connect.CodeNotFound, connect.CodeOf(err)) - assert.Contains(t, err.Error(), "not found") -} - -func TestTailLogs_FollowSetBasedOnPodPhase(t *testing.T) { - tests := []struct { - name string - phase corev1.PodPhase - wantFollow bool - }{ - {"running pod should follow", corev1.PodRunning, true}, - {"succeeded pod should not follow", corev1.PodSucceeded, false}, - {"failed pod should not follow", corev1.PodFailed, false}, - {"pending pod should not follow", corev1.PodPending, false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - clientset := fake.NewSimpleClientset(&corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - Name: "my-pod", - Namespace: "default", - }, - Status: corev1.PodStatus{ - Phase: tt.phase, - }, - }) - - // Verify we can fetch the pod and the phase is correct. - podObj, err := clientset.CoreV1().Pods("default").Get(context.Background(), "my-pod", metav1.GetOptions{}) - require.NoError(t, err) - assert.Equal(t, tt.phase, podObj.Status.Phase) - - // Verify the follow logic: Follow should only be true when phase is Running. - gotFollow := podObj.Status.Phase == corev1.PodRunning - assert.Equal(t, tt.wantFollow, gotFollow) - }) - } -} diff --git a/runs/service/run_logs_service.go b/runs/service/run_logs_service.go index ac4b947b36..aa393bca54 100644 --- a/runs/service/run_logs_service.go +++ b/runs/service/run_logs_service.go @@ -2,46 +2,33 @@ package service import ( "context" - "errors" "fmt" - "database/sql" - "connectrpc.com/connect" "golang.org/x/sync/semaphore" - "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common" - "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/dataproxy" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/dataproxy/dataproxyconnect" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow" - "github.com/flyteorg/flyte/v2/runs/repository/interfaces" - - "github.com/samber/lo" ) const defaultMaxConcurrentStreams = 100 -// LogStreamer abstracts log fetching from different backends. -type LogStreamer interface { - TailLogs(ctx context.Context, logContext *core.LogContext, stream *connect.ServerStream[workflow.TailLogsResponse]) error -} - // RunLogsService implements the RunLogsServiceHandler interface. type RunLogsService struct { - repo interfaces.Repository - streamer LogStreamer - sem *semaphore.Weighted + dataProxyClient dataproxyconnect.DataProxyServiceClient + sem *semaphore.Weighted } // NewRunLogsService creates a new RunLogsService. -func NewRunLogsService(repo interfaces.Repository, streamer LogStreamer) *RunLogsService { +func NewRunLogsService(dataProxyClient dataproxyconnect.DataProxyServiceClient) *RunLogsService { return &RunLogsService{ - repo: repo, - streamer: streamer, - sem: semaphore.NewWeighted(defaultMaxConcurrentStreams), + dataProxyClient: dataProxyClient, + sem: semaphore.NewWeighted(defaultMaxConcurrentStreams), } } -// TailLogs streams pod logs for an action attempt. +// TailLogs streams pod logs for an action attempt by delegating to DataProxyService. func (s *RunLogsService) TailLogs(ctx context.Context, req *connect.Request[workflow.TailLogsRequest], stream *connect.ServerStream[workflow.TailLogsResponse]) error { msg := req.Msg if msg.GetActionId() == nil { @@ -52,56 +39,26 @@ func (s *RunLogsService) TailLogs(ctx context.Context, req *connect.Request[work } defer s.sem.Release(1) - logContext, err := getLogContextForAttempt(ctx, s.repo, msg.GetActionId(), msg.GetAttempt()) + dpStream, err := s.dataProxyClient.TailLogs(ctx, connect.NewRequest(&dataproxy.TailLogsRequest{ + ActionId: msg.GetActionId(), + Attempt: msg.GetAttempt(), + })) if err != nil { return err } - - return s.streamer.TailLogs(ctx, logContext, stream) -} - -// getLogContextForAttempt fetches the latest event for the given attempt and -// extracts its LogContext. Uses a targeted DB query instead of scanning all events. -func getLogContextForAttempt(ctx context.Context, repo interfaces.Repository, actionID *common.ActionIdentifier, attempt uint32) (*core.LogContext, error) { - m, err := repo.ActionRepo().GetLatestEventByAttempt(ctx, actionID, attempt) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("no event found for action %v attempt %d", actionID, attempt)) + defer dpStream.Close() + + for dpStream.Receive() { + dpResp := dpStream.Msg() + logs := make([]*workflow.TailLogsResponse_Logs, 0, len(dpResp.GetLogs())) + for _, l := range dpResp.GetLogs() { + logs = append(logs, &workflow.TailLogsResponse_Logs{ + Lines: l.GetLines(), + }) + } + if err := stream.Send(&workflow.TailLogsResponse{Logs: logs}); err != nil { + return err } - return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("failed to get event for action %v attempt %d: %w", actionID, attempt, err)) - } - - event, err := m.ToActionEvent() - if err != nil { - return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("failed to deserialize event: %w", err)) - } - - if event.GetLogContext() == nil { - return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("no log context found for action %v attempt %d", actionID, attempt)) - } - - return event.GetLogContext(), nil -} - -// getPrimaryPodAndContainer finds the primary pod and container from a LogContext. -func getPrimaryPodAndContainer(logContext *core.LogContext) (*core.PodLogContext, *core.ContainerContext, error) { - if logContext.GetPrimaryPodName() == "" { - return nil, nil, fmt.Errorf("primary pod name is empty in log context") - } - - pod, found := lo.Find(logContext.GetPods(), func(pod *core.PodLogContext) bool { - return pod.GetPodName() == logContext.GetPrimaryPodName() - }) - if !found { - return nil, nil, fmt.Errorf("primary pod %s not found in log context", logContext.GetPrimaryPodName()) - } - - container, found := lo.Find(pod.GetContainers(), func(c *core.ContainerContext) bool { - return c.GetContainerName() == pod.GetPrimaryContainerName() - }) - if !found { - return nil, nil, fmt.Errorf("primary container %s not found in pod %s", pod.GetPrimaryContainerName(), pod.GetPodName()) } - - return pod, container, nil + return dpStream.Err() } diff --git a/runs/service/run_logs_service_test.go b/runs/service/run_logs_service_test.go index aaec10e399..bdce8b6216 100644 --- a/runs/service/run_logs_service_test.go +++ b/runs/service/run_logs_service_test.go @@ -2,35 +2,34 @@ package service import ( "context" + "fmt" "net/http" "net/http/httptest" "testing" - "fmt" - "connectrpc.com/connect" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "google.golang.org/protobuf/types/known/timestamppb" - "database/sql" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common" - "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/dataproxy" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/dataproxy/dataproxyconnect" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/logs/dataplane" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/workflow/workflowconnect" - repoMocks "github.com/flyteorg/flyte/v2/runs/repository/mocks" - "github.com/flyteorg/flyte/v2/runs/repository/models" ) -// mockLogStreamer is a test double for LogStreamer. -type mockLogStreamer struct { - mock.Mock +// fakeDataProxyHandler is a minimal DataProxyServiceHandler for tests. +// Only TailLogs is overridden; all other methods return CodeUnimplemented via the embedded struct. +type fakeDataProxyHandler struct { + dataproxyconnect.UnimplementedDataProxyServiceHandler + tailLogsFn func(ctx context.Context, req *connect.Request[dataproxy.TailLogsRequest], stream *connect.ServerStream[dataproxy.TailLogsResponse]) error } -func (m *mockLogStreamer) TailLogs(ctx context.Context, logContext *core.LogContext, stream *connect.ServerStream[workflow.TailLogsResponse]) error { - args := m.Called(ctx, logContext, stream) - return args.Error(0) +func (f *fakeDataProxyHandler) TailLogs(ctx context.Context, req *connect.Request[dataproxy.TailLogsRequest], stream *connect.ServerStream[dataproxy.TailLogsResponse]) error { + if f.tailLogsFn != nil { + return f.tailLogsFn(ctx, req, stream) + } + return nil } var tailLogsActionID = &common.ActionIdentifier{ @@ -43,64 +42,43 @@ var tailLogsActionID = &common.ActionIdentifier{ Name: "action-1", } -func newTailLogsTestClient(t *testing.T, actionRepo *repoMocks.ActionRepo, streamer *mockLogStreamer) workflowconnect.RunLogsServiceClient { - repo := &repoMocks.Repository{} - repo.On("ActionRepo").Return(actionRepo) +// newTailLogsTestClient wires a RunLogsService backed by a fake DataProxy server and +// returns a client connected to the RunLogsService. +func newTailLogsTestClient(t *testing.T, dpHandler dataproxyconnect.DataProxyServiceHandler) workflowconnect.RunLogsServiceClient { + t.Helper() - svc := NewRunLogsService(repo, streamer) - path, handler := workflowconnect.NewRunLogsServiceHandler(svc) + dpMux := http.NewServeMux() + dpPath, dpHTTPHandler := dataproxyconnect.NewDataProxyServiceHandler(dpHandler) + dpMux.Handle(dpPath, dpHTTPHandler) + dpServer := httptest.NewServer(dpMux) + t.Cleanup(dpServer.Close) - mux := http.NewServeMux() - mux.Handle(path, handler) - server := httptest.NewServer(mux) - t.Cleanup(server.Close) + dpClient := dataproxyconnect.NewDataProxyServiceClient(http.DefaultClient, dpServer.URL) + svc := NewRunLogsService(dpClient) - client := workflowconnect.NewRunLogsServiceClient(http.DefaultClient, server.URL) - return client -} + runLogsMux := http.NewServeMux() + path, handler := workflowconnect.NewRunLogsServiceHandler(svc) + runLogsMux.Handle(path, handler) + server := httptest.NewServer(runLogsMux) + t.Cleanup(server.Close) -func makeEventWithLogContext(actionID *common.ActionIdentifier, attempt uint32, phase common.ActionPhase, logCtx *core.LogContext) *models.ActionEvent { - event := &workflow.ActionEvent{ - Id: actionID, - Attempt: attempt, - Phase: phase, - Version: 0, - UpdatedTime: timestamppb.Now(), - LogContext: logCtx, - } - m, _ := models.NewActionEventModel(event) - return m + return workflowconnect.NewRunLogsServiceClient(http.DefaultClient, server.URL) } func TestTailLogs_HappyPath(t *testing.T) { - actionRepo := &repoMocks.ActionRepo{} - streamer := &mockLogStreamer{} - - logCtx := &core.LogContext{ - PrimaryPodName: "my-pod", - Pods: []*core.PodLogContext{ - {PodName: "my-pod", Namespace: "ns"}, - }, - } - - eventModel := makeEventWithLogContext(tailLogsActionID, 1, common.ActionPhase_ACTION_PHASE_RUNNING, logCtx) - actionRepo.On("GetLatestEventByAttempt", mock.Anything, mock.Anything, uint32(1)).Return(eventModel, nil) - - // The streamer should be called with the logContext and send some response. - streamer.On("TailLogs", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { - stream := args.Get(2).(*connect.ServerStream[workflow.TailLogsResponse]) - _ = stream.Send(&workflow.TailLogsResponse{ - Logs: []*workflow.TailLogsResponse_Logs{ - { - Lines: []*dataplane.LogLine{ + dpHandler := &fakeDataProxyHandler{ + tailLogsFn: func(_ context.Context, _ *connect.Request[dataproxy.TailLogsRequest], stream *connect.ServerStream[dataproxy.TailLogsResponse]) error { + return stream.Send(&dataproxy.TailLogsResponse{ + Logs: []*dataproxy.TailLogsResponse_Logs{ + {Lines: []*dataplane.LogLine{ {Message: "hello world", Originator: dataplane.LogLineOriginator_USER}, - }, + }}, }, - }, - }) - }).Return(nil) + }) + }, + } - client := newTailLogsTestClient(t, actionRepo, streamer) + client := newTailLogsTestClient(t, dpHandler) stream, err := client.TailLogs(context.Background(), connect.NewRequest(&workflow.TailLogsRequest{ ActionId: tailLogsActionID, @@ -116,73 +94,29 @@ func TestTailLogs_HappyPath(t *testing.T) { assert.False(t, stream.Receive()) assert.NoError(t, stream.Err()) - - actionRepo.AssertExpectations(t) - streamer.AssertExpectations(t) } -func TestTailLogs_NoLogContext(t *testing.T) { - actionRepo := &repoMocks.ActionRepo{} - streamer := &mockLogStreamer{} - - // Event without LogContext. - eventModel := makeEventWithLogContext(tailLogsActionID, 1, common.ActionPhase_ACTION_PHASE_RUNNING, nil) - actionRepo.On("GetLatestEventByAttempt", mock.Anything, mock.Anything, uint32(1)).Return(eventModel, nil) - - client := newTailLogsTestClient(t, actionRepo, streamer) +func TestTailLogs_MissingActionID(t *testing.T) { + client := newTailLogsTestClient(t, &fakeDataProxyHandler{}) stream, err := client.TailLogs(context.Background(), connect.NewRequest(&workflow.TailLogsRequest{ - ActionId: tailLogsActionID, - Attempt: 1, + Attempt: 1, })) assert.NoError(t, err) assert.False(t, stream.Receive()) assert.Error(t, stream.Err()) - assert.Equal(t, connect.CodeNotFound, connect.CodeOf(stream.Err())) - - actionRepo.AssertExpectations(t) + assert.Equal(t, connect.CodeInvalidArgument, connect.CodeOf(stream.Err())) } -func TestTailLogs_GetLatestEventError(t *testing.T) { - actionRepo := &repoMocks.ActionRepo{} - streamer := &mockLogStreamer{} - - actionRepo.On("GetLatestEventByAttempt", mock.Anything, mock.Anything, uint32(1)).Return(nil, assert.AnError) - - client := newTailLogsTestClient(t, actionRepo, streamer) - - stream, err := client.TailLogs(context.Background(), connect.NewRequest(&workflow.TailLogsRequest{ - ActionId: tailLogsActionID, - Attempt: 1, - })) - assert.NoError(t, err) - - assert.False(t, stream.Receive()) - assert.Error(t, stream.Err()) - assert.Equal(t, connect.CodeInternal, connect.CodeOf(stream.Err())) - - actionRepo.AssertExpectations(t) -} - -func TestTailLogs_StreamerError(t *testing.T) { - actionRepo := &repoMocks.ActionRepo{} - streamer := &mockLogStreamer{} - - logCtx := &core.LogContext{ - PrimaryPodName: "my-pod", - Pods: []*core.PodLogContext{ - {PodName: "my-pod", Namespace: "ns"}, +func TestTailLogs_DataProxyError(t *testing.T) { + dpHandler := &fakeDataProxyHandler{ + tailLogsFn: func(_ context.Context, _ *connect.Request[dataproxy.TailLogsRequest], _ *connect.ServerStream[dataproxy.TailLogsResponse]) error { + return connect.NewError(connect.CodeNotFound, fmt.Errorf("action not found")) }, } - eventModel := makeEventWithLogContext(tailLogsActionID, 1, common.ActionPhase_ACTION_PHASE_RUNNING, logCtx) - actionRepo.On("GetLatestEventByAttempt", mock.Anything, mock.Anything, uint32(1)).Return(eventModel, nil) - - streamerErr := connect.NewError(connect.CodeInternal, assert.AnError) - streamer.On("TailLogs", mock.Anything, mock.Anything, mock.Anything).Return(streamerErr) - - client := newTailLogsTestClient(t, actionRepo, streamer) + client := newTailLogsTestClient(t, dpHandler) stream, err := client.TailLogs(context.Background(), connect.NewRequest(&workflow.TailLogsRequest{ ActionId: tailLogsActionID, @@ -192,43 +126,24 @@ func TestTailLogs_StreamerError(t *testing.T) { assert.False(t, stream.Receive()) assert.Error(t, stream.Err()) - - actionRepo.AssertExpectations(t) - streamer.AssertExpectations(t) } func TestTailLogs_ConcurrencyLimit(t *testing.T) { - actionRepo := &repoMocks.ActionRepo{} - streamer := &mockLogStreamer{} - - logCtx := &core.LogContext{ - PrimaryPodName: "my-pod", - Pods: []*core.PodLogContext{ - {PodName: "my-pod", Namespace: "ns"}, - }, - } - - eventModel := makeEventWithLogContext(tailLogsActionID, 1, common.ActionPhase_ACTION_PHASE_RUNNING, logCtx) - actionRepo.On("GetLatestEventByAttempt", mock.Anything, mock.Anything, mock.Anything).Return(eventModel, nil) - - // Block in streamer to hold semaphore slots. - blocker := make(chan struct{}) - streamer.On("TailLogs", mock.Anything, mock.Anything, mock.Anything).Run(func(_ mock.Arguments) { - <-blocker - }).Return(nil) - - repo := &repoMocks.Repository{} - repo.On("ActionRepo").Return(actionRepo) - - svc := NewRunLogsService(repo, streamer) - // Exhaust the semaphore by acquiring all slots. - svc.sem.Acquire(context.Background(), defaultMaxConcurrentStreams) - - // The next TailLogs should be rejected. + dpMux := http.NewServeMux() + dpPath, dpHTTPHandler := dataproxyconnect.NewDataProxyServiceHandler(&fakeDataProxyHandler{}) + dpMux.Handle(dpPath, dpHTTPHandler) + dpServer := httptest.NewServer(dpMux) + t.Cleanup(dpServer.Close) + + dpClient := dataproxyconnect.NewDataProxyServiceClient(http.DefaultClient, dpServer.URL) + svc := NewRunLogsService(dpClient) + // Exhaust the semaphore so the next request is rejected. + svc.sem.Acquire(context.Background(), defaultMaxConcurrentStreams) //nolint:errcheck + + runLogsMux := http.NewServeMux() path, handler := workflowconnect.NewRunLogsServiceHandler(svc) - mux := http.NewServeMux() - mux.Handle(path, handler) - server := httptest.NewServer(mux) + runLogsMux.Handle(path, handler) + server := httptest.NewServer(runLogsMux) t.Cleanup(server.Close) client := workflowconnect.NewRunLogsServiceClient(http.DefaultClient, server.URL) @@ -242,76 +157,29 @@ func TestTailLogs_ConcurrencyLimit(t *testing.T) { assert.Error(t, stream.Err()) assert.Equal(t, connect.CodeResourceExhausted, connect.CodeOf(stream.Err())) - // Release so cleanup doesn't hang. svc.sem.Release(defaultMaxConcurrentStreams) - close(blocker) } -func TestTailLogs_AttemptZero(t *testing.T) { - actionRepo := &repoMocks.ActionRepo{} - streamer := &mockLogStreamer{} - - logCtx := &core.LogContext{ - PrimaryPodName: "my-pod", - Pods: []*core.PodLogContext{ - {PodName: "my-pod", Namespace: "ns"}, +func TestTailLogs_RequestForwardedToDataProxy(t *testing.T) { + var capturedReq *dataproxy.TailLogsRequest + dpHandler := &fakeDataProxyHandler{ + tailLogsFn: func(_ context.Context, req *connect.Request[dataproxy.TailLogsRequest], _ *connect.ServerStream[dataproxy.TailLogsResponse]) error { + capturedReq = req.Msg + return nil }, } - eventModel := makeEventWithLogContext(tailLogsActionID, 0, common.ActionPhase_ACTION_PHASE_RUNNING, logCtx) - actionRepo.On("GetLatestEventByAttempt", mock.Anything, mock.Anything, uint32(0)).Return(eventModel, nil) - - streamer.On("TailLogs", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { - stream := args.Get(2).(*connect.ServerStream[workflow.TailLogsResponse]) - _ = stream.Send(&workflow.TailLogsResponse{ - Logs: []*workflow.TailLogsResponse_Logs{ - { - Lines: []*dataplane.LogLine{ - {Message: "attempt zero log", Originator: dataplane.LogLineOriginator_USER}, - }, - }, - }, - }) - }).Return(nil) - - client := newTailLogsTestClient(t, actionRepo, streamer) + client := newTailLogsTestClient(t, dpHandler) stream, err := client.TailLogs(context.Background(), connect.NewRequest(&workflow.TailLogsRequest{ ActionId: tailLogsActionID, - Attempt: 0, + Attempt: 3, })) assert.NoError(t, err) - - assert.True(t, stream.Receive()) - resp := stream.Msg() - assert.Len(t, resp.Logs, 1) - assert.Equal(t, "attempt zero log", resp.Logs[0].Lines[0].Message) - assert.False(t, stream.Receive()) assert.NoError(t, stream.Err()) - actionRepo.AssertExpectations(t) - streamer.AssertExpectations(t) -} - -func TestTailLogs_EventNotFound(t *testing.T) { - actionRepo := &repoMocks.ActionRepo{} - streamer := &mockLogStreamer{} - - actionRepo.On("GetLatestEventByAttempt", mock.Anything, mock.Anything, uint32(5)). - Return(nil, fmt.Errorf("event not found for attempt 5: %w", sql.ErrNoRows)) - - client := newTailLogsTestClient(t, actionRepo, streamer) - - stream, err := client.TailLogs(context.Background(), connect.NewRequest(&workflow.TailLogsRequest{ - ActionId: tailLogsActionID, - Attempt: 5, - })) - assert.NoError(t, err) - - assert.False(t, stream.Receive()) - assert.Error(t, stream.Err()) - assert.Equal(t, connect.CodeNotFound, connect.CodeOf(stream.Err())) - - actionRepo.AssertExpectations(t) + assert.NotNil(t, capturedReq) + assert.Equal(t, tailLogsActionID.GetName(), capturedReq.GetActionId().GetName()) + assert.Equal(t, uint32(3), capturedReq.GetAttempt()) } diff --git a/runs/setup.go b/runs/setup.go index 22483574ed..77f70362a7 100644 --- a/runs/setup.go +++ b/runs/setup.go @@ -120,16 +120,15 @@ func Setup(ctx context.Context, sc *app.SetupContext) error { sc.Mux.Handle(projectPath, projectHandler) logger.Infof(ctx, "Mounted ProjectService at %s", projectPath) - if sc.K8sConfig != nil { - logStreamer, err := service.NewK8sLogStreamer(sc.K8sConfig) - if err != nil { - return fmt.Errorf("runs: failed to create k8s log streamer: %w", err) - } - runLogsSvc := service.NewRunLogsService(repo, logStreamer) - runLogsPath, runLogsHandler := workflowconnect.NewRunLogsServiceHandler(runLogsSvc) - sc.Mux.Handle(runLogsPath, runLogsHandler) - logger.Infof(ctx, "Mounted RunLogsService at %s", runLogsPath) + dataProxyURL := cfg.DataProxyServiceURL + if sc.BaseURL != "" { + dataProxyURL = sc.BaseURL } + dataProxyClient := dataproxyconnect.NewDataProxyServiceClient(http.DefaultClient, dataProxyURL) + runLogsSvc := service.NewRunLogsService(dataProxyClient) + runLogsPath, runLogsHandler := workflowconnect.NewRunLogsServiceHandler(runLogsSvc) + sc.Mux.Handle(runLogsPath, runLogsHandler) + logger.Infof(ctx, "Mounted RunLogsService at %s", runLogsPath) if err := seedProjects(ctx, impl.NewProjectRepo(sc.DB), cfg.SeedProjects); err != nil { return fmt.Errorf("runs: failed to seed projects: %w", err)