From b08d6a152ddc3c3c63e320fe75d386adbbdce0d2 Mon Sep 17 00:00:00 2001 From: XLS Team Date: Fri, 19 Jun 2026 01:28:20 -0700 Subject: [PATCH] Adding option to pass a vector of `tops` to `optimizeUsingXls`. - Before: `optimizeUsingXls` took as top either the module name (when the module is named) or defaulted to `_package`, which triggers an error if not func inside module is named like that. This design was problematic in the context of broader code optimization when we typically want to extract multiple structures, optimize them with XLS, and inline the resulting optimization back into the original structures. Two work arounds were explored: i) wrapping all extracted funcs into the same module, but didn't work (as optimized functions couldn't be inlined back into their original structure), ii) wrapping each extracted func inside its own module, but possibly costly to run. - current proosal: passing a list of tops to `optimizeUsingXls` to call `mlirXlsToXls` over the entire module, and then loop over tops for `OptimizeIrForTop` and `XlsToMlirXlsTranslate` PiperOrigin-RevId: 934791923 --- xls/contrib/mlir/BUILD | 2 + .../mlir/transforms/optimize_using_xls.cc | 96 +++++++++++++++---- xls/contrib/mlir/transforms/passes.h | 8 +- 3 files changed, 85 insertions(+), 21 deletions(-) diff --git a/xls/contrib/mlir/BUILD b/xls/contrib/mlir/BUILD index a420bc2d0c..989b126142 100644 --- a/xls/contrib/mlir/BUILD +++ b/xls/contrib/mlir/BUILD @@ -367,11 +367,13 @@ cc_library( ":xls_transforms_passes", ":xls_transforms_passes_inc_gen", ":xls_translate_lib", + "//xls/ir:clone_package", "//xls/passes:pass_pipeline_cc_proto", "//xls/tools:opt", "@com_google_absl//absl/status", "@com_google_protobuf//:protobuf", "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", diff --git a/xls/contrib/mlir/transforms/optimize_using_xls.cc b/xls/contrib/mlir/transforms/optimize_using_xls.cc index a97f68fa90..09b37eba72 100644 --- a/xls/contrib/mlir/transforms/optimize_using_xls.cc +++ b/xls/contrib/mlir/transforms/optimize_using_xls.cc @@ -12,12 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include +#include #include "absl/status/status.h" #include "llvm/include/llvm/Support/DebugLog.h" +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/include/mlir/IR/BuiltinOps.h" #include "mlir/include/mlir/IR/Diagnostics.h" #include "mlir/include/mlir/IR/MLIRContext.h" @@ -29,6 +32,7 @@ #include "xls/contrib/mlir/tools/xls_translate/xls_translate_from_mlir.h" #include "xls/contrib/mlir/tools/xls_translate/xls_translate_to_mlir.h" #include "xls/contrib/mlir/transforms/passes.h" // IWYU pragma: keep +#include "xls/ir/clone_package.h" #include "xls/passes/pass_pipeline.pb.h" #include "xls/tools/opt.h" @@ -64,16 +68,15 @@ void OptimizeUsingXlsPass::runOnOperation() { } LogicalResult optimizeUsingXls(ModuleOp module, DslxPackageCache& dslx_cache, - std::optional xls_pipeline) { + std::optional xls_pipeline, + ArrayRef tops) { FailureOr> package = - mlirXlsToXls(module, - /*dslx_search_path=*/"", dslx_cache); + mlirXlsToXls(module, /*dslx_search_path=*/"", dslx_cache); if (failed(package)) { return failure(); } ::xls::tools::OptOptions opt_options; - opt_options.top = module.getName().value_or("_package"); if (xls_pipeline.has_value()) { ::xls::PassPipelineProto pass_pipeline; if (!google::protobuf::TextFormat::ParseFromString(*xls_pipeline, &pass_pipeline)) { @@ -83,25 +86,80 @@ LogicalResult optimizeUsingXls(ModuleOp module, DslxPackageCache& dslx_cache, opt_options.pass_pipeline = pass_pipeline; } - if (!xls_pipeline.has_value() || !xls_pipeline->empty()) { - LDBG() << "Optimizing IR for top: '" << opt_options.top << "using \n\t" - << xls_pipeline.value_or("default pipeline"); + bool is_empty_tops = tops.empty(); + std::string default_top; + SmallVector actual_tops(tops.begin(), tops.end()); - absl::Status status = - ::xls::tools::OptimizeIrForTop(package->get(), opt_options); - if (!status.ok()) { - return module.emitError("failed to optimize IR: ") << status.ToString(); - } + if (is_empty_tops) { + default_top = module.getName().value_or("_package").str(); + actual_tops.push_back(default_top); } - OwningOpRef new_module_op = - XlsToMlirXlsTranslate(**package, module.getContext()); - if (!new_module_op) { - return module.emitError( - "failed to translate optimized XLS IR back to MLIR"); + for (size_t i = 0; i < actual_tops.size(); ++i) { + StringRef top = actual_tops[i]; + bool is_last = (i == actual_tops.size() - 1); + + ::xls::Package* pkg_to_optimize; + std::unique_ptr<::xls::Package> pkg_clone; + + if (is_last) { + // For the last (or only) top, we don't need to clone. We can just mutate + // the original package directly. + pkg_to_optimize = package->get(); + } else { + auto pkg_clone_or = ::xls::ClonePackage((*package).get()); + if (!pkg_clone_or.ok()) { + return module.emitError("failed to clone package: ") + << pkg_clone_or.status().ToString(); + } + pkg_clone = std::move(pkg_clone_or).value(); + pkg_to_optimize = pkg_clone.get(); + } + + opt_options.top = top.str(); + + if (!xls_pipeline.has_value() || !xls_pipeline->empty()) { + LDBG() << "Optimizing IR for top: '" << opt_options.top << "' using \n\t" + << xls_pipeline.value_or("default pipeline"); + + absl::Status status = + ::xls::tools::OptimizeIrForTop(pkg_to_optimize, opt_options); + if (!status.ok()) { + return module.emitError("failed to optimize IR: ") << status.ToString(); + } + } + + OwningOpRef new_module_op = + XlsToMlirXlsTranslate(*pkg_to_optimize, module.getContext()); + if (!new_module_op) { + return module.emitError( + "failed to translate optimized XLS IR back to MLIR"); + } + + if (is_empty_tops) { + // If no explicit tops were given, preserve the original + // behavior of replacing the entire module body. + module.getBodyRegion().takeBody( + cast(*new_module_op).getBodyRegion()); + } else { + // Merge the optimized top function back into the module. + // NOTE: This assumes XLS optimization preserves the function signature. + // If the function type were to change, we would also need to update all + // call sites in the module to avoid producing invalid IR. + ModuleOp new_module = cast(*new_module_op); + auto optimized_func = new_module.lookupSymbol(top); + if (!optimized_func) { + return module.emitError("could not find optimized func ") << top; + } + auto original_func = module.lookupSymbol(top); + if (!original_func) { + return module.emitError("could not find original func ") << top; + } + original_func.setType(optimized_func.getFunctionType()); + original_func.getBody().takeBody(optimized_func.getBody()); + } } - module.getBodyRegion().takeBody( - cast(*new_module_op).getBodyRegion()); + return success(); } diff --git a/xls/contrib/mlir/transforms/passes.h b/xls/contrib/mlir/transforms/passes.h index ecf76b6747..397a7b2952 100644 --- a/xls/contrib/mlir/transforms/passes.h +++ b/xls/contrib/mlir/transforms/passes.h @@ -35,10 +35,14 @@ class DslxPackageCache; #define GEN_PASS_REGISTRATION #include "xls/contrib/mlir/transforms/passes.h.inc" // IWYU pragma: export -// Optimizes the given MLIR module using XLS. +// Optimizes the given MLIR module using XLS. When tops is empty (default), +// the module name is used as the single top and the entire module body is +// replaced. When tops is non-empty, only the specified functions are optimized +// and spliced back into the module. LogicalResult optimizeUsingXls( ModuleOp module, DslxPackageCache& dslx_cache, - std::optional xls_pipeline = std::nullopt); + std::optional xls_pipeline = std::nullopt, + ArrayRef tops = {}); } // namespace mlir::xls