Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions xls/ir/ir_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -654,8 +654,18 @@ bool NextMatcher::MatchAndExplain(
if (!NodeMatcher::MatchAndExplain(node, listener)) {
return false;
}
if (label_.has_value() &&
!label_->MatchAndExplain(node->As<xls::Next>()->label(), listener)) {
const xls::Next* next = node->As<xls::Next>();
if (state_element_.has_value()) {
if (next->has_state_read()) {
*listener << " is a coupled Next node, but expected decoupled";
return false;
}
if (!state_element_->MatchAndExplain(next->state_element(), listener)) {
*listener << " has incorrect state element";
return false;
}
}
if (label_.has_value() && !label_->MatchAndExplain(next->label(), listener)) {
*listener << " has incorrect label";
return false;
}
Expand All @@ -664,6 +674,12 @@ bool NextMatcher::MatchAndExplain(

void NextMatcher::DescribeTo(::std::ostream* os) const {
std::vector<std::string> additional_fields;
if (state_element_.has_value()) {
std::stringstream ss;
ss << "state_element=";
state_element_->DescribeTo(&ss);
additional_fields.push_back(ss.str());
}
if (label_.has_value()) {
std::stringstream ss;
ss << "label=";
Expand Down
49 changes: 43 additions & 6 deletions xls/ir/ir_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -1328,24 +1328,36 @@ inline ::testing::Matcher<const ::xls::Node*> StateRead() {
//
// EXPECT_THAT(x, m::Next());
// EXPECT_THAT(x, m::Next(m::StateRead("foo"), m::Literal(1)));
// EXPECT_THAT(x, m::Next(m::StateRead("foo"), m::Literal(1),
// m::Literal(1)))
// EXPECT_THAT(x, m::Next(m::StateRead("foo"), m::Literal(1),
// m::Literal(1), "some_label"))
// EXPECT_THAT(x, m::Next(m::StateRead("foo"), m::Literal(1), m::Literal(1)));
// EXPECT_THAT(x, m::NextWithLabel(m::StateRead("foo"), m::Literal(1),
// "label"));
//
// Decoupled forms (asserts that the next node has no StateRead operand):
// EXPECT_THAT(x, m::NextWithStateElement(se, m::Literal(1)));
// EXPECT_THAT(x, m::NextWithStateElement(se, m::Literal(1), m::Literal(1)));
// EXPECT_THAT(x, m::NextWithStateElementWithLabel(se, m::Literal(1),
// "label"));
// EXPECT_THAT(x, m::NextWithStateElementWithLabel(se, m::Literal(1),
// m::Literal(1), "label"));
class NextMatcher : public NodeMatcher {
public:
explicit NextMatcher(
absl::Span<const ::testing::Matcher<const Node*>> operands = {},
std::optional<::testing::Matcher<const std::optional<std::string>&>>
label = std::nullopt)
: NodeMatcher(Op::kNext, operands), label_(std::move(label)) {}
label = std::nullopt,
std::optional<::testing::Matcher<const ::xls::StateElement*>>
state_element = std::nullopt)
: NodeMatcher(Op::kNext, operands),
label_(std::move(label)),
state_element_(std::move(state_element)) {}

bool MatchAndExplain(const Node* node,
::testing::MatchResultListener* listener) const override;
void DescribeTo(::std::ostream* os) const override;

private:
std::optional<::testing::Matcher<const std::optional<std::string>&>> label_;
std::optional<::testing::Matcher<const ::xls::StateElement*>> state_element_;
};

inline ::testing::Matcher<const ::xls::Node*> Next() { return NextMatcher(); }
Expand Down Expand Up @@ -1374,6 +1386,31 @@ inline ::testing::Matcher<const ::xls::Node*> NextWithLabel(
return NextMatcher({state_read, value, predicate}, label);
}

inline ::testing::Matcher<const ::xls::Node*> NextWithStateElement(
::testing::Matcher<const ::xls::StateElement*> state_element,
::testing::Matcher<const Node*> value) {
return NextMatcher({value}, std::nullopt, state_element);
}
inline ::testing::Matcher<const ::xls::Node*> NextWithStateElement(
::testing::Matcher<const ::xls::StateElement*> state_element,
::testing::Matcher<const Node*> value,
::testing::Matcher<const Node*> predicate) {
return NextMatcher({value, predicate}, std::nullopt, state_element);
}
inline ::testing::Matcher<const ::xls::Node*> NextWithStateElementWithLabel(
::testing::Matcher<const ::xls::StateElement*> state_element,
::testing::Matcher<const Node*> value,
::testing::Matcher<const std::optional<std::string>&> label) {
return NextMatcher({value}, label, state_element);
}
inline ::testing::Matcher<const ::xls::Node*> NextWithStateElementWithLabel(
::testing::Matcher<const ::xls::StateElement*> state_element,
::testing::Matcher<const Node*> value,
::testing::Matcher<const Node*> predicate,
::testing::Matcher<const std::optional<std::string>&> label) {
return NextMatcher({value, predicate}, label, state_element);
}

// RegisterRead matcher. Matches register name only. Supported forms:
//
// EXPECT_THAT(x, m::RegisterRead());
Expand Down
38 changes: 29 additions & 9 deletions xls/ir/proc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1002,22 +1002,42 @@ absl::StatusOr<StateElement*> Proc::TransformStateElement(

// Identity-ify the old next nodes and create new ones.
for (const NextTransformation& nt : transforms) {
// Make the next
XLS_ASSIGN_OR_RETURN(
Next * nxt,
MakeNodeWithName<Next>(nt.old_next->loc(), new_state_read, nt.new_value,
nt.new_predicate, nt.old_next->label(),
nt.old_next->GetName()));
Next* nxt;
if (nt.old_next->has_state_read()) {
// Coupled: use the matched new_state_read
XLS_ASSIGN_OR_RETURN(
nxt,
MakeNodeWithName<Next>(nt.old_next->loc(), new_state_read,
nt.new_value, nt.new_predicate,
nt.old_next->label(), nt.old_next->GetName()));
} else {
// Decoupled: use new_state_element directly
XLS_ASSIGN_OR_RETURN(
nxt,
MakeNodeWithName<Next>(nt.old_next->loc(), new_state_element,
nt.new_value, nt.new_predicate,
nt.old_next->label(), nt.old_next->GetName()));
}
to_replace.push_back({nt.old_next, nxt});
// Identity-ify the old next.
XLS_RETURN_IF_ERROR(nt.old_next->ReplaceOperandNumber(
Next::kValueOperand, nt.old_next->state_read()));
if (nt.old_next->has_state_read()) {
XLS_RETURN_IF_ERROR(nt.old_next->ReplaceOperandNumber(
Next::kValueOperand, nt.old_next->state_read()));
} else {
XLS_ASSIGN_OR_RETURN(
Node * placeholder,
MakeNode<Literal>(nt.old_next->loc(),
ZeroOfType(old_state_element->type())));
XLS_RET_CHECK(
nt.old_next->ReplaceOperand(nt.old_next->value(), placeholder));
}
}
for (const auto& [old_n, new_n] : to_replace) {
XLS_RETURN_IF_ERROR(old_n->ReplaceUsesWith(
new_n,
[&](Node* n) {
if (n->Is<Next>() && n->As<Next>()->state_read() == old_n) {
if (n->Is<Next>() && n->As<Next>()->has_state_read() &&
n->As<Next>()->state_read() == old_n) {
return false;
}
return true;
Expand Down
4 changes: 2 additions & 2 deletions xls/ir/proc.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ class Proc : public FunctionBase {
// switch users of the old param to the new one.
//
// The old state element will continue to exist with a new name and all
// identity next nodes. It should be cleaned up using the
// NextValueOptimizationPass.
// identity next nodes with all users removed. It should be cleaned up by
// ProcStateOptimizationPass in RemoveUnobservableStateElements.
//
// The proc must only use 'next' nodes to call this function.
absl::StatusOr<StateElement*> TransformStateElement(
Expand Down
136 changes: 108 additions & 28 deletions xls/ir/proc_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <optional>
#include <string>
#include <string_view>
#include <vector>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
Expand Down Expand Up @@ -57,7 +58,36 @@ using ::testing::HasSubstr;
using ::testing::IsEmpty;
using ::testing::UnorderedElementsAre;

class ProcTest : public IrTestBase {};
class ProcTest : public IrTestBase {
protected:
struct TestTransformer : public Proc::StateElementTransformer {
public:
absl::StatusOr<Node*> TransformStateRead(
Proc* proc, StateRead* new_state_read,
StateRead* old_state_read) override {
return proc->MakeNode<UnOp>(new_state_read->loc(), new_state_read,
Op::kNeg);
}
absl::StatusOr<Node*> TransformNextValue(Proc* proc,
StateRead* new_state_read,
Next* old_next) override {
return proc->MakeNode<UnOp>(old_next->value()->loc(), old_next->value(),
Op::kNeg);
}
absl::StatusOr<std::optional<Node*>> TransformNextPredicate(
Proc* proc, StateRead* new_state_read, Next* old_next) override {
XLS_ASSIGN_OR_RETURN(
Node * true_const,
proc->MakeNode<Literal>(old_next->loc(), Value::Bool(true)));
if (old_next->predicate()) {
return proc->MakeNode<NaryOp>(
old_next->predicate().value()->loc(),
std::array<Node*, 2>{true_const, *old_next->predicate()}, Op::kAnd);
}
return true_const;
}
};
};

TEST_F(ProcTest, SimpleProc) {
auto p = CreatePackage();
Expand Down Expand Up @@ -508,33 +538,6 @@ TEST_F(ProcTest, TransformStateElement) {

XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build());
// Test transformer that inverts the param.
struct TestTransformer : public Proc::StateElementTransformer {
public:
absl::StatusOr<Node*> TransformStateRead(
Proc* proc, StateRead* new_state_read,
StateRead* old_state_read) override {
return proc->MakeNode<UnOp>(new_state_read->loc(), new_state_read,
Op::kNeg);
}
absl::StatusOr<Node*> TransformNextValue(Proc* proc,
StateRead* new_state_read,
Next* old_next) override {
return proc->MakeNode<UnOp>(old_next->value()->loc(), old_next->value(),
Op::kNeg);
}
absl::StatusOr<std::optional<Node*>> TransformNextPredicate(
Proc* proc, StateRead* new_state_read, Next* old_next) override {
XLS_ASSIGN_OR_RETURN(
Node * true_const,
proc->MakeNode<Literal>(old_next->loc(), Value::Bool(true)));
if (old_next->predicate()) {
return proc->MakeNode<NaryOp>(
old_next->predicate().value()->loc(),
std::array<Node*, 2>{true_const, *old_next->predicate()}, Op::kAnd);
}
return true_const;
}
};
TestTransformer tt;
ScopedRecordIr sri(p.get());
XLS_ASSERT_OK_AND_ASSIGN(
Expand Down Expand Up @@ -568,6 +571,83 @@ TEST_F(ProcTest, TransformStateElement) {
EXPECT_THAT(user.node(), m::Tuple(m::Neg(new_st)));
}

TEST_F(ProcTest, TransformStateElementDecoupled) {
auto p = CreatePackage();
TokenlessProcBuilder pb(TestName(), "tkn", p.get());
auto cond = pb.StateElement("cond", UBits(0, 1));

XLS_ASSERT_OK_AND_ASSIGN(
StateElement * state_element,
pb.UnreadStateElement("st", Value(UBits(0b1010, 4))));

BValue st_read = pb.StateRead(state_element, std::nullopt, "my_read_label");

// Labeled next node
BValue add_st =
pb.Next(state_element, pb.Add(st_read, pb.Literal(UBits(1, 4))), cond,
/*label=*/"my_next_label");
// Unlabeled next node
BValue sub_st =
pb.Next(state_element, pb.Subtract(st_read, pb.Literal(UBits(1, 4))),
pb.Not(cond));

XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build());
// Test transformer that inverts the param.
TestTransformer tt;
ScopedRecordIr sri(p.get());
XLS_ASSERT_OK_AND_ASSIGN(
StateElement * new_st_element,
proc->TransformStateElement(state_element, Value(UBits(0b0101, 4)), tt));
StateRead* new_st = proc->GetStateReadByStateElement(new_st_element);

// Make sure the st next has been identity-ified (dummy value Zero is set
// for decoupled)
EXPECT_THAT(st_read.node(), m::StateRead(testing::Not("st")));
EXPECT_THAT(st_read.node()->users(), ::testing::IsEmpty());

// Verify old next nodes are identity-ified (labeled and unlabeled)
EXPECT_THAT(add_st.node(), m::NextWithStateElementWithLabel(
state_element, m::Literal(0), cond.node(),
std::optional<std::string>("my_next_label")));
EXPECT_THAT(sub_st.node(),
m::NextWithStateElement(state_element, m::Literal(0),
m::Not(cond.node())));

// Make sure that 'new_state_read' takes over the name and everything.
EXPECT_THAT(new_st,
m::StateRead("st", std::optional<std::string>("my_read_label")));
EXPECT_THAT(new_st->users(), UnorderedElementsAre(m::Neg(new_st)));

// Verify new next nodes (labeled and unlabeled)
EXPECT_THAT(proc->next_values(new_st_element),
UnorderedElementsAre(
m::NextWithStateElementWithLabel(
new_st_element,
m::Neg(m::Add(m::Neg(new_st), m::Literal(UBits(1, 4)))),
m::And(m::Literal(UBits(1, 1)), cond.node()),
std::optional<std::string>("my_next_label")),
m::NextWithStateElement(
new_st_element,
m::Neg(m::Sub(m::Neg(new_st), m::Literal(UBits(1, 4)))),
m::And(m::Literal(UBits(1, 1)), m::Not(cond.node())))));

XLS_ASSERT_OK_AND_ASSIGN(int64_t old_idx,
proc->GetStateElementIndex(state_element));

// Remove the old Next nodes.
std::vector<Next*> old_nexts(proc->next_values(state_element).begin(),
proc->next_values(state_element).end());
for (Next* next : old_nexts) {
XLS_ASSERT_OK(proc->RemoveNode(next));
}

// Remove the old state element (and its StateRead) with no users.
XLS_EXPECT_OK(proc->RemoveStateElement(old_idx));

// Verify only 'cond' and 'new_st' remain.
EXPECT_EQ(proc->GetStateElementCount(), 2);
}

class ScheduledProcTest : public IrTestBase {
protected:
absl::StatusOr<ScheduledProc*> CreateScheduledProc(Package* p) {
Expand Down
Loading