diff --git a/xls/passes/BUILD b/xls/passes/BUILD index cbcb73da0c..faabedff30 100644 --- a/xls/passes/BUILD +++ b/xls/passes/BUILD @@ -1662,6 +1662,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -4846,13 +4847,12 @@ cc_library( "//xls/ir:source_location", "//xls/ir:type", "//xls/ir:value", - "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) @@ -4880,6 +4880,7 @@ cc_test( "//xls/ir:function_builder", "//xls/ir:ir_matcher", "//xls/ir:ir_test_base", + "//xls/ir:op", "//xls/visualization:math_notation", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", diff --git a/xls/passes/visibility_analysis.cc b/xls/passes/visibility_analysis.cc index 30132ebf61..d21a646289 100644 --- a/xls/passes/visibility_analysis.cc +++ b/xls/passes/visibility_analysis.cc @@ -25,6 +25,7 @@ #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" @@ -538,6 +539,8 @@ BddNodeIndex OperandVisibilityAnalysis::ConditionOfUse(Node* node, return ConditionOnPredicate(node, user->As()->predicate()); } else if (user->Is()) { return ConditionOnNextUse(user->As(), node); + } else if (user->Is()) { + return ConditionOnPredicate(node, user->As()->condition()); } else if (user->OpIn({Op::kAnd, Op::kNand})) { return ConditionOfUseWithAnd(node, user->As()); } else if (user->OpIn({Op::kOr, Op::kNor})) { @@ -796,8 +799,10 @@ bool VisibilityAnalysis::IsMutuallyExclusive(Node* one, Node* other) const { return bdd.Implies(*GetInfo(one), bdd.Not(*GetInfo(other))) == bdd.one(); } -absl::StatusOr OperandVisibilityAnalysis::IsVisibilityIndependentOf( - Node* operand, Node* node, std::vector& sources) const { +namespace { + +std::vector GetVisibilityControlConditions(const Node* operand, + const Node* node) { std::vector conditions; if (node->Is()) { conditions.push_back(node->As()->selector()); @@ -807,16 +812,31 @@ absl::StatusOr OperandVisibilityAnalysis::IsVisibilityIndependentOf( conditions.push_back(*node->As()->predicate()); } else if (node->Is() && node->As()->predicate().has_value()) { conditions.push_back(*node->As()->predicate()); + } else if (node->Is()) { + conditions.push_back(node->As()->condition()); } else if (node->OpIn({Op::kAnd, Op::kOr, Op::kNand, Op::kNor})) { for (Node* other_op : node->operands()) { if (other_op != operand) { conditions.push_back(other_op); } } - } else { - return absl::InvalidArgumentError( - absl::StrFormat("Unsupported node type for visibility expression: %s", - node->ToString())); + } + return conditions; +} + +} // namespace + +absl::StatusOr OperandVisibilityAnalysis::IsVisibilityIndependentOf( + Node* operand, Node* node, std::vector& sources) const { + std::vector conditions = GetVisibilityControlConditions(operand, node); + if (conditions.empty()) { + if (!node->Is() && !node->Is()) { selector = select->As()) { - return GetVisibilityExprForSelect(node, user->As(), source, func, + is_live_source, get_remaining_delay); } else if (user->Is()) { return GetVisibilityExprForPrioritySelect(node, user->As(), - source, func); + source, func, is_live_source, + get_remaining_delay); } else if (user->Is()) { return GetVisibilityExprForPredicate(user->As()->predicate(), source, - func); + func, is_live_source, + get_remaining_delay); } else if (user->Is()) { return GetVisibilityExprForPredicate(user->As()->predicate(), source, - func); + func, is_live_source, + get_remaining_delay); + } else if (user->Is()) { + return GetVisibilityExprForPredicate(user->As()->condition(), source, + func, is_live_source, + get_remaining_delay); } else if (user->OpIn({Op::kAnd, Op::kNand})) { - return GetVisibilityExprForAnd(node, user->As(), source, func); + return GetVisibilityExprForAnd(node, user->As(), source, func, + is_live_source, get_remaining_delay); } else if (user->OpIn({Op::kOr, Op::kNor})) { - return GetVisibilityExprForOr(node, user->As(), source, func); + return GetVisibilityExprForOr(node, user->As(), source, func, + is_live_source, get_remaining_delay); } return nullptr; } // Builds predicate for node `u` being used by `v` on `func`. absl::StatusOr VisibilityBuilder::BuildVisibilityExpr( - Node* node, Node* user, Node* source, FunctionBase* func) { + Node* node, Node* user, Node* source, FunctionBase* func, + const std::function& is_live_source, + const std::function& get_remaining_delay) { auto cache_key = std::make_tuple(node, user, func); if (auto it = visibility_expr_cache_.find(cache_key); it != visibility_expr_cache_.end()) { return it->second; } - XLS_ASSIGN_OR_RETURN(Node * visibility, - BuildVisibilityExprHelper(node, user, source, func)); + XLS_ASSIGN_OR_RETURN( + Node * visibility, + BuildVisibilityExprHelper(node, user, source, func, is_live_source, + get_remaining_delay)); return visibility_expr_cache_[cache_key] = visibility; } @@ -339,7 +431,11 @@ absl::StatusOr VisibilityBuilder::BuildVisibilityIRExprFromEdges( const absl::flat_hash_set& conditional_edges, absl::flat_hash_map& node_to_visibility_ir_cache, - Literal* always_visible) { + Literal* always_visible, const std::function& is_live_source, + const std::function& get_remaining_delay) { + if (is_live_source && !is_live_source(node)) { + return always_visible; + } if (node->users().empty()) { return always_visible; } @@ -348,39 +444,81 @@ absl::StatusOr VisibilityBuilder::BuildVisibilityIRExprFromEdges( return it->second; } - absl::flat_hash_set user_visibilities; + absl::btree_set user_visibilities; for (Node* user : node->users()) { if (user->id() > prior_existing_id_) { continue; } - XLS_ASSIGN_OR_RETURN(Node * user_is_used, - BuildVisibilityIRExprFromEdges( - func, user, source, conditional_edges, - node_to_visibility_ir_cache, always_visible)); + XLS_ASSIGN_OR_RETURN( + Node * user_is_used, + BuildVisibilityIRExprFromEdges( + func, user, source, conditional_edges, node_to_visibility_ir_cache, + always_visible, is_live_source, get_remaining_delay)); Node* user_uses_node = always_visible; if (conditional_edges.contains({node, user})) { - XLS_ASSIGN_OR_RETURN(user_uses_node, - BuildVisibilityExpr(node, user, source, func)); + XLS_ASSIGN_OR_RETURN( + user_uses_node, + BuildVisibilityExpr(node, user, source, func, is_live_source, + get_remaining_delay)); if (!user_uses_node) { - return absl::InternalError(absl::StrCat( - "conditional edge exists but visibility expression could NOT be " - "made between ", - node->GetName(), " and ", user->GetName())); + user_uses_node = always_visible; } } if (user_uses_node == always_visible && user_is_used == always_visible) { return node_to_visibility_ir_cache[node] = always_visible; } + std::vector and_operands; + if (user_uses_node != always_visible) { + if (user_uses_node->op() == Op::kAnd) { + and_operands.insert(and_operands.end(), + user_uses_node->operands().begin(), + user_uses_node->operands().end()); + } else { + and_operands.push_back(user_uses_node); + } + } + if (user_is_used != always_visible) { + if (user_is_used->op() == Op::kAnd) { + and_operands.insert(and_operands.end(), + user_is_used->operands().begin(), + user_is_used->operands().end()); + } else { + and_operands.push_back(user_is_used); + } + } + Node* node_and_user_visible; - if (user_uses_node == always_visible) { - node_and_user_visible = user_is_used; - } else if (user_is_used == always_visible) { - node_and_user_visible = user_uses_node; + if (and_operands.empty()) { + node_and_user_visible = always_visible; + } else if (and_operands.size() == 1) { + node_and_user_visible = and_operands[0]; } else { XLS_ASSIGN_OR_RETURN( node_and_user_visible, - FindOrMakeBinaryNode(Op::kAnd, user_uses_node, user_is_used)); + func->MakeNode(SourceInfo(), and_operands, Op::kAnd)); + if (get_remaining_delay && + get_remaining_delay(node_and_user_visible) < 0) { + std::sort(and_operands.begin(), and_operands.end(), + [&](Node* a, Node* b) { + return get_remaining_delay(a) > get_remaining_delay(b); + }); + while (!and_operands.empty() && + get_remaining_delay(node_and_user_visible) < 0) { + and_operands.pop_back(); + if (and_operands.empty()) { + node_and_user_visible = always_visible; + break; + } + if (and_operands.size() == 1) { + node_and_user_visible = and_operands[0]; + } else { + XLS_ASSIGN_OR_RETURN( + node_and_user_visible, + func->MakeNode(SourceInfo(), and_operands, Op::kAnd)); + } + } + } } user_visibilities.insert(node_and_user_visible); } @@ -391,15 +529,30 @@ absl::StatusOr VisibilityBuilder::BuildVisibilityIRExprFromEdges( if (user_visibilities.size() == 1) { return node_to_visibility_ir_cache[node] = *user_visibilities.begin(); } - std::vector user_vis_vec{user_visibilities.begin(), - user_visibilities.end()}; - absl::c_sort(user_vis_vec, - [&](Node* a, Node* b) { return a->id() < b->id(); }); - Node* any_user_visible = user_vis_vec[0]; - for (int64_t i = 1; i < user_vis_vec.size(); ++i) { + std::vector user_vis_vec; + for (Node* n : user_visibilities) { + if (n != always_visible) { + if (n->op() == Op::kOr) { + user_vis_vec.insert(user_vis_vec.end(), n->operands().begin(), + n->operands().end()); + } else { + user_vis_vec.push_back(n); + } + } + } + + Node* any_user_visible; + if (user_vis_vec.empty()) { + any_user_visible = always_visible; + } else if (user_vis_vec.size() == 1) { + any_user_visible = user_vis_vec[0]; + } else { XLS_ASSIGN_OR_RETURN( any_user_visible, - FindOrMakeBinaryNode(Op::kOr, any_user_visible, user_vis_vec[i])); + func->MakeNode(SourceInfo(), user_vis_vec, Op::kOr)); + if (get_remaining_delay && get_remaining_delay(any_user_visible) < 0) { + any_user_visible = always_visible; + } } return node_to_visibility_ir_cache[node] = any_user_visible; } @@ -407,22 +560,80 @@ absl::StatusOr VisibilityBuilder::BuildVisibilityIRExprFromEdges( absl::StatusOr VisibilityBuilder::BuildVisibilityIRExpr( FunctionBase* func, Node* node, const absl::flat_hash_set& - conditional_edges) { + conditional_edges, + std::function is_live_source, + std::function get_remaining_delay, + std::optional target_stage) { + int64_t max_id_before = 0; + if (target_stage.has_value()) { + for (Node* n : func->nodes()) { + max_id_before = std::max(max_id_before, n->id()); + } + } + XLS_ASSIGN_OR_RETURN( Literal * always_visible, func->MakeNode(SourceInfo(), Value(UBits(1, 1)))); + + Node* result_expr = nullptr; if (conditional_edges.size() == 1) { XLS_ASSIGN_OR_RETURN( Node * user_uses_node, BuildVisibilityExpr(conditional_edges.begin()->operand, - conditional_edges.begin()->node, node, func)); - return user_uses_node ? user_uses_node : always_visible; - } - absl::flat_hash_map node_to_visibility_ir_cache; - absl::flat_hash_map, Node*> binary_op_cache; - return BuildVisibilityIRExprFromEdges(func, node, node, conditional_edges, - node_to_visibility_ir_cache, - always_visible); + conditional_edges.begin()->node, node, func, + is_live_source, get_remaining_delay)); + result_expr = user_uses_node ? user_uses_node : always_visible; + } else { + absl::flat_hash_map node_to_visibility_ir_cache; + absl::flat_hash_map, Node*> binary_op_cache; + XLS_ASSIGN_OR_RETURN( + result_expr, + BuildVisibilityIRExprFromEdges( + func, node, node, conditional_edges, node_to_visibility_ir_cache, + always_visible, is_live_source, get_remaining_delay)); + } + + if (target_stage.has_value()) { + for (Node* n : func->nodes()) { + if (n->id() > max_id_before) { + XLS_RETURN_IF_ERROR(func->AddNodeToStage(*target_stage, n).status()); + } + } + } + + return result_expr; +} + +absl::Status VisibilityBuilder::CleanUpUnusedNodes(FunctionBase* fb) { + std::vector worklist; + absl::flat_hash_set dead_nodes; + + for (Node* n : fb->nodes()) { + if (n->id() > prior_existing_id_ && n->IsDead()) { + dead_nodes.insert(n); + worklist.push_back(n); + } + } + + while (!worklist.empty()) { + Node* n = worklist.back(); + worklist.pop_back(); + + std::vector operands(n->operands().begin(), n->operands().end()); + + CHECK_GT(dead_nodes.erase(n), 0); + XLS_RETURN_IF_ERROR(fb->RemoveNode(n)); + + for (Node* operand : operands) { + if (operand->id() > prior_existing_id_ && operand->IsDead()) { + if (auto [_, inserted] = dead_nodes.insert(operand); inserted) { + worklist.push_back(operand); + } + } + } + } + + return absl::OkStatus(); } absl::StatusOr diff --git a/xls/passes/visibility_expr_builder.h b/xls/passes/visibility_expr_builder.h index 10a253c5e8..c2ff65cb50 100644 --- a/xls/passes/visibility_expr_builder.h +++ b/xls/passes/visibility_expr_builder.h @@ -16,6 +16,7 @@ #define XLS_PASSES_VISIBILITY_EXPR_BUILDER_H_ #include +#include #include #include @@ -100,38 +101,55 @@ class VisibilityBuilder : public ExpressionBuilder { absl::StatusOr BuildVisibilityIRExpr( FunctionBase* func, Node* node, const absl::flat_hash_set& - conditional_edges); + conditional_edges, + std::function is_live_source = nullptr, + std::function get_remaining_delay = nullptr, + std::optional target_stage = std::nullopt); + + // Cleans up any nodes created by this builder that have no users. + // This is useful for removing dead visibility expressions. + absl::Status CleanUpUnusedNodes(FunctionBase* fb); private: absl::StatusOr MakeParamIfTmpFunc(Node* node, FunctionBase* func) { return func == TmpFunc() ? TmpFuncNodeOrParam(node) : node; } - absl::StatusOr GetSelectorIfIndependent(Node* node, Node* select, - Node* source, - FunctionBase* func); + absl::StatusOr GetSelectorIfIndependent( + Node* node, Node* select, Node* source, FunctionBase* func, + const std::function& is_live_source); bool DoesCaseImplyNoPrevCase(PrioritySelect* select, int64_t case_index); absl::StatusOr GetVisibilityExprForPrioritySelect( - Node* node, PrioritySelect* select, Node* source, FunctionBase* func); - absl::StatusOr GetVisibilityExprForSelect(Node* node, Select* select, - Node* source, - FunctionBase* func); - absl::StatusOr GetVisibilityExprForAnd(Node* node, NaryOp* and_node, - Node* source, - FunctionBase* func); - absl::StatusOr GetVisibilityExprForOr(Node* node, NaryOp* or_node, - Node* source, - FunctionBase* func); + Node* node, PrioritySelect* select, Node* source, FunctionBase* func, + const std::function& is_live_source, + const std::function& get_remaining_delay); + absl::StatusOr GetVisibilityExprForSelect( + Node* node, Select* select, Node* source, FunctionBase* func, + const std::function& is_live_source, + const std::function& get_remaining_delay); + absl::StatusOr GetVisibilityExprForAnd( + Node* node, NaryOp* and_node, Node* source, FunctionBase* func, + const std::function& is_live_source, + const std::function& get_remaining_delay); + absl::StatusOr GetVisibilityExprForOr( + Node* node, NaryOp* or_node, Node* source, FunctionBase* func, + const std::function& is_live_source, + const std::function& get_remaining_delay); absl::StatusOr GetVisibilityExprForPredicate( - std::optional predicate, Node* source, FunctionBase* func); - - absl::StatusOr BuildVisibilityExprHelper(Node* node, Node* user, - Node* source, - FunctionBase* func); + std::optional predicate, Node* source, FunctionBase* func, + const std::function& is_live_source, + const std::function& get_remaining_delay); + + absl::StatusOr BuildVisibilityExprHelper( + Node* node, Node* user, Node* source, FunctionBase* func, + const std::function& is_live_source, + const std::function& get_remaining_delay); // Builds predicate for node `u` being used by `v` on `func`. - absl::StatusOr BuildVisibilityExpr(Node* node, Node* user, - Node* source, FunctionBase* func); + absl::StatusOr BuildVisibilityExpr( + Node* node, Node* user, Node* source, FunctionBase* func, + const std::function& is_live_source, + const std::function& get_remaining_delay); absl::StatusOr BuildNodeAndUserVisibleExpr(FunctionBase* func, Node* user_uses_node, Node* user_is_used, @@ -142,10 +160,12 @@ class VisibilityBuilder : public ExpressionBuilder { const absl::flat_hash_set& conditional_edges, absl::flat_hash_map& node_to_visibility_ir_cache, - Literal* always_visible); + Literal* always_visible, const std::function& is_live_source, + const std::function& get_remaining_delay); - absl::StatusOr GetNonRepeatedSourceOf(Node* operand, - FunctionBase* func); + absl::StatusOr GetNonRepeatedSourceOf( + Node* operand, FunctionBase* func, + const std::function& is_live_source); }; class VisibilityEstimator : public VisibilityBuilder { diff --git a/xls/passes/visibility_expr_builder_test.cc b/xls/passes/visibility_expr_builder_test.cc index 81c0f87b32..716b00f095 100644 --- a/xls/passes/visibility_expr_builder_test.cc +++ b/xls/passes/visibility_expr_builder_test.cc @@ -14,6 +14,7 @@ #include "xls/passes/visibility_expr_builder.h" +#include #include #include #include @@ -33,6 +34,9 @@ #include "xls/ir/function_builder.h" #include "xls/ir/ir_matcher.h" #include "xls/ir/ir_test_base.h" +#include "xls/ir/nodes.h" +#include "xls/ir/op.h" +#include "xls/ir/scheduled_builder.h" #include "xls/passes/bdd_query_engine.h" #include "xls/passes/bit_provenance_analysis.h" #include "xls/passes/node_dependency_analysis.h" @@ -113,20 +117,19 @@ TEST_F(VisibilityExprBuilderTest, ExampleInFunctionHeaderComment) { VLOG(3) << "is 'y' used:\n" << ToMathNotation(is_y_used.first); EXPECT_THAT(is_x_used.first, - m::Or(m::Or(m::Eq(m::Param("op1"), m::Literal(0)), - m::Eq(m::Param("op1"), m::Literal(2))), + m::Or(m::Eq(m::Param("op1"), m::Literal(0)), + m::Eq(m::Param("op1"), m::Literal(2)), m::And(m::ULt(m::Param("op2"), m::Literal(5)), m::UGe(m::Param("op1"), m::Literal(3))))); - // 6 instead of 7 because of the re-use of the ult. - EXPECT_EQ(is_x_used.second.area, 6); + // 5 instead of 6 because of flat OR gate consolidation. + EXPECT_EQ(is_x_used.second.area, 5); EXPECT_EQ(is_x_used.second.delay, 3); - EXPECT_THAT(is_y_used.first, - m::Or(m::Or(m::Eq(m::Param("op1"), m::Literal(1)), - m::UGe(m::Param("op1"), m::Literal(3))), - m::Or(m::Eq(m::Param("op1"), m::Literal(0)), - m::Eq(m::Param("op1"), m::Literal(2))))); - EXPECT_EQ(is_y_used.second.area, 7); - EXPECT_EQ(is_y_used.second.delay, 3); + EXPECT_THAT(is_y_used.first, m::Or(m::Eq(m::Param("op1"), m::Literal(1)), + m::UGe(m::Param("op1"), m::Literal(3)), + m::Eq(m::Param("op1"), m::Literal(0)), + m::Eq(m::Param("op1"), m::Literal(2)))); + EXPECT_EQ(is_y_used.second.area, 5); + EXPECT_EQ(is_y_used.second.delay, 2); // Now that the returned expression must be mutually exclusive with z's // visibility, it must condition on the selection criteria of 'select2'. @@ -136,11 +139,11 @@ TEST_F(VisibilityExprBuilderTest, ExampleInFunctionHeaderComment) { VLOG(3) << "is 'x' used and 'z' not used:\n" << ToMathNotation(is_x_used_and_z_not.first); EXPECT_THAT(is_x_used_and_z_not.first, - m::Or(m::Or(m::Eq(m::Param("op1"), m::Literal(0)), - m::Eq(m::Param("op1"), m::Literal(2))), + m::Or(m::Eq(m::Param("op1"), m::Literal(0)), + m::Eq(m::Param("op1"), m::Literal(2)), m::UGe(m::Param("op1"), m::Literal(3)))); - EXPECT_EQ(is_x_used_and_z_not.second.area, 5); - EXPECT_EQ(is_x_used_and_z_not.second.delay, 3); + EXPECT_EQ(is_x_used_and_z_not.second.area, 4); + EXPECT_EQ(is_x_used_and_z_not.second.delay, 2); } TEST_F(VisibilityExprBuilderTest, PrioritySelectOneHot) { @@ -207,9 +210,9 @@ TEST_F(VisibilityExprBuilderTest, Ors) { std::pair is_y_used; XLS_ASSERT_OK_AND_ASSIGN(is_y_used, BuildDefaultVisibilityExpr(f, y.node(), {})); - ASSERT_THAT(is_y_used.first, - m::Ne(m::BitSlice(m::Param("x"), 1, 2), m::Literal(3))); - EXPECT_EQ(is_y_used.first->operand(0), bits12.node()); + EXPECT_THAT(is_y_used.first, + m::Or(m::Not(m::BitSlice(m::Param("x"), 1, 1)), + m::Ne(m::BitSlice(m::Param("x"), 1, 2), m::Literal(3)))); } TEST_F(VisibilityExprBuilderTest, Ands) { @@ -226,7 +229,9 @@ TEST_F(VisibilityExprBuilderTest, Ands) { std::pair is_y_used; XLS_ASSERT_OK_AND_ASSIGN(is_y_used, BuildDefaultVisibilityExpr(f, y.node(), {})); - EXPECT_EQ(is_y_used.first, bit1.node()); + EXPECT_THAT(is_y_used.first, + m::And(m::Ne(m::BitSlice(m::Param("x"), 1, 2), m::Literal(0)), + m::BitSlice(m::Param("x"), 1, 1))); } TEST_F(VisibilityExprBuilderTest, FindsSourceOfOperandInComparison) { @@ -260,5 +265,330 @@ TEST_F(VisibilityExprBuilderTest, NotAFunctionOfSelf) { EXPECT_THAT(is_x_used.first, m::Ne(m::Param("y"), m::Literal(7))); } +TEST_F(VisibilityExprBuilderTest, LivenessHalting) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetBitsType(32)); + BValue y = fb.Param("y", p->GetBitsType(32)); + BValue z = fb.Param("z", p->GetBitsType(32)); + BValue op1 = fb.Param("op1", p->GetBitsType(4)); + BValue op2 = fb.Param("op2", p->GetBitsType(4)); + BValue select1 = fb.Select(op1, {x, y, x}, y); + BValue lt1 = fb.ULt(op2, fb.Literal(UBits(5, 4))); + BValue and1 = fb.And(x, fb.SignExtend(lt1, 32)); + BValue select2 = fb.Select(op1, {y, z, y}, and1); + BValue ret = fb.Tuple({select1, select2}); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ret)); + + NodeForwardDependencyAnalysis nda; + XLS_ASSERT_OK(nda.Attach(f).status()); + LazyPostDominatorAnalysis post_dom; + XLS_ASSERT_OK(post_dom.Attach(f).status()); + std::unique_ptr bdd_engine = BddQueryEngine::MakeDefault(); + XLS_ASSERT_OK(bdd_engine->Populate(f).status()); + XLS_ASSERT_OK_AND_ASSIGN( + auto operand_visibility, + OperandVisibilityAnalysis::Create(&nda, bdd_engine.get())); + XLS_ASSERT_OK_AND_ASSIGN( + auto visibility, VisibilityAnalysis::Create(&operand_visibility, + bdd_engine.get(), &post_dom)); + XLS_ASSERT_OK_AND_ASSIGN( + auto conditional_edges, + visibility->GetEdgesForMutuallyExclusiveVisibilityExpr(x.node(), {}, -1)); + + auto last_node_id = f->nodes_reversed().begin()->id(); + XLS_ASSERT_OK_AND_ASSIGN(AreaEstimator * ae, GetAreaEstimator("unit")); + XLS_ASSERT_OK_AND_ASSIGN(DelayEstimator * de, GetDelayEstimator("unit")); + BitProvenanceAnalysis bpa; + VisibilityEstimator estimator(last_node_id, bdd_engine.get(), nda, bpa, ae, + de); + + // A callback that says 'and1' is not live! + auto is_live_source = [&](Node* n) { return n != and1.node(); }; + + XLS_ASSERT_OK_AND_ASSIGN(Node * expr, + estimator.BuildVisibilityIRExpr( + f, x.node(), conditional_edges, is_live_source)); + + // Because and1 is ignored, we only accumulate its immediate condition (lt1 != + // 0) combined with the conditions from select1! + EXPECT_THAT(expr, m::Or(m::ULt(m::Param("op2"), m::Literal(5)), + m::Eq(m::Param("op1"), m::Literal(0)), + m::Eq(m::Param("op1"), m::Literal(2)))); +} + +TEST_F(VisibilityExprBuilderTest, AndPruning) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue y = fb.Param("y", p->GetBitsType(1)); + BValue a = fb.Param("a", p->GetBitsType(1)); + BValue b = fb.Param("b", p->GetBitsType(1)); + BValue c = fb.Param("c", p->GetBitsType(1)); + + BValue and_gate = fb.And({a, b, c, y}); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(and_gate)); + + NodeForwardDependencyAnalysis nda; + XLS_ASSERT_OK(nda.Attach(f).status()); + LazyPostDominatorAnalysis post_dom; + XLS_ASSERT_OK(post_dom.Attach(f).status()); + std::unique_ptr bdd_engine = BddQueryEngine::MakeDefault(); + XLS_ASSERT_OK(bdd_engine->Populate(f).status()); + XLS_ASSERT_OK_AND_ASSIGN( + auto operand_visibility, + OperandVisibilityAnalysis::Create(&nda, bdd_engine.get())); + XLS_ASSERT_OK_AND_ASSIGN( + auto visibility, VisibilityAnalysis::Create(&operand_visibility, + bdd_engine.get(), &post_dom)); + XLS_ASSERT_OK_AND_ASSIGN( + auto conditional_edges, + visibility->GetEdgesForMutuallyExclusiveVisibilityExpr(y.node(), {}, -1)); + + auto last_node_id = f->nodes_reversed().begin()->id(); + XLS_ASSERT_OK_AND_ASSIGN(AreaEstimator * ae, GetAreaEstimator("unit")); + XLS_ASSERT_OK_AND_ASSIGN(DelayEstimator * de, GetDelayEstimator("unit")); + BitProvenanceAnalysis bpa; + VisibilityEstimator estimator(last_node_id, bdd_engine.get(), nda, bpa, ae, + de); + + // Mock slacks: we make 'c' have worst slack (10), then 'b' (20), then 'a' + // (30). + auto get_remaining_delay = [&](Node* n) -> int64_t { + if (n == c.node()) { + return 10; + } + if (n == b.node()) { + return 20; + } + if (n == a.node()) { + return 30; + } + + // For the resulting AND, if it includes all three expected operands, we + // return -1 (exceeds limit!). Otherwise, we return a positive slack. + if (n->op() == Op::kAnd && n->operands().size() == 3) { + return -1; + } + if (n->op() == Op::kAnd && n->operands().size() == 2) { + return 5; + } + return 100; + }; + + XLS_ASSERT_OK_AND_ASSIGN( + Node * expr, estimator.BuildVisibilityIRExpr( + f, y.node(), conditional_edges, + /*is_live_source=*/nullptr, get_remaining_delay)); + + // The expected result should omit 'c'! So it should be AND(a, b)! + EXPECT_THAT(expr, m::And(m::Param("a"), m::Param("b"))); +} + +TEST_F(VisibilityExprBuilderTest, OrPruning) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetBitsType(3)); + BValue y = fb.Param("y", p->GetBitsType(2)); + BValue bit1 = fb.BitSlice(x, 1, 1); + BValue or_y = fb.Or(fb.SignExtend(bit1, 2), y); + BValue bits12 = fb.BitSlice(x, 1, 2); + BValue or_y2 = fb.Or(bits12, y); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, + fb.BuildWithReturnValue(fb.Tuple({or_y, or_y2}))); + + NodeForwardDependencyAnalysis nda; + XLS_ASSERT_OK(nda.Attach(f).status()); + LazyPostDominatorAnalysis post_dom; + XLS_ASSERT_OK(post_dom.Attach(f).status()); + std::unique_ptr bdd_engine = BddQueryEngine::MakeDefault(); + XLS_ASSERT_OK(bdd_engine->Populate(f).status()); + XLS_ASSERT_OK_AND_ASSIGN( + auto operand_visibility, + OperandVisibilityAnalysis::Create(&nda, bdd_engine.get())); + XLS_ASSERT_OK_AND_ASSIGN( + auto visibility, VisibilityAnalysis::Create(&operand_visibility, + bdd_engine.get(), &post_dom)); + XLS_ASSERT_OK_AND_ASSIGN( + auto conditional_edges, + visibility->GetEdgesForMutuallyExclusiveVisibilityExpr(y.node(), {}, -1)); + + auto last_node_id = f->nodes_reversed().begin()->id(); + XLS_ASSERT_OK_AND_ASSIGN(AreaEstimator * ae, GetAreaEstimator("unit")); + XLS_ASSERT_OK_AND_ASSIGN(DelayEstimator * de, GetDelayEstimator("unit")); + BitProvenanceAnalysis bpa; + VisibilityEstimator estimator(last_node_id, bdd_engine.get(), nda, bpa, ae, + de); + + // Mock delay limit! If the OR node contains conditions from bits12 (which + // corresponds to user 2), it will exceed timing slack limits! + auto get_remaining_delay = [&](Node* n) -> int64_t { + if (n->op() == Op::kOr) { + for (Node* operand : n->operands()) { + if (operand->Is() && operand->operand(0) == bits12.node()) { + return -1; + } + } + } + return 100; + }; + + XLS_ASSERT_OK_AND_ASSIGN( + Node * expr, estimator.BuildVisibilityIRExpr( + f, y.node(), conditional_edges, + /*is_live_source=*/nullptr, get_remaining_delay)); + + // The expected result should be true because the operation is pruned to + // literal(1)! + EXPECT_THAT(expr, m::Literal(1)); +} + +TEST_F(VisibilityExprBuilderTest, SelectPruning) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetBitsType(32)); + BValue y = fb.Param("y", p->GetBitsType(32)); + BValue op = fb.Param("op", p->GetBitsType(1)); + BValue sel = fb.Select(op, {x}, y); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(sel)); + + NodeForwardDependencyAnalysis nda; + XLS_ASSERT_OK(nda.Attach(f).status()); + LazyPostDominatorAnalysis post_dom; + XLS_ASSERT_OK(post_dom.Attach(f).status()); + std::unique_ptr bdd_engine = BddQueryEngine::MakeDefault(); + XLS_ASSERT_OK(bdd_engine->Populate(f).status()); + XLS_ASSERT_OK_AND_ASSIGN( + auto operand_visibility, + OperandVisibilityAnalysis::Create(&nda, bdd_engine.get())); + XLS_ASSERT_OK_AND_ASSIGN( + auto visibility, VisibilityAnalysis::Create(&operand_visibility, + bdd_engine.get(), &post_dom)); + XLS_ASSERT_OK_AND_ASSIGN( + auto conditional_edges, + visibility->GetEdgesForMutuallyExclusiveVisibilityExpr(x.node(), {}, -1)); + + auto last_node_id = f->nodes_reversed().begin()->id(); + XLS_ASSERT_OK_AND_ASSIGN(AreaEstimator * ae, GetAreaEstimator("unit")); + XLS_ASSERT_OK_AND_ASSIGN(DelayEstimator * de, GetDelayEstimator("unit")); + BitProvenanceAnalysis bpa; + VisibilityEstimator estimator(last_node_id, bdd_engine.get(), nda, bpa, ae, + de); + + auto get_remaining_delay = [&](Node* n) -> int64_t { + if (n->op() == Op::kOr || n->op() == Op::kEq) { + return -1; + } + return 100; + }; + + XLS_ASSERT_OK_AND_ASSIGN( + Node * expr, estimator.BuildVisibilityIRExpr( + f, x.node(), conditional_edges, + /*is_live_source=*/nullptr, get_remaining_delay)); + + EXPECT_THAT(expr, m::Literal(1)); +} + +TEST_F(VisibilityExprBuilderTest, PrioritySelectPruning) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetBitsType(32)); + BValue y = fb.Param("y", p->GetBitsType(32)); + BValue sel = fb.Param("sel", p->GetBitsType(2)); + BValue prio = fb.PrioritySelect(sel, {x, y}, fb.Literal(UBits(0, 32))); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(prio)); + + NodeForwardDependencyAnalysis nda; + XLS_ASSERT_OK(nda.Attach(f).status()); + LazyPostDominatorAnalysis post_dom; + XLS_ASSERT_OK(post_dom.Attach(f).status()); + std::unique_ptr bdd_engine = BddQueryEngine::MakeDefault(); + XLS_ASSERT_OK(bdd_engine->Populate(f).status()); + XLS_ASSERT_OK_AND_ASSIGN( + auto operand_visibility, + OperandVisibilityAnalysis::Create(&nda, bdd_engine.get())); + XLS_ASSERT_OK_AND_ASSIGN( + auto visibility, VisibilityAnalysis::Create(&operand_visibility, + bdd_engine.get(), &post_dom)); + XLS_ASSERT_OK_AND_ASSIGN( + auto conditional_edges, + visibility->GetEdgesForMutuallyExclusiveVisibilityExpr(x.node(), {}, -1)); + + auto last_node_id = f->nodes_reversed().begin()->id(); + XLS_ASSERT_OK_AND_ASSIGN(AreaEstimator * ae, GetAreaEstimator("unit")); + XLS_ASSERT_OK_AND_ASSIGN(DelayEstimator * de, GetDelayEstimator("unit")); + BitProvenanceAnalysis bpa; + VisibilityEstimator estimator(last_node_id, bdd_engine.get(), nda, bpa, ae, + de); + + auto get_remaining_delay = [&](Node* n) -> int64_t { + if (n->op() == Op::kOr || n->op() == Op::kEq || n->op() == Op::kBitSlice) { + return -1; + } + return 100; + }; + + XLS_ASSERT_OK_AND_ASSIGN( + Node * expr, estimator.BuildVisibilityIRExpr( + f, x.node(), conditional_edges, + /*is_live_source=*/nullptr, get_remaining_delay)); + + EXPECT_THAT(expr, m::Literal(1)); +} + +TEST_F(VisibilityExprBuilderTest, TargetStageEnforcement) { + auto p = CreatePackage(); + ScheduledFunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetBitsType(32)); + BValue y = fb.Param("y", p->GetBitsType(32)); + BValue select1 = fb.Select(fb.Literal(UBits(1, 1)), {x}, y); + XLS_ASSERT_OK_AND_ASSIGN(ScheduledFunction * f, + fb.BuildWithReturnValue(select1)); + + // Add empty stages to make it a scheduled function! + f->AddEmptyStages(3); + + NodeForwardDependencyAnalysis nda; + XLS_ASSERT_OK(nda.Attach(f).status()); + LazyPostDominatorAnalysis post_dom; + XLS_ASSERT_OK(post_dom.Attach(f).status()); + std::unique_ptr bdd_engine = BddQueryEngine::MakeDefault(); + XLS_ASSERT_OK(bdd_engine->Populate(f).status()); + XLS_ASSERT_OK_AND_ASSIGN( + auto operand_visibility, + OperandVisibilityAnalysis::Create(&nda, bdd_engine.get())); + XLS_ASSERT_OK_AND_ASSIGN( + auto visibility, VisibilityAnalysis::Create(&operand_visibility, + bdd_engine.get(), &post_dom)); + XLS_ASSERT_OK_AND_ASSIGN( + auto conditional_edges, + visibility->GetEdgesForMutuallyExclusiveVisibilityExpr(x.node(), {}, -1)); + + auto last_node_id = f->nodes_reversed().begin()->id(); + XLS_ASSERT_OK_AND_ASSIGN(AreaEstimator * ae, GetAreaEstimator("unit")); + XLS_ASSERT_OK_AND_ASSIGN(DelayEstimator * de, GetDelayEstimator("unit")); + BitProvenanceAnalysis bpa; + VisibilityEstimator estimator(last_node_id, bdd_engine.get(), nda, bpa, ae, + de); + + int64_t kTargetStage = 2; + XLS_ASSERT_OK_AND_ASSIGN( + Node * expr, + estimator.BuildVisibilityIRExpr( + f, x.node(), conditional_edges, /*is_live_source=*/nullptr, + /*get_remaining_delay=*/nullptr, kTargetStage)); + (void)expr; + + // All nodes created during BuildVisibilityIRExpr should be assigned to stage + // 2! + for (Node* n : f->nodes()) { + if (n->id() > last_node_id) { + EXPECT_TRUE(f->IsStaged(n)); + XLS_ASSERT_OK_AND_ASSIGN(int64_t stage, f->GetStageIndex(n)); + EXPECT_EQ(stage, kTargetStage); + } + } +} + } // namespace } // namespace xls