Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions xls/ir/node_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,9 @@ class GenericSelect {
GenericSelect& operator=(const GenericSelect&) = default;
GenericSelect& operator=(GenericSelect&&) = default;
static absl::StatusOr<GenericSelect> From(Node* n);
static bool IsSelect(const Node* n) {
return n->OpIn({Op::kSel, Op::kPrioritySel, Op::kOneHotSel});
}

// Assignment operators from underlying types.
GenericSelect& operator=(Select* select) {
Expand Down
7 changes: 7 additions & 0 deletions xls/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1144,9 +1144,15 @@ xls_pass(
hdrs = ["select_lifting_pass.h"],
pass_class = "SelectLiftingPass",
deps = [
":bdd_query_engine",
":critical_path_delay_analysis",
":lazy_ternary_query_engine",
":optimization_pass",
":partial_info_query_engine",
":pass_base",
":query_engine",
":stateless_query_engine",
":union_query_engine",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
"//xls/estimators/delay_model:delay_estimator",
Expand Down Expand Up @@ -4602,6 +4608,7 @@ cc_test(
"//xls/ir",
"//xls/ir:bits",
"//xls/ir:function_builder",
"//xls/ir:ir_matcher",
"//xls/ir:ir_test_base",
"//xls/ir:op",
"//xls/solvers:z3_ir_equivalence_testutils",
Expand Down
71 changes: 18 additions & 53 deletions xls/passes/array_simplification_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1381,14 +1381,13 @@ absl::StatusOr<SimplifyResult> SimplifyConditionalAssign(
Node* select, const QueryEngine& query_engine) {
absl::Span<Node* const> original_cases;
std::optional<Node*> original_default_value;
if (select->Is<Select>()) {
original_cases = select->As<Select>()->cases();
original_default_value = select->As<Select>()->default_value();
} else {
XLS_RET_CHECK(select->Is<PrioritySelect>());
original_cases = select->As<PrioritySelect>()->cases();
original_default_value = select->As<PrioritySelect>()->default_value();
XLS_ASSIGN_OR_RETURN(GenericSelect sel, GenericSelect::From(select));
if (select->Is<OneHotSelect>() &&
!query_engine.ExactlyOneBitTrue(sel.selector())) {
return SimplifyResult{.changed = false};
}
original_cases = sel.cases();
original_default_value = sel.default_value();

struct IdentityValue : std::monostate {};
Node* array_to_update = nullptr;
Expand Down Expand Up @@ -1497,24 +1496,9 @@ absl::StatusOr<SimplifyResult> SimplifyConditionalAssign(
}

Node* selected_value;
if (select->Is<Select>()) {
XLS_ASSIGN_OR_RETURN(selected_value,
select->function_base()->MakeNode<Select>(
select->loc(), select->As<Select>()->selector(),
/*cases=*/
case_values,
/*default_value=*/default_value));
} else {
XLS_RET_CHECK(select->Is<PrioritySelect>());
XLS_RET_CHECK(default_value.has_value());
XLS_ASSIGN_OR_RETURN(
selected_value,
select->function_base()->MakeNode<PrioritySelect>(
select->loc(), select->As<PrioritySelect>()->selector(),
/*cases=*/
case_values,
/*default_value=*/*default_value));
}
XLS_ASSIGN_OR_RETURN(
selected_value,
sel.CloneSelectLike(sel.selector(), case_values, default_value));

XLS_ASSIGN_OR_RETURN(
ArrayUpdate * overall_update,
Expand All @@ -1535,16 +1519,9 @@ absl::StatusOr<SimplifyResult> SimplifyConditionalAssign(
// further optimization. On the other hand, this can replicate the select
// logic, which can be expensive in area, so we limit by the number of cases.
absl::StatusOr<SimplifyResult> SimplifySelectOfArrays(Node* select) {
absl::Span<Node* const> original_cases;
std::optional<Node*> original_default_value;
if (select->Is<Select>()) {
original_cases = select->As<Select>()->cases();
original_default_value = select->As<Select>()->default_value();
} else {
XLS_RET_CHECK(select->Is<PrioritySelect>());
original_cases = select->As<PrioritySelect>()->cases();
original_default_value = select->As<PrioritySelect>()->default_value();
}
XLS_ASSIGN_OR_RETURN(GenericSelect sel, GenericSelect::From(select));
absl::Span<Node* const> original_cases = sel.cases();
std::optional<Node*> original_default_value = sel.default_value();

for (Node* sel_case : original_cases) {
if (!sel_case->Is<Array>()) {
Expand Down Expand Up @@ -1578,21 +1555,9 @@ absl::StatusOr<SimplifyResult> SimplifySelectOfArrays(Node* select) {
default_element = original_default_value.value()->operand(i);
}
Node* selected_element;
if (select->Is<Select>()) {
XLS_ASSIGN_OR_RETURN(
selected_element,
select->function_base()->MakeNode<Select>(
select->loc(), select->As<Select>()->selector(),
/*cases=*/elements, /*default=*/default_element));
} else {
XLS_RET_CHECK(select->Is<PrioritySelect>());
XLS_RET_CHECK(default_element.has_value());
XLS_ASSIGN_OR_RETURN(
selected_element,
select->function_base()->MakeNode<PrioritySelect>(
select->loc(), select->As<PrioritySelect>()->selector(),
/*cases=*/elements, /*default=*/*default_element));
}
XLS_ASSIGN_OR_RETURN(
selected_element,
sel.CloneSelectLike(sel.selector(), elements, default_element));
selected_elements.push_back(selected_element);
}
XLS_ASSIGN_OR_RETURN(Array * new_array,
Expand Down Expand Up @@ -1655,7 +1620,7 @@ absl::StatusOr<SimplifyResult> SimplifyArraySlice(
// Simplify various forms of a select of array-typed values.
absl::StatusOr<SimplifyResult> SimplifySelect(Node* select,
const QueryEngine& query_engine) {
XLS_RET_CHECK(select->Is<Select>() || select->Is<PrioritySelect>());
XLS_RET_CHECK(GenericSelect::IsSelect(select));

XLS_ASSIGN_OR_RETURN(SimplifyResult conditional_assign_result,
SimplifyConditionalAssign(select, query_engine));
Expand Down Expand Up @@ -1740,7 +1705,7 @@ absl::StatusOr<bool> ArraySimplificationPass::RunOnFunctionBaseInternal(
for (Node* node : reverse_topo_sort_nodes) {
if (!node->IsDead() &&
node->OpIn({Op::kArray, Op::kArrayIndex, Op::kArrayUpdate, Op::kSel,
Op::kPrioritySel, Op::kArraySlice})) {
Op::kPrioritySel, Op::kOneHotSel, Op::kArraySlice})) {
add_to_worklist(node, false);
}
}
Expand All @@ -1766,7 +1731,7 @@ absl::StatusOr<bool> ArraySimplificationPass::RunOnFunctionBaseInternal(
} else if (node->Is<Array>()) {
XLS_ASSIGN_OR_RETURN(result,
SimplifyArray(node->As<Array>(), query_engine));
} else if (node->Is<Select>() || node->Is<PrioritySelect>()) {
} else if (GenericSelect::IsSelect(node)) {
XLS_ASSIGN_OR_RETURN(result, SimplifySelect(node, query_engine));
} else if (node->Is<ArraySlice>()) {
XLS_ASSIGN_OR_RETURN(
Expand Down
77 changes: 77 additions & 0 deletions xls/passes/array_simplification_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1722,6 +1722,83 @@ TEST_F(ArraySimplificationPassTest, IndexOfOneHotSelect) {
m::ArrayIndex(m::Param("b"), {m::Param("i"), m::Param("j")})}));
}

TEST_F(ArraySimplificationPassTest,
SimplifyConditionalAssignWithOneHotSelectAndOneHotSelector) {
Package* p = GetPackage();
FunctionBuilder fb(TestName(), p);
Type* u32 = p->GetBitsType(32);
BValue a_param = fb.Param("a", u32);
BValue b_param = fb.Param("b", u32);
BValue c_param = fb.Param("c", u32);
BValue A = fb.Array({a_param, b_param, c_param}, u32);

BValue x = fb.Param("x", p->GetBitsType(1));
BValue selector = fb.Concat({x, fb.Not(x)});

BValue v = fb.Param("v", u32);
BValue idx = fb.Literal(UBits(1, 32));
BValue A_updated = fb.ArrayUpdate(A, v, {idx});

BValue ohs = fb.OneHotSelect(selector, {A, A_updated});

XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs));
ASSERT_THAT(Run(f), IsOkAndHolds(true));

EXPECT_THAT(
f->return_value(),
m::Array(m::Param("a"),
m::OneHotSelect(m::Concat(m::Param("x"), m::Not(m::Param("x"))),
{m::Param("b"), m::Param("v")}),
m::Param("c")));
}

TEST_F(ArraySimplificationPassTest,
NoSimplifyConditionalAssignWithOneHotSelectAndNonOneHotSelector) {
Package* p = GetPackage();
FunctionBuilder fb(TestName(), p);
Type* u32 = p->GetBitsType(32);
BValue a_param = fb.Param("a", u32);
BValue b_param = fb.Param("b", u32);
BValue c_param = fb.Param("c", u32);
BValue A = fb.Array({a_param, b_param, c_param}, u32);

BValue selector = fb.Param("selector", p->GetBitsType(2));

BValue v = fb.Param("v", u32);
BValue idx = fb.Literal(UBits(1, 32));
BValue A_updated = fb.ArrayUpdate(A, v, {idx});

BValue ohs = fb.OneHotSelect(selector, {A, A_updated});

XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs));
ASSERT_THAT(Run(f), IsOkAndHolds(false));
}

TEST_F(ArraySimplificationPassTest, SimplifySelectOfArraysWithOneHotSelect) {
Package* p = GetPackage();
FunctionBuilder fb(TestName(), p);
Type* u32 = p->GetBitsType(32);
BValue a = fb.Param("a", u32);
BValue b = fb.Param("b", u32);
BValue c = fb.Param("c", u32);
BValue d = fb.Param("d", u32);

BValue A = fb.Array({a, b}, u32);
BValue B = fb.Array({c, d}, u32);

BValue selector = fb.Param("selector", p->GetBitsType(2));
BValue ohs = fb.OneHotSelect(selector, {A, B});

XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs));
ASSERT_THAT(Run(f), IsOkAndHolds(true));

EXPECT_THAT(f->return_value(),
m::Array(m::OneHotSelect(m::Param("selector"),
{m::Param("a"), m::Param("c")}),
m::OneHotSelect(m::Param("selector"),
{m::Param("b"), m::Param("d")})));
}

void IrFuzzArraySimplification(FuzzPackageWithArgs fuzz_package_with_args) {
ArraySimplificationPass pass;
OptimizationPassChangesOutputs(std::move(fuzz_package_with_args), pass);
Expand Down
97 changes: 41 additions & 56 deletions xls/passes/bit_slice_simplification_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -625,27 +625,34 @@ absl::StatusOr<bool> SimplifyLiteralBitSliceUpdate(BitSliceUpdate* update,
return true;
}

bool IsSelectOfLiterals(Node* node, QueryEngine* query_engine) {
bool IsSelectOfLiterals(GenericSelect sel, QueryEngine* query_engine) {
// Only allow OneHotSelect if the selector is actually one-hot.
if (sel.kind() == GenericSelect::Kind::kOneHotSel &&
!query_engine->ExactlyOneBitTrue(sel.selector())) {
return false;
}

auto is_literal_or_select_of_literals = [&](Node* node) {
return query_engine->IsFullyKnown(node) ||
IsSelectOfLiterals(node, query_engine);
if (query_engine->IsFullyKnown(node)) {
return true;
}
if (!GenericSelect::IsSelect(node)) {
return false;
}
absl::StatusOr<GenericSelect> node_sel = GenericSelect::From(node);
CHECK_OK(node_sel);
return IsSelectOfLiterals(*std::move(node_sel), query_engine);
};

if (node->Is<Select>()) {
Select* sel = node->As<Select>();
return sel->AllCases([&](Node* case_value) {
return is_literal_or_select_of_literals(case_value);
});
if (!absl::c_all_of(sel.cases(), [&](Node* case_value) {
return is_literal_or_select_of_literals(case_value);
})) {
return false;
}
if (node->Is<PrioritySelect>()) {
PrioritySelect* sel = node->As<PrioritySelect>();
return absl::c_all_of(sel->cases(),
[&](Node* case_value) {
return is_literal_or_select_of_literals(case_value);
}) &&
is_literal_or_select_of_literals(sel->default_value());
if (sel.default_value().has_value()) {
return is_literal_or_select_of_literals(*sel.default_value());
}
return false;
return true;
}

absl::StatusOr<Node*> LiftThroughSelectsOfLiterals(
Expand All @@ -656,62 +663,36 @@ absl::StatusOr<Node*> LiftThroughSelectsOfLiterals(
return lift_to_literal(*known_value);
}

Node* selector;
absl::Span<Node* const> cases;
std::optional<Node*> default_value;
if (node->Is<Select>()) {
Select* sel = node->As<Select>();
selector = sel->selector();
cases = sel->cases();
default_value = sel->default_value();
} else if (node->Is<PrioritySelect>()) {
PrioritySelect* sel = node->As<PrioritySelect>();
selector = sel->selector();
cases = sel->cases();
default_value = sel->default_value();
} else {
return absl::InternalError(
absl::StrCat("LiftThroughSelectsOfLiterals invoked on a node that was "
"not a select of literals: ",
node->ToString()));
}
XLS_ASSIGN_OR_RETURN(GenericSelect sel, GenericSelect::From(node));

std::vector<Node*> new_cases;
std::optional<Node*> new_default_value = std::nullopt;
new_cases.reserve(cases.size());
for (Node* case_value : cases) {
new_cases.reserve(sel.cases().size());
for (Node* case_value : sel.cases()) {
XLS_ASSIGN_OR_RETURN(Node * new_case_value,
LiftThroughSelectsOfLiterals(case_value, query_engine,
lift_to_literal));
new_cases.push_back(new_case_value);
}
if (default_value.has_value()) {
XLS_ASSIGN_OR_RETURN(new_default_value,
LiftThroughSelectsOfLiterals(
*default_value, query_engine, lift_to_literal));
if (sel.default_value().has_value()) {
XLS_ASSIGN_OR_RETURN(new_default_value, LiftThroughSelectsOfLiterals(
*sel.default_value(),
query_engine, lift_to_literal));
}

if (node->Is<Select>()) {
return node->function_base()->MakeNode<Select>(
node->loc(), selector, new_cases, new_default_value);
}
if (node->Is<PrioritySelect>()) {
XLS_RET_CHECK(new_default_value.has_value());
return node->function_base()->MakeNode<PrioritySelect>(
node->loc(), selector, new_cases, *new_default_value);
}
XLS_RET_CHECK(node->Is<OneHotSelect>());
XLS_RET_CHECK(!new_default_value.has_value());
return node->function_base()->MakeNode<OneHotSelect>(node->loc(), selector,
new_cases);
return sel.CloneSelectLike(sel.selector(), new_cases, new_default_value);
}

// Hoist bit slice updates above selects of literals, where they can be turned
// into static operations.
absl::StatusOr<bool> SimplifySelectOfLiteralsBitSliceUpdate(
BitSliceUpdate* update, QueryEngine* query_engine) {
Node* start = update->start();
if (!IsSelectOfLiterals(start, query_engine)) {
if (!GenericSelect::IsSelect(start)) {
return false;
}
XLS_ASSIGN_OR_RETURN(GenericSelect sel, GenericSelect::From(start));
if (!IsSelectOfLiterals(sel, query_engine)) {
return false;
}

Expand Down Expand Up @@ -822,7 +803,11 @@ absl::StatusOr<bool> SimplifyLiteralDynamicBitSlice(DynamicBitSlice* bit_slice,
absl::StatusOr<bool> SimplifySelectOfLiteralsDynamicBitSlice(
DynamicBitSlice* bit_slice, QueryEngine* query_engine) {
Node* start = bit_slice->start();
if (!IsSelectOfLiterals(start, query_engine)) {
if (!GenericSelect::IsSelect(start)) {
return false;
}
XLS_ASSIGN_OR_RETURN(GenericSelect sel, GenericSelect::From(start));
if (!IsSelectOfLiterals(sel, query_engine)) {
return false;
}

Expand Down
Loading
Loading