From ee20cc1d5c21c8b1ef247f6e955bc09b3cdbc450 Mon Sep 17 00:00:00 2001 From: Nelson Liang Date: Wed, 17 Jun 2026 10:19:22 -0700 Subject: [PATCH] [Explicit State Access] Handle decoupled next nodes when transforming state elements in IR PiperOrigin-RevId: 933795307 --- xls/ir/ir_matcher.cc | 20 ++++++- xls/ir/ir_matcher.h | 49 ++++++++++++++-- xls/ir/proc.cc | 38 +++++++++--- xls/ir/proc.h | 4 +- xls/ir/proc_test.cc | 136 ++++++++++++++++++++++++++++++++++--------- 5 files changed, 200 insertions(+), 47 deletions(-) diff --git a/xls/ir/ir_matcher.cc b/xls/ir/ir_matcher.cc index 3955802d99..7e60e010c8 100644 --- a/xls/ir/ir_matcher.cc +++ b/xls/ir/ir_matcher.cc @@ -654,8 +654,18 @@ bool NextMatcher::MatchAndExplain( if (!NodeMatcher::MatchAndExplain(node, listener)) { return false; } - if (label_.has_value() && - !label_->MatchAndExplain(node->As()->label(), listener)) { + const xls::Next* next = node->As(); + if (state_element_.has_value()) { + if (next->has_state_read()) { + *listener << " is a coupled Next node, but expected decoupled"; + return false; + } + if (!state_element_->MatchAndExplain(next->state_element(), listener)) { + *listener << " has incorrect state element"; + return false; + } + } + if (label_.has_value() && !label_->MatchAndExplain(next->label(), listener)) { *listener << " has incorrect label"; return false; } @@ -664,6 +674,12 @@ bool NextMatcher::MatchAndExplain( void NextMatcher::DescribeTo(::std::ostream* os) const { std::vector additional_fields; + if (state_element_.has_value()) { + std::stringstream ss; + ss << "state_element="; + state_element_->DescribeTo(&ss); + additional_fields.push_back(ss.str()); + } if (label_.has_value()) { std::stringstream ss; ss << "label="; diff --git a/xls/ir/ir_matcher.h b/xls/ir/ir_matcher.h index 497fb2d8f7..b5d405c39b 100644 --- a/xls/ir/ir_matcher.h +++ b/xls/ir/ir_matcher.h @@ -1328,17 +1328,28 @@ inline ::testing::Matcher StateRead() { // // EXPECT_THAT(x, m::Next()); // EXPECT_THAT(x, m::Next(m::StateRead("foo"), m::Literal(1))); -// EXPECT_THAT(x, m::Next(m::StateRead("foo"), m::Literal(1), -// m::Literal(1))) -// EXPECT_THAT(x, m::Next(m::StateRead("foo"), m::Literal(1), -// m::Literal(1), "some_label")) +// EXPECT_THAT(x, m::Next(m::StateRead("foo"), m::Literal(1), m::Literal(1))); +// EXPECT_THAT(x, m::NextWithLabel(m::StateRead("foo"), m::Literal(1), +// "label")); +// +// Decoupled forms (asserts that the next node has no StateRead operand): +// EXPECT_THAT(x, m::NextWithStateElement(se, m::Literal(1))); +// EXPECT_THAT(x, m::NextWithStateElement(se, m::Literal(1), m::Literal(1))); +// EXPECT_THAT(x, m::NextWithStateElementWithLabel(se, m::Literal(1), +// "label")); +// EXPECT_THAT(x, m::NextWithStateElementWithLabel(se, m::Literal(1), +// m::Literal(1), "label")); class NextMatcher : public NodeMatcher { public: explicit NextMatcher( absl::Span> operands = {}, std::optional<::testing::Matcher&>> - label = std::nullopt) - : NodeMatcher(Op::kNext, operands), label_(std::move(label)) {} + label = std::nullopt, + std::optional<::testing::Matcher> + state_element = std::nullopt) + : NodeMatcher(Op::kNext, operands), + label_(std::move(label)), + state_element_(std::move(state_element)) {} bool MatchAndExplain(const Node* node, ::testing::MatchResultListener* listener) const override; @@ -1346,6 +1357,7 @@ class NextMatcher : public NodeMatcher { private: std::optional<::testing::Matcher&>> label_; + std::optional<::testing::Matcher> state_element_; }; inline ::testing::Matcher Next() { return NextMatcher(); } @@ -1374,6 +1386,31 @@ inline ::testing::Matcher NextWithLabel( return NextMatcher({state_read, value, predicate}, label); } +inline ::testing::Matcher NextWithStateElement( + ::testing::Matcher state_element, + ::testing::Matcher value) { + return NextMatcher({value}, std::nullopt, state_element); +} +inline ::testing::Matcher NextWithStateElement( + ::testing::Matcher state_element, + ::testing::Matcher value, + ::testing::Matcher predicate) { + return NextMatcher({value, predicate}, std::nullopt, state_element); +} +inline ::testing::Matcher NextWithStateElementWithLabel( + ::testing::Matcher state_element, + ::testing::Matcher value, + ::testing::Matcher&> label) { + return NextMatcher({value}, label, state_element); +} +inline ::testing::Matcher NextWithStateElementWithLabel( + ::testing::Matcher state_element, + ::testing::Matcher value, + ::testing::Matcher predicate, + ::testing::Matcher&> label) { + return NextMatcher({value, predicate}, label, state_element); +} + // RegisterRead matcher. Matches register name only. Supported forms: // // EXPECT_THAT(x, m::RegisterRead()); diff --git a/xls/ir/proc.cc b/xls/ir/proc.cc index 4d88219dd8..09a3b12801 100644 --- a/xls/ir/proc.cc +++ b/xls/ir/proc.cc @@ -1002,22 +1002,42 @@ absl::StatusOr Proc::TransformStateElement( // Identity-ify the old next nodes and create new ones. for (const NextTransformation& nt : transforms) { - // Make the next - XLS_ASSIGN_OR_RETURN( - Next * nxt, - MakeNodeWithName(nt.old_next->loc(), new_state_read, nt.new_value, - nt.new_predicate, nt.old_next->label(), - nt.old_next->GetName())); + Next* nxt; + if (nt.old_next->has_state_read()) { + // Coupled: use the matched new_state_read + XLS_ASSIGN_OR_RETURN( + nxt, + MakeNodeWithName(nt.old_next->loc(), new_state_read, + nt.new_value, nt.new_predicate, + nt.old_next->label(), nt.old_next->GetName())); + } else { + // Decoupled: use new_state_element directly + XLS_ASSIGN_OR_RETURN( + nxt, + MakeNodeWithName(nt.old_next->loc(), new_state_element, + nt.new_value, nt.new_predicate, + nt.old_next->label(), nt.old_next->GetName())); + } to_replace.push_back({nt.old_next, nxt}); // Identity-ify the old next. - XLS_RETURN_IF_ERROR(nt.old_next->ReplaceOperandNumber( - Next::kValueOperand, nt.old_next->state_read())); + if (nt.old_next->has_state_read()) { + XLS_RETURN_IF_ERROR(nt.old_next->ReplaceOperandNumber( + Next::kValueOperand, nt.old_next->state_read())); + } else { + XLS_ASSIGN_OR_RETURN( + Node * placeholder, + MakeNode(nt.old_next->loc(), + ZeroOfType(old_state_element->type()))); + XLS_RET_CHECK( + nt.old_next->ReplaceOperand(nt.old_next->value(), placeholder)); + } } for (const auto& [old_n, new_n] : to_replace) { XLS_RETURN_IF_ERROR(old_n->ReplaceUsesWith( new_n, [&](Node* n) { - if (n->Is() && n->As()->state_read() == old_n) { + if (n->Is() && n->As()->has_state_read() && + n->As()->state_read() == old_n) { return false; } return true; diff --git a/xls/ir/proc.h b/xls/ir/proc.h index d6dd5d9b55..6a15cce9ca 100644 --- a/xls/ir/proc.h +++ b/xls/ir/proc.h @@ -200,8 +200,8 @@ class Proc : public FunctionBase { // switch users of the old param to the new one. // // The old state element will continue to exist with a new name and all - // identity next nodes. It should be cleaned up using the - // NextValueOptimizationPass. + // identity next nodes with all users removed. It should be cleaned up by + // ProcStateOptimizationPass in RemoveUnobservableStateElements. // // The proc must only use 'next' nodes to call this function. absl::StatusOr TransformStateElement( diff --git a/xls/ir/proc_test.cc b/xls/ir/proc_test.cc index 16b68917f4..687b943f11 100644 --- a/xls/ir/proc_test.cc +++ b/xls/ir/proc_test.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -57,7 +58,36 @@ using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::UnorderedElementsAre; -class ProcTest : public IrTestBase {}; +class ProcTest : public IrTestBase { + protected: + struct TestTransformer : public Proc::StateElementTransformer { + public: + absl::StatusOr TransformStateRead( + Proc* proc, StateRead* new_state_read, + StateRead* old_state_read) override { + return proc->MakeNode(new_state_read->loc(), new_state_read, + Op::kNeg); + } + absl::StatusOr TransformNextValue(Proc* proc, + StateRead* new_state_read, + Next* old_next) override { + return proc->MakeNode(old_next->value()->loc(), old_next->value(), + Op::kNeg); + } + absl::StatusOr> TransformNextPredicate( + Proc* proc, StateRead* new_state_read, Next* old_next) override { + XLS_ASSIGN_OR_RETURN( + Node * true_const, + proc->MakeNode(old_next->loc(), Value::Bool(true))); + if (old_next->predicate()) { + return proc->MakeNode( + old_next->predicate().value()->loc(), + std::array{true_const, *old_next->predicate()}, Op::kAnd); + } + return true_const; + } + }; +}; TEST_F(ProcTest, SimpleProc) { auto p = CreatePackage(); @@ -508,33 +538,6 @@ TEST_F(ProcTest, TransformStateElement) { XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); // Test transformer that inverts the param. - struct TestTransformer : public Proc::StateElementTransformer { - public: - absl::StatusOr TransformStateRead( - Proc* proc, StateRead* new_state_read, - StateRead* old_state_read) override { - return proc->MakeNode(new_state_read->loc(), new_state_read, - Op::kNeg); - } - absl::StatusOr TransformNextValue(Proc* proc, - StateRead* new_state_read, - Next* old_next) override { - return proc->MakeNode(old_next->value()->loc(), old_next->value(), - Op::kNeg); - } - absl::StatusOr> TransformNextPredicate( - Proc* proc, StateRead* new_state_read, Next* old_next) override { - XLS_ASSIGN_OR_RETURN( - Node * true_const, - proc->MakeNode(old_next->loc(), Value::Bool(true))); - if (old_next->predicate()) { - return proc->MakeNode( - old_next->predicate().value()->loc(), - std::array{true_const, *old_next->predicate()}, Op::kAnd); - } - return true_const; - } - }; TestTransformer tt; ScopedRecordIr sri(p.get()); XLS_ASSERT_OK_AND_ASSIGN( @@ -568,6 +571,83 @@ TEST_F(ProcTest, TransformStateElement) { EXPECT_THAT(user.node(), m::Tuple(m::Neg(new_st))); } +TEST_F(ProcTest, TransformStateElementDecoupled) { + auto p = CreatePackage(); + TokenlessProcBuilder pb(TestName(), "tkn", p.get()); + auto cond = pb.StateElement("cond", UBits(0, 1)); + + XLS_ASSERT_OK_AND_ASSIGN( + StateElement * state_element, + pb.UnreadStateElement("st", Value(UBits(0b1010, 4)))); + + BValue st_read = pb.StateRead(state_element, std::nullopt, "my_read_label"); + + // Labeled next node + BValue add_st = + pb.Next(state_element, pb.Add(st_read, pb.Literal(UBits(1, 4))), cond, + /*label=*/"my_next_label"); + // Unlabeled next node + BValue sub_st = + pb.Next(state_element, pb.Subtract(st_read, pb.Literal(UBits(1, 4))), + pb.Not(cond)); + + XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build()); + // Test transformer that inverts the param. + TestTransformer tt; + ScopedRecordIr sri(p.get()); + XLS_ASSERT_OK_AND_ASSIGN( + StateElement * new_st_element, + proc->TransformStateElement(state_element, Value(UBits(0b0101, 4)), tt)); + StateRead* new_st = proc->GetStateReadByStateElement(new_st_element); + + // Make sure the st next has been identity-ified (dummy value Zero is set + // for decoupled) + EXPECT_THAT(st_read.node(), m::StateRead(testing::Not("st"))); + EXPECT_THAT(st_read.node()->users(), ::testing::IsEmpty()); + + // Verify old next nodes are identity-ified (labeled and unlabeled) + EXPECT_THAT(add_st.node(), m::NextWithStateElementWithLabel( + state_element, m::Literal(0), cond.node(), + std::optional("my_next_label"))); + EXPECT_THAT(sub_st.node(), + m::NextWithStateElement(state_element, m::Literal(0), + m::Not(cond.node()))); + + // Make sure that 'new_state_read' takes over the name and everything. + EXPECT_THAT(new_st, + m::StateRead("st", std::optional("my_read_label"))); + EXPECT_THAT(new_st->users(), UnorderedElementsAre(m::Neg(new_st))); + + // Verify new next nodes (labeled and unlabeled) + EXPECT_THAT(proc->next_values(new_st_element), + UnorderedElementsAre( + m::NextWithStateElementWithLabel( + new_st_element, + m::Neg(m::Add(m::Neg(new_st), m::Literal(UBits(1, 4)))), + m::And(m::Literal(UBits(1, 1)), cond.node()), + std::optional("my_next_label")), + m::NextWithStateElement( + new_st_element, + m::Neg(m::Sub(m::Neg(new_st), m::Literal(UBits(1, 4)))), + m::And(m::Literal(UBits(1, 1)), m::Not(cond.node()))))); + + XLS_ASSERT_OK_AND_ASSIGN(int64_t old_idx, + proc->GetStateElementIndex(state_element)); + + // Remove the old Next nodes. + std::vector old_nexts(proc->next_values(state_element).begin(), + proc->next_values(state_element).end()); + for (Next* next : old_nexts) { + XLS_ASSERT_OK(proc->RemoveNode(next)); + } + + // Remove the old state element (and its StateRead) with no users. + XLS_EXPECT_OK(proc->RemoveStateElement(old_idx)); + + // Verify only 'cond' and 'new_st' remain. + EXPECT_EQ(proc->GetStateElementCount(), 2); +} + class ScheduledProcTest : public IrTestBase { protected: absl::StatusOr CreateScheduledProc(Package* p) {