From 6a7a573e16f1f4a3962616e2e0d750429e442c3c Mon Sep 17 00:00:00 2001 From: Eric Astor Date: Wed, 17 Jun 2026 14:17:25 -0700 Subject: [PATCH] [opt] Extend VisibilityExprBuilder to support expression pruning We have some upcoming work that wants the ability to trim visibility expressions to remove dependencies on inputs that are "too far away" in some sense. PiperOrigin-RevId: 933919516 --- xls/passes/BUILD | 5 +- xls/passes/visibility_analysis.cc | 81 ++++- xls/passes/visibility_analysis.h | 10 + xls/passes/visibility_expr_builder.cc | 347 +++++++++++++++---- xls/passes/visibility_expr_builder.h | 68 ++-- xls/passes/visibility_expr_builder_test.cc | 368 +++++++++++++++++++-- 6 files changed, 760 insertions(+), 119 deletions(-) diff --git a/xls/passes/BUILD b/xls/passes/BUILD index cbcb73da0c..faabedff30 100644 --- a/xls/passes/BUILD +++ b/xls/passes/BUILD @@ -1662,6 +1662,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -4846,13 +4847,12 @@ cc_library( "//xls/ir:source_location", "//xls/ir:type", "//xls/ir:value", - "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) @@ -4880,6 +4880,7 @@ cc_test( "//xls/ir:function_builder", "//xls/ir:ir_matcher", "//xls/ir:ir_test_base", + "//xls/ir:op", "//xls/visualization:math_notation", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", diff --git a/xls/passes/visibility_analysis.cc b/xls/passes/visibility_analysis.cc index 30132ebf61..d21a646289 100644 --- a/xls/passes/visibility_analysis.cc +++ b/xls/passes/visibility_analysis.cc @@ -25,6 +25,7 @@ #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" @@ -538,6 +539,8 @@ BddNodeIndex OperandVisibilityAnalysis::ConditionOfUse(Node* node, return ConditionOnPredicate(node, user->As()->predicate()); } else if (user->Is()) { return ConditionOnNextUse(user->As(), node); + } else if (user->Is()) { + return ConditionOnPredicate(node, user->As()->condition()); } else if (user->OpIn({Op::kAnd, Op::kNand})) { return ConditionOfUseWithAnd(node, user->As()); } else if (user->OpIn({Op::kOr, Op::kNor})) { @@ -796,8 +799,10 @@ bool VisibilityAnalysis::IsMutuallyExclusive(Node* one, Node* other) const { return bdd.Implies(*GetInfo(one), bdd.Not(*GetInfo(other))) == bdd.one(); } -absl::StatusOr OperandVisibilityAnalysis::IsVisibilityIndependentOf( - Node* operand, Node* node, std::vector& sources) const { +namespace { + +std::vector GetVisibilityControlConditions(const Node* operand, + const Node* node) { std::vector conditions; if (node->Is()) { conditions.push_back(node->As()->selector()); @@ -807,16 +812,31 @@ absl::StatusOr OperandVisibilityAnalysis::IsVisibilityIndependentOf( conditions.push_back(*node->As()->predicate()); } else if (node->Is() && node->As()->predicate().has_value()) { conditions.push_back(*node->As()->predicate()); + } else if (node->Is()) { + conditions.push_back(node->As()->condition()); } else if (node->OpIn({Op::kAnd, Op::kOr, Op::kNand, Op::kNor})) { for (Node* other_op : node->operands()) { if (other_op != operand) { conditions.push_back(other_op); } } - } else { - return absl::InvalidArgumentError( - absl::StrFormat("Unsupported node type for visibility expression: %s", - node->ToString())); + } + return conditions; +} + +} // namespace + +absl::StatusOr OperandVisibilityAnalysis::IsVisibilityIndependentOf( + Node* operand, Node* node, std::vector& sources) const { + std::vector conditions = GetVisibilityControlConditions(operand, node); + if (conditions.empty()) { + if (!node->Is() && !node->Is()) { selector = select->As()) { - return GetVisibilityExprForSelect(node, user->As(), source, func, + is_live_source, get_remaining_delay); } else if (user->Is()) { return GetVisibilityExprForPrioritySelect(node, user->As(), - source, func); + source, func, is_live_source, + get_remaining_delay); } else if (user->Is()) { return GetVisibilityExprForPredicate(user->As()->predicate(), source, - func); + func, is_live_source, + get_remaining_delay); } else if (user->Is()) { return GetVisibilityExprForPredicate(user->As()->predicate(), source, - func); + func, is_live_source, + get_remaining_delay); + } else if (user->Is()) { + return GetVisibilityExprForPredicate(user->As()->condition(), source, + func, is_live_source, + get_remaining_delay); } else if (user->OpIn({Op::kAnd, Op::kNand})) { - return GetVisibilityExprForAnd(node, user->As(), source, func); + return GetVisibilityExprForAnd(node, user->As(), source, func, + is_live_source, get_remaining_delay); } else if (user->OpIn({Op::kOr, Op::kNor})) { - return GetVisibilityExprForOr(node, user->As(), source, func); + return GetVisibilityExprForOr(node, user->As(), source, func, + is_live_source, get_remaining_delay); } return nullptr; } // Builds predicate for node `u` being used by `v` on `func`. absl::StatusOr VisibilityBuilder::BuildVisibilityExpr( - Node* node, Node* user, Node* source, FunctionBase* func) { + Node* node, Node* user, Node* source, FunctionBase* func, + const std::function& is_live_source, + const std::function& get_remaining_delay) { auto cache_key = std::make_tuple(node, user, func); if (auto it = visibility_expr_cache_.find(cache_key); it != visibility_expr_cache_.end()) { return it->second; } - XLS_ASSIGN_OR_RETURN(Node * visibility, - BuildVisibilityExprHelper(node, user, source, func)); + XLS_ASSIGN_OR_RETURN( + Node * visibility, + BuildVisibilityExprHelper(node, user, source, func, is_live_source, + get_remaining_delay)); return visibility_expr_cache_[cache_key] = visibility; } @@ -339,7 +431,11 @@ absl::StatusOr VisibilityBuilder::BuildVisibilityIRExprFromEdges( const absl::flat_hash_set& conditional_edges, absl::flat_hash_map& node_to_visibility_ir_cache, - Literal* always_visible) { + Literal* always_visible, const std::function& is_live_source, + const std::function& get_remaining_delay) { + if (is_live_source && !is_live_source(node)) { + return always_visible; + } if (node->users().empty()) { return always_visible; } @@ -348,39 +444,81 @@ absl::StatusOr VisibilityBuilder::BuildVisibilityIRExprFromEdges( return it->second; } - absl::flat_hash_set user_visibilities; + absl::btree_set user_visibilities; for (Node* user : node->users()) { if (user->id() > prior_existing_id_) { continue; } - XLS_ASSIGN_OR_RETURN(Node * user_is_used, - BuildVisibilityIRExprFromEdges( - func, user, source, conditional_edges, - node_to_visibility_ir_cache, always_visible)); + XLS_ASSIGN_OR_RETURN( + Node * user_is_used, + BuildVisibilityIRExprFromEdges( + func, user, source, conditional_edges, node_to_visibility_ir_cache, + always_visible, is_live_source, get_remaining_delay)); Node* user_uses_node = always_visible; if (conditional_edges.contains({node, user})) { - XLS_ASSIGN_OR_RETURN(user_uses_node, - BuildVisibilityExpr(node, user, source, func)); + XLS_ASSIGN_OR_RETURN( + user_uses_node, + BuildVisibilityExpr(node, user, source, func, is_live_source, + get_remaining_delay)); if (!user_uses_node) { - return absl::InternalError(absl::StrCat( - "conditional edge exists but visibility expression could NOT be " - "made between ", - node->GetName(), " and ", user->GetName())); + user_uses_node = always_visible; } } if (user_uses_node == always_visible && user_is_used == always_visible) { return node_to_visibility_ir_cache[node] = always_visible; } + std::vector and_operands; + if (user_uses_node != always_visible) { + if (user_uses_node->op() == Op::kAnd) { + and_operands.insert(and_operands.end(), + user_uses_node->operands().begin(), + user_uses_node->operands().end()); + } else { + and_operands.push_back(user_uses_node); + } + } + if (user_is_used != always_visible) { + if (user_is_used->op() == Op::kAnd) { + and_operands.insert(and_operands.end(), + user_is_used->operands().begin(), + user_is_used->operands().end()); + } else { + and_operands.push_back(user_is_used); + } + } + Node* node_and_user_visible; - if (user_uses_node == always_visible) { - node_and_user_visible = user_is_used; - } else if (user_is_used == always_visible) { - node_and_user_visible = user_uses_node; + if (and_operands.empty()) { + node_and_user_visible = always_visible; + } else if (and_operands.size() == 1) { + node_and_user_visible = and_operands[0]; } else { XLS_ASSIGN_OR_RETURN( node_and_user_visible, - FindOrMakeBinaryNode(Op::kAnd, user_uses_node, user_is_used)); + func->MakeNode(SourceInfo(), and_operands, Op::kAnd)); + if (get_remaining_delay && + get_remaining_delay(node_and_user_visible) < 0) { + std::sort(and_operands.begin(), and_operands.end(), + [&](Node* a, Node* b) { + return get_remaining_delay(a) > get_remaining_delay(b); + }); + while (!and_operands.empty() && + get_remaining_delay(node_and_user_visible) < 0) { + and_operands.pop_back(); + if (and_operands.empty()) { + node_and_user_visible = always_visible; + break; + } + if (and_operands.size() == 1) { + node_and_user_visible = and_operands[0]; + } else { + XLS_ASSIGN_OR_RETURN( + node_and_user_visible, + func->MakeNode(SourceInfo(), and_operands, Op::kAnd)); + } + } + } } user_visibilities.insert(node_and_user_visible); } @@ -391,15 +529,30 @@ absl::StatusOr VisibilityBuilder::BuildVisibilityIRExprFromEdges( if (user_visibilities.size() == 1) { return node_to_visibility_ir_cache[node] = *user_visibilities.begin(); } - std::vector user_vis_vec{user_visibilities.begin(), - user_visibilities.end()}; - absl::c_sort(user_vis_vec, - [&](Node* a, Node* b) { return a->id() < b->id(); }); - Node* any_user_visible = user_vis_vec[0]; - for (int64_t i = 1; i < user_vis_vec.size(); ++i) { + std::vector user_vis_vec; + for (Node* n : user_visibilities) { + if (n != always_visible) { + if (n->op() == Op::kOr) { + user_vis_vec.insert(user_vis_vec.end(), n->operands().begin(), + n->operands().end()); + } else { + user_vis_vec.push_back(n); + } + } + } + + Node* any_user_visible; + if (user_vis_vec.empty()) { + any_user_visible = always_visible; + } else if (user_vis_vec.size() == 1) { + any_user_visible = user_vis_vec[0]; + } else { XLS_ASSIGN_OR_RETURN( any_user_visible, - FindOrMakeBinaryNode(Op::kOr, any_user_visible, user_vis_vec[i])); + func->MakeNode(SourceInfo(), user_vis_vec, Op::kOr)); + if (get_remaining_delay && get_remaining_delay(any_user_visible) < 0) { + any_user_visible = always_visible; + } } return node_to_visibility_ir_cache[node] = any_user_visible; } @@ -407,22 +560,80 @@ absl::StatusOr VisibilityBuilder::BuildVisibilityIRExprFromEdges( absl::StatusOr VisibilityBuilder::BuildVisibilityIRExpr( FunctionBase* func, Node* node, const absl::flat_hash_set& - conditional_edges) { + conditional_edges, + std::function is_live_source, + std::function get_remaining_delay, + std::optional target_stage) { + int64_t max_id_before = 0; + if (target_stage.has_value()) { + for (Node* n : func->nodes()) { + max_id_before = std::max(max_id_before, n->id()); + } + } + XLS_ASSIGN_OR_RETURN( Literal * always_visible, func->MakeNode(SourceInfo(), Value(UBits(1, 1)))); + + Node* result_expr = nullptr; if (conditional_edges.size() == 1) { XLS_ASSIGN_OR_RETURN( Node * user_uses_node, BuildVisibilityExpr(conditional_edges.begin()->operand, - conditional_edges.begin()->node, node, func)); - return user_uses_node ? user_uses_node : always_visible; - } - absl::flat_hash_map node_to_visibility_ir_cache; - absl::flat_hash_map, Node*> binary_op_cache; - return BuildVisibilityIRExprFromEdges(func, node, node, conditional_edges, - node_to_visibility_ir_cache, - always_visible); + conditional_edges.begin()->node, node, func, + is_live_source, get_remaining_delay)); + result_expr = user_uses_node ? user_uses_node : always_visible; + } else { + absl::flat_hash_map node_to_visibility_ir_cache; + absl::flat_hash_map, Node*> binary_op_cache; + XLS_ASSIGN_OR_RETURN( + result_expr, + BuildVisibilityIRExprFromEdges( + func, node, node, conditional_edges, node_to_visibility_ir_cache, + always_visible, is_live_source, get_remaining_delay)); + } + + if (target_stage.has_value()) { + for (Node* n : func->nodes()) { + if (n->id() > max_id_before) { + XLS_RETURN_IF_ERROR(func->AddNodeToStage(*target_stage, n).status()); + } + } + } + + return result_expr; +} + +absl::Status VisibilityBuilder::CleanUpUnusedNodes(FunctionBase* fb) { + std::vector worklist; + absl::flat_hash_set dead_nodes; + + for (Node* n : fb->nodes()) { + if (n->id() > prior_existing_id_ && n->IsDead()) { + dead_nodes.insert(n); + worklist.push_back(n); + } + } + + while (!worklist.empty()) { + Node* n = worklist.back(); + worklist.pop_back(); + + std::vector operands(n->operands().begin(), n->operands().end()); + + CHECK_GT(dead_nodes.erase(n), 0); + XLS_RETURN_IF_ERROR(fb->RemoveNode(n)); + + for (Node* operand : operands) { + if (operand->id() > prior_existing_id_ && operand->IsDead()) { + if (auto [_, inserted] = dead_nodes.insert(operand); inserted) { + worklist.push_back(operand); + } + } + } + } + + return absl::OkStatus(); } absl::StatusOr diff --git a/xls/passes/visibility_expr_builder.h b/xls/passes/visibility_expr_builder.h index 10a253c5e8..c2ff65cb50 100644 --- a/xls/passes/visibility_expr_builder.h +++ b/xls/passes/visibility_expr_builder.h @@ -16,6 +16,7 @@ #define XLS_PASSES_VISIBILITY_EXPR_BUILDER_H_ #include +#include #include #include @@ -100,38 +101,55 @@ class VisibilityBuilder : public ExpressionBuilder { absl::StatusOr BuildVisibilityIRExpr( FunctionBase* func, Node* node, const absl::flat_hash_set& - conditional_edges); + conditional_edges, + std::function is_live_source = nullptr, + std::function get_remaining_delay = nullptr, + std::optional target_stage = std::nullopt); + + // Cleans up any nodes created by this builder that have no users. + // This is useful for removing dead visibility expressions. + absl::Status CleanUpUnusedNodes(FunctionBase* fb); private: absl::StatusOr MakeParamIfTmpFunc(Node* node, FunctionBase* func) { return func == TmpFunc() ? TmpFuncNodeOrParam(node) : node; } - absl::StatusOr GetSelectorIfIndependent(Node* node, Node* select, - Node* source, - FunctionBase* func); + absl::StatusOr GetSelectorIfIndependent( + Node* node, Node* select, Node* source, FunctionBase* func, + const std::function& is_live_source); bool DoesCaseImplyNoPrevCase(PrioritySelect* select, int64_t case_index); absl::StatusOr GetVisibilityExprForPrioritySelect( - Node* node, PrioritySelect* select, Node* source, FunctionBase* func); - absl::StatusOr GetVisibilityExprForSelect(Node* node, Select* select, - Node* source, - FunctionBase* func); - absl::StatusOr GetVisibilityExprForAnd(Node* node, NaryOp* and_node, - Node* source, - FunctionBase* func); - absl::StatusOr GetVisibilityExprForOr(Node* node, NaryOp* or_node, - Node* source, - FunctionBase* func); + Node* node, PrioritySelect* select, Node* source, FunctionBase* func, + const std::function& is_live_source, + const std::function& get_remaining_delay); + absl::StatusOr GetVisibilityExprForSelect( + Node* node, Select* select, Node* source, FunctionBase* func, + const std::function& is_live_source, + const std::function& get_remaining_delay); + absl::StatusOr GetVisibilityExprForAnd( + Node* node, NaryOp* and_node, Node* source, FunctionBase* func, + const std::function& is_live_source, + const std::function& get_remaining_delay); + absl::StatusOr GetVisibilityExprForOr( + Node* node, NaryOp* or_node, Node* source, FunctionBase* func, + const std::function& is_live_source, + const std::function& get_remaining_delay); absl::StatusOr GetVisibilityExprForPredicate( - std::optional predicate, Node* source, FunctionBase* func); - - absl::StatusOr BuildVisibilityExprHelper(Node* node, Node* user, - Node* source, - FunctionBase* func); + std::optional predicate, Node* source, FunctionBase* func, + const std::function& is_live_source, + const std::function& get_remaining_delay); + + absl::StatusOr BuildVisibilityExprHelper( + Node* node, Node* user, Node* source, FunctionBase* func, + const std::function& is_live_source, + const std::function& get_remaining_delay); // Builds predicate for node `u` being used by `v` on `func`. - absl::StatusOr BuildVisibilityExpr(Node* node, Node* user, - Node* source, FunctionBase* func); + absl::StatusOr BuildVisibilityExpr( + Node* node, Node* user, Node* source, FunctionBase* func, + const std::function& is_live_source, + const std::function& get_remaining_delay); absl::StatusOr BuildNodeAndUserVisibleExpr(FunctionBase* func, Node* user_uses_node, Node* user_is_used, @@ -142,10 +160,12 @@ class VisibilityBuilder : public ExpressionBuilder { const absl::flat_hash_set& conditional_edges, absl::flat_hash_map& node_to_visibility_ir_cache, - Literal* always_visible); + Literal* always_visible, const std::function& is_live_source, + const std::function& get_remaining_delay); - absl::StatusOr GetNonRepeatedSourceOf(Node* operand, - FunctionBase* func); + absl::StatusOr GetNonRepeatedSourceOf( + Node* operand, FunctionBase* func, + const std::function& is_live_source); }; class VisibilityEstimator : public VisibilityBuilder { diff --git a/xls/passes/visibility_expr_builder_test.cc b/xls/passes/visibility_expr_builder_test.cc index 81c0f87b32..716b00f095 100644 --- a/xls/passes/visibility_expr_builder_test.cc +++ b/xls/passes/visibility_expr_builder_test.cc @@ -14,6 +14,7 @@ #include "xls/passes/visibility_expr_builder.h" +#include #include #include #include @@ -33,6 +34,9 @@ #include "xls/ir/function_builder.h" #include "xls/ir/ir_matcher.h" #include "xls/ir/ir_test_base.h" +#include "xls/ir/nodes.h" +#include "xls/ir/op.h" +#include "xls/ir/scheduled_builder.h" #include "xls/passes/bdd_query_engine.h" #include "xls/passes/bit_provenance_analysis.h" #include "xls/passes/node_dependency_analysis.h" @@ -113,20 +117,19 @@ TEST_F(VisibilityExprBuilderTest, ExampleInFunctionHeaderComment) { VLOG(3) << "is 'y' used:\n" << ToMathNotation(is_y_used.first); EXPECT_THAT(is_x_used.first, - m::Or(m::Or(m::Eq(m::Param("op1"), m::Literal(0)), - m::Eq(m::Param("op1"), m::Literal(2))), + m::Or(m::Eq(m::Param("op1"), m::Literal(0)), + m::Eq(m::Param("op1"), m::Literal(2)), m::And(m::ULt(m::Param("op2"), m::Literal(5)), m::UGe(m::Param("op1"), m::Literal(3))))); - // 6 instead of 7 because of the re-use of the ult. - EXPECT_EQ(is_x_used.second.area, 6); + // 5 instead of 6 because of flat OR gate consolidation. + EXPECT_EQ(is_x_used.second.area, 5); EXPECT_EQ(is_x_used.second.delay, 3); - EXPECT_THAT(is_y_used.first, - m::Or(m::Or(m::Eq(m::Param("op1"), m::Literal(1)), - m::UGe(m::Param("op1"), m::Literal(3))), - m::Or(m::Eq(m::Param("op1"), m::Literal(0)), - m::Eq(m::Param("op1"), m::Literal(2))))); - EXPECT_EQ(is_y_used.second.area, 7); - EXPECT_EQ(is_y_used.second.delay, 3); + EXPECT_THAT(is_y_used.first, m::Or(m::Eq(m::Param("op1"), m::Literal(1)), + m::UGe(m::Param("op1"), m::Literal(3)), + m::Eq(m::Param("op1"), m::Literal(0)), + m::Eq(m::Param("op1"), m::Literal(2)))); + EXPECT_EQ(is_y_used.second.area, 5); + EXPECT_EQ(is_y_used.second.delay, 2); // Now that the returned expression must be mutually exclusive with z's // visibility, it must condition on the selection criteria of 'select2'. @@ -136,11 +139,11 @@ TEST_F(VisibilityExprBuilderTest, ExampleInFunctionHeaderComment) { VLOG(3) << "is 'x' used and 'z' not used:\n" << ToMathNotation(is_x_used_and_z_not.first); EXPECT_THAT(is_x_used_and_z_not.first, - m::Or(m::Or(m::Eq(m::Param("op1"), m::Literal(0)), - m::Eq(m::Param("op1"), m::Literal(2))), + m::Or(m::Eq(m::Param("op1"), m::Literal(0)), + m::Eq(m::Param("op1"), m::Literal(2)), m::UGe(m::Param("op1"), m::Literal(3)))); - EXPECT_EQ(is_x_used_and_z_not.second.area, 5); - EXPECT_EQ(is_x_used_and_z_not.second.delay, 3); + EXPECT_EQ(is_x_used_and_z_not.second.area, 4); + EXPECT_EQ(is_x_used_and_z_not.second.delay, 2); } TEST_F(VisibilityExprBuilderTest, PrioritySelectOneHot) { @@ -207,9 +210,9 @@ TEST_F(VisibilityExprBuilderTest, Ors) { std::pair is_y_used; XLS_ASSERT_OK_AND_ASSIGN(is_y_used, BuildDefaultVisibilityExpr(f, y.node(), {})); - ASSERT_THAT(is_y_used.first, - m::Ne(m::BitSlice(m::Param("x"), 1, 2), m::Literal(3))); - EXPECT_EQ(is_y_used.first->operand(0), bits12.node()); + EXPECT_THAT(is_y_used.first, + m::Or(m::Not(m::BitSlice(m::Param("x"), 1, 1)), + m::Ne(m::BitSlice(m::Param("x"), 1, 2), m::Literal(3)))); } TEST_F(VisibilityExprBuilderTest, Ands) { @@ -226,7 +229,9 @@ TEST_F(VisibilityExprBuilderTest, Ands) { std::pair is_y_used; XLS_ASSERT_OK_AND_ASSIGN(is_y_used, BuildDefaultVisibilityExpr(f, y.node(), {})); - EXPECT_EQ(is_y_used.first, bit1.node()); + EXPECT_THAT(is_y_used.first, + m::And(m::Ne(m::BitSlice(m::Param("x"), 1, 2), m::Literal(0)), + m::BitSlice(m::Param("x"), 1, 1))); } TEST_F(VisibilityExprBuilderTest, FindsSourceOfOperandInComparison) { @@ -260,5 +265,330 @@ TEST_F(VisibilityExprBuilderTest, NotAFunctionOfSelf) { EXPECT_THAT(is_x_used.first, m::Ne(m::Param("y"), m::Literal(7))); } +TEST_F(VisibilityExprBuilderTest, LivenessHalting) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetBitsType(32)); + BValue y = fb.Param("y", p->GetBitsType(32)); + BValue z = fb.Param("z", p->GetBitsType(32)); + BValue op1 = fb.Param("op1", p->GetBitsType(4)); + BValue op2 = fb.Param("op2", p->GetBitsType(4)); + BValue select1 = fb.Select(op1, {x, y, x}, y); + BValue lt1 = fb.ULt(op2, fb.Literal(UBits(5, 4))); + BValue and1 = fb.And(x, fb.SignExtend(lt1, 32)); + BValue select2 = fb.Select(op1, {y, z, y}, and1); + BValue ret = fb.Tuple({select1, select2}); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ret)); + + NodeForwardDependencyAnalysis nda; + XLS_ASSERT_OK(nda.Attach(f).status()); + LazyPostDominatorAnalysis post_dom; + XLS_ASSERT_OK(post_dom.Attach(f).status()); + std::unique_ptr bdd_engine = BddQueryEngine::MakeDefault(); + XLS_ASSERT_OK(bdd_engine->Populate(f).status()); + XLS_ASSERT_OK_AND_ASSIGN( + auto operand_visibility, + OperandVisibilityAnalysis::Create(&nda, bdd_engine.get())); + XLS_ASSERT_OK_AND_ASSIGN( + auto visibility, VisibilityAnalysis::Create(&operand_visibility, + bdd_engine.get(), &post_dom)); + XLS_ASSERT_OK_AND_ASSIGN( + auto conditional_edges, + visibility->GetEdgesForMutuallyExclusiveVisibilityExpr(x.node(), {}, -1)); + + auto last_node_id = f->nodes_reversed().begin()->id(); + XLS_ASSERT_OK_AND_ASSIGN(AreaEstimator * ae, GetAreaEstimator("unit")); + XLS_ASSERT_OK_AND_ASSIGN(DelayEstimator * de, GetDelayEstimator("unit")); + BitProvenanceAnalysis bpa; + VisibilityEstimator estimator(last_node_id, bdd_engine.get(), nda, bpa, ae, + de); + + // A callback that says 'and1' is not live! + auto is_live_source = [&](Node* n) { return n != and1.node(); }; + + XLS_ASSERT_OK_AND_ASSIGN(Node * expr, + estimator.BuildVisibilityIRExpr( + f, x.node(), conditional_edges, is_live_source)); + + // Because and1 is ignored, we only accumulate its immediate condition (lt1 != + // 0) combined with the conditions from select1! + EXPECT_THAT(expr, m::Or(m::ULt(m::Param("op2"), m::Literal(5)), + m::Eq(m::Param("op1"), m::Literal(0)), + m::Eq(m::Param("op1"), m::Literal(2)))); +} + +TEST_F(VisibilityExprBuilderTest, AndPruning) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue y = fb.Param("y", p->GetBitsType(1)); + BValue a = fb.Param("a", p->GetBitsType(1)); + BValue b = fb.Param("b", p->GetBitsType(1)); + BValue c = fb.Param("c", p->GetBitsType(1)); + + BValue and_gate = fb.And({a, b, c, y}); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(and_gate)); + + NodeForwardDependencyAnalysis nda; + XLS_ASSERT_OK(nda.Attach(f).status()); + LazyPostDominatorAnalysis post_dom; + XLS_ASSERT_OK(post_dom.Attach(f).status()); + std::unique_ptr bdd_engine = BddQueryEngine::MakeDefault(); + XLS_ASSERT_OK(bdd_engine->Populate(f).status()); + XLS_ASSERT_OK_AND_ASSIGN( + auto operand_visibility, + OperandVisibilityAnalysis::Create(&nda, bdd_engine.get())); + XLS_ASSERT_OK_AND_ASSIGN( + auto visibility, VisibilityAnalysis::Create(&operand_visibility, + bdd_engine.get(), &post_dom)); + XLS_ASSERT_OK_AND_ASSIGN( + auto conditional_edges, + visibility->GetEdgesForMutuallyExclusiveVisibilityExpr(y.node(), {}, -1)); + + auto last_node_id = f->nodes_reversed().begin()->id(); + XLS_ASSERT_OK_AND_ASSIGN(AreaEstimator * ae, GetAreaEstimator("unit")); + XLS_ASSERT_OK_AND_ASSIGN(DelayEstimator * de, GetDelayEstimator("unit")); + BitProvenanceAnalysis bpa; + VisibilityEstimator estimator(last_node_id, bdd_engine.get(), nda, bpa, ae, + de); + + // Mock slacks: we make 'c' have worst slack (10), then 'b' (20), then 'a' + // (30). + auto get_remaining_delay = [&](Node* n) -> int64_t { + if (n == c.node()) { + return 10; + } + if (n == b.node()) { + return 20; + } + if (n == a.node()) { + return 30; + } + + // For the resulting AND, if it includes all three expected operands, we + // return -1 (exceeds limit!). Otherwise, we return a positive slack. + if (n->op() == Op::kAnd && n->operands().size() == 3) { + return -1; + } + if (n->op() == Op::kAnd && n->operands().size() == 2) { + return 5; + } + return 100; + }; + + XLS_ASSERT_OK_AND_ASSIGN( + Node * expr, estimator.BuildVisibilityIRExpr( + f, y.node(), conditional_edges, + /*is_live_source=*/nullptr, get_remaining_delay)); + + // The expected result should omit 'c'! So it should be AND(a, b)! + EXPECT_THAT(expr, m::And(m::Param("a"), m::Param("b"))); +} + +TEST_F(VisibilityExprBuilderTest, OrPruning) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetBitsType(3)); + BValue y = fb.Param("y", p->GetBitsType(2)); + BValue bit1 = fb.BitSlice(x, 1, 1); + BValue or_y = fb.Or(fb.SignExtend(bit1, 2), y); + BValue bits12 = fb.BitSlice(x, 1, 2); + BValue or_y2 = fb.Or(bits12, y); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, + fb.BuildWithReturnValue(fb.Tuple({or_y, or_y2}))); + + NodeForwardDependencyAnalysis nda; + XLS_ASSERT_OK(nda.Attach(f).status()); + LazyPostDominatorAnalysis post_dom; + XLS_ASSERT_OK(post_dom.Attach(f).status()); + std::unique_ptr bdd_engine = BddQueryEngine::MakeDefault(); + XLS_ASSERT_OK(bdd_engine->Populate(f).status()); + XLS_ASSERT_OK_AND_ASSIGN( + auto operand_visibility, + OperandVisibilityAnalysis::Create(&nda, bdd_engine.get())); + XLS_ASSERT_OK_AND_ASSIGN( + auto visibility, VisibilityAnalysis::Create(&operand_visibility, + bdd_engine.get(), &post_dom)); + XLS_ASSERT_OK_AND_ASSIGN( + auto conditional_edges, + visibility->GetEdgesForMutuallyExclusiveVisibilityExpr(y.node(), {}, -1)); + + auto last_node_id = f->nodes_reversed().begin()->id(); + XLS_ASSERT_OK_AND_ASSIGN(AreaEstimator * ae, GetAreaEstimator("unit")); + XLS_ASSERT_OK_AND_ASSIGN(DelayEstimator * de, GetDelayEstimator("unit")); + BitProvenanceAnalysis bpa; + VisibilityEstimator estimator(last_node_id, bdd_engine.get(), nda, bpa, ae, + de); + + // Mock delay limit! If the OR node contains conditions from bits12 (which + // corresponds to user 2), it will exceed timing slack limits! + auto get_remaining_delay = [&](Node* n) -> int64_t { + if (n->op() == Op::kOr) { + for (Node* operand : n->operands()) { + if (operand->Is() && operand->operand(0) == bits12.node()) { + return -1; + } + } + } + return 100; + }; + + XLS_ASSERT_OK_AND_ASSIGN( + Node * expr, estimator.BuildVisibilityIRExpr( + f, y.node(), conditional_edges, + /*is_live_source=*/nullptr, get_remaining_delay)); + + // The expected result should be true because the operation is pruned to + // literal(1)! + EXPECT_THAT(expr, m::Literal(1)); +} + +TEST_F(VisibilityExprBuilderTest, SelectPruning) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetBitsType(32)); + BValue y = fb.Param("y", p->GetBitsType(32)); + BValue op = fb.Param("op", p->GetBitsType(1)); + BValue sel = fb.Select(op, {x}, y); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(sel)); + + NodeForwardDependencyAnalysis nda; + XLS_ASSERT_OK(nda.Attach(f).status()); + LazyPostDominatorAnalysis post_dom; + XLS_ASSERT_OK(post_dom.Attach(f).status()); + std::unique_ptr bdd_engine = BddQueryEngine::MakeDefault(); + XLS_ASSERT_OK(bdd_engine->Populate(f).status()); + XLS_ASSERT_OK_AND_ASSIGN( + auto operand_visibility, + OperandVisibilityAnalysis::Create(&nda, bdd_engine.get())); + XLS_ASSERT_OK_AND_ASSIGN( + auto visibility, VisibilityAnalysis::Create(&operand_visibility, + bdd_engine.get(), &post_dom)); + XLS_ASSERT_OK_AND_ASSIGN( + auto conditional_edges, + visibility->GetEdgesForMutuallyExclusiveVisibilityExpr(x.node(), {}, -1)); + + auto last_node_id = f->nodes_reversed().begin()->id(); + XLS_ASSERT_OK_AND_ASSIGN(AreaEstimator * ae, GetAreaEstimator("unit")); + XLS_ASSERT_OK_AND_ASSIGN(DelayEstimator * de, GetDelayEstimator("unit")); + BitProvenanceAnalysis bpa; + VisibilityEstimator estimator(last_node_id, bdd_engine.get(), nda, bpa, ae, + de); + + auto get_remaining_delay = [&](Node* n) -> int64_t { + if (n->op() == Op::kOr || n->op() == Op::kEq) { + return -1; + } + return 100; + }; + + XLS_ASSERT_OK_AND_ASSIGN( + Node * expr, estimator.BuildVisibilityIRExpr( + f, x.node(), conditional_edges, + /*is_live_source=*/nullptr, get_remaining_delay)); + + EXPECT_THAT(expr, m::Literal(1)); +} + +TEST_F(VisibilityExprBuilderTest, PrioritySelectPruning) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetBitsType(32)); + BValue y = fb.Param("y", p->GetBitsType(32)); + BValue sel = fb.Param("sel", p->GetBitsType(2)); + BValue prio = fb.PrioritySelect(sel, {x, y}, fb.Literal(UBits(0, 32))); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(prio)); + + NodeForwardDependencyAnalysis nda; + XLS_ASSERT_OK(nda.Attach(f).status()); + LazyPostDominatorAnalysis post_dom; + XLS_ASSERT_OK(post_dom.Attach(f).status()); + std::unique_ptr bdd_engine = BddQueryEngine::MakeDefault(); + XLS_ASSERT_OK(bdd_engine->Populate(f).status()); + XLS_ASSERT_OK_AND_ASSIGN( + auto operand_visibility, + OperandVisibilityAnalysis::Create(&nda, bdd_engine.get())); + XLS_ASSERT_OK_AND_ASSIGN( + auto visibility, VisibilityAnalysis::Create(&operand_visibility, + bdd_engine.get(), &post_dom)); + XLS_ASSERT_OK_AND_ASSIGN( + auto conditional_edges, + visibility->GetEdgesForMutuallyExclusiveVisibilityExpr(x.node(), {}, -1)); + + auto last_node_id = f->nodes_reversed().begin()->id(); + XLS_ASSERT_OK_AND_ASSIGN(AreaEstimator * ae, GetAreaEstimator("unit")); + XLS_ASSERT_OK_AND_ASSIGN(DelayEstimator * de, GetDelayEstimator("unit")); + BitProvenanceAnalysis bpa; + VisibilityEstimator estimator(last_node_id, bdd_engine.get(), nda, bpa, ae, + de); + + auto get_remaining_delay = [&](Node* n) -> int64_t { + if (n->op() == Op::kOr || n->op() == Op::kEq || n->op() == Op::kBitSlice) { + return -1; + } + return 100; + }; + + XLS_ASSERT_OK_AND_ASSIGN( + Node * expr, estimator.BuildVisibilityIRExpr( + f, x.node(), conditional_edges, + /*is_live_source=*/nullptr, get_remaining_delay)); + + EXPECT_THAT(expr, m::Literal(1)); +} + +TEST_F(VisibilityExprBuilderTest, TargetStageEnforcement) { + auto p = CreatePackage(); + ScheduledFunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetBitsType(32)); + BValue y = fb.Param("y", p->GetBitsType(32)); + BValue select1 = fb.Select(fb.Literal(UBits(1, 1)), {x}, y); + XLS_ASSERT_OK_AND_ASSIGN(ScheduledFunction * f, + fb.BuildWithReturnValue(select1)); + + // Add empty stages to make it a scheduled function! + f->AddEmptyStages(3); + + NodeForwardDependencyAnalysis nda; + XLS_ASSERT_OK(nda.Attach(f).status()); + LazyPostDominatorAnalysis post_dom; + XLS_ASSERT_OK(post_dom.Attach(f).status()); + std::unique_ptr bdd_engine = BddQueryEngine::MakeDefault(); + XLS_ASSERT_OK(bdd_engine->Populate(f).status()); + XLS_ASSERT_OK_AND_ASSIGN( + auto operand_visibility, + OperandVisibilityAnalysis::Create(&nda, bdd_engine.get())); + XLS_ASSERT_OK_AND_ASSIGN( + auto visibility, VisibilityAnalysis::Create(&operand_visibility, + bdd_engine.get(), &post_dom)); + XLS_ASSERT_OK_AND_ASSIGN( + auto conditional_edges, + visibility->GetEdgesForMutuallyExclusiveVisibilityExpr(x.node(), {}, -1)); + + auto last_node_id = f->nodes_reversed().begin()->id(); + XLS_ASSERT_OK_AND_ASSIGN(AreaEstimator * ae, GetAreaEstimator("unit")); + XLS_ASSERT_OK_AND_ASSIGN(DelayEstimator * de, GetDelayEstimator("unit")); + BitProvenanceAnalysis bpa; + VisibilityEstimator estimator(last_node_id, bdd_engine.get(), nda, bpa, ae, + de); + + int64_t kTargetStage = 2; + XLS_ASSERT_OK_AND_ASSIGN( + Node * expr, + estimator.BuildVisibilityIRExpr( + f, x.node(), conditional_edges, /*is_live_source=*/nullptr, + /*get_remaining_delay=*/nullptr, kTargetStage)); + (void)expr; + + // All nodes created during BuildVisibilityIRExpr should be assigned to stage + // 2! + for (Node* n : f->nodes()) { + if (n->id() > last_node_id) { + EXPECT_TRUE(f->IsStaged(n)); + XLS_ASSERT_OK_AND_ASSIGN(int64_t stage, f->GetStageIndex(n)); + EXPECT_EQ(stage, kTargetStage); + } + } +} + } // namespace } // namespace xls