diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index 0e9a26e873..e5ccd567df 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -25,6 +25,7 @@ import ( pluginsCore "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core" "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/flytek8s" "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" + "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/ioutils" "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/utils" @@ -830,7 +831,20 @@ func (plugin rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginCont return pluginsCore.PhaseInfoQueuedWithTaskInfo(time.Now(), pluginsCore.DefaultPhaseVersion, "cluster is suspended", info), nil case rayv1.JobDeploymentStatusFailed: failInfo := fmt.Sprintf("Failed to run Ray job %s with error: [%s] %s", rayJob.Name, rayJob.Status.Reason, rayJob.Status.Message) + // Honor a RECOVERABLE error.pb (written by sdk) so the task's retries fire. A failed RayJob surfaces here as a + // terminal phase, so -- unlike the success path -- the k8s plugin manager never reads the + // error file on our behalf. Key off the proto-level recoverability so only a genuine + // RECOVERABLE container error retries: an absent, unreadable, or malformed error file is + // reported by the reader as a SYSTEM error and stays terminal, preserving previous behavior. phaseInfo, err = pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, failInfo, info), nil + if ow := pluginContext.OutputWriter(); ow != nil { + reader := ioutils.NewRemoteFileOutputReader(ctx, pluginContext.DataStore(), ow, 0) + if hasErr, readerErr := reader.IsError(ctx); readerErr == nil && hasErr { + if execErr, readerErr := reader.ReadError(ctx); readerErr == nil && execErr.GetRecoverability() == core.ContainerError_RECOVERABLE { + phaseInfo = pluginsCore.PhaseInfoRetryableFailure(flyteerr.TaskFailedWithError, failInfo, info) + } + } + } default: // We already handle all known deployment status, so this should never happen unless a future version of ray // introduced a new job status. diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 7bc066d5c4..f13f51940b 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -29,6 +29,10 @@ import ( k8smocks "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/k8s/mocks" "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/tasklog" "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/flyteorg/flyte/v2/flytestdlib/contextutils" + "github.com/flyteorg/flyte/v2/flytestdlib/promutils" + "github.com/flyteorg/flyte/v2/flytestdlib/promutils/labeled" + "github.com/flyteorg/flyte/v2/flytestdlib/storage" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/plugins" ) @@ -1463,6 +1467,8 @@ func newPluginContext(pluginState k8s.PluginState) *k8smocks.PluginContext { } func init() { + labeled.SetMetricKeys(contextutils.NamespaceKey) + f := defaultConfig f.Logs = logs.LogConfig{ IsKubernetesEnabled: true, @@ -1491,7 +1497,7 @@ func TestGetTaskPhase(t *testing.T) { {rayv1.JobDeploymentStatusSuspending, pluginsCore.PhaseQueued, false}, } - startTime := time.Date(2024, 0, 0, 0, 0, 0, 0, time.UTC) + startTime := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) endTime := startTime.Add(time.Hour) podName, contName, initCont := "ray-clust-ray-head", "ray-head", "init" logCtx := &core.LogContext{ @@ -1561,6 +1567,69 @@ func TestGetTaskPhase(t *testing.T) { } } +// TestGetTaskPhase_RecoverableErrorFile verifies that a failed RayJob is mapped to a *retryable* +// failure when the task wrote a RECOVERABLE error.pb (e.g. user code raised +// FlyteRecoverableException) and stays terminal when the error file marks a non-recoverable error +// or is absent. This mirrors how the k8s plugin manager honors the error file for container/pod +// tasks on the success path, which a terminal RayJob phase would otherwise bypass. +func TestGetTaskPhase_RecoverableErrorFile(t *testing.T) { + ctx := context.Background() + handler := rayJobResourceHandler{} + + newFailedRayJob := func() *rayv1.RayJob { + startTime := metav1.NewTime(time.Now()) + return &rayv1.RayJob{ + Spec: rayv1.RayJobSpec{ + RayClusterSpec: &rayv1.RayClusterSpec{ + HeadGroupSpec: rayv1.HeadGroupSpec{ + Template: corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + {Name: "ray-head", Image: "rayproject/ray:latest"}, + }, + }, + }, + }, + }, + }, + Status: rayv1.RayJobStatus{ + JobDeploymentStatus: rayv1.JobDeploymentStatusFailed, + RayClusterName: "ray-clust", + Reason: "AppFailed", + Message: "Job entrypoint command failed with exit code 1", + StartTime: &startTime, + }, + } + } + + newErrorDoc := func(kind core.ContainerError_Kind) *core.ErrorDocument { + return &core.ErrorDocument{Error: &core.ContainerError{ + Code: "USER:Unknown", + Message: "boom", + Kind: kind, + Origin: core.ExecutionError_USER, + }} + } + + for _, tc := range []struct { + name string + errorDoc *core.ErrorDocument + expectedPhase pluginsCore.Phase + }{ + {"recoverable error.pb maps to retryable failure", newErrorDoc(core.ContainerError_RECOVERABLE), pluginsCore.PhaseRetryableFailure}, + {"non-recoverable error.pb stays terminal", newErrorDoc(core.ContainerError_NON_RECOVERABLE), pluginsCore.PhasePermanentFailure}, + {"absent error.pb stays terminal", nil, pluginsCore.PhasePermanentFailure}, + {"malformed error.pb (nil Error) stays terminal", &core.ErrorDocument{}, pluginsCore.PhasePermanentFailure}, + } { + t.Run(tc.name, func(t *testing.T) { + pluginCtx := rayPluginContextWithErrorDoc(k8s.PluginState{}, tc.errorDoc) + phaseInfo, err := handler.GetTaskPhase(ctx, pluginCtx, newFailedRayJob()) + assert.NoError(t, err) + assert.Equal(t, tc.expectedPhase.String(), phaseInfo.Phase().String()) + }) + } +} + func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) { rayJobResourceHandler := rayJobResourceHandler{} @@ -1947,9 +2016,9 @@ func TestGetPropertiesRay(t *testing.T) { assert.Equal(t, expected, rayJobResourceHandler.GetProperties()) } -func rayPluginContext(pluginState k8s.PluginState) *k8smocks.PluginContext { +func rayPluginContextWithErrorDoc(pluginState k8s.PluginState, errorDoc *core.ErrorDocument) *k8smocks.PluginContext { pluginCtx := newPluginContext(pluginState) - startTime := time.Date(2024, 0, 0, 0, 0, 0, 0, time.UTC) + startTime := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) endTime := startTime.Add(time.Hour) podName, contName, initCont := "ray-clust-ray-head", "ray-head", "init" podList := []runtime.Object{ @@ -1993,9 +2062,35 @@ func rayPluginContext(pluginState k8s.PluginState) *k8smocks.PluginContext { } reader := fake.NewFakeClient(podList...) pluginCtx.EXPECT().K8sReader().Return(reader) + wireErrorFile(pluginCtx, errorDoc) return pluginCtx } +// rayPluginContext builds a plugin context whose task error file is absent (the common case). +func rayPluginContext(pluginState k8s.PluginState) *k8smocks.PluginContext { + return rayPluginContextWithErrorDoc(pluginState, nil) +} + +// wireErrorFile backs the plugin context's OutputWriter/DataStore with an in-memory store. A +// non-nil errorDoc is written to the task's error.pb path so GetTaskPhase can read it back; a nil +// errorDoc leaves the store empty, modeling a task that produced no error file. +func wireErrorFile(pluginCtx *k8smocks.PluginContext, errorDoc *core.ErrorDocument) { + store, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + if err != nil { + panic(err) + } + errPath := storage.DataReference("/error.pb") + if errorDoc != nil { + if err := store.WriteProtobuf(context.Background(), errPath, storage.Options{}, errorDoc); err != nil { + panic(err) + } + } + ow := &pluginIOMocks.OutputWriter{} + ow.EXPECT().GetErrorPath().Return(errPath).Maybe() + pluginCtx.EXPECT().OutputWriter().Return(ow).Maybe() + pluginCtx.EXPECT().DataStore().Return(store).Maybe() +} + func transformStructToStructPB(t *testing.T, obj interface{}) *structpb.Struct { data, err := json.Marshal(obj) assert.Nil(t, err)