diff --git a/xls/contrib/mlir/BUILD b/xls/contrib/mlir/BUILD index a420bc2d0c..989b126142 100644 --- a/xls/contrib/mlir/BUILD +++ b/xls/contrib/mlir/BUILD @@ -367,11 +367,13 @@ cc_library( ":xls_transforms_passes", ":xls_transforms_passes_inc_gen", ":xls_translate_lib", + "//xls/ir:clone_package", "//xls/passes:pass_pipeline_cc_proto", "//xls/tools:opt", "@com_google_absl//absl/status", "@com_google_protobuf//:protobuf", "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", diff --git a/xls/contrib/mlir/transforms/optimize_using_xls.cc b/xls/contrib/mlir/transforms/optimize_using_xls.cc index a97f68fa90..09b37eba72 100644 --- a/xls/contrib/mlir/transforms/optimize_using_xls.cc +++ b/xls/contrib/mlir/transforms/optimize_using_xls.cc @@ -12,12 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include +#include #include "absl/status/status.h" #include "llvm/include/llvm/Support/DebugLog.h" +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/include/mlir/IR/BuiltinOps.h" #include "mlir/include/mlir/IR/Diagnostics.h" #include "mlir/include/mlir/IR/MLIRContext.h" @@ -29,6 +32,7 @@ #include "xls/contrib/mlir/tools/xls_translate/xls_translate_from_mlir.h" #include "xls/contrib/mlir/tools/xls_translate/xls_translate_to_mlir.h" #include "xls/contrib/mlir/transforms/passes.h" // IWYU pragma: keep +#include "xls/ir/clone_package.h" #include "xls/passes/pass_pipeline.pb.h" #include "xls/tools/opt.h" @@ -64,16 +68,15 @@ void OptimizeUsingXlsPass::runOnOperation() { } LogicalResult optimizeUsingXls(ModuleOp module, DslxPackageCache& dslx_cache, - std::optional xls_pipeline) { + std::optional xls_pipeline, + ArrayRef tops) { FailureOr> package = - mlirXlsToXls(module, - /*dslx_search_path=*/"", dslx_cache); + mlirXlsToXls(module, /*dslx_search_path=*/"", dslx_cache); if (failed(package)) { return failure(); } ::xls::tools::OptOptions opt_options; - opt_options.top = module.getName().value_or("_package"); if (xls_pipeline.has_value()) { ::xls::PassPipelineProto pass_pipeline; if (!google::protobuf::TextFormat::ParseFromString(*xls_pipeline, &pass_pipeline)) { @@ -83,25 +86,80 @@ LogicalResult optimizeUsingXls(ModuleOp module, DslxPackageCache& dslx_cache, opt_options.pass_pipeline = pass_pipeline; } - if (!xls_pipeline.has_value() || !xls_pipeline->empty()) { - LDBG() << "Optimizing IR for top: '" << opt_options.top << "using \n\t" - << xls_pipeline.value_or("default pipeline"); + bool is_empty_tops = tops.empty(); + std::string default_top; + SmallVector actual_tops(tops.begin(), tops.end()); - absl::Status status = - ::xls::tools::OptimizeIrForTop(package->get(), opt_options); - if (!status.ok()) { - return module.emitError("failed to optimize IR: ") << status.ToString(); - } + if (is_empty_tops) { + default_top = module.getName().value_or("_package").str(); + actual_tops.push_back(default_top); } - OwningOpRef new_module_op = - XlsToMlirXlsTranslate(**package, module.getContext()); - if (!new_module_op) { - return module.emitError( - "failed to translate optimized XLS IR back to MLIR"); + for (size_t i = 0; i < actual_tops.size(); ++i) { + StringRef top = actual_tops[i]; + bool is_last = (i == actual_tops.size() - 1); + + ::xls::Package* pkg_to_optimize; + std::unique_ptr<::xls::Package> pkg_clone; + + if (is_last) { + // For the last (or only) top, we don't need to clone. We can just mutate + // the original package directly. + pkg_to_optimize = package->get(); + } else { + auto pkg_clone_or = ::xls::ClonePackage((*package).get()); + if (!pkg_clone_or.ok()) { + return module.emitError("failed to clone package: ") + << pkg_clone_or.status().ToString(); + } + pkg_clone = std::move(pkg_clone_or).value(); + pkg_to_optimize = pkg_clone.get(); + } + + opt_options.top = top.str(); + + if (!xls_pipeline.has_value() || !xls_pipeline->empty()) { + LDBG() << "Optimizing IR for top: '" << opt_options.top << "' using \n\t" + << xls_pipeline.value_or("default pipeline"); + + absl::Status status = + ::xls::tools::OptimizeIrForTop(pkg_to_optimize, opt_options); + if (!status.ok()) { + return module.emitError("failed to optimize IR: ") << status.ToString(); + } + } + + OwningOpRef new_module_op = + XlsToMlirXlsTranslate(*pkg_to_optimize, module.getContext()); + if (!new_module_op) { + return module.emitError( + "failed to translate optimized XLS IR back to MLIR"); + } + + if (is_empty_tops) { + // If no explicit tops were given, preserve the original + // behavior of replacing the entire module body. + module.getBodyRegion().takeBody( + cast(*new_module_op).getBodyRegion()); + } else { + // Merge the optimized top function back into the module. + // NOTE: This assumes XLS optimization preserves the function signature. + // If the function type were to change, we would also need to update all + // call sites in the module to avoid producing invalid IR. + ModuleOp new_module = cast(*new_module_op); + auto optimized_func = new_module.lookupSymbol(top); + if (!optimized_func) { + return module.emitError("could not find optimized func ") << top; + } + auto original_func = module.lookupSymbol(top); + if (!original_func) { + return module.emitError("could not find original func ") << top; + } + original_func.setType(optimized_func.getFunctionType()); + original_func.getBody().takeBody(optimized_func.getBody()); + } } - module.getBodyRegion().takeBody( - cast(*new_module_op).getBodyRegion()); + return success(); } diff --git a/xls/contrib/mlir/transforms/passes.h b/xls/contrib/mlir/transforms/passes.h index ecf76b6747..397a7b2952 100644 --- a/xls/contrib/mlir/transforms/passes.h +++ b/xls/contrib/mlir/transforms/passes.h @@ -35,10 +35,14 @@ class DslxPackageCache; #define GEN_PASS_REGISTRATION #include "xls/contrib/mlir/transforms/passes.h.inc" // IWYU pragma: export -// Optimizes the given MLIR module using XLS. +// Optimizes the given MLIR module using XLS. When tops is empty (default), +// the module name is used as the single top and the entire module body is +// replaced. When tops is non-empty, only the specified functions are optimized +// and spliced back into the module. LogicalResult optimizeUsingXls( ModuleOp module, DslxPackageCache& dslx_cache, - std::optional xls_pipeline = std::nullopt); + std::optional xls_pipeline = std::nullopt, + ArrayRef tops = {}); } // namespace mlir::xls