Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
pluginsCore "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/flytek8s"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/flytek8s/config"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/ioutils"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/k8s"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/tasklog"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/utils"
Expand Down Expand Up @@ -830,7 +831,20 @@ func (plugin rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginCont
return pluginsCore.PhaseInfoQueuedWithTaskInfo(time.Now(), pluginsCore.DefaultPhaseVersion, "cluster is suspended", info), nil
case rayv1.JobDeploymentStatusFailed:
failInfo := fmt.Sprintf("Failed to run Ray job %s with error: [%s] %s", rayJob.Name, rayJob.Status.Reason, rayJob.Status.Message)
// Honor a RECOVERABLE error.pb (written by sdk) so the task's retries fire. A failed RayJob surfaces here as a
// terminal phase, so -- unlike the success path -- the k8s plugin manager never reads the
// error file on our behalf. Key off the proto-level recoverability so only a genuine
// RECOVERABLE container error retries: an absent, unreadable, or malformed error file is
// reported by the reader as a SYSTEM error and stays terminal, preserving previous behavior.
phaseInfo, err = pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, failInfo, info), nil
if ow := pluginContext.OutputWriter(); ow != nil {
reader := ioutils.NewRemoteFileOutputReader(ctx, pluginContext.DataStore(), ow, 0)
if hasErr, readerErr := reader.IsError(ctx); readerErr == nil && hasErr {
if execErr, readerErr := reader.ReadError(ctx); readerErr == nil && execErr.GetRecoverability() == core.ContainerError_RECOVERABLE {
phaseInfo = pluginsCore.PhaseInfoRetryableFailure(flyteerr.TaskFailedWithError, failInfo, info)
}
}
}
default:
// We already handle all known deployment status, so this should never happen unless a future version of ray
// introduced a new job status.
Expand Down
101 changes: 98 additions & 3 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ import (
k8smocks "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/k8s/mocks"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/tasklog"
"github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/utils"
"github.com/flyteorg/flyte/v2/flytestdlib/contextutils"
"github.com/flyteorg/flyte/v2/flytestdlib/promutils"
"github.com/flyteorg/flyte/v2/flytestdlib/promutils/labeled"
"github.com/flyteorg/flyte/v2/flytestdlib/storage"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/plugins"
)
Expand Down Expand Up @@ -1463,6 +1467,8 @@ func newPluginContext(pluginState k8s.PluginState) *k8smocks.PluginContext {
}

func init() {
labeled.SetMetricKeys(contextutils.NamespaceKey)

f := defaultConfig
f.Logs = logs.LogConfig{
IsKubernetesEnabled: true,
Expand Down Expand Up @@ -1491,7 +1497,7 @@ func TestGetTaskPhase(t *testing.T) {
{rayv1.JobDeploymentStatusSuspending, pluginsCore.PhaseQueued, false},
}

startTime := time.Date(2024, 0, 0, 0, 0, 0, 0, time.UTC)
startTime := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
endTime := startTime.Add(time.Hour)
podName, contName, initCont := "ray-clust-ray-head", "ray-head", "init"
logCtx := &core.LogContext{
Expand Down Expand Up @@ -1561,6 +1567,69 @@ func TestGetTaskPhase(t *testing.T) {
}
}

// TestGetTaskPhase_RecoverableErrorFile verifies that a failed RayJob is mapped to a *retryable*
// failure when the task wrote a RECOVERABLE error.pb (e.g. user code raised
// FlyteRecoverableException) and stays terminal when the error file marks a non-recoverable error
// or is absent. This mirrors how the k8s plugin manager honors the error file for container/pod
// tasks on the success path, which a terminal RayJob phase would otherwise bypass.
func TestGetTaskPhase_RecoverableErrorFile(t *testing.T) {
ctx := context.Background()
handler := rayJobResourceHandler{}

newFailedRayJob := func() *rayv1.RayJob {
startTime := metav1.NewTime(time.Now())
return &rayv1.RayJob{
Spec: rayv1.RayJobSpec{
RayClusterSpec: &rayv1.RayClusterSpec{
HeadGroupSpec: rayv1.HeadGroupSpec{
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{Name: "ray-head", Image: "rayproject/ray:latest"},
},
},
},
},
},
},
Status: rayv1.RayJobStatus{
JobDeploymentStatus: rayv1.JobDeploymentStatusFailed,
RayClusterName: "ray-clust",
Reason: "AppFailed",
Message: "Job entrypoint command failed with exit code 1",
StartTime: &startTime,
},
}
}

