From 335b2034fa8d9f58e66fafb0b9bf33309e9f188b Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Tue, 12 May 2026 15:14:00 +0000 Subject: [PATCH 01/15] Get started. --- compiler/include/garel/GARelAttr.td | 2 +- compiler/include/garel/GARelOps.td | 2 +- compiler/include/garel/GARelSQL.h | 11 ++++++++ compiler/src/garel/CMakeLists.txt | 10 +++++++ compiler/src/garel/GARelSQL.cpp | 10 +++++++ .../test/graphalg-to-rel/deferred-reduce.mlir | 2 +- compiler/tools/CMakeLists.txt | 9 +++++++ compiler/tools/garel-translate.cpp | 26 +++++++++++++++++++ compiler/tools/graphalg-translate.cpp | 2 +- 9 files changed, 70 insertions(+), 4 deletions(-) create mode 100644 compiler/include/garel/GARelSQL.h create mode 100644 compiler/src/garel/GARelSQL.cpp create mode 100644 compiler/tools/garel-translate.cpp diff --git a/compiler/include/garel/GARelAttr.td b/compiler/include/garel/GARelAttr.td index 4dac0ca..e34d9b7 100644 --- a/compiler/include/garel/GARelAttr.td +++ b/compiler/include/garel/GARelAttr.td @@ -53,7 +53,7 @@ def Aggregator : GARel_Attr<"Aggregator", "aggregator"> { OptionalArrayRefParameter<"ColumnIdx">:$inputs); let assemblyFormat = [{ - `<` $func $inputs `>` + `<` $func (`>`) : ($inputs^ `>`)? }]; let genVerifyDecl = 1; diff --git a/compiler/include/garel/GARelOps.td b/compiler/include/garel/GARelOps.td index e1e94ec..fbc0e00 100644 --- a/compiler/include/garel/GARelOps.td +++ b/compiler/include/garel/GARelOps.td @@ -113,7 +113,7 @@ def UnionOp : GARel_Op<"union", [SameOperandsAndResultType]> { let results = (outs Relation:$result); let assemblyFormat = [{ - $inputs `:` type($inputs) + $inputs `:` type($result) attr-dict }]; diff --git a/compiler/include/garel/GARelSQL.h b/compiler/include/garel/GARelSQL.h new file mode 100644 index 0000000..5aa2ebc --- /dev/null +++ b/compiler/include/garel/GARelSQL.h @@ -0,0 +1,11 @@ +#pragma once + +#include +#include +#include + +namespace garel { + +mlir::LogicalResult translateToSQL(mlir::Operation *op, llvm::raw_ostream &os); + +} // namespace garel diff --git a/compiler/src/garel/CMakeLists.txt b/compiler/src/garel/CMakeLists.txt index 6c8622e..12e3d5b 100644 --- a/compiler/src/garel/CMakeLists.txt +++ b/compiler/src/garel/CMakeLists.txt @@ -35,3 +35,13 @@ target_link_libraries( GARelIR MLIRPass ) + +add_library(GARelSQL + GARelSQL.cpp +) +target_link_libraries( + GARelSQL + PRIVATE + GraphAlgIR + GARelIR +) diff --git a/compiler/src/garel/GARelSQL.cpp b/compiler/src/garel/GARelSQL.cpp new file mode 100644 index 0000000..b6a0ef9 --- /dev/null +++ b/compiler/src/garel/GARelSQL.cpp @@ -0,0 +1,10 @@ +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" + +namespace garel { + +mlir::LogicalResult translateToSQL(mlir::Operation *op, llvm::raw_ostream &os) { + return mlir::success(); +} + +} // namespace garel diff --git a/compiler/test/graphalg-to-rel/deferred-reduce.mlir b/compiler/test/graphalg-to-rel/deferred-reduce.mlir index e6c27c9..2e2edc8 100644 --- a/compiler/test/graphalg-to-rel/deferred-reduce.mlir +++ b/compiler/test/graphalg-to-rel/deferred-reduce.mlir @@ -70,7 +70,7 @@ func.func @ReduceMultiple( %arg0 : !graphalg.mat<1 x 43 x i64>, %arg1 : !graphalg.mat<42 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { - // CHECK: %[[#UNION:]] = garel.union %arg0, %arg1 : !garel.rel, !garel.rel + // CHECK: %[[#UNION:]] = garel.union %arg0, %arg1 : // CHECK: %[[#AGG:]] = garel.aggregate %0 : group_by=[] aggregators=[] %0 = graphalg.deferred_reduce %arg0, %arg1 : !graphalg.mat<1 x 43 x i64>, !graphalg.mat<42 x 1 x i64> -> <1 x 1 x i64> return %0 : !graphalg.mat<1 x 1 x i64> diff --git a/compiler/tools/CMakeLists.txt b/compiler/tools/CMakeLists.txt index 6efaafd..a97e550 100644 --- a/compiler/tools/CMakeLists.txt +++ b/compiler/tools/CMakeLists.txt @@ -44,3 +44,12 @@ if (GRAPHALG_ENABLE_FUZZER) target_compile_options(fuzz-parser PRIVATE -fsanitize=fuzzer) target_link_options(fuzz-parser PRIVATE -fsanitize=fuzzer) endif () + +add_executable(garel-translate garel-translate.cpp) +target_link_libraries(garel-translate PRIVATE + ${llvm_libs} + GARelSQL + GraphAlgIR + MLIRTranslateLib + MLIRFuncDialect +) diff --git a/compiler/tools/garel-translate.cpp b/compiler/tools/garel-translate.cpp new file mode 100644 index 0000000..679ff7c --- /dev/null +++ b/compiler/tools/garel-translate.cpp @@ -0,0 +1,26 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "garel/GARelDialect.h" +#include "garel/GARelSQL.h" +#include "graphalg/GraphAlgDialect.h" + +int main(int argc, char *argv[]) { + mlir::TranslateFromMLIRRegistration exportSQL( + "export-sql", "export to SQL", garel::translateToSQL, + [](mlir::DialectRegistry ®istry) { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + }); + + return failed( + mlir::mlirTranslateMain(argc, argv, "garel translation testing tool")); +} diff --git a/compiler/tools/graphalg-translate.cpp b/compiler/tools/graphalg-translate.cpp index 89a45f8..b740b9c 100644 --- a/compiler/tools/graphalg-translate.cpp +++ b/compiler/tools/graphalg-translate.cpp @@ -31,5 +31,5 @@ int main(int argc, char *argv[]) { }); return failed( - mlir::mlirTranslateMain(argc, argv, "Graphalg Translation Testing Tool")); + mlir::mlirTranslateMain(argc, argv, "Graphalg translation testing tool")); } From 0253a15e5400e6e3a747408860f88a0e06027f64 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Wed, 13 May 2026 11:10:45 +0000 Subject: [PATCH 02/15] Rough translations. --- compiler/src/garel/GARelSQL.cpp | 240 +++++++++++++++++++++++++++++++- 1 file changed, 239 insertions(+), 1 deletion(-) diff --git a/compiler/src/garel/GARelSQL.cpp b/compiler/src/garel/GARelSQL.cpp index b6a0ef9..09c7080 100644 --- a/compiler/src/garel/GARelSQL.cpp +++ b/compiler/src/garel/GARelSQL.cpp @@ -1,10 +1,248 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Operation.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" + +#include "garel/GARelAttr.h" +#include "garel/GARelOps.h" namespace garel { -mlir::LogicalResult translateToSQL(mlir::Operation *op, llvm::raw_ostream &os) { +namespace { + +class SQLTranslator { +private: + llvm::raw_ostream &_os; + llvm::DenseMap _valMap; + std::size_t _tempCount; + + std::string newTemp() { + return std::string("temp") + std::to_string(_tempCount++); + } + + mlir::LogicalResult translate(mlir::func::FuncOp op); + mlir::LogicalResult translate(mlir::Value val); + mlir::LogicalResult translate(mlir::Operation *op); + mlir::LogicalResult translate(ForOp op); + mlir::LogicalResult translate(ConstantOp op); + mlir::LogicalResult translate(AggregateOp op); + mlir::LogicalResult translate(ProjectOp op); + + mlir::LogicalResult translate(ExtractOp op); + mlir::LogicalResult translate(mlir::arith::SelectOp op); + mlir::LogicalResult translate(mlir::arith::ConstantOp op); + +public: + SQLTranslator(llvm::raw_ostream &os) : _os(os) {} + + mlir::LogicalResult translate(mlir::ModuleOp op); +}; + +} // namespace + +mlir::LogicalResult SQLTranslator::translate(mlir::ModuleOp op) { + for (auto &op : *op.getBody()) { + auto funcOp = llvm::dyn_cast(op); + if (!funcOp) { + return op.emitOpError("expected function"); + } + + if (mlir::failed(translate(funcOp))) { + return mlir::failure(); + } + } + + return mlir::success(); +} + +mlir::LogicalResult SQLTranslator::translate(mlir::func::FuncOp op) { + auto name = op.getSymName(); + _os << "def " << name << "("; + for (auto [i, arg] : llvm::enumerate(op.getBody().getArguments())) { + auto varName = std::string("farg") + std::to_string(i); + _valMap[arg] = varName; + + if (i != 0) { + _os << ", "; + } + + _os << varName; + } + + _os << "):\n"; + + auto retOp = + llvm::cast(op.getBody().front().getTerminator()); + if (retOp.getNumOperands() != 1) { + return retOp.emitOpError("expected a single return value"); + } + + return translate(retOp.getOperand(0)); +} + +mlir::LogicalResult SQLTranslator::translate(mlir::Value val) { + if (_valMap.contains(val)) { + _os << _valMap[val]; + return mlir::success(); + } + + auto op = val.getDefiningOp(); + if (!op) { + return mlir::emitError(val.getLoc()) + << val << " is not a known variable or an operation result"; + } + + return translate(op); +} + +mlir::LogicalResult SQLTranslator::translate(mlir::Operation *op) { +#define CASE(OP) \ + if (auto o = llvm::dyn_cast(op)) { \ + return translate(o); \ + } + + CASE(ForOp) + CASE(ConstantOp) + CASE(AggregateOp) + CASE(ProjectOp) + CASE(ExtractOp) + CASE(mlir::arith::SelectOp) + CASE(mlir::arith::ConstantOp) +#undef CASE + + return op->emitOpError("no SQL translation defined for this op"); +} + +mlir::LogicalResult SQLTranslator::translate(ForOp op) { + // Initialize temporary tables for loop state + for (auto i : llvm::seq(op.getInit().size())) { + auto temp = newTemp(); + _os << "CREATE TABLE " << temp << " AS "; + if (mlir::failed(translate(op.getInit()[i]))) { + return mlir::failure(); + } + _os << ";\n"; + + _valMap[op.getBody().getArgument(i)] = temp; + } + // In the loop: + // - update the loop variables + // - check break condition + // Signal where to write output + return mlir::success(); +} + +mlir::LogicalResult SQLTranslator::translate(ConstantOp op) { + _os << "(SELECT " << op.getValue() << " AS c0)"; + return mlir::success(); +} + +mlir::LogicalResult SQLTranslator::translate(AggregateOp op) { + _os << "(SELECT "; + std::size_t colOut = 0; + for (auto key : op.getGroupBy()) { + if (colOut > 0) { + _os << ", "; + } + + _os << "c" << key << " AS c" << colOut++; + } + + for (auto agg : op.getAggregators()) { + if (colOut > 0) { + _os << ", "; + } + + _os << stringifyAggregateFunc(agg.getFunc()) << "("; + llvm::interleaveComma(agg.getInputs(), _os, + [&](ColumnIdx idx) { _os << "c" << idx; }); + _os << ") AS c" << colOut++; + } + + _os << " FROM "; + if (mlir::failed(translate(op.getInput()))) { + return mlir::failure(); + } + + if (!op.getGroupBy().empty()) { + _os << " GROUP BY "; + llvm::interleaveComma(op.getGroupBy(), _os, + [&](ColumnIdx idx) { _os << "c" << idx; }); + } + + _os << ")"; + return mlir::success(); +} + +mlir::LogicalResult SQLTranslator::translate(ProjectOp op) { + _os << "(SELECT "; + auto retOp = op.getTerminator(); + for (auto [i, val] : llvm::enumerate(retOp.getProjections())) { + if (i != 0) { + _os << ", "; + } + + if (mlir::failed(translate(val))) { + return mlir::failure(); + } + + _os << " AS c" << i; + } + + _os << " FROM "; + if (mlir::failed(translate(op.getInput()))) { + return mlir::failure(); + } + + _os << ")"; return mlir::success(); } +mlir::LogicalResult SQLTranslator::translate(ExtractOp op) { + _os << "c" << op.getColumn(); + return mlir::success(); +} + +mlir::LogicalResult SQLTranslator::translate(mlir::arith::SelectOp op) { + _os << "(CASE WHEN "; + if (mlir::failed(translate(op.getCondition()))) { + return mlir::failure(); + } + + _os << " THEN "; + if (mlir::failed(translate(op.getTrueValue()))) { + return mlir::failure(); + } + + _os << " ELSE "; + if (mlir::failed(translate(op.getFalseValue()))) { + return mlir::failure(); + } + + _os << ")"; + return mlir::success(); +} + +mlir::LogicalResult SQLTranslator::translate(mlir::arith::ConstantOp op) { + _os << op.getValue(); + return mlir::success(); +} + +mlir::LogicalResult translateToSQL(mlir::Operation *op, llvm::raw_ostream &os) { + SQLTranslator translator(os); + auto moduleOp = llvm::dyn_cast(op); + if (!moduleOp) { + return op->emitOpError("expected a module"); + } + + return translator.translate(moduleOp); +} + } // namespace garel From 08bb72b5ac1d5c6a4ed81d61a9ce5139824a27bc Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Wed, 13 May 2026 12:32:10 +0000 Subject: [PATCH 03/15] Generate some SQL that looks sort of reasonable. --- compiler/src/garel/GARelSQL.cpp | 158 ++++++++++++++++++++++++++++---- 1 file changed, 141 insertions(+), 17 deletions(-) diff --git a/compiler/src/garel/GARelSQL.cpp b/compiler/src/garel/GARelSQL.cpp index 09c7080..a06dbb6 100644 --- a/compiler/src/garel/GARelSQL.cpp +++ b/compiler/src/garel/GARelSQL.cpp @@ -1,17 +1,18 @@ -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/Operation.h" -#include "mlir/Interfaces/DataLayoutInterfaces.h" -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include "garel/GARelAttr.h" #include "garel/GARelOps.h" +#include "garel/GARelTypes.h" namespace garel { @@ -34,10 +35,16 @@ class SQLTranslator { mlir::LogicalResult translate(ConstantOp op); mlir::LogicalResult translate(AggregateOp op); mlir::LogicalResult translate(ProjectOp op); + mlir::LogicalResult translate(UnionOp op); + mlir::LogicalResult translate(JoinOp op); mlir::LogicalResult translate(ExtractOp op); mlir::LogicalResult translate(mlir::arith::SelectOp op); mlir::LogicalResult translate(mlir::arith::ConstantOp op); + mlir::LogicalResult translate(mlir::arith::AddIOp op); + mlir::LogicalResult translate(mlir::arith::AddFOp op); + + mlir::LogicalResult translateAdd(mlir::Operation *op); public: SQLTranslator(llvm::raw_ostream &os) : _os(os) {} @@ -89,7 +96,7 @@ mlir::LogicalResult SQLTranslator::translate(mlir::func::FuncOp op) { mlir::LogicalResult SQLTranslator::translate(mlir::Value val) { if (_valMap.contains(val)) { - _os << _valMap[val]; + _os << "(SELECT * FROM " << _valMap[val] << ")"; return mlir::success(); } @@ -112,16 +119,22 @@ mlir::LogicalResult SQLTranslator::translate(mlir::Operation *op) { CASE(ConstantOp) CASE(AggregateOp) CASE(ProjectOp) + CASE(UnionOp) + CASE(JoinOp) CASE(ExtractOp) CASE(mlir::arith::SelectOp) CASE(mlir::arith::ConstantOp) + CASE(mlir::arith::AddIOp) + CASE(mlir::arith::AddFOp) #undef CASE return op->emitOpError("no SQL translation defined for this op"); } mlir::LogicalResult SQLTranslator::translate(ForOp op) { + auto &body = op.getBody().front(); // Initialize temporary tables for loop state + llvm::SmallVector stateTables; for (auto i : llvm::seq(op.getInit().size())) { auto temp = newTemp(); _os << "CREATE TABLE " << temp << " AS "; @@ -130,12 +143,41 @@ mlir::LogicalResult SQLTranslator::translate(ForOp op) { } _os << ";\n"; - _valMap[op.getBody().getArgument(i)] = temp; + stateTables.push_back(temp); + _valMap[body.getArgument(i)] = temp; + } + + _os << "iters = "; + if (mlir::failed(translate(op.getIters()))) { + return mlir::failure(); } - // In the loop: - // - update the loop variables - // - check break condition - // Signal where to write output + _os << "\n"; + + _os << "for i in iters:\n"; + auto yieldOp = llvm::cast(op.getBody().front().getTerminator()); + llvm::SmallVector newStateTables; + for (auto i : llvm::seq(stateTables.size())) { + auto temp = newTemp(); + _os << " CREATE TABLE " << temp << " AS ("; + if (mlir::failed(translate(yieldOp.getInputs()[i]))) { + return mlir::failure(); + } + _os << ");\n"; + + newStateTables.push_back(temp); + } + + if (!op.getUntil().empty()) { + return op.emitOpError("'until' not implemented"); + } + + // TODO: convergence check? + // Swap to new tables + for (auto [table, newTable] : llvm::zip_equal(stateTables, newStateTables)) { + _os << " DROP TABLE " << table << ";\n"; + _os << " ALTER TABLE " << newTable << " RENAME TO " << table << ";\n"; + } + return mlir::success(); } @@ -205,6 +247,68 @@ mlir::LogicalResult SQLTranslator::translate(ProjectOp op) { return mlir::success(); } +mlir::LogicalResult SQLTranslator::translate(UnionOp op) { + if (op.getInputs().empty()) { + return op.emitOpError("union with zero inputs"); + } + + _os << "("; + for (auto [i, input] : llvm::enumerate(op.getInputs())) { + if (i != 0) { + _os << " UNION ALL "; + } + + if (mlir::failed(translate(input))) { + return mlir::failure(); + } + } + _os << ")"; + return mlir::success(); +} + +mlir::LogicalResult SQLTranslator::translate(JoinOp op) { + _os << "(SELECT "; + std::size_t outIdx = 0; + for (auto [i, input] : llvm::enumerate(op.getInputs())) { + auto type = llvm::cast(input.getType()); + for (auto c : llvm::seq(type.getColumns().size())) { + if (i != 0 || c != 0) { + _os << ", "; + } + + _os << "i" << i << ".c" << c << " AS c" << outIdx++; + } + } + + _os << " FROM "; + for (auto [i, input] : llvm::enumerate(op.getInputs())) { + if (i != 0) { + _os << ", "; + } + + _os << "("; + if (mlir::failed(translate(input))) { + return mlir::failure(); + } + _os << ") i" << i; + } + + if (!op.getPredicates().empty()) { + _os << " WHERE "; + for (auto [i, pred] : llvm::enumerate(op.getPredicates())) { + if (i != 0) { + _os << " AND "; + } + + _os << "i" << pred.getLhsRelIdx() << ".c" << pred.getLhsColIdx() << " = " + << "i" << pred.getRhsRelIdx() << ".c" << pred.getRhsColIdx(); + } + } + + _os << ")"; + return mlir::success(); +} + mlir::LogicalResult SQLTranslator::translate(ExtractOp op) { _os << "c" << op.getColumn(); return mlir::success(); @@ -235,6 +339,26 @@ mlir::LogicalResult SQLTranslator::translate(mlir::arith::ConstantOp op) { return mlir::success(); } +mlir::LogicalResult SQLTranslator::translateAdd(mlir::Operation *op) { + _os << "("; + if (mlir::failed(translate(op->getOperand(0)))) { + return mlir::failure(); + } + _os << " + "; + if (mlir::failed(translate(op->getOperand(1)))) { + return mlir::failure(); + } + _os << ")"; + return mlir::success(); +} + +mlir::LogicalResult SQLTranslator::translate(mlir::arith::AddIOp op) { + return translateAdd(op); +} +mlir::LogicalResult SQLTranslator::translate(mlir::arith::AddFOp op) { + return translateAdd(op); +} + mlir::LogicalResult translateToSQL(mlir::Operation *op, llvm::raw_ostream &os) { SQLTranslator translator(os); auto moduleOp = llvm::dyn_cast(op); From 38a9c401cffd4010f9533f7e0fc72c6abc3ece89 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Wed, 13 May 2026 12:47:07 +0000 Subject: [PATCH 04/15] Script for testing. --- compiler/sql.sh | 13 +++++++++++++ compiler/src/garel/GARelSQL.cpp | 25 +++++++++++++++++++++++++ 2 files changed, 38 insertions(+) create mode 100755 compiler/sql.sh diff --git a/compiler/sql.sh b/compiler/sql.sh new file mode 100755 index 0000000..ab31e54 --- /dev/null +++ b/compiler/sql.sh @@ -0,0 +1,13 @@ +#!/bin/bash +set -e + +cmake --build compiler/build --target graphalg-translate graphalg-opt garel-translate +compiler/build/tools/graphalg-translate --import-graphalg $1 | \ +compiler/build/tools/graphalg-opt \ + --graphalg-to-core-pipeline \ + --graphalg-verify-loop-bounds \ + --graphalg-explicate-sparsity \ + --graphalg-split-aggregate \ + --graphalg-loop-aggregate \ + --graphalg-to-rel | \ +compiler/build/tools/garel-translate --export-sql diff --git a/compiler/src/garel/GARelSQL.cpp b/compiler/src/garel/GARelSQL.cpp index a06dbb6..82fe46c 100644 --- a/compiler/src/garel/GARelSQL.cpp +++ b/compiler/src/garel/GARelSQL.cpp @@ -43,8 +43,11 @@ class SQLTranslator { mlir::LogicalResult translate(mlir::arith::ConstantOp op); mlir::LogicalResult translate(mlir::arith::AddIOp op); mlir::LogicalResult translate(mlir::arith::AddFOp op); + mlir::LogicalResult translate(mlir::arith::MulIOp op); + mlir::LogicalResult translate(mlir::arith::MulFOp op); mlir::LogicalResult translateAdd(mlir::Operation *op); + mlir::LogicalResult translateMul(mlir::Operation *op); public: SQLTranslator(llvm::raw_ostream &os) : _os(os) {} @@ -126,6 +129,8 @@ mlir::LogicalResult SQLTranslator::translate(mlir::Operation *op) { CASE(mlir::arith::ConstantOp) CASE(mlir::arith::AddIOp) CASE(mlir::arith::AddFOp) + CASE(mlir::arith::MulIOp) + CASE(mlir::arith::MulFOp) #undef CASE return op->emitOpError("no SQL translation defined for this op"); @@ -359,6 +364,26 @@ mlir::LogicalResult SQLTranslator::translate(mlir::arith::AddFOp op) { return translateAdd(op); } +mlir::LogicalResult SQLTranslator::translateMul(mlir::Operation *op) { + _os << "("; + if (mlir::failed(translate(op->getOperand(0)))) { + return mlir::failure(); + } + _os << " * "; + if (mlir::failed(translate(op->getOperand(1)))) { + return mlir::failure(); + } + _os << ")"; + return mlir::success(); +} + +mlir::LogicalResult SQLTranslator::translate(mlir::arith::MulIOp op) { + return translateMul(op); +} +mlir::LogicalResult SQLTranslator::translate(mlir::arith::MulFOp op) { + return translateMul(op); +} + mlir::LogicalResult translateToSQL(mlir::Operation *op, llvm::raw_ostream &os) { SQLTranslator translator(os); auto moduleOp = llvm::dyn_cast(op); From 154aa265a3f724957a279e8745269a16dfdd64e3 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Wed, 13 May 2026 12:57:29 +0000 Subject: [PATCH 05/15] Handle infinity values. --- compiler/src/garel/GARelSQL.cpp | 36 ++++++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/compiler/src/garel/GARelSQL.cpp b/compiler/src/garel/GARelSQL.cpp index 82fe46c..7d3f497 100644 --- a/compiler/src/garel/GARelSQL.cpp +++ b/compiler/src/garel/GARelSQL.cpp @@ -13,6 +13,8 @@ #include "garel/GARelAttr.h" #include "garel/GARelOps.h" #include "garel/GARelTypes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Location.h" namespace garel { @@ -46,6 +48,8 @@ class SQLTranslator { mlir::LogicalResult translate(mlir::arith::MulIOp op); mlir::LogicalResult translate(mlir::arith::MulFOp op); + mlir::LogicalResult translateConstant(mlir::Location loc, + mlir::Attribute attr); mlir::LogicalResult translateAdd(mlir::Operation *op); mlir::LogicalResult translateMul(mlir::Operation *op); @@ -186,8 +190,35 @@ mlir::LogicalResult SQLTranslator::translate(ForOp op) { return mlir::success(); } +mlir::LogicalResult SQLTranslator::translateConstant(mlir::Location loc, + mlir::Attribute attr) { + if (auto boolAttr = llvm::dyn_cast(attr)) { + _os << (boolAttr.getValue() ? "true" : "false"); + } else if (auto intAttr = llvm::dyn_cast(attr)) { + _os << intAttr.getValue(); + } else if (auto floatAttr = llvm::dyn_cast(attr)) { + auto value = floatAttr.getValue(); + if (value.isNegInfinity()) { + _os << "'-Infinity'"; + } else if (value.isPosInfinity()) { + _os << "'Infinity'"; + } else { + _os << value; + } + } else { + return mlir::emitError(loc) << "cannot convert constant " << attr; + } + + return mlir::success(); +} + mlir::LogicalResult SQLTranslator::translate(ConstantOp op) { - _os << "(SELECT " << op.getValue() << " AS c0)"; + _os << "(SELECT "; + if (mlir::failed(translateConstant(op.getLoc(), op.getValue()))) { + return mlir::failure(); + } + + _os << " AS c0)"; return mlir::success(); } @@ -340,8 +371,7 @@ mlir::LogicalResult SQLTranslator::translate(mlir::arith::SelectOp op) { } mlir::LogicalResult SQLTranslator::translate(mlir::arith::ConstantOp op) { - _os << op.getValue(); - return mlir::success(); + return translateConstant(op.getLoc(), op.getValue()); } mlir::LogicalResult SQLTranslator::translateAdd(mlir::Operation *op) { From b009aebf1179cca738f27ad722ae8e2df40b5b77 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Wed, 13 May 2026 14:40:08 +0000 Subject: [PATCH 06/15] A first integration test for SQL conversions. --- compiler/sql.sh | 13 ------- compiler/src/garel/GARelSQL.cpp | 69 ++++++++++++++++++++++++--------- compiler/test/CMakeLists.txt | 1 + compiler/test/lit.cfg.py | 1 + compiler/test/sql/sssp.gr | 60 ++++++++++++++++++++++++++++ 5 files changed, 113 insertions(+), 31 deletions(-) delete mode 100755 compiler/sql.sh create mode 100644 compiler/test/sql/sssp.gr diff --git a/compiler/sql.sh b/compiler/sql.sh deleted file mode 100755 index ab31e54..0000000 --- a/compiler/sql.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash -set -e - -cmake --build compiler/build --target graphalg-translate graphalg-opt garel-translate -compiler/build/tools/graphalg-translate --import-graphalg $1 | \ -compiler/build/tools/graphalg-opt \ - --graphalg-to-core-pipeline \ - --graphalg-verify-loop-bounds \ - --graphalg-explicate-sparsity \ - --graphalg-split-aggregate \ - --graphalg-loop-aggregate \ - --graphalg-to-rel | \ -compiler/build/tools/garel-translate --export-sql diff --git a/compiler/src/garel/GARelSQL.cpp b/compiler/src/garel/GARelSQL.cpp index 7d3f497..6a6bbaa 100644 --- a/compiler/src/garel/GARelSQL.cpp +++ b/compiler/src/garel/GARelSQL.cpp @@ -23,9 +23,17 @@ namespace { class SQLTranslator { private: llvm::raw_ostream &_os; + std::size_t _indentLevel = 0; + llvm::DenseMap _valMap; std::size_t _tempCount; + void indent() { + for (auto i : llvm::seq(_indentLevel)) { + _os << " "; + } + } + std::string newTemp() { return std::string("temp") + std::to_string(_tempCount++); } @@ -78,19 +86,23 @@ mlir::LogicalResult SQLTranslator::translate(mlir::ModuleOp op) { mlir::LogicalResult SQLTranslator::translate(mlir::func::FuncOp op) { auto name = op.getSymName(); - _os << "def " << name << "("; + _os << "def " << name << "(conn"; for (auto [i, arg] : llvm::enumerate(op.getBody().getArguments())) { auto varName = std::string("farg") + std::to_string(i); _valMap[arg] = varName; - - if (i != 0) { - _os << ", "; - } - + _os << ", "; _os << varName; } _os << "):\n"; + _indentLevel++; + + // Visit loops first, as they cannot be done with pure SQL + for (auto op : op.getOps()) { + if (mlir::failed(translate(op))) { + return mlir::failure(); + } + } auto retOp = llvm::cast(op.getBody().front().getTerminator()); @@ -98,7 +110,15 @@ mlir::LogicalResult SQLTranslator::translate(mlir::func::FuncOp op) { return retOp.emitOpError("expected a single return value"); } - return translate(retOp.getOperand(0)); + indent(); + _os << "return conn.sql(\"\"\""; + if (mlir::failed(translate(retOp.getOperand(0)))) { + return mlir::failure(); + } + + _os << "\"\"\")\n"; + _indentLevel--; + return mlir::success(); } mlir::LogicalResult SQLTranslator::translate(mlir::Value val) { @@ -146,32 +166,37 @@ mlir::LogicalResult SQLTranslator::translate(ForOp op) { llvm::SmallVector stateTables; for (auto i : llvm::seq(op.getInit().size())) { auto temp = newTemp(); - _os << "CREATE TABLE " << temp << " AS "; + indent(); + _os << "conn.execute(\"\"\"CREATE TABLE " << temp << " AS "; if (mlir::failed(translate(op.getInit()[i]))) { return mlir::failure(); } - _os << ";\n"; + _os << "\"\"\")\n"; stateTables.push_back(temp); _valMap[body.getArgument(i)] = temp; } - _os << "iters = "; + indent(); + _os << "iters, = conn.sql(\"\"\""; if (mlir::failed(translate(op.getIters()))) { return mlir::failure(); } - _os << "\n"; + _os << "\"\"\").fetchone()\n"; - _os << "for i in iters:\n"; + indent(); + _os << "for i in range(iters):\n"; + _indentLevel++; auto yieldOp = llvm::cast(op.getBody().front().getTerminator()); llvm::SmallVector newStateTables; for (auto i : llvm::seq(stateTables.size())) { auto temp = newTemp(); - _os << " CREATE TABLE " << temp << " AS ("; + indent(); + _os << "conn.execute(\"\"\"CREATE TABLE " << temp << " AS "; if (mlir::failed(translate(yieldOp.getInputs()[i]))) { return mlir::failure(); } - _os << ");\n"; + _os << "\"\"\")\n"; newStateTables.push_back(temp); } @@ -183,10 +208,18 @@ mlir::LogicalResult SQLTranslator::translate(ForOp op) { // TODO: convergence check? // Swap to new tables for (auto [table, newTable] : llvm::zip_equal(stateTables, newStateTables)) { - _os << " DROP TABLE " << table << ";\n"; - _os << " ALTER TABLE " << newTable << " RENAME TO " << table << ";\n"; + indent(); + _os << "conn.execute(\"DROP TABLE " << table << "\")\n"; + indent(); + _os << "conn.execute(\"ALTER TABLE " << newTable << " RENAME TO " << table + << "\")\n"; } + _indentLevel--; + + // Bind result of the loop to state table. + _valMap[op] = stateTables[op.getResultIdx()]; + return mlir::success(); } @@ -203,7 +236,7 @@ mlir::LogicalResult SQLTranslator::translateConstant(mlir::Location loc, } else if (value.isPosInfinity()) { _os << "'Infinity'"; } else { - _os << value; + _os << "CAST(" << value << " AS DOUBLE PRECISION)"; } } else { return mlir::emitError(loc) << "cannot convert constant " << attr; @@ -366,7 +399,7 @@ mlir::LogicalResult SQLTranslator::translate(mlir::arith::SelectOp op) { return mlir::failure(); } - _os << ")"; + _os << " END)"; return mlir::success(); } diff --git a/compiler/test/CMakeLists.txt b/compiler/test/CMakeLists.txt index 7d6fe86..4863dd3 100644 --- a/compiler/test/CMakeLists.txt +++ b/compiler/test/CMakeLists.txt @@ -20,6 +20,7 @@ configure_lit_site_cfg( add_lit_testsuite(check "Run integration tests" ${CMAKE_CURRENT_BINARY_DIR} DEPENDS + garel-translate graphalg-exec graphalg-opt graphalg-translate diff --git a/compiler/test/lit.cfg.py b/compiler/test/lit.cfg.py index 0ad7b62..f739d7a 100644 --- a/compiler/test/lit.cfg.py +++ b/compiler/test/lit.cfg.py @@ -14,6 +14,7 @@ # The tools we want to use in lit test (inside RUN) tools = [ + "garel-translate", "graphalg-exec", "graphalg-opt", "graphalg-translate", diff --git a/compiler/test/sql/sssp.gr b/compiler/test/sql/sssp.gr new file mode 100644 index 0000000..7b97206 --- /dev/null +++ b/compiler/test/sql/sssp.gr @@ -0,0 +1,60 @@ +// RUN: split-file %s %t +// RUN: graphalg-translate --import-graphalg %t/input.gr > %t/parsed.mlir +// RUN: graphalg-opt %t/parsed.mlir --graphalg-to-core-pipeline --graphalg-verify-loop-bounds --graphalg-explicate-sparsity --graphalg-split-aggregate --graphalg-loop-aggregate --graphalg-to-rel > %t/rel.mlir +// RUN: garel-translate --export-sql %t/rel.mlir > %t/sssp.py +// RUN: touch %t/__init__.py +// RUN: python3 %t/driver.py +// RUN: diff %t/reference.csv %t/output.csv + +//--- edges.csv +c0, c1, c2 +0, 1, 0.5 +0, 2, 5.0 +0, 3, 5.0 +1, 4, 0.5 +2, 3, 2.0 +4, 5, 0.5 +5, 2, 0.5 +5, 9, 23.0 +6, 0, 1.0 +6, 7, 3.2 +7, 9, 0.2 +8, 9, 0.1 +9, 6, 8.0 + +//--- input.gr +func SSSP( + graph: Matrix, + source: Vector) -> Vector { + v = cast(source); + for i in graph.nrows { + v += v * graph; + } + return v; +} + +//--- driver.py +from sssp import SSSP +import duckdb +from pathlib import Path + +dir = Path(__file__).parent +conn = duckdb.connect() +vertices = conn.sql("SELECT generate_series AS c0 FROM generate_series(0, 10)") +edges = conn.read_csv(dir / 'edges.csv') +source = conn.sql("SELECT 0 AS c0, true AS c1") + +res = SSSP(conn, edges, source, vertices) +res.sort('c0').to_csv(str(dir / 'output.csv')) + +//--- reference.csv +c0,c1 +0,0.0 +1,0.5 +2,2.0 +3,4.0 +4,1.0 +5,1.5 +6,32.5 +7,35.7 +9,24.5 From 958438fcd748c441abdff34353abb150a8642732 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Wed, 13 May 2026 19:33:06 +0000 Subject: [PATCH 07/15] More stuff. --- compiler/src/garel/GARelSQL.cpp | 92 ++++++++++++++++++- compiler/src/garel/GraphAlgToRel.cpp | 48 ++++++++-- .../test/graphalg-to-rel/deferred-reduce.mlir | 20 +++- compiler/test/sql/bfs.gr | 76 +++++++++++++++ 4 files changed, 219 insertions(+), 17 deletions(-) create mode 100644 compiler/test/sql/bfs.gr diff --git a/compiler/src/garel/GARelSQL.cpp b/compiler/src/garel/GARelSQL.cpp index 6a6bbaa..72331e2 100644 --- a/compiler/src/garel/GARelSQL.cpp +++ b/compiler/src/garel/GARelSQL.cpp @@ -15,6 +15,7 @@ #include "garel/GARelTypes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Location.h" +#include "llvm/ADT/StringRef.h" namespace garel { @@ -55,6 +56,8 @@ class SQLTranslator { mlir::LogicalResult translate(mlir::arith::AddFOp op); mlir::LogicalResult translate(mlir::arith::MulIOp op); mlir::LogicalResult translate(mlir::arith::MulFOp op); + mlir::LogicalResult translate(mlir::arith::AndIOp op); + mlir::LogicalResult translate(mlir::arith::CmpIOp op); mlir::LogicalResult translateConstant(mlir::Location loc, mlir::Attribute attr); @@ -155,6 +158,8 @@ mlir::LogicalResult SQLTranslator::translate(mlir::Operation *op) { CASE(mlir::arith::AddFOp) CASE(mlir::arith::MulIOp) CASE(mlir::arith::MulFOp) + CASE(mlir::arith::AndIOp) + CASE(mlir::arith::CmpIOp) #undef CASE return op->emitOpError("no SQL translation defined for this op"); @@ -201,10 +206,6 @@ mlir::LogicalResult SQLTranslator::translate(ForOp op) { newStateTables.push_back(temp); } - if (!op.getUntil().empty()) { - return op.emitOpError("'until' not implemented"); - } - // TODO: convergence check? // Swap to new tables for (auto [table, newTable] : llvm::zip_equal(stateTables, newStateTables)) { @@ -215,6 +216,22 @@ mlir::LogicalResult SQLTranslator::translate(ForOp op) { << "\")\n"; } + if (!op.getUntil().empty()) { + indent(); + _os << "until, = conn.sql(\"\"\""; + if (mlir::failed(translate(op.getIters()))) { + return mlir::failure(); + } + _os << "\"\"\").fetchone()\n"; + + indent(); + _os << "if until:\n"; + _indentLevel++; + indent(); + _os << "break\n"; + _indentLevel--; + } + _indentLevel--; // Bind result of the loop to state table. @@ -255,6 +272,23 @@ mlir::LogicalResult SQLTranslator::translate(ConstantOp op) { return mlir::success(); } +static llvm::StringLiteral translateAggregateFunc(AggregateFunc f) { + switch (f) { + case AggregateFunc::SUM: + return "SUM"; + case AggregateFunc::MIN: + return "MIN"; + case AggregateFunc::MAX: + return "MAX"; + case AggregateFunc::LOR: + return "BOOL_OR"; + case AggregateFunc::ARGMIN: + return "ARG_MIN"; + case AggregateFunc::COUNT: + return "COUNT"; + } +} + mlir::LogicalResult SQLTranslator::translate(AggregateOp op) { _os << "(SELECT "; std::size_t colOut = 0; @@ -271,7 +305,7 @@ mlir::LogicalResult SQLTranslator::translate(AggregateOp op) { _os << ", "; } - _os << stringifyAggregateFunc(agg.getFunc()) << "("; + _os << translateAggregateFunc(agg.getFunc()) << "("; llvm::interleaveComma(agg.getInputs(), _os, [&](ColumnIdx idx) { _os << "c" << idx; }); _os << ") AS c" << colOut++; @@ -447,6 +481,54 @@ mlir::LogicalResult SQLTranslator::translate(mlir::arith::MulFOp op) { return translateMul(op); } +mlir::LogicalResult SQLTranslator::translate(mlir::arith::AndIOp op) { + _os << "("; + if (mlir::failed(translate(op->getOperand(0)))) { + return mlir::failure(); + } + _os << " AND "; + if (mlir::failed(translate(op->getOperand(1)))) { + return mlir::failure(); + } + _os << ")"; + return mlir::success(); +} + +static llvm::StringLiteral translatePredicate(mlir::arith::CmpIPredicate pred) { + switch (pred) { + case mlir::arith::CmpIPredicate::eq: + return "="; + case mlir::arith::CmpIPredicate::ne: + return "<>"; + case mlir::arith::CmpIPredicate::slt: + case mlir::arith::CmpIPredicate::ult: + return "<"; + case mlir::arith::CmpIPredicate::sle: + case mlir::arith::CmpIPredicate::ule: + return "<="; + case mlir::arith::CmpIPredicate::sgt: + case mlir::arith::CmpIPredicate::ugt: + return ">"; + case mlir::arith::CmpIPredicate::sge: + case mlir::arith::CmpIPredicate::uge: + return ">="; + } +} + +mlir::LogicalResult SQLTranslator::translate(mlir::arith::CmpIOp op) { + _os << "("; + if (mlir::failed(translate(op->getOperand(0)))) { + return mlir::failure(); + } + _os << " " << translatePredicate(op.getPredicate()) << " "; + + if (mlir::failed(translate(op->getOperand(1)))) { + return mlir::failure(); + } + _os << ")"; + return mlir::success(); +} + mlir::LogicalResult translateToSQL(mlir::Operation *op, llvm::raw_ostream &os) { SQLTranslator translator(os); auto moduleOp = llvm::dyn_cast(op); diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index cc94417..f2cc533 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -852,31 +852,61 @@ template <> mlir::LogicalResult OpConversion::matchAndRewrite( graphalg::DeferredReduceOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { - MatrixAdaptor input(op.getInputs()[0], adaptor.getInputs()[0]); MatrixAdaptor output(op, typeConverter->convertType(op.getType())); + // Do we have row and column universally for all inputs? + llvm::SmallVector inputAdaptors; + bool haveRowColumn = true; + bool haveColColumn = true; + for (auto [matrix, rel] : + llvm::zip_equal(op.getInputs(), adaptor.getInputs())) { + auto &adaptor = inputAdaptors.emplace_back(matrix, rel); + haveRowColumn &= adaptor.hasRowColumn(); + haveColColumn &= adaptor.hasColColumn(); + } + + // Drop row/col columns if they are not universal in order to unify the types + // of all inputs, as required to build UnionOp. + llvm::SmallVector inputs; + for (auto &input : inputAdaptors) { + llvm::SmallVector remap; + if (haveRowColumn) { + remap.push_back(input.rowColumn()); + } + + if (haveColColumn) { + remap.push_back(input.colColumn()); + } + + remap.push_back(input.valColumn()); + inputs.push_back( + rewriter.createOrFold(op.getLoc(), input.relation(), remap)); + } + + auto unionOp = rewriter.createOrFold(op.getLoc(), inputs); + // Group by keys llvm::SmallVector groupBy; if (output.hasRowColumn()) { - groupBy.push_back(input.rowColumn()); + // Always column 0 + groupBy.push_back(0); } if (output.hasColColumn()) { - groupBy.push_back(input.colColumn()); + // Follows row column + std::size_t colIdx = haveRowColumn ? 1 : 0; + groupBy.push_back(colIdx); } // Aggregators - auto aggregator = - createAggregator(op, input.semiring(), input.valColumn(), rewriter); + // Follows row and col columns + auto valIdx = haveRowColumn + haveColColumn; + auto aggregator = createAggregator(op, output.semiring(), valIdx, rewriter); if (mlir::failed(aggregator)) { return mlir::failure(); } std::array aggregators{*aggregator}; - - // union the inputs and then aggregate. - auto unionOp = - rewriter.createOrFold(op.getLoc(), adaptor.getInputs()); rewriter.replaceOpWithNewOp(op, unionOp, groupBy, aggregators); return mlir::success(); } diff --git a/compiler/test/graphalg-to-rel/deferred-reduce.mlir b/compiler/test/graphalg-to-rel/deferred-reduce.mlir index 2e2edc8..3102a88 100644 --- a/compiler/test/graphalg-to-rel/deferred-reduce.mlir +++ b/compiler/test/graphalg-to-rel/deferred-reduce.mlir @@ -66,16 +66,30 @@ func.func @ReduceScalarScalar(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.ma } // CHECK-LABEL: @ReduceMultiple -func.func @ReduceMultiple( +func.func @ReduceMultipleScalar( %arg0 : !graphalg.mat<1 x 43 x i64>, %arg1 : !graphalg.mat<42 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { - // CHECK: %[[#UNION:]] = garel.union %arg0, %arg1 : - // CHECK: %[[#AGG:]] = garel.aggregate %0 : group_by=[] aggregators=[] + // CHECK: %[[#REMAP0:]] = garel.remap %arg0 : [1] + // CHECK: %[[#REMAP1:]] = garel.remap %arg1 : [1] + // CHECK: %[[#UNION:]] = garel.union %[[#REMAP0]], %[[#REMAP1]] : + // CHECK: %[[#AGG:]] = garel.aggregate %[[#UNION]] : group_by=[] aggregators=[] %0 = graphalg.deferred_reduce %arg0, %arg1 : !graphalg.mat<1 x 43 x i64>, !graphalg.mat<42 x 1 x i64> -> <1 x 1 x i64> return %0 : !graphalg.mat<1 x 1 x i64> } +// CHECK-LABEL: @ReduceMultipleVec +func.func @ReduceMultipleVec( + %arg0 : !graphalg.mat<1 x 42 x i64>, + %arg1 : !graphalg.mat<43 x 42 x i64>) + -> !graphalg.mat<1 x 42 x i64> { + // CHECK: %[[#REMAP:]] = garel.remap %arg1 : [1, 2] + // CHECK: %[[#UNION:]] = garel.union %arg0, %[[#REMAP]] : + // CHECK: %[[#AGG:]] = garel.aggregate %[[#UNION]] : group_by=[0] aggregators=[] + %0 = graphalg.deferred_reduce %arg0, %arg1 : !graphalg.mat<1 x 42 x i64>, !graphalg.mat<43 x 42 x i64> -> <1 x 42 x i64> + return %0 : !graphalg.mat<1 x 42 x i64> +} + // === Semirings // CHECK-LABEL: @ReduceBool diff --git a/compiler/test/sql/bfs.gr b/compiler/test/sql/bfs.gr new file mode 100644 index 0000000..0a1616a --- /dev/null +++ b/compiler/test/sql/bfs.gr @@ -0,0 +1,76 @@ +// RUN: split-file %s %t +// RUN: graphalg-translate --import-graphalg %t/input.gr > %t/parsed.mlir +// RUN: graphalg-opt %t/parsed.mlir --graphalg-to-core-pipeline --graphalg-verify-loop-bounds --graphalg-explicate-sparsity --graphalg-split-aggregate --graphalg-loop-aggregate --graphalg-to-rel > %t/rel.mlir +// RUN: garel-translate --export-sql %t/rel.mlir > %t/bfs.py +// RUN: touch %t/__init__.py +// RUN: python3 %t/driver.py +// RUN: diff %t/reference.csv %t/output.csv + +//--- edges.csv +c0, c1, c2 +0,1,TRUE +0,2,TRUE +1,2,TRUE +1,3,TRUE +1,4,TRUE +2,0,TRUE +3,5,TRUE +3,6,TRUE +3,7,TRUE +4,0,TRUE +4,1,TRUE +5,3,TRUE +5,7,TRUE +7,0,TRUE +7,1,TRUE +7,2,TRUE +8,9,TRUE + +//--- input.gr +func setDepth(b:bool, iter:int) -> int { + return cast(b) * (iter + int(2)); +} + +func BFS(graph: Matrix, source: Vector) -> Vector { + v = Vector(graph.nrows); + v[:] = int(1); + + frontier = source; + reach = source; + + for i in graph.nrows { + step = Vector(graph.nrows); + step = frontier * graph; + + v += apply(setDepth, step, i); + + frontier = step; + reach += step; + } until frontier.nvals == int(0); + + return v; +} + +//--- driver.py +from bfs import BFS +import duckdb +from pathlib import Path + +dir = Path(__file__).parent +conn = duckdb.connect() +vertices = conn.sql("SELECT generate_series AS c0 FROM generate_series(0, 10)") +edges = conn.read_csv(dir / 'edges.csv', dtype=['int64', 'int64', 'bool']) +source = conn.sql("SELECT 0 AS c0, true::BOOL AS c1") + +res = BFS(conn, edges, source, vertices) +res.sort('c0').to_csv(str(dir / 'output.csv')) + +//--- reference.m +0 0 1 : i64 +1 0 2 : i64 +2 0 2 : i64 +3 0 3 : i64 +4 0 3 : i64 +5 0 4 : i64 +6 0 4 : i64 +7 0 4 : i64 From 54e18ceabc3afe5b192f2e4f19e5d47b3c2e9281 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Thu, 14 May 2026 07:32:40 +0000 Subject: [PATCH 08/15] BFS works. --- compiler/src/garel/GARelSQL.cpp | 32 +++++++++++++++++++++++++++++++- compiler/test/sql/bfs.gr | 19 ++++++++++--------- 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/compiler/src/garel/GARelSQL.cpp b/compiler/src/garel/GARelSQL.cpp index 72331e2..8c19dc3 100644 --- a/compiler/src/garel/GARelSQL.cpp +++ b/compiler/src/garel/GARelSQL.cpp @@ -48,6 +48,7 @@ class SQLTranslator { mlir::LogicalResult translate(ProjectOp op); mlir::LogicalResult translate(UnionOp op); mlir::LogicalResult translate(JoinOp op); + mlir::LogicalResult translate(RemapOp op); mlir::LogicalResult translate(ExtractOp op); mlir::LogicalResult translate(mlir::arith::SelectOp op); @@ -151,6 +152,7 @@ mlir::LogicalResult SQLTranslator::translate(mlir::Operation *op) { CASE(ProjectOp) CASE(UnionOp) CASE(JoinOp) + CASE(RemapOp) CASE(ExtractOp) CASE(mlir::arith::SelectOp) CASE(mlir::arith::ConstantOp) @@ -217,9 +219,16 @@ mlir::LogicalResult SQLTranslator::translate(ForOp op) { } if (!op.getUntil().empty()) { + auto &body = op.getUntil().front(); + auto yieldOp = llvm::cast(body.getTerminator()); + // Map block arguments + for (auto i : llvm::seq(stateTables.size())) { + _valMap[body.getArgument(i)] = stateTables[i]; + } + indent(); _os << "until, = conn.sql(\"\"\""; - if (mlir::failed(translate(op.getIters()))) { + if (mlir::failed(translate(yieldOp.getInputs()[0]))) { return mlir::failure(); } _os << "\"\"\").fetchone()\n"; @@ -412,6 +421,27 @@ mlir::LogicalResult SQLTranslator::translate(JoinOp op) { return mlir::success(); } +mlir::LogicalResult SQLTranslator::translate(RemapOp op) { + _os << "(SELECT "; + + ColumnIdx outIdx = 0; + for (ColumnIdx inIdx : op.getRemap()) { + if (outIdx != 0) { + _os << ", "; + } + + _os << "c" << inIdx << " AS c" << outIdx; + } + + _os << " FROM "; + if (mlir::failed(translate(op.getInput()))) { + return mlir::failure(); + } + + _os << ")"; + return mlir::success(); +} + mlir::LogicalResult SQLTranslator::translate(ExtractOp op) { _os << "c" << op.getColumn(); return mlir::success(); diff --git a/compiler/test/sql/bfs.gr b/compiler/test/sql/bfs.gr index 0a1616a..3f9342f 100644 --- a/compiler/test/sql/bfs.gr +++ b/compiler/test/sql/bfs.gr @@ -65,12 +65,13 @@ source = conn.sql("SELECT 0 AS c0, true::BOOL AS c1") res = BFS(conn, edges, source, vertices) res.sort('c0').to_csv(str(dir / 'output.csv')) -//--- reference.m -0 0 1 : i64 -1 0 2 : i64 -2 0 2 : i64 -3 0 3 : i64 -4 0 3 : i64 -5 0 4 : i64 -6 0 4 : i64 -7 0 4 : i64 +//--- reference.csv +c0,c1 +0,1 +1,2 +2,2 +3,3 +4,3 +5,4 +6,4 +7,4 From f3b33b2fd0c6288377a05c2042f63f628a4dd35a Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Thu, 14 May 2026 07:57:57 +0000 Subject: [PATCH 09/15] CDLP. --- compiler/src/garel/GARelSQL.cpp | 43 ++++++++++++++- compiler/test/sql/bfs.gr | 2 +- compiler/test/sql/cdlp.gr | 93 +++++++++++++++++++++++++++++++++ compiler/test/sql/sssp.gr | 2 +- 4 files changed, 137 insertions(+), 3 deletions(-) create mode 100644 compiler/test/sql/cdlp.gr diff --git a/compiler/src/garel/GARelSQL.cpp b/compiler/src/garel/GARelSQL.cpp index 8c19dc3..5aa66ef 100644 --- a/compiler/src/garel/GARelSQL.cpp +++ b/compiler/src/garel/GARelSQL.cpp @@ -49,6 +49,7 @@ class SQLTranslator { mlir::LogicalResult translate(UnionOp op); mlir::LogicalResult translate(JoinOp op); mlir::LogicalResult translate(RemapOp op); + mlir::LogicalResult translate(SelectOp op); mlir::LogicalResult translate(ExtractOp op); mlir::LogicalResult translate(mlir::arith::SelectOp op); @@ -58,6 +59,7 @@ class SQLTranslator { mlir::LogicalResult translate(mlir::arith::MulIOp op); mlir::LogicalResult translate(mlir::arith::MulFOp op); mlir::LogicalResult translate(mlir::arith::AndIOp op); + mlir::LogicalResult translate(mlir::arith::OrIOp op); mlir::LogicalResult translate(mlir::arith::CmpIOp op); mlir::LogicalResult translateConstant(mlir::Location loc, @@ -153,6 +155,7 @@ mlir::LogicalResult SQLTranslator::translate(mlir::Operation *op) { CASE(UnionOp) CASE(JoinOp) CASE(RemapOp) + CASE(SelectOp) CASE(ExtractOp) CASE(mlir::arith::SelectOp) CASE(mlir::arith::ConstantOp) @@ -161,6 +164,7 @@ mlir::LogicalResult SQLTranslator::translate(mlir::Operation *op) { CASE(mlir::arith::MulIOp) CASE(mlir::arith::MulFOp) CASE(mlir::arith::AndIOp) + CASE(mlir::arith::OrIOp) CASE(mlir::arith::CmpIOp) #undef CASE @@ -430,7 +434,7 @@ mlir::LogicalResult SQLTranslator::translate(RemapOp op) { _os << ", "; } - _os << "c" << inIdx << " AS c" << outIdx; + _os << "c" << inIdx << " AS c" << outIdx++; } _os << " FROM "; @@ -442,6 +446,30 @@ mlir::LogicalResult SQLTranslator::translate(RemapOp op) { return mlir::success(); } +mlir::LogicalResult SQLTranslator::translate(SelectOp op) { + _os << "(SELECT * FROM "; + if (mlir::failed(translate(op.getInput()))) { + return mlir::failure(); + } + + _os << " WHERE "; + auto yieldOp = op.getTerminator(); + for (auto [i, pred] : llvm::enumerate(yieldOp.getPredicates())) { + if (i != 0) { + _os << " AND "; + } + + _os << "("; + if (mlir::failed(translate(pred))) { + return mlir::failure(); + } + _os << ")"; + } + + _os << ")"; + return mlir::success(); +} + mlir::LogicalResult SQLTranslator::translate(ExtractOp op) { _os << "c" << op.getColumn(); return mlir::success(); @@ -524,6 +552,19 @@ mlir::LogicalResult SQLTranslator::translate(mlir::arith::AndIOp op) { return mlir::success(); } +mlir::LogicalResult SQLTranslator::translate(mlir::arith::OrIOp op) { + _os << "("; + if (mlir::failed(translate(op->getOperand(0)))) { + return mlir::failure(); + } + _os << " OR "; + if (mlir::failed(translate(op->getOperand(1)))) { + return mlir::failure(); + } + _os << ")"; + return mlir::success(); +} + static llvm::StringLiteral translatePredicate(mlir::arith::CmpIPredicate pred) { switch (pred) { case mlir::arith::CmpIPredicate::eq: diff --git a/compiler/test/sql/bfs.gr b/compiler/test/sql/bfs.gr index 3f9342f..1f54ceb 100644 --- a/compiler/test/sql/bfs.gr +++ b/compiler/test/sql/bfs.gr @@ -58,7 +58,7 @@ from pathlib import Path dir = Path(__file__).parent conn = duckdb.connect() -vertices = conn.sql("SELECT generate_series AS c0 FROM generate_series(0, 10)") +vertices = conn.sql("SELECT generate_series AS c0 FROM generate_series(0, 7)") edges = conn.read_csv(dir / 'edges.csv', dtype=['int64', 'int64', 'bool']) source = conn.sql("SELECT 0 AS c0, true::BOOL AS c1") diff --git a/compiler/test/sql/cdlp.gr b/compiler/test/sql/cdlp.gr new file mode 100644 index 0000000..d586aaa --- /dev/null +++ b/compiler/test/sql/cdlp.gr @@ -0,0 +1,93 @@ +// RUN: split-file %s %t +// RUN: graphalg-translate --import-graphalg %t/input.gr > %t/parsed.mlir +// RUN: graphalg-opt %t/parsed.mlir --graphalg-to-core-pipeline --graphalg-verify-loop-bounds --graphalg-explicate-sparsity --graphalg-split-aggregate --graphalg-loop-aggregate --graphalg-to-rel > %t/rel.mlir +// RUN: garel-translate --export-sql %t/rel.mlir > %t/cdlp.py +// RUN: touch %t/__init__.py +// RUN: python3 %t/driver.py +// RUN: diff %t/reference.csv %t/output.csv + +//--- edges.csv +c0, c1, c2 +0,1,TRUE +0,2,TRUE +0,6,TRUE +1,0,TRUE +1,2,TRUE +2,0,TRUE +2,1,TRUE +3,4,TRUE +3,5,TRUE +4,3,TRUE +4,5,TRUE +4,6,TRUE +5,4,TRUE +5,6,TRUE +6,4,TRUE +6,5,TRUE +6,7,TRUE +7,5,TRUE + +//--- input.gr +func isMax(v: int, max: trop_max_int) -> bool { + return (cast(v) == max) + * (v != zero(int)); +} + +func CDLP(graph: Matrix) -> Matrix { + id = Vector(graph.nrows); + id[:] = bool(true); + L = diag(id); + + for i in int(0):int(5) { + step_forward = cast(graph) * cast(L); + step_backward = cast(graph.T) * cast(L); + step = step_forward (.+) step_backward; + + // Max per row + max = reduceRows(cast(step)); + + // Broadcast to all columns + b = Vector(graph.ncols); + b[:] = one(trop_max_int); + max_broadcast = max * b.T; + + // Matrix with true at every position where L has max element. + step_max = step (.isMax) max_broadcast; + + // Keep only one assigned label per vertex. + // The implementation always picks the one with the lowest id. + L = pickAny(step_max); + } + + // Map isolated nodes to their own label. + connected = reduceRows(graph) (.+) reduceRows(graph.T); + isolated = Vector(graph.nrows); + isolated[:] = bool(true); + L = diag(isolated) (.+) L; + + return L; +} + +//--- driver.py +from cdlp import CDLP +import duckdb +from pathlib import Path + +dir = Path(__file__).parent +conn = duckdb.connect() +vertices = conn.sql("SELECT generate_series AS c0 FROM generate_series(0, 7)") +edges = conn.read_csv(dir / 'edges.csv', dtype=['int64', 'int64', 'bool']) + +res = CDLP(conn, edges, vertices) +res.filter('c2').sort('c0').to_csv(str(dir / 'output.csv')) + +//--- reference.csv +c0,c1,c2 +0,0,true +1,0,true +2,0,true +3,4,true +4,3,true +5,3,true +6,3,true +7,3,true diff --git a/compiler/test/sql/sssp.gr b/compiler/test/sql/sssp.gr index 7b97206..213805a 100644 --- a/compiler/test/sql/sssp.gr +++ b/compiler/test/sql/sssp.gr @@ -40,7 +40,7 @@ from pathlib import Path dir = Path(__file__).parent conn = duckdb.connect() -vertices = conn.sql("SELECT generate_series AS c0 FROM generate_series(0, 10)") +vertices = conn.sql("SELECT generate_series AS c0 FROM generate_series(0, 9)") edges = conn.read_csv(dir / 'edges.csv') source = conn.sql("SELECT 0 AS c0, true AS c1") From 978e9bfcf235ac85f39a27d8d68ce941f3596e91 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Thu, 14 May 2026 08:15:30 +0000 Subject: [PATCH 10/15] PR. --- compiler/src/garel/GARelSQL.cpp | 26 +++ compiler/test/sql/bfs.gr | 34 +-- compiler/test/sql/pr.gr | 359 ++++++++++++++++++++++++++++++++ 3 files changed, 402 insertions(+), 17 deletions(-) create mode 100644 compiler/test/sql/pr.gr diff --git a/compiler/src/garel/GARelSQL.cpp b/compiler/src/garel/GARelSQL.cpp index 5aa66ef..4e6c103 100644 --- a/compiler/src/garel/GARelSQL.cpp +++ b/compiler/src/garel/GARelSQL.cpp @@ -61,6 +61,8 @@ class SQLTranslator { mlir::LogicalResult translate(mlir::arith::AndIOp op); mlir::LogicalResult translate(mlir::arith::OrIOp op); mlir::LogicalResult translate(mlir::arith::CmpIOp op); + mlir::LogicalResult translate(mlir::arith::DivFOp op); + mlir::LogicalResult translate(mlir::arith::SIToFPOp op); mlir::LogicalResult translateConstant(mlir::Location loc, mlir::Attribute attr); @@ -166,6 +168,8 @@ mlir::LogicalResult SQLTranslator::translate(mlir::Operation *op) { CASE(mlir::arith::AndIOp) CASE(mlir::arith::OrIOp) CASE(mlir::arith::CmpIOp) + CASE(mlir::arith::DivFOp) + CASE(mlir::arith::SIToFPOp) #undef CASE return op->emitOpError("no SQL translation defined for this op"); @@ -600,6 +604,28 @@ mlir::LogicalResult SQLTranslator::translate(mlir::arith::CmpIOp op) { return mlir::success(); } +mlir::LogicalResult SQLTranslator::translate(mlir::arith::DivFOp op) { + _os << "("; + if (mlir::failed(translate(op.getLhs()))) { + return mlir::failure(); + } + _os << " / "; + if (mlir::failed(translate(op.getRhs()))) { + return mlir::failure(); + } + _os << ")"; + return mlir::success(); +} + +mlir::LogicalResult SQLTranslator::translate(mlir::arith::SIToFPOp op) { + _os << "CAST("; + if (mlir::failed(translate(op.getIn()))) { + return mlir::failure(); + } + _os << " AS DOUBLE PRECISION)"; + return mlir::success(); +} + mlir::LogicalResult translateToSQL(mlir::Operation *op, llvm::raw_ostream &os) { SQLTranslator translator(os); auto moduleOp = llvm::dyn_cast(op); diff --git a/compiler/test/sql/bfs.gr b/compiler/test/sql/bfs.gr index 1f54ceb..1421ab5 100644 --- a/compiler/test/sql/bfs.gr +++ b/compiler/test/sql/bfs.gr @@ -8,23 +8,23 @@ //--- edges.csv c0, c1, c2 -0,1,TRUE -0,2,TRUE -1,2,TRUE -1,3,TRUE -1,4,TRUE -2,0,TRUE -3,5,TRUE -3,6,TRUE -3,7,TRUE -4,0,TRUE -4,1,TRUE -5,3,TRUE -5,7,TRUE -7,0,TRUE -7,1,TRUE -7,2,TRUE -8,9,TRUE +0,1,true +0,2,true +1,2,true +1,3,true +1,4,true +2,0,true +3,5,true +3,6,true +3,7,true +4,0,true +4,1,true +5,3,true +5,7,true +7,0,true +7,1,true +7,2,true +8,9,true //--- input.gr func setDepth(b:bool, iter:int) -> int { diff --git a/compiler/test/sql/pr.gr b/compiler/test/sql/pr.gr new file mode 100644 index 0000000..95076c5 --- /dev/null +++ b/compiler/test/sql/pr.gr @@ -0,0 +1,359 @@ +// RUN: split-file %s %t +// RUN: graphalg-translate --import-graphalg %t/input.gr > %t/parsed.mlir +// RUN: graphalg-opt %t/parsed.mlir --graphalg-to-core-pipeline --graphalg-verify-loop-bounds --graphalg-explicate-sparsity --graphalg-split-aggregate --graphalg-loop-aggregate --graphalg-to-rel > %t/rel.mlir +// RUN: garel-translate --export-sql %t/rel.mlir > %t/pr.py +// RUN: touch %t/__init__.py +// RUN: python3 %t/driver.py +// RUN: diff %t/reference.csv %t/output.csv + +//--- edges.csv +c0,c1,c2 +0,18,true +0,20,true +0,21,true +0,26,true +0,30,true +0,36,true +0,44,true +0,47,true +1,2,true +1,19,true +1,38,true +1,45,true +2,5,true +2,9,true +2,31,true +2,40,true +2,44,true +3,14,true +4,14,true +4,15,true +4,17,true +4,27,true +4,46,true +5,48,true +6,5,true +6,26,true +6,42,true +6,45,true +7,4,true +7,20,true +7,28,true +7,29,true +7,31,true +7,42,true +8,15,true +8,17,true +8,20,true +8,27,true +8,29,true +8,34,true +8,39,true +9,8,true +9,12,true +9,27,true +9,28,true +9,32,true +10,2,true +10,38,true +11,46,true +11,49,true +12,6,true +12,11,true +12,16,true +12,31,true +12,47,true +13,3,true +13,19,true +13,20,true +13,34,true +13,37,true +13,39,true +14,7,true +14,23,true +14,30,true +14,34,true +14,43,true +16,4,true +16,8,true +16,10,true +16,15,true +16,25,true +16,36,true +17,0,true +17,11,true +17,27,true +17,29,true +17,43,true +17,44,true +17,46,true +17,49,true +18,9,true +18,10,true +18,12,true +18,26,true +18,37,true +19,14,true +19,24,true +20,21,true +20,26,true +20,30,true +20,31,true +20,39,true +21,18,true +21,25,true +21,26,true +21,30,true +22,21,true +22,34,true +22,35,true +22,37,true +22,39,true +22,45,true +22,46,true +23,8,true +23,12,true +23,14,true +23,33,true +23,35,true +23,49,true +24,7,true +24,23,true +24,29,true +24,33,true +24,40,true +24,46,true +25,6,true +25,30,true +25,36,true +25,39,true +25,43,true +25,46,true +26,30,true +26,32,true +26,42,true +27,7,true +27,31,true +27,41,true +27,44,true +28,0,true +28,1,true +28,11,true +28,13,true +28,15,true +28,18,true +28,19,true +28,35,true +29,8,true +29,23,true +29,33,true +29,43,true +30,10,true +30,16,true +30,31,true +30,38,true +30,45,true +30,46,true +31,1,true +31,27,true +31,28,true +31,29,true +31,30,true +32,6,true +32,7,true +32,8,true +32,9,true +32,31,true +32,33,true +32,36,true +33,25,true +33,47,true +34,2,true +34,9,true +34,16,true +34,23,true +34,25,true +34,27,true +34,32,true +34,40,true +35,19,true +35,20,true +35,28,true +35,31,true +35,45,true +36,0,true +36,4,true +36,8,true +36,12,true +36,22,true +36,23,true +37,1,true +37,21,true +37,49,true +38,5,true +38,7,true +38,19,true +38,27,true +38,29,true +38,46,true +38,47,true +39,4,true +39,6,true +39,7,true +39,10,true +39,32,true +39,33,true +39,36,true +39,48,true +40,23,true +40,42,true +42,0,true +42,1,true +42,10,true +42,14,true +42,16,true +42,28,true +42,37,true +42,46,true +43,10,true +43,12,true +43,14,true +44,4,true +44,10,true +44,11,true +44,20,true +44,23,true +45,22,true +45,23,true +45,25,true +45,30,true +45,35,true +45,40,true +46,7,true +46,13,true +46,15,true +46,27,true +46,28,true +46,33,true +46,34,true +46,39,true +46,41,true +46,45,true +46,49,true +47,7,true +47,18,true +47,29,true +47,34,true +47,37,true +47,42,true +47,49,true +48,6,true +48,7,true +48,16,true +48,17,true +49,3,true +49,27,true +49,46,true + +//--- input.gr +func withDamping(degree:int, damping:real) -> real { + return cast(degree) / damping; +} + +func PR(graph: Matrix) -> Vector { + iterations = int(10); + damping = real(0.85); + n = graph.nrows; + teleport = (real(1.0) - damping) / cast(n); + rdiff = real(1.0); + + d_out = reduceRows(cast(graph)); + + d = apply(withDamping, d_out, damping); + + connected = reduceRows(graph); + sinks = Vector(n); + sinks[:] = bool(true); + + pr = Vector(n); + pr[:] = real(1.0) / cast(n); + + for i in int(0):iterations { + sink_pr = Vector(n); + sink_pr = pr; + redist = (damping / cast(n)) * reduce(sink_pr); + + w = pr (./) d; + + pr[:] = teleport + redist; + pr += cast(graph).T * w; + } + + return pr; +} + +//--- driver.py +from pr import PR +import duckdb +from pathlib import Path + +dir = Path(__file__).parent +conn = duckdb.connect() +vertices = conn.sql("SELECT generate_series AS c0 FROM generate_series(0, 49)") +edges = conn.read_csv(dir / 'edges.csv') + +res = PR(conn, edges, vertices) +conn.sql("SELECT c0, round(c1, 4) FROM res ORDER BY c0").to_csv(str(dir / 'output.csv')) + +//--- reference.csv +c0,"round(c1, 4)" +0,0.0123 +1,0.0185 +2,0.0209 +3,0.0118 +4,0.0181 +5,0.0137 +6,0.0177 +7,0.034 +8,0.0225 +9,0.0132 +10,0.0265 +11,0.0139 +12,0.0203 +13,0.009 +14,0.0367 +15,0.0177 +16,0.0203 +17,0.013 +18,0.0127 +19,0.0168 +20,0.0191 +21,0.0129 +22,0.0088 +23,0.0329 +24,0.0107 +25,0.0237 +26,0.0167 +27,0.0338 +28,0.0247 +29,0.0253 +30,0.0343 +31,0.035 +32,0.0146 +33,0.0216 +34,0.0202 +35,0.0151 +36,0.0148 +37,0.0132 +38,0.0236 +39,0.0181 +40,0.0139 +41,0.0136 +42,0.0252 +43,0.0199 +44,0.0169 +45,0.0226 +46,0.0372 +47,0.0204 +48,0.0171 +49,0.0245 From a1e01ebe276a29033e6665b15d972d5f0c1e48d5 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Thu, 14 May 2026 10:42:56 +0000 Subject: [PATCH 11/15] Add WCC test. --- compiler/test/sql/wcc.gr | 67 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 compiler/test/sql/wcc.gr diff --git a/compiler/test/sql/wcc.gr b/compiler/test/sql/wcc.gr new file mode 100644 index 0000000..ac2e80b --- /dev/null +++ b/compiler/test/sql/wcc.gr @@ -0,0 +1,67 @@ +// RUN: split-file %s %t +// RUN: graphalg-translate --import-graphalg %t/input.gr > %t/parsed.mlir +// RUN: graphalg-opt %t/parsed.mlir --graphalg-to-core-pipeline --graphalg-verify-loop-bounds --graphalg-explicate-sparsity --graphalg-split-aggregate --graphalg-loop-aggregate --graphalg-to-rel > %t/rel.mlir +// RUN: garel-translate --export-sql %t/rel.mlir > %t/wcc.py +// RUN: touch %t/__init__.py +// RUN: python3 %t/driver.py +// RUN: diff %t/reference.csv %t/output.csv + +//--- edges.csv +c0,c1,c2 +0,1,true +0,2,true +1,0,true +1,2,true +1,3,true +3,1,true +5,6,true +5,7,true +6,5,true +8,2,true + +//--- input.gr +func WCC(graph: Matrix) -> Matrix { + id = Vector(graph.nrows); + id[:] = bool(true); + label = diag(id); + + for i in graph.nrows { + // Keep current label + alternatives = label; + // Labels reachable with a forward step + alternatives += graph * label; + // Labels reachable with a backward step + alternatives += graph.T * label; + + // Select a new label + label = pickAny(alternatives); + } + + return label; +} + +//--- driver.py +from wcc import WCC +import duckdb +from pathlib import Path + +dir = Path(__file__).parent +conn = duckdb.connect() +vertices = conn.sql("SELECT generate_series AS c0 FROM generate_series(0, 8)") +edges = conn.read_csv(dir / 'edges.csv') +source = conn.sql("SELECT 0 AS c0, true AS c1") + +res = WCC(conn, edges, vertices) +res.sort('c0').to_csv(str(dir / 'output.csv')) + +//--- reference.csv +c0,c1,c2 +0,0,true +1,0,true +2,0,true +3,0,true +4,4,true +5,5,true +6,5,true +7,5,true +8,0,true From c9512fac95d039d81ae0c6b59552758285e5776f Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Mon, 18 May 2026 21:23:36 +0200 Subject: [PATCH 12/15] Add DuckDB dependency. --- docker/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 0d60e0e..ecbdb05 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -36,7 +36,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ zlib1g-dev \ && rm -rf /var/lib/apt/lists/* -RUN pip install --break-system-packages lit +RUN pip install --break-system-packages lit duckdb # - Use clang-20 as default compiler # - Allow user 'ubuntu' to run commands as root From a04a1803606e56cd304994cc0f8d616683824631 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Thu, 28 May 2026 15:52:48 +0000 Subject: [PATCH 13/15] WIP: dialect flag. --- compiler/tools/garel-translate.cpp | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/compiler/tools/garel-translate.cpp b/compiler/tools/garel-translate.cpp index 679ff7c..aac8a4b 100644 --- a/compiler/tools/garel-translate.cpp +++ b/compiler/tools/garel-translate.cpp @@ -11,7 +11,25 @@ #include "garel/GARelSQL.h" #include "graphalg/GraphAlgDialect.h" +enum class SQLDialect { + DUCKDB_PYTHON, + UMBRA_ITERATE, +}; + +namespace cmd { + +using namespace llvm; + +cl::opt sqlDialect( + "sql-dialect", cl::desc("The SQL dialect to export"), + cl::init(SQLDialect::DUCKDB_PYTHON), + cl::values(clEnumValN(SQLDialect::DUCKDB_PYTHON, "duckdb_python", + "DuckDB (with Python driver for control flow)"), + clEnumValN(SQLDialect::UMBRA_ITERATE, "umbra", "Umbra"))); +} // namespace cmd + int main(int argc, char *argv[]) { + // TODO: Use dialect flag. mlir::TranslateFromMLIRRegistration exportSQL( "export-sql", "export to SQL", garel::translateToSQL, [](mlir::DialectRegistry ®istry) { From 53a0104b99268996ca13c3fa9bc758bb20ae12cc Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Fri, 29 May 2026 12:09:45 +0000 Subject: [PATCH 14/15] Umbra dialect. --- compiler/include/garel/GARelSQL.h | 8 +- compiler/src/garel/GARelSQL.cpp | 148 ++++++++++++++++++++++++++--- compiler/tools/garel-translate.cpp | 23 +++-- 3 files changed, 155 insertions(+), 24 deletions(-) diff --git a/compiler/include/garel/GARelSQL.h b/compiler/include/garel/GARelSQL.h index 5aa2ebc..9395255 100644 --- a/compiler/include/garel/GARelSQL.h +++ b/compiler/include/garel/GARelSQL.h @@ -6,6 +6,12 @@ namespace garel { -mlir::LogicalResult translateToSQL(mlir::Operation *op, llvm::raw_ostream &os); +enum class SQLDialect { + DUCKDB_PYTHON, + UMBRA_ITERATE, +}; + +mlir::LogicalResult translateToSQL(mlir::Operation *op, llvm::raw_ostream &os, + SQLDialect dialect); } // namespace garel diff --git a/compiler/src/garel/GARelSQL.cpp b/compiler/src/garel/GARelSQL.cpp index 4e6c103..cf95a4e 100644 --- a/compiler/src/garel/GARelSQL.cpp +++ b/compiler/src/garel/GARelSQL.cpp @@ -1,21 +1,22 @@ #include #include #include +#include #include #include #include #include +#include #include #include +#include #include #include #include "garel/GARelAttr.h" #include "garel/GARelOps.h" +#include "garel/GARelSQL.h" #include "garel/GARelTypes.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/Location.h" -#include "llvm/ADT/StringRef.h" namespace garel { @@ -24,8 +25,9 @@ namespace { class SQLTranslator { private: llvm::raw_ostream &_os; - std::size_t _indentLevel = 0; + SQLDialect _dialect; + std::size_t _indentLevel = 0; llvm::DenseMap _valMap; std::size_t _tempCount; @@ -40,9 +42,13 @@ class SQLTranslator { } mlir::LogicalResult translate(mlir::func::FuncOp op); + mlir::LogicalResult translateDuckDB(mlir::func::FuncOp op); + mlir::LogicalResult translateUmbra(mlir::func::FuncOp op); mlir::LogicalResult translate(mlir::Value val); mlir::LogicalResult translate(mlir::Operation *op); mlir::LogicalResult translate(ForOp op); + mlir::LogicalResult translateDuckDB(ForOp op); + mlir::LogicalResult translateUmbra(ForOp op); mlir::LogicalResult translate(ConstantOp op); mlir::LogicalResult translate(AggregateOp op); mlir::LogicalResult translate(ProjectOp op); @@ -70,7 +76,8 @@ class SQLTranslator { mlir::LogicalResult translateMul(mlir::Operation *op); public: - SQLTranslator(llvm::raw_ostream &os) : _os(os) {} + SQLTranslator(llvm::raw_ostream &os, SQLDialect dialect) + : _os(os), _dialect(dialect) {} mlir::LogicalResult translate(mlir::ModuleOp op); }; @@ -93,6 +100,15 @@ mlir::LogicalResult SQLTranslator::translate(mlir::ModuleOp op) { } mlir::LogicalResult SQLTranslator::translate(mlir::func::FuncOp op) { + switch (_dialect) { + case SQLDialect::DUCKDB_PYTHON: + return translateDuckDB(op); + case SQLDialect::UMBRA_ITERATE: + return translateUmbra(op); + } +} + +mlir::LogicalResult SQLTranslator::translateDuckDB(mlir::func::FuncOp op) { auto name = op.getSymName(); _os << "def " << name << "(conn"; for (auto [i, arg] : llvm::enumerate(op.getBody().getArguments())) { @@ -129,6 +145,30 @@ mlir::LogicalResult SQLTranslator::translate(mlir::func::FuncOp op) { return mlir::success(); } +mlir::LogicalResult SQLTranslator::translateUmbra(mlir::func::FuncOp op) { + auto name = op.getSymName(); + _os << "-- " << name << "\n"; + + // Assume arguments are tables named farg0, farg1, ... + for (auto [i, arg] : llvm::enumerate(op.getBody().getArguments())) { + auto varName = std::string("farg") + std::to_string(i); + _valMap[arg] = varName; + } + + auto retOp = + llvm::cast(op.getBody().front().getTerminator()); + if (retOp.getNumOperands() != 1) { + return retOp.emitOpError("expected a single return value"); + } + + _os << "SELECT * FROM "; + if (mlir::failed(translate(retOp.getOperand(0)))) { + return mlir::failure(); + } + + return mlir::success(); +} + mlir::LogicalResult SQLTranslator::translate(mlir::Value val) { if (_valMap.contains(val)) { _os << "(SELECT * FROM " << _valMap[val] << ")"; @@ -176,14 +216,23 @@ mlir::LogicalResult SQLTranslator::translate(mlir::Operation *op) { } mlir::LogicalResult SQLTranslator::translate(ForOp op) { + switch (_dialect) { + case SQLDialect::DUCKDB_PYTHON: + return translateDuckDB(op); + case SQLDialect::UMBRA_ITERATE: + return translateUmbra(op); + } +} + +mlir::LogicalResult SQLTranslator::translateDuckDB(ForOp op) { auto &body = op.getBody().front(); // Initialize temporary tables for loop state llvm::SmallVector stateTables; - for (auto i : llvm::seq(op.getInit().size())) { + for (auto [i, init] : llvm::enumerate(op.getInit())) { auto temp = newTemp(); indent(); _os << "conn.execute(\"\"\"CREATE TABLE " << temp << " AS "; - if (mlir::failed(translate(op.getInit()[i]))) { + if (mlir::failed(translate(init))) { return mlir::failure(); } _os << "\"\"\")\n"; @@ -257,6 +306,77 @@ mlir::LogicalResult SQLTranslator::translate(ForOp op) { return mlir::success(); } +mlir::LogicalResult SQLTranslator::translateUmbra(ForOp op) { + auto &body = op.getBody().front(); + _os << "(SELECT * FROM umbra.iterate(\n"; + _indentLevel++; + + // Define the order in which we add the states. This matters because the first + // state becomes the result of the op. + llvm::SmallVector stateOrder{op.getResultIdx()}; + for (auto i : llvm::seq(op.getInit().size())) { + if (i != op.getResultIdx()) { + stateOrder.push_back(i); + } + } + + // Initial state + llvm::SmallVector stateTables(op.getInit().size()); + bool first = true; + for (auto i : stateOrder) { + if (first) { + first = false; + } else { + _os << ",\n"; + } + + auto temp = newTemp(); + indent(); + _os << temp << "_init => TABLE"; + if (mlir::failed(translate(op.getInit()[i]))) { + return mlir::failure(); + } + + stateTables[i] = temp; + _valMap[body.getArgument(i)] = temp; + } + + // Next state + auto yieldOp = llvm::cast(op.getBody().front().getTerminator()); + for (auto i : stateOrder) { + _os << ",\n"; + + indent(); + _os << stateTables[i] << "_next => TABLE"; + if (mlir::failed(translate(yieldOp.getInputs()[i]))) { + return mlir::failure(); + } + } + + // Loop counter + _os << ",\n"; + indent(); + _os << "counter_init => TABLE(SELECT 0 AS i),\n"; + indent(); + _os << "counter_next => TABLE(SELECT i + 1 AS i FROM counter)"; + + if (!op.getUntil().empty()) { + return op->emitOpError("until not yet supported"); + } else { + _os << ",\n"; + indent(); + _os << "until => TABLE(SELECT i = c0 FROM counter,"; + if (mlir::failed(translate(op.getIters()))) { + return mlir::failure(); + } + _os << ")"; + } + + _os << "))\n"; + _indentLevel--; + return mlir::success(); +} + mlir::LogicalResult SQLTranslator::translateConstant(mlir::Location loc, mlir::Attribute attr) { if (auto boolAttr = llvm::dyn_cast(attr)) { @@ -323,8 +443,13 @@ mlir::LogicalResult SQLTranslator::translate(AggregateOp op) { } _os << translateAggregateFunc(agg.getFunc()) << "("; - llvm::interleaveComma(agg.getInputs(), _os, - [&](ColumnIdx idx) { _os << "c" << idx; }); + if (agg.getFunc() == AggregateFunc::COUNT) { + assert(agg.getInputs().empty()); + _os << "*"; + } else { + llvm::interleaveComma(agg.getInputs(), _os, + [&](ColumnIdx idx) { _os << "c" << idx; }); + } _os << ") AS c" << colOut++; } @@ -626,8 +751,9 @@ mlir::LogicalResult SQLTranslator::translate(mlir::arith::SIToFPOp op) { return mlir::success(); } -mlir::LogicalResult translateToSQL(mlir::Operation *op, llvm::raw_ostream &os) { - SQLTranslator translator(os); +mlir::LogicalResult translateToSQL(mlir::Operation *op, llvm::raw_ostream &os, + SQLDialect dialect) { + SQLTranslator translator(os, dialect); auto moduleOp = llvm::dyn_cast(op); if (!moduleOp) { return op->emitOpError("expected a module"); diff --git a/compiler/tools/garel-translate.cpp b/compiler/tools/garel-translate.cpp index aac8a4b..e330f03 100644 --- a/compiler/tools/garel-translate.cpp +++ b/compiler/tools/garel-translate.cpp @@ -11,27 +11,26 @@ #include "garel/GARelSQL.h" #include "graphalg/GraphAlgDialect.h" -enum class SQLDialect { - DUCKDB_PYTHON, - UMBRA_ITERATE, -}; - -namespace cmd { +namespace { using namespace llvm; -cl::opt sqlDialect( +cl::opt sqlDialect( "sql-dialect", cl::desc("The SQL dialect to export"), - cl::init(SQLDialect::DUCKDB_PYTHON), - cl::values(clEnumValN(SQLDialect::DUCKDB_PYTHON, "duckdb_python", + cl::init(garel::SQLDialect::DUCKDB_PYTHON), + cl::values(clEnumValN(garel::SQLDialect::DUCKDB_PYTHON, "duckdb_python", "DuckDB (with Python driver for control flow)"), - clEnumValN(SQLDialect::UMBRA_ITERATE, "umbra", "Umbra"))); -} // namespace cmd + clEnumValN(garel::SQLDialect::UMBRA_ITERATE, "umbra", "Umbra"))); + +} // namespace int main(int argc, char *argv[]) { // TODO: Use dialect flag. mlir::TranslateFromMLIRRegistration exportSQL( - "export-sql", "export to SQL", garel::translateToSQL, + "export-sql", "export to SQL", + [](mlir::Operation *op, llvm::raw_ostream &os) { + return garel::translateToSQL(op, os, sqlDialect); + }, [](mlir::DialectRegistry ®istry) { registry.insert(); registry.insert(); From c24cd553b312f35467fab986cf5e2490e4b637fe Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Fri, 29 May 2026 12:48:21 +0000 Subject: [PATCH 15/15] Support until clauses. --- compiler/src/garel/GARelSQL.cpp | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/compiler/src/garel/GARelSQL.cpp b/compiler/src/garel/GARelSQL.cpp index cf95a4e..ec85d26 100644 --- a/compiler/src/garel/GARelSQL.cpp +++ b/compiler/src/garel/GARelSQL.cpp @@ -94,6 +94,7 @@ mlir::LogicalResult SQLTranslator::translate(mlir::ModuleOp op) { if (mlir::failed(translate(funcOp))) { return mlir::failure(); } + _os << "\n"; } return mlir::success(); @@ -360,18 +361,32 @@ mlir::LogicalResult SQLTranslator::translateUmbra(ForOp op) { indent(); _os << "counter_next => TABLE(SELECT i + 1 AS i FROM counter)"; + // until clause for iteration count + _os << ",\n"; + indent(); + _os << "until => TABLE(SELECT i = c0 FROM counter,"; + if (mlir::failed(translate(op.getIters()))) { + return mlir::failure(); + } + if (!op.getUntil().empty()) { - return op->emitOpError("until not yet supported"); - } else { - _os << ",\n"; - indent(); - _os << "until => TABLE(SELECT i = c0 FROM counter,"; - if (mlir::failed(translate(op.getIters()))) { + // Additional until clause + _os << " UNION ALL SELECT c0 FROM "; + + auto &body = op.getUntil().front(); + auto yieldOp = llvm::cast(body.getTerminator()); + // Map block arguments + for (auto i : stateOrder) { + _valMap[body.getArgument(i)] = stateTables[i]; + } + + if (mlir::failed(translate(yieldOp.getInputs()[0]))) { return mlir::failure(); } - _os << ")"; } + _os << ")"; + _os << "))\n"; _indentLevel--; return mlir::success();