Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions xls/ir/node_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,9 @@ class GenericSelect {
GenericSelect& operator=(const GenericSelect&) = default;
GenericSelect& operator=(GenericSelect&&) = default;
static absl::StatusOr<GenericSelect> From(Node* n);
static bool IsSelect(const Node* n) {
return n->OpIn({Op::kSel, Op::kPrioritySel, Op::kOneHotSel});
}

// Assignment operators from underlying types.
GenericSelect& operator=(Select* select) {
Expand Down
7 changes: 7 additions & 0 deletions xls/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1144,9 +1144,15 @@ xls_pass(
hdrs = ["select_lifting_pass.h"],
pass_class = "SelectLiftingPass",
deps = [
":bdd_query_engine",
":critical_path_delay_analysis",
":lazy_ternary_query_engine",
":optimization_pass",
":partial_info_query_engine",
":pass_base",
":query_engine",
":stateless_query_engine",
":union_query_engine",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
"//xls/estimators/delay_model:delay_estimator",
Expand Down Expand Up @@ -4602,6 +4608,7 @@ cc_test(
"//xls/ir",
"//xls/ir:bits",
"//xls/ir:function_builder",
"//xls/ir:ir_matcher",
"//xls/ir:ir_test_base",
"//xls/ir:op",
"//xls/solvers:z3_ir_equivalence_testutils",
Expand Down
71 changes: 18 additions & 53 deletions xls/passes/array_simplification_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1381,14 +1381,13 @@ absl::StatusOr<SimplifyResult> SimplifyConditionalAssign(
Node* select, const QueryEngine& query_engine) {
absl::Span<Node* const> original_cases;
std::optional<Node*> original_default_value;
if (select->Is<Select>()) {
original_cases = select->As<Select>()->cases();
original_default_value = select->As<Select>()->default_value();
} else {
XLS_RET_CHECK(select->Is<PrioritySelect>());
original_cases = select->As<PrioritySelect>()->cases();
original_default_value = select->As<PrioritySelect>()->default_value();
XLS_ASSIGN_OR_RETURN(GenericSelect sel, GenericSelect::From(select));
if (select->Is<OneHotSelect>() &&
!query_engine.ExactlyOneBitTrue(sel.selector())) {
return SimplifyResult{.changed = false};
}
original_cases = sel.cases();
original_default_value = sel.default_value();

struct IdentityValue : std::monostate {};
Node* array_to_update = nullptr;
Expand Down Expand Up @@ -1497,24 +1496,9 @@ absl::StatusOr<SimplifyResult> SimplifyConditionalAssign(
}

Node* selected_value;
if (select->Is<Select>()) {
XLS_ASSIGN_OR_RETURN(selected_value,
select->function_base()->MakeNode<Select>(
select->loc(), select->As<Select>()->selector(),
/*cases=*/
case_values,
/*default_value=*/default_value));
} else {
XLS_RET_CHECK(select->Is<PrioritySelect>());
XLS_RET_CHECK(default_value.has_value());
XLS_ASSIGN_OR_RETURN(
selected_value,
select->function_base()->MakeNode<PrioritySelect>(
select->loc(), select->As<PrioritySelect>()->selector(),
/*cases=*/
case_values,
/*default_value=*/*default_value));
}
XLS_ASSIGN_OR_RETURN(
selected_value,
sel.CloneSelectLike(sel.selector(), case_values, default_value));

XLS_ASSIGN_OR_RETURN(
ArrayUpdate * overall_update,
Expand All @@ -1535,16 +1519,9 @@ absl::StatusOr<SimplifyResult> SimplifyConditionalAssign(
// further optimization. On the other hand, this can replicate the select
// logic, which can be expensive in area, so we limit by the number of cases.
absl::StatusOr<SimplifyResult> SimplifySelectOfArrays(Node* select) {
absl::Span<Node* const> original_cases;
std::optional<Node*> original_default_value;
if (select->Is<Select>()) {
original_cases = select->As<Select>()->cases();
original_default_value = select->As<Select>()->default_value();
} else {
XLS_RET_CHECK(select->Is<PrioritySelect>());
original_cases = select->As<PrioritySelect>()->cases();
original_default_value = select->As<PrioritySelect>()->default_value();
}
XLS_ASSIGN_OR_RETURN(GenericSelect sel, GenericSelect::From(select));
absl::Span<Node* const> original_cases = sel.cases();
std::optional<Node*> original_default_value = sel.default_value();

for (Node* sel_case : original_cases) {
if (!sel_case->Is<Array>()) {
Expand Down Expand Up @@ -1578,21 +1555,9 @@ absl::StatusOr<SimplifyResult> SimplifySelectOfArrays(Node* select) {
default_element = original_default_value.value()->operand(i);
}
Node* selected_element;
if (select->Is<Select>()) {
XLS_ASSIGN_OR_RETURN(
selected_element,
select->function_base()->MakeNode<Select>(
select->loc(), select->As<Select>()->selector(),
/*cases=*/elements, /*default=*/default_element));
} else {
XLS_RET_CHECK(select->Is<PrioritySelect>());
XLS_RET_CHECK(default_element.has_value());
XLS_ASSIGN_OR_RETURN(
selected_element,
select->function_base()->MakeNode<PrioritySelect>(
select->loc(), select->As<PrioritySelect>()->selector(),
/*cases=*/elements, /*default=*/*default_element));
}
XLS_ASSIGN_OR_RETURN(
selected_element,
sel.CloneSelectLike(sel.selector(), elements, default_element));
selected_elements.push_back(selected_element);
}
XLS_ASSIGN_OR_RETURN(Array * new_array,
Expand Down Expand Up @@ -1655,7 +1620,7 @@ absl::StatusOr<SimplifyResult> SimplifyArraySlice(
// Simplify various forms of a select of array-typed values.
absl::StatusOr<SimplifyResult> SimplifySelect(Node* select,
const QueryEngine& query_engine) {
XLS_RET_CHECK(select->Is<Select>() || select->Is<PrioritySelect>());
XLS_RET_CHECK(GenericSelect::IsSelect(select));

XLS_ASSIGN_OR_RETURN(SimplifyResult conditional_assign_result,
SimplifyConditionalAssign(select, query_engine));
Expand Down Expand Up @@ -1740,7 +1705,7 @@ absl::StatusOr<bool> ArraySimplificationPass::RunOnFunctionBaseInternal(
for (Node* node : reverse_topo_sort_nodes) {
if (!node->IsDead() &&
node->OpIn({Op::kArray, Op::kArrayIndex, Op::kArrayUpdate, Op::kSel,
Op::kPrioritySel, Op::kArraySlice})) {
Op::kPrioritySel, Op::kOneHotSel, Op::kArraySlice})) {
add_to_worklist(node, false);
}
}
Expand All @@ -1766,7 +1731,7 @@ absl::StatusOr<bool> ArraySimplificationPass::RunOnFunctionBaseInternal(
} else if (node->Is<Array>()) {
XLS_ASSIGN_OR_RETURN(result,
SimplifyArray(node->As<Array>(), query_engine));
} else if (node->Is<Select>() || node->Is<PrioritySelect>()) {
} else if (GenericSelect::IsSelect(node)) {
XLS_ASSIGN_OR_RETURN(result, SimplifySelect(node, query_engine));
} else if (node->Is<ArraySlice>()) {
XLS_ASSIGN_OR_RETURN(
Expand Down
77 changes: 77 additions & 0 deletions xls/passes/array_simplification_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1722,6 +1722,83 @@ TEST_F(ArraySimplificationPassTest, IndexOfOneHotSelect) {
m::ArrayIndex(m::Param("b"), {m::Param("i"), m::Param("j")})}));
}

TEST_F(ArraySimplificationPassTest,
SimplifyConditionalAssignWithOneHotSelectAndOneHotSelector) {
Package* p = GetPackage();
FunctionBuilder fb(TestName(), p);
Type* u32 = p->GetBitsType(32);
BValue a_param = fb.Param("a", u32);
BValue b_param = fb.Param("b", u32);
BValue c_param = fb.Param("c", u32);
BValue A = fb.Array({a_param, b_param, c_param}, u32);

BValue x = fb.Param("x", p->GetBitsType(1));
BValue selector = fb.Concat({x, fb.Not(x)});

BValue v = fb.Param("v", u32);
BValue idx = fb.Literal(UBits(1, 32));
BValue A_updated = fb.ArrayUpdate(A, v, {idx});

BValue ohs = fb.OneHotSelect(selector, {A, A_updated});

XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs));
ASSERT_THAT(Run(f), IsOkAndHolds(true));

EXPECT_THAT(
f->return_value(),
m::Array(m::Param("a"),
m::OneHotSelect(m::Concat(m::Param("x"), m::Not(m::Param("x"))),
{m::Param("b"), m::Param("v")}),
m::Param("c")));
}

TEST_F(ArraySimplificationPassTest,
NoSimplifyConditionalAssignWithOneHotSelectAndNonOneHotSelector) {
Package* p = GetPackage();
FunctionBuilder fb(TestName(), p);
Type* u32 = p->GetBitsType(32);
BValue a_param = fb.Param("a", u32);
BValue b_param = fb.Param("b", u32);
BValue c_param = fb.Param("c", u32);
BValue A = fb.Array({a_param, b_param, c_param}, u32);

BValue selector = fb.Param("selector", p->GetBitsType(2));

BValue v = fb.Param("v", u32);
BValue idx = fb.Literal(UBits(1, 32));
BValue A_updated = fb.ArrayUpdate(A, v, {idx});

BValue ohs = fb.OneHotSelect(selector, {A, A_updated});

XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs));
ASSERT_THAT(Run(f), IsOkAndHolds(false));
}

TEST_F(ArraySimplificationPassTest, SimplifySelectOfArraysWithOneHotSelect) {
Package* p = GetPackage();
FunctionBuilder fb(TestName(), p);
Type* u32 = p->GetBitsType(32);
BValue a = fb.Param("a", u32);
BValue b = fb.Param("b", u32);
BValue c = fb.Param("c", u32);
BValue d = fb.Param("d", u32);

BValue A = fb.Array({a, b}, u32);
BValue B = fb.Array({c, d}, u32);

BValue selector = fb.Param("selector", p->GetBitsType(2));
BValue ohs = fb.OneHotSelect(selector, {A, B});

XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(ohs));
ASSERT_THAT(Run(f), IsOkAndHolds(true));

EXPECT_THAT(f->return_value(),
m::Array(m::OneHotSelect(m::Param("selector"),
{m::Param("a"), m::Param("c")}),
m::OneHotSelect(m::Param("selector"),
{m::Param("b"), m::Param("d")})));
}

void IrFuzzArraySimplification(FuzzPackageWithArgs fuzz_package_with_args) {
ArraySimplificationPass pass;
OptimizationPassChangesOutputs(std::move(fuzz_package_with_args), pass);
Expand Down
Loading
Loading