From 5a36384e2d058ca1f4b7d2b6aa8ae64183bee74f Mon Sep 17 00:00:00 2001 From: Eric Astor Date: Thu, 11 Jun 2026 15:49:02 -0700 Subject: [PATCH] [opt] Support lifting operations through OneHotSelect This change introduces safety checks for lifting operations through OneHotSelect based on whether the operation preserves zero and distributes over bitwise OR. This allows more aggressive optimization of OneHotSelect nodes while maintaining correctness. If the OneHotSelect's selector might be zero, lifting the operation is only safe if the operation preserves zero; i.e., replacing input X by 0 makes the output 0. If the OneHotSelect's selector might have more than one bit set, lifting the operation is only safe if the operation distributes over bitwise OR; i.e., f(replace input X with A | B) produces the same result as replacing the result with f(replace input X with A) | f(replace input X with B). PiperOrigin-RevId: 930781718 --- xls/passes/BUILD | 7 + xls/passes/select_lifting_pass.cc | 285 ++++++++++++++++--------- xls/passes/select_lifting_pass_test.cc | 208 +++++++++++++++++- 3 files changed, 396 insertions(+), 104 deletions(-) diff --git a/xls/passes/BUILD b/xls/passes/BUILD index 7632f6e402..cbcb73da0c 100644 --- a/xls/passes/BUILD +++ b/xls/passes/BUILD @@ -1144,9 +1144,15 @@ xls_pass( hdrs = ["select_lifting_pass.h"], pass_class = "SelectLiftingPass", deps = [ + ":bdd_query_engine", ":critical_path_delay_analysis", + ":lazy_ternary_query_engine", ":optimization_pass", + ":partial_info_query_engine", ":pass_base", + ":query_engine", + ":stateless_query_engine", + ":union_query_engine", "//xls/common/status:ret_check", "//xls/common/status:status_macros", "//xls/estimators/delay_model:delay_estimator", @@ -4602,6 +4608,7 @@ cc_test( "//xls/ir", "//xls/ir:bits", "//xls/ir:function_builder", + "//xls/ir:ir_matcher", "//xls/ir:ir_test_base", "//xls/ir:op", "//xls/solvers:z3_ir_equivalence_testutils", diff --git a/xls/passes/select_lifting_pass.cc b/xls/passes/select_lifting_pass.cc index dc7a2c04da..7844de099c 100644 --- a/xls/passes/select_lifting_pass.cc +++ b/xls/passes/select_lifting_pass.cc @@ -15,6 +15,7 @@ #include "xls/passes/select_lifting_pass.h" #include +#include #include #include #include @@ -43,9 +44,15 @@ #include "xls/ir/source_location.h" #include "xls/ir/type.h" #include "xls/ir/value.h" +#include "xls/passes/bdd_query_engine.h" #include "xls/passes/critical_path_delay_analysis.h" +#include "xls/passes/lazy_ternary_query_engine.h" #include "xls/passes/optimization_pass.h" +#include "xls/passes/partial_info_query_engine.h" #include "xls/passes/pass_base.h" +#include "xls/passes/query_engine.h" +#include "xls/passes/stateless_query_engine.h" +#include "xls/passes/union_query_engine.h" namespace xls { @@ -72,21 +79,7 @@ struct LiftedOpInfo { bool default_is_identity; // True if the default case is an identity. }; -std::optional GetDefaultValue(Node* select) { - if (select->Is()) { - return select->As()->default_value(); - } - CHECK(select->Is()->default_value(); -} -absl::Span GetCases(Node* select) { - if (select->Is()) { - return select->As()->cases(); - } - CHECK(select->Is()->cases(); -} bool MatchesIndexBitwidth(ArrayIndex* ai, int64_t shared_index_bitwidth) { absl::Span current_case_indices = ai->indices(); @@ -365,6 +358,45 @@ std::optional GetLiftableOperationInfoForOp( }; } +// Returns true if the operation preserves zero, i.e., op(0) == 0. +// This is the safety requirement for lifting across a OneHotSelect where the +// selector might be zero. +bool OpPreservesZero(Op op) { + switch (op) { + case Op::kAnd: + case Op::kUMul: + case Op::kSMul: + case Op::kZeroExt: + case Op::kSignExt: + case Op::kTupleIndex: + case Op::kReverse: + case Op::kOrReduce: + return true; + default: + return false; + } +} + +// Returns true if the operation distributes over bitwise OR, i.e., +// op(A v B) == op(A) v op(B). +// This is the safety requirement for lifting across a OneHotSelect where +// more than one bit might be active. +bool OpDistributesOverOr(Op op) { + switch (op) { + case Op::kAnd: + case Op::kOr: + case Op::kZeroExt: + case Op::kSignExt: + case Op::kTupleIndex: + case Op::kReverse: + case Op::kOrReduce: + case Op::kConcat: + return true; + default: + return false; + } +} + // Attempts to find a liftable operation in the select cases. std::optional GetLiftableOperationInfo( absl::Span cases, std::optional default_case, @@ -383,14 +415,27 @@ std::optional GetLiftableOperationInfo( return std::nullopt; } +std::unique_ptr GetQueryEngine( + FunctionBase* f, const OptimizationPassOptions& options, + OptimizationContext& context) { + if (options.opt_level < 3) { + return std::make_unique(UnionQueryEngine::Of( + StatelessQueryEngine(), + context.SharedQueryEngine(f))); + } + return std::make_unique( + UnionQueryEngine::Of(StatelessQueryEngine(), + context.SharedQueryEngine(f), + context.SharedQueryEngine(f))); +} + absl::StatusOr> CanLiftSelect( - FunctionBase* func, Node* select_to_optimize) { + FunctionBase* func, GenericSelect select_to_optimize, + const QueryEngine& query_engine) { VLOG(3) << " Checking the applicability guard"; - // Only "select" nodes with specific properties can be optimized by this - // transformation. - absl::Span cases = GetCases(select_to_optimize); - std::optional default_case = GetDefaultValue(select_to_optimize); + absl::Span cases = select_to_optimize.cases(); + std::optional default_case = select_to_optimize.default_value(); if (cases.empty()) { VLOG(3) << " Select has no cases, not liftable."; @@ -418,15 +463,49 @@ absl::StatusOr> CanLiftSelect( for (Node* potential_shared : potential_shared_nodes) { std::optional info = GetLiftableOperationInfo(cases, default_case, potential_shared); - if (info.has_value()) { - return info; + if (!info.has_value()) { + continue; } + if (select_to_optimize.AsNode()->op() == Op::kOneHotSel) { + if (!OpPreservesZero(info->lifted_op) && + !query_engine.AtLeastOneBitTrue(select_to_optimize.selector())) { + VLOG(3) << " OneHotSelect selector is not provably non-zero, and " + << OpToString(info->lifted_op) + << " is not safe to commute over a zero selector. Skipping."; + continue; // Try other potential shared nodes + } + if (!OpDistributesOverOr(info->lifted_op) && + !query_engine.AtMostOneBitTrue(select_to_optimize.selector())) { + VLOG(3) << " OneHotSelect selector could have more than one bit " + "set, and " + << OpToString(info->lifted_op) + << " is not safe to commute over a selector with multiple " + "bits set. Skipping."; + continue; // Try other potential shared nodes + } + } + return info; } // Check for ArrayIndex lifting opportunity. std::optional shared_array = CheckArrayIndexLiftable(cases, default_case); if (shared_array.has_value()) { + if (select_to_optimize.AsNode()->op() == Op::kOneHotSel) { + if (!query_engine.AtMostOneBitTrue(select_to_optimize.selector())) { + VLOG(3) + << " OneHotSelect selector could have more than one bit set, " + "and ArrayIndex is not safe to commute over a selector with " + "multiple bits set. Skipping."; + return std::nullopt; + } + if (!query_engine.AtLeastOneBitTrue(select_to_optimize.selector())) { + VLOG(3) << " OneHotSelect selector is not provably non-zero, and " + "ArrayIndex is not safe to commute without that guarantee. " + "Skipping."; + return std::nullopt; + } + } // Build LiftedOpInfo for ArrayIndex std::vector index_operands; for (Node* case_node : cases) { @@ -451,23 +530,10 @@ absl::StatusOr> CanLiftSelect( return std::nullopt; } -absl::StatusOr MakeSelectNode(FunctionBase* func, Node* old_select, - const std::vector& new_cases, - std::optional new_default) { - if (old_select->Is()) { - return func->MakeNode( - SourceInfo(), old_select->As()->selector(), new_cases, - *new_default); - } else { - return func->MakeNode()->selector(), - new_cases, new_default); - } -} - absl::StatusOr CheckLatencyIncrease( - FunctionBase* func, Node* select_to_optimize, const LiftedOpInfo& info, - const OptimizationPassOptions& options, OptimizationContext& context) { + FunctionBase* func, GenericSelect select_to_optimize, + const LiftedOpInfo& info, const OptimizationPassOptions& options, + OptimizationContext& context) { const DelayEstimator* delay_estimator = options.delay_estimator; XLS_RET_CHECK(delay_estimator != nullptr) << "No delay estimator configured."; XLS_ASSIGN_OR_RETURN( @@ -476,7 +542,7 @@ absl::StatusOr CheckLatencyIncrease( _ << "Failed to get CriticalPathDelayAnalysis for delay model: " << delay_estimator->name()); // Check the (unscheduled) critical path through the select we're optimizing - int64_t t_before = *analysis->GetInfo(select_to_optimize); + int64_t t_before = *analysis->GetInfo(select_to_optimize.AsNode()); // To make it easy to estimate the critical path after lifting the select, we // add nodes to represent the post-optimization result. @@ -500,10 +566,10 @@ absl::StatusOr CheckLatencyIncrease( }; if (info.lifted_op == Op::kArrayIndex) { - XLS_ASSIGN_OR_RETURN( - tmp_new_select, - MakeSelectNode(func, select_to_optimize, info.other_operands, - info.default_other_operand)); + XLS_ASSIGN_OR_RETURN(tmp_new_select, + select_to_optimize.MakeSelectLikeWithNewArms( + info.other_operands, info.default_other_operand, + select_to_optimize.AsNode()->loc())); XLS_ASSIGN_OR_RETURN( tmp_lifted_op, func->MakeNode(SourceInfo(), info.shared_node, @@ -518,7 +584,7 @@ absl::StatusOr CheckLatencyIncrease( std::vector tmp_new_cases; std::optional tmp_new_default; - absl::Span original_cases = GetCases(select_to_optimize); + absl::Span original_cases = select_to_optimize.cases(); int64_t other_operand_idx = 0; for (int64_t i = 0; i < original_cases.size(); ++i) { if (info.identity_case_indices.contains(i)) { @@ -531,7 +597,7 @@ absl::StatusOr CheckLatencyIncrease( tmp_new_cases.push_back(info.other_operands[other_operand_idx++]); } } - std::optional original_default = GetDefaultValue(select_to_optimize); + std::optional original_default = select_to_optimize.default_value(); if (original_default.has_value()) { if (info.default_is_identity) { XLS_ASSIGN_OR_RETURN( @@ -544,8 +610,9 @@ absl::StatusOr CheckLatencyIncrease( } } XLS_ASSIGN_OR_RETURN(tmp_new_select, - MakeSelectNode(func, select_to_optimize, tmp_new_cases, - tmp_new_default)); + select_to_optimize.MakeSelectLikeWithNewArms( + tmp_new_cases, tmp_new_default, + select_to_optimize.AsNode()->loc())); Node* lhs = info.shared_is_lhs ? info.shared_node : tmp_new_select; Node* rhs = info.shared_is_lhs ? tmp_new_select : info.shared_node; switch (info.lifted_op) { @@ -571,10 +638,11 @@ absl::StatusOr CheckLatencyIncrease( case Op::kUMul: case Op::kSMul: { XLS_ASSIGN_OR_RETURN( - tmp_lifted_op, func->MakeNode( - SourceInfo(), lhs, rhs, - select_to_optimize->GetType()->GetFlatBitCount(), - info.lifted_op)); + tmp_lifted_op, + func->MakeNode( + SourceInfo(), lhs, rhs, + select_to_optimize.AsNode()->GetType()->GetFlatBitCount(), + info.lifted_op)); break; } default: @@ -588,9 +656,9 @@ absl::StatusOr CheckLatencyIncrease( return t_after > t_before; } -absl::StatusOr ProfitabilityGuardForArrayIndex(FunctionBase* func, - Node* select_to_optimize, - Node* array_reference) { +absl::StatusOr ProfitabilityGuardForArrayIndex( + FunctionBase* func, GenericSelect select_to_optimize, + Node* array_reference) { // The next properties when hold guarantee that it is profitable to transform // the "select" node. // @@ -656,7 +724,7 @@ absl::StatusOr ProfitabilityGuardForArrayIndex(FunctionBase* func, Type* array_reference_type = array_reference->GetType(); ArrayType* array_reference_type_as_array_type = array_reference_type->AsArrayOrDie(); - absl::Span select_cases = GetCases(select_to_optimize); + absl::Span select_cases = select_to_optimize.cases(); Type* array_element_type = array_reference_type_as_array_type->element_type(); int64_t array_element_bitwidth = array_element_type->GetFlatBitCount(); for (Node* current_select_case_as_node : select_cases) { @@ -696,8 +764,9 @@ absl::StatusOr ProfitabilityGuardForArrayIndex(FunctionBase* func, } absl::StatusOr ProfitabilityGuardForBinaryOperation( - FunctionBase* func, Node* select_to_optimize, const LiftedOpInfo& info, - const OptimizationPassOptions& options, OptimizationContext& context) { + FunctionBase* func, GenericSelect select_to_optimize, + const LiftedOpInfo& info, const OptimizationPassOptions& options, + OptimizationContext& context) { // If lifting a shift operation combines literal shift amounts into a select, // this creates a variable shift from constant shifts, which is more // expensive. If we don't have a delay model, avoid lifting in this case. @@ -721,9 +790,7 @@ absl::StatusOr ProfitabilityGuardForBinaryOperation( // Heuristically: If the selector depends on `shared_node`, lifting will // likely serialize more operations & worsen the critical path. - Node* selector = select_to_optimize->Is()->selector() - : select_to_optimize->As()->selector(); + Node* selector = select_to_optimize.selector(); if (IsAncestorOf(info.shared_node, selector)) { VLOG(3) << " Selector depends on shared node, avoiding lift due to " "potential latency increase."; @@ -733,8 +800,9 @@ absl::StatusOr ProfitabilityGuardForBinaryOperation( // Calculate Cost Before: // Sum of bitwidths of the original select and any single-use non-identity // case nodes. - int64_t initial_bitwidths = select_to_optimize->GetType()->GetFlatBitCount(); - absl::Span cases = GetCases(select_to_optimize); + int64_t initial_bitwidths = + select_to_optimize.AsNode()->GetType()->GetFlatBitCount(); + absl::Span cases = select_to_optimize.cases(); for (int64_t i = 0; i < cases.size(); ++i) { if (!info.identity_case_indices.contains(i)) { Node* case_node = cases[i]; @@ -743,7 +811,7 @@ absl::StatusOr ProfitabilityGuardForBinaryOperation( } } } - std::optional default_case = GetDefaultValue(select_to_optimize); + std::optional default_case = select_to_optimize.default_value(); if (default_case.has_value() && !info.default_is_identity) { if (HasSingleUse(*default_case)) { initial_bitwidths += (*default_case)->GetType()->GetFlatBitCount(); @@ -766,7 +834,7 @@ absl::StatusOr ProfitabilityGuardForBinaryOperation( // The output width of the lifted op is the same as the original select. int64_t lifted_op_output_width = - select_to_optimize->GetType()->GetFlatBitCount(); + select_to_optimize.AsNode()->GetType()->GetFlatBitCount(); int64_t remaining_bitwidths = new_select_width + lifted_op_output_width; @@ -776,7 +844,7 @@ absl::StatusOr ProfitabilityGuardForBinaryOperation( } absl::StatusOr ShouldLiftSelect(FunctionBase* func, - Node* select_to_optimize, + GenericSelect select_to_optimize, const LiftedOpInfo& info, const OptimizationPassOptions& options, OptimizationContext& context) { @@ -812,7 +880,8 @@ absl::StatusOr ShouldLiftSelect(FunctionBase* func, } absl::StatusOr LiftSelectForArrayIndex( - FunctionBase* func, Node* select_to_optimize, const LiftedOpInfo& info) { + FunctionBase* func, GenericSelect select_to_optimize, + const LiftedOpInfo& info) { TransformationResult result; Node* array_reference = info.shared_node; @@ -823,37 +892,37 @@ absl::StatusOr LiftSelectForArrayIndex( const std::vector& new_cases = info.other_operands; Node* new_select; - XLS_ASSIGN_OR_RETURN( - new_select, - MakeSelectNode(func, select_to_optimize, new_cases, new_default_value)); + XLS_ASSIGN_OR_RETURN(new_select, select_to_optimize.MakeSelectLikeWithNewArms( + new_cases, new_default_value, + select_to_optimize.AsNode()->loc())); // Step 1: add the new array access VLOG(3) << " Step 1: add the new arrayIndex node"; - std::vector new_indices; - new_indices.push_back(new_select); XLS_ASSIGN_OR_RETURN( Node * new_array_index, func->MakeNode(SourceInfo(), array_reference, - absl::Span(new_indices))); + absl::MakeConstSpan({new_select}))); // Step 2: replace the uses of the original "select" node with the only // exception of the new array access VLOG(3) << " Step 2: replace the uses of the original \"select\""; - XLS_RETURN_IF_ERROR(select_to_optimize->ReplaceUsesWith(new_array_index)); - VLOG(3) << " New select : " << select_to_optimize->ToString(); + XLS_RETURN_IF_ERROR( + select_to_optimize.AsNode()->ReplaceUsesWith(new_array_index)); + VLOG(3) << " New select : " + << select_to_optimize.AsNode()->ToString(); VLOG(3) << " New array index: " << new_array_index->ToString(); // Step 3: remove the original "select" node as it just became dead. This is // done by adding such node to the list of nodes to delete at the end of the // main loop of this transformation. VLOG(3) << " Step 3: mark the old \"select\" to be deleted"; - result.nodes_to_delete.insert(select_to_optimize); + result.nodes_to_delete.insert(select_to_optimize.AsNode()); // Step 4: check if new "select" nodes become optimizable. These are users of // the new arrayIndex node VLOG(3) << " Step 4: check if more \"select\" nodes should be considered"; for (Node* user : new_array_index->users()) { - if (user->OpIn({Op::kSel, Op::kPrioritySel})) { + if (user->OpIn({Op::kSel, Op::kPrioritySel, Op::kOneHotSel})) { result.new_selects_to_consider.insert(user); } } @@ -863,9 +932,10 @@ absl::StatusOr LiftSelectForArrayIndex( } absl::StatusOr LiftSelectForBinaryOperation( - FunctionBase* func, Node* select_to_optimize, const LiftedOpInfo& info) { + FunctionBase* func, GenericSelect select_to_optimize, + const LiftedOpInfo& info) { TransformationResult result; - absl::Span original_cases = GetCases(select_to_optimize); + absl::Span original_cases = select_to_optimize.cases(); VLOG(3) << " Step 1: Build new cases for the inner select"; std::vector new_cases; @@ -898,7 +968,7 @@ absl::StatusOr LiftSelectForBinaryOperation( VLOG(3) << " Step 2: Build new default for the inner select"; std::optional new_default; - std::optional original_default = GetDefaultValue(select_to_optimize); + std::optional original_default = select_to_optimize.default_value(); if (original_default.has_value()) { if (info.default_is_identity) { XLS_ASSIGN_OR_RETURN( @@ -913,7 +983,8 @@ absl::StatusOr LiftSelectForBinaryOperation( XLS_ASSIGN_OR_RETURN( Node * new_select, - MakeSelectNode(func, select_to_optimize, new_cases, new_default)); + select_to_optimize.MakeSelectLikeWithNewArms( + new_cases, new_default, select_to_optimize.AsNode()->loc())); VLOG(3) << " Step 3: Create the lifted binary operation"; Node* lhs = info.shared_is_lhs ? info.shared_node : new_select; @@ -942,11 +1013,12 @@ absl::StatusOr LiftSelectForBinaryOperation( } case Op::kUMul: case Op::kSMul: { - XLS_ASSIGN_OR_RETURN(new_binop, - func->MakeNode( - SourceInfo(), lhs, rhs, - select_to_optimize->GetType()->GetFlatBitCount(), - info.lifted_op)); + XLS_ASSIGN_OR_RETURN( + new_binop, + func->MakeNode( + SourceInfo(), lhs, rhs, + select_to_optimize.AsNode()->GetType()->GetFlatBitCount(), + info.lifted_op)); } break; default: return absl::InternalError(absl::StrCat( @@ -955,16 +1027,16 @@ absl::StatusOr LiftSelectForBinaryOperation( } VLOG(3) << " Step 4: Replace uses of the original \"select\""; - XLS_RETURN_IF_ERROR(select_to_optimize->ReplaceUsesWith(new_binop)); + XLS_RETURN_IF_ERROR(select_to_optimize.AsNode()->ReplaceUsesWith(new_binop)); VLOG(3) << " New select: " << new_select->ToString(); VLOG(3) << " New binop : " << new_binop->ToString(); VLOG(3) << " Step 5: mark the old \"select\" to be deleted"; - result.nodes_to_delete.insert(select_to_optimize); + result.nodes_to_delete.insert(select_to_optimize.AsNode()); VLOG(3) << " Step 6: check if more \"select\" nodes should be considered"; for (Node* user : new_binop->users()) { - if (user->OpIn({Op::kSel, Op::kPrioritySel})) { + if (user->OpIn({Op::kSel, Op::kPrioritySel, Op::kOneHotSel})) { result.new_selects_to_consider.insert(user); } } @@ -974,9 +1046,9 @@ absl::StatusOr LiftSelectForBinaryOperation( return result; } -absl::StatusOr LiftSelect(FunctionBase* func, - Node* select_to_optimize, - const LiftedOpInfo& info) { +absl::StatusOr LiftSelect( + FunctionBase* func, GenericSelect select_to_optimize, + const LiftedOpInfo& info) { TransformationResult result; VLOG(3) << " Apply the transformation"; @@ -1014,12 +1086,16 @@ absl::StatusOr LiftSelect(FunctionBase* func, absl::StatusOr LiftSelect( FunctionBase* func, Node* select_to_optimize, - const OptimizationPassOptions& options, OptimizationContext& context) { + const OptimizationPassOptions& options, OptimizationContext& context, + QueryEngine& query_engine) { TransformationResult result; + XLS_ASSIGN_OR_RETURN(GenericSelect sel, + GenericSelect::From(select_to_optimize)); + // Check if it is safe to apply the transformation XLS_ASSIGN_OR_RETURN(std::optional applicability_guard_result, - CanLiftSelect(func, select_to_optimize)); + CanLiftSelect(func, sel, query_engine)); if (!applicability_guard_result) { VLOG(3) << " It is not safe to apply the transformation for this select"; @@ -1031,9 +1107,8 @@ absl::StatusOr LiftSelect( // It is safe to apply the transformation // // Check if it is profitable to apply the transformation - XLS_ASSIGN_OR_RETURN( - bool should_lift, - ShouldLiftSelect(func, select_to_optimize, info, options, context)); + XLS_ASSIGN_OR_RETURN(bool should_lift, + ShouldLiftSelect(func, sel, info, options, context)); if (!should_lift) { VLOG(3) << " This transformation is not profitable for this select"; @@ -1045,7 +1120,7 @@ absl::StatusOr LiftSelect( // It is now the time to apply it. VLOG(3) << " This transformation is applicable and profitable for this " "select"; - XLS_ASSIGN_OR_RETURN(result, LiftSelect(func, select_to_optimize, info)); + XLS_ASSIGN_OR_RETURN(result, LiftSelect(func, sel, info)); return result; } @@ -1053,7 +1128,8 @@ absl::StatusOr LiftSelect( absl::StatusOr LiftSelects( FunctionBase* func, const absl::btree_set& selects_to_consider, - const OptimizationPassOptions& options, OptimizationContext& context) { + const OptimizationPassOptions& options, OptimizationContext& context, + QueryEngine& query_engine) { TransformationResult result; // Try to optimize all "select" nodes @@ -1066,8 +1142,9 @@ absl::StatusOr LiftSelects( VLOG(3) << "Select: " << select_node->ToString(); // Try to optimize the current "select" node - XLS_ASSIGN_OR_RETURN(TransformationResult current_transformation_result, - LiftSelect(func, select_node, options, context)); + XLS_ASSIGN_OR_RETURN( + TransformationResult current_transformation_result, + LiftSelect(func, select_node, options, context, query_engine)); // Accumulate the result of the transformation result.was_code_modified |= current_transformation_result.was_code_modified; @@ -1107,9 +1184,11 @@ absl::StatusOr SelectLiftingPass::RunOnFunctionBaseInternal( // Collect the "select" nodes that might be optimizable VLOG(3) << "Optimizing the function at level " << options.opt_level; + std::unique_ptr query_engine = + GetQueryEngine(func, options, context); for (Node* node : func->nodes()) { // Only consider selects. - if (!node->OpIn({Op::kSel, Op::kPrioritySel})) { + if (!node->OpIn({Op::kSel, Op::kPrioritySel, Op::kOneHotSel})) { continue; } @@ -1128,9 +1207,9 @@ absl::StatusOr SelectLiftingPass::RunOnFunctionBaseInternal( VLOG(3) << " New optimization iteration"; // Optimize all "select" nodes. - XLS_ASSIGN_OR_RETURN( - TransformationResult current_result, - LiftSelects(func, selects_to_consider, options, context)); + XLS_ASSIGN_OR_RETURN(TransformationResult current_result, + LiftSelects(func, selects_to_consider, options, + context, *query_engine)); // Check if we have modified the code. was_code_modified |= current_result.was_code_modified; diff --git a/xls/passes/select_lifting_pass_test.cc b/xls/passes/select_lifting_pass_test.cc index 408ced7674..1c7e539545 100644 --- a/xls/passes/select_lifting_pass_test.cc +++ b/xls/passes/select_lifting_pass_test.cc @@ -35,7 +35,9 @@ #include "xls/ir/bits.h" #include "xls/ir/function.h" #include "xls/ir/function_builder.h" +#include "xls/ir/ir_matcher.h" #include "xls/ir/ir_test_base.h" +#include "xls/ir/lsb_or_msb.h" #include "xls/ir/node.h" #include "xls/ir/nodes.h" #include "xls/ir/op.h" @@ -45,8 +47,9 @@ #include "xls/passes/pass_base.h" #include "xls/solvers/z3_ir_equivalence_testutils.h" -namespace xls { +namespace m = ::xls::op_matchers; +namespace xls { namespace { class FakeDelayEstimator : public DelayEstimator { @@ -947,6 +950,209 @@ TEST_F(SelectLiftingPassTest, DontLiftMulWithIdentityIfLatencyIncreases) { EXPECT_THAT(Run(f, opts), absl_testing::IsOkAndHolds(false)); } +TEST_F(SelectLiftingPassTest, LiftThroughOneHotSelectWithOneHotSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue x = fb.Param("x", u32_type); + BValue y = fb.Param("y", u32_type); + BValue selector = fb.Param("selector", p->GetBitsType(2)); + + BValue oh_selector = fb.OneHot(selector, LsbOrMsb::kLsb); + + BValue zero = fb.Literal(UBits(0, 32)); + BValue x_add_zero = fb.Add(x, zero); + BValue x_add_y = fb.Add(x, y); + + BValue ohs = fb.OneHotSelect(oh_selector, {x_add_zero, x_add_y, x_add_zero}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(true)); + + EXPECT_THAT( + f->return_value(), + m::Add(m::Param("x"), + m::OneHotSelect(m::OneHot(m::Param("selector"), LsbOrMsb::kLsb), + {m::Literal(0), m::Param("y"), m::Literal(0)}))); +} + +TEST_F(SelectLiftingPassTest, NoLiftThroughOneHotSelectWithoutOneHotSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue x = fb.Param("x", u32_type); + BValue y = fb.Param("y", u32_type); + BValue selector = fb.Param("selector", p->GetBitsType(3)); + + BValue zero = fb.Literal(UBits(0, 32)); + BValue x_add_zero = fb.Add(x, zero); + BValue x_add_y = fb.Add(x, y); + + BValue ohs = fb.OneHotSelect(selector, {x_add_zero, x_add_y, x_add_zero}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(false)); +} + +TEST_F(SelectLiftingPassTest, LiftAndThroughOneHotSelect) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue x = fb.Param("x", u32_type); + BValue y = fb.Param("y", u32_type); + BValue z = fb.Param("z", u32_type); + BValue selector = fb.Param("selector", p->GetBitsType(2)); + + BValue x_and_y = fb.And(x, y); + BValue x_and_z = fb.And(x, z); + + BValue ohs = fb.OneHotSelect(selector, {x_and_y, x_and_z}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(true)); + + EXPECT_THAT( + f->return_value(), + m::And(m::Param("x"), m::OneHotSelect(m::Param("selector"), + {m::Param("y"), m::Param("z")}))); +} + +TEST_F(SelectLiftingPassTest, LiftSubThroughOneHotSelectWithOneHotSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue val = fb.Param("val", u32_type); + BValue x = fb.Param("x", u32_type); + BValue y = fb.Param("y", u32_type); + BValue selector = fb.Param("selector", p->GetBitsType(1)); + + BValue oh_selector = fb.OneHot(selector, LsbOrMsb::kLsb); + + BValue sub_x = fb.Subtract(val, x); + BValue sub_y = fb.Subtract(val, y); + + BValue ohs = fb.OneHotSelect(oh_selector, {sub_x, sub_y}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(true)); + + EXPECT_THAT( + f->return_value(), + m::Sub(m::Param("val"), + m::OneHotSelect(m::OneHot(m::Param("selector"), LsbOrMsb::kLsb), + {m::Param("x"), m::Param("y")}))); +} + +TEST_F(SelectLiftingPassTest, + LiftMulThroughOneHotSelectWithAtMostOneHotSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue val = fb.Param("val", u32_type); + BValue x = fb.Param("x", u32_type); + BValue a = fb.Param("a", u32_type); + BValue b = fb.Param("b", u32_type); + + BValue eq_1 = fb.Eq(x, fb.Literal(UBits(1, 32))); + BValue eq_2 = fb.Eq(x, fb.Literal(UBits(2, 32))); + BValue selector = fb.Concat({eq_2, eq_1}); + + BValue mul_a = fb.UMul(val, a); + BValue mul_b = fb.UMul(val, b); + + BValue ohs = fb.OneHotSelect(selector, {mul_a, mul_b}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(true)); + + EXPECT_THAT( + f->return_value(), + m::UMul(m::Param("val"), + m::OneHotSelect(m::Concat(m::Eq(m::Param("x"), m::Literal(2)), + m::Eq(m::Param("x"), m::Literal(1))), + {m::Param("a"), m::Param("b")}))); +} + +TEST_F(SelectLiftingPassTest, + NoLiftMulThroughOneHotSelectWithNonOneHotSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue val = fb.Param("val", u32_type); + BValue selector = fb.Param("selector", p->GetBitsType(2)); + BValue a = fb.Param("a", u32_type); + BValue b = fb.Param("b", u32_type); + + BValue mul_a = fb.UMul(val, a); + BValue mul_b = fb.UMul(val, b); + + BValue ohs = fb.OneHotSelect(selector, {mul_a, mul_b}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(false)); +} + +TEST_F(SelectLiftingPassTest, LiftOrThroughOneHotSelectWithNonzeroSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue a = fb.Param("a", u32_type); + BValue b = fb.Param("b", u32_type); + BValue param_p = fb.Param("p", p->GetBitsType(2)); + + BValue selector = fb.Or(param_p, fb.Literal(UBits(1, 2))); + + BValue constant = fb.Literal(UBits(42, 32)); + BValue or_a = fb.Or(a, constant); + BValue or_b = fb.Or(b, constant); + + BValue ohs = fb.OneHotSelect(selector, {or_a, or_b}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + // This should fail initially (RED state) + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(true)); + + EXPECT_THAT( + f->return_value(), + m::Or(m::Literal(42), m::OneHotSelect(m::Or(m::Param("p"), m::Literal(1)), + {m::Param("a"), m::Param("b")}))); +} + +TEST_F(SelectLiftingPassTest, + NoLiftOrThroughOneHotSelectWithMaybeZeroSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue a = fb.Param("a", u32_type); + BValue b = fb.Param("b", u32_type); + BValue selector = fb.Param("selector", p->GetBitsType(2)); + + BValue constant = fb.Literal(UBits(42, 32)); + BValue or_a = fb.Or(a, constant); + BValue or_b = fb.Or(b, constant); + + BValue ohs = fb.OneHotSelect(selector, {or_a, or_b}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(false)); +} + FUZZ_TEST(IrFuzzTest, IrFuzzSelectLifting) .WithDomains(IrFuzzDomainWithArgs(/*arg_set_count=*/10));