diff --git a/cel_expr_python/BUILD b/cel_expr_python/BUILD index d4b206a..aca5cc3 100644 --- a/cel_expr_python/BUILD +++ b/cel_expr_python/BUILD @@ -105,14 +105,25 @@ pybind_extension( # For pybind11-based CEL extensions. pybind_library( name = "cel_extension", - hdrs = ["cel_extension.h"], + srcs = [ + "py_error_status.cc", + ], + hdrs = [ + "cel_extension.h", + "py_error_status.h", + ], visibility = ["//visibility:public"], deps = [ + ":status_macros", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_cel_cpp//compiler", "@com_google_cel_cpp//runtime:runtime_builder", "@com_google_cel_cpp//runtime:runtime_options", - "@com_google_protobuf//:protobuf", ], ) @@ -141,7 +152,10 @@ py_test( srcs = ["cel_env_test.py"], deps = [ ":cel", + "//cel_expr_python/ext:ext_bindings", "//cel_expr_python/ext:ext_math", + "//cel_expr_python/ext:ext_optional", + "//cel_expr_python/ext:ext_strings", "//testing:proto2_test_all_types_py_pb2", "@com_google_absl_py//absl/testing:absltest", ], diff --git a/cel_expr_python/cel_env_test.py b/cel_expr_python/cel_env_test.py index 2aa16c2..79a9700 100644 --- a/cel_expr_python/cel_env_test.py +++ b/cel_expr_python/cel_env_test.py @@ -22,7 +22,10 @@ from absl.testing import absltest from cel_expr_python import cel +from cel_expr_python.ext import ext_bindings from cel_expr_python.ext import ext_math +from cel_expr_python.ext import ext_optional +from cel_expr_python.ext import ext_strings from cel.expr.conformance.proto2 import test_all_types_pb2 as test_all_types_pb @@ -95,9 +98,7 @@ def test_invalid_yaml(self): ) def test_config_export_container(self): - env = cel.NewEnv( - container="test.container" - ) + env = cel.NewEnv(container="test.container") yaml = env.config().to_yaml() self.assertEqual( normalize_yaml(yaml), @@ -251,6 +252,52 @@ def test_config_variable_types(self): self.assertEqual(res.type(), cel.Type.INT) self.assertEqual(res.value(), 42) + def test_config_export_extension_version(self): + env = cel.NewEnv( + extensions=[ + ext_math.ExtMath(0), + ext_optional.ExtOptional(1), + ext_strings.ExtStrings(2), + ext_bindings.ExtBindings(), + ], + ) + yaml = env.config().to_yaml() + self.assertEqual( + normalize_yaml(yaml), + normalize_yaml(""" + extensions: + - name: "bindings" + - name: "math" + version: 0 + - name: "optional" + version: 1 + - name: "strings" + version: 2 + """), + ) + + def test_config_extension_version_out_of_range(self): + cases = [ + [ + lambda: ext_math.ExtMath(42), + r"'math' extension version: 42 not in range \[0, \d+\]", + ], + [ + lambda: ext_optional.ExtOptional(6), + r"'optional' extension version: 6 not in range \[0, \d+\]", + ], + [ + lambda: ext_strings.ExtStrings(18), + r"'strings' extension version: 18 not in range \[0, \d+\]", + ], + ] + for test_case in cases: + with self.assertRaises(Exception) as e: + cel.NewEnv( + extensions=[test_case[0]()], + ) + self.assertRegex(str(e.exception), test_case[1]) + def test_config_extensions(self): config = cel.NewEnvConfigFromYaml(""" extensions: @@ -276,14 +323,28 @@ def test_config_extensions(self): res = env.compile("hello('World')").eval() self.assertEqual(res.value(), "Hello, World!") - def test_config_extensions_override(self): - # TODO(b/498655870): add assertion based on extension aliases once - # supported. + def test_config_extension_override_same_version(self): config = cel.NewEnvConfigFromYaml(""" extensions: - name: cel.lib.ext.math + version: 1 + - name: strings + version: 2 + """) + env = cel.NewEnv( + config=config, + extensions=[ext_math.ExtMath(1), ext_strings.ExtStrings(2)], + ) + res = env.compile("'%.3f'.format([math.floor(3.14)])").eval() + self.assertEqual(res.value(), "3.000") + + def test_config_extension_override_different_version(self): + config = cel.NewEnvConfigFromYaml(""" + extensions: + - name: math version: 0 - name: cel.lib.ext.strings + version: 2 """) with self.assertRaises(Exception) as e: cel.NewEnv( @@ -291,8 +352,18 @@ def test_config_extensions_override(self): extensions=[ext_math.ExtMath()], ) self.assertIn( - "Extension 'cel.lib.ext.math' version 0 is already included. Cannot" - " also include version 'latest'", + "Extension 'math' version 0 is already included. Cannot" + " also include version 2", + str(e.exception), + ) + with self.assertRaises(Exception) as e: + cel.NewEnv( + config=config, + extensions=[ext_strings.ExtStrings(1)], + ) + self.assertIn( + "Extension 'cel.lib.ext.strings' version 2 is already included. Cannot" + " also include version 1", str(e.exception), ) diff --git a/cel_expr_python/cel_extension.h b/cel_expr_python/cel_extension.h index 86bac22..48a7151 100644 --- a/cel_expr_python/cel_extension.h +++ b/cel_expr_python/cel_extension.h @@ -38,7 +38,9 @@ namespace cel_python { // Python. class CelExtension { public: - explicit CelExtension(std::string name) : name_(std::move(name)) {}; + explicit CelExtension(std::string name, std::string alias = "", + int version = -1) + : name_(std::move(name)), alias_(std::move(alias)), version_(version) {} virtual ~CelExtension() = default; virtual cel::CompilerLibrary GetCompilerLibrary() { @@ -51,9 +53,13 @@ class CelExtension { } std::string name() const { return name_; } + std::string alias() const { return alias_; } + int version() const { return version_; } private: std::string name_; + std::string alias_; + int version_; }; #define CEL_MODULE_NAME "cel_expr_python.cel" @@ -80,6 +86,13 @@ class CelExtension { .def(pybind11::init<>()); \ } +#define CEL_VERSIONED_EXTENSION_MODULE(module_name, class_name) \ + PYBIND11_MODULE(module_name, m) { \ + pybind11::module_::import(CEL_MODULE_NAME); \ + pybind11::class_(m, #class_name) \ + .def(pybind11::init<>()) \ + .def(pybind11::init(), pybind11::arg("version")); \ + } } // namespace cel_python #endif // THIRD_PARTY_CEL_PYTHON_CEL_EXTENSION_H_ diff --git a/cel_expr_python/ext/BUILD b/cel_expr_python/ext/BUILD index 7e32348..cdb038e 100644 --- a/cel_expr_python/ext/BUILD +++ b/cel_expr_python/ext/BUILD @@ -49,6 +49,7 @@ pybind_extension( deps = [ "//cel_expr_python:cel_extension", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_cel_cpp//compiler", "@com_google_cel_cpp//extensions:math_ext", "@com_google_cel_cpp//extensions:math_ext_decls", @@ -69,6 +70,8 @@ pybind_extension( deps = [ "//cel_expr_python:cel_extension", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_cel_cpp//checker:optional", "@com_google_cel_cpp//compiler", "@com_google_cel_cpp//compiler:optional", "@com_google_cel_cpp//runtime:optional_types", @@ -94,9 +97,9 @@ pybind_extension( ) pybind_extension( - name = "ext_string", + name = "ext_strings", srcs = [ - "ext_string.cc", + "ext_strings.cc", ], data = [ "//cel_expr_python:cel", @@ -105,6 +108,7 @@ pybind_extension( deps = [ "//cel_expr_python:cel_extension", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_cel_cpp//compiler", "@com_google_cel_cpp//extensions:strings", "@com_google_cel_cpp//runtime:runtime_builder", diff --git a/cel_expr_python/ext/ext_bindings.cc b/cel_expr_python/ext/ext_bindings.cc index 08caebc..49d9731 100644 --- a/cel_expr_python/ext/ext_bindings.cc +++ b/cel_expr_python/ext/ext_bindings.cc @@ -20,7 +20,8 @@ namespace cel_python { class ExtBindings : public CelExtension { public: - explicit ExtBindings() : CelExtension("cel.lib.ext.cel.bindings") {} + explicit ExtBindings() + : CelExtension("cel.lib.ext.cel.bindings", "bindings") {} cel::CompilerLibrary GetCompilerLibrary() override { return cel::extensions::BindingsCompilerLibrary(); diff --git a/cel_expr_python/ext/ext_encoders.cc b/cel_expr_python/ext/ext_encoders.cc index cc9e2ab..59c67ae 100644 --- a/cel_expr_python/ext/ext_encoders.cc +++ b/cel_expr_python/ext/ext_encoders.cc @@ -23,7 +23,7 @@ namespace cel_python { class ExtEncoders : public CelExtension { public: - explicit ExtEncoders() : CelExtension("cel.lib.ext.encoders") {} + explicit ExtEncoders() : CelExtension("cel.lib.ext.encoders", "encoders") {} cel::CompilerLibrary GetCompilerLibrary() override { return cel::extensions::EncodersCompilerLibrary(); diff --git a/cel_expr_python/ext/ext_math.cc b/cel_expr_python/ext/ext_math.cc index bb50581..4ed18ec 100644 --- a/cel_expr_python/ext/ext_math.cc +++ b/cel_expr_python/ext/ext_math.cc @@ -13,30 +13,41 @@ // limitations under the License. #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "compiler/compiler.h" #include "extensions/math_ext.h" #include "extensions/math_ext_decls.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "cel_expr_python/cel_extension.h" +#include "cel_expr_python/py_error_status.h" namespace cel_python { class ExtMath : public CelExtension { public: - explicit ExtMath() : CelExtension("cel.lib.ext.math") {} + explicit ExtMath(int version) + : CelExtension("cel.lib.ext.math", "math", version) { + if (version < 0 || version > cel::extensions::kMathExtensionLatestVersion) { + throw StatusToException(absl::InvalidArgumentError(absl::StrCat( + "'math' extension version: ", version, " not in range [0, ", + cel::extensions::kMathExtensionLatestVersion, "]"))); + } + } + + ExtMath() : ExtMath(cel::extensions::kMathExtensionLatestVersion) {} cel::CompilerLibrary GetCompilerLibrary() override { - return cel::extensions::MathCompilerLibrary(); + return cel::extensions::MathCompilerLibrary(version()); } absl::Status ConfigureRuntime(cel::RuntimeBuilder& runtime_builder, const cel::RuntimeOptions& opts) override { return cel::extensions::RegisterMathExtensionFunctions( - runtime_builder.function_registry(), opts); + runtime_builder.function_registry(), opts, version()); } }; -CEL_EXTENSION_MODULE(ext_math, ExtMath); +CEL_VERSIONED_EXTENSION_MODULE(ext_math, ExtMath); } // namespace cel_python diff --git a/cel_expr_python/ext/ext_optional.cc b/cel_expr_python/ext/ext_optional.cc index f18f057..a54242d 100644 --- a/cel_expr_python/ext/ext_optional.cc +++ b/cel_expr_python/ext/ext_optional.cc @@ -13,21 +13,32 @@ // limitations under the License. #include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "checker/optional.h" #include "compiler/compiler.h" #include "compiler/optional.h" #include "runtime/optional_types.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "cel_expr_python/cel_extension.h" +#include "cel_expr_python/py_error_status.h" namespace cel_python { class ExtOptional : public CelExtension { public: - explicit ExtOptional() : CelExtension("optional") {} + explicit ExtOptional(int version) : CelExtension("optional", "", version) { + if (version < 0 || version > cel::kOptionalExtensionLatestVersion) { + throw StatusToException(absl::InvalidArgumentError(absl::StrCat( + "'optional' extension version: ", version, " not in range [0, ", + cel::kOptionalExtensionLatestVersion, "]"))); + } + } + + ExtOptional() : ExtOptional(cel::kOptionalExtensionLatestVersion) {} cel::CompilerLibrary GetCompilerLibrary() override { - return cel::OptionalCompilerLibrary(); + return cel::OptionalCompilerLibrary(version()); } absl::Status ConfigureRuntime(cel::RuntimeBuilder& runtime_builder, @@ -40,6 +51,6 @@ class ExtOptional : public CelExtension { } }; -CEL_EXTENSION_MODULE(ext_optional, ExtOptional); +CEL_VERSIONED_EXTENSION_MODULE(ext_optional, ExtOptional); } // namespace cel_python diff --git a/cel_expr_python/ext/ext_string.cc b/cel_expr_python/ext/ext_strings.cc similarity index 56% rename from cel_expr_python/ext/ext_string.cc rename to cel_expr_python/ext/ext_strings.cc index fbda450..d1e6456 100644 --- a/cel_expr_python/ext/ext_string.cc +++ b/cel_expr_python/ext/ext_strings.cc @@ -13,29 +13,42 @@ // limitations under the License. #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "compiler/compiler.h" #include "extensions/strings.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "cel_expr_python/cel_extension.h" +#include "cel_expr_python/py_error_status.h" namespace cel_python { -class ExtString : public CelExtension { +class ExtStrings : public CelExtension { public: - explicit ExtString() : CelExtension("cel.lib.ext.string") {} + explicit ExtStrings(int version) + : CelExtension("cel.lib.ext.strings", "strings", version) { + if (version < 0 || + version > cel::extensions::kStringsExtensionLatestVersion) { + throw StatusToException(absl::InvalidArgumentError(absl::StrCat( + "'strings' extension version: ", version, " not in range [0, ", + cel::extensions::kStringsExtensionLatestVersion, "]"))); + } + } + + ExtStrings() : ExtStrings(cel::extensions::kStringsExtensionLatestVersion) {} cel::CompilerLibrary GetCompilerLibrary() override { - return cel::extensions::StringsCompilerLibrary(); + return cel::extensions::StringsCompilerLibrary(version()); } absl::Status ConfigureRuntime(cel::RuntimeBuilder& runtime_builder, const cel::RuntimeOptions& opts) override { return cel::extensions::RegisterStringsFunctions( - runtime_builder.function_registry(), opts); + runtime_builder.function_registry(), opts, + cel::extensions::StringsExtensionOptions{.version = version()}); } }; -CEL_EXTENSION_MODULE(ext_string, ExtString); +CEL_VERSIONED_EXTENSION_MODULE(ext_strings, ExtStrings); } // namespace cel_python diff --git a/cel_expr_python/py_cel_env_internal.cc b/cel_expr_python/py_cel_env_internal.cc index 7af9af2..64a1697 100644 --- a/cel_expr_python/py_cel_env_internal.cc +++ b/cel_expr_python/py_cel_env_internal.cc @@ -108,8 +108,29 @@ PyCelEnvInternal::NewCelEnvInternal( CelExtensionHandle handle(ext); CEL_PYTHON_ASSIGN_OR_RETURN(CelExtension * extension, handle.GetExtension()); - // TODO(b/498655870): support extension version. - CEL_PYTHON_RETURN_IF_ERROR(config.AddExtensionConfig(extension->name())); + + std::string name; + if (!extension->alias().empty() && + extension->alias() != extension->name()) { + // If the configuration lists the extension by name, use the name; + // otherwise, use the alias. This allows us to detect conflicting + // extension registrations, whether they are included by the extension + // name or alias. + name = extension->alias(); + for (const cel::Config::ExtensionConfig& extension_config : + config.GetExtensionConfigs()) { + if (extension_config.name == extension->name()) { + name = extension_config.name; + break; + } + } + } else { + name = extension->name(); + } + CEL_PYTHON_RETURN_IF_ERROR(config.AddExtensionConfig( + name, extension->version() >= 0 + ? extension->version() + : cel::Config::ExtensionConfig::kLatest)); extension_handles.push_back(std::move(handle)); } diff --git a/conformance/BUILD b/conformance/BUILD index 3101953..bed1210 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -45,7 +45,7 @@ py_test( "//cel_expr_python/ext:ext_math", "//cel_expr_python/ext:ext_optional", "//cel_expr_python/ext:ext_proto", - "//cel_expr_python/ext:ext_string", + "//cel_expr_python/ext:ext_strings", "@com_google_absl_py//absl/testing:absltest", "//testing:proto2_test_all_types_py_pb2", "//testing:proto3_test_all_types_py_pb2", diff --git a/conformance/conformance_test.py b/conformance/conformance_test.py index 58a82b9..10329c6 100644 --- a/conformance/conformance_test.py +++ b/conformance/conformance_test.py @@ -33,7 +33,7 @@ from cel_expr_python.ext import ext_math from cel_expr_python.ext import ext_optional from cel_expr_python.ext import ext_proto -from cel_expr_python.ext import ext_string +from cel_expr_python.ext import ext_strings from cel.expr.conformance.proto2 import test_all_types_extensions_pb2 as test_all_types_extensions_proto2 # pylint: disable=unused-import from cel.expr.conformance.proto2 import test_all_types_pb2 as test_all_types_proto2 # pylint: disable=unused-import from cel.expr.conformance.proto3 import test_all_types_pb2 as test_all_types_proto3 # pylint: disable=unused-import @@ -138,7 +138,7 @@ class ConformanceTest(absltest.TestCase): "math_ext": [ext_math.ExtMath()], "optionals": [ext_optional.ExtOptional()], "proto2_ext": [ext_proto.ExtProto()], - "string_ext": [ext_string.ExtString()], + "string_ext": [ext_strings.ExtStrings()], "type_deduction": [ext_optional.ExtOptional()], }