From 37e02b2e929b474dc62fe79d8d21c9ab9667875c Mon Sep 17 00:00:00 2001 From: Nelson Liang Date: Tue, 16 Jun 2026 13:46:59 -0700 Subject: [PATCH] [Explicit State Access] Support multiple reads for state element, verify labels are propagated, and decouple next nodes from state reads when transforming state elements. PiperOrigin-RevId: 933278396 --- xls/ir/BUILD | 1 + xls/ir/ir_matcher.cc | 21 +++- xls/ir/ir_matcher.h | 39 +++++++- xls/ir/proc.cc | 140 +++++++++++++++++++------- xls/ir/proc_test.cc | 234 +++++++++++++++++++++++++++++++++++++------ 5 files changed, 365 insertions(+), 70 deletions(-) diff --git a/xls/ir/BUILD b/xls/ir/BUILD index 32dab03c75..6d0d62313b 100644 --- a/xls/ir/BUILD +++ b/xls/ir/BUILD @@ -1893,6 +1893,7 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", "@googletest//:gtest", ], ) diff --git a/xls/ir/ir_matcher.cc b/xls/ir/ir_matcher.cc index 3955802d99..ce6f504c66 100644 --- a/xls/ir/ir_matcher.cc +++ b/xls/ir/ir_matcher.cc @@ -654,8 +654,19 @@ bool NextMatcher::MatchAndExplain( if (!NodeMatcher::MatchAndExplain(node, listener)) { return false; } - if (label_.has_value() && - !label_->MatchAndExplain(node->As()->label(), listener)) { + const xls::Next* next = node->As(); + if (state_element_.has_value()) { + if (next->has_state_read()) { + *listener << " expected decoupled Next node with only state_element, but " + "it has state_read"; + return false; + } + if (!state_element_->MatchAndExplain(next->state_element(), listener)) { + *listener << " has incorrect state_element"; + return false; + } + } + if (label_.has_value() && !label_->MatchAndExplain(next->label(), listener)) { *listener << " has incorrect label"; return false; } @@ -664,6 +675,12 @@ bool NextMatcher::MatchAndExplain( void NextMatcher::DescribeTo(::std::ostream* os) const { std::vector additional_fields; + if (state_element_.has_value()) { + std::stringstream ss; + ss << "state_element="; + state_element_->DescribeTo(&ss); + additional_fields.push_back(ss.str()); + } if (label_.has_value()) { std::stringstream ss; ss << "label="; diff --git a/xls/ir/ir_matcher.h b/xls/ir/ir_matcher.h index 497fb2d8f7..80d43f4c17 100644 --- a/xls/ir/ir_matcher.h +++ b/xls/ir/ir_matcher.h @@ -1329,9 +1329,18 @@ inline ::testing::Matcher StateRead() { // EXPECT_THAT(x, m::Next()); // EXPECT_THAT(x, m::Next(m::StateRead("foo"), m::Literal(1))); // EXPECT_THAT(x, m::Next(m::StateRead("foo"), m::Literal(1), -// m::Literal(1))) -// EXPECT_THAT(x, m::Next(m::StateRead("foo"), m::Literal(1), -// m::Literal(1), "some_label")) +// m::Literal(1))); +// EXPECT_THAT(x, m::NextWithLabel(m::StateRead("foo"), m::Literal(1), +// "some_label")); +// EXPECT_THAT(x, m::NextWithLabel(m::StateRead("foo"), m::Literal(1), +// m::Literal(1), "some_label")); +// +// NextStateElement matcher. Supported forms: +// +// EXPECT_THAT(x, m::NextStateElement(state_element, m::Literal(1), +// m::Literal(1))); +// EXPECT_THAT(x, m::NextStateElementWithLabel(state_element, m::Literal(1), +// m::Literal(1), "some_label")); class NextMatcher : public NodeMatcher { public: explicit NextMatcher( @@ -1340,11 +1349,21 @@ class NextMatcher : public NodeMatcher { label = std::nullopt) : NodeMatcher(Op::kNext, operands), label_(std::move(label)) {} + explicit NextMatcher( + ::testing::Matcher state_element, + absl::Span> operands = {}, + std::optional<::testing::Matcher&>> + label = std::nullopt) + : NodeMatcher(Op::kNext, operands), + state_element_(std::move(state_element)), + label_(std::move(label)) {} + bool MatchAndExplain(const Node* node, ::testing::MatchResultListener* listener) const override; void DescribeTo(::std::ostream* os) const override; private: + std::optional<::testing::Matcher> state_element_; std::optional<::testing::Matcher&>> label_; }; @@ -1374,6 +1393,20 @@ inline ::testing::Matcher NextWithLabel( return NextMatcher({state_read, value, predicate}, label); } +inline ::testing::Matcher NextStateElement( + ::testing::Matcher state_element, + ::testing::Matcher value, + ::testing::Matcher predicate) { + return NextMatcher(std::move(state_element), {value, predicate}); +} +inline ::testing::Matcher NextStateElementWithLabel( + ::testing::Matcher state_element, + ::testing::Matcher value, + ::testing::Matcher predicate, + ::testing::Matcher&> label) { + return NextMatcher(std::move(state_element), {value, predicate}, label); +} + // RegisterRead matcher. Matches register name only. Supported forms: // // EXPECT_THAT(x, m::RegisterRead()); diff --git a/xls/ir/proc.cc b/xls/ir/proc.cc index 4d88219dd8..db9ec03062 100644 --- a/xls/ir/proc.cc +++ b/xls/ir/proc.cc @@ -15,6 +15,7 @@ #include "xls/ir/proc.h" #include +#include #include #include #include @@ -940,84 +941,149 @@ absl::Status Proc::ConvertToNewStyle() { absl::StatusOr Proc::TransformStateElement( StateElement* old_state_element, const Value& init_value, Proc::StateElementTransformer& transform) { - StateRead* old_state_read = GetStateReadByStateElement(old_state_element); + absl::Span old_state_reads = + GetStateReadsByStateElement(old_state_element); std::string orig_name(old_state_element->name()); - std::string orig_read_name(old_state_read->GetNameView()); - XLS_ASSIGN_OR_RETURN(std::optional read_predicate, - transform.TransformReadPredicate(this, old_state_read)); + + std::vector orig_read_names; + orig_read_names.reserve(old_state_reads.size()); + for (StateRead* old_state_read : old_state_reads) { + orig_read_names.push_back(std::string(old_state_read->GetNameView())); + } + + // Create new state element (unread initially). XLS_ASSIGN_OR_RETURN( - StateRead * new_state_read, - AppendStateElement(absl::StrFormat("TEMP_NAME__%s__", orig_name), - init_value, read_predicate, - /*next_state=*/std::nullopt)); - new_state_read->SetLoc(old_state_read->loc()); - new_state_read->set_label(old_state_read->label()); - if (old_state_read->state_element()->non_synthesizable()) { - new_state_read->state_element()->SetNonSynthesizable(); + StateElement * new_state_element, + InsertUnreadStateElement(GetStateElementCount(), + absl::StrFormat("TEMP_NAME__%s__", orig_name), + init_value)); + if (old_state_element->non_synthesizable()) { + new_state_element->SetNonSynthesizable(); } - StateElement* new_state_element = new_state_read->state_element(); std::string temp_name = new_state_element->name(); - XLS_ASSIGN_OR_RETURN( - Node * new_state_value, - transform.TransformStateRead(this, new_state_read, old_state_read)); - std::vector> to_replace{ - {old_state_read, new_state_value}}; + std::vector> to_replace; + std::vector new_state_reads; + new_state_reads.reserve(old_state_reads.size()); + // We track the first state read because we need to check the return type of + // the state read after transformation. The type is identical across all + // state reads. + StateRead* first_new_read = nullptr; + + // Transform and create new reads for each old read. + for (StateRead* old_state_read : old_state_reads) { + XLS_ASSIGN_OR_RETURN( + std::optional read_predicate, + transform.TransformReadPredicate(this, old_state_read)); + XLS_ASSIGN_OR_RETURN( + StateRead * new_state_read, + AddStateRead(new_state_element, read_predicate, old_state_read->label(), + old_state_read->loc())); + new_state_reads.push_back(new_state_read); + if (first_new_read == nullptr) { + first_new_read = new_state_read; + } + + XLS_ASSIGN_OR_RETURN( + Node * new_state_value, + transform.TransformStateRead(this, new_state_read, old_state_read)); + to_replace.push_back({old_state_read, new_state_value}); + } + struct NextTransformation { Next* old_next; Node* new_value; std::optional new_predicate; + StateRead* new_state_read; }; std::vector transforms; for (Next* nxt : next_values(old_state_element)) { NextTransformation& new_next = transforms.emplace_back(); new_next.old_next = nxt; - XLS_ASSIGN_OR_RETURN(new_next.new_value, transform.TransformNextValue( - this, new_state_read, nxt)); - XLS_RET_CHECK(new_next.new_value->GetType() == new_state_read->GetType()) + + // Find the corresponding new state read. + new_next.new_state_read = first_new_read; + if (nxt->has_state_read()) { + StateRead* old_read = nxt->state_read()->As(); + for (size_t i = 0; i < old_state_reads.size(); ++i) { + if (old_state_reads[i] == old_read) { + new_next.new_state_read = new_state_reads[i]; + break; + } + } + } + + XLS_ASSIGN_OR_RETURN( + new_next.new_value, + transform.TransformNextValue(this, new_next.new_state_read, nxt)); + XLS_RET_CHECK(new_next.new_value->GetType() == + new_next.new_state_read->GetType()) << "New value is not compatible type. Expected: " - << new_state_read->GetType() << " got " << new_next.new_value; + << new_next.new_state_read->GetType() << " got " << new_next.new_value; XLS_ASSIGN_OR_RETURN( new_next.new_predicate, - transform.TransformNextPredicate(this, new_state_read, nxt)); + transform.TransformNextPredicate(this, new_next.new_state_read, nxt)); } - // We've transformed all the graph elements. Start replacing them. - - // Switch old_state_read's name to a temporary to-remove name + // Rename old element & reads to a temporary to-remove name. std::string to_remove_name = UniquifyStateName( absl::StrFormat("TO_REMOVE_TRANSFORMED_STATE__%s__", orig_name)); auto orig_storage = state_elements_.extract(orig_name); orig_storage.key() = to_remove_name; old_state_element->SetName(to_remove_name); - old_state_read->SetName(to_remove_name); + for (StateRead* old_state_read : old_state_reads) { + old_state_read->SetName(to_remove_name); + } CHECK(state_elements_.insert(std::move(orig_storage)).inserted); // Take over the old state element & read names. auto new_storage = state_elements_.extract(temp_name); new_storage.key() = orig_name; new_state_element->SetName(orig_name); - new_state_read->SetNameDirectly(orig_read_name); + for (size_t i = 0; i < new_state_reads.size(); ++i) { + new_state_reads[i]->SetNameDirectly(orig_read_names[i]); + } CHECK(state_elements_.insert(std::move(new_storage)).inserted); // Identity-ify the old next nodes and create new ones. for (const NextTransformation& nt : transforms) { - // Make the next - XLS_ASSIGN_OR_RETURN( - Next * nxt, - MakeNodeWithName(nt.old_next->loc(), new_state_read, nt.new_value, - nt.new_predicate, nt.old_next->label(), - nt.old_next->GetName())); + Next* nxt; + if (nt.old_next->has_state_read()) { + // Coupled: use the matched new_state_read + XLS_ASSIGN_OR_RETURN( + nxt, + MakeNodeWithName(nt.old_next->loc(), nt.new_state_read, + nt.new_value, nt.new_predicate, + nt.old_next->label(), nt.old_next->GetName())); + } else { + // Decoupled: use new_state_element directly + XLS_ASSIGN_OR_RETURN( + nxt, + MakeNodeWithName(nt.old_next->loc(), new_state_element, + nt.new_value, nt.new_predicate, + nt.old_next->label(), nt.old_next->GetName())); + } to_replace.push_back({nt.old_next, nxt}); // Identity-ify the old next. - XLS_RETURN_IF_ERROR(nt.old_next->ReplaceOperandNumber( - Next::kValueOperand, nt.old_next->state_read())); + if (nt.old_next->has_state_read()) { + XLS_RETURN_IF_ERROR(nt.old_next->ReplaceOperandNumber( + Next::kValueOperand, nt.old_next->state_read())); + } else { + XLS_ASSIGN_OR_RETURN( + Node * dummy, + MakeNode(nt.old_next->loc(), + ZeroOfType(old_state_element->type()))); + XLS_RET_CHECK(nt.old_next->ReplaceOperand(nt.old_next->value(), dummy)); + } } + + // Replace uses. for (const auto& [old_n, new_n] : to_replace) { XLS_RETURN_IF_ERROR(old_n->ReplaceUsesWith( new_n, [&](Node* n) { - if (n->Is() && n->As()->state_read() == old_n) { + if (n->Is() && n->As()->has_state_read() && + n->As()->state_read() == old_n) { return false; } return true; diff --git a/xls/ir/proc_test.cc b/xls/ir/proc_test.cc index 16b68917f4..66862e5b54 100644 --- a/xls/ir/proc_test.cc +++ b/xls/ir/proc_test.cc @@ -26,6 +26,7 @@ #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" #include "xls/common/status/matchers.h" #include "xls/common/status/status_macros.h" #include "xls/ir/bits.h" @@ -57,7 +58,36 @@ using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::UnorderedElementsAre; -class ProcTest : public IrTestBase {}; +class ProcTest : public IrTestBase { + protected: + struct TestTransformer : public Proc::StateElementTransformer { + public: + absl::StatusOr TransformStateRead( + Proc* proc, StateRead* new_state_read, + StateRead* old_state_read) override { + return proc->MakeNode(new_state_read->loc(), new_state_read, + Op::kNeg); + } + absl::StatusOr TransformNextValue(Proc* proc, + StateRead* new_state_read, + Next* old_next) override { + return proc->MakeNode(old_next->value()->loc(), old_next->value(), + Op::kNeg); + } + absl::StatusOr> TransformNextPredicate( + Proc* proc, StateRead* new_state_read, Next* old_next) override { + XLS_ASSIGN_OR_RETURN( + Node * true_const, + proc->MakeNode(old_next->loc(), Value::Bool(true))); + if (old_next->predicate()) { + return proc->MakeNode( + old_next->predicate().value()->loc(), + std::array{true_const, *old_next->predicate()}, Op::kAnd); + } + return true_const; + } + }; +}; TEST_F(ProcTest, SimpleProc) { auto p = CreatePackage(); @@ -508,33 +538,6 @@ TEST_F(ProcTest, TransformStateElement) { XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); // Test transformer that inverts the param. - struct TestTransformer : public Proc::StateElementTransformer { - public: - absl::StatusOr TransformStateRead( - Proc* proc, StateRead* new_state_read, - StateRead* old_state_read) override { - return proc->MakeNode(new_state_read->loc(), new_state_read, - Op::kNeg); - } - absl::StatusOr TransformNextValue(Proc* proc, - StateRead* new_state_read, - Next* old_next) override { - return proc->MakeNode(old_next->value()->loc(), old_next->value(), - Op::kNeg); - } - absl::StatusOr> TransformNextPredicate( - Proc* proc, StateRead* new_state_read, Next* old_next) override { - XLS_ASSIGN_OR_RETURN( - Node * true_const, - proc->MakeNode(old_next->loc(), Value::Bool(true))); - if (old_next->predicate()) { - return proc->MakeNode( - old_next->predicate().value()->loc(), - std::array{true_const, *old_next->predicate()}, Op::kAnd); - } - return true_const; - } - }; TestTransformer tt; ScopedRecordIr sri(p.get()); XLS_ASSERT_OK_AND_ASSIGN( @@ -568,6 +571,181 @@ TEST_F(ProcTest, TransformStateElement) { EXPECT_THAT(user.node(), m::Tuple(m::Neg(new_st))); } +TEST_F(ProcTest, TransformStateElementDecoupled) { + auto p = CreatePackage(); + TokenlessProcBuilder pb(TestName(), "tkn", p.get()); + auto cond = pb.StateElement("cond", UBits(0, 1)); + + XLS_ASSERT_OK_AND_ASSIGN( + StateElement * state_element, + pb.UnreadStateElement("st", Value(UBits(0b1010, 4)))); + + BValue st_read = pb.StateRead(state_element); + + // Create decoupled Next nodes (using state_element instead of state read) + BValue add_st = + pb.Next(state_element, pb.Add(st_read, pb.Literal(UBits(1, 4))), cond); + BValue sub_st = + pb.Next(state_element, pb.Subtract(st_read, pb.Literal(UBits(1, 4))), + pb.Not(cond)); + + XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); + // Test transformer that inverts the param. + TestTransformer tt; + ScopedRecordIr sri(p.get()); + XLS_ASSERT_OK_AND_ASSIGN( + StateElement * new_st_element, + proc->TransformStateElement(state_element, Value(UBits(0b0101, 4)), tt)); + StateRead* new_st = proc->GetStateReadByStateElement(new_st_element); + + // Make sure the st nexts has been identity-ified (dummy value Zero is set + // for decoupled) + EXPECT_THAT(st_read.node(), m::StateRead(testing::Not("st"))); + EXPECT_THAT(st_read.node()->users(), ::testing::IsEmpty()); + EXPECT_THAT(add_st.node(), + m::NextStateElement(state_element, m::Literal(0), cond.node())); + EXPECT_THAT(sub_st.node(), m::NextStateElement(state_element, m::Literal(0), + m::Not(cond.node()))); + + // Make sure that 'new_state_read' takes over the name and everything. + EXPECT_THAT(new_st, m::StateRead("st")); + EXPECT_THAT(new_st->users(), UnorderedElementsAre(m::Neg(new_st))); + EXPECT_THAT(proc->next_values(new_st_element), + UnorderedElementsAre( + m::NextStateElement( + new_st_element, + m::Neg(m::Add(m::Neg(new_st), m::Literal(UBits(1, 4)))), + m::And(m::Literal(UBits(1, 1)), cond.node())), + m::NextStateElement( + new_st_element, + m::Neg(m::Sub(m::Neg(new_st), m::Literal(UBits(1, 4)))), + m::And(m::Literal(UBits(1, 1)), m::Not(cond.node()))))); +} + +TEST_F(ProcTest, TransformStateElementMultipleReadsCoupled) { + auto p = CreatePackage(); + TokenlessProcBuilder pb(TestName(), "tkn", p.get()); + auto cond = pb.StateElement("cond", UBits(0, 1)); + + XLS_ASSERT_OK_AND_ASSIGN( + StateElement * st_element, + pb.UnreadStateElement("st", Value(UBits(0b1010, 4)))); + + BValue not_cond_val = pb.Not(cond); + + // Add the first read. + BValue st_read_1 = pb.StateRead(st_element, cond, "read_1_label"); + + // Add a second read. + BValue st_read_2 = pb.StateRead(st_element, not_cond_val, "read_2_label"); + + // Add next values (coupled). + pb.Next(st_read_1, pb.Add(st_read_1, pb.Literal(UBits(1, 4))), cond, + "next_1_label"); + pb.Next(st_read_2, pb.Subtract(st_read_2, pb.Literal(UBits(1, 4))), + not_cond_val, "next_2_label"); + + XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); + Node* not_cond = not_cond_val.node(); + + TestTransformer tt; + XLS_ASSERT_OK_AND_ASSIGN( + StateElement * new_st_element, + proc->TransformStateElement(st_element, Value(UBits(0b0101, 4)), tt)); + + absl::Span new_reads = + proc->GetStateReadsByStateElement(new_st_element); + ASSERT_EQ(new_reads.size(), 2); + StateRead* new_read_1 = new_reads[0]; + StateRead* new_read_2 = new_reads[1]; + + EXPECT_THAT( + new_read_1, + m::StateRead("st", testing::Optional(std::string("read_1_label")))); + EXPECT_THAT( + new_read_2, + m::StateRead("st", testing::Optional(std::string("read_2_label")))); + + EXPECT_THAT(new_read_1->predicate(), ::testing::Optional(cond.node())); + EXPECT_THAT(new_read_2->predicate(), ::testing::Optional(not_cond)); + + EXPECT_THAT( + proc->next_values(new_st_element), + UnorderedElementsAre( + m::NextWithLabel( + new_read_1, + m::Neg(m::Add(m::Neg(new_read_1), m::Literal(UBits(1, 4)))), + m::And(m::Literal(UBits(1, 1)), cond.node()), + testing::Optional(std::string("next_1_label"))), + m::NextWithLabel( + new_read_2, + m::Neg(m::Sub(m::Neg(new_read_2), m::Literal(UBits(1, 4)))), + m::And(m::Literal(UBits(1, 1)), not_cond), + testing::Optional(std::string("next_2_label"))))); +} + +TEST_F(ProcTest, TransformStateElementMultipleReadsDecoupled) { + auto p = CreatePackage(); + TokenlessProcBuilder pb(TestName(), "tkn", p.get()); + auto cond = pb.StateElement("cond", UBits(0, 1)); + + XLS_ASSERT_OK_AND_ASSIGN( + StateElement * st_element, + pb.UnreadStateElement("st", Value(UBits(0b1010, 4)))); + + BValue not_cond_val = pb.Not(cond); + + // Add the first read. + BValue st_read_1 = pb.StateRead(st_element, cond, "read_1_label"); + + // Add a second read. + BValue st_read_2 = pb.StateRead(st_element, not_cond_val, "read_2_label"); + + // Add next values (decoupled). + pb.Next(st_element, pb.Add(st_read_1, pb.Literal(UBits(1, 4))), cond, + "next_1_label"); + pb.Next(st_element, pb.Subtract(st_read_2, pb.Literal(UBits(1, 4))), + not_cond_val, "next_2_label"); + + XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); + Node* not_cond = not_cond_val.node(); + + TestTransformer tt; + XLS_ASSERT_OK_AND_ASSIGN( + StateElement * new_st_element, + proc->TransformStateElement(st_element, Value(UBits(0b0101, 4)), tt)); + + absl::Span new_reads = + proc->GetStateReadsByStateElement(new_st_element); + ASSERT_EQ(new_reads.size(), 2); + StateRead* new_read_1 = new_reads[0]; + StateRead* new_read_2 = new_reads[1]; + + EXPECT_THAT( + new_read_1, + m::StateRead("st", testing::Optional(std::string("read_1_label")))); + EXPECT_THAT( + new_read_2, + m::StateRead("st", testing::Optional(std::string("read_2_label")))); + + EXPECT_THAT(new_read_1->predicate(), ::testing::Optional(cond.node())); + EXPECT_THAT(new_read_2->predicate(), ::testing::Optional(not_cond)); + + EXPECT_THAT( + proc->next_values(new_st_element), + UnorderedElementsAre( + m::NextStateElementWithLabel( + new_st_element, + m::Neg(m::Add(m::Neg(new_read_1), m::Literal(UBits(1, 4)))), + m::And(m::Literal(UBits(1, 1)), cond.node()), + testing::Optional(std::string("next_1_label"))), + m::NextStateElementWithLabel( + new_st_element, + m::Neg(m::Sub(m::Neg(new_read_2), m::Literal(UBits(1, 4)))), + m::And(m::Literal(UBits(1, 1)), not_cond), + testing::Optional(std::string("next_2_label"))))); +} + class ScheduledProcTest : public IrTestBase { protected: absl::StatusOr CreateScheduledProc(Package* p) {