From 9d6a99d178702d40f4a1d1831dfdd271dd4c1ad6 Mon Sep 17 00:00:00 2001 From: Jason Parraga Date: Wed, 20 May 2026 15:12:43 -0700 Subject: [PATCH 1/4] Migrate protobuf Signed-off-by: Jason Parraga --- flytecopilot/cmd/sidecar_test.go | 2 +- flytecopilot/data/download.go | 27 ++-- flytecopilot/data/upload.go | 2 +- .../go/coreutils/extract_literal_test.go | 2 +- flyteidl2/clients/go/coreutils/literals.go | 8 +- .../clients/go/coreutils/literals_test.go | 2 +- .../go/tasks/pluginmachinery/core/phase.go | 2 +- .../tasks/pluginmachinery/flytek8s/copilot.go | 2 +- .../pluginmachinery/flytek8s/copilot_test.go | 2 +- .../flytek8s/k8s_resource_adds_test.go | 3 +- .../pluginmachinery/flytek8s/pod_helper.go | 2 +- .../ioutils/remote_file_output_reader_test.go | 6 +- .../ioutils/task_reader_test.go | 2 +- .../pluginmachinery/utils/marshal_utils.go | 26 ++-- .../utils/marshal_utils_test.go | 2 +- .../utils/secrets/marshaler.go | 11 +- .../utils/secrets/marshaler_test.go | 4 +- .../go/tasks/plugins/k8s/dask/dask_test.go | 2 +- .../plugins/k8s/kfoperators/mpi/mpi_test.go | 2 +- .../k8s/kfoperators/pytorch/pytorch_test.go | 2 +- .../kfoperators/tensorflow/tensorflow_test.go | 2 +- .../go/tasks/plugins/k8s/pod/sidecar_test.go | 2 +- .../go/tasks/plugins/k8s/spark/spark_test.go | 2 +- flytestdlib/app/error.go | 5 +- .../flytestdlib/storage/mocks/mocks.go | 2 +- flytestdlib/pbhash/pbhash.go | 11 +- flytestdlib/pbhash/pbhash_test.go | 116 +++++++----------- flytestdlib/storage/mocks/mocks.go | 2 +- flytestdlib/storage/protobuf_store.go | 2 +- flytestdlib/storage/protobuf_store_test.go | 46 ++----- flytestdlib/storage/storage.go | 2 +- flytestdlib/utils/marshal_utils.go | 39 +++--- flytestdlib/utils/marshal_utils_test.go | 4 +- flytestdlib/utils/prototest/test_type.pb.go | 5 - go.mod | 2 +- runs/repository/transformers/task.go | 2 +- runs/service/run_service_test.go | 2 +- 37 files changed, 156 insertions(+), 201 deletions(-) diff --git a/flytecopilot/cmd/sidecar_test.go b/flytecopilot/cmd/sidecar_test.go index c8a89c60e5..31bde46349 100644 --- a/flytecopilot/cmd/sidecar_test.go +++ b/flytecopilot/cmd/sidecar_test.go @@ -6,8 +6,8 @@ import ( "os" "testing" - "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" "github.com/flyteorg/flyte/v2/flytecopilot/cmd/containerwatcher" "github.com/flyteorg/flyte/v2/flytestdlib/promutils" diff --git a/flytecopilot/data/download.go b/flytecopilot/data/download.go index a79c016e93..e849afecd1 100644 --- a/flytecopilot/data/download.go +++ b/flytecopilot/data/download.go @@ -15,10 +15,10 @@ import ( "time" "github.com/ghodss/yaml" - "github.com/golang/protobuf/jsonpb" //nolint: staticcheck - "github.com/golang/protobuf/proto" //nolint: staticcheck - structpb "github.com/golang/protobuf/ptypes/struct" "github.com/pkg/errors" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" //nolint: staticcheck + structpb "google.golang.org/protobuf/types/known/structpb" "github.com/flyteorg/flyte/v2/flytestdlib/futures" "github.com/flyteorg/flyte/v2/flytestdlib/logger" @@ -289,7 +289,7 @@ func (d Downloader) handleError(_ context.Context, b *core.Error, toFilePath str func (d Downloader) handleGeneric(ctx context.Context, b *structpb.Struct, toFilePath string, writeToFile bool) (interface{}, error) { if writeToFile && b != nil { - m := jsonpb.Marshaler{} + m := protojson.MarshalOptions{} writer, err := os.Create(toFilePath) if err != nil { return nil, errors.Wrapf(err, "failed to open file at path %s", toFilePath) @@ -300,7 +300,12 @@ func (d Downloader) handleGeneric(ctx context.Context, b *structpb.Struct, toFil logger.Errorf(ctx, "failed to close File write stream. Error: %s", err) } }() - return b, m.Marshal(writer, b) + raw, err := m.Marshal(b) + if err != nil { + return nil, err + } + _, err = writer.Write(raw) + return b, err } return b, nil } @@ -310,7 +315,6 @@ func (d Downloader) handlePrimitive(primitive *core.Primitive, toFilePath string var toByteArray func() ([]byte, error) var v interface{} - var err error switch primitive.GetValue().(type) { case *core.Primitive_StringValue: @@ -334,18 +338,18 @@ func (d Downloader) handlePrimitive(primitive *core.Primitive, toFilePath string return []byte(strconv.FormatFloat(primitive.GetFloatValue(), 'f', -1, 64)), nil } case *core.Primitive_Datetime: - v = primitive.GetDatetime().AsTime() - if err != nil { + if err := primitive.GetDatetime().CheckValid(); err != nil { return nil, err } + v = primitive.GetDatetime().AsTime() toByteArray = func() ([]byte, error) { return []byte(primitive.GetDatetime().AsTime().Format(time.RFC3339Nano)), nil } case *core.Primitive_Duration: - v = primitive.GetDuration().AsDuration() - if err != nil { + if err := primitive.GetDuration().CheckValid(); err != nil { return nil, err } + v = primitive.GetDuration().AsDuration() toByteArray = func() ([]byte, error) { return []byte(primitive.GetDuration().AsDuration().String()), nil } @@ -537,6 +541,9 @@ func (d Downloader) DownloadInputs(ctx context.Context, inputRef storage.DataRef if err != nil { return errors.Wrapf(err, "failed to download input variable from remote store") } + if len(lMap.GetLiterals()) == 0 { + return nil + } // We will always write the protobuf b, err := proto.Marshal(lMap) diff --git a/flytecopilot/data/upload.go b/flytecopilot/data/upload.go index 8412e3a9ee..664916b409 100644 --- a/flytecopilot/data/upload.go +++ b/flytecopilot/data/upload.go @@ -9,8 +9,8 @@ import ( "path/filepath" "reflect" - "github.com/golang/protobuf/proto" //nolint: staticcheck "github.com/pkg/errors" + "google.golang.org/protobuf/proto" //nolint: staticcheck "github.com/flyteorg/flyte/v2/flyteidl2/clients/go/coreutils" "github.com/flyteorg/flyte/v2/flytestdlib/futures" diff --git a/flyteidl2/clients/go/coreutils/extract_literal_test.go b/flyteidl2/clients/go/coreutils/extract_literal_test.go index 8781c6b3a5..66562f5ec3 100644 --- a/flyteidl2/clients/go/coreutils/extract_literal_test.go +++ b/flyteidl2/clients/go/coreutils/extract_literal_test.go @@ -8,8 +8,8 @@ import ( "testing" "time" - structpb "github.com/golang/protobuf/ptypes/struct" "github.com/stretchr/testify/assert" + structpb "google.golang.org/protobuf/types/known/structpb" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" ) diff --git a/flyteidl2/clients/go/coreutils/literals.go b/flyteidl2/clients/go/coreutils/literals.go index b3d4bfed36..1925ee216b 100644 --- a/flyteidl2/clients/go/coreutils/literals.go +++ b/flyteidl2/clients/go/coreutils/literals.go @@ -11,11 +11,11 @@ import ( "strings" "time" - "github.com/golang/protobuf/jsonpb" //nolint: staticcheck - structpb "github.com/golang/protobuf/ptypes/struct" "github.com/pkg/errors" "github.com/shamaton/msgpack/v2" + "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/types/known/durationpb" + structpb "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" "github.com/flyteorg/flyte/v2/flytestdlib/storage" @@ -378,8 +378,8 @@ func MakeLiteralForSimpleType(t core.SimpleType, s string) (*core.Literal, error switch t { case core.SimpleType_STRUCT: st := &structpb.Struct{} - unmarshaler := jsonpb.Unmarshaler{AllowUnknownFields: true} - err := unmarshaler.Unmarshal(strings.NewReader(s), st) + unmarshaler := protojson.UnmarshalOptions{DiscardUnknown: true} + err := unmarshaler.Unmarshal([]byte(s), st) if err != nil { return nil, errors.Wrapf(err, "failed to load generic type as json.") } diff --git a/flyteidl2/clients/go/coreutils/literals_test.go b/flyteidl2/clients/go/coreutils/literals_test.go index ac851cab4b..99aa043c9e 100644 --- a/flyteidl2/clients/go/coreutils/literals_test.go +++ b/flyteidl2/clients/go/coreutils/literals_test.go @@ -13,11 +13,11 @@ import ( "github.com/flyteorg/flyte/v2/flytestdlib/storage" "github.com/go-test/deep" - structpb "github.com/golang/protobuf/ptypes/struct" "github.com/pkg/errors" "github.com/shamaton/msgpack/v2" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/types/known/durationpb" + structpb "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" diff --git a/flyteplugins/go/tasks/pluginmachinery/core/phase.go b/flyteplugins/go/tasks/pluginmachinery/core/phase.go index de3f48cdbd..e7cd81d501 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/phase.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/phase.go @@ -4,7 +4,7 @@ import ( "fmt" "time" - structpb "github.com/golang/protobuf/ptypes/struct" + structpb "google.golang.org/protobuf/types/known/structpb" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" ) diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot.go index 126a9f7f94..1c9ba62fb8 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot.go @@ -7,8 +7,8 @@ import ( "strconv" "time" - "github.com/golang/protobuf/proto" //nolint: staticcheck "github.com/pkg/errors" + "google.golang.org/protobuf/proto" //nolint: staticcheck v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot_test.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot_test.go index f74f3e702a..054c008b5f 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot_test.go @@ -7,8 +7,8 @@ import ( "testing" "time" - "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds_test.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds_test.go index 4968242800..aa8cfcacf0 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds_test.go @@ -6,7 +6,6 @@ import ( "reflect" "testing" - "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" v12 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" @@ -64,7 +63,7 @@ func TestGetExecutionEnvVars(t *testing.T) { envVars := GetExecutionEnvVars(mock, tt.consoleURL) assert.Len(t, envVars, tt.expectedEnvVars) if tt.expectedEnvVar != nil { - assert.True(t, proto.Equal(&envVars[5], tt.expectedEnvVar)) + assert.Equal(t, tt.expectedEnvVar, &envVars[5]) } } } diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 1b1a67f163..5521b22fe8 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -9,8 +9,8 @@ import ( "strings" "time" - "github.com/golang/protobuf/proto" //nolint: staticcheck "github.com/imdario/mergo" + "google.golang.org/protobuf/proto" //nolint: staticcheck "google.golang.org/protobuf/types/known/timestamppb" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" diff --git a/flyteplugins/go/tasks/pluginmachinery/ioutils/remote_file_output_reader_test.go b/flyteplugins/go/tasks/pluginmachinery/ioutils/remote_file_output_reader_test.go index cedd8fb4d2..fddb7aa530 100644 --- a/flyteplugins/go/tasks/pluginmachinery/ioutils/remote_file_output_reader_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/ioutils/remote_file_output_reader_test.go @@ -10,7 +10,7 @@ import ( "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "google.golang.org/protobuf/runtime/protoiface" + "google.golang.org/protobuf/proto" ) type MemoryMetadata struct { @@ -54,7 +54,7 @@ func TestReadOrigin(t *testing.T) { }, } store := &storageMocks.ComposedProtobufStore{} - store.EXPECT().ReadProtobuf(mock.Anything, mock.Anything, mock.Anything).Run(func(ctx context.Context, ref storage.DataReference, msg protoiface.MessageV1) { + store.EXPECT().ReadProtobuf(mock.Anything, mock.Anything, mock.Anything).Run(func(ctx context.Context, ref storage.DataReference, msg proto.Message) { assert.NotNil(t, msg) casted := msg.(*core.ErrorDocument) casted.Error = errorDoc.Error @@ -89,7 +89,7 @@ func TestReadOrigin(t *testing.T) { }, } store := &storageMocks.ComposedProtobufStore{} - store.EXPECT().ReadProtobuf(mock.Anything, mock.Anything, mock.Anything).Run(func(ctx context.Context, ref storage.DataReference, msg protoiface.MessageV1) { + store.EXPECT().ReadProtobuf(mock.Anything, mock.Anything, mock.Anything).Run(func(ctx context.Context, ref storage.DataReference, msg proto.Message) { assert.NotNil(t, msg) casted := msg.(*core.ErrorDocument) casted.Error = errorDoc.Error diff --git a/flyteplugins/go/tasks/pluginmachinery/ioutils/task_reader_test.go b/flyteplugins/go/tasks/pluginmachinery/ioutils/task_reader_test.go index c27f8788b3..dbfccb7383 100644 --- a/flyteplugins/go/tasks/pluginmachinery/ioutils/task_reader_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/ioutils/task_reader_test.go @@ -5,8 +5,8 @@ import ( "fmt" "testing" - "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core/mocks" "github.com/flyteorg/flyte/v2/flytestdlib/contextutils" diff --git a/flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils.go b/flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils.go index 80d453bc53..f04978f3f6 100755 --- a/flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils.go +++ b/flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils.go @@ -3,16 +3,15 @@ package utils import ( "encoding/json" "fmt" - "strings" - "github.com/golang/protobuf/jsonpb" //nolint: staticcheck - "github.com/golang/protobuf/proto" //nolint: staticcheck - structpb "github.com/golang/protobuf/ptypes/struct" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" //nolint: staticcheck + structpb "google.golang.org/protobuf/types/known/structpb" ) -var jsonPbMarshaler = jsonpb.Marshaler{} -var jsonPbUnmarshaler = &jsonpb.Unmarshaler{ - AllowUnknownFields: true, +var jsonPbMarshaler = protojson.MarshalOptions{} +var jsonPbUnmarshaler = protojson.UnmarshalOptions{ + DiscardUnknown: true, } // Deprecated: Use flytestdlib/utils.UnmarshalStructToPb instead. @@ -21,12 +20,12 @@ func UnmarshalStruct(structObj *structpb.Struct, msg proto.Message) error { return fmt.Errorf("nil Struct Object passed") } - jsonObj, err := jsonPbMarshaler.MarshalToString(structObj) + jsonObj, err := jsonPbMarshaler.Marshal(structObj) if err != nil { return err } - if err = jsonPbUnmarshaler.Unmarshal(strings.NewReader(jsonObj), msg); err != nil { + if err = jsonPbUnmarshaler.Unmarshal(jsonObj, msg); err != nil { return err } @@ -39,12 +38,12 @@ func MarshalStruct(in proto.Message, out *structpb.Struct) error { return fmt.Errorf("nil Struct Object passed") } - jsonObj, err := jsonPbMarshaler.MarshalToString(in) + jsonObj, err := jsonPbMarshaler.Marshal(in) if err != nil { return err } - if err = jsonpb.UnmarshalString(jsonObj, out); err != nil { + if err = jsonPbUnmarshaler.Unmarshal(jsonObj, out); err != nil { return err } @@ -53,7 +52,8 @@ func MarshalStruct(in proto.Message, out *structpb.Struct) error { // Deprecated: Use flytestdlib/utils.MarshalToString instead. func MarshalToString(msg proto.Message) (string, error) { - return jsonPbMarshaler.MarshalToString(msg) + b, err := jsonPbMarshaler.Marshal(msg) + return string(b), err } // Deprecated: Use flytestdlib/utils.MarshalObjToStruct instead. @@ -66,7 +66,7 @@ func MarshalObjToStruct(input interface{}) (*structpb.Struct, error) { // Turn JSON into a protobuf struct structObj := &structpb.Struct{} - if err := jsonpb.UnmarshalString(string(b), structObj); err != nil { + if err := jsonPbUnmarshaler.Unmarshal(b, structObj); err != nil { return nil, err } return structObj, nil diff --git a/flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils_test.go b/flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils_test.go index abe1b7d2a2..3f34c4ef07 100644 --- a/flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils_test.go @@ -5,8 +5,8 @@ import ( "testing" "github.com/go-test/deep" - structpb "github.com/golang/protobuf/ptypes/struct" "github.com/stretchr/testify/assert" + structpb "google.golang.org/protobuf/types/known/structpb" v1 "k8s.io/api/core/v1" ) diff --git a/flyteplugins/go/tasks/pluginmachinery/utils/secrets/marshaler.go b/flyteplugins/go/tasks/pluginmachinery/utils/secrets/marshaler.go index b4997a6384..1bc0c84927 100644 --- a/flyteplugins/go/tasks/pluginmachinery/utils/secrets/marshaler.go +++ b/flyteplugins/go/tasks/pluginmachinery/utils/secrets/marshaler.go @@ -2,13 +2,12 @@ package secrets import ( "fmt" + "strconv" + "strings" "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/encoding" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" - "github.com/golang/protobuf/proto" //nolint: staticcheck - - "strconv" - "strings" + "google.golang.org/protobuf/encoding/prototext" ) const ( @@ -36,7 +35,7 @@ func decodeSecret(encoded string) (string, error) { } func marshalSecret(s *core.Secret) string { - return encodeSecret(proto.MarshalTextString(s)) + return encodeSecret(prototext.MarshalOptions{Multiline: false}.Format(s)) } func unmarshalSecret(encoded string) (*core.Secret, error) { @@ -46,7 +45,7 @@ func unmarshalSecret(encoded string) (*core.Secret, error) { } s := &core.Secret{} - err = proto.UnmarshalText(decoded, s) + err = prototext.Unmarshal([]byte(decoded), s) return s, err } diff --git a/flyteplugins/go/tasks/pluginmachinery/utils/secrets/marshaler_test.go b/flyteplugins/go/tasks/pluginmachinery/utils/secrets/marshaler_test.go index b07899ee63..f769ed2683 100644 --- a/flyteplugins/go/tasks/pluginmachinery/utils/secrets/marshaler_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/utils/secrets/marshaler_test.go @@ -35,14 +35,14 @@ func TestMarshalSecretsToMapStrings(t *testing.T) { Group: ";':/\\", }, }}, want: map[string]string{ - "flyte.secrets/s0": "m4zg54lqhiqceozhhixvyxbcbi", + "flyte.secrets/s0": "m4zg54lqhirdwjz1f4ofyiq", }, wantErr: false}, {name: "Without group", args: args{secrets: []*core.Secret{ { Key: "my_key", }, }}, want: map[string]string{ - "flyte.secrets/s0": "nnsxsoraejwxsx2lmv3secq", + "flyte.secrets/s0": "nnsxsorcnv3v512fpera", }, wantErr: false}, } for _, tt := range tests { diff --git a/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go b/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go index 386b0d960a..8b44c5f450 100644 --- a/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go @@ -7,9 +7,9 @@ import ( "time" daskAPI "github.com/dask/dask-kubernetes/v2023/dask_kubernetes/operator/go_client/pkg/apis/kubernetes.dask.org/v1" - "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/structpb" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index 1db7cdc0e4..ff18dee236 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -7,10 +7,10 @@ import ( "testing" "time" - structpb "github.com/golang/protobuf/ptypes/struct" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + structpb "google.golang.org/protobuf/types/known/structpb" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 896b6960aa..c5780c0fa5 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -7,10 +7,10 @@ import ( "testing" "time" - structpb "github.com/golang/protobuf/ptypes/struct" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + structpb "google.golang.org/protobuf/types/known/structpb" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index 068526581e..fa3290e685 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -7,11 +7,11 @@ import ( "testing" "time" - structpb "github.com/golang/protobuf/ptypes/struct" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "google.golang.org/protobuf/proto" + structpb "google.golang.org/protobuf/types/known/structpb" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" diff --git a/flyteplugins/go/tasks/plugins/k8s/pod/sidecar_test.go b/flyteplugins/go/tasks/plugins/k8s/pod/sidecar_test.go index ca33040c92..890f0aed19 100644 --- a/flyteplugins/go/tasks/plugins/k8s/pod/sidecar_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/pod/sidecar_test.go @@ -8,10 +8,10 @@ import ( "path" "testing" - structpb "github.com/golang/protobuf/ptypes/struct" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "google.golang.org/protobuf/proto" + structpb "google.golang.org/protobuf/types/known/structpb" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index 2d7e2ffa26..17b6d0a21b 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -10,9 +10,9 @@ import ( sj "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" sparkOp "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" - structpb "github.com/golang/protobuf/ptypes/struct" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + structpb "google.golang.org/protobuf/types/known/structpb" corev1 "k8s.io/api/core/v1" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/controller-runtime/pkg/client" diff --git a/flytestdlib/app/error.go b/flytestdlib/app/error.go index 8433f149f0..92cf96b1f7 100644 --- a/flytestdlib/app/error.go +++ b/flytestdlib/app/error.go @@ -3,9 +3,10 @@ package app import ( "fmt" - "github.com/golang/protobuf/proto" //nolint: staticcheck "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" //nolint: staticcheck + "google.golang.org/protobuf/protoadapt" ) type ServerError interface { @@ -27,7 +28,7 @@ func (e *serverError) Code() codes.Code { } func (e *serverError) WithDetails(details proto.Message) (ServerError, error) { - s, err := e.status.WithDetails(details) + s, err := e.status.WithDetails(protoadapt.MessageV1Of(details)) if err != nil { return nil, err } diff --git a/flytestdlib/flytestdlib/storage/mocks/mocks.go b/flytestdlib/flytestdlib/storage/mocks/mocks.go index a55277e13d..3eeb9b02b1 100644 --- a/flytestdlib/flytestdlib/storage/mocks/mocks.go +++ b/flytestdlib/flytestdlib/storage/mocks/mocks.go @@ -10,7 +10,7 @@ import ( "github.com/flyteorg/flyte/v2/flytestdlib/storage" mock "github.com/stretchr/testify/mock" - "github.com/golang/protobuf/proto" //nolint: staticcheck + "google.golang.org/protobuf/proto" //nolint: staticcheck ) // NewMetadata creates a new instance of Metadata. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. diff --git a/flytestdlib/pbhash/pbhash.go b/flytestdlib/pbhash/pbhash.go index d2f46b2a29..3f00f8dd89 100644 --- a/flytestdlib/pbhash/pbhash.go +++ b/flytestdlib/pbhash/pbhash.go @@ -6,13 +6,13 @@ import ( "encoding/base64" goObjectHash "github.com/benlaurie/objecthash/go/objecthash" - "github.com/golang/protobuf/jsonpb" //nolint: staticcheck - "github.com/golang/protobuf/proto" //nolint: staticcheck + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" //nolint: staticcheck "github.com/flyteorg/flyte/v2/flytestdlib/logger" ) -var marshaller = &jsonpb.Marshaler{} +var marshaller = protojson.MarshalOptions{} func fromHashToByteArray(input [32]byte) []byte { output := make([]byte, 32) @@ -24,16 +24,17 @@ func fromHashToByteArray(input [32]byte) []byte { func ComputeHash(ctx context.Context, pb proto.Message) ([]byte, error) { // We marshal the pb object to JSON first which should provide a consistent mapping of pb to json fields as stated // here: https://developers.google.com/protocol-buffers/docs/proto3#json - // jsonpb marshalling includes: + // protojson marshalling includes: // - sorting map values to provide a stable output // - omitting empty values which supports backwards compatibility of old protobuf definitions // We do not use protobuf marshalling because it does not guarantee stable output because of how it handles // unknown fields and ordering of fields. https://github.com/protocolbuffers/protobuf/issues/2830 - pbJSON, err := marshaller.MarshalToString(pb) + pbJSONBytes, err := marshaller.Marshal(pb) if err != nil { logger.Warning(ctx, "failed to marshal pb [%+v] to JSON with err %v", pb, err) return nil, err } + pbJSON := string(pbJSONBytes) // Deterministically hash the JSON object to a byte array. The library will sort the map keys of the JSON object // so that we do not run into the issues from pb marshalling. diff --git a/flytestdlib/pbhash/pbhash_test.go b/flytestdlib/pbhash/pbhash_test.go index 75735b4135..23abae1286 100644 --- a/flytestdlib/pbhash/pbhash_test.go +++ b/flytestdlib/pbhash/pbhash_test.go @@ -5,74 +5,41 @@ import ( "testing" "time" - "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/ptypes" - "github.com/golang/protobuf/ptypes/duration" - "github.com/golang/protobuf/ptypes/timestamp" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/structpb" + "google.golang.org/protobuf/types/known/timestamppb" ) -// Mock a Protobuf generated GO object -type mockProtoMessage struct { - Integer int64 `protobuf:"varint,1,opt,name=integer,proto3" json:"integer,omitempty"` - FloatValue float64 `protobuf:"fixed64,2,opt,name=float_value,json=floatValue,proto3" json:"float_value,omitempty"` - StringValue string `protobuf:"bytes,3,opt,name=string_value,json=stringValue,proto3" json:"string_value,omitempty"` - Boolean bool `protobuf:"varint,4,opt,name=boolean,proto3" json:"boolean,omitempty"` - Datetime *timestamp.Timestamp `protobuf:"bytes,5,opt,name=datetime,proto3" json:"datetime,omitempty"` - Duration *duration.Duration `protobuf:"bytes,6,opt,name=duration,proto3" json:"duration,omitempty"` - MapValue map[string]string `protobuf:"bytes,7,rep,name=map_value,json=mapValue,proto3" json:"map_value,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` - Collections []string `protobuf:"bytes,8,rep,name=collections,proto3" json:"collections,omitempty"` -} - -func (mockProtoMessage) Reset() { -} - -func (m mockProtoMessage) String() string { - return proto.CompactTextString(m) -} - -func (mockProtoMessage) ProtoMessage() { -} - -// Mock an older version of the above pb object that doesn't have some fields -type mockOlderProto struct { - Integer int64 `protobuf:"varint,1,opt,name=integer,proto3" json:"integer,omitempty"` - FloatValue float64 `protobuf:"fixed64,2,opt,name=float_value,json=floatValue,proto3" json:"float_value,omitempty"` - StringValue string `protobuf:"bytes,3,opt,name=string_value,json=stringValue,proto3" json:"string_value,omitempty"` - Boolean bool `protobuf:"varint,4,opt,name=boolean,proto3" json:"boolean,omitempty"` -} +var sampleTime = timestamppb.New(time.Date(2019, 03, 29, 12, 0, 0, 0, time.UTC)) -func (mockOlderProto) Reset() { -} - -func (m mockOlderProto) String() string { - return proto.CompactTextString(m) -} +func makeStruct(t *testing.T, fields map[string]interface{}) *structpb.Struct { + t.Helper() -func (mockOlderProto) ProtoMessage() { + s, err := structpb.NewStruct(fields) + require.NoError(t, err) + return s } -var sampleTime, _ = ptypes.TimestampProto( - time.Date(2019, 03, 29, 12, 0, 0, 0, time.UTC)) - func TestProtoHash(t *testing.T) { - mockProto := &mockProtoMessage{ - Integer: 18, - FloatValue: 1.3, - StringValue: "lets test this", - Boolean: true, - Datetime: sampleTime, - Duration: ptypes.DurationProto(time.Millisecond), - MapValue: map[string]string{ + mockProto := makeStruct(t, map[string]interface{}{ + "integer": 18, + "floatValue": 1.3, + "stringValue": "lets test this", + "boolean": true, + "datetime": sampleTime.AsTime().Format(time.RFC3339Nano), + "duration": durationpb.New(time.Millisecond).AsDuration().String(), + "mapValue": map[string]interface{}{ "z": "last", "a": "first", }, - Collections: []string{"1", "2", "3"}, - } + "collections": []interface{}{"1", "2", "3"}, + }) - expectedHashedMockProto := []byte{0x62, 0x95, 0xb2, 0x2c, 0x23, 0xf5, 0x35, 0x6d, 0x3, 0x56, 0x4d, 0xc7, 0x8f, 0xae, - 0x2d, 0x2b, 0xbd, 0x7, 0xff, 0xdb, 0x7e, 0xe5, 0xf4, 0x25, 0x8f, 0xbc, 0xb2, 0xc, 0xad, 0xa5, 0x48, 0x44} - expectedHashString := "YpWyLCP1NW0DVk3Hj64tK70H/9t+5fQlj7yyDK2lSEQ=" + expectedHashedMockProto := []byte{0x45, 0xd1, 0xe, 0x9, 0x5e, 0xe3, 0xf7, 0x3e, 0xe9, 0x9, 0xe9, 0xc9, 0x27, 0xd6, + 0xf5, 0x79, 0x81, 0xf6, 0x52, 0x48, 0x3f, 0x71, 0x8c, 0x2, 0x87, 0x1, 0x98, 0x58, 0x5b, 0x7e, 0xf, 0xda} + expectedHashString := "RdEOCV7j9z7pCenJJ9b1eYH2Ukg/cYwChwGYWFt+D9o=" t.Run("TestFullProtoHash", func(t *testing.T) { hashedBytes, err := ComputeHash(context.Background(), mockProto) @@ -86,7 +53,10 @@ func TestProtoHash(t *testing.T) { }) t.Run("TestFullProtoHashReorderKeys", func(t *testing.T) { - mockProto.MapValue = map[string]string{"a": "first", "z": "last"} + mockProto.Fields["mapValue"] = structpb.NewStructValue(makeStruct(t, map[string]interface{}{ + "a": "first", + "z": "last", + })) hashedBytes, err := ComputeHash(context.Background(), mockProto) assert.Nil(t, err) assert.Equal(t, expectedHashedMockProto, hashedBytes) @@ -100,18 +70,18 @@ func TestProtoHash(t *testing.T) { func TestPartialFilledProtoHash(t *testing.T) { - mockProtoOmitEmpty := &mockProtoMessage{ - Integer: 18, - FloatValue: 1.3, - StringValue: "lets test this", - Boolean: true, - } + mockProtoOmitEmpty := makeStruct(t, map[string]interface{}{ + "integer": 18, + "floatValue": 1.3, + "stringValue": "lets test this", + "boolean": true, + }) - expectedHashedMockProtoOmitEmpty := []byte{0x1a, 0x13, 0xcc, 0x4c, 0xab, 0xc9, 0x7d, 0x43, 0xc7, 0x2b, 0xc5, 0x37, - 0xbc, 0x49, 0xa8, 0x8b, 0xfc, 0x1d, 0x54, 0x1c, 0x7b, 0x21, 0x04, 0x8f, 0xab, 0x28, 0xc6, 0x5c, 0x06, 0x73, - 0xaa, 0xe2} + expectedHashedMockProtoOmitEmpty := []byte{0x6d, 0xfa, 0xc1, 0xc2, 0xe0, 0xee, 0xad, 0xe2, 0xa5, 0xad, 0x7d, 0x9e, + 0xad, 0x1c, 0x94, 0x11, 0x6a, 0x21, 0x23, 0xe1, 0xfb, 0xe2, 0x35, 0xd5, 0x37, 0x89, 0xf3, 0xfc, 0xa, 0xfb, + 0x3d, 0xe9} - expectedHashStringOmitEmpty := "GhPMTKvJfUPHK8U3vEmoi/wdVBx7IQSPqyjGXAZzquI=" + expectedHashStringOmitEmpty := "bfrBwuDureKlrX2erRyUEWohI+H74jXVN4nz/Ar7Pek=" t.Run("TestPartial", func(t *testing.T) { hashedBytes, err := ComputeHash(context.Background(), mockProtoOmitEmpty) @@ -124,12 +94,12 @@ func TestPartialFilledProtoHash(t *testing.T) { assert.Equal(t, hashedString, expectedHashStringOmitEmpty) }) - mockOldProtoMessage := &mockOlderProto{ - Integer: 18, - FloatValue: 1.3, - StringValue: "lets test this", - Boolean: true, - } + mockOldProtoMessage := makeStruct(t, map[string]interface{}{ + "integer": 18, + "floatValue": 1.3, + "stringValue": "lets test this", + "boolean": true, + }) t.Run("TestOlderProto", func(t *testing.T) { hashedBytes, err := ComputeHash(context.Background(), mockOldProtoMessage) diff --git a/flytestdlib/storage/mocks/mocks.go b/flytestdlib/storage/mocks/mocks.go index 83180ebabd..dad6cfbc9b 100644 --- a/flytestdlib/storage/mocks/mocks.go +++ b/flytestdlib/storage/mocks/mocks.go @@ -9,8 +9,8 @@ import ( "io" "github.com/flyteorg/flyte/v2/flytestdlib/storage" - "github.com/golang/protobuf/proto" mock "github.com/stretchr/testify/mock" + "google.golang.org/protobuf/proto" ) // NewMetadata creates a new instance of Metadata. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. diff --git a/flytestdlib/storage/protobuf_store.go b/flytestdlib/storage/protobuf_store.go index 44e04d6acb..3b6fe0dc1d 100644 --- a/flytestdlib/storage/protobuf_store.go +++ b/flytestdlib/storage/protobuf_store.go @@ -6,9 +6,9 @@ import ( "fmt" "time" - "github.com/golang/protobuf/proto" //nolint: staticcheck errs "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" + "google.golang.org/protobuf/proto" //nolint: staticcheck "github.com/flyteorg/flyte/v2/flytestdlib/ioutils" "github.com/flyteorg/flyte/v2/flytestdlib/logger" diff --git a/flytestdlib/storage/protobuf_store_test.go b/flytestdlib/storage/protobuf_store_test.go index 019c61e7d4..6ac6e210dc 100644 --- a/flytestdlib/storage/protobuf_store_test.go +++ b/flytestdlib/storage/protobuf_store_test.go @@ -9,56 +9,28 @@ import ( "net/http/httptest" "testing" - "github.com/golang/protobuf/proto" errs "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/wrapperspb" "github.com/flyteorg/flyte/v2/flytestdlib/promutils" "github.com/flyteorg/stow/s3" ) -type mockProtoMessage struct { - X int64 `protobuf:"varint,2,opt,name=x,json=x,proto3" json:"x,omitempty"` -} - -type mockBigDataProtoMessage struct { - X []byte `protobuf:"bytes,1,opt,name=X,proto3" json:"X,omitempty"` -} - -func (mockProtoMessage) Reset() { -} - -func (m mockProtoMessage) String() string { - return proto.CompactTextString(m) -} - -func (mockProtoMessage) ProtoMessage() { -} - -func (mockBigDataProtoMessage) Reset() { -} - -func (m mockBigDataProtoMessage) String() string { - return proto.CompactTextString(m) -} - -func (mockBigDataProtoMessage) ProtoMessage() { -} - func TestDefaultProtobufStore(t *testing.T) { t.Run("Read after Write", func(t *testing.T) { testScope := promutils.NewTestScope() s, err := NewDataStore(&Config{Type: TypeMemory}, testScope) assert.NoError(t, err) - err = s.WriteProtobuf(context.TODO(), "hello", Options{}, &mockProtoMessage{X: 5}) + err = s.WriteProtobuf(context.TODO(), "hello", Options{}, wrapperspb.Int64(5)) assert.NoError(t, err) - m := &mockProtoMessage{} + m := &wrapperspb.Int64Value{} err = s.ReadProtobuf(context.TODO(), "hello", m) assert.NoError(t, err) - assert.Equal(t, int64(5), m.X) + assert.Equal(t, int64(5), m.Value) }) t.Run("RefreshConfig", func(t *testing.T) { @@ -127,15 +99,15 @@ func TestDefaultProtobufStore_BigDataReadAfterWrite(t *testing.T) { _, err = rand.Read(bigD) assert.NoError(t, err) - mockMessage := &mockBigDataProtoMessage{X: bigD} + mockMessage := wrapperspb.Bytes(bigD) err = s.WriteProtobuf(context.TODO(), DataReference("bigK"), Options{}, mockMessage) assert.NoError(t, err) - m := &mockBigDataProtoMessage{} + m := &wrapperspb.BytesValue{} err = s.ReadProtobuf(context.TODO(), DataReference("bigK"), m) assert.NoError(t, err) - assert.Equal(t, bigD, m.X) + assert.Equal(t, bigD, m.Value) }) } @@ -159,13 +131,13 @@ func TestDefaultProtobufStore_HardErrors(t *testing.T) { } pbErroneousStore := NewDefaultProtobufStoreWithMetrics(store, metrics.protoMetrics) t.Run("Test if hard write errors are handled correctly", func(t *testing.T) { - err := pbErroneousStore.WriteProtobuf(ctx, k1, Options{}, &mockProtoMessage{X: 5}) + err := pbErroneousStore.WriteProtobuf(ctx, k1, Options{}, wrapperspb.Int64(5)) assert.False(t, IsFailedWriteToCache(err)) assert.Equal(t, dummyWriteErrorMsg, errs.Cause(err).Error()) }) t.Run("Test if hard read errors are handled correctly", func(t *testing.T) { - m := &mockProtoMessage{} + m := &wrapperspb.Int64Value{} err := pbErroneousStore.ReadProtobuf(ctx, k1, m) assert.False(t, IsFailedWriteToCache(err)) assert.Equal(t, dummyReadErrorMsg, errs.Cause(err).Error()) diff --git a/flytestdlib/storage/storage.go b/flytestdlib/storage/storage.go index 6a2ba589ef..af2107edc3 100644 --- a/flytestdlib/storage/storage.go +++ b/flytestdlib/storage/storage.go @@ -15,7 +15,7 @@ import ( "time" "github.com/flyteorg/stow" - "github.com/golang/protobuf/proto" //nolint: staticcheck + "google.golang.org/protobuf/proto" //nolint: staticcheck ) // DataReference defines a reference to data location. diff --git a/flytestdlib/utils/marshal_utils.go b/flytestdlib/utils/marshal_utils.go index d0fde42545..f5a09e8246 100644 --- a/flytestdlib/utils/marshal_utils.go +++ b/flytestdlib/utils/marshal_utils.go @@ -4,17 +4,16 @@ import ( "bytes" "encoding/json" "fmt" - "strings" - "github.com/golang/protobuf/jsonpb" //nolint: staticcheck - "github.com/golang/protobuf/proto" //nolint: staticcheck - structpb "github.com/golang/protobuf/ptypes/struct" "github.com/pkg/errors" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" //nolint: staticcheck + structpb "google.golang.org/protobuf/types/known/structpb" ) -var jsonPbMarshaler = jsonpb.Marshaler{} -var jsonPbUnmarshaler = jsonpb.Unmarshaler{ - AllowUnknownFields: true, +var jsonPbMarshaler = protojson.MarshalOptions{} +var jsonPbUnmarshaler = protojson.UnmarshalOptions{ + DiscardUnknown: true, } // UnmarshalStructToPb unmarshals a proto struct into a proto message using jsonPb marshaler. @@ -27,12 +26,12 @@ func UnmarshalStructToPb(structObj *structpb.Struct, msg proto.Message) error { return fmt.Errorf("nil proto.Message object passed") } - jsonObj, err := jsonPbMarshaler.MarshalToString(structObj) + jsonObj, err := jsonPbMarshaler.Marshal(structObj) if err != nil { return errors.WithMessage(err, "Failed to marshal strcutObj input") } - if err = UnmarshalStringToPb(jsonObj, msg); err != nil { + if err = UnmarshalBytesToPb(jsonObj, msg); err != nil { return errors.WithMessage(err, "Failed to unmarshal json obj into proto") } @@ -46,9 +45,11 @@ func MarshalPbToStruct(in proto.Message) (out *structpb.Struct, err error) { } var buf bytes.Buffer - if err := jsonPbMarshaler.Marshal(&buf, in); err != nil { + b, err := jsonPbMarshaler.Marshal(in) + if err != nil { return nil, errors.WithMessage(err, "Failed to marshal input proto message") } + buf.Write(b) out = &structpb.Struct{} if err = UnmarshalBytesToPb(buf.Bytes(), out); err != nil { @@ -60,27 +61,37 @@ func MarshalPbToStruct(in proto.Message) (out *structpb.Struct, err error) { // MarshalPbToString marshals a proto message using jsonPb marshaler to string. func MarshalPbToString(msg proto.Message) (string, error) { - return jsonPbMarshaler.MarshalToString(msg) + if msg == nil { + return "", fmt.Errorf("nil proto message passed") + } + + b, err := jsonPbMarshaler.Marshal(msg) + return string(b), err } // UnmarshalStringToPb unmarshals a string to a proto message func UnmarshalStringToPb(s string, msg proto.Message) error { - return jsonPbUnmarshaler.Unmarshal(strings.NewReader(s), msg) + return jsonPbUnmarshaler.Unmarshal([]byte(s), msg) } // MarshalPbToBytes marshals a proto message to a byte slice func MarshalPbToBytes(msg proto.Message) ([]byte, error) { + if msg == nil { + return nil, fmt.Errorf("nil proto message passed") + } + var buf bytes.Buffer - err := jsonPbMarshaler.Marshal(&buf, msg) + b, err := jsonPbMarshaler.Marshal(msg) if err != nil { return nil, err } + buf.Write(b) return buf.Bytes(), nil } // UnmarshalBytesToPb unmarshals a byte slice to a proto message func UnmarshalBytesToPb(b []byte, msg proto.Message) error { - return jsonPbUnmarshaler.Unmarshal(bytes.NewReader(b), msg) + return jsonPbUnmarshaler.Unmarshal(b, msg) } // MarshalObjToStruct marshals obj into a struct. Will use jsonPb if input is a proto message, otherwise, it'll use json diff --git a/flytestdlib/utils/marshal_utils_test.go b/flytestdlib/utils/marshal_utils_test.go index 3ac0c3f339..ad86ca4f8e 100644 --- a/flytestdlib/utils/marshal_utils_test.go +++ b/flytestdlib/utils/marshal_utils_test.go @@ -4,9 +4,9 @@ import ( "testing" "github.com/go-test/deep" - "github.com/golang/protobuf/proto" - structpb "github.com/golang/protobuf/ptypes/struct" "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" + structpb "google.golang.org/protobuf/types/known/structpb" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/util/json" diff --git a/flytestdlib/utils/prototest/test_type.pb.go b/flytestdlib/utils/prototest/test_type.pb.go index 7495748680..103f4d8d45 100644 --- a/flytestdlib/utils/prototest/test_type.pb.go +++ b/flytestdlib/utils/prototest/test_type.pb.go @@ -10,7 +10,6 @@ import ( reflect "reflect" sync "sync" - proto "github.com/golang/protobuf/proto" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" ) @@ -22,10 +21,6 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) -// This is a compile-time assertion that a sufficiently up-to-date version -// of the legacy proto package is being used. -const _ = proto.ProtoPackageIsVersion4 - type TestProto struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache diff --git a/go.mod b/go.mod index f7c24b527f..7bccc1bf00 100644 --- a/go.mod +++ b/go.mod @@ -27,7 +27,6 @@ require ( github.com/ghodss/yaml v1.0.0 github.com/go-test/deep v1.1.1 github.com/go-viper/mapstructure/v2 v2.4.0 - github.com/golang/protobuf v1.5.4 github.com/googleapis/gax-go/v2 v2.15.0 github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus v1.1.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 @@ -144,6 +143,7 @@ require ( github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/mock v1.6.0 // indirect + github.com/golang/protobuf v1.5.4 // indirect github.com/google/btree v1.1.3 // indirect github.com/google/cel-go v0.26.0 // indirect github.com/google/gnostic-models v0.7.0 // indirect diff --git a/runs/repository/transformers/task.go b/runs/repository/transformers/task.go index ec7ccb16c6..4c0f30c98a 100644 --- a/runs/repository/transformers/task.go +++ b/runs/repository/transformers/task.go @@ -5,7 +5,7 @@ import ( "database/sql" "strings" - "github.com/golang/protobuf/proto" //nolint: staticcheck + "google.golang.org/protobuf/proto" //nolint: staticcheck "google.golang.org/protobuf/types/known/timestamppb" "github.com/flyteorg/flyte/v2/flytestdlib/logger" diff --git a/runs/service/run_service_test.go b/runs/service/run_service_test.go index 3c40bd5917..6c7a580525 100644 --- a/runs/service/run_service_test.go +++ b/runs/service/run_service_test.go @@ -13,10 +13,10 @@ import ( "time" "connectrpc.com/connect" - "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" "github.com/flyteorg/flyte/v2/flytestdlib/storage" From 3d6844fe2915a913a85d88291333ce54117a52f2 Mon Sep 17 00:00:00 2001 From: Jason Parraga Date: Wed, 20 May 2026 15:24:04 -0700 Subject: [PATCH 2/4] fix unit test Signed-off-by: Jason Parraga --- .../tasks/pluginmachinery/secret/secrets_test.go | 14 ++++++++++---- runs/service/internal_run_service.go | 3 ++- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/flyteplugins/go/tasks/pluginmachinery/secret/secrets_test.go b/flyteplugins/go/tasks/pluginmachinery/secret/secrets_test.go index 17891551dd..10048d2282 100644 --- a/flyteplugins/go/tasks/pluginmachinery/secret/secrets_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/secret/secrets_test.go @@ -12,6 +12,8 @@ import ( "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/secret/config" "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/secret/mocks" + secretUtils "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/utils/secrets" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" ) func TestSecretsWebhook_Mutate(t *testing.T) { @@ -23,12 +25,16 @@ func TestSecretsWebhook_Mutate(t *testing.T) { }) namespace := "test-namespace" + secretAnnotations, err := secretUtils.MarshalSecretsToMapStrings([]*core.Secret{ + { + Key: "my_key", + }, + }) + assert.NoError(t, err) podWithAnnotations := &corev1.Pod{ ObjectMeta: v1.ObjectMeta{ - Namespace: namespace, - Annotations: map[string]string{ - "flyte.secrets/s0": "nnsxsorcnv4v623fperca", - }, + Namespace: namespace, + Annotations: secretAnnotations, }, } diff --git a/runs/service/internal_run_service.go b/runs/service/internal_run_service.go index 32eebc334b..577e7686dd 100644 --- a/runs/service/internal_run_service.go +++ b/runs/service/internal_run_service.go @@ -5,10 +5,11 @@ import ( "database/sql" "errors" "fmt" - "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common" "io" "time" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common" + "connectrpc.com/connect" grpcstatus "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/protobuf/proto" From a4d0b826d4c46644461d7616a1eae283ac60d506 Mon Sep 17 00:00:00 2001 From: Jason Parraga Date: Wed, 20 May 2026 15:33:27 -0700 Subject: [PATCH 3/4] Add targeted unit test Signed-off-by: Jason Parraga --- flytecopilot/data/download.go | 7 +++--- flytecopilot/data/download_test.go | 2 +- flytestdlib/storage/protobuf_store_test.go | 28 ++++++++++++++++++++++ 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/flytecopilot/data/download.go b/flytecopilot/data/download.go index e849afecd1..c89ca424c2 100644 --- a/flytecopilot/data/download.go +++ b/flytecopilot/data/download.go @@ -537,13 +537,14 @@ func (d Downloader) DownloadInputs(ctx context.Context, inputRef storage.DataRef logger.Errorf(ctx, "Failed to download inputs from [%s], err [%s]", inputRef, err) return errors.Wrapf(err, "failed to download input metadata message from remote store") } + if len(inputs.GetLiterals()) == 0 { + return nil + } + varMap, lMap, err := d.RecursiveDownload(ctx, inputs, outputDir, true) if err != nil { return errors.Wrapf(err, "failed to download input variable from remote store") } - if len(lMap.GetLiterals()) == 0 { - return nil - } // We will always write the protobuf b, err := proto.Marshal(lMap) diff --git a/flytecopilot/data/download_test.go b/flytecopilot/data/download_test.go index 2fb23847f6..83b6df9b18 100644 --- a/flytecopilot/data/download_test.go +++ b/flytecopilot/data/download_test.go @@ -265,7 +265,7 @@ func TestRecursiveDownload(t *testing.T) { } // Mock reading the offloaded metadata - err = s.WriteProtobuf(context.Background(), storage.DataReference("s3://container/offloaded"), storage.Options{}, &core.Literal{ + err = s.WriteProtobuf(context.Background(), "s3://container/offloaded", storage.Options{}, &core.Literal{ Value: &core.Literal_Map{ Map: &core.LiteralMap{ Literals: map[string]*core.Literal{ diff --git a/flytestdlib/storage/protobuf_store_test.go b/flytestdlib/storage/protobuf_store_test.go index 6ac6e210dc..b6562ac589 100644 --- a/flytestdlib/storage/protobuf_store_test.go +++ b/flytestdlib/storage/protobuf_store_test.go @@ -15,6 +15,7 @@ import ( "google.golang.org/protobuf/types/known/wrapperspb" "github.com/flyteorg/flyte/v2/flytestdlib/promutils" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" "github.com/flyteorg/stow/s3" ) @@ -80,6 +81,33 @@ func TestDefaultProtobufStore(t *testing.T) { }) } +func TestDefaultProtobufStore_EmptyLiteralMap(t *testing.T) { + testScope := promutils.NewTestScope() + s, err := NewDataStore(&Config{Type: TypeMemory}, testScope) + require.NoError(t, err) + + ref := DataReference("empty-literal-map") + require.NoError(t, s.WriteProtobuf(context.TODO(), ref, Options{}, &core.LiteralMap{})) + + raw, err := s.ReadRaw(context.TODO(), ref) + require.NoError(t, err) + defer func() { + require.NoError(t, raw.Close()) + }() + + rawBytes, err := io.ReadAll(raw) + require.NoError(t, err) + assert.Empty(t, rawBytes) + + got := &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "stale": {}, + }, + } + require.NoError(t, s.ReadProtobuf(context.TODO(), ref, got)) + assert.Empty(t, got.GetLiterals()) +} + func TestDefaultProtobufStore_BigDataReadAfterWrite(t *testing.T) { t.Run("Read after Write with Big Data", func(t *testing.T) { testScope := promutils.NewTestScope() From 12d30a8ffa4b37ef9e2f16728e771973788f5a7c Mon Sep 17 00:00:00 2001 From: Jason Parraga Date: Wed, 20 May 2026 15:35:07 -0700 Subject: [PATCH 4/4] cleanup nolint Signed-off-by: Jason Parraga --- flytecopilot/data/download.go | 2 +- flytecopilot/data/upload.go | 2 +- flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot.go | 2 +- flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go | 2 +- flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils.go | 2 +- flytestdlib/app/error.go | 2 +- flytestdlib/flytestdlib/storage/mocks/mocks.go | 2 +- flytestdlib/pbhash/pbhash.go | 2 +- flytestdlib/storage/protobuf_store.go | 2 +- flytestdlib/storage/storage.go | 2 +- flytestdlib/utils/marshal_utils.go | 2 +- runs/repository/transformers/task.go | 2 +- 12 files changed, 12 insertions(+), 12 deletions(-) diff --git a/flytecopilot/data/download.go b/flytecopilot/data/download.go index c89ca424c2..fccaa38c8c 100644 --- a/flytecopilot/data/download.go +++ b/flytecopilot/data/download.go @@ -17,7 +17,7 @@ import ( "github.com/ghodss/yaml" "github.com/pkg/errors" "google.golang.org/protobuf/encoding/protojson" - "google.golang.org/protobuf/proto" //nolint: staticcheck + "google.golang.org/protobuf/proto" structpb "google.golang.org/protobuf/types/known/structpb" "github.com/flyteorg/flyte/v2/flytestdlib/futures" diff --git a/flytecopilot/data/upload.go b/flytecopilot/data/upload.go index 664916b409..ca3b49b263 100644 --- a/flytecopilot/data/upload.go +++ b/flytecopilot/data/upload.go @@ -10,7 +10,7 @@ import ( "reflect" "github.com/pkg/errors" - "google.golang.org/protobuf/proto" //nolint: staticcheck + "google.golang.org/protobuf/proto" "github.com/flyteorg/flyte/v2/flyteidl2/clients/go/coreutils" "github.com/flyteorg/flyte/v2/flytestdlib/futures" diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot.go index 1c9ba62fb8..50488f5799 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot.go @@ -8,7 +8,7 @@ import ( "time" "github.com/pkg/errors" - "google.golang.org/protobuf/proto" //nolint: staticcheck + "google.golang.org/protobuf/proto" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 5521b22fe8..ce2ccd48f0 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -10,7 +10,7 @@ import ( "time" "github.com/imdario/mergo" - "google.golang.org/protobuf/proto" //nolint: staticcheck + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" diff --git a/flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils.go b/flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils.go index f04978f3f6..fda9d208fe 100755 --- a/flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils.go +++ b/flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils.go @@ -5,7 +5,7 @@ import ( "fmt" "google.golang.org/protobuf/encoding/protojson" - "google.golang.org/protobuf/proto" //nolint: staticcheck + "google.golang.org/protobuf/proto" structpb "google.golang.org/protobuf/types/known/structpb" ) diff --git a/flytestdlib/app/error.go b/flytestdlib/app/error.go index 92cf96b1f7..55984aee09 100644 --- a/flytestdlib/app/error.go +++ b/flytestdlib/app/error.go @@ -5,7 +5,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "google.golang.org/protobuf/proto" //nolint: staticcheck + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/protoadapt" ) diff --git a/flytestdlib/flytestdlib/storage/mocks/mocks.go b/flytestdlib/flytestdlib/storage/mocks/mocks.go index 3eeb9b02b1..dad6cfbc9b 100644 --- a/flytestdlib/flytestdlib/storage/mocks/mocks.go +++ b/flytestdlib/flytestdlib/storage/mocks/mocks.go @@ -10,7 +10,7 @@ import ( "github.com/flyteorg/flyte/v2/flytestdlib/storage" mock "github.com/stretchr/testify/mock" - "google.golang.org/protobuf/proto" //nolint: staticcheck + "google.golang.org/protobuf/proto" ) // NewMetadata creates a new instance of Metadata. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. diff --git a/flytestdlib/pbhash/pbhash.go b/flytestdlib/pbhash/pbhash.go index 3f00f8dd89..e599c0f497 100644 --- a/flytestdlib/pbhash/pbhash.go +++ b/flytestdlib/pbhash/pbhash.go @@ -7,7 +7,7 @@ import ( goObjectHash "github.com/benlaurie/objecthash/go/objecthash" "google.golang.org/protobuf/encoding/protojson" - "google.golang.org/protobuf/proto" //nolint: staticcheck + "google.golang.org/protobuf/proto" "github.com/flyteorg/flyte/v2/flytestdlib/logger" ) diff --git a/flytestdlib/storage/protobuf_store.go b/flytestdlib/storage/protobuf_store.go index 3b6fe0dc1d..59747e6eb8 100644 --- a/flytestdlib/storage/protobuf_store.go +++ b/flytestdlib/storage/protobuf_store.go @@ -8,7 +8,7 @@ import ( errs "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" - "google.golang.org/protobuf/proto" //nolint: staticcheck + "google.golang.org/protobuf/proto" "github.com/flyteorg/flyte/v2/flytestdlib/ioutils" "github.com/flyteorg/flyte/v2/flytestdlib/logger" diff --git a/flytestdlib/storage/storage.go b/flytestdlib/storage/storage.go index af2107edc3..7ef5ac92ae 100644 --- a/flytestdlib/storage/storage.go +++ b/flytestdlib/storage/storage.go @@ -15,7 +15,7 @@ import ( "time" "github.com/flyteorg/stow" - "google.golang.org/protobuf/proto" //nolint: staticcheck + "google.golang.org/protobuf/proto" ) // DataReference defines a reference to data location. diff --git a/flytestdlib/utils/marshal_utils.go b/flytestdlib/utils/marshal_utils.go index f5a09e8246..90180f4b0b 100644 --- a/flytestdlib/utils/marshal_utils.go +++ b/flytestdlib/utils/marshal_utils.go @@ -7,7 +7,7 @@ import ( "github.com/pkg/errors" "google.golang.org/protobuf/encoding/protojson" - "google.golang.org/protobuf/proto" //nolint: staticcheck + "google.golang.org/protobuf/proto" structpb "google.golang.org/protobuf/types/known/structpb" ) diff --git a/runs/repository/transformers/task.go b/runs/repository/transformers/task.go index 4c0f30c98a..2a7c8246d1 100644 --- a/runs/repository/transformers/task.go +++ b/runs/repository/transformers/task.go @@ -5,7 +5,7 @@ import ( "database/sql" "strings" - "google.golang.org/protobuf/proto" //nolint: staticcheck + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" "github.com/flyteorg/flyte/v2/flytestdlib/logger"