From 6975536ba104d314fcaab2cca1cfa179239554c9 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Thu, 11 Jun 2026 11:02:18 -0700 Subject: [PATCH] Fix variadic logical operator planning PiperOrigin-RevId: 930627361 --- conformance/BUILD | 23 +++++-- conformance/run.bzl | 17 ++++-- conformance/run.cc | 5 ++ conformance/service.cc | 46 +++++++++----- conformance/service.h | 1 + eval/compiler/BUILD | 1 + eval/compiler/flat_expr_builder.cc | 80 +++++++++++++------------ eval/compiler/flat_expr_builder_test.cc | 7 ++- 8 files changed, 113 insertions(+), 67 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index a6f25e001..35d554c7b 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -32,7 +32,6 @@ cc_library( "//common:ast", "//common:ast_proto", "//common:decl_proto_v1alpha1", - "//common:expr", "//common:source", "//common:value", "//common/internal:value_conversion", @@ -57,8 +56,6 @@ cc_library( "//extensions/protobuf:enum_adapter", "//internal:status_macros", "//parser", - "//parser:macro", - "//parser:macro_expr_factory", "//parser:macro_registry", "//parser:options", "//parser:standard_macros", @@ -75,8 +72,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:span", "@com_google_cel_spec//proto/cel/expr:syntax_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto2:test_all_types_cc_proto", "@com_google_cel_spec//proto/cel/expr/conformance/proto3:test_all_types_cc_proto", @@ -302,6 +297,24 @@ gen_conformance_tests( skip_tests = _TESTS_TO_SKIP_MODERN + _TESTS_TO_SKIP_CHECKED, ) +gen_conformance_tests( + name = "conformance_variadic", + checked = True, + data = _ALL_TESTS, + enable_variadic_logical_operators = True, + modern = True, + skip_tests = _TESTS_TO_SKIP_MODERN + _TESTS_TO_SKIP_CHECKED, +) + +gen_conformance_tests( + name = "conformance_legacy_variadic", + checked = True, + data = _ALL_TESTS, + enable_variadic_logical_operators = True, + modern = False, + skip_tests = _TESTS_TO_SKIP_LEGACY + _TESTS_TO_SKIP_CHECKED, +) + # Generates a bunch of `cc_test` whose names follow the pattern # `conformance_dashboard_..._{arena|refcount}_{optimized|unoptimized}_{recursive|iterative}`. gen_conformance_tests( diff --git a/conformance/run.bzl b/conformance/run.bzl index 2c0b51c0e..8faeb6c16 100644 --- a/conformance/run.bzl +++ b/conformance/run.bzl @@ -56,7 +56,7 @@ def _conformance_test_name(name, optimize, recursive): ], ) -def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard): +def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard, enable_variadic_logical_operators): args = [] if modern: args.append("--modern") @@ -72,12 +72,14 @@ def _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, args.append("--noskip_check") if dashboard: args.append("--dashboard") + if enable_variadic_logical_operators: + args.append("--enable_variadic_logical_operators") return args -def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_check, skip_tests, tags, dashboard): +def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_check, skip_tests, tags, dashboard, enable_variadic_logical_operators): cc_test( name = _conformance_test_name(name, optimize, recursive), - args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard) + ["$(rlocationpath {})".format(test) for test in data], + args = _conformance_test_args(modern, optimize, recursive, select_opt, skip_check, dashboard, enable_variadic_logical_operators) + ["$(rlocationpath {})".format(test) for test in data], env = select( { "@platforms//os:windows": {"CEL_SKIP_TESTS": ",".join(skip_tests + _TESTS_TO_SKIP_WINDOWS)}, @@ -89,18 +91,20 @@ def _conformance_test(name, data, modern, optimize, recursive, select_opt, skip_ tags = tags, ) -def gen_conformance_tests(name, data, modern = False, checked = False, select_opt = False, dashboard = False, skip_tests = [], tags = []): +def gen_conformance_tests(name, data, modern = False, checked = False, select_opt = False, dashboard = False, skip_tests = [], tags = [], enable_variadic_logical_operators = False): """Generates conformance tests. Args: name: prefix for all tests + data: textproto targets describing conformance tests modern: run using modern APIs checked: whether to apply type checking - data: textproto targets describing conformance tests + select_opt: enable select optimization + dashboard: enable dashboard mode skip_tests: tests to skip in the format of the cel-spec test runner. See documentation in github.com/google/cel-spec/tests/simple/simple_test.go tags: tags added to the generated targets - dashboard: enable dashboard mode + enable_variadic_logical_operators: enable variadic logical operators """ skip_check = not checked tests = [] @@ -119,6 +123,7 @@ def gen_conformance_tests(name, data, modern = False, checked = False, select_op skip_tests = _expand_tests_to_skip(skip_tests), tags = tags, dashboard = dashboard, + enable_variadic_logical_operators = enable_variadic_logical_operators, ) native.test_suite( name = name, diff --git a/conformance/run.cc b/conformance/run.cc index 4a0493494..1be16ba60 100644 --- a/conformance/run.cc +++ b/conformance/run.cc @@ -66,6 +66,9 @@ ABSL_FLAG(std::vector, skip_tests, {}, "Tests to skip"); ABSL_FLAG(bool, dashboard, false, "Dashboard mode, ignore test failures"); ABSL_FLAG(bool, skip_check, true, "Skip type checking the expressions"); ABSL_FLAG(bool, select_optimization, false, "Enable select optimization."); +ABSL_FLAG(bool, enable_variadic_logical_operators, false, + "Enable parsing logical AND & OR operators as a single flat variadic " + "call."); namespace { @@ -261,6 +264,8 @@ NewConformanceServiceFromFlags() { .modern = absl::GetFlag(FLAGS_modern), .recursive = absl::GetFlag(FLAGS_recursive), .select_optimization = absl::GetFlag(FLAGS_select_optimization), + .enable_variadic_logical_operators = + absl::GetFlag(FLAGS_enable_variadic_logical_operators), }); ABSL_CHECK_OK(status_or_service); return std::shared_ptr( diff --git a/conformance/service.cc b/conformance/service.cc index 7e3eded82..d81200cad 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -128,13 +128,15 @@ cel::expr::Expr ExtractExpr( absl::Status LegacyParse(const conformance::v1alpha1::ParseRequest& request, conformance::v1alpha1::ParseResponse& response, - bool enable_optional_syntax) { + bool enable_optional_syntax, + bool enable_variadic_logical_operators) { if (request.cel_source().empty()) { return absl::InvalidArgumentError("no source code"); } cel::ParserOptions options; options.enable_optional_syntax = enable_optional_syntax; options.enable_quoted_identifiers = true; + options.enable_variadic_logical_operators = enable_variadic_logical_operators; cel::MacroRegistry macros; CEL_RETURN_IF_ERROR(cel::RegisterStandardMacros(macros, options)); CEL_RETURN_IF_ERROR( @@ -236,7 +238,8 @@ absl::Status CheckImpl(google::protobuf::Arena* arena, class LegacyConformanceServiceImpl : public ConformanceServiceInterface { public: static absl::StatusOr> Create( - bool optimize, bool recursive, bool select_optimization) { + bool optimize, bool recursive, bool select_optimization, + bool enable_variadic_logical_operators) { static auto* constant_arena = new Arena(); google::protobuf::LinkMessageReflection< @@ -313,14 +316,15 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface { CEL_RETURN_IF_ERROR(cel::extensions::RegisterMathExtensionFunctions( builder->GetRegistry(), options)); - return absl::WrapUnique( - new LegacyConformanceServiceImpl(std::move(builder))); + return absl::WrapUnique(new LegacyConformanceServiceImpl( + std::move(builder), enable_variadic_logical_operators)); } void Parse(const conformance::v1alpha1::ParseRequest& request, conformance::v1alpha1::ParseResponse& response) override { auto status = - LegacyParse(request, response, /*enable_optional_syntax=*/false); + LegacyParse(request, response, /*enable_optional_syntax=*/false, + enable_variadic_logical_operators_); if (!status.ok()) { auto* issue = response.add_issues(); issue->set_code(ToGrpcCode(status.code())); @@ -418,17 +422,20 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface { } private: - explicit LegacyConformanceServiceImpl( - std::unique_ptr builder) - : builder_(std::move(builder)) {} + LegacyConformanceServiceImpl(std::unique_ptr builder, + bool enable_variadic_logical_operators) + : builder_(std::move(builder)), + enable_variadic_logical_operators_(enable_variadic_logical_operators) {} std::unique_ptr builder_; + bool enable_variadic_logical_operators_; }; class ModernConformanceServiceImpl : public ConformanceServiceInterface { public: static absl::StatusOr> Create( - bool optimize, bool recursive, bool select_optimization) { + bool optimize, bool recursive, bool select_optimization, + bool enable_variadic_logical_operators) { google::protobuf::LinkMessageReflection< cel::expr::conformance::proto3::TestAllTypes>(); google::protobuf::LinkMessageReflection< @@ -470,8 +477,9 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { options.max_recursion_depth = 48; } - return absl::WrapUnique(new ModernConformanceServiceImpl( - options, optimize, select_optimization)); + return absl::WrapUnique( + new ModernConformanceServiceImpl(options, optimize, select_optimization, + enable_variadic_logical_operators)); } absl::StatusOr> Setup( @@ -523,7 +531,8 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { void Parse(const conformance::v1alpha1::ParseRequest& request, conformance::v1alpha1::ParseResponse& response) override { auto status = - LegacyParse(request, response, /*enable_optional_syntax=*/true); + LegacyParse(request, response, /*enable_optional_syntax=*/true, + enable_variadic_logical_operators_); if (!status.ok()) { auto* issue = response.add_issues(); issue->set_code(ToGrpcCode(status.code())); @@ -614,10 +623,12 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { private: ModernConformanceServiceImpl(const RuntimeOptions& options, bool enable_optimizations, - bool enable_select_optimization) + bool enable_select_optimization, + bool enable_variadic_logical_operators) : options_(options), enable_optimizations_(enable_optimizations), - enable_select_optimization_(enable_select_optimization) {} + enable_select_optimization_(enable_select_optimization), + enable_variadic_logical_operators_(enable_variadic_logical_operators) {} static absl::StatusOr> Plan( const cel::Runtime& runtime, @@ -648,6 +659,7 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { RuntimeOptions options_; bool enable_optimizations_; bool enable_select_optimization_; + bool enable_variadic_logical_operators_; }; } // namespace @@ -660,10 +672,12 @@ absl::StatusOr> NewConformanceService(const ConformanceServiceOptions& options) { if (options.modern) { return google::api::expr::runtime::ModernConformanceServiceImpl::Create( - options.optimize, options.recursive, options.select_optimization); + options.optimize, options.recursive, options.select_optimization, + options.enable_variadic_logical_operators); } else { return google::api::expr::runtime::LegacyConformanceServiceImpl::Create( - options.optimize, options.recursive, options.select_optimization); + options.optimize, options.recursive, options.select_optimization, + options.enable_variadic_logical_operators); } } diff --git a/conformance/service.h b/conformance/service.h index 2dd2abf32..8eb97296e 100644 --- a/conformance/service.h +++ b/conformance/service.h @@ -46,6 +46,7 @@ struct ConformanceServiceOptions { bool arena; bool recursive; bool select_optimization; + bool enable_variadic_logical_operators = false; }; absl::StatusOr> diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index ed8e4d20c..f7300cb58 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -193,6 +193,7 @@ cc_test( "//internal:status_macros", "//internal:testing", "//parser", + "//parser:options", "//runtime:function", "//runtime:function_adapter", "//runtime:runtime_options", diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index d6ccdf040..fc6d87b16 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -2154,7 +2154,7 @@ void BinaryCondVisitor::PreVisit(const cel::Expr* expr) { case BinaryCond::kOr: visitor_->ValidateOrError( !expr->call_expr().has_target() && - expr->call_expr().args().size() == 2, + expr->call_expr().args().size() >= 2, "Invalid argument count for a binary function call."); break; case BinaryCond::kOptionalOr: @@ -2172,28 +2172,40 @@ void BinaryCondVisitor::PostVisitArg(int arg_num, const cel::Expr* expr) { return; } const int last_arg_index = expr->call_expr().args().size() - 1; - if (short_circuiting_ && arg_num < last_arg_index && - (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr)) { - // If first branch evaluation result is enough to determine output, - // jump over the second branch and provide result of the first argument as - // final output. - // Retain pointers to the jump steps so we can update the target after - // planning the next arguments. - std::unique_ptr jump_step; - switch (cond_) { - case BinaryCond::kAnd: - jump_step = CreateCondJumpStep(false, true, {}, expr->id()); - break; - case BinaryCond::kOr: - jump_step = CreateCondJumpStep(true, true, {}, expr->id()); - break; - default: - ABSL_UNREACHABLE(); + if (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr) { + if (arg_num > 0) { + switch (cond_) { + case BinaryCond::kAnd: + visitor_->AddStep(CreateAndStep(expr->id())); + break; + case BinaryCond::kOr: + visitor_->AddStep(CreateOrStep(expr->id())); + break; + default: + break; + } + if (short_circuiting_ && !jump_steps_.empty()) { + visitor_->SetProgressStatusIfError( + jump_steps_.back().set_target(visitor_->GetCurrentIndex())); + } } - ProgramStepIndex index = visitor_->GetCurrentIndex(); - if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step)); - jump_step_ptr) { - jump_steps_.push_back(Jump(index, jump_step_ptr)); + if (short_circuiting_ && arg_num < last_arg_index) { + std::unique_ptr jump_step; + switch (cond_) { + case BinaryCond::kAnd: + jump_step = CreateCondJumpStep(false, true, {}, expr->id()); + break; + case BinaryCond::kOr: + jump_step = CreateCondJumpStep(true, true, {}, expr->id()); + break; + default: + ABSL_UNREACHABLE(); + } + ProgramStepIndex index = visitor_->GetCurrentIndex(); + if (JumpStepBase* jump_step_ptr = visitor_->AddStep(std::move(jump_step)); + jump_step_ptr) { + jump_steps_.push_back(Jump(index, jump_step_ptr)); + } } } } @@ -2251,17 +2263,9 @@ void BinaryCondVisitor::PostVisit(const cel::Expr* expr) { return; } - int args_count = (cond_ == BinaryCond::kAnd || cond_ == BinaryCond::kOr) - ? expr->call_expr().args().size() - : 2; - for (int i = 0; i < args_count - 1; ++i) { + if (cond_ == BinaryCond::kOptionalOr || + cond_ == BinaryCond::kOptionalOrValue) { switch (cond_) { - case BinaryCond::kAnd: - visitor_->AddStep(CreateAndStep(expr->id())); - break; - case BinaryCond::kOr: - visitor_->AddStep(CreateOrStep(expr->id())); - break; case BinaryCond::kOptionalOr: visitor_->AddStep( CreateOptionalOrStep(/*is_or_value=*/false, expr->id())); @@ -2273,13 +2277,11 @@ void BinaryCondVisitor::PostVisit(const cel::Expr* expr) { default: ABSL_UNREACHABLE(); } - } - if (short_circuiting_) { - // If short-circuiting is enabled, point the conditional jump past the - // boolean operator step. - for (auto& jump : jump_steps_) { - visitor_->SetProgressStatusIfError( - jump.set_target(visitor_->GetCurrentIndex())); + if (short_circuiting_) { + for (auto& jump : jump_steps_) { + visitor_->SetProgressStatusIfError( + jump.set_target(visitor_->GetCurrentIndex())); + } } } } diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index e2581e3fd..105060282 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -64,6 +64,7 @@ #include "internal/proto_matchers.h" #include "internal/status_macros.h" #include "internal/testing.h" +#include "parser/options.h" #include "parser/parser.h" #include "runtime/function.h" #include "runtime/function_adapter.h" @@ -2916,7 +2917,11 @@ class FlatExprBuilderVariadicLogicalTest TEST_P(FlatExprBuilderVariadicLogicalTest, Evaluate) { const auto& test_case = GetParam(); - ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, parser::Parse(test_case.expr)); + parser::ParserOptions parser_options; + parser_options.enable_variadic_logical_operators = true; + ASSERT_OK_AND_ASSIGN( + ParsedExpr parsed_expr, + parser::Parse(test_case.expr, test_case.label, parser_options)); cel::RuntimeOptions options; options.unknown_processing =