diff --git a/xls/codegen/codegen_options.h b/xls/codegen/codegen_options.h index ca5b9321cf..aa17d09d69 100644 --- a/xls/codegen/codegen_options.h +++ b/xls/codegen/codegen_options.h @@ -16,6 +16,7 @@ #define XLS_CODEGEN_CODEGEN_OPTIONS_H_ #include +#include #include #include #include @@ -37,11 +38,32 @@ namespace xls::verilog { +class CodegenOptionExtension { + public: + virtual ~CodegenOptionExtension() = default; + virtual std::string_view extension_name() const = 0; +}; + // Options describing how codegen should be performed. class CodegenOptions { public: explicit CodegenOptions() = default; + template + const T* get_extension() const { + for (const auto& ext : extensions_) { + if (ext->extension_name() == T::kExtensionName) { + return static_cast(ext.get()); + } + } + return nullptr; + } + + CodegenOptions& add_extension(std::unique_ptr ext) { + extensions_.push_back(std::move(ext)); + return *this; + } + // Enum to describe which codegen version to use. enum class Version : uint8_t { kDefault = 0, @@ -461,6 +483,7 @@ class CodegenOptions { int64_t max_trace_verbosity_ = 0; RegisterMergeStrategy register_merge_strategy_ = RegisterMergeStrategy::kDefault; + SourceAnnotationStrategy source_annotation_strategy_ = SourceAnnotationStrategy::kNone; std::optional package_interface_; @@ -474,6 +497,8 @@ class CodegenOptions { std::vector randomize_order_seed_; std::optional residual_data_; std::optional ir_dump_path_; + + std::vector> extensions_; }; template diff --git a/xls/codegen/module_builder.cc b/xls/codegen/module_builder.cc index 49f93f110b..fee3fde9d1 100644 --- a/xls/codegen/module_builder.cc +++ b/xls/codegen/module_builder.cc @@ -1247,6 +1247,13 @@ absl::StatusOr ModuleBuilder::EmitGate( XLS_ASSIGN_OR_RETURN(LogicRef * ref, DeclareVariable(gate->GetName(), gate->GetType())); + if (gate->gate_type() == GateType::kIgnorableGate) { + // Ignorable gates don't perform zero-clamping; they can be evaluated as + // identity operations. + XLS_RETURN_IF_ERROR(Assign(ref, data, gate->GetType())); + return ref; + } + // Emit the gate as an AND of the (potentially replicated) condition and the // data. For example: // diff --git a/xls/codegen/verilog_conversion.cc b/xls/codegen/verilog_conversion.cc index caa9b58cf4..6b21d619c0 100644 --- a/xls/codegen/verilog_conversion.cc +++ b/xls/codegen/verilog_conversion.cc @@ -652,6 +652,13 @@ class BlockGenerator { std::optional op_override = options_.GetOpOverride(node->op()); + // Ignorable-value gates are emitted as identity, so they don't need to + // support overrides. + if (node->op() == Op::kGate && + node->As()->gate_type() == GateType::kIgnorableGate) { + op_override = std::nullopt; + } + if (op_override.has_value()) { std::vector inputs; for (const Node* operand : node->operands()) { diff --git a/xls/codegen/verilog_conversion_test.cc b/xls/codegen/verilog_conversion_test.cc index 4947029410..d1fd402333 100644 --- a/xls/codegen/verilog_conversion_test.cc +++ b/xls/codegen/verilog_conversion_test.cc @@ -2400,12 +2400,15 @@ class ZeroWidthVerilogConversionTest } std::filesystem::path GoldenFilePath( std::string_view test_file_name, - const std::filesystem::path& testdata_dir) override { + const std::filesystem::path& testdata_dir, + std::optional forced_test_base_name = + std::nullopt) override { // We suffix the golden reference files with "txt" on top of the extension // just to indicate they're compiler byproduct comparison points and not // Verilog files that have been written by hand. std::string filename = absl::StrCat( - test_file_name, "_", TestBaseName(), "Input", + test_file_name, "_", forced_test_base_name.value_or(TestBaseName()), + "Input", ParameterizedFloppingName(std::get<0>(std::get<1>(GetParam()))), "Output", ParameterizedFloppingName(std::get<1>(std::get<1>(GetParam()))), ".", diff --git a/xls/codegen_v_1_5/BUILD b/xls/codegen_v_1_5/BUILD index 78e4f005cb..bb2558dfa0 100644 --- a/xls/codegen_v_1_5/BUILD +++ b/xls/codegen_v_1_5/BUILD @@ -61,6 +61,7 @@ cc_library( hdrs = ["block_conversion_pass.h"], deps = [ "//xls/codegen:codegen_options", + "//xls/estimators/delay_model:delay_estimator", "//xls/passes:optimization_pass", "//xls/passes:pass_base", "//xls/scheduling:pipeline_schedule_cc_proto", diff --git a/xls/codegen_v_1_5/block_conversion_pass.h b/xls/codegen_v_1_5/block_conversion_pass.h index 86ae3059fe..e06ce8200f 100644 --- a/xls/codegen_v_1_5/block_conversion_pass.h +++ b/xls/codegen_v_1_5/block_conversion_pass.h @@ -16,6 +16,7 @@ #define XLS_CODEGEN_V_1_5_BLOCK_CONVERSION_PASS_H_ #include "xls/codegen/codegen_options.h" +#include "xls/estimators/delay_model/delay_estimator.h" #include "xls/passes/optimization_pass.h" #include "xls/passes/pass_base.h" #include "xls/scheduling/pipeline_schedule.pb.h" @@ -28,6 +29,9 @@ namespace xls::codegen { struct BlockConversionPassOptions : public PassOptionsBase { verilog::CodegenOptions codegen_options; PackageScheduleProto package_schedule; + + // The delay estimator used for scheduling & schedule-checking purposes. + const DelayEstimator* delay_estimator = nullptr; }; struct BlockConversionContext { diff --git a/xls/codegen_v_1_5/convert_to_block.cc b/xls/codegen_v_1_5/convert_to_block.cc index 58a1981b23..2ed26e1424 100644 --- a/xls/codegen_v_1_5/convert_to_block.cc +++ b/xls/codegen_v_1_5/convert_to_block.cc @@ -58,6 +58,7 @@ absl::Status ConvertToBlock( BlockConversionPassOptions options{ .codegen_options = std::move(codegen_options), .package_schedule = std::move(schedule), + .delay_estimator = delay_estimator, }; XLS_ASSIGN_OR_RETURN(std::unique_ptr pipeline, diff --git a/xls/codegen_v_1_5/pipeline_register_insertion_pass.cc b/xls/codegen_v_1_5/pipeline_register_insertion_pass.cc index 309179ec5e..0d72ffd6ee 100644 --- a/xls/codegen_v_1_5/pipeline_register_insertion_pass.cc +++ b/xls/codegen_v_1_5/pipeline_register_insertion_pass.cc @@ -96,6 +96,23 @@ absl::StatusOr CreatePipelineRegister( block->AddRegister(name, node->GetType(), reset_value)); Node* load_enable = stage_done; + + if (node->Is() && + node->As()->gate_type() == GateType::kIgnorableGate) { + // Ignorable gates are used to indicate that the node is conditionally + // visible. We can strengthen the logic for gating the pipeline register by + // ANDing the stage done signal with the gate's predicate, reducing the + // number of times the register needs to be updated. + Gate* gate = node->As(); + Node* predicate = gate->condition(); + node = gate->data(); + XLS_ASSIGN_OR_RETURN( + load_enable, + block->MakeNode(node->loc(), + absl::MakeConstSpan({stage_done, predicate}), + Op::kAnd)); + } + if (block->GetResetPort().has_value() && options.codegen_options.reset().has_value() && !options.codegen_options.reset()->reset_data_path()) { @@ -109,9 +126,9 @@ absl::StatusOr CreatePipelineRegister( } XLS_ASSIGN_OR_RETURN( load_enable, - block->MakeNode(node->loc(), - absl::MakeConstSpan({stage_done, reset_active}), - Op::kOr)); + block->MakeNode( + node->loc(), absl::MakeConstSpan({load_enable, reset_active}), + Op::kOr)); } // NOTE: The RegisterWrite is added to the block, but not to the stage. Its // `load_enable` depends on `outputs_ready`, which comes from outside @@ -139,7 +156,13 @@ absl::StatusOr CreatePipelineRegister( absl::StatusOr AddPipelineRegisterFor( Node* node, int64_t stage_index, Node* stage_done, ScheduledBlock* block, const BlockConversionPassOptions& options) { - std::string base_name = PipelineSignalName(node->GetName(), stage_index); + Node* name_source = node; + if (node->Is() && + node->As()->gate_type() == GateType::kIgnorableGate) { + name_source = node->As()->data(); + } + std::string base_name = + PipelineSignalName(name_source->GetName(), stage_index); // As a special case, check if the node is a tuple // containing types that are of zero-width. If so, separate them out so diff --git a/xls/ir/function.cc b/xls/ir/function.cc index 45157be200..83f49daf0c 100644 --- a/xls/ir/function.cc +++ b/xls/ir/function.cc @@ -135,17 +135,12 @@ std::string Function::DumpIr(const IrAnnotator& annotate) const { private: Node* return_value_; }; - if (IsScheduled()) { - absl::StrAppend(&res, - DumpFunctionBaseNodes(IrAnnotatorJoiner( - SkipParamsAnnotator{}, AddRetAnnotator{return_value()}, - IrAnnotatorRef(annotate)))); - } else { - absl::StrAppend( - &res, - DumpFunctionBaseNodes(IrAnnotatorJoiner( - SkipParamsAnnotator{}, AddRetAnnotator{return_value()}, - IrAnnotatorRef(annotate), TopoSortAnnotator(!is_block_source_)))); + absl::StrAppend( + &res, + DumpFunctionBaseNodes(IrAnnotatorJoiner( + SkipParamsAnnotator{}, AddRetAnnotator{return_value()}, + IrAnnotatorRef(annotate), TopoSortAnnotator(!is_block_source_)))); + if (!IsScheduled()) { // Need to add the 'ret' specially if the ret is also a parameter. if (return_value() != nullptr && return_value()->op() == Op::kParam) { absl::StrAppendFormat(&res, " ret %s\n", diff --git a/xls/ir/function_base.cc b/xls/ir/function_base.cc index 4c422affee..3bd3854067 100644 --- a/xls/ir/function_base.cc +++ b/xls/ir/function_base.cc @@ -64,10 +64,33 @@ namespace { // sections may be empty after it runs. class StageSectioner : public DfsVisitorWithDefault { public: - explicit StageSectioner(const FunctionBase* fb) - : fb_(fb), sections_(fb->stages().size() * 2 + 1) {} + struct Section { + std::optional stage_index = std::nullopt; + std::list nodes; - absl::Status Run() { + bool IsStage() const { return stage_index.has_value(); } + + bool empty() const { return nodes.empty(); } + + auto begin() const { return nodes.begin(); } + auto end() const { return nodes.end(); } + + auto cbegin() const { return nodes.cbegin(); } + auto cend() const { return nodes.cend(); } + + void push_back(Node* node) { nodes.push_back(node); } + }; + + explicit StageSectioner(const FunctionBase* fb) : fb_(fb) { + sections_.reserve(fb->stages().size() * 2 + 1); + for (int i = 0; i < fb_->stages().size(); i++) { + sections_.push_back(Section{.stage_index = std::nullopt}); + sections_.push_back(Section{.stage_index = i}); + } + sections_.push_back(Section{.stage_index = std::nullopt}); + } + + absl::Status Run(std::optional> order = std::nullopt) { // Capture stages and their deps. for (int i = 0; i < fb_->stages().size(); i++) { const Stage& stage = fb_->stages()[i]; @@ -107,6 +130,26 @@ class StageSectioner : public DfsVisitorWithDefault { XLS_RETURN_IF_ERROR(node->Accept(this)); } } + + // If provided, we use the given order to sort within each section. + if (order.has_value()) { + std::optional> node_to_index; + node_to_index.emplace(); + node_to_index->reserve(order->size()); + for (int i = 0; i < order->size(); ++i) { + node_to_index->emplace(order->at(i), i); + } + + for (Section& section : sections_) { + CHECK(absl::c_all_of(section.nodes, [&](Node* node) { + return node_to_index->contains(node); + })) << "Some nodes not found in order"; + section.nodes.sort([&](Node* a, Node* b) { + return node_to_index->at(a) < node_to_index->at(b); + }); + } + } + return absl::OkStatus(); } @@ -126,11 +169,11 @@ class StageSectioner : public DfsVisitorWithDefault { return absl::OkStatus(); } - const std::vector>& sections() const { return sections_; } + const std::vector
& sections() const { return sections_; } private: const FunctionBase* fb_; - std::vector> sections_; + std::vector
sections_; int64_t current_stage_; absl::flat_hash_set handled_; }; @@ -266,15 +309,12 @@ std::string FunctionBase::DumpFunctionBaseNodes( const IrAnnotator& annotate) const { std::string res; if (IsScheduled()) { - CHECK(!annotate.NodeOrder(const_cast(this))) - << "Cannot use custom node order for scheduled entities"; StageSectioner sectioner(this); - CHECK(sectioner.Run().ok()); - for (const std::list& section : sectioner.sections()) { - if (section.empty()) { - continue; - } - + // In a scheduled entity, we use the annotator's order to sort within each + // section. + CHECK(sectioner.Run(annotate.NodeOrder(const_cast(this))) + .ok()); + for (const StageSectioner::Section& section : sectioner.sections()) { struct ScheduledAnnotations : public IrAnnotator { public: ScheduledAnnotations(Node* outputs_valid, Node* active_inputs_valid) @@ -294,9 +334,8 @@ std::string FunctionBase::DumpFunctionBaseNodes( Node* active_inputs_valid_; }; - Node* first = *section.begin(); - if (IsStaged(first)) { - const Stage& stage = stages()[*GetStageIndex(first)]; + if (section.IsStage()) { + const Stage& stage = stages()[*section.stage_index]; IrAnnotatorJoiner joiner( ScheduledAnnotations(stage.outputs_valid(), stage.active_inputs_valid()), diff --git a/xls/ir/function_base.h b/xls/ir/function_base.h index bcd89105ec..6caf311013 100644 --- a/xls/ir/function_base.h +++ b/xls/ir/function_base.h @@ -102,6 +102,10 @@ class Stage { active_outputs_.erase(node); } + bool empty() const { + return active_inputs_.empty() && logic_.empty() && active_outputs_.empty(); + } + inline auto begin() { return iter::chain(active_inputs_, logic_, active_outputs_).begin(); } diff --git a/xls/ir/function_builder.cc b/xls/ir/function_builder.cc index cbcbdbfe12..688f9794c3 100644 --- a/xls/ir/function_builder.cc +++ b/xls/ir/function_builder.cc @@ -1821,8 +1821,8 @@ BValue BuilderBase::Cover(BValue condition, std::string_view label, /*original_label=*/std::nullopt, name); } -BValue BuilderBase::Gate(BValue condition, BValue data, const SourceInfo& loc, - std::string_view name) { +BValue BuilderBase::Gate(BValue condition, BValue data, GateType gate_type, + const SourceInfo& loc, std::string_view name) { if (ErrorPending()) { return BValue(); } @@ -1834,7 +1834,8 @@ BValue BuilderBase::Gate(BValue condition, BValue data, const SourceInfo& loc, condition.GetType()->ToString()), loc); } - return AddNode(loc, condition.node(), data.node(), name); + return AddNode(loc, condition.node(), data.node(), gate_type, + name); } BValue TokenlessProcBuilder::MinDelay(int64_t delay, const SourceInfo& loc, diff --git a/xls/ir/function_builder.h b/xls/ir/function_builder.h index ecce51f364..dcde9405c8 100644 --- a/xls/ir/function_builder.h +++ b/xls/ir/function_builder.h @@ -653,6 +653,16 @@ class BuilderBase { // Adds a gate operation. The output of the operation is `data` if `cond` is // true and zero-valued otherwise. Gates are side-effecting. BValue Gate(BValue condition, BValue data, + const SourceInfo& loc = SourceInfo(), + std::string_view name = "") { + return Gate(condition, data, GateType::kZeroGate, loc, name); + } + + // Adds a gate operation with a specified gate type. The output of the + // operation is `data` if `cond` is true and otherwise depends on the gate + // type. Gates are side-effecting. + // are side-effecting. + BValue Gate(BValue condition, BValue data, GateType gate_type, const SourceInfo& loc = SourceInfo(), std::string_view name = ""); // Add a receive operation. The type of the data value received is diff --git a/xls/ir/ir_parser.cc b/xls/ir/ir_parser.cc index 46be4098ea..acb4c9b625 100644 --- a/xls/ir/ir_parser.cc +++ b/xls/ir/ir_parser.cc @@ -1351,8 +1351,16 @@ absl::StatusOr Parser::ParseNode( break; } case Op::kGate: { + std::optional* gate_type_str = + arg_parser.AddOptionalKeywordArg("gate_type"); XLS_ASSIGN_OR_RETURN(operands, arg_parser.Run(/*arity=*/2)); - bvalue = fb->Gate(operands[0], operands[1], *loc, node_name); + if (gate_type_str->has_value()) { + XLS_ASSIGN_OR_RETURN(GateType gate_type, + ParseGateType(gate_type_str->value().value)); + bvalue = fb->Gate(operands[0], operands[1], gate_type, *loc, node_name); + } else { + bvalue = fb->Gate(operands[0], operands[1], *loc, node_name); + } break; } case Op::kInstantiationInput: { diff --git a/xls/ir/node.cc b/xls/ir/node.cc index f652fe1a5e..7c7e4345ee 100644 --- a/xls/ir/node.cc +++ b/xls/ir/node.cc @@ -639,6 +639,13 @@ std::string Node::ToStringInternal(bool include_operand_types) const { } break; } + case Op::kGate: { + const Gate* gate = As(); + if (gate->gate_type() == GateType::kIgnorableGate) { + args.push_back("gate_type=ignorable"); + } + break; + } case Op::kNext: { const Next* next = As(); std::string param_name = next->has_state_read() diff --git a/xls/ir/nodes.cc b/xls/ir/nodes.cc index a34aa926d7..12f960e2b1 100755 --- a/xls/ir/nodes.cc +++ b/xls/ir/nodes.cc @@ -26,6 +26,7 @@ #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "xls/common/math_util.h" #include "xls/common/status/ret_check.h" @@ -1319,9 +1320,20 @@ bool InstantiationInput::IsDefinitelyEqualTo(const Node* other) const { port_name_ == other->As()->port_name_; } +absl::StatusOr ParseGateType(std::string_view gate_type) { + if (gate_type == "zero") { + return GateType::kZeroGate; + } else if (gate_type == "ignorable") { + return GateType::kIgnorableGate; + } + return absl::InvalidArgumentError( + absl::StrCat("Unknown gate type: ", gate_type)); +} + Gate::Gate(const SourceInfo& loc, Node* condition, Node* data, - std::string_view name, FunctionBase* function) - : Node(Op::kGate, data->GetType(), loc, name, function) { + GateType gate_type, std::string_view name, FunctionBase* function) + : Node(Op::kGate, data->GetType(), loc, name, function), + gate_type_(gate_type) { CHECK(IsOpClass(op_)) << "Op `" << op_ << "` is not a valid op for Node class `Gate`."; AddOperand(condition); @@ -1331,8 +1343,8 @@ Gate::Gate(const SourceInfo& loc, Node* condition, Node* data, absl::StatusOr Gate::CloneInNewFunction( absl::Span new_operands, FunctionBase* new_function) const { XLS_RET_CHECK_EQ(operand_count(), new_operands.size()); - return new_function->MakeNodeWithName(loc(), new_operands[0], - new_operands[1], GetNameView()); + return new_function->MakeNodeWithName( + loc(), new_operands[0], new_operands[1], gate_type(), GetNameView()); } SliceData Concat::GetOperandSliceData(int64_t operandno) const { diff --git a/xls/ir/nodes.h b/xls/ir/nodes.h index a2151db377..b5ae391aa9 100755 --- a/xls/ir/nodes.h +++ b/xls/ir/nodes.h @@ -514,6 +514,22 @@ class ExtendOp final : public Node { int64_t new_bit_count_; }; +enum class GateType : uint8_t { kZeroGate, kIgnorableGate }; + +template +void AbslStringify(Sink& sink, const GateType& gate_type) { + switch (gate_type) { + case GateType::kZeroGate: + absl::Format(&sink, "zero"); + return; + case GateType::kIgnorableGate: + absl::Format(&sink, "ignorable"); + return; + } +} + +absl::StatusOr ParseGateType(std::string_view gate_type); + class Gate final : public Node { public: static constexpr std::array kOps = {Op::kGate}; @@ -521,6 +537,11 @@ class Gate final : public Node { static constexpr int64_t kDataOperand = 1; Gate(const SourceInfo& loc, Node* condition, Node* data, + std::string_view name, FunctionBase* function) + : Gate(loc, condition, data, /*gate_type=*/GateType::kZeroGate, name, + function) {} + + Gate(const SourceInfo& loc, Node* condition, Node* data, GateType gate_type, std::string_view name, FunctionBase* function); absl::StatusOr CloneInNewFunction( @@ -529,6 +550,11 @@ class Gate final : public Node { Node* condition() const { return operand(0); } Node* data() const { return operand(1); } + + GateType gate_type() const { return gate_type_; } + + private: + GateType gate_type_; }; enum PortDirection : uint8_t { kInput, kOutput }; diff --git a/xls/ir/proc.cc b/xls/ir/proc.cc index 4d88219dd8..aa1d1602cc 100644 --- a/xls/ir/proc.cc +++ b/xls/ir/proc.cc @@ -143,11 +143,10 @@ std::string Proc::DumpIr(const IrAnnotator& annotate) const { } } - if (IsScheduled()) { - absl::StrAppend(&res, DumpFunctionBaseNodes(annotate)); - } else if (!is_block_source_) { + if (IsScheduled() || !is_block_source_) { absl::StrAppend(&res, DumpFunctionBaseNodes(IrAnnotatorJoiner( - IrAnnotatorRef(annotate), TopoSortAnnotator()))); + IrAnnotatorRef(annotate), + TopoSortAnnotator(!is_block_source_)))); } absl::StrAppend(&res, "}\n"); return res; diff --git a/xls/passes/BUILD b/xls/passes/BUILD index cbcb73da0c..9074f41b7f 100644 --- a/xls/passes/BUILD +++ b/xls/passes/BUILD @@ -1662,6 +1662,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -4846,13 +4847,12 @@ cc_library( "//xls/ir:source_location", "//xls/ir:type", "//xls/ir:value", - "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) @@ -4880,7 +4880,9 @@ cc_test( "//xls/ir:function_builder", "//xls/ir:ir_matcher", "//xls/ir:ir_test_base", + "//xls/ir:op", "//xls/visualization:math_notation", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@googletest//:gtest", diff --git a/xls/passes/visibility_analysis.cc b/xls/passes/visibility_analysis.cc index 30132ebf61..96b768a793 100644 --- a/xls/passes/visibility_analysis.cc +++ b/xls/passes/visibility_analysis.cc @@ -25,6 +25,7 @@ #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" @@ -538,6 +539,8 @@ BddNodeIndex OperandVisibilityAnalysis::ConditionOfUse(Node* node, return ConditionOnPredicate(node, user->As()->predicate()); } else if (user->Is()) { return ConditionOnNextUse(user->As(), node); + } else if (user->Is()) { + return ConditionOnPredicate(node, user->As()->condition()); } else if (user->OpIn({Op::kAnd, Op::kNand})) { return ConditionOfUseWithAnd(node, user->As()); } else if (user->OpIn({Op::kOr, Op::kNor})) { @@ -796,8 +799,10 @@ bool VisibilityAnalysis::IsMutuallyExclusive(Node* one, Node* other) const { return bdd.Implies(*GetInfo(one), bdd.Not(*GetInfo(other))) == bdd.one(); } -absl::StatusOr OperandVisibilityAnalysis::IsVisibilityIndependentOf( - Node* operand, Node* node, std::vector& sources) const { +namespace { + +absl::StatusOr> GetVisibilityControlConditions( + const Node* operand, const Node* node) { std::vector conditions; if (node->Is()) { conditions.push_back(node->As()->selector()); @@ -807,6 +812,8 @@ absl::StatusOr OperandVisibilityAnalysis::IsVisibilityIndependentOf( conditions.push_back(*node->As()->predicate()); } else if (node->Is() && node->As()->predicate().has_value()) { conditions.push_back(*node->As()->predicate()); + } else if (node->Is()) { + conditions.push_back(node->As()->condition()); } else if (node->OpIn({Op::kAnd, Op::kOr, Op::kNand, Op::kNor})) { for (Node* other_op : node->operands()) { if (other_op != operand) { @@ -818,6 +825,15 @@ absl::StatusOr OperandVisibilityAnalysis::IsVisibilityIndependentOf( absl::StrFormat("Unsupported node type for visibility expression: %s", node->ToString())); } + return conditions; +} + +} // namespace + +absl::StatusOr OperandVisibilityAnalysis::IsVisibilityIndependentOf( + Node* operand, Node* node, std::vector& sources) const { + XLS_ASSIGN_OR_RETURN(std::vector conditions, + GetVisibilityControlConditions(operand, node)); for (Node* condition : conditions) { for (Node* source : sources) { @@ -977,6 +993,52 @@ VisibilityAnalysis::GetEdgesForMutuallyExclusiveVisibilityExpr( return kept_edges; } +absl::StatusOr> +VisibilityAnalysis::GetEdgesForConservativeVisibilityExpr( + Node* one, absl::AnyInvocable is_live_source, + int64_t max_edges_to_handle) const { + absl::flat_hash_set kept_edges; + std::queue worklist; + worklist.push(one); + absl::flat_hash_set visited = {one}; + + while (!worklist.empty()) { + Node* node = worklist.front(); + worklist.pop(); + + for (Node* user : node->users()) { + // Whether or not we want to keep this edge, we should add the user to + // the worklist if it hasn't been visited yet. + if (auto [_, inserted] = visited.insert(user); inserted) { + worklist.push(user); + } + + XLS_ASSIGN_OR_RETURN(std::vector conditions, + GetVisibilityControlConditions(node, user)); + + if (!conditions.empty() && + absl::c_none_of(conditions, [&](Node* condition) { + return is_live_source(condition); + })) { + // There are conditions, but none of them are live; we have to consider + // this edge always active. + continue; + } + + BddNodeIndex visibility = + operand_visibility_->OperandVisibilityThroughNode(node, user); + if (visibility != bdd_query_engine_->bdd().one()) { + kept_edges.insert({node, user}); + } + } + } + + if (max_edges_to_handle >= 0 && kept_edges.size() > max_edges_to_handle) { + return absl::flat_hash_set{}; + } + return kept_edges; +} + /* static */ absl::StatusOr> SingleSelectVisibilityAnalysis::Create( const OperandVisibilityAnalysis* operand_vis, diff --git a/xls/passes/visibility_analysis.h b/xls/passes/visibility_analysis.h index c962382398..dc83fcf280 100644 --- a/xls/passes/visibility_analysis.h +++ b/xls/passes/visibility_analysis.h @@ -24,6 +24,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -229,6 +230,15 @@ class VisibilityAnalysis : public LazyNodeData { absl::Span others, int64_t max_edges_to_handle) const; + // Returns the (node -> user) edges necessary to compute a conservative + // visibility expression for 'one' such that only source nodes satisfying the + // given liveness predicate are included. Invalid edges are dynamically + // pruned to fold the expression conservatively. + absl::StatusOr> + GetEdgesForConservativeVisibilityExpr( + Node* one, absl::AnyInvocable is_live_source, + int64_t max_edges_to_handle) const; + BddNodeIndex VisibilityOfNearestPostDominator(Node* node) const; const BddQueryEngine* bdd_query_engine() const { return bdd_query_engine_; } diff --git a/xls/passes/visibility_expr_builder.cc b/xls/passes/visibility_expr_builder.cc index 807d542f16..df88474a91 100644 --- a/xls/passes/visibility_expr_builder.cc +++ b/xls/passes/visibility_expr_builder.cc @@ -14,19 +14,20 @@ #include "xls/passes/visibility_expr_builder.h" +#include #include +#include #include #include #include #include -#include "absl/algorithm/container.h" +#include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "xls/common/status/status_macros.h" #include "xls/data_structures/binary_decision_diagram.h" @@ -50,14 +51,16 @@ namespace xls { absl::StatusOr VisibilityBuilder::GetSelectorIfIndependent( - Node* node, Node* select, Node* source, FunctionBase* func) { + Node* node, Node* select, Node* source, FunctionBase* func, + const std::function& is_live_source) { Node* selector = nullptr; if (select->Is()->selector(); } else if (select->Is()) { selector = select->As()->selector(); } - if (!selector || nda_.IsDependent(source, selector)) { + if (!selector || nda_.IsDependent(source, selector) || + (is_live_source && !is_live_source(selector))) { return nullptr; } return MakeParamIfTmpFunc(selector, func); @@ -91,9 +94,12 @@ bool VisibilityBuilder::DoesCaseImplyNoPrevCase(PrioritySelect* select, } absl::StatusOr VisibilityBuilder::GetVisibilityExprForPrioritySelect( - Node* node, PrioritySelect* select, Node* source, FunctionBase* func) { - XLS_ASSIGN_OR_RETURN(Node * selector, - GetSelectorIfIndependent(node, select, source, func)); + Node* node, PrioritySelect* select, Node* source, FunctionBase* func, + const std::function& is_live_source, + const std::function& get_remaining_delay) { + XLS_ASSIGN_OR_RETURN( + Node * selector, + GetSelectorIfIndependent(node, select, source, func, is_live_source)); if (!selector) { return nullptr; } @@ -138,13 +144,22 @@ absl::StatusOr VisibilityBuilder::GetVisibilityExprForPrioritySelect( func->MakeNode(select->loc(), selector, zero, Op::kEq)); or_cases.push_back(selector_is_zero); } - return OrOperands(or_cases); + + XLS_ASSIGN_OR_RETURN(Node * result, OrOperands(or_cases)); + if (get_remaining_delay && get_remaining_delay(result) < 0) { + // This would push the condition over the acceptable delay. + return nullptr; + } + return result; } absl::StatusOr VisibilityBuilder::GetVisibilityExprForSelect( - Node* node, Select* select, Node* source, FunctionBase* func) { - XLS_ASSIGN_OR_RETURN(Node * selector, - GetSelectorIfIndependent(node, select, source, func)); + Node* node, Select* select, Node* source, FunctionBase* func, + const std::function& is_live_source, + const std::function& get_remaining_delay) { + XLS_ASSIGN_OR_RETURN( + Node * selector, + GetSelectorIfIndependent(node, select, source, func, is_live_source)); if (!selector) { return nullptr; } @@ -175,12 +190,18 @@ absl::StatusOr VisibilityBuilder::GetVisibilityExprForSelect( value_num_cases, Op::kUGe)); or_cases.push_back(is_default); } - return OrOperands(or_cases); + XLS_ASSIGN_OR_RETURN(Node * result, OrOperands(or_cases)); + if (get_remaining_delay && get_remaining_delay(result) < 0) { + // This would push the condition over the acceptable delay. + return nullptr; + } + return result; } // Find the source bits of operand, and if unknown, use the operand itself. absl::StatusOr VisibilityBuilder::GetNonRepeatedSourceOf( - Node* operand, FunctionBase* func) { + Node* operand, FunctionBase* func, + const std::function& is_live_source) { XLS_ASSIGN_OR_RETURN(auto bit_sources_tree, bpa_.GetBitSources(operand)); LeafTypeTree trimmed_bit_sources_tree = BitProvenanceAnalysis::TrimRepeatedSourceBits( @@ -189,6 +210,9 @@ absl::StatusOr VisibilityBuilder::GetNonRepeatedSourceOf( if (source_ranges.size() == 1) { const TreeBitSources::BitRange& single_range = source_ranges[0]; Node* source = single_range.source_node(); + if (is_live_source && !is_live_source(source)) { + return operand; + } // Clone the source node if building expressions in a temp function. XLS_ASSIGN_OR_RETURN(source, MakeParamIfTmpFunc(source, func)); // If derived from a single range of contiguous bits, find or create a bit @@ -210,8 +234,36 @@ absl::StatusOr VisibilityBuilder::GetNonRepeatedSourceOf( return MakeParamIfTmpFunc(operand, func); } +namespace { + +absl::StatusOr FilterByDelay( + FunctionBase* func, const SourceInfo& loc, std::vector& operands, + const std::function& get_remaining_delay, + const std::function(const SourceInfo&)>& build_empty, + const std::function(std::vector&)>& build_fn) { + XLS_ASSIGN_OR_RETURN(Node * result, build_fn(operands)); + if (get_remaining_delay && get_remaining_delay(result) < 0) { + std::sort(operands.begin(), operands.end(), [&](Node* a, Node* b) { + return get_remaining_delay(a) > get_remaining_delay(b); + }); + while (!operands.empty() && get_remaining_delay(result) < 0) { + operands.pop_back(); + if (operands.empty()) { + XLS_ASSIGN_OR_RETURN(result, build_empty(loc)); + break; + } + XLS_ASSIGN_OR_RETURN(result, build_fn(operands)); + } + } + return result; +} + +} // namespace + absl::StatusOr VisibilityBuilder::GetVisibilityExprForAnd( - Node* node, NaryOp* and_node, Node* source, FunctionBase* func) { + Node* node, NaryOp* and_node, Node* source, FunctionBase* func, + const std::function& is_live_source, + const std::function& get_remaining_delay) { std::vector others_not_zero; for (Node* operand : and_node->operands()) { if (nda_.IsDependent(source, operand)) { @@ -219,10 +271,8 @@ absl::StatusOr VisibilityBuilder::GetVisibilityExprForAnd( } XLS_ASSIGN_OR_RETURN(Node * compare_val, - GetNonRepeatedSourceOf(operand, func)); + GetNonRepeatedSourceOf(operand, func, is_live_source)); - // If operand is derived from a single bit, we can simply check that the - // source bit is set to determine that @node is visible if (compare_val->GetType()->GetFlatBitCount() == 1) { others_not_zero.push_back(compare_val); continue; @@ -238,11 +288,30 @@ absl::StatusOr VisibilityBuilder::GetVisibilityExprForAnd( value_zero, Op::kNe)); others_not_zero.push_back(is_not_zero); } - return AndOperands(others_not_zero); + + if (others_not_zero.empty()) { + return AndOperands(others_not_zero); + } + if (others_not_zero.size() == 1) { + return others_not_zero[0]; + } + + return FilterByDelay( + func, and_node->loc(), others_not_zero, get_remaining_delay, + /*build_empty=*/ + [&](const SourceInfo& loc) -> absl::StatusOr { + return func->MakeNode(loc, Value(UBits(1, 1))); + }, + /*build_fn=*/ + [&](std::vector& ops) -> absl::StatusOr { + return AndOperands(ops); + }); } absl::StatusOr VisibilityBuilder::GetVisibilityExprForOr( - Node* node, NaryOp* or_node, Node* source, FunctionBase* func) { + Node* node, NaryOp* or_node, Node* source, FunctionBase* func, + const std::function& is_live_source, + const std::function& get_remaining_delay) { std::vector others_not_ones; for (Node* operand : or_node->operands()) { if (nda_.IsDependent(source, operand)) { @@ -250,10 +319,8 @@ absl::StatusOr VisibilityBuilder::GetVisibilityExprForOr( } XLS_ASSIGN_OR_RETURN(Node * compare_val, - GetNonRepeatedSourceOf(operand, func)); + GetNonRepeatedSourceOf(operand, func, is_live_source)); - // If operand is derived from a single bit, we can simply check that the - // source bit is NOT set to determine that @node is visible if (compare_val->GetType()->GetFlatBitCount() == 1) { XLS_ASSIGN_OR_RETURN( Node * not_single_bit_value, @@ -272,48 +339,81 @@ absl::StatusOr VisibilityBuilder::GetVisibilityExprForOr( value_ones, Op::kNe)); others_not_ones.push_back(is_not_ones); } - return AndOperands(others_not_ones); + + if (others_not_ones.empty()) { + return AndOperands(others_not_ones); + } + + return FilterByDelay( + func, or_node->loc(), others_not_ones, get_remaining_delay, + /*build_empty=*/ + [&](const SourceInfo& loc) -> absl::StatusOr { + return func->MakeNode(loc, Value(UBits(1, 1))); + }, + /*build_fn=*/ + [&](std::vector& ops) -> absl::StatusOr { + return AndOperands(ops); + }); } absl::StatusOr VisibilityBuilder::GetVisibilityExprForPredicate( - std::optional predicate, Node* source, FunctionBase* func) { - if (!predicate.has_value() || nda_.IsDependent(source, *predicate)) { + std::optional predicate, Node* source, FunctionBase* func, + const std::function& is_live_source, + const std::function& get_remaining_delay) { + if (!predicate.has_value() || nda_.IsDependent(source, *predicate) || + (is_live_source && !is_live_source(*predicate))) { return nullptr; } return MakeParamIfTmpFunc(*predicate, func); } absl::StatusOr VisibilityBuilder::BuildVisibilityExprHelper( - Node* node, Node* user, Node* source, FunctionBase* func) { + Node* node, Node* user, Node* source, FunctionBase* func, + const std::function& is_live_source, + const std::function& get_remaining_delay) { if (user->Is(), source, func); + return GetVisibilityExprForSelect(node, user->As