diff --git a/flyte-single-binary-local.yaml b/flyte-single-binary-local.yaml index 0b3caf45f1..0977283e51 100644 --- a/flyte-single-binary-local.yaml +++ b/flyte-single-binary-local.yaml @@ -45,6 +45,7 @@ tasks: - K8S-ARRAY - connector-service - echo + - core-sleep default-for-task-types: - container: container - container_array: K8S-ARRAY diff --git a/flyteplugins/go/tasks/plugins/core/sleep/plugin.go b/flyteplugins/go/tasks/plugins/core/sleep/plugin.go new file mode 100644 index 0000000000..5f3d465ad1 --- /dev/null +++ b/flyteplugins/go/tasks/plugins/core/sleep/plugin.go @@ -0,0 +1,155 @@ +package sleep + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery" + core "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" + idlcore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" +) + +const sleepTaskType = "core-sleep" + +type invalidInputError struct { + message string +} + +func (e *invalidInputError) Error() string { + return e.message +} + +type Plugin struct { + taskStartTimes map[string]time.Time + sync.Mutex +} + +func (p *Plugin) GetID() string { + return sleepTaskType +} + +func (p *Plugin) GetProperties() core.PluginProperties { + return core.PluginProperties{} +} + +func (p *Plugin) getOrAddTaskStartTime(tCtx core.TaskExecutionContext) time.Time { + p.Lock() + defer p.Unlock() + + taskExecutionID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + if startTime, exists := p.taskStartTimes[taskExecutionID]; exists { + return startTime + } + + startTime := time.Now() + p.taskStartTimes[taskExecutionID] = startTime + return startTime +} + +func (p *Plugin) removeTask(taskExecutionID string) { + p.Lock() + defer p.Unlock() + delete(p.taskStartTimes, taskExecutionID) +} + +func (p *Plugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) { + sleepDuration, err := resolveSleepDuration(ctx, tCtx) + if err != nil { + var invalidErr *invalidInputError + if errors.As(err, &invalidErr) { + return core.DoTransition(core.PhaseInfoFailure("BadTaskSpecification", invalidErr.Error(), nil)), nil + } + return core.UnknownTransition, err + } + + if sleepDuration == 0 { + return core.DoTransition(core.PhaseInfoSuccess(nil)), nil + } + + startTime := p.getOrAddTaskStartTime(tCtx) + if time.Since(startTime) >= sleepDuration { + return core.DoTransition(core.PhaseInfoSuccess(nil)), nil + } + + return core.DoTransition(core.PhaseInfoRunning(core.DefaultPhaseVersion, nil)), nil +} + +func (p *Plugin) Abort(ctx context.Context, tCtx core.TaskExecutionContext) error { + return nil +} + +func (p *Plugin) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) error { + taskExecutionID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + p.removeTask(taskExecutionID) + return nil +} + +func resolveSleepDuration(ctx context.Context, tCtx core.TaskExecutionContext) (time.Duration, error) { + taskTemplate, err := tCtx.TaskReader().Read(ctx) + if err != nil { + return 0, fmt.Errorf("failed to read task template: %w", err) + } + if taskTemplate == nil { + return 0, fmt.Errorf("nil task template") + } + + iface := taskTemplate.GetInterface() + if iface == nil || iface.GetInputs() == nil || len(iface.GetInputs().GetVariables()) != 1 { + return 0, &invalidInputError{message: fmt.Sprintf("task type [%s] requires exactly one duration input", sleepTaskType)} + } + if iface.GetOutputs() != nil && len(iface.GetOutputs().GetVariables()) != 0 { + return 0, &invalidInputError{message: fmt.Sprintf("task type [%s] does not support outputs", sleepTaskType)} + } + + // v1 GetVariables() returns map[string]*Variable (v2 returns []*VariableEntry) + inputName := "" + for name, v := range iface.GetInputs().GetVariables() { + if v == nil || v.GetType() == nil || v.GetType().GetSimple() != idlcore.SimpleType_DURATION { + return 0, &invalidInputError{message: fmt.Sprintf("input [%s] must be typed as duration", name)} + } + inputName = name + } + + inputs, err := tCtx.InputReader().Get(ctx) + if err != nil { + return 0, fmt.Errorf("failed to read task inputs: %w", err) + } + if inputs == nil { + return 0, &invalidInputError{message: fmt.Sprintf("task type [%s] requires a duration input value", sleepTaskType)} + } + + literal, ok := inputs.GetLiterals()[inputName] + if !ok || literal == nil { + return 0, &invalidInputError{message: fmt.Sprintf("duration input [%s] is missing", inputName)} + } + + durationValue := literal.GetScalar().GetPrimitive().GetDuration() + if durationValue == nil { + return 0, &invalidInputError{message: fmt.Sprintf("duration input [%s] must be a duration literal", inputName)} + } + + sleepDuration := durationValue.AsDuration() + if sleepDuration < 0 { + return 0, &invalidInputError{message: fmt.Sprintf("duration input [%s] must be non-negative", inputName)} + } + + return sleepDuration, nil +} + +func init() { + pluginmachinery.PluginRegistry().RegisterCorePlugin( + core.PluginEntry{ + ID: sleepTaskType, + RegisteredTaskTypes: []core.TaskType{sleepTaskType}, + LoadPlugin: func(ctx context.Context, iCtx core.SetupContext) (core.Plugin, error) { + return &Plugin{ + taskStartTimes: make(map[string]time.Time), + }, nil + }, + IsDefault: false, + }, + ) +} diff --git a/flyteplugins/go/tasks/plugins/core/sleep/plugin_test.go b/flyteplugins/go/tasks/plugins/core/sleep/plugin_test.go new file mode 100644 index 0000000000..355560591f --- /dev/null +++ b/flyteplugins/go/tasks/plugins/core/sleep/plugin_test.go @@ -0,0 +1,127 @@ +package sleep + +import ( + "context" + "testing" + "time" + + core "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core" + coreMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks" + ioMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks" + idlcore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/durationpb" +) + +func TestHandleSucceedsImmediatelyForZeroDuration(t *testing.T) { + ctx := context.Background() + plugin := &Plugin{taskStartTimes: make(map[string]time.Time)} + tCtx := newTaskExecutionContext(time.Duration(0), true, "task-1") + + trns, err := plugin.Handle(ctx, tCtx) + require.NoError(t, err) + assert.Equal(t, core.PhaseSuccess, trns.Info().Phase()) + assert.Empty(t, plugin.taskStartTimes) +} + +func TestHandleWaitsUntilDurationElapses(t *testing.T) { + ctx := context.Background() + plugin := &Plugin{taskStartTimes: make(map[string]time.Time)} + tCtx := newTaskExecutionContext(30*time.Second, true, "task-1") + + first, err := plugin.Handle(ctx, tCtx) + require.NoError(t, err) + assert.Equal(t, core.PhaseRunning, first.Info().Phase()) + + plugin.Lock() + plugin.taskStartTimes["task-1"] = time.Now().Add(-31 * time.Second) + plugin.Unlock() + + second, err := plugin.Handle(ctx, tCtx) + require.NoError(t, err) + assert.Equal(t, core.PhaseSuccess, second.Info().Phase()) + + require.NoError(t, plugin.Finalize(ctx, tCtx)) + assert.Empty(t, plugin.taskStartTimes) +} + +func TestHandleReturnsUserFailureForInvalidInput(t *testing.T) { + ctx := context.Background() + plugin := &Plugin{taskStartTimes: make(map[string]time.Time)} + tCtx := newTaskExecutionContext(time.Second, false, "task-1") + + trns, err := plugin.Handle(ctx, tCtx) + require.NoError(t, err) + assert.Equal(t, core.PhasePermanentFailure, trns.Info().Phase()) + require.NotNil(t, trns.Info().Err()) + assert.Equal(t, idlcore.ExecutionError_USER, trns.Info().Err().GetKind()) + assert.Contains(t, trns.Info().Err().GetMessage(), "duration") +} + +func newTaskExecutionContext(sleepDuration time.Duration, validInput bool, generatedName string) *coreMocks.TaskExecutionContext { + // v1 VariableMap uses map[string]*Variable (v2 uses []*VariableEntry) + taskTemplate := &idlcore.TaskTemplate{ + Type: sleepTaskType, + Interface: &idlcore.TypedInterface{ + Inputs: &idlcore.VariableMap{ + Variables: map[string]*idlcore.Variable{ + "duration": { + Type: literalType(validInput), + }, + }, + }, + }, + } + + inputs := &idlcore.LiteralMap{ + Literals: map[string]*idlcore.Literal{ + "duration": { + Value: &idlcore.Literal_Scalar{ + Scalar: &idlcore.Scalar{ + Value: &idlcore.Scalar_Primitive{ + Primitive: &idlcore.Primitive{ + Value: &idlcore.Primitive_Duration{ + Duration: durationpb.New(sleepDuration), + }, + }, + }, + }, + }, + }, + }, + } + + taskReader := &coreMocks.TaskReader{} + taskReader.EXPECT().Read(mock.Anything).Return(taskTemplate, nil) + + inputReader := &ioMocks.InputReader{} + inputReader.EXPECT().Get(mock.Anything).Return(inputs, nil) + + taskExecutionID := &coreMocks.TaskExecutionID{} + taskExecutionID.EXPECT().GetGeneratedName().Return(generatedName) + + metadata := &coreMocks.TaskExecutionMetadata{} + metadata.EXPECT().GetTaskExecutionID().Return(taskExecutionID) + + tCtx := &coreMocks.TaskExecutionContext{} + tCtx.EXPECT().TaskReader().Return(taskReader) + tCtx.EXPECT().InputReader().Return(inputReader) + tCtx.EXPECT().TaskExecutionMetadata().Return(metadata) + + return tCtx +} + +func literalType(validInput bool) *idlcore.LiteralType { + simpleType := idlcore.SimpleType_DURATION + if !validInput { + simpleType = idlcore.SimpleType_STRING + } + + return &idlcore.LiteralType{ + Type: &idlcore.LiteralType_Simple{ + Simple: simpleType, + }, + } +} diff --git a/flytepropeller/plugins/loader.go b/flytepropeller/plugins/loader.go index 87f03a072b..2c212ce317 100644 --- a/flytepropeller/plugins/loader.go +++ b/flytepropeller/plugins/loader.go @@ -3,6 +3,8 @@ package plugins import ( // Common place to import all plugins, so that it can be imported by Singlebinary (flytelite) or by propeller main + _ "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/core/sleep" + _ "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/array/awsbatch" _ "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/array/k8s" _ "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/dask"