From 2a4d7ddce9b5f05b8ed4b014b598d2a431c3521b Mon Sep 17 00:00:00 2001 From: Jay Phan Date: Sat, 2 May 2026 10:55:35 -0400 Subject: [PATCH] analyzer completed --- Makefile | 7 +- src/sql/analyzer.cpp | 224 +++++++++++++++++++ src/sql/analyzer.h | 107 +++++++++ tests/sql/test_analyzer.cpp | 435 ++++++++++++++++++++++++++++++++++++ 4 files changed, 771 insertions(+), 2 deletions(-) create mode 100644 src/sql/analyzer.cpp create mode 100644 src/sql/analyzer.h create mode 100644 tests/sql/test_analyzer.cpp diff --git a/Makefile b/Makefile index 6c586a9..da3546a 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,8 @@ DBMS_OBJS = $(BUILD_DIR)/main.o \ $(BUILD_DIR)/src/storage/slotted_page.o \ $(BUILD_DIR)/src/storage/heap_file.o \ $(BUILD_DIR)/src/sql/tuple.o \ - $(BUILD_DIR)/src/sql/catalog.o + $(BUILD_DIR)/src/sql/catalog.o \ + $(BUILD_DIR)/src/sql/analyzer.o TEST_OBJS = $(BUILD_DIR)/tests/test_parser.o \ $(BUILD_DIR)/tests/storage/test_disk_manager.o \ $(BUILD_DIR)/tests/storage/test_buffer_pool.o \ @@ -19,13 +20,15 @@ TEST_OBJS = $(BUILD_DIR)/tests/test_parser.o \ $(BUILD_DIR)/tests/storage/test_integration.o \ $(BUILD_DIR)/tests/sql/test_tuple.o \ $(BUILD_DIR)/tests/sql/test_catalog.o \ + $(BUILD_DIR)/tests/sql/test_analyzer.o \ $(BUILD_DIR)/src/parser.o \ $(BUILD_DIR)/src/storage/disk_manager.o \ $(BUILD_DIR)/src/storage/buffer_pool.o \ $(BUILD_DIR)/src/storage/slotted_page.o \ $(BUILD_DIR)/src/storage/heap_file.o \ $(BUILD_DIR)/src/sql/tuple.o \ - $(BUILD_DIR)/src/sql/catalog.o + $(BUILD_DIR)/src/sql/catalog.o \ + $(BUILD_DIR)/src/sql/analyzer.o dbms: $(DBMS_OBJS) $(CXX) $(CXXFLAGS) -o $@ $^ diff --git a/src/sql/analyzer.cpp b/src/sql/analyzer.cpp new file mode 100644 index 0000000..65bcee5 --- /dev/null +++ b/src/sql/analyzer.cpp @@ -0,0 +1,224 @@ +#include "src/sql/analyzer.h" + +#include +#include +#include +#include +#include + +Type resultTypeOf(const BoundExpr& e) { + return std::visit( + [](const auto& v) -> Type { + using T = std::decay_t; + if constexpr (std::is_same_v) return v.result_type; + else if constexpr (std::is_same_v) return v.result_type; + else /* std::unique_ptr */ return v->result_type; + }, + e); +} + +BoundColumnRef Analyzer::resolveColumn(const std::string& name, + const Scope& scope) const { + // "table.column" form: pin to one specific FROM table. + const auto dot = name.find('.'); + if (dot != std::string::npos) { + const std::string tbl = name.substr(0, dot); + const std::string col = name.substr(dot + 1); + for (size_t t = 0; t < scope.tables.size(); ++t) { + if (scope.tables[t]->name == tbl) { + const size_t c = scope.tables[t]->schema.indexOf(col); + if (c == Schema::kNotFound) { + throw std::runtime_error("no such column '" + col + + "' in table '" + tbl + "'"); + } + return {t, c, scope.tables[t]->schema.columns[c].type}; + } + } + throw std::runtime_error("no such table '" + tbl + "' in FROM clause"); + } + + // Bare name: search every FROM table; reject if it appears in two. + size_t found_t = 0; + size_t found_c = 0; + bool found = false; + for (size_t t = 0; t < scope.tables.size(); ++t) { + const size_t c = scope.tables[t]->schema.indexOf(name); + if (c != Schema::kNotFound) { + if (found) { + throw std::runtime_error("ambiguous column '" + name + + "': appears in multiple FROM tables"); + } + found = true; + found_t = t; + found_c = c; + } + } + if (!found) throw std::runtime_error("no such column '" + name + "'"); + return {found_t, found_c, scope.tables[found_t]->schema.columns[found_c].type}; +} + +BoundLiteral Analyzer::analyzeLiteral(const std::string& text, + bool value_is_string, + Type expected) const { + if (value_is_string) { + if (expected != Type::Text) { + throw std::runtime_error( + "string literal cannot be compared with a non-Text column"); + } + return BoundLiteral{Value::Text(text), Type::Text}; + } + + // Numeric literal. Parse under the column's expected type. + switch (expected) { + case Type::Int32: { + try { + const long long v = std::stoll(text); + if (v < std::numeric_limits::min() || + v > std::numeric_limits::max()) { + throw std::runtime_error( + "integer literal out of Int32 range: " + text); + } + return BoundLiteral{Value::Int32(static_cast(v)), + Type::Int32}; + } catch (const std::invalid_argument&) { + throw std::runtime_error("invalid integer literal: " + text); + } catch (const std::out_of_range&) { + throw std::runtime_error("integer literal out of range: " + text); + } + } + case Type::Int64: { + try { + const long long v = std::stoll(text); + return BoundLiteral{Value::Int64(static_cast(v)), + Type::Int64}; + } catch (const std::invalid_argument&) { + throw std::runtime_error("invalid integer literal: " + text); + } catch (const std::out_of_range&) { + throw std::runtime_error("integer literal out of range: " + text); + } + } + case Type::Bool: { + // Parser has no TRUE/FALSE keyword; allow 0/1 against Bool columns. + if (text == "0") return BoundLiteral{Value::Bool(false), Type::Bool}; + if (text == "1") return BoundLiteral{Value::Bool(true), Type::Bool}; + throw std::runtime_error( + "expected Bool literal (0 or 1), got: " + text); + } + case Type::Text: + throw std::runtime_error( + "Text column cannot be compared with a numeric literal"); + } + throw std::runtime_error("analyzeLiteral: unknown expected type"); +} + +Type Analyzer::checkBinaryOp(Op op, Type lhs, Type rhs) const { + if (lhs != rhs) { + throw std::runtime_error("binary operator: type mismatch on operands"); + } + switch (op) { + case Op::Eq: + case Op::Neq: + // Equality is defined for every primitive type we have. + return Type::Bool; + case Op::Lt: + case Op::Gt: + case Op::Leq: + case Op::Geq: + // Ordering is numeric-only for now (no lexicographic Text yet). + if (lhs != Type::Int32 && lhs != Type::Int64) { + throw std::runtime_error( + "ordering operator requires numeric operands"); + } + return Type::Bool; + } + throw std::runtime_error("checkBinaryOp: unknown operator"); +} + +BoundExpr Analyzer::analyzeCondition(const Condition& c, + const Scope& scope) const { + // Parser's WHERE shape is always . Resolve the + // column first so we can use its type as the literal's expected type. + BoundColumnRef col = resolveColumn(c.column, scope); + BoundLiteral lit = analyzeLiteral(c.value, c.value_is_string, col.result_type); + const Type rt = checkBinaryOp(c.op, col.result_type, lit.result_type); + + auto bop = std::make_unique(); + bop->op = c.op; + bop->lhs = std::move(col); + bop->rhs = std::move(lit); + bop->result_type = rt; + return BoundExpr{std::move(bop)}; +} + +BoundSelect Analyzer::analyze(const SelectQuery& q) const { + BoundSelect out; + + // FROM clause + JOIN tables. Resolve every named table up front so the + // ON clauses (and SELECT list, and WHERE) all see the full scope; this + // matches SQL semantics for inner joins and keeps resolution uniform. + const Catalog::TableInfo* info = cat_.getTable(q.table); + if (info == nullptr) { + throw std::runtime_error("no such table '" + q.table + "'"); + } + out.from_tables.push_back(info); + + for (const auto& j : q.joins) { + const Catalog::TableInfo* ji = cat_.getTable(j.table); + if (ji == nullptr) { + throw std::runtime_error("no such table '" + j.table + "'"); + } + // Without aliases, the same table cannot appear twice — every + // qualified reference would be ambiguous and bare references + // would silently bind to the first copy. + for (const auto* t : out.from_tables) { + if (t->name == j.table) { + throw std::runtime_error( + "table '" + j.table + + "' appears more than once in FROM/JOIN (aliases unsupported)"); + } + } + out.from_tables.push_back(ji); + } + + Scope scope; + scope.tables = out.from_tables; + + // ON clauses. Each join is ` = `; both columns are resolved + // against the full scope and must share a result type. + for (const auto& j : q.joins) { + BoundColumnRef l = resolveColumn(j.left, scope); + BoundColumnRef r = resolveColumn(j.right, scope); + if (l.result_type != r.result_type) { + throw std::runtime_error( + "JOIN ON: type mismatch between '" + j.left + + "' and '" + j.right + "'"); + } + out.joins.push_back(BoundJoin{l, r}); + } + + // SELECT list. + if (q.select_all) { + out.select_all = true; + for (size_t t = 0; t < out.from_tables.size(); ++t) { + const auto& cols = out.from_tables[t]->schema.columns; + for (size_t c = 0; c < cols.size(); ++c) { + out.select_list.push_back(BoundColumnRef{t, c, cols[c].type}); + } + } + } else { + for (const auto& name : q.columns) { + out.select_list.push_back(resolveColumn(name, scope)); + } + } + + // WHERE clause. + if (q.where) { + BoundExpr w = analyzeCondition(*q.where, scope); + if (resultTypeOf(w) != Type::Bool) { + throw std::runtime_error("WHERE clause must produce a Bool"); + } + out.where = std::move(w); + } + + return out; +} diff --git a/src/sql/analyzer.h b/src/sql/analyzer.h new file mode 100644 index 0000000..879649d --- /dev/null +++ b/src/sql/analyzer.h @@ -0,0 +1,107 @@ +#pragma once + +#include "src/parser.h" +#include "src/sql/catalog.h" +#include "src/sql/tuple.h" + +#include +#include +#include +#include +#include +#include + +// Bound (resolved) parallel of the parser's AST. Strings — column names, +// table names — have all been replaced by numeric indices into a schema, +// so the executor can look up values by integer indexing rather than by +// re-resolving names per row. + +struct BoundColumnRef { + size_t table_index; // index into BoundSelect::from_tables + size_t column_index; // index into that table's schema.columns + Type result_type; +}; + +struct BoundLiteral { + Value value; + Type result_type; // == value.type +}; + +struct BoundBinaryOp; // forward; needed because BoundExpr can recurse + +// std::unique_ptr works as a variant alternative even when +// BoundBinaryOp is incomplete here, because unique_ptr only needs the +// pointee's size to be known at construction/destruction (not declaration). +using BoundExpr = std::variant>; + +struct BoundBinaryOp { + Op op; + BoundExpr lhs; + BoundExpr rhs; + Type result_type; +}; + +// Result type accessor that handles every variant alternative uniformly. +Type resultTypeOf(const BoundExpr& e); + +// Resolved `JOIN ON = `. Both columns are bound +// against the full FROM scope, so either side may refer to any joined +// table (not only the two flanking this clause). The newly-joined table +// itself sits at from_tables[i + 1] for the i-th element of `joins`. +struct BoundJoin { + BoundColumnRef left; + BoundColumnRef right; +}; + +struct BoundSelect { + // Resolved FROM tables, in declaration order: index 0 is the FROM + // table, indices 1..n are JOIN tables in the order they appeared. + // A BoundColumnRef::table_index indexes into this vector. + std::vector from_tables; + + // One per JOIN clause, in source order. joins[i] introduces + // from_tables[i + 1]. + std::vector joins; + + // One BoundExpr per output column. For SELECT *, populated with one + // BoundColumnRef per (table, column) across every from_tables entry. + std::vector select_list; + + // Bound WHERE predicate, if the query had one. result_type is Bool. + std::optional where; + + // Tracks whether the parsed query was SELECT *. Mainly informational + // — select_list is already expanded. + bool select_all = false; +}; + +// Walks a parsed SelectQuery against a Catalog, resolving every name and +// type-checking every operator. Throws std::runtime_error on any name +// resolution failure or type mismatch. Pure: no I/O of its own beyond the +// catalog lookups it performs. +class Analyzer { +public: + explicit Analyzer(const Catalog& cat) : cat_(cat) {} + + BoundSelect analyze(const SelectQuery& q) const; + +private: + // Set of tables visible to expression resolution. For now, just the + // FROM clause's tables in order. + struct Scope { + std::vector tables; + }; + + BoundColumnRef resolveColumn(const std::string& name, + const Scope& scope) const; + BoundLiteral analyzeLiteral(const std::string& text, + bool value_is_string, + Type expected) const; + BoundExpr analyzeCondition(const Condition& c, + const Scope& scope) const; + Type checkBinaryOp(Op op, Type lhs, Type rhs) const; + + const Catalog& cat_; +}; diff --git a/tests/sql/test_analyzer.cpp b/tests/sql/test_analyzer.cpp new file mode 100644 index 0000000..4245a8e --- /dev/null +++ b/tests/sql/test_analyzer.cpp @@ -0,0 +1,435 @@ +#include "tests/vendor/doctest.h" + +#include "src/parser.h" +#include "src/sql/analyzer.h" +#include "src/sql/catalog.h" +#include "src/sql/tuple.h" +#include "src/storage/buffer_pool.h" +#include "src/storage/disk_manager.h" +#include "tests/test_util.h" + +#include +#include +#include + +namespace { + +Schema usersSchema() { + return Schema{{ + {"id", Type::Int32, false}, + {"name", Type::Text, false}, + {"age", Type::Int32, true}, + {"salary", Type::Int64, false}, + {"active", Type::Bool, false}, + }}; +} + +Schema postsSchema() { + return Schema{{ + {"id", Type::Int32, false}, + {"title", Type::Text, false}, + {"user_id", Type::Int32, false}, + }}; +} + +// Convenience: parse + analyze in one go. +BoundSelect parseAnalyze(const Analyzer& az, const std::string& sql) { + Parser p(sql); + SelectQuery q = p.parse(); + return az.analyze(q); +} + +// Pull out a BoundColumnRef from a select_list slot, asserting it's that +// alternative. Lets test cases stay terse. +const BoundColumnRef& asColumn(const BoundExpr& e) { + return std::get(e); +} + +// Pull out the BoundBinaryOp from a BoundExpr. +const BoundBinaryOp& asBinaryOp(const BoundExpr& e) { + return *std::get>(e); +} + +} // namespace + +TEST_CASE("SELECT * expands every column of the FROM table") { + TempFile tf; + DiskManager dm(tf.path()); + BufferPool bp(4, &dm); + Catalog cat = Catalog::create(&bp); + cat.createTable("users", usersSchema()); + Analyzer az(cat); + + auto b = parseAnalyze(az, "SELECT * FROM users"); + + REQUIRE(b.from_tables.size() == 1); + CHECK(b.from_tables[0]->name == "users"); + CHECK(b.select_all); + REQUIRE(b.select_list.size() == 5); + for (size_t i = 0; i < 5; ++i) { + const auto& c = asColumn(b.select_list[i]); + CHECK(c.table_index == 0); + CHECK(c.column_index == i); + } + CHECK_FALSE(b.where.has_value()); +} + +TEST_CASE("SELECT with explicit columns resolves names to indices") { + TempFile tf; + DiskManager dm(tf.path()); + BufferPool bp(4, &dm); + Catalog cat = Catalog::create(&bp); + cat.createTable("users", usersSchema()); + Analyzer az(cat); + + auto b = parseAnalyze(az, "SELECT name, salary FROM users"); + + CHECK_FALSE(b.select_all); + REQUIRE(b.select_list.size() == 2); + CHECK(asColumn(b.select_list[0]).column_index == 1); + CHECK(asColumn(b.select_list[0]).result_type == Type::Text); + CHECK(asColumn(b.select_list[1]).column_index == 3); + CHECK(asColumn(b.select_list[1]).result_type == Type::Int64); +} + +TEST_CASE("WHERE on Int32 column with a numeric literal binds correctly") { + TempFile tf; + DiskManager dm(tf.path()); + BufferPool bp(4, &dm); + Catalog cat = Catalog::create(&bp); + cat.createTable("users", usersSchema()); + Analyzer az(cat); + + auto b = parseAnalyze(az, "SELECT * FROM users WHERE age > 18"); + + REQUIRE(b.where.has_value()); + const auto& bin = asBinaryOp(*b.where); + CHECK(bin.op == Op::Gt); + CHECK(bin.result_type == Type::Bool); + + const auto& col = std::get(bin.lhs); + CHECK(col.column_index == 2); // "age" + CHECK(col.result_type == Type::Int32); + + const auto& lit = std::get(bin.rhs); + CHECK(lit.result_type == Type::Int32); + CHECK(lit.value.i32 == 18); +} + +TEST_CASE("WHERE with a string literal targets a Text column") { + TempFile tf; + DiskManager dm(tf.path()); + BufferPool bp(4, &dm); + Catalog cat = Catalog::create(&bp); + cat.createTable("users", usersSchema()); + Analyzer az(cat); + + auto b = parseAnalyze(az, "SELECT * FROM users WHERE name = 'alice'"); + + REQUIRE(b.where.has_value()); + const auto& bin = asBinaryOp(*b.where); + CHECK(bin.op == Op::Eq); + CHECK(std::get(bin.rhs).value.text == "alice"); + CHECK(std::get(bin.rhs).result_type == Type::Text); +} + +TEST_CASE("WHERE on Int64 column parses the literal as Int64") { + TempFile tf; + DiskManager dm(tf.path()); + BufferPool bp(4, &dm); + Catalog cat = Catalog::create(&bp); + cat.createTable("users", usersSchema()); + Analyzer az(cat); + + auto b = parseAnalyze(az, "SELECT * FROM users WHERE salary >= 50000"); + const auto& bin = asBinaryOp(*b.where); + CHECK(bin.op == Op::Geq); + const auto& lit = std::get(bin.rhs); + CHECK(lit.result_type == Type::Int64); + CHECK(lit.value.i64 == 50000); +} + +TEST_CASE("WHERE on Bool column accepts 0 / 1 as the literal") { + TempFile tf; + DiskManager dm(tf.path()); + BufferPool bp(4, &dm); + Catalog cat = Catalog::create(&bp); + cat.createTable("users", usersSchema()); + Analyzer az(cat); + + auto b = parseAnalyze(az, "SELECT * FROM users WHERE active = 1"); + const auto& bin = asBinaryOp(*b.where); + const auto& lit = std::get(bin.rhs); + CHECK(lit.result_type == Type::Bool); + CHECK(lit.value.b == true); +} + +TEST_CASE("Qualified column reference (table.column) resolves") { + TempFile tf; + DiskManager dm(tf.path()); + BufferPool bp(4, &dm); + Catalog cat = Catalog::create(&bp); + cat.createTable("users", usersSchema()); + Analyzer az(cat); + + auto b = parseAnalyze(az, "SELECT users.id FROM users"); + REQUIRE(b.select_list.size() == 1); + const auto& col = asColumn(b.select_list[0]); + CHECK(col.table_index == 0); + CHECK(col.column_index == 0); + CHECK(col.result_type == Type::Int32); +} + +TEST_CASE("unknown table throws") { + TempFile tf; + DiskManager dm(tf.path()); + BufferPool bp(4, &dm); + Catalog cat = Catalog::create(&bp); + cat.createTable("users", usersSchema()); + Analyzer az(cat); + + CHECK_THROWS_AS(parseAnalyze(az, "SELECT * FROM missing"), + std::runtime_error); +} + +TEST_CASE("unknown column throws") { + TempFile tf; + DiskManager dm(tf.path()); + BufferPool bp(4, &dm); + Catalog cat = Catalog::create(&bp); + cat.createTable("users", usersSchema()); + Analyzer az(cat); + + CHECK_THROWS_AS(parseAnalyze(az, "SELECT bogus FROM users"), + std::runtime_error); + CHECK_THROWS_AS(parseAnalyze(az, "SELECT * FROM users WHERE bogus = 1"), + std::runtime_error); +} + +TEST_CASE("qualified column with wrong table name throws") { + TempFile tf; + DiskManager dm(tf.path()); + BufferPool bp(4, &dm); + Catalog cat = Catalog::create(&bp); + cat.createTable("users", usersSchema()); + Analyzer az(cat); + + CHECK_THROWS_AS(parseAnalyze(az, "SELECT posts.id FROM users"), + std::runtime_error); +} + +TEST_CASE("type-mismatched WHERE throws") { + TempFile tf; + DiskManager dm(tf.path()); + BufferPool bp(4, &dm); + Catalog cat = Catalog::create(&bp); + cat.createTable("users", usersSchema()); + Analyzer az(cat); + + // Numeric column compared against a string literal. + CHECK_THROWS_AS(parseAnalyze(az, "SELECT * FROM users WHERE age = 'foo'"), + std::runtime_error); + // Text column compared against a numeric literal. + CHECK_THROWS_AS(parseAnalyze(az, "SELECT * FROM users WHERE name = 5"), + std::runtime_error); +} + +TEST_CASE("ordering operator on Text column throws") { + TempFile tf; + DiskManager dm(tf.path()); + BufferPool bp(4, &dm); + Catalog cat = Catalog::create(&bp); + cat.createTable("users", usersSchema()); + Analyzer az(cat); + + // Even though both sides become Text, < is rejected on Text for now. + CHECK_THROWS_AS(parseAnalyze(az, "SELECT * FROM users WHERE name < 'm'"), + std::runtime_error); +} + +TEST_CASE("integer literal that overflows Int32 column throws") { + TempFile tf; + DiskManager dm(tf.path()); + BufferPool bp(4, &dm); + Catalog cat = Catalog::create(&bp); + cat.createTable("users", usersSchema()); + Analyzer az(cat); + + // 9999999999 > INT32_MAX (~2.1e9) so should be rejected against age (Int32), + // but a similar literal against salary (Int64) succeeds. + CHECK_THROWS_AS(parseAnalyze(az, "SELECT * FROM users WHERE age > 9999999999"), + std::runtime_error); + auto b = parseAnalyze(az, "SELECT * FROM users WHERE salary > 9999999999"); + REQUIRE(b.where.has_value()); +} + +TEST_CASE("Single JOIN resolves both ON columns against the full scope") { + TempFile tf; + DiskManager dm(tf.path()); + BufferPool bp(4, &dm); + Catalog cat = Catalog::create(&bp); + cat.createTable("users", usersSchema()); + cat.createTable("posts", postsSchema()); + Analyzer az(cat); + + auto b = parseAnalyze(az, + "SELECT users.name, posts.title FROM users " + "JOIN posts ON users.id = posts.user_id"); + + REQUIRE(b.from_tables.size() == 2); + CHECK(b.from_tables[0]->name == "users"); + CHECK(b.from_tables[1]->name == "posts"); + + REQUIRE(b.joins.size() == 1); + const auto& jn = b.joins[0]; + CHECK(jn.left.table_index == 0); // users + CHECK(jn.left.column_index == 0); // id + CHECK(jn.left.result_type == Type::Int32); + CHECK(jn.right.table_index == 1); // posts + CHECK(jn.right.column_index == 2); // user_id + CHECK(jn.right.result_type == Type::Int32); + + REQUIRE(b.select_list.size() == 2); + CHECK(asColumn(b.select_list[0]).table_index == 0); // users.name + CHECK(asColumn(b.select_list[0]).column_index == 1); + CHECK(asColumn(b.select_list[1]).table_index == 1); // posts.title + CHECK(asColumn(b.select_list[1]).column_index == 1); +} + +TEST_CASE("SELECT * across a JOIN expands every column of every FROM table") { + TempFile tf; + DiskManager dm(tf.path()); + BufferPool bp(4, &dm); + Catalog cat = Catalog::create(&bp); + cat.createTable("users", usersSchema()); + cat.createTable("posts", postsSchema()); + Analyzer az(cat); + + auto b = parseAnalyze(az, + "SELECT * FROM users JOIN posts ON users.id = posts.user_id"); + + // 5 (users) + 3 (posts) = 8 expanded columns, in table order. + REQUIRE(b.select_list.size() == 8); + CHECK(asColumn(b.select_list[0]).table_index == 0); + CHECK(asColumn(b.select_list[4]).table_index == 0); + CHECK(asColumn(b.select_list[5]).table_index == 1); + CHECK(asColumn(b.select_list[7]).table_index == 1); +} + +TEST_CASE("WHERE after a JOIN can reference either table") { + TempFile tf; + DiskManager dm(tf.path()); + BufferPool bp(4, &dm); + Catalog cat = Catalog::create(&bp); + cat.createTable("users", usersSchema()); + cat.createTable("posts", postsSchema()); + Analyzer az(cat); + + auto b = parseAnalyze(az, + "SELECT users.name FROM users JOIN posts ON users.id = posts.user_id " + "WHERE posts.title = 'hi'"); + + REQUIRE(b.where.has_value()); + const auto& bin = asBinaryOp(*b.where); + const auto& col = std::get(bin.lhs); + CHECK(col.table_index == 1); // posts + CHECK(col.column_index == 1); // title +} + +TEST_CASE("JOIN on an unknown table throws") { + TempFile tf; + DiskManager dm(tf.path()); + BufferPool bp(4, &dm); + Catalog cat = Catalog::create(&bp); + cat.createTable("users", usersSchema()); + Analyzer az(cat); + + CHECK_THROWS_AS( + parseAnalyze(az, + "SELECT * FROM users JOIN posts ON users.id = posts.user_id"), + std::runtime_error); +} + +TEST_CASE("JOIN ON with an unknown column throws") { + TempFile tf; + DiskManager dm(tf.path()); + BufferPool bp(4, &dm); + Catalog cat = Catalog::create(&bp); + cat.createTable("users", usersSchema()); + cat.createTable("posts", postsSchema()); + Analyzer az(cat); + + CHECK_THROWS_AS( + parseAnalyze(az, + "SELECT * FROM users JOIN posts ON users.id = posts.bogus"), + std::runtime_error); +} + +TEST_CASE("JOIN ON with mismatched column types throws") { + TempFile tf; + DiskManager dm(tf.path()); + BufferPool bp(4, &dm); + Catalog cat = Catalog::create(&bp); + cat.createTable("users", usersSchema()); + cat.createTable("posts", postsSchema()); + Analyzer az(cat); + + // users.name is Text, posts.user_id is Int32. + CHECK_THROWS_AS( + parseAnalyze(az, + "SELECT * FROM users JOIN posts ON users.name = posts.user_id"), + std::runtime_error); +} + +TEST_CASE("Repeating the same table in FROM/JOIN throws (no aliases yet)") { + TempFile tf; + DiskManager dm(tf.path()); + BufferPool bp(4, &dm); + Catalog cat = Catalog::create(&bp); + cat.createTable("users", usersSchema()); + Analyzer az(cat); + + CHECK_THROWS_AS( + parseAnalyze(az, + "SELECT * FROM users JOIN users ON users.id = users.id"), + std::runtime_error); +} + +TEST_CASE("Bare column appearing in two joined tables is ambiguous") { + TempFile tf; + DiskManager dm(tf.path()); + BufferPool bp(4, &dm); + Catalog cat = Catalog::create(&bp); + cat.createTable("users", usersSchema()); + cat.createTable("posts", postsSchema()); + Analyzer az(cat); + + // Both users and posts have an "id" column; bare "id" must be rejected. + CHECK_THROWS_AS( + parseAnalyze(az, + "SELECT id FROM users JOIN posts ON users.id = posts.user_id"), + std::runtime_error); +} + +TEST_CASE("Multiple chained JOINs all resolve") { + TempFile tf; + DiskManager dm(tf.path()); + BufferPool bp(4, &dm); + Catalog cat = Catalog::create(&bp); + cat.createTable("a", Schema{{{"x", Type::Int32, false}}}); + cat.createTable("b", Schema{{{"x", Type::Int32, false}, + {"y", Type::Int32, false}}}); + cat.createTable("c", Schema{{{"y", Type::Int32, false}, + {"z", Type::Int32, false}}}); + Analyzer az(cat); + + auto bs = parseAnalyze(az, + "SELECT a.x, c.z FROM a JOIN b ON a.x = b.x JOIN c ON b.y = c.y"); + + REQUIRE(bs.from_tables.size() == 3); + REQUIRE(bs.joins.size() == 2); + CHECK(bs.joins[0].left.table_index == 0); // a + CHECK(bs.joins[0].right.table_index == 1); // b + CHECK(bs.joins[1].left.table_index == 1); // b + CHECK(bs.joins[1].right.table_index == 2); // c +}