newErrorDoc := func(kind core.ContainerError_Kind) *core.ErrorDocument {
return &core.ErrorDocument{Error: &core.ContainerError{
Code: "USER:Unknown",
Message: "boom",
Kind: kind,
Origin: core.ExecutionError_USER,
}}
}

for _, tc := range []struct {
name string
errorDoc *core.ErrorDocument
expectedPhase pluginsCore.Phase
}{
{"recoverable error.pb maps to retryable failure", newErrorDoc(core.ContainerError_RECOVERABLE), pluginsCore.PhaseRetryableFailure},
{"non-recoverable error.pb stays terminal", newErrorDoc(core.ContainerError_NON_RECOVERABLE), pluginsCore.PhasePermanentFailure},
{"absent error.pb stays terminal", nil, pluginsCore.PhasePermanentFailure},
{"malformed error.pb (nil Error) stays terminal", &core.ErrorDocument{}, pluginsCore.PhasePermanentFailure},
} {
t.Run(tc.name, func(t *testing.T) {
pluginCtx := rayPluginContextWithErrorDoc(k8s.PluginState{}, tc.errorDoc)
phaseInfo, err := handler.GetTaskPhase(ctx, pluginCtx, newFailedRayJob())
assert.NoError(t, err)
assert.Equal(t, tc.expectedPhase.String(), phaseInfo.Phase().String())
})
}
}

func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) {
rayJobResourceHandler := rayJobResourceHandler{}

Expand Down Expand Up @@ -1947,9 +2016,9 @@ func TestGetPropertiesRay(t *testing.T) {
assert.Equal(t, expected, rayJobResourceHandler.GetProperties())
}

func rayPluginContext(pluginState k8s.PluginState) *k8smocks.PluginContext {
func rayPluginContextWithErrorDoc(pluginState k8s.PluginState, errorDoc *core.ErrorDocument) *k8smocks.PluginContext {
pluginCtx := newPluginContext(pluginState)
startTime := time.Date(2024, 0, 0, 0, 0, 0, 0, time.UTC)
startTime := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
endTime := startTime.Add(time.Hour)
podName, contName, initCont := "ray-clust-ray-head", "ray-head", "init"
podList := []runtime.Object{
Expand Down Expand Up @@ -1993,9 +2062,35 @@ func rayPluginContext(pluginState k8s.PluginState) *k8smocks.PluginContext {
}
reader := fake.NewFakeClient(podList...)
pluginCtx.EXPECT().K8sReader().Return(reader)
wireErrorFile(pluginCtx, errorDoc)
return pluginCtx
}

// rayPluginContext builds a plugin context whose task error file is absent (the common case).
func rayPluginContext(pluginState k8s.PluginState) *k8smocks.PluginContext {
return rayPluginContextWithErrorDoc(pluginState, nil)
}

// wireErrorFile backs the plugin context's OutputWriter/DataStore with an in-memory store. A
// non-nil errorDoc is written to the task's error.pb path so GetTaskPhase can read it back; a nil
// errorDoc leaves the store empty, modeling a task that produced no error file.
func wireErrorFile(pluginCtx *k8smocks.PluginContext, errorDoc *core.ErrorDocument) {
store, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
if err != nil {
panic(err)
}
errPath := storage.DataReference("/error.pb")
if errorDoc != nil {
if err := store.WriteProtobuf(context.Background(), errPath, storage.Options{}, errorDoc); err != nil {
panic(err)
}
}
ow := &pluginIOMocks.OutputWriter{}
ow.EXPECT().GetErrorPath().Return(errPath).Maybe()
pluginCtx.EXPECT().OutputWriter().Return(ow).Maybe()
pluginCtx.EXPECT().DataStore().Return(store).Maybe()
}

func transformStructToStructPB(t *testing.T, obj interface{}) *structpb.Struct {
data, err := json.Marshal(obj)
assert.Nil(t, err)
Expand Down
Loading