From 68af8a22ee89ab7ca848bd807974b2d2c3a27901 Mon Sep 17 00:00:00 2001 From: Eric Astor Date: Fri, 12 Jun 2026 14:45:11 -0700 Subject: [PATCH] [opt] Support lifting operations through OneHotSelect in more circumstances This change extends SelectLiftingPass to lift operations through OneHotSelect, including comparisons and shifts. It introduces a safety analysis based on what's known about the OneHotSelect's selector's state (zero-hot, one-hot, multi-hot) and the behavior of the operations. PiperOrigin-RevId: 931343361 --- xls/passes/BUILD | 7 + xls/passes/select_lifting_pass.cc | 497 +++++++++++++++++++------ xls/passes/select_lifting_pass_test.cc | 383 ++++++++++++++++++- 3 files changed, 764 insertions(+), 123 deletions(-) diff --git a/xls/passes/BUILD b/xls/passes/BUILD index 7632f6e402..cbcb73da0c 100644 --- a/xls/passes/BUILD +++ b/xls/passes/BUILD @@ -1144,9 +1144,15 @@ xls_pass( hdrs = ["select_lifting_pass.h"], pass_class = "SelectLiftingPass", deps = [ + ":bdd_query_engine", ":critical_path_delay_analysis", + ":lazy_ternary_query_engine", ":optimization_pass", + ":partial_info_query_engine", ":pass_base", + ":query_engine", + ":stateless_query_engine", + ":union_query_engine", "//xls/common/status:ret_check", "//xls/common/status:status_macros", "//xls/estimators/delay_model:delay_estimator", @@ -4602,6 +4608,7 @@ cc_test( "//xls/ir", "//xls/ir:bits", "//xls/ir:function_builder", + "//xls/ir:ir_matcher", "//xls/ir:ir_test_base", "//xls/ir:op", "//xls/solvers:z3_ir_equivalence_testutils", diff --git a/xls/passes/select_lifting_pass.cc b/xls/passes/select_lifting_pass.cc index dc7a2c04da..ad5f8c9555 100644 --- a/xls/passes/select_lifting_pass.cc +++ b/xls/passes/select_lifting_pass.cc @@ -15,6 +15,7 @@ #include "xls/passes/select_lifting_pass.h" #include +#include #include #include #include @@ -43,9 +44,15 @@ #include "xls/ir/source_location.h" #include "xls/ir/type.h" #include "xls/ir/value.h" +#include "xls/passes/bdd_query_engine.h" #include "xls/passes/critical_path_delay_analysis.h" +#include "xls/passes/lazy_ternary_query_engine.h" #include "xls/passes/optimization_pass.h" +#include "xls/passes/partial_info_query_engine.h" #include "xls/passes/pass_base.h" +#include "xls/passes/query_engine.h" +#include "xls/passes/stateless_query_engine.h" +#include "xls/passes/union_query_engine.h" namespace xls { @@ -72,21 +79,7 @@ struct LiftedOpInfo { bool default_is_identity; // True if the default case is an identity. }; -std::optional GetDefaultValue(Node* select) { - if (select->Is()) { - return select->As()->default_value(); - } - CHECK(select->Is()->default_value(); -} -absl::Span GetCases(Node* select) { - if (select->Is()) { - return select->As()->cases(); - } - CHECK(select->Is()->cases(); -} bool MatchesIndexBitwidth(ArrayIndex* ai, int64_t shared_index_bitwidth) { absl::Span current_case_indices = ai->indices(); @@ -205,6 +198,25 @@ std::optional CheckArrayIndexLiftable( return ApplicabilityGuardForArrayIndex(cases, default_case); } +// Returns true if the operation has a registered identity literal. +bool HasIdentityLiteral(Op op) { + switch (op) { + case Op::kAdd: + case Op::kSub: + case Op::kOr: + case Op::kXor: + case Op::kShll: + case Op::kShrl: + case Op::kShra: + case Op::kAnd: + case Op::kUMul: + case Op::kSMul: + return true; + default: + return false; + } +} + // Creates a Literal node representing the right-identity for the given // operation. e.g., 0 for Add/Sub/Xor, -1 for And, 1 for Mul. // Note that all operations in kLiftableBinaryOps have a constant @@ -348,6 +360,14 @@ std::optional GetLiftableOperationInfoForOp( return std::nullopt; } + if ((!identity_case_indices.empty() || default_is_identity) && + !HasIdentityLiteral(test_op)) { + VLOG(3) << "Cannot lift: operation " << OpToString(test_op) + << " does not have a registered identity literal, but identity " + "cases are present."; + return std::nullopt; + } + // If the op is commutative, we can always act as if the shared node was on // the LHS. if (op_is_commutative) { @@ -365,13 +385,109 @@ std::optional GetLiftableOperationInfoForOp( }; } +// Returns true if the operation preserves zero, i.e., op(0) == 0. +// For comparison and shift operations, this is conditional on the value +// and position of the shared operand X. +bool OpPreservesZero(const LiftedOpInfo& info, + const QueryEngine& query_engine) { + Op op = info.lifted_op; + switch (op) { + case Op::kAnd: + case Op::kUMul: + case Op::kSMul: + case Op::kZeroExt: + case Op::kSignExt: + case Op::kTupleIndex: + case Op::kReverse: + case Op::kOrReduce: + return true; + + // If the select is the value being shifted, then this is always true; + // shifting 0 by any amount is 0. + // + // It's technically also safe if the value being shifted is identically + // zero, but that should be simplified to zero elsewhere. + case Op::kShll: + case Op::kShrl: + case Op::kShra: { + return !info.shared_is_lhs; + } + + // For comparisons, the safety depends on whether the shared node is 0 (and + // for inequalities, which side it's on). + case Op::kEq: { + if (!info.shared_node->GetType()->IsBits()) { + return false; + } + // If the shared node is non-zero, then comparing shared == 0 will be + // false - so it returns 0. + return query_engine.AtLeastOneBitTrue(info.shared_node); + } + case Op::kNe: { + // If the shared node is zero, then comparing shared != 0 will be false - + // so it returns 0. + return query_engine.IsAllZeros(info.shared_node); + } + case Op::kULe: { + // If the shared node is the LHS and non-zero, then comparing shared <= 0 + // will be false - so it returns 0. + return info.shared_is_lhs && + query_engine.AtLeastOneBitTrue(info.shared_node); + } + case Op::kUGe: { + // If the shared node is the RHS and non-zero, then comparing 0 <= shared + // will be false - so it returns 0. + return !info.shared_is_lhs && + query_engine.AtLeastOneBitTrue(info.shared_node); + } + case Op::kULt: { + // If the shared node is the LHS, then comparing shared < 0 will always be + // false, so it returns 0. + // If it's the RHS, then comparing 0 < shared will + // be only be false if shared is 0. + return info.shared_is_lhs || query_engine.IsAllZeros(info.shared_node); + } + case Op::kUGt: { + // If the shared node is the RHS, then comparing 0 > shared will always be + // false, so it returns 0. + // If it's the LHS, then comparing shared > 0 will + // be only be false if shared is 0. + return !info.shared_is_lhs || query_engine.IsAllZeros(info.shared_node); + } + + default: + return false; + } +} + +// Returns true if the operation distributes over bitwise OR, i.e., +// op(A v B) == op(A) v op(B). +// This is the safety requirement for lifting across a OneHotSelect where +// more than one bit might be active. +bool OpDistributesOverOr(Op op) { + switch (op) { + case Op::kAnd: + case Op::kOr: + case Op::kZeroExt: + case Op::kSignExt: + case Op::kTupleIndex: + case Op::kReverse: + case Op::kOrReduce: + case Op::kConcat: + return true; + default: + return false; + } +} + // Attempts to find a liftable operation in the select cases. std::optional GetLiftableOperationInfo( absl::Span cases, std::optional default_case, Node* potential_shared_node) { constexpr Op kLiftableBinaryOps[] = { - Op::kAdd, Op::kSub, Op::kAnd, Op::kOr, Op::kXor, - Op::kUMul, Op::kSMul, Op::kShll, Op::kShrl, Op::kShra}; + Op::kAdd, Op::kSub, Op::kAnd, Op::kOr, Op::kXor, Op::kUMul, + Op::kSMul, Op::kShll, Op::kShrl, Op::kShra, Op::kEq, Op::kNe, + Op::kULe, Op::kUGe, Op::kULt, Op::kUGt}; for (Op test_op : kLiftableBinaryOps) { std::optional info = GetLiftableOperationInfoForOp( @@ -383,14 +499,54 @@ std::optional GetLiftableOperationInfo( return std::nullopt; } +std::unique_ptr GetQueryEngine( + FunctionBase* f, const OptimizationPassOptions& options, + OptimizationContext& context) { + if (options.opt_level < 3) { + return std::make_unique(UnionQueryEngine::Of( + StatelessQueryEngine(), + context.SharedQueryEngine(f))); + } + return std::make_unique( + UnionQueryEngine::Of(StatelessQueryEngine(), + context.SharedQueryEngine(f), + context.SharedQueryEngine(f))); +} + +// Returns true if the 0-th element of the array is proven to be all-zeros. +// Natively supports nested arrays, tuples, and structs by slicing the leaf +// tree. +bool IsFirstElementZero(Node* array, const QueryEngine& query_engine) { + if (!array->GetType()->IsArray()) { + return false; + } + // Fast-path: If it is structurally an Array node, check its 0-th operand + // directly + if (array->Is()) { + return query_engine.IsAllZeros(array->operand(0)); + } + // Fallback: Slice the ternary tree at index 0 and check all leaves + std::optional> ternary = + query_engine.GetTernary(array); + if (!ternary.has_value()) { + return false; + } + LeafTypeTreeView element_view = ternary->AsView({0}); + for (const TernaryVector& leaf_vector : element_view.elements()) { + if (!ternary_ops::IsKnownZero(leaf_vector)) { + return false; + } + } + return true; +} + absl::StatusOr> CanLiftSelect( - FunctionBase* func, Node* select_to_optimize) { + FunctionBase* func, GenericSelect select_to_optimize, + const QueryEngine& query_engine) { VLOG(3) << " Checking the applicability guard"; - // Only "select" nodes with specific properties can be optimized by this - // transformation. - absl::Span cases = GetCases(select_to_optimize); - std::optional default_case = GetDefaultValue(select_to_optimize); + absl::Span cases = select_to_optimize.cases(); + std::optional default_case = select_to_optimize.default_value(); if (cases.empty()) { VLOG(3) << " Select has no cases, not liftable."; @@ -418,15 +574,60 @@ absl::StatusOr> CanLiftSelect( for (Node* potential_shared : potential_shared_nodes) { std::optional info = GetLiftableOperationInfo(cases, default_case, potential_shared); - if (info.has_value()) { - return info; + if (!info.has_value()) { + continue; + } + if (select_to_optimize.AsNode()->op() == Op::kOneHotSel) { + bool preserves_zero = OpPreservesZero(*info, query_engine); + bool at_least_one_bit = + query_engine.AtLeastOneBitTrue(select_to_optimize.selector()); + bool distributes_or = OpDistributesOverOr(info->lifted_op); + bool at_most_one_bit = + query_engine.AtMostOneBitTrue(select_to_optimize.selector()); + VLOG(3) << "SelectLifting: Checking Op " << OpToString(info->lifted_op); + VLOG(3) << " preserves_zero: " << preserves_zero + << ", at_least_one_bit: " << at_least_one_bit; + VLOG(3) << " distributes_or: " << distributes_or + << ", at_most_one_bit: " << at_most_one_bit; + + if (!preserves_zero && !at_least_one_bit) { + VLOG(3) << " OneHotSelect selector is not provably non-zero, and " + << OpToString(info->lifted_op) + << " is not safe to commute over a zero selector. Skipping."; + continue; // Try other potential shared nodes + } + if (!distributes_or && !at_most_one_bit) { + VLOG(3) << " OneHotSelect selector could have more than one bit " + "set, and " + << OpToString(info->lifted_op) + << " is not safe to commute over a selector with multiple " + "bits set. Skipping."; + continue; // Try other potential shared nodes + } } + return info; } // Check for ArrayIndex lifting opportunity. std::optional shared_array = CheckArrayIndexLiftable(cases, default_case); if (shared_array.has_value()) { + if (select_to_optimize.AsNode()->op() == Op::kOneHotSel) { + if (!query_engine.ExactlyOneBitTrue(select_to_optimize.selector())) { + // Class B ArrayIndex lifting requires AtMostOneBitTrue AND A[0] == 0! + if (query_engine.AtMostOneBitTrue(select_to_optimize.selector()) && + IsFirstElementZero(shared_array.value(), query_engine)) { + VLOG(3) << " OneHotSelect selector is Class B, and first array " + "element " + << shared_array.value()->ToString() + << "[0] is zero. Proceeding."; + } else { + VLOG(3) << " OneHotSelect selector is not Class A, and ArrayIndex " + "is not safe. Skipping."; + return std::nullopt; + } + } + } // Build LiftedOpInfo for ArrayIndex std::vector index_operands; for (Node* case_node : cases) { @@ -451,23 +652,10 @@ absl::StatusOr> CanLiftSelect( return std::nullopt; } -absl::StatusOr MakeSelectNode(FunctionBase* func, Node* old_select, - const std::vector& new_cases, - std::optional new_default) { - if (old_select->Is()) { - return func->MakeNode( - SourceInfo(), old_select->As()->selector(), new_cases, - *new_default); - } else { - return func->MakeNode()->selector(), - new_cases, new_default); - } -} - absl::StatusOr CheckLatencyIncrease( - FunctionBase* func, Node* select_to_optimize, const LiftedOpInfo& info, - const OptimizationPassOptions& options, OptimizationContext& context) { + FunctionBase* func, GenericSelect select_to_optimize, + const LiftedOpInfo& info, const OptimizationPassOptions& options, + OptimizationContext& context) { const DelayEstimator* delay_estimator = options.delay_estimator; XLS_RET_CHECK(delay_estimator != nullptr) << "No delay estimator configured."; XLS_ASSIGN_OR_RETURN( @@ -476,7 +664,7 @@ absl::StatusOr CheckLatencyIncrease( _ << "Failed to get CriticalPathDelayAnalysis for delay model: " << delay_estimator->name()); // Check the (unscheduled) critical path through the select we're optimizing - int64_t t_before = *analysis->GetInfo(select_to_optimize); + int64_t t_before = *analysis->GetInfo(select_to_optimize.AsNode()); // To make it easy to estimate the critical path after lifting the select, we // add nodes to represent the post-optimization result. @@ -500,10 +688,10 @@ absl::StatusOr CheckLatencyIncrease( }; if (info.lifted_op == Op::kArrayIndex) { - XLS_ASSIGN_OR_RETURN( - tmp_new_select, - MakeSelectNode(func, select_to_optimize, info.other_operands, - info.default_other_operand)); + XLS_ASSIGN_OR_RETURN(tmp_new_select, + select_to_optimize.MakeSelectLikeWithNewArms( + info.other_operands, info.default_other_operand, + select_to_optimize.AsNode()->loc())); XLS_ASSIGN_OR_RETURN( tmp_lifted_op, func->MakeNode(SourceInfo(), info.shared_node, @@ -518,7 +706,7 @@ absl::StatusOr CheckLatencyIncrease( std::vector tmp_new_cases; std::optional tmp_new_default; - absl::Span original_cases = GetCases(select_to_optimize); + absl::Span original_cases = select_to_optimize.cases(); int64_t other_operand_idx = 0; for (int64_t i = 0; i < original_cases.size(); ++i) { if (info.identity_case_indices.contains(i)) { @@ -531,7 +719,7 @@ absl::StatusOr CheckLatencyIncrease( tmp_new_cases.push_back(info.other_operands[other_operand_idx++]); } } - std::optional original_default = GetDefaultValue(select_to_optimize); + std::optional original_default = select_to_optimize.default_value(); if (original_default.has_value()) { if (info.default_is_identity) { XLS_ASSIGN_OR_RETURN( @@ -544,8 +732,9 @@ absl::StatusOr CheckLatencyIncrease( } } XLS_ASSIGN_OR_RETURN(tmp_new_select, - MakeSelectNode(func, select_to_optimize, tmp_new_cases, - tmp_new_default)); + select_to_optimize.MakeSelectLikeWithNewArms( + tmp_new_cases, tmp_new_default, + select_to_optimize.AsNode()->loc())); Node* lhs = info.shared_is_lhs ? info.shared_node : tmp_new_select; Node* rhs = info.shared_is_lhs ? tmp_new_select : info.shared_node; switch (info.lifted_op) { @@ -559,6 +748,21 @@ absl::StatusOr CheckLatencyIncrease( func->MakeNode(SourceInfo(), lhs, rhs, info.lifted_op)); break; } + case Op::kEq: + case Op::kNe: + case Op::kULt: + case Op::kULe: + case Op::kUGt: + case Op::kUGe: + case Op::kSLt: + case Op::kSLe: + case Op::kSGt: + case Op::kSGe: { + XLS_ASSIGN_OR_RETURN( + tmp_lifted_op, + func->MakeNode(SourceInfo(), lhs, rhs, info.lifted_op)); + break; + } case Op::kAnd: case Op::kOr: case Op::kXor: { @@ -571,10 +775,11 @@ absl::StatusOr CheckLatencyIncrease( case Op::kUMul: case Op::kSMul: { XLS_ASSIGN_OR_RETURN( - tmp_lifted_op, func->MakeNode( - SourceInfo(), lhs, rhs, - select_to_optimize->GetType()->GetFlatBitCount(), - info.lifted_op)); + tmp_lifted_op, + func->MakeNode( + SourceInfo(), lhs, rhs, + select_to_optimize.AsNode()->GetType()->GetFlatBitCount(), + info.lifted_op)); break; } default: @@ -588,9 +793,9 @@ absl::StatusOr CheckLatencyIncrease( return t_after > t_before; } -absl::StatusOr ProfitabilityGuardForArrayIndex(FunctionBase* func, - Node* select_to_optimize, - Node* array_reference) { +absl::StatusOr ProfitabilityGuardForArrayIndex( + FunctionBase* func, GenericSelect select_to_optimize, + Node* array_reference) { // The next properties when hold guarantee that it is profitable to transform // the "select" node. // @@ -656,7 +861,7 @@ absl::StatusOr ProfitabilityGuardForArrayIndex(FunctionBase* func, Type* array_reference_type = array_reference->GetType(); ArrayType* array_reference_type_as_array_type = array_reference_type->AsArrayOrDie(); - absl::Span select_cases = GetCases(select_to_optimize); + absl::Span select_cases = select_to_optimize.cases(); Type* array_element_type = array_reference_type_as_array_type->element_type(); int64_t array_element_bitwidth = array_element_type->GetFlatBitCount(); for (Node* current_select_case_as_node : select_cases) { @@ -696,8 +901,9 @@ absl::StatusOr ProfitabilityGuardForArrayIndex(FunctionBase* func, } absl::StatusOr ProfitabilityGuardForBinaryOperation( - FunctionBase* func, Node* select_to_optimize, const LiftedOpInfo& info, - const OptimizationPassOptions& options, OptimizationContext& context) { + FunctionBase* func, GenericSelect select_to_optimize, + const LiftedOpInfo& info, const OptimizationPassOptions& options, + OptimizationContext& context) { // If lifting a shift operation combines literal shift amounts into a select, // this creates a variable shift from constant shifts, which is more // expensive. If we don't have a delay model, avoid lifting in this case. @@ -721,54 +927,69 @@ absl::StatusOr ProfitabilityGuardForBinaryOperation( // Heuristically: If the selector depends on `shared_node`, lifting will // likely serialize more operations & worsen the critical path. - Node* selector = select_to_optimize->Is()->selector() - : select_to_optimize->As()->selector(); + Node* selector = select_to_optimize.selector(); if (IsAncestorOf(info.shared_node, selector)) { VLOG(3) << " Selector depends on shared node, avoiding lift due to " "potential latency increase."; return false; } + // Calculate Cost After first, because we need `new_select_width` for the + // cost of comparison ops in "Cost Before". + Type* other_operand_type = nullptr; + if (!info.other_operands.empty()) { + other_operand_type = info.other_operands[0]->GetType(); + } else if (info.default_other_operand.has_value()) { + other_operand_type = (*info.default_other_operand)->GetType(); + } else { + // This should not happen if CanLiftSelect passed. + return false; + } + int64_t new_select_width = other_operand_type->GetFlatBitCount(); + + bool is_comparison = absl::c_contains(CompareOp::kOps, info.lifted_op); + // Calculate Cost Before: // Sum of bitwidths of the original select and any single-use non-identity // case nodes. - int64_t initial_bitwidths = select_to_optimize->GetType()->GetFlatBitCount(); - absl::Span cases = GetCases(select_to_optimize); + int64_t initial_bitwidths = + select_to_optimize.AsNode()->GetType()->GetFlatBitCount(); + absl::Span cases = select_to_optimize.cases(); for (int64_t i = 0; i < cases.size(); ++i) { if (!info.identity_case_indices.contains(i)) { Node* case_node = cases[i]; if (HasSingleUse(case_node)) { - initial_bitwidths += case_node->GetType()->GetFlatBitCount(); + if (is_comparison) { + // For comparisons, the cost of the op is proportional to its input + // bitwidth, not its 1-bit output. + initial_bitwidths += new_select_width; + } else { + initial_bitwidths += case_node->GetType()->GetFlatBitCount(); + } } } } - std::optional default_case = GetDefaultValue(select_to_optimize); + std::optional default_case = select_to_optimize.default_value(); if (default_case.has_value() && !info.default_is_identity) { if (HasSingleUse(*default_case)) { - initial_bitwidths += (*default_case)->GetType()->GetFlatBitCount(); + if (is_comparison) { + initial_bitwidths += new_select_width; + } else { + initial_bitwidths += (*default_case)->GetType()->GetFlatBitCount(); + } } } - // Calculate Cost After: - // Bitwidth of the new select + bitwidth of the lifted binary operation - // output. - Type* other_operand_type = nullptr; - if (!info.other_operands.empty()) { - other_operand_type = info.other_operands[0]->GetType(); - } else if (info.default_other_operand.has_value()) { - other_operand_type = (*info.default_other_operand)->GetType(); + // The output width of the lifted op. For comparisons, we use the input + // bitwidth as a proxy for the cost of the lifted op. + int64_t lifted_op_cost = 0; + if (is_comparison) { + lifted_op_cost = new_select_width; } else { - // This should not happen if CanLiftSelect passed. - return false; + lifted_op_cost = select_to_optimize.AsNode()->GetType()->GetFlatBitCount(); } - int64_t new_select_width = other_operand_type->GetFlatBitCount(); - // The output width of the lifted op is the same as the original select. - int64_t lifted_op_output_width = - select_to_optimize->GetType()->GetFlatBitCount(); - - int64_t remaining_bitwidths = new_select_width + lifted_op_output_width; + int64_t remaining_bitwidths = new_select_width + lifted_op_cost; VLOG(3) << " Profitability: Initial bitwidths: " << initial_bitwidths << ", Remaining bitwidths: " << remaining_bitwidths; @@ -776,7 +997,7 @@ absl::StatusOr ProfitabilityGuardForBinaryOperation( } absl::StatusOr ShouldLiftSelect(FunctionBase* func, - Node* select_to_optimize, + GenericSelect select_to_optimize, const LiftedOpInfo& info, const OptimizationPassOptions& options, OptimizationContext& context) { @@ -812,7 +1033,8 @@ absl::StatusOr ShouldLiftSelect(FunctionBase* func, } absl::StatusOr LiftSelectForArrayIndex( - FunctionBase* func, Node* select_to_optimize, const LiftedOpInfo& info) { + FunctionBase* func, GenericSelect select_to_optimize, + const LiftedOpInfo& info) { TransformationResult result; Node* array_reference = info.shared_node; @@ -823,37 +1045,37 @@ absl::StatusOr LiftSelectForArrayIndex( const std::vector& new_cases = info.other_operands; Node* new_select; - XLS_ASSIGN_OR_RETURN( - new_select, - MakeSelectNode(func, select_to_optimize, new_cases, new_default_value)); + XLS_ASSIGN_OR_RETURN(new_select, select_to_optimize.MakeSelectLikeWithNewArms( + new_cases, new_default_value, + select_to_optimize.AsNode()->loc())); // Step 1: add the new array access VLOG(3) << " Step 1: add the new arrayIndex node"; - std::vector new_indices; - new_indices.push_back(new_select); XLS_ASSIGN_OR_RETURN( Node * new_array_index, func->MakeNode(SourceInfo(), array_reference, - absl::Span(new_indices))); + absl::MakeConstSpan({new_select}))); // Step 2: replace the uses of the original "select" node with the only // exception of the new array access VLOG(3) << " Step 2: replace the uses of the original \"select\""; - XLS_RETURN_IF_ERROR(select_to_optimize->ReplaceUsesWith(new_array_index)); - VLOG(3) << " New select : " << select_to_optimize->ToString(); + XLS_RETURN_IF_ERROR( + select_to_optimize.AsNode()->ReplaceUsesWith(new_array_index)); + VLOG(3) << " New select : " + << select_to_optimize.AsNode()->ToString(); VLOG(3) << " New array index: " << new_array_index->ToString(); // Step 3: remove the original "select" node as it just became dead. This is // done by adding such node to the list of nodes to delete at the end of the // main loop of this transformation. VLOG(3) << " Step 3: mark the old \"select\" to be deleted"; - result.nodes_to_delete.insert(select_to_optimize); + result.nodes_to_delete.insert(select_to_optimize.AsNode()); // Step 4: check if new "select" nodes become optimizable. These are users of // the new arrayIndex node VLOG(3) << " Step 4: check if more \"select\" nodes should be considered"; for (Node* user : new_array_index->users()) { - if (user->OpIn({Op::kSel, Op::kPrioritySel})) { + if (user->OpIn({Op::kSel, Op::kPrioritySel, Op::kOneHotSel})) { result.new_selects_to_consider.insert(user); } } @@ -863,9 +1085,10 @@ absl::StatusOr LiftSelectForArrayIndex( } absl::StatusOr LiftSelectForBinaryOperation( - FunctionBase* func, Node* select_to_optimize, const LiftedOpInfo& info) { + FunctionBase* func, GenericSelect select_to_optimize, + const LiftedOpInfo& info) { TransformationResult result; - absl::Span original_cases = GetCases(select_to_optimize); + absl::Span original_cases = select_to_optimize.cases(); VLOG(3) << " Step 1: Build new cases for the inner select"; std::vector new_cases; @@ -898,7 +1121,7 @@ absl::StatusOr LiftSelectForBinaryOperation( VLOG(3) << " Step 2: Build new default for the inner select"; std::optional new_default; - std::optional original_default = GetDefaultValue(select_to_optimize); + std::optional original_default = select_to_optimize.default_value(); if (original_default.has_value()) { if (info.default_is_identity) { XLS_ASSIGN_OR_RETURN( @@ -913,7 +1136,8 @@ absl::StatusOr LiftSelectForBinaryOperation( XLS_ASSIGN_OR_RETURN( Node * new_select, - MakeSelectNode(func, select_to_optimize, new_cases, new_default)); + select_to_optimize.MakeSelectLikeWithNewArms( + new_cases, new_default, select_to_optimize.AsNode()->loc())); VLOG(3) << " Step 3: Create the lifted binary operation"; Node* lhs = info.shared_is_lhs ? info.shared_node : new_select; @@ -931,6 +1155,21 @@ absl::StatusOr LiftSelectForBinaryOperation( func->MakeNode(SourceInfo(), lhs, rhs, info.lifted_op)); break; } + case Op::kEq: + case Op::kNe: + case Op::kULe: + case Op::kUGe: + case Op::kULt: + case Op::kUGt: + case Op::kSLe: + case Op::kSGe: + case Op::kSLt: + case Op::kSGt: { + XLS_ASSIGN_OR_RETURN( + new_binop, + func->MakeNode(SourceInfo(), lhs, rhs, info.lifted_op)); + break; + } case Op::kAnd: case Op::kOr: case Op::kXor: { @@ -942,11 +1181,12 @@ absl::StatusOr LiftSelectForBinaryOperation( } case Op::kUMul: case Op::kSMul: { - XLS_ASSIGN_OR_RETURN(new_binop, - func->MakeNode( - SourceInfo(), lhs, rhs, - select_to_optimize->GetType()->GetFlatBitCount(), - info.lifted_op)); + XLS_ASSIGN_OR_RETURN( + new_binop, + func->MakeNode( + SourceInfo(), lhs, rhs, + select_to_optimize.AsNode()->GetType()->GetFlatBitCount(), + info.lifted_op)); } break; default: return absl::InternalError(absl::StrCat( @@ -955,16 +1195,16 @@ absl::StatusOr LiftSelectForBinaryOperation( } VLOG(3) << " Step 4: Replace uses of the original \"select\""; - XLS_RETURN_IF_ERROR(select_to_optimize->ReplaceUsesWith(new_binop)); + XLS_RETURN_IF_ERROR(select_to_optimize.AsNode()->ReplaceUsesWith(new_binop)); VLOG(3) << " New select: " << new_select->ToString(); VLOG(3) << " New binop : " << new_binop->ToString(); VLOG(3) << " Step 5: mark the old \"select\" to be deleted"; - result.nodes_to_delete.insert(select_to_optimize); + result.nodes_to_delete.insert(select_to_optimize.AsNode()); VLOG(3) << " Step 6: check if more \"select\" nodes should be considered"; for (Node* user : new_binop->users()) { - if (user->OpIn({Op::kSel, Op::kPrioritySel})) { + if (user->OpIn({Op::kSel, Op::kPrioritySel, Op::kOneHotSel})) { result.new_selects_to_consider.insert(user); } } @@ -974,9 +1214,9 @@ absl::StatusOr LiftSelectForBinaryOperation( return result; } -absl::StatusOr LiftSelect(FunctionBase* func, - Node* select_to_optimize, - const LiftedOpInfo& info) { +absl::StatusOr LiftSelect( + FunctionBase* func, GenericSelect select_to_optimize, + const LiftedOpInfo& info) { TransformationResult result; VLOG(3) << " Apply the transformation"; @@ -1000,6 +1240,12 @@ absl::StatusOr LiftSelect(FunctionBase* func, case Op::kShll: case Op::kShrl: case Op::kShra: + case Op::kEq: + case Op::kNe: + case Op::kULe: + case Op::kUGe: + case Op::kULt: + case Op::kUGt: return LiftSelectForBinaryOperation(func, select_to_optimize, info); default: @@ -1014,12 +1260,16 @@ absl::StatusOr LiftSelect(FunctionBase* func, absl::StatusOr LiftSelect( FunctionBase* func, Node* select_to_optimize, - const OptimizationPassOptions& options, OptimizationContext& context) { + const OptimizationPassOptions& options, OptimizationContext& context, + QueryEngine& query_engine) { TransformationResult result; + XLS_ASSIGN_OR_RETURN(GenericSelect sel, + GenericSelect::From(select_to_optimize)); + // Check if it is safe to apply the transformation XLS_ASSIGN_OR_RETURN(std::optional applicability_guard_result, - CanLiftSelect(func, select_to_optimize)); + CanLiftSelect(func, sel, query_engine)); if (!applicability_guard_result) { VLOG(3) << " It is not safe to apply the transformation for this select"; @@ -1031,9 +1281,8 @@ absl::StatusOr LiftSelect( // It is safe to apply the transformation // // Check if it is profitable to apply the transformation - XLS_ASSIGN_OR_RETURN( - bool should_lift, - ShouldLiftSelect(func, select_to_optimize, info, options, context)); + XLS_ASSIGN_OR_RETURN(bool should_lift, + ShouldLiftSelect(func, sel, info, options, context)); if (!should_lift) { VLOG(3) << " This transformation is not profitable for this select"; @@ -1045,7 +1294,7 @@ absl::StatusOr LiftSelect( // It is now the time to apply it. VLOG(3) << " This transformation is applicable and profitable for this " "select"; - XLS_ASSIGN_OR_RETURN(result, LiftSelect(func, select_to_optimize, info)); + XLS_ASSIGN_OR_RETURN(result, LiftSelect(func, sel, info)); return result; } @@ -1053,7 +1302,8 @@ absl::StatusOr LiftSelect( absl::StatusOr LiftSelects( FunctionBase* func, const absl::btree_set& selects_to_consider, - const OptimizationPassOptions& options, OptimizationContext& context) { + const OptimizationPassOptions& options, OptimizationContext& context, + QueryEngine& query_engine) { TransformationResult result; // Try to optimize all "select" nodes @@ -1066,8 +1316,9 @@ absl::StatusOr LiftSelects( VLOG(3) << "Select: " << select_node->ToString(); // Try to optimize the current "select" node - XLS_ASSIGN_OR_RETURN(TransformationResult current_transformation_result, - LiftSelect(func, select_node, options, context)); + XLS_ASSIGN_OR_RETURN( + TransformationResult current_transformation_result, + LiftSelect(func, select_node, options, context, query_engine)); // Accumulate the result of the transformation result.was_code_modified |= current_transformation_result.was_code_modified; @@ -1107,9 +1358,11 @@ absl::StatusOr SelectLiftingPass::RunOnFunctionBaseInternal( // Collect the "select" nodes that might be optimizable VLOG(3) << "Optimizing the function at level " << options.opt_level; + std::unique_ptr query_engine = + GetQueryEngine(func, options, context); for (Node* node : func->nodes()) { // Only consider selects. - if (!node->OpIn({Op::kSel, Op::kPrioritySel})) { + if (!node->OpIn({Op::kSel, Op::kPrioritySel, Op::kOneHotSel})) { continue; } @@ -1128,9 +1381,9 @@ absl::StatusOr SelectLiftingPass::RunOnFunctionBaseInternal( VLOG(3) << " New optimization iteration"; // Optimize all "select" nodes. - XLS_ASSIGN_OR_RETURN( - TransformationResult current_result, - LiftSelects(func, selects_to_consider, options, context)); + XLS_ASSIGN_OR_RETURN(TransformationResult current_result, + LiftSelects(func, selects_to_consider, options, + context, *query_engine)); // Check if we have modified the code. was_code_modified |= current_result.was_code_modified; diff --git a/xls/passes/select_lifting_pass_test.cc b/xls/passes/select_lifting_pass_test.cc index 408ced7674..3caaec01a9 100644 --- a/xls/passes/select_lifting_pass_test.cc +++ b/xls/passes/select_lifting_pass_test.cc @@ -35,7 +35,9 @@ #include "xls/ir/bits.h" #include "xls/ir/function.h" #include "xls/ir/function_builder.h" +#include "xls/ir/ir_matcher.h" #include "xls/ir/ir_test_base.h" +#include "xls/ir/lsb_or_msb.h" #include "xls/ir/node.h" #include "xls/ir/nodes.h" #include "xls/ir/op.h" @@ -45,8 +47,9 @@ #include "xls/passes/pass_base.h" #include "xls/solvers/z3_ir_equivalence_testutils.h" -namespace xls { +namespace m = ::xls::op_matchers; +namespace xls { namespace { class FakeDelayEstimator : public DelayEstimator { @@ -947,6 +950,384 @@ TEST_F(SelectLiftingPassTest, DontLiftMulWithIdentityIfLatencyIncreases) { EXPECT_THAT(Run(f, opts), absl_testing::IsOkAndHolds(false)); } +TEST_F(SelectLiftingPassTest, LiftCompareIfLatencyPermits) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue s = fb.Param("s", p->GetBitsType(1)); + BValue x = fb.Param("x", p->GetBitsType(32)); + BValue y = fb.Param("y", p->GetBitsType(32)); + + BValue x_ne_y = fb.Ne(x, y); + BValue zero = fb.Literal(UBits(0, 32)); + BValue x_ne_zero = fb.Ne(x, zero); + + // sel(s, [x != y, x != 0]) -> should lift to Ne(x, sel(s, [y, 0])) + BValue sel = fb.Select(s, {x_ne_y, x_ne_zero}); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(sel)); + + // Force latency check by providing a delay estimator + OptimizationPassOptions opts; + XLS_ASSERT_OK_AND_ASSIGN(opts.delay_estimator, GetDelayEstimator("unit")); + + // This should fail/crash initially with the production error! + EXPECT_THAT(Run(f, opts), absl_testing::IsOkAndHolds(true)); +} + +TEST_F(SelectLiftingPassTest, LiftThroughOneHotSelectWithOneHotSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue x = fb.Param("x", u32_type); + BValue y = fb.Param("y", u32_type); + BValue selector = fb.Param("selector", p->GetBitsType(2)); + + BValue oh_selector = fb.OneHot(selector, LsbOrMsb::kLsb); + + BValue zero = fb.Literal(UBits(0, 32)); + BValue x_add_zero = fb.Add(x, zero); + BValue x_add_y = fb.Add(x, y); + + BValue ohs = fb.OneHotSelect(oh_selector, {x_add_zero, x_add_y, x_add_zero}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(true)); + + EXPECT_THAT( + f->return_value(), + m::Add(m::Param("x"), + m::OneHotSelect(m::OneHot(m::Param("selector"), LsbOrMsb::kLsb), + {m::Literal(0), m::Param("y"), m::Literal(0)}))); +} + +TEST_F(SelectLiftingPassTest, NoLiftThroughOneHotSelectWithoutOneHotSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue x = fb.Param("x", u32_type); + BValue y = fb.Param("y", u32_type); + BValue selector = fb.Param("selector", p->GetBitsType(3)); + + BValue zero = fb.Literal(UBits(0, 32)); + BValue x_add_zero = fb.Add(x, zero); + BValue x_add_y = fb.Add(x, y); + + BValue ohs = fb.OneHotSelect(selector, {x_add_zero, x_add_y, x_add_zero}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(false)); +} + +TEST_F(SelectLiftingPassTest, LiftAndThroughOneHotSelect) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue x = fb.Param("x", u32_type); + BValue y = fb.Param("y", u32_type); + BValue z = fb.Param("z", u32_type); + BValue selector = fb.Param("selector", p->GetBitsType(2)); + + BValue x_and_y = fb.And(x, y); + BValue x_and_z = fb.And(x, z); + + BValue ohs = fb.OneHotSelect(selector, {x_and_y, x_and_z}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(true)); + + EXPECT_THAT( + f->return_value(), + m::And(m::Param("x"), m::OneHotSelect(m::Param("selector"), + {m::Param("y"), m::Param("z")}))); +} + +TEST_F(SelectLiftingPassTest, LiftSubThroughOneHotSelectWithOneHotSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue val = fb.Param("val", u32_type); + BValue x = fb.Param("x", u32_type); + BValue y = fb.Param("y", u32_type); + BValue selector = fb.Param("selector", p->GetBitsType(1)); + + BValue oh_selector = fb.OneHot(selector, LsbOrMsb::kLsb); + + BValue sub_x = fb.Subtract(val, x); + BValue sub_y = fb.Subtract(val, y); + + BValue ohs = fb.OneHotSelect(oh_selector, {sub_x, sub_y}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(true)); + + EXPECT_THAT( + f->return_value(), + m::Sub(m::Param("val"), + m::OneHotSelect(m::OneHot(m::Param("selector"), LsbOrMsb::kLsb), + {m::Param("x"), m::Param("y")}))); +} + +TEST_F(SelectLiftingPassTest, + LiftMulThroughOneHotSelectWithAtMostOneHotSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue val = fb.Param("val", u32_type); + BValue x = fb.Param("x", u32_type); + BValue a = fb.Param("a", u32_type); + BValue b = fb.Param("b", u32_type); + + BValue eq_1 = fb.Eq(x, fb.Literal(UBits(1, 32))); + BValue eq_2 = fb.Eq(x, fb.Literal(UBits(2, 32))); + BValue selector = fb.Concat({eq_2, eq_1}); + + BValue mul_a = fb.UMul(val, a); + BValue mul_b = fb.UMul(val, b); + + BValue ohs = fb.OneHotSelect(selector, {mul_a, mul_b}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(true)); + + EXPECT_THAT( + f->return_value(), + m::UMul(m::Param("val"), + m::OneHotSelect(m::Concat(m::Eq(m::Param("x"), m::Literal(2)), + m::Eq(m::Param("x"), m::Literal(1))), + {m::Param("a"), m::Param("b")}))); +} + +TEST_F(SelectLiftingPassTest, + NoLiftMulThroughOneHotSelectWithNonOneHotSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue val = fb.Param("val", u32_type); + BValue selector = fb.Param("selector", p->GetBitsType(2)); + BValue a = fb.Param("a", u32_type); + BValue b = fb.Param("b", u32_type); + + BValue mul_a = fb.UMul(val, a); + BValue mul_b = fb.UMul(val, b); + + BValue ohs = fb.OneHotSelect(selector, {mul_a, mul_b}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(false)); +} + +TEST_F(SelectLiftingPassTest, LiftOrThroughOneHotSelectWithNonzeroSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue a = fb.Param("a", u32_type); + BValue b = fb.Param("b", u32_type); + BValue param_p = fb.Param("p", p->GetBitsType(2)); + + BValue selector = fb.Or(param_p, fb.Literal(UBits(1, 2))); + + BValue constant = fb.Literal(UBits(42, 32)); + BValue or_a = fb.Or(a, constant); + BValue or_b = fb.Or(b, constant); + + BValue ohs = fb.OneHotSelect(selector, {or_a, or_b}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + // This should fail initially (RED state) + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(true)); + + EXPECT_THAT( + f->return_value(), + m::Or(m::Literal(42), m::OneHotSelect(m::Or(m::Param("p"), m::Literal(1)), + {m::Param("a"), m::Param("b")}))); +} + +TEST_F(SelectLiftingPassTest, + NoLiftOrThroughOneHotSelectWithMaybeZeroSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue a = fb.Param("a", u32_type); + BValue b = fb.Param("b", u32_type); + BValue selector = fb.Param("selector", p->GetBitsType(2)); + + BValue constant = fb.Literal(UBits(42, 32)); + BValue or_a = fb.Or(a, constant); + BValue or_b = fb.Or(b, constant); + + BValue ohs = fb.OneHotSelect(selector, {or_a, or_b}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(false)); +} + +TEST_F(SelectLiftingPassTest, + LiftArrayIndexThroughOneHotSelectWithAtMostOneHotSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue zero = fb.Literal(UBits(0, 32)); + BValue val1 = fb.Param("val1", u32_type); + BValue val2 = fb.Param("val2", u32_type); + BValue array = fb.Array({zero, val1, val2}, u32_type); + + BValue x = fb.Param("x", u32_type); + BValue eq_1 = fb.Eq(x, fb.Literal(UBits(1, 32))); + BValue eq_2 = fb.Eq(x, fb.Literal(UBits(2, 32))); + BValue selector = fb.Concat({eq_2, eq_1}); + + BValue idx0 = fb.Param("idx0", u32_type); + BValue idx1 = fb.Param("idx1", u32_type); + BValue access0 = fb.ArrayIndex(array, {idx0}); + BValue access1 = fb.ArrayIndex(array, {idx1}); + BValue ohs = fb.OneHotSelect(selector, {access0, access1}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + // This should fail initially (RED state) + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(true)); + + EXPECT_THAT( + f->return_value(), + m::ArrayIndex( + m::Array(m::Literal(0), m::Param("val1"), m::Param("val2")), + {m::OneHotSelect(m::Concat(m::Eq(m::Param("x"), m::Literal(2)), + m::Eq(m::Param("x"), m::Literal(1))), + {m::Param("idx0"), m::Param("idx1")})})); +} + +TEST_F(SelectLiftingPassTest, + NoLiftArrayIndexThroughOneHotSelectWithAtMostOneHotSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue non_zero = fb.Literal(UBits(42, 32)); + BValue val1 = fb.Param("val1", u32_type); + BValue val2 = fb.Param("val2", u32_type); + BValue array = fb.Array({non_zero, val1, val2}, u32_type); + + BValue x = fb.Param("x", u32_type); + BValue eq_1 = fb.Eq(x, fb.Literal(UBits(1, 32))); + BValue eq_2 = fb.Eq(x, fb.Literal(UBits(2, 32))); + BValue selector = fb.Concat({eq_2, eq_1}); + + BValue idx0 = fb.Param("idx0", u32_type); + BValue idx1 = fb.Param("idx1", u32_type); + BValue access0 = fb.ArrayIndex(array, {idx0}); + BValue access1 = fb.ArrayIndex(array, {idx1}); + BValue ohs = fb.OneHotSelect(selector, {access0, access1}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(false)); +} + +TEST_F(SelectLiftingPassTest, + LiftShllValueThroughOneHotSelectWithAtMostOneHotSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue a = fb.Param("a", u32_type); + BValue b = fb.Param("b", u32_type); + BValue shift_amount = fb.Param("shift_amount", u32_type); + + BValue shll_a = fb.Shll(a, shift_amount); + BValue shll_b = fb.Shll(b, shift_amount); + + BValue x = fb.Param("x", u32_type); + BValue eq_1 = fb.Eq(x, fb.Literal(UBits(1, 32))); + BValue eq_2 = fb.Eq(x, fb.Literal(UBits(2, 32))); + BValue selector = fb.Concat({eq_2, eq_1}); + + BValue ohs = fb.OneHotSelect(selector, {shll_a, shll_b}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(true)); + + EXPECT_THAT( + f->return_value(), + m::Shll(m::OneHotSelect(m::Concat(m::Eq(m::Param("x"), m::Literal(2)), + m::Eq(m::Param("x"), m::Literal(1))), + {m::Param("a"), m::Param("b")}), + m::Param("shift_amount"))); +} + +TEST_F(SelectLiftingPassTest, + LiftEqThroughOneHotSelectWithAtMostOneHotSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue a = fb.Param("a", u32_type); + BValue b = fb.Param("b", u32_type); + BValue const_42 = fb.Literal(UBits(42, 32)); + + BValue eq_a = fb.Eq(const_42, a); + BValue eq_b = fb.Eq(const_42, b); + + BValue x = fb.Param("x", u32_type); + BValue eq_1 = fb.Eq(x, fb.Literal(UBits(1, 32))); + BValue eq_2 = fb.Eq(x, fb.Literal(UBits(2, 32))); + BValue selector = fb.Concat({eq_2, eq_1}); + + BValue ohs = fb.OneHotSelect(selector, {eq_a, eq_b}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(true)); + + EXPECT_THAT( + f->return_value(), + m::Eq(m::Literal(42), + m::OneHotSelect(m::Concat(m::Eq(m::Param("x"), m::Literal(2)), + m::Eq(m::Param("x"), m::Literal(1))), + {m::Param("a"), m::Param("b")}))); +} + +TEST_F(SelectLiftingPassTest, + NoLiftEqThroughOneHotSelectWithAtMostOneHotSelectorAndUnknownShared) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue a = fb.Param("a", u32_type); + BValue b = fb.Param("b", u32_type); + BValue unknown_val = fb.Param("unknown_val", u32_type); + + BValue eq_a = fb.Eq(unknown_val, a); + BValue eq_b = fb.Eq(unknown_val, b); + + BValue x = fb.Param("x", u32_type); + BValue eq_1 = fb.Eq(x, fb.Literal(UBits(1, 32))); + BValue eq_2 = fb.Eq(x, fb.Literal(UBits(2, 32))); + BValue selector = fb.Concat({eq_2, eq_1}); + + BValue ohs = fb.OneHotSelect(selector, {eq_a, eq_b}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(false)); +} + FUZZ_TEST(IrFuzzTest, IrFuzzSelectLifting) .WithDomains(IrFuzzDomainWithArgs(/*arg_set_count=*/10));