Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flyte-single-binary-local.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ tasks:
- K8S-ARRAY
- connector-service
- echo
- core-sleep
default-for-task-types:
- container: container
- container_array: K8S-ARRAY
Expand Down
155 changes: 155 additions & 0 deletions flyteplugins/go/tasks/plugins/core/sleep/plugin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package sleep

import (
"context"
"errors"
"fmt"
"sync"
"time"

"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery"
core "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core"
idlcore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
)

const sleepTaskType = "core-sleep"

type invalidInputError struct {
message string
}

func (e *invalidInputError) Error() string {
return e.message
}

type Plugin struct {
taskStartTimes map[string]time.Time
sync.Mutex
}

func (p *Plugin) GetID() string {
return sleepTaskType
}

func (p *Plugin) GetProperties() core.PluginProperties {
return core.PluginProperties{}
}

func (p *Plugin) getOrAddTaskStartTime(tCtx core.TaskExecutionContext) time.Time {
p.Lock()
defer p.Unlock()

taskExecutionID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()
if startTime, exists := p.taskStartTimes[taskExecutionID]; exists {
return startTime
}

startTime := time.Now()
p.taskStartTimes[taskExecutionID] = startTime
return startTime
}

func (p *Plugin) removeTask(taskExecutionID string) {
p.Lock()
defer p.Unlock()
delete(p.taskStartTimes, taskExecutionID)
}

func (p *Plugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) {
sleepDuration, err := resolveSleepDuration(ctx, tCtx)
if err != nil {
var invalidErr *invalidInputError
if errors.As(err, &invalidErr) {
return core.DoTransition(core.PhaseInfoFailure("BadTaskSpecification", invalidErr.Error(), nil)), nil
}
return core.UnknownTransition, err
}

if sleepDuration == 0 {
return core.DoTransition(core.PhaseInfoSuccess(nil)), nil
}

startTime := p.getOrAddTaskStartTime(tCtx)
if time.Since(startTime) >= sleepDuration {
return core.DoTransition(core.PhaseInfoSuccess(nil)), nil
}

return core.DoTransition(core.PhaseInfoRunning(core.DefaultPhaseVersion, nil)), nil
}

func (p *Plugin) Abort(ctx context.Context, tCtx core.TaskExecutionContext) error {
return nil
}

func (p *Plugin) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) error {
taskExecutionID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName()
p.removeTask(taskExecutionID)
return nil
}

func resolveSleepDuration(ctx context.Context, tCtx core.TaskExecutionContext) (time.Duration, error) {
taskTemplate, err := tCtx.TaskReader().Read(ctx)
if err != nil {
return 0, fmt.Errorf("failed to read task template: %w", err)
}
if taskTemplate == nil {
return 0, fmt.Errorf("nil task template")
}

iface := taskTemplate.GetInterface()
if iface == nil || iface.GetInputs() == nil || len(iface.GetInputs().GetVariables()) != 1 {
return 0, &invalidInputError{message: fmt.Sprintf("task type [%s] requires exactly one duration input", sleepTaskType)}
}
if iface.GetOutputs() != nil && len(iface.GetOutputs().GetVariables()) != 0 {
return 0, &invalidInputError{message: fmt.Sprintf("task type [%s] does not support outputs", sleepTaskType)}
}

// v1 GetVariables() returns map[string]*Variable (v2 returns []*VariableEntry)
inputName := ""
for name, v := range iface.GetInputs().GetVariables() {
if v == nil || v.GetType() == nil || v.GetType().GetSimple() != idlcore.SimpleType_DURATION {
return 0, &invalidInputError{message: fmt.Sprintf("input [%s] must be typed as duration", name)}
}
inputName = name
}

inputs, err := tCtx.InputReader().Get(ctx)
if err != nil {
return 0, fmt.Errorf("failed to read task inputs: %w", err)
}
if inputs == nil {
return 0, &invalidInputError{message: fmt.Sprintf("task type [%s] requires a duration input value", sleepTaskType)}
}

literal, ok := inputs.GetLiterals()[inputName]
if !ok || literal == nil {
return 0, &invalidInputError{message: fmt.Sprintf("duration input [%s] is missing", inputName)}
}

durationValue := literal.GetScalar().GetPrimitive().GetDuration()
if durationValue == nil {
return 0, &invalidInputError{message: fmt.Sprintf("duration input [%s] must be a duration literal", inputName)}
}

sleepDuration := durationValue.AsDuration()
if sleepDuration < 0 {
return 0, &invalidInputError{message: fmt.Sprintf("duration input [%s] must be non-negative", inputName)}
}

return sleepDuration, nil
}

func init() {
pluginmachinery.PluginRegistry().RegisterCorePlugin(
core.PluginEntry{
ID: sleepTaskType,
RegisteredTaskTypes: []core.TaskType{sleepTaskType},
LoadPlugin: func(ctx context.Context, iCtx core.SetupContext) (core.Plugin, error) {
return &Plugin{
taskStartTimes: make(map[string]time.Time),
}, nil
},
IsDefault: false,
},
)
}
127 changes: 127 additions & 0 deletions flyteplugins/go/tasks/plugins/core/sleep/plugin_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package sleep

import (
"context"
"testing"
"time"

core "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core"
coreMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/mocks"
ioMocks "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/io/mocks"
idlcore "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/durationpb"
)

func TestHandleSucceedsImmediatelyForZeroDuration(t *testing.T) {
ctx := context.Background()
plugin := &Plugin{taskStartTimes: make(map[string]time.Time)}
tCtx := newTaskExecutionContext(time.Duration(0), true, "task-1")

trns, err := plugin.Handle(ctx, tCtx)
require.NoError(t, err)
assert.Equal(t, core.PhaseSuccess, trns.Info().Phase())
assert.Empty(t, plugin.taskStartTimes)
}

func TestHandleWaitsUntilDurationElapses(t *testing.T) {
ctx := context.Background()
plugin := &Plugin{taskStartTimes: make(map[string]time.Time)}
tCtx := newTaskExecutionContext(30*time.Second, true, "task-1")

first, err := plugin.Handle(ctx, tCtx)
require.NoError(t, err)
assert.Equal(t, core.PhaseRunning, first.Info().Phase())

plugin.Lock()
plugin.taskStartTimes["task-1"] = time.Now().Add(-31 * time.Second)
plugin.Unlock()

second, err := plugin.Handle(ctx, tCtx)
require.NoError(t, err)
assert.Equal(t, core.PhaseSuccess, second.Info().Phase())

require.NoError(t, plugin.Finalize(ctx, tCtx))
assert.Empty(t, plugin.taskStartTimes)
}

func TestHandleReturnsUserFailureForInvalidInput(t *testing.T) {
ctx := context.Background()
plugin := &Plugin{taskStartTimes: make(map[string]time.Time)}
tCtx := newTaskExecutionContext(time.Second, false, "task-1")

trns, err := plugin.Handle(ctx, tCtx)
require.NoError(t, err)
assert.Equal(t, core.PhasePermanentFailure, trns.Info().Phase())
require.NotNil(t, trns.Info().Err())
assert.Equal(t, idlcore.ExecutionError_USER, trns.Info().Err().GetKind())
assert.Contains(t, trns.Info().Err().GetMessage(), "duration")
}

func newTaskExecutionContext(sleepDuration time.Duration, validInput bool, generatedName string) *coreMocks.TaskExecutionContext {
// v1 VariableMap uses map[string]*Variable (v2 uses []*VariableEntry)
taskTemplate := &idlcore.TaskTemplate{
Type: sleepTaskType,
Interface: &idlcore.TypedInterface{
Inputs: &idlcore.VariableMap{
Variables: map[string]*idlcore.Variable{
"duration": {
Type: literalType(validInput),
},
},
},
},
}

inputs := &idlcore.LiteralMap{
Literals: map[string]*idlcore.Literal{
"duration": {
Value: &idlcore.Literal_Scalar{
Scalar: &idlcore.Scalar{
Value: &idlcore.Scalar_Primitive{
Primitive: &idlcore.Primitive{
Value: &idlcore.Primitive_Duration{
Duration: durationpb.New(sleepDuration),
},
},
},
},
},
},
},
}

taskReader := &coreMocks.TaskReader{}
taskReader.EXPECT().Read(mock.Anything).Return(taskTemplate, nil)

inputReader := &ioMocks.InputReader{}
inputReader.EXPECT().Get(mock.Anything).Return(inputs, nil)

taskExecutionID := &coreMocks.TaskExecutionID{}
taskExecutionID.EXPECT().GetGeneratedName().Return(generatedName)

metadata := &coreMocks.TaskExecutionMetadata{}
metadata.EXPECT().GetTaskExecutionID().Return(taskExecutionID)

tCtx := &coreMocks.TaskExecutionContext{}
tCtx.EXPECT().TaskReader().Return(taskReader)
tCtx.EXPECT().InputReader().Return(inputReader)
tCtx.EXPECT().TaskExecutionMetadata().Return(metadata)

return tCtx
}

func literalType(validInput bool) *idlcore.LiteralType {
simpleType := idlcore.SimpleType_DURATION
if !validInput {
simpleType = idlcore.SimpleType_STRING
}

return &idlcore.LiteralType{
Type: &idlcore.LiteralType_Simple{
Simple: simpleType,
},
}
}
2 changes: 2 additions & 0 deletions flytepropeller/plugins/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package plugins

import (
// Common place to import all plugins, so that it can be imported by Singlebinary (flytelite) or by propeller main
_ "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/core/sleep"

_ "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/array/awsbatch"
_ "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/array/k8s"
_ "github.com/flyteorg/flyte/flyteplugins/go/tasks/plugins/k8s/dask"
Expand Down
Loading