From edb2d109b3f215c412f8ad666cbd3dd51a8947f7 Mon Sep 17 00:00:00 2001 From: Eric Astor Date: Fri, 12 Jun 2026 16:06:43 -0700 Subject: [PATCH] [opt] Support more OneHotSelect simplifications Specifically allows constant-selector simplification to apply to complex-type OneHotSelects, and relaxes the type constraint for identical case simplification. PiperOrigin-RevId: 931381835 --- xls/data_structures/leaf_type_tree.h | 12 + xls/ir/node_util.h | 3 + xls/passes/BUILD | 10 + xls/passes/array_simplification_pass.cc | 71 +-- xls/passes/array_simplification_pass_test.cc | 77 +++ xls/passes/bit_slice_simplification_pass.cc | 97 ++-- .../bit_slice_simplification_pass_test.cc | 28 + xls/passes/select_lifting_pass.cc | 497 +++++++++++++----- xls/passes/select_lifting_pass_test.cc | 383 +++++++++++++- xls/passes/select_simplification_pass.cc | 117 ++++- xls/passes/select_simplification_pass_test.cc | 65 +++ 11 files changed, 1102 insertions(+), 258 deletions(-) diff --git a/xls/data_structures/leaf_type_tree.h b/xls/data_structures/leaf_type_tree.h index df7e32fbb6..0f6510e898 100644 --- a/xls/data_structures/leaf_type_tree.h +++ b/xls/data_structures/leaf_type_tree.h @@ -940,6 +940,18 @@ LeafTypeTree Map(LeafTypeTreeView ltt, } return LeafTypeTree::CreateFromVector(ltt.type(), std::move(result)); } +template +absl::StatusOr> MapStatus( + LeafTypeTreeView ltt, + std::function(const R& element)> function) { + typename LeafTypeTree::DataContainerT result; + result.reserve(ltt.size()); + for (const R& element : ltt.elements()) { + XLS_ASSIGN_OR_RETURN(T value, function(element)); + result.push_back(std::move(value)); + } + return LeafTypeTree::CreateFromVector(ltt.type(), std::move(result)); +} // Use the given function to update each leaf element in this `LeafTypeTree` // using the corresponding element in the `other`. Return an error if the given diff --git a/xls/ir/node_util.h b/xls/ir/node_util.h index 501a03c7c6..08e931a443 100644 --- a/xls/ir/node_util.h +++ b/xls/ir/node_util.h @@ -636,6 +636,9 @@ class GenericSelect { GenericSelect& operator=(const GenericSelect&) = default; GenericSelect& operator=(GenericSelect&&) = default; static absl::StatusOr From(Node* n); + static bool IsSelect(const Node* n) { + return n->OpIn({Op::kSel, Op::kPrioritySel, Op::kOneHotSel}); + } // Assignment operators from underlying types. GenericSelect& operator=(Select* select) { diff --git a/xls/passes/BUILD b/xls/passes/BUILD index 7632f6e402..421f9dfc7c 100644 --- a/xls/passes/BUILD +++ b/xls/passes/BUILD @@ -1144,9 +1144,15 @@ xls_pass( hdrs = ["select_lifting_pass.h"], pass_class = "SelectLiftingPass", deps = [ + ":bdd_query_engine", ":critical_path_delay_analysis", + ":lazy_ternary_query_engine", ":optimization_pass", + ":partial_info_query_engine", ":pass_base", + ":query_engine", + ":stateless_query_engine", + ":union_query_engine", "//xls/common/status:ret_check", "//xls/common/status:status_macros", "//xls/estimators/delay_model:delay_estimator", @@ -1224,6 +1230,7 @@ cc_library( "//xls/ir:interval_ops", "//xls/ir:node_util", "//xls/ir:op", + "//xls/ir:source_location", "//xls/ir:ternary", "//xls/ir:type", "//xls/ir:value", @@ -1240,6 +1247,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", + "@cppitertools", ], # Temporary until pipeline changes are further along. alwayslink = 1, @@ -3183,6 +3191,7 @@ cc_test( "//xls/ir:source_location", "//xls/ir:type", "//xls/ir:value", + "//xls/ir:value_utils", "//xls/solvers:z3_ir_equivalence_testutils", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", @@ -4602,6 +4611,7 @@ cc_test( "//xls/ir", "//xls/ir:bits", "//xls/ir:function_builder", + "//xls/ir:ir_matcher", "//xls/ir:ir_test_base", "//xls/ir:op", "//xls/solvers:z3_ir_equivalence_testutils", diff --git a/xls/passes/array_simplification_pass.cc b/xls/passes/array_simplification_pass.cc index 146f4493e4..8dc52aa714 100644 --- a/xls/passes/array_simplification_pass.cc +++ b/xls/passes/array_simplification_pass.cc @@ -1381,14 +1381,13 @@ absl::StatusOr SimplifyConditionalAssign( Node* select, const QueryEngine& query_engine) { absl::Span original_cases; std::optional original_default_value; - if (select->Is()->cases(); - original_default_value = select->As()) { - XLS_ASSIGN_OR_RETURN(selected_value, - select->function_base()->MakeNode()->selector(), - /*cases=*/ - case_values, - /*default_value=*/default_value)); - } else { - XLS_RET_CHECK(select->Is()); - XLS_RET_CHECK(default_value.has_value()); - XLS_ASSIGN_OR_RETURN( - selected_value, - select->function_base()->MakeNode( - select->loc(), select->As()->selector(), - /*cases=*/ - case_values, - /*default_value=*/*default_value)); - } + XLS_ASSIGN_OR_RETURN( + selected_value, + sel.CloneSelectLike(sel.selector(), case_values, default_value)); XLS_ASSIGN_OR_RETURN( ArrayUpdate * overall_update, @@ -1535,16 +1519,9 @@ absl::StatusOr SimplifyConditionalAssign( // further optimization. On the other hand, this can replicate the select // logic, which can be expensive in area, so we limit by the number of cases. absl::StatusOr SimplifySelectOfArrays(Node* select) { - absl::Span original_cases; - std::optional original_default_value; - if (select->Is()->cases(); - original_default_value = select->As()) { - XLS_ASSIGN_OR_RETURN( - selected_element, - select->function_base()->MakeNode()->selector(), - /*cases=*/elements, /*default=*/default_element)); - } else { - XLS_RET_CHECK(select->Is()); - XLS_RET_CHECK(default_element.has_value()); - XLS_ASSIGN_OR_RETURN( - selected_element, - select->function_base()->MakeNode( - select->loc(), select->As()->selector(), - /*cases=*/elements, /*default=*/*default_element)); - } + XLS_ASSIGN_OR_RETURN( + selected_element, + sel.CloneSelectLike(sel.selector(), elements, default_element)); selected_elements.push_back(selected_element); } XLS_ASSIGN_OR_RETURN(Array * new_array, @@ -1655,7 +1620,7 @@ absl::StatusOr SimplifyArraySlice( // Simplify various forms of a select of array-typed values. absl::StatusOr SimplifySelect(Node* select, const QueryEngine& query_engine) { - XLS_RET_CHECK(select->Is() || node->Is()) { + } else if (GenericSelect::IsSelect(node)) { XLS_ASSIGN_OR_RETURN(result, SimplifySelect(node, query_engine)); } else if (node->Is()) { XLS_ASSIGN_OR_RETURN( diff --git a/xls/passes/array_simplification_pass_test.cc b/xls/passes/array_simplification_pass_test.cc index 1601f33313..d98c44a937 100644 --- a/xls/passes/array_simplification_pass_test.cc +++ b/xls/passes/array_simplification_pass_test.cc @@ -1722,6 +1722,83 @@ TEST_F(ArraySimplificationPassTest, IndexOfOneHotSelect) { m::ArrayIndex(m::Param("b"), {m::Param("i"), m::Param("j")})})); } +TEST_F(ArraySimplificationPassTest, + SimplifyConditionalAssignWithOneHotSelectAndOneHotSelector) { + Package* p = GetPackage(); + FunctionBuilder fb(TestName(), p); + Type* u32 = p->GetBitsType(32); + BValue a_param = fb.Param("a", u32); + BValue b_param = fb.Param("b", u32); + BValue c_param = fb.Param("c", u32); + BValue A = fb.Array({a_param, b_param, c_param}, u32); + + BValue x = fb.Param("x", p->GetBitsType(1)); + BValue selector = fb.Concat({x, fb.Not(x)}); + + BValue v = fb.Param("v", u32); + BValue idx = fb.Literal(UBits(1, 32)); + BValue A_updated = fb.ArrayUpdate(A, v, {idx}); + + BValue ohs = fb.OneHotSelect(selector, {A, A_updated}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + ASSERT_THAT(Run(f), IsOkAndHolds(true)); + + EXPECT_THAT( + f->return_value(), + m::Array(m::Param("a"), + m::OneHotSelect(m::Concat(m::Param("x"), m::Not(m::Param("x"))), + {m::Param("b"), m::Param("v")}), + m::Param("c"))); +} + +TEST_F(ArraySimplificationPassTest, + NoSimplifyConditionalAssignWithOneHotSelectAndNonOneHotSelector) { + Package* p = GetPackage(); + FunctionBuilder fb(TestName(), p); + Type* u32 = p->GetBitsType(32); + BValue a_param = fb.Param("a", u32); + BValue b_param = fb.Param("b", u32); + BValue c_param = fb.Param("c", u32); + BValue A = fb.Array({a_param, b_param, c_param}, u32); + + BValue selector = fb.Param("selector", p->GetBitsType(2)); + + BValue v = fb.Param("v", u32); + BValue idx = fb.Literal(UBits(1, 32)); + BValue A_updated = fb.ArrayUpdate(A, v, {idx}); + + BValue ohs = fb.OneHotSelect(selector, {A, A_updated}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + ASSERT_THAT(Run(f), IsOkAndHolds(false)); +} + +TEST_F(ArraySimplificationPassTest, SimplifySelectOfArraysWithOneHotSelect) { + Package* p = GetPackage(); + FunctionBuilder fb(TestName(), p); + Type* u32 = p->GetBitsType(32); + BValue a = fb.Param("a", u32); + BValue b = fb.Param("b", u32); + BValue c = fb.Param("c", u32); + BValue d = fb.Param("d", u32); + + BValue A = fb.Array({a, b}, u32); + BValue B = fb.Array({c, d}, u32); + + BValue selector = fb.Param("selector", p->GetBitsType(2)); + BValue ohs = fb.OneHotSelect(selector, {A, B}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + ASSERT_THAT(Run(f), IsOkAndHolds(true)); + + EXPECT_THAT(f->return_value(), + m::Array(m::OneHotSelect(m::Param("selector"), + {m::Param("a"), m::Param("c")}), + m::OneHotSelect(m::Param("selector"), + {m::Param("b"), m::Param("d")}))); +} + void IrFuzzArraySimplification(FuzzPackageWithArgs fuzz_package_with_args) { ArraySimplificationPass pass; OptimizationPassChangesOutputs(std::move(fuzz_package_with_args), pass); diff --git a/xls/passes/bit_slice_simplification_pass.cc b/xls/passes/bit_slice_simplification_pass.cc index 3fbb672530..d6e329e109 100644 --- a/xls/passes/bit_slice_simplification_pass.cc +++ b/xls/passes/bit_slice_simplification_pass.cc @@ -625,27 +625,34 @@ absl::StatusOr SimplifyLiteralBitSliceUpdate(BitSliceUpdate* update, return true; } -bool IsSelectOfLiterals(Node* node, QueryEngine* query_engine) { +bool IsSelectOfLiterals(GenericSelect sel, QueryEngine* query_engine) { + // Only allow OneHotSelect if the selector is actually one-hot. + if (sel.kind() == GenericSelect::Kind::kOneHotSel && + !query_engine->ExactlyOneBitTrue(sel.selector())) { + return false; + } + auto is_literal_or_select_of_literals = [&](Node* node) { - return query_engine->IsFullyKnown(node) || - IsSelectOfLiterals(node, query_engine); + if (query_engine->IsFullyKnown(node)) { + return true; + } + if (!GenericSelect::IsSelect(node)) { + return false; + } + absl::StatusOr node_sel = GenericSelect::From(node); + CHECK_OK(node_sel); + return IsSelectOfLiterals(*std::move(node_sel), query_engine); }; - if (node->Is(); - return sel->AllCases([&](Node* case_value) { - return is_literal_or_select_of_literals(case_value); - }); + if (!absl::c_all_of(sel.cases(), [&](Node* case_value) { + return is_literal_or_select_of_literals(case_value); + })) { + return false; } - if (node->Is()) { - PrioritySelect* sel = node->As(); - return absl::c_all_of(sel->cases(), - [&](Node* case_value) { - return is_literal_or_select_of_literals(case_value); - }) && - is_literal_or_select_of_literals(sel->default_value()); + if (sel.default_value().has_value()) { + return is_literal_or_select_of_literals(*sel.default_value()); } - return false; + return true; } absl::StatusOr LiftThroughSelectsOfLiterals( @@ -656,54 +663,24 @@ absl::StatusOr LiftThroughSelectsOfLiterals( return lift_to_literal(*known_value); } - Node* selector; - absl::Span cases; - std::optional default_value; - if (node->Is(); - selector = sel->selector(); - cases = sel->cases(); - default_value = sel->default_value(); - } else if (node->Is()) { - PrioritySelect* sel = node->As(); - selector = sel->selector(); - cases = sel->cases(); - default_value = sel->default_value(); - } else { - return absl::InternalError( - absl::StrCat("LiftThroughSelectsOfLiterals invoked on a node that was " - "not a select of literals: ", - node->ToString())); - } + XLS_ASSIGN_OR_RETURN(GenericSelect sel, GenericSelect::From(node)); std::vector new_cases; std::optional new_default_value = std::nullopt; - new_cases.reserve(cases.size()); - for (Node* case_value : cases) { + new_cases.reserve(sel.cases().size()); + for (Node* case_value : sel.cases()) { XLS_ASSIGN_OR_RETURN(Node * new_case_value, LiftThroughSelectsOfLiterals(case_value, query_engine, lift_to_literal)); new_cases.push_back(new_case_value); } - if (default_value.has_value()) { - XLS_ASSIGN_OR_RETURN(new_default_value, - LiftThroughSelectsOfLiterals( - *default_value, query_engine, lift_to_literal)); + if (sel.default_value().has_value()) { + XLS_ASSIGN_OR_RETURN(new_default_value, LiftThroughSelectsOfLiterals( + *sel.default_value(), + query_engine, lift_to_literal)); } - if (node->Is( - node->loc(), selector, new_cases, new_default_value); - } - if (node->Is()) { - XLS_RET_CHECK(new_default_value.has_value()); - return node->function_base()->MakeNode( - node->loc(), selector, new_cases, *new_default_value); - } - XLS_RET_CHECK(node->Is()); - XLS_RET_CHECK(!new_default_value.has_value()); - return node->function_base()->MakeNode(node->loc(), selector, - new_cases); + return sel.CloneSelectLike(sel.selector(), new_cases, new_default_value); } // Hoist bit slice updates above selects of literals, where they can be turned @@ -711,7 +688,11 @@ absl::StatusOr LiftThroughSelectsOfLiterals( absl::StatusOr SimplifySelectOfLiteralsBitSliceUpdate( BitSliceUpdate* update, QueryEngine* query_engine) { Node* start = update->start(); - if (!IsSelectOfLiterals(start, query_engine)) { + if (!GenericSelect::IsSelect(start)) { + return false; + } + XLS_ASSIGN_OR_RETURN(GenericSelect sel, GenericSelect::From(start)); + if (!IsSelectOfLiterals(sel, query_engine)) { return false; } @@ -822,7 +803,11 @@ absl::StatusOr SimplifyLiteralDynamicBitSlice(DynamicBitSlice* bit_slice, absl::StatusOr SimplifySelectOfLiteralsDynamicBitSlice( DynamicBitSlice* bit_slice, QueryEngine* query_engine) { Node* start = bit_slice->start(); - if (!IsSelectOfLiterals(start, query_engine)) { + if (!GenericSelect::IsSelect(start)) { + return false; + } + XLS_ASSIGN_OR_RETURN(GenericSelect sel, GenericSelect::From(start)); + if (!IsSelectOfLiterals(sel, query_engine)) { return false; } diff --git a/xls/passes/bit_slice_simplification_pass_test.cc b/xls/passes/bit_slice_simplification_pass_test.cc index d8e43f0c42..7fff179861 100644 --- a/xls/passes/bit_slice_simplification_pass_test.cc +++ b/xls/passes/bit_slice_simplification_pass_test.cc @@ -1221,6 +1221,34 @@ TEST_F(BitSliceSimplificationPassTest, BitSliceCannotReachAllBits) { })); } +TEST_F( + BitSliceSimplificationPassTest, + SimplifySelectOfLiteralsDynamicBitSliceWithOneHotSelectAndOneHotSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue to_slice = fb.Param("to_slice", p->GetBitsType(30)); + BValue x = fb.Param("x", p->GetBitsType(1)); + BValue selector = fb.Concat({x, fb.Not(x)}); + BValue start_index = fb.OneHotSelect( + selector, {fb.Literal(UBits(5, 32)), fb.Literal(UBits(25, 32))}); + fb.DynamicBitSlice(to_slice, start_index, /*width=*/15); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + + solvers::z3::ScopedVerifyEquivalence sve(f); + ASSERT_THAT(Run(f), IsOkAndHolds(true)); + + EXPECT_THAT( + f->return_value(), + m::OneHotSelect( + m::Concat(m::Param("x"), m::Not(m::Param("x"))), + { + m::BitSlice(m::Param("to_slice"), /*start=*/5, /*width=*/15), + m::ZeroExt( + m::BitSlice(m::Param("to_slice"), /*start=*/25, /*width=*/5)), + })); +} + void IrFuzzBitSliceSimplification(FuzzPackageWithArgs fuzz_package_with_args) { BitSliceSimplificationPass pass; OptimizationPassChangesOutputs(std::move(fuzz_package_with_args), pass); diff --git a/xls/passes/select_lifting_pass.cc b/xls/passes/select_lifting_pass.cc index dc7a2c04da..ad5f8c9555 100644 --- a/xls/passes/select_lifting_pass.cc +++ b/xls/passes/select_lifting_pass.cc @@ -15,6 +15,7 @@ #include "xls/passes/select_lifting_pass.h" #include +#include #include #include #include @@ -43,9 +44,15 @@ #include "xls/ir/source_location.h" #include "xls/ir/type.h" #include "xls/ir/value.h" +#include "xls/passes/bdd_query_engine.h" #include "xls/passes/critical_path_delay_analysis.h" +#include "xls/passes/lazy_ternary_query_engine.h" #include "xls/passes/optimization_pass.h" +#include "xls/passes/partial_info_query_engine.h" #include "xls/passes/pass_base.h" +#include "xls/passes/query_engine.h" +#include "xls/passes/stateless_query_engine.h" +#include "xls/passes/union_query_engine.h" namespace xls { @@ -72,21 +79,7 @@ struct LiftedOpInfo { bool default_is_identity; // True if the default case is an identity. }; -std::optional GetDefaultValue(Node* select) { - if (select->Is()) { - return select->As()->default_value(); - } - CHECK(select->Is()->default_value(); -} -absl::Span GetCases(Node* select) { - if (select->Is()) { - return select->As()->cases(); - } - CHECK(select->Is()->cases(); -} bool MatchesIndexBitwidth(ArrayIndex* ai, int64_t shared_index_bitwidth) { absl::Span current_case_indices = ai->indices(); @@ -205,6 +198,25 @@ std::optional CheckArrayIndexLiftable( return ApplicabilityGuardForArrayIndex(cases, default_case); } +// Returns true if the operation has a registered identity literal. +bool HasIdentityLiteral(Op op) { + switch (op) { + case Op::kAdd: + case Op::kSub: + case Op::kOr: + case Op::kXor: + case Op::kShll: + case Op::kShrl: + case Op::kShra: + case Op::kAnd: + case Op::kUMul: + case Op::kSMul: + return true; + default: + return false; + } +} + // Creates a Literal node representing the right-identity for the given // operation. e.g., 0 for Add/Sub/Xor, -1 for And, 1 for Mul. // Note that all operations in kLiftableBinaryOps have a constant @@ -348,6 +360,14 @@ std::optional GetLiftableOperationInfoForOp( return std::nullopt; } + if ((!identity_case_indices.empty() || default_is_identity) && + !HasIdentityLiteral(test_op)) { + VLOG(3) << "Cannot lift: operation " << OpToString(test_op) + << " does not have a registered identity literal, but identity " + "cases are present."; + return std::nullopt; + } + // If the op is commutative, we can always act as if the shared node was on // the LHS. if (op_is_commutative) { @@ -365,13 +385,109 @@ std::optional GetLiftableOperationInfoForOp( }; } +// Returns true if the operation preserves zero, i.e., op(0) == 0. +// For comparison and shift operations, this is conditional on the value +// and position of the shared operand X. +bool OpPreservesZero(const LiftedOpInfo& info, + const QueryEngine& query_engine) { + Op op = info.lifted_op; + switch (op) { + case Op::kAnd: + case Op::kUMul: + case Op::kSMul: + case Op::kZeroExt: + case Op::kSignExt: + case Op::kTupleIndex: + case Op::kReverse: + case Op::kOrReduce: + return true; + + // If the select is the value being shifted, then this is always true; + // shifting 0 by any amount is 0. + // + // It's technically also safe if the value being shifted is identically + // zero, but that should be simplified to zero elsewhere. + case Op::kShll: + case Op::kShrl: + case Op::kShra: { + return !info.shared_is_lhs; + } + + // For comparisons, the safety depends on whether the shared node is 0 (and + // for inequalities, which side it's on). + case Op::kEq: { + if (!info.shared_node->GetType()->IsBits()) { + return false; + } + // If the shared node is non-zero, then comparing shared == 0 will be + // false - so it returns 0. + return query_engine.AtLeastOneBitTrue(info.shared_node); + } + case Op::kNe: { + // If the shared node is zero, then comparing shared != 0 will be false - + // so it returns 0. + return query_engine.IsAllZeros(info.shared_node); + } + case Op::kULe: { + // If the shared node is the LHS and non-zero, then comparing shared <= 0 + // will be false - so it returns 0. + return info.shared_is_lhs && + query_engine.AtLeastOneBitTrue(info.shared_node); + } + case Op::kUGe: { + // If the shared node is the RHS and non-zero, then comparing 0 <= shared + // will be false - so it returns 0. + return !info.shared_is_lhs && + query_engine.AtLeastOneBitTrue(info.shared_node); + } + case Op::kULt: { + // If the shared node is the LHS, then comparing shared < 0 will always be + // false, so it returns 0. + // If it's the RHS, then comparing 0 < shared will + // be only be false if shared is 0. + return info.shared_is_lhs || query_engine.IsAllZeros(info.shared_node); + } + case Op::kUGt: { + // If the shared node is the RHS, then comparing 0 > shared will always be + // false, so it returns 0. + // If it's the LHS, then comparing shared > 0 will + // be only be false if shared is 0. + return !info.shared_is_lhs || query_engine.IsAllZeros(info.shared_node); + } + + default: + return false; + } +} + +// Returns true if the operation distributes over bitwise OR, i.e., +// op(A v B) == op(A) v op(B). +// This is the safety requirement for lifting across a OneHotSelect where +// more than one bit might be active. +bool OpDistributesOverOr(Op op) { + switch (op) { + case Op::kAnd: + case Op::kOr: + case Op::kZeroExt: + case Op::kSignExt: + case Op::kTupleIndex: + case Op::kReverse: + case Op::kOrReduce: + case Op::kConcat: + return true; + default: + return false; + } +} + // Attempts to find a liftable operation in the select cases. std::optional GetLiftableOperationInfo( absl::Span cases, std::optional default_case, Node* potential_shared_node) { constexpr Op kLiftableBinaryOps[] = { - Op::kAdd, Op::kSub, Op::kAnd, Op::kOr, Op::kXor, - Op::kUMul, Op::kSMul, Op::kShll, Op::kShrl, Op::kShra}; + Op::kAdd, Op::kSub, Op::kAnd, Op::kOr, Op::kXor, Op::kUMul, + Op::kSMul, Op::kShll, Op::kShrl, Op::kShra, Op::kEq, Op::kNe, + Op::kULe, Op::kUGe, Op::kULt, Op::kUGt}; for (Op test_op : kLiftableBinaryOps) { std::optional info = GetLiftableOperationInfoForOp( @@ -383,14 +499,54 @@ std::optional GetLiftableOperationInfo( return std::nullopt; } +std::unique_ptr GetQueryEngine( + FunctionBase* f, const OptimizationPassOptions& options, + OptimizationContext& context) { + if (options.opt_level < 3) { + return std::make_unique(UnionQueryEngine::Of( + StatelessQueryEngine(), + context.SharedQueryEngine(f))); + } + return std::make_unique( + UnionQueryEngine::Of(StatelessQueryEngine(), + context.SharedQueryEngine(f), + context.SharedQueryEngine(f))); +} + +// Returns true if the 0-th element of the array is proven to be all-zeros. +// Natively supports nested arrays, tuples, and structs by slicing the leaf +// tree. +bool IsFirstElementZero(Node* array, const QueryEngine& query_engine) { + if (!array->GetType()->IsArray()) { + return false; + } + // Fast-path: If it is structurally an Array node, check its 0-th operand + // directly + if (array->Is()) { + return query_engine.IsAllZeros(array->operand(0)); + } + // Fallback: Slice the ternary tree at index 0 and check all leaves + std::optional> ternary = + query_engine.GetTernary(array); + if (!ternary.has_value()) { + return false; + } + LeafTypeTreeView element_view = ternary->AsView({0}); + for (const TernaryVector& leaf_vector : element_view.elements()) { + if (!ternary_ops::IsKnownZero(leaf_vector)) { + return false; + } + } + return true; +} + absl::StatusOr> CanLiftSelect( - FunctionBase* func, Node* select_to_optimize) { + FunctionBase* func, GenericSelect select_to_optimize, + const QueryEngine& query_engine) { VLOG(3) << " Checking the applicability guard"; - // Only "select" nodes with specific properties can be optimized by this - // transformation. - absl::Span cases = GetCases(select_to_optimize); - std::optional default_case = GetDefaultValue(select_to_optimize); + absl::Span cases = select_to_optimize.cases(); + std::optional default_case = select_to_optimize.default_value(); if (cases.empty()) { VLOG(3) << " Select has no cases, not liftable."; @@ -418,15 +574,60 @@ absl::StatusOr> CanLiftSelect( for (Node* potential_shared : potential_shared_nodes) { std::optional info = GetLiftableOperationInfo(cases, default_case, potential_shared); - if (info.has_value()) { - return info; + if (!info.has_value()) { + continue; + } + if (select_to_optimize.AsNode()->op() == Op::kOneHotSel) { + bool preserves_zero = OpPreservesZero(*info, query_engine); + bool at_least_one_bit = + query_engine.AtLeastOneBitTrue(select_to_optimize.selector()); + bool distributes_or = OpDistributesOverOr(info->lifted_op); + bool at_most_one_bit = + query_engine.AtMostOneBitTrue(select_to_optimize.selector()); + VLOG(3) << "SelectLifting: Checking Op " << OpToString(info->lifted_op); + VLOG(3) << " preserves_zero: " << preserves_zero + << ", at_least_one_bit: " << at_least_one_bit; + VLOG(3) << " distributes_or: " << distributes_or + << ", at_most_one_bit: " << at_most_one_bit; + + if (!preserves_zero && !at_least_one_bit) { + VLOG(3) << " OneHotSelect selector is not provably non-zero, and " + << OpToString(info->lifted_op) + << " is not safe to commute over a zero selector. Skipping."; + continue; // Try other potential shared nodes + } + if (!distributes_or && !at_most_one_bit) { + VLOG(3) << " OneHotSelect selector could have more than one bit " + "set, and " + << OpToString(info->lifted_op) + << " is not safe to commute over a selector with multiple " + "bits set. Skipping."; + continue; // Try other potential shared nodes + } } + return info; } // Check for ArrayIndex lifting opportunity. std::optional shared_array = CheckArrayIndexLiftable(cases, default_case); if (shared_array.has_value()) { + if (select_to_optimize.AsNode()->op() == Op::kOneHotSel) { + if (!query_engine.ExactlyOneBitTrue(select_to_optimize.selector())) { + // Class B ArrayIndex lifting requires AtMostOneBitTrue AND A[0] == 0! + if (query_engine.AtMostOneBitTrue(select_to_optimize.selector()) && + IsFirstElementZero(shared_array.value(), query_engine)) { + VLOG(3) << " OneHotSelect selector is Class B, and first array " + "element " + << shared_array.value()->ToString() + << "[0] is zero. Proceeding."; + } else { + VLOG(3) << " OneHotSelect selector is not Class A, and ArrayIndex " + "is not safe. Skipping."; + return std::nullopt; + } + } + } // Build LiftedOpInfo for ArrayIndex std::vector index_operands; for (Node* case_node : cases) { @@ -451,23 +652,10 @@ absl::StatusOr> CanLiftSelect( return std::nullopt; } -absl::StatusOr MakeSelectNode(FunctionBase* func, Node* old_select, - const std::vector& new_cases, - std::optional new_default) { - if (old_select->Is()) { - return func->MakeNode( - SourceInfo(), old_select->As()->selector(), new_cases, - *new_default); - } else { - return func->MakeNode()->selector(), - new_cases, new_default); - } -} - absl::StatusOr CheckLatencyIncrease( - FunctionBase* func, Node* select_to_optimize, const LiftedOpInfo& info, - const OptimizationPassOptions& options, OptimizationContext& context) { + FunctionBase* func, GenericSelect select_to_optimize, + const LiftedOpInfo& info, const OptimizationPassOptions& options, + OptimizationContext& context) { const DelayEstimator* delay_estimator = options.delay_estimator; XLS_RET_CHECK(delay_estimator != nullptr) << "No delay estimator configured."; XLS_ASSIGN_OR_RETURN( @@ -476,7 +664,7 @@ absl::StatusOr CheckLatencyIncrease( _ << "Failed to get CriticalPathDelayAnalysis for delay model: " << delay_estimator->name()); // Check the (unscheduled) critical path through the select we're optimizing - int64_t t_before = *analysis->GetInfo(select_to_optimize); + int64_t t_before = *analysis->GetInfo(select_to_optimize.AsNode()); // To make it easy to estimate the critical path after lifting the select, we // add nodes to represent the post-optimization result. @@ -500,10 +688,10 @@ absl::StatusOr CheckLatencyIncrease( }; if (info.lifted_op == Op::kArrayIndex) { - XLS_ASSIGN_OR_RETURN( - tmp_new_select, - MakeSelectNode(func, select_to_optimize, info.other_operands, - info.default_other_operand)); + XLS_ASSIGN_OR_RETURN(tmp_new_select, + select_to_optimize.MakeSelectLikeWithNewArms( + info.other_operands, info.default_other_operand, + select_to_optimize.AsNode()->loc())); XLS_ASSIGN_OR_RETURN( tmp_lifted_op, func->MakeNode(SourceInfo(), info.shared_node, @@ -518,7 +706,7 @@ absl::StatusOr CheckLatencyIncrease( std::vector tmp_new_cases; std::optional tmp_new_default; - absl::Span original_cases = GetCases(select_to_optimize); + absl::Span original_cases = select_to_optimize.cases(); int64_t other_operand_idx = 0; for (int64_t i = 0; i < original_cases.size(); ++i) { if (info.identity_case_indices.contains(i)) { @@ -531,7 +719,7 @@ absl::StatusOr CheckLatencyIncrease( tmp_new_cases.push_back(info.other_operands[other_operand_idx++]); } } - std::optional original_default = GetDefaultValue(select_to_optimize); + std::optional original_default = select_to_optimize.default_value(); if (original_default.has_value()) { if (info.default_is_identity) { XLS_ASSIGN_OR_RETURN( @@ -544,8 +732,9 @@ absl::StatusOr CheckLatencyIncrease( } } XLS_ASSIGN_OR_RETURN(tmp_new_select, - MakeSelectNode(func, select_to_optimize, tmp_new_cases, - tmp_new_default)); + select_to_optimize.MakeSelectLikeWithNewArms( + tmp_new_cases, tmp_new_default, + select_to_optimize.AsNode()->loc())); Node* lhs = info.shared_is_lhs ? info.shared_node : tmp_new_select; Node* rhs = info.shared_is_lhs ? tmp_new_select : info.shared_node; switch (info.lifted_op) { @@ -559,6 +748,21 @@ absl::StatusOr CheckLatencyIncrease( func->MakeNode(SourceInfo(), lhs, rhs, info.lifted_op)); break; } + case Op::kEq: + case Op::kNe: + case Op::kULt: + case Op::kULe: + case Op::kUGt: + case Op::kUGe: + case Op::kSLt: + case Op::kSLe: + case Op::kSGt: + case Op::kSGe: { + XLS_ASSIGN_OR_RETURN( + tmp_lifted_op, + func->MakeNode(SourceInfo(), lhs, rhs, info.lifted_op)); + break; + } case Op::kAnd: case Op::kOr: case Op::kXor: { @@ -571,10 +775,11 @@ absl::StatusOr CheckLatencyIncrease( case Op::kUMul: case Op::kSMul: { XLS_ASSIGN_OR_RETURN( - tmp_lifted_op, func->MakeNode( - SourceInfo(), lhs, rhs, - select_to_optimize->GetType()->GetFlatBitCount(), - info.lifted_op)); + tmp_lifted_op, + func->MakeNode( + SourceInfo(), lhs, rhs, + select_to_optimize.AsNode()->GetType()->GetFlatBitCount(), + info.lifted_op)); break; } default: @@ -588,9 +793,9 @@ absl::StatusOr CheckLatencyIncrease( return t_after > t_before; } -absl::StatusOr ProfitabilityGuardForArrayIndex(FunctionBase* func, - Node* select_to_optimize, - Node* array_reference) { +absl::StatusOr ProfitabilityGuardForArrayIndex( + FunctionBase* func, GenericSelect select_to_optimize, + Node* array_reference) { // The next properties when hold guarantee that it is profitable to transform // the "select" node. // @@ -656,7 +861,7 @@ absl::StatusOr ProfitabilityGuardForArrayIndex(FunctionBase* func, Type* array_reference_type = array_reference->GetType(); ArrayType* array_reference_type_as_array_type = array_reference_type->AsArrayOrDie(); - absl::Span select_cases = GetCases(select_to_optimize); + absl::Span select_cases = select_to_optimize.cases(); Type* array_element_type = array_reference_type_as_array_type->element_type(); int64_t array_element_bitwidth = array_element_type->GetFlatBitCount(); for (Node* current_select_case_as_node : select_cases) { @@ -696,8 +901,9 @@ absl::StatusOr ProfitabilityGuardForArrayIndex(FunctionBase* func, } absl::StatusOr ProfitabilityGuardForBinaryOperation( - FunctionBase* func, Node* select_to_optimize, const LiftedOpInfo& info, - const OptimizationPassOptions& options, OptimizationContext& context) { + FunctionBase* func, GenericSelect select_to_optimize, + const LiftedOpInfo& info, const OptimizationPassOptions& options, + OptimizationContext& context) { // If lifting a shift operation combines literal shift amounts into a select, // this creates a variable shift from constant shifts, which is more // expensive. If we don't have a delay model, avoid lifting in this case. @@ -721,54 +927,69 @@ absl::StatusOr ProfitabilityGuardForBinaryOperation( // Heuristically: If the selector depends on `shared_node`, lifting will // likely serialize more operations & worsen the critical path. - Node* selector = select_to_optimize->Is()->selector() - : select_to_optimize->As()->selector(); + Node* selector = select_to_optimize.selector(); if (IsAncestorOf(info.shared_node, selector)) { VLOG(3) << " Selector depends on shared node, avoiding lift due to " "potential latency increase."; return false; } + // Calculate Cost After first, because we need `new_select_width` for the + // cost of comparison ops in "Cost Before". + Type* other_operand_type = nullptr; + if (!info.other_operands.empty()) { + other_operand_type = info.other_operands[0]->GetType(); + } else if (info.default_other_operand.has_value()) { + other_operand_type = (*info.default_other_operand)->GetType(); + } else { + // This should not happen if CanLiftSelect passed. + return false; + } + int64_t new_select_width = other_operand_type->GetFlatBitCount(); + + bool is_comparison = absl::c_contains(CompareOp::kOps, info.lifted_op); + // Calculate Cost Before: // Sum of bitwidths of the original select and any single-use non-identity // case nodes. - int64_t initial_bitwidths = select_to_optimize->GetType()->GetFlatBitCount(); - absl::Span cases = GetCases(select_to_optimize); + int64_t initial_bitwidths = + select_to_optimize.AsNode()->GetType()->GetFlatBitCount(); + absl::Span cases = select_to_optimize.cases(); for (int64_t i = 0; i < cases.size(); ++i) { if (!info.identity_case_indices.contains(i)) { Node* case_node = cases[i]; if (HasSingleUse(case_node)) { - initial_bitwidths += case_node->GetType()->GetFlatBitCount(); + if (is_comparison) { + // For comparisons, the cost of the op is proportional to its input + // bitwidth, not its 1-bit output. + initial_bitwidths += new_select_width; + } else { + initial_bitwidths += case_node->GetType()->GetFlatBitCount(); + } } } } - std::optional default_case = GetDefaultValue(select_to_optimize); + std::optional default_case = select_to_optimize.default_value(); if (default_case.has_value() && !info.default_is_identity) { if (HasSingleUse(*default_case)) { - initial_bitwidths += (*default_case)->GetType()->GetFlatBitCount(); + if (is_comparison) { + initial_bitwidths += new_select_width; + } else { + initial_bitwidths += (*default_case)->GetType()->GetFlatBitCount(); + } } } - // Calculate Cost After: - // Bitwidth of the new select + bitwidth of the lifted binary operation - // output. - Type* other_operand_type = nullptr; - if (!info.other_operands.empty()) { - other_operand_type = info.other_operands[0]->GetType(); - } else if (info.default_other_operand.has_value()) { - other_operand_type = (*info.default_other_operand)->GetType(); + // The output width of the lifted op. For comparisons, we use the input + // bitwidth as a proxy for the cost of the lifted op. + int64_t lifted_op_cost = 0; + if (is_comparison) { + lifted_op_cost = new_select_width; } else { - // This should not happen if CanLiftSelect passed. - return false; + lifted_op_cost = select_to_optimize.AsNode()->GetType()->GetFlatBitCount(); } - int64_t new_select_width = other_operand_type->GetFlatBitCount(); - // The output width of the lifted op is the same as the original select. - int64_t lifted_op_output_width = - select_to_optimize->GetType()->GetFlatBitCount(); - - int64_t remaining_bitwidths = new_select_width + lifted_op_output_width; + int64_t remaining_bitwidths = new_select_width + lifted_op_cost; VLOG(3) << " Profitability: Initial bitwidths: " << initial_bitwidths << ", Remaining bitwidths: " << remaining_bitwidths; @@ -776,7 +997,7 @@ absl::StatusOr ProfitabilityGuardForBinaryOperation( } absl::StatusOr ShouldLiftSelect(FunctionBase* func, - Node* select_to_optimize, + GenericSelect select_to_optimize, const LiftedOpInfo& info, const OptimizationPassOptions& options, OptimizationContext& context) { @@ -812,7 +1033,8 @@ absl::StatusOr ShouldLiftSelect(FunctionBase* func, } absl::StatusOr LiftSelectForArrayIndex( - FunctionBase* func, Node* select_to_optimize, const LiftedOpInfo& info) { + FunctionBase* func, GenericSelect select_to_optimize, + const LiftedOpInfo& info) { TransformationResult result; Node* array_reference = info.shared_node; @@ -823,37 +1045,37 @@ absl::StatusOr LiftSelectForArrayIndex( const std::vector& new_cases = info.other_operands; Node* new_select; - XLS_ASSIGN_OR_RETURN( - new_select, - MakeSelectNode(func, select_to_optimize, new_cases, new_default_value)); + XLS_ASSIGN_OR_RETURN(new_select, select_to_optimize.MakeSelectLikeWithNewArms( + new_cases, new_default_value, + select_to_optimize.AsNode()->loc())); // Step 1: add the new array access VLOG(3) << " Step 1: add the new arrayIndex node"; - std::vector new_indices; - new_indices.push_back(new_select); XLS_ASSIGN_OR_RETURN( Node * new_array_index, func->MakeNode(SourceInfo(), array_reference, - absl::Span(new_indices))); + absl::MakeConstSpan({new_select}))); // Step 2: replace the uses of the original "select" node with the only // exception of the new array access VLOG(3) << " Step 2: replace the uses of the original \"select\""; - XLS_RETURN_IF_ERROR(select_to_optimize->ReplaceUsesWith(new_array_index)); - VLOG(3) << " New select : " << select_to_optimize->ToString(); + XLS_RETURN_IF_ERROR( + select_to_optimize.AsNode()->ReplaceUsesWith(new_array_index)); + VLOG(3) << " New select : " + << select_to_optimize.AsNode()->ToString(); VLOG(3) << " New array index: " << new_array_index->ToString(); // Step 3: remove the original "select" node as it just became dead. This is // done by adding such node to the list of nodes to delete at the end of the // main loop of this transformation. VLOG(3) << " Step 3: mark the old \"select\" to be deleted"; - result.nodes_to_delete.insert(select_to_optimize); + result.nodes_to_delete.insert(select_to_optimize.AsNode()); // Step 4: check if new "select" nodes become optimizable. These are users of // the new arrayIndex node VLOG(3) << " Step 4: check if more \"select\" nodes should be considered"; for (Node* user : new_array_index->users()) { - if (user->OpIn({Op::kSel, Op::kPrioritySel})) { + if (user->OpIn({Op::kSel, Op::kPrioritySel, Op::kOneHotSel})) { result.new_selects_to_consider.insert(user); } } @@ -863,9 +1085,10 @@ absl::StatusOr LiftSelectForArrayIndex( } absl::StatusOr LiftSelectForBinaryOperation( - FunctionBase* func, Node* select_to_optimize, const LiftedOpInfo& info) { + FunctionBase* func, GenericSelect select_to_optimize, + const LiftedOpInfo& info) { TransformationResult result; - absl::Span original_cases = GetCases(select_to_optimize); + absl::Span original_cases = select_to_optimize.cases(); VLOG(3) << " Step 1: Build new cases for the inner select"; std::vector new_cases; @@ -898,7 +1121,7 @@ absl::StatusOr LiftSelectForBinaryOperation( VLOG(3) << " Step 2: Build new default for the inner select"; std::optional new_default; - std::optional original_default = GetDefaultValue(select_to_optimize); + std::optional original_default = select_to_optimize.default_value(); if (original_default.has_value()) { if (info.default_is_identity) { XLS_ASSIGN_OR_RETURN( @@ -913,7 +1136,8 @@ absl::StatusOr LiftSelectForBinaryOperation( XLS_ASSIGN_OR_RETURN( Node * new_select, - MakeSelectNode(func, select_to_optimize, new_cases, new_default)); + select_to_optimize.MakeSelectLikeWithNewArms( + new_cases, new_default, select_to_optimize.AsNode()->loc())); VLOG(3) << " Step 3: Create the lifted binary operation"; Node* lhs = info.shared_is_lhs ? info.shared_node : new_select; @@ -931,6 +1155,21 @@ absl::StatusOr LiftSelectForBinaryOperation( func->MakeNode(SourceInfo(), lhs, rhs, info.lifted_op)); break; } + case Op::kEq: + case Op::kNe: + case Op::kULe: + case Op::kUGe: + case Op::kULt: + case Op::kUGt: + case Op::kSLe: + case Op::kSGe: + case Op::kSLt: + case Op::kSGt: { + XLS_ASSIGN_OR_RETURN( + new_binop, + func->MakeNode(SourceInfo(), lhs, rhs, info.lifted_op)); + break; + } case Op::kAnd: case Op::kOr: case Op::kXor: { @@ -942,11 +1181,12 @@ absl::StatusOr LiftSelectForBinaryOperation( } case Op::kUMul: case Op::kSMul: { - XLS_ASSIGN_OR_RETURN(new_binop, - func->MakeNode( - SourceInfo(), lhs, rhs, - select_to_optimize->GetType()->GetFlatBitCount(), - info.lifted_op)); + XLS_ASSIGN_OR_RETURN( + new_binop, + func->MakeNode( + SourceInfo(), lhs, rhs, + select_to_optimize.AsNode()->GetType()->GetFlatBitCount(), + info.lifted_op)); } break; default: return absl::InternalError(absl::StrCat( @@ -955,16 +1195,16 @@ absl::StatusOr LiftSelectForBinaryOperation( } VLOG(3) << " Step 4: Replace uses of the original \"select\""; - XLS_RETURN_IF_ERROR(select_to_optimize->ReplaceUsesWith(new_binop)); + XLS_RETURN_IF_ERROR(select_to_optimize.AsNode()->ReplaceUsesWith(new_binop)); VLOG(3) << " New select: " << new_select->ToString(); VLOG(3) << " New binop : " << new_binop->ToString(); VLOG(3) << " Step 5: mark the old \"select\" to be deleted"; - result.nodes_to_delete.insert(select_to_optimize); + result.nodes_to_delete.insert(select_to_optimize.AsNode()); VLOG(3) << " Step 6: check if more \"select\" nodes should be considered"; for (Node* user : new_binop->users()) { - if (user->OpIn({Op::kSel, Op::kPrioritySel})) { + if (user->OpIn({Op::kSel, Op::kPrioritySel, Op::kOneHotSel})) { result.new_selects_to_consider.insert(user); } } @@ -974,9 +1214,9 @@ absl::StatusOr LiftSelectForBinaryOperation( return result; } -absl::StatusOr LiftSelect(FunctionBase* func, - Node* select_to_optimize, - const LiftedOpInfo& info) { +absl::StatusOr LiftSelect( + FunctionBase* func, GenericSelect select_to_optimize, + const LiftedOpInfo& info) { TransformationResult result; VLOG(3) << " Apply the transformation"; @@ -1000,6 +1240,12 @@ absl::StatusOr LiftSelect(FunctionBase* func, case Op::kShll: case Op::kShrl: case Op::kShra: + case Op::kEq: + case Op::kNe: + case Op::kULe: + case Op::kUGe: + case Op::kULt: + case Op::kUGt: return LiftSelectForBinaryOperation(func, select_to_optimize, info); default: @@ -1014,12 +1260,16 @@ absl::StatusOr LiftSelect(FunctionBase* func, absl::StatusOr LiftSelect( FunctionBase* func, Node* select_to_optimize, - const OptimizationPassOptions& options, OptimizationContext& context) { + const OptimizationPassOptions& options, OptimizationContext& context, + QueryEngine& query_engine) { TransformationResult result; + XLS_ASSIGN_OR_RETURN(GenericSelect sel, + GenericSelect::From(select_to_optimize)); + // Check if it is safe to apply the transformation XLS_ASSIGN_OR_RETURN(std::optional applicability_guard_result, - CanLiftSelect(func, select_to_optimize)); + CanLiftSelect(func, sel, query_engine)); if (!applicability_guard_result) { VLOG(3) << " It is not safe to apply the transformation for this select"; @@ -1031,9 +1281,8 @@ absl::StatusOr LiftSelect( // It is safe to apply the transformation // // Check if it is profitable to apply the transformation - XLS_ASSIGN_OR_RETURN( - bool should_lift, - ShouldLiftSelect(func, select_to_optimize, info, options, context)); + XLS_ASSIGN_OR_RETURN(bool should_lift, + ShouldLiftSelect(func, sel, info, options, context)); if (!should_lift) { VLOG(3) << " This transformation is not profitable for this select"; @@ -1045,7 +1294,7 @@ absl::StatusOr LiftSelect( // It is now the time to apply it. VLOG(3) << " This transformation is applicable and profitable for this " "select"; - XLS_ASSIGN_OR_RETURN(result, LiftSelect(func, select_to_optimize, info)); + XLS_ASSIGN_OR_RETURN(result, LiftSelect(func, sel, info)); return result; } @@ -1053,7 +1302,8 @@ absl::StatusOr LiftSelect( absl::StatusOr LiftSelects( FunctionBase* func, const absl::btree_set& selects_to_consider, - const OptimizationPassOptions& options, OptimizationContext& context) { + const OptimizationPassOptions& options, OptimizationContext& context, + QueryEngine& query_engine) { TransformationResult result; // Try to optimize all "select" nodes @@ -1066,8 +1316,9 @@ absl::StatusOr LiftSelects( VLOG(3) << "Select: " << select_node->ToString(); // Try to optimize the current "select" node - XLS_ASSIGN_OR_RETURN(TransformationResult current_transformation_result, - LiftSelect(func, select_node, options, context)); + XLS_ASSIGN_OR_RETURN( + TransformationResult current_transformation_result, + LiftSelect(func, select_node, options, context, query_engine)); // Accumulate the result of the transformation result.was_code_modified |= current_transformation_result.was_code_modified; @@ -1107,9 +1358,11 @@ absl::StatusOr SelectLiftingPass::RunOnFunctionBaseInternal( // Collect the "select" nodes that might be optimizable VLOG(3) << "Optimizing the function at level " << options.opt_level; + std::unique_ptr query_engine = + GetQueryEngine(func, options, context); for (Node* node : func->nodes()) { // Only consider selects. - if (!node->OpIn({Op::kSel, Op::kPrioritySel})) { + if (!node->OpIn({Op::kSel, Op::kPrioritySel, Op::kOneHotSel})) { continue; } @@ -1128,9 +1381,9 @@ absl::StatusOr SelectLiftingPass::RunOnFunctionBaseInternal( VLOG(3) << " New optimization iteration"; // Optimize all "select" nodes. - XLS_ASSIGN_OR_RETURN( - TransformationResult current_result, - LiftSelects(func, selects_to_consider, options, context)); + XLS_ASSIGN_OR_RETURN(TransformationResult current_result, + LiftSelects(func, selects_to_consider, options, + context, *query_engine)); // Check if we have modified the code. was_code_modified |= current_result.was_code_modified; diff --git a/xls/passes/select_lifting_pass_test.cc b/xls/passes/select_lifting_pass_test.cc index 408ced7674..3caaec01a9 100644 --- a/xls/passes/select_lifting_pass_test.cc +++ b/xls/passes/select_lifting_pass_test.cc @@ -35,7 +35,9 @@ #include "xls/ir/bits.h" #include "xls/ir/function.h" #include "xls/ir/function_builder.h" +#include "xls/ir/ir_matcher.h" #include "xls/ir/ir_test_base.h" +#include "xls/ir/lsb_or_msb.h" #include "xls/ir/node.h" #include "xls/ir/nodes.h" #include "xls/ir/op.h" @@ -45,8 +47,9 @@ #include "xls/passes/pass_base.h" #include "xls/solvers/z3_ir_equivalence_testutils.h" -namespace xls { +namespace m = ::xls::op_matchers; +namespace xls { namespace { class FakeDelayEstimator : public DelayEstimator { @@ -947,6 +950,384 @@ TEST_F(SelectLiftingPassTest, DontLiftMulWithIdentityIfLatencyIncreases) { EXPECT_THAT(Run(f, opts), absl_testing::IsOkAndHolds(false)); } +TEST_F(SelectLiftingPassTest, LiftCompareIfLatencyPermits) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue s = fb.Param("s", p->GetBitsType(1)); + BValue x = fb.Param("x", p->GetBitsType(32)); + BValue y = fb.Param("y", p->GetBitsType(32)); + + BValue x_ne_y = fb.Ne(x, y); + BValue zero = fb.Literal(UBits(0, 32)); + BValue x_ne_zero = fb.Ne(x, zero); + + // sel(s, [x != y, x != 0]) -> should lift to Ne(x, sel(s, [y, 0])) + BValue sel = fb.Select(s, {x_ne_y, x_ne_zero}); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(sel)); + + // Force latency check by providing a delay estimator + OptimizationPassOptions opts; + XLS_ASSERT_OK_AND_ASSIGN(opts.delay_estimator, GetDelayEstimator("unit")); + + // This should fail/crash initially with the production error! + EXPECT_THAT(Run(f, opts), absl_testing::IsOkAndHolds(true)); +} + +TEST_F(SelectLiftingPassTest, LiftThroughOneHotSelectWithOneHotSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue x = fb.Param("x", u32_type); + BValue y = fb.Param("y", u32_type); + BValue selector = fb.Param("selector", p->GetBitsType(2)); + + BValue oh_selector = fb.OneHot(selector, LsbOrMsb::kLsb); + + BValue zero = fb.Literal(UBits(0, 32)); + BValue x_add_zero = fb.Add(x, zero); + BValue x_add_y = fb.Add(x, y); + + BValue ohs = fb.OneHotSelect(oh_selector, {x_add_zero, x_add_y, x_add_zero}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(true)); + + EXPECT_THAT( + f->return_value(), + m::Add(m::Param("x"), + m::OneHotSelect(m::OneHot(m::Param("selector"), LsbOrMsb::kLsb), + {m::Literal(0), m::Param("y"), m::Literal(0)}))); +} + +TEST_F(SelectLiftingPassTest, NoLiftThroughOneHotSelectWithoutOneHotSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue x = fb.Param("x", u32_type); + BValue y = fb.Param("y", u32_type); + BValue selector = fb.Param("selector", p->GetBitsType(3)); + + BValue zero = fb.Literal(UBits(0, 32)); + BValue x_add_zero = fb.Add(x, zero); + BValue x_add_y = fb.Add(x, y); + + BValue ohs = fb.OneHotSelect(selector, {x_add_zero, x_add_y, x_add_zero}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(false)); +} + +TEST_F(SelectLiftingPassTest, LiftAndThroughOneHotSelect) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue x = fb.Param("x", u32_type); + BValue y = fb.Param("y", u32_type); + BValue z = fb.Param("z", u32_type); + BValue selector = fb.Param("selector", p->GetBitsType(2)); + + BValue x_and_y = fb.And(x, y); + BValue x_and_z = fb.And(x, z); + + BValue ohs = fb.OneHotSelect(selector, {x_and_y, x_and_z}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(true)); + + EXPECT_THAT( + f->return_value(), + m::And(m::Param("x"), m::OneHotSelect(m::Param("selector"), + {m::Param("y"), m::Param("z")}))); +} + +TEST_F(SelectLiftingPassTest, LiftSubThroughOneHotSelectWithOneHotSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue val = fb.Param("val", u32_type); + BValue x = fb.Param("x", u32_type); + BValue y = fb.Param("y", u32_type); + BValue selector = fb.Param("selector", p->GetBitsType(1)); + + BValue oh_selector = fb.OneHot(selector, LsbOrMsb::kLsb); + + BValue sub_x = fb.Subtract(val, x); + BValue sub_y = fb.Subtract(val, y); + + BValue ohs = fb.OneHotSelect(oh_selector, {sub_x, sub_y}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(true)); + + EXPECT_THAT( + f->return_value(), + m::Sub(m::Param("val"), + m::OneHotSelect(m::OneHot(m::Param("selector"), LsbOrMsb::kLsb), + {m::Param("x"), m::Param("y")}))); +} + +TEST_F(SelectLiftingPassTest, + LiftMulThroughOneHotSelectWithAtMostOneHotSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue val = fb.Param("val", u32_type); + BValue x = fb.Param("x", u32_type); + BValue a = fb.Param("a", u32_type); + BValue b = fb.Param("b", u32_type); + + BValue eq_1 = fb.Eq(x, fb.Literal(UBits(1, 32))); + BValue eq_2 = fb.Eq(x, fb.Literal(UBits(2, 32))); + BValue selector = fb.Concat({eq_2, eq_1}); + + BValue mul_a = fb.UMul(val, a); + BValue mul_b = fb.UMul(val, b); + + BValue ohs = fb.OneHotSelect(selector, {mul_a, mul_b}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(true)); + + EXPECT_THAT( + f->return_value(), + m::UMul(m::Param("val"), + m::OneHotSelect(m::Concat(m::Eq(m::Param("x"), m::Literal(2)), + m::Eq(m::Param("x"), m::Literal(1))), + {m::Param("a"), m::Param("b")}))); +} + +TEST_F(SelectLiftingPassTest, + NoLiftMulThroughOneHotSelectWithNonOneHotSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue val = fb.Param("val", u32_type); + BValue selector = fb.Param("selector", p->GetBitsType(2)); + BValue a = fb.Param("a", u32_type); + BValue b = fb.Param("b", u32_type); + + BValue mul_a = fb.UMul(val, a); + BValue mul_b = fb.UMul(val, b); + + BValue ohs = fb.OneHotSelect(selector, {mul_a, mul_b}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(false)); +} + +TEST_F(SelectLiftingPassTest, LiftOrThroughOneHotSelectWithNonzeroSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue a = fb.Param("a", u32_type); + BValue b = fb.Param("b", u32_type); + BValue param_p = fb.Param("p", p->GetBitsType(2)); + + BValue selector = fb.Or(param_p, fb.Literal(UBits(1, 2))); + + BValue constant = fb.Literal(UBits(42, 32)); + BValue or_a = fb.Or(a, constant); + BValue or_b = fb.Or(b, constant); + + BValue ohs = fb.OneHotSelect(selector, {or_a, or_b}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + // This should fail initially (RED state) + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(true)); + + EXPECT_THAT( + f->return_value(), + m::Or(m::Literal(42), m::OneHotSelect(m::Or(m::Param("p"), m::Literal(1)), + {m::Param("a"), m::Param("b")}))); +} + +TEST_F(SelectLiftingPassTest, + NoLiftOrThroughOneHotSelectWithMaybeZeroSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue a = fb.Param("a", u32_type); + BValue b = fb.Param("b", u32_type); + BValue selector = fb.Param("selector", p->GetBitsType(2)); + + BValue constant = fb.Literal(UBits(42, 32)); + BValue or_a = fb.Or(a, constant); + BValue or_b = fb.Or(b, constant); + + BValue ohs = fb.OneHotSelect(selector, {or_a, or_b}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(false)); +} + +TEST_F(SelectLiftingPassTest, + LiftArrayIndexThroughOneHotSelectWithAtMostOneHotSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue zero = fb.Literal(UBits(0, 32)); + BValue val1 = fb.Param("val1", u32_type); + BValue val2 = fb.Param("val2", u32_type); + BValue array = fb.Array({zero, val1, val2}, u32_type); + + BValue x = fb.Param("x", u32_type); + BValue eq_1 = fb.Eq(x, fb.Literal(UBits(1, 32))); + BValue eq_2 = fb.Eq(x, fb.Literal(UBits(2, 32))); + BValue selector = fb.Concat({eq_2, eq_1}); + + BValue idx0 = fb.Param("idx0", u32_type); + BValue idx1 = fb.Param("idx1", u32_type); + BValue access0 = fb.ArrayIndex(array, {idx0}); + BValue access1 = fb.ArrayIndex(array, {idx1}); + BValue ohs = fb.OneHotSelect(selector, {access0, access1}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + // This should fail initially (RED state) + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(true)); + + EXPECT_THAT( + f->return_value(), + m::ArrayIndex( + m::Array(m::Literal(0), m::Param("val1"), m::Param("val2")), + {m::OneHotSelect(m::Concat(m::Eq(m::Param("x"), m::Literal(2)), + m::Eq(m::Param("x"), m::Literal(1))), + {m::Param("idx0"), m::Param("idx1")})})); +} + +TEST_F(SelectLiftingPassTest, + NoLiftArrayIndexThroughOneHotSelectWithAtMostOneHotSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue non_zero = fb.Literal(UBits(42, 32)); + BValue val1 = fb.Param("val1", u32_type); + BValue val2 = fb.Param("val2", u32_type); + BValue array = fb.Array({non_zero, val1, val2}, u32_type); + + BValue x = fb.Param("x", u32_type); + BValue eq_1 = fb.Eq(x, fb.Literal(UBits(1, 32))); + BValue eq_2 = fb.Eq(x, fb.Literal(UBits(2, 32))); + BValue selector = fb.Concat({eq_2, eq_1}); + + BValue idx0 = fb.Param("idx0", u32_type); + BValue idx1 = fb.Param("idx1", u32_type); + BValue access0 = fb.ArrayIndex(array, {idx0}); + BValue access1 = fb.ArrayIndex(array, {idx1}); + BValue ohs = fb.OneHotSelect(selector, {access0, access1}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(false)); +} + +TEST_F(SelectLiftingPassTest, + LiftShllValueThroughOneHotSelectWithAtMostOneHotSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue a = fb.Param("a", u32_type); + BValue b = fb.Param("b", u32_type); + BValue shift_amount = fb.Param("shift_amount", u32_type); + + BValue shll_a = fb.Shll(a, shift_amount); + BValue shll_b = fb.Shll(b, shift_amount); + + BValue x = fb.Param("x", u32_type); + BValue eq_1 = fb.Eq(x, fb.Literal(UBits(1, 32))); + BValue eq_2 = fb.Eq(x, fb.Literal(UBits(2, 32))); + BValue selector = fb.Concat({eq_2, eq_1}); + + BValue ohs = fb.OneHotSelect(selector, {shll_a, shll_b}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(true)); + + EXPECT_THAT( + f->return_value(), + m::Shll(m::OneHotSelect(m::Concat(m::Eq(m::Param("x"), m::Literal(2)), + m::Eq(m::Param("x"), m::Literal(1))), + {m::Param("a"), m::Param("b")}), + m::Param("shift_amount"))); +} + +TEST_F(SelectLiftingPassTest, + LiftEqThroughOneHotSelectWithAtMostOneHotSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue a = fb.Param("a", u32_type); + BValue b = fb.Param("b", u32_type); + BValue const_42 = fb.Literal(UBits(42, 32)); + + BValue eq_a = fb.Eq(const_42, a); + BValue eq_b = fb.Eq(const_42, b); + + BValue x = fb.Param("x", u32_type); + BValue eq_1 = fb.Eq(x, fb.Literal(UBits(1, 32))); + BValue eq_2 = fb.Eq(x, fb.Literal(UBits(2, 32))); + BValue selector = fb.Concat({eq_2, eq_1}); + + BValue ohs = fb.OneHotSelect(selector, {eq_a, eq_b}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(true)); + + EXPECT_THAT( + f->return_value(), + m::Eq(m::Literal(42), + m::OneHotSelect(m::Concat(m::Eq(m::Param("x"), m::Literal(2)), + m::Eq(m::Param("x"), m::Literal(1))), + {m::Param("a"), m::Param("b")}))); +} + +TEST_F(SelectLiftingPassTest, + NoLiftEqThroughOneHotSelectWithAtMostOneHotSelectorAndUnknownShared) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* u32_type = p->GetBitsType(32); + + BValue a = fb.Param("a", u32_type); + BValue b = fb.Param("b", u32_type); + BValue unknown_val = fb.Param("unknown_val", u32_type); + + BValue eq_a = fb.Eq(unknown_val, a); + BValue eq_b = fb.Eq(unknown_val, b); + + BValue x = fb.Param("x", u32_type); + BValue eq_1 = fb.Eq(x, fb.Literal(UBits(1, 32))); + BValue eq_2 = fb.Eq(x, fb.Literal(UBits(2, 32))); + BValue selector = fb.Concat({eq_2, eq_1}); + + BValue ohs = fb.OneHotSelect(selector, {eq_a, eq_b}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs)); + + EXPECT_THAT(Run(f), absl_testing::IsOkAndHolds(false)); +} + FUZZ_TEST(IrFuzzTest, IrFuzzSelectLifting) .WithDomains(IrFuzzDomainWithArgs(/*arg_set_count=*/10)); diff --git a/xls/passes/select_simplification_pass.cc b/xls/passes/select_simplification_pass.cc index d49c8b2a61..c65c72e2f3 100644 --- a/xls/passes/select_simplification_pass.cc +++ b/xls/passes/select_simplification_pass.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -41,6 +42,7 @@ #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "absl/types/variant.h" +#include "cppitertools/zip.hpp" #include "xls/common/status/ret_check.h" #include "xls/common/status/status_macros.h" #include "xls/common/visitor.h" @@ -57,6 +59,7 @@ #include "xls/ir/node_util.h" #include "xls/ir/nodes.h" #include "xls/ir/op.h" +#include "xls/ir/source_location.h" #include "xls/ir/ternary.h" #include "xls/ir/type.h" #include "xls/ir/value.h" @@ -1288,6 +1291,47 @@ absl::StatusOr TryHoistCompareThroughSelectLike( return true; } +absl::StatusOr ElementwiseOr(absl::Span inputs, Type* type, + FunctionBase* fb, SourceInfo loc) { + XLS_RET_CHECK(!TypeHasToken(type)); + + if (inputs.empty()) { + return fb->MakeNode(loc, ZeroOfType(type)); + } + if (inputs.size() == 1) { + return inputs[0]; + } + + // For simple types, this reduces to NaryOrIfNeeded. + if (type->IsBits()) { + return NaryOrIfNeeded(fb, inputs); + } + + // First, decompose each active case into a tree of nodes, collecting the + // leaves into a tree of vectors. + LeafTypeTree> accumulator(type, std::vector{}); + for (std::vector& leaf_vector : accumulator.elements()) { + leaf_vector.reserve(inputs.size()); + } + for (Node* input : inputs) { + XLS_RET_CHECK_EQ(input->GetType(), type); + XLS_ASSIGN_OR_RETURN(LeafTypeTree input_tree, ToTreeOfNodes(input)); + for (size_t i = 0; i < accumulator.elements().size(); ++i) { + accumulator.elements()[i].push_back(input_tree.elements()[i]); + } + } + + // Now that all the leaves are aligned, OR the leaves together at each + // position, then reconstruct the complex-typed node from the resulting tree. + XLS_ASSIGN_OR_RETURN( + LeafTypeTree result_tree, + (leaf_type_tree::MapStatus>( + accumulator.AsView(), [&](const std::vector& cases) { + return NaryOrIfNeeded(fb, cases); + }))); + return FromTreeOfNodes(fb, result_tree.AsView(), "", loc); +} + absl::StatusOr SimplifyNode(Node* node, const QueryEngine& query_engine, BitProvenanceAnalysis& provenance, int64_t opt_level, bool range_analysis) { @@ -1379,36 +1423,55 @@ absl::StatusOr SimplifyNode(Node* node, const QueryEngine& query_engine, } } - // One-hot-select with a constant selector can be replaced with OR of the - // activated cases. + // One-hot-select with a constant selector. if (node->Is() && - query_engine.IsFullyKnown(node->As()->selector()) && - node->GetType()->IsBits()) { + query_engine.IsFullyKnown(node->As()->selector())) { OneHotSelect* sel = node->As(); const Bits selector = *query_engine.KnownValueAsBits(sel->selector()); - Node* replacement = nullptr; - for (int64_t i = 0; i < selector.bit_count(); ++i) { - if (selector.Get(i)) { - if (replacement == nullptr) { - replacement = sel->get_case(i); - } else { - XLS_ASSIGN_OR_RETURN( - replacement, - node->function_base()->MakeNode( - node->loc(), - std::vector{replacement, sel->get_case(i)}, Op::kOr)); + + // Case 1: Zero-active (zero-hot); returns a constant zero of the same type. + if (selector.IsZero()) { + VLOG(2) << absl::StrFormat( + "Simplifying one-hot-select with constant zero selector: %s", + node->ToString()); + XLS_ASSIGN_OR_RETURN(Node * zero, + node->function_base()->MakeNode( + node->loc(), ZeroOfType(node->GetType()))); + XLS_RETURN_IF_ERROR(sel->ReplaceUsesWith(zero)); + return true; + } + + // Case 2: Exactly one active bit (one-hot); returns the corresponding case. + if (selector.PopCount() == 1) { + int64_t active_bit = -1; + for (int64_t i = 0; i < selector.bit_count(); ++i) { + if (selector.Get(i)) { + active_bit = i; + break; } } + XLS_RET_CHECK_NE(active_bit, -1); + VLOG(2) << absl::StrFormat( + "Simplifying one-hot-select with constant one-active selector: %s", + node->ToString()); + XLS_RETURN_IF_ERROR(sel->ReplaceUsesWith(sel->get_case(active_bit))); + return true; } - if (replacement == nullptr) { - XLS_ASSIGN_OR_RETURN( - replacement, - node->function_base()->MakeNode( - node->loc(), Value(UBits(0, node->BitCountOrDie())))); - } + + // Case 3: Multi-active (multi-hot); returns the OR of the active cases. VLOG(2) << absl::StrFormat( - "Simplifying one-hot-select with constant selector: %s", + "Simplifying one-hot-select with constant multi-active selector: %s", node->ToString()); + std::vector active_cases; + active_cases.reserve(selector.PopCount()); + for (auto [bit, case_node] : iter::zip(selector, sel->cases())) { + if (bit) { + active_cases.push_back(case_node); + } + } + XLS_ASSIGN_OR_RETURN(Node * replacement, + ElementwiseOr(active_cases, node->GetType(), + node->function_base(), node->loc())); XLS_RETURN_IF_ERROR(sel->ReplaceUsesWith(replacement)); return true; } @@ -1426,10 +1489,12 @@ absl::StatusOr SimplifyNode(Node* node, const QueryEngine& query_engine, } // OneHotSelect with identical cases can be replaced with a select between one - // of the identical case and the default value where the selector is: original - // selector == 0 - if (node->Is() && node->GetType()->IsBits() && - node->BitCountOrDie() > 1) { + // of the identical case and the default value where the selector is replaced + // with "original selector != 0". + // + // Skip 1-bit or empty Bits selects to avoid bloating the IR. + if (node->Is() && + (!node->GetType()->IsBits() || node->BitCountOrDie() > 1)) { Node* selector = node->As()->selector(); absl::Span cases = node->As()->cases(); if (absl::c_all_of(cases, [&](Node* c) { return c == cases[0]; })) { diff --git a/xls/passes/select_simplification_pass_test.cc b/xls/passes/select_simplification_pass_test.cc index e99ef515b1..b850a00d1f 100644 --- a/xls/passes/select_simplification_pass_test.cc +++ b/xls/passes/select_simplification_pass_test.cc @@ -41,6 +41,7 @@ #include "xls/ir/source_location.h" #include "xls/ir/type.h" #include "xls/ir/value.h" +#include "xls/ir/value_utils.h" #include "xls/passes/optimization_pass.h" #include "xls/passes/pass_base.h" #include "xls/solvers/z3_ir_equivalence_testutils.h" @@ -1853,6 +1854,70 @@ TEST_P(SelectSimplificationPassTest, PredicatedStateReadFeedSelector) { EXPECT_THAT(Run(proc), IsOkAndHolds(false)); } +TEST_P(SelectSimplificationPassTest, + SimplifyOneHotSelectConstantSelectorZeroActiveTuple) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* tuple_type = p->GetTupleType({p->GetBitsType(32), p->GetBitsType(32)}); + BValue selector = fb.Literal(UBits(0, 2)); + BValue t0 = fb.Param("t0", tuple_type); + BValue t1 = fb.Param("t1", tuple_type); + fb.OneHotSelect(selector, {t0, t1}); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + + EXPECT_THAT(Run(f), IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), m::Literal(ZeroOfType(tuple_type))); +} + +TEST_P(SelectSimplificationPassTest, + SimplifyOneHotSelectConstantSelectorOneActiveArray) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* array_type = p->GetArrayType(3, p->GetBitsType(32)); + BValue selector = fb.Literal(UBits(2, 2)); + BValue a0 = fb.Param("a0", array_type); + BValue a1 = fb.Param("a1", array_type); + fb.OneHotSelect(selector, {a0, a1}); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + + EXPECT_THAT(Run(f), IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), m::Param("a1")); +} + +TEST_P(SelectSimplificationPassTest, SimplifyOneHotSelectIdenticalCasesTuple) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* tuple_type = p->GetTupleType({p->GetBitsType(32), p->GetBitsType(32)}); + BValue selector = fb.Param("selector", p->GetBitsType(2)); + BValue t0 = fb.Param("t0", tuple_type); + fb.OneHotSelect(selector, {t0, t0}); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + + EXPECT_THAT(Run(f), IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), + m::Select(m::Ne(m::Param("selector"), m::Literal(UBits(0, 2))), + {m::Literal(ZeroOfType(tuple_type)), m::Param("t0")})); +} + +TEST_P(SelectSimplificationPassTest, + SimplifyOneHotSelectConstantSelectorMultiActiveTuple) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + Type* tuple_type = p->GetTupleType({p->GetBitsType(32), p->GetBitsType(32)}); + BValue selector = fb.Literal(UBits(3, 2)); // Both bits active (11b) + BValue t0 = fb.Param("t0", tuple_type); + BValue t1 = fb.Param("t1", tuple_type); + fb.OneHotSelect(selector, {t0, t1}); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + + EXPECT_THAT(Run(f), IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), + m::Tuple(m::Or(m::TupleIndex(m::Param("t0"), 0), + m::TupleIndex(m::Param("t1"), 0)), + m::Or(m::TupleIndex(m::Param("t0"), 1), + m::TupleIndex(m::Param("t1"), 1)))); +} + INSTANTIATE_TEST_SUITE_P(SelectSimplificationPassTest, SelectSimplificationPassTest, testing::Values(AnalysisType::kTernary,