Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xls/contrib/mlir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -367,11 +367,13 @@ cc_library(
":xls_transforms_passes",
":xls_transforms_passes_inc_gen",
":xls_translate_lib",
"//xls/ir:clone_package",
"//xls/passes:pass_pipeline_cc_proto",
"//xls/tools:opt",
"@com_google_absl//absl/status",
"@com_google_protobuf//:protobuf",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
Expand Down
96 changes: 77 additions & 19 deletions xls/contrib/mlir/transforms/optimize_using_xls.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cstddef>
#include <memory>
#include <optional>
#include <string>
#include <utility>

#include "absl/status/status.h"
#include "llvm/include/llvm/Support/DebugLog.h"
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/include/mlir/IR/BuiltinOps.h"
#include "mlir/include/mlir/IR/Diagnostics.h"
#include "mlir/include/mlir/IR/MLIRContext.h"
Expand All @@ -29,6 +32,7 @@
#include "xls/contrib/mlir/tools/xls_translate/xls_translate_from_mlir.h"
#include "xls/contrib/mlir/tools/xls_translate/xls_translate_to_mlir.h"
#include "xls/contrib/mlir/transforms/passes.h" // IWYU pragma: keep
#include "xls/ir/clone_package.h"
#include "xls/passes/pass_pipeline.pb.h"
#include "xls/tools/opt.h"

Expand Down Expand Up @@ -64,16 +68,15 @@ void OptimizeUsingXlsPass::runOnOperation() {
}

LogicalResult optimizeUsingXls(ModuleOp module, DslxPackageCache& dslx_cache,
std::optional<std::string> xls_pipeline) {
std::optional<std::string> xls_pipeline,
ArrayRef<StringRef> tops) {
FailureOr<std::unique_ptr<::xls::Package>> package =
mlirXlsToXls(module,
/*dslx_search_path=*/"", dslx_cache);
mlirXlsToXls(module, /*dslx_search_path=*/"", dslx_cache);
if (failed(package)) {
return failure();
}

::xls::tools::OptOptions opt_options;
opt_options.top = module.getName().value_or("_package");
if (xls_pipeline.has_value()) {
::xls::PassPipelineProto pass_pipeline;
if (!google::protobuf::TextFormat::ParseFromString(*xls_pipeline, &pass_pipeline)) {
Expand All @@ -83,25 +86,80 @@ LogicalResult optimizeUsingXls(ModuleOp module, DslxPackageCache& dslx_cache,
opt_options.pass_pipeline = pass_pipeline;
}

if (!xls_pipeline.has_value() || !xls_pipeline->empty()) {
LDBG() << "Optimizing IR for top: '" << opt_options.top << "using \n\t"
<< xls_pipeline.value_or("default pipeline");
bool is_empty_tops = tops.empty();
std::string default_top;
SmallVector<StringRef> actual_tops(tops.begin(), tops.end());

absl::Status status =
::xls::tools::OptimizeIrForTop(package->get(), opt_options);
if (!status.ok()) {
return module.emitError("failed to optimize IR: ") << status.ToString();
}
if (is_empty_tops) {
default_top = module.getName().value_or("_package").str();
actual_tops.push_back(default_top);
}

OwningOpRef<Operation*> new_module_op =
XlsToMlirXlsTranslate(**package, module.getContext());
if (!new_module_op) {
return module.emitError(
"failed to translate optimized XLS IR back to MLIR");
for (size_t i = 0; i < actual_tops.size(); ++i) {
StringRef top = actual_tops[i];
bool is_last = (i == actual_tops.size() - 1);

::xls::Package* pkg_to_optimize;
std::unique_ptr<::xls::Package> pkg_clone;

if (is_last) {
// For the last (or only) top, we don't need to clone. We can just mutate
// the original package directly.
pkg_to_optimize = package->get();
} else {
auto pkg_clone_or = ::xls::ClonePackage((*package).get());
if (!pkg_clone_or.ok()) {
return module.emitError("failed to clone package: ")
<< pkg_clone_or.status().ToString();
}
pkg_clone = std::move(pkg_clone_or).value();
pkg_to_optimize = pkg_clone.get();
}

opt_options.top = top.str();

if (!xls_pipeline.has_value() || !xls_pipeline->empty()) {
LDBG() << "Optimizing IR for top: '" << opt_options.top << "' using \n\t"
<< xls_pipeline.value_or("default pipeline");

absl::Status status =
::xls::tools::OptimizeIrForTop(pkg_to_optimize, opt_options);
if (!status.ok()) {
return module.emitError("failed to optimize IR: ") << status.ToString();
}
}

OwningOpRef<Operation*> new_module_op =
XlsToMlirXlsTranslate(*pkg_to_optimize, module.getContext());
if (!new_module_op) {
return module.emitError(
"failed to translate optimized XLS IR back to MLIR");
}

if (is_empty_tops) {
// If no explicit tops were given, preserve the original
// behavior of replacing the entire module body.
module.getBodyRegion().takeBody(
cast<ModuleOp>(*new_module_op).getBodyRegion());
} else {
// Merge the optimized top function back into the module.
// NOTE: This assumes XLS optimization preserves the function signature.
// If the function type were to change, we would also need to update all
// call sites in the module to avoid producing invalid IR.
ModuleOp new_module = cast<ModuleOp>(*new_module_op);
auto optimized_func = new_module.lookupSymbol<mlir::func::FuncOp>(top);
if (!optimized_func) {
return module.emitError("could not find optimized func ") << top;
}
auto original_func = module.lookupSymbol<mlir::func::FuncOp>(top);
if (!original_func) {
return module.emitError("could not find original func ") << top;
}
original_func.setType(optimized_func.getFunctionType());
original_func.getBody().takeBody(optimized_func.getBody());
}
}
module.getBodyRegion().takeBody(
cast<ModuleOp>(*new_module_op).getBodyRegion());

return success();
}

Expand Down
8 changes: 6 additions & 2 deletions xls/contrib/mlir/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,14 @@ class DslxPackageCache;
#define GEN_PASS_REGISTRATION
#include "xls/contrib/mlir/transforms/passes.h.inc" // IWYU pragma: export

// Optimizes the given MLIR module using XLS.
// Optimizes the given MLIR module using XLS. When tops is empty (default),
// the module name is used as the single top and the entire module body is
// replaced. When tops is non-empty, only the specified functions are optimized
// and spliced back into the module.
LogicalResult optimizeUsingXls(
ModuleOp module, DslxPackageCache& dslx_cache,
std::optional<std::string> xls_pipeline = std::nullopt);
std::optional<std::string> xls_pipeline = std::nullopt,
ArrayRef<StringRef> tops = {});

} // namespace mlir::xls

Expand Down
Loading