From 5fed31dd96c683c904dcc33925fa512586151fea Mon Sep 17 00:00:00 2001 From: Peter Adams <18162810+Maxteabag@users.noreply.github.com> Date: Wed, 6 May 2026 19:27:58 +0200 Subject: [PATCH] Switch MariaDB driver from mariadb-connector to PyMySQL PyMySQL is a pure-Python MySQL/MariaDB client, so it works on any platform without a system C library. This removes the macOS install friction (Homebrew keg-only mariadb-connector-c, MARIADB_CONFIG env var) and lets the adapter handle legacy charsets like TIS-620 and Latin1, which the C connector cannot read. Add MariaDB TIS-620 and Latin1 docker services and charset tests mirroring the MySQL ones. --- README.md | 2 +- aur/.SRCINFO | 3 +- aur/PKGBUILD | 3 +- infra/docker/docker-compose.test.yml | 42 +++ pyproject.toml | 4 +- .../connections/app/install_strategy.py | 1 - .../connections/providers/mariadb/adapter.py | 283 +++-------------- tests/conftest.py | 1 + tests/fixtures/mariadb.py | 8 +- tests/fixtures/mariadb_charset.py | 249 +++++++++++++++ .../test_stale_connection_reconnect.py | 6 +- tests/test_charset_support.py | 300 +++++++++++++++++- 12 files changed, 652 insertions(+), 250 deletions(-) create mode 100644 tests/fixtures/mariadb_charset.py diff --git a/README.md b/README.md index c62d9e5f..49c0bee9 100644 --- a/README.md +++ b/README.md @@ -278,7 +278,7 @@ Most of the time you can just run `sqlit` and connect. If a Python driver is mis | PostgreSQL / CockroachDB / Supabase | `psycopg2-binary` | `pipx inject sqlit-tui psycopg2-binary` | `python -m pip install psycopg2-binary` | | SQL Server | `mssql-python` | `pipx inject sqlit-tui mssql-python` | `python -m pip install mssql-python` | | MySQL | `PyMySQL` | `pipx inject sqlit-tui PyMySQL` | `python -m pip install PyMySQL` | -| MariaDB | `mariadb` | `pipx inject sqlit-tui mariadb` | `python -m pip install mariadb` | +| MariaDB | `PyMySQL` | `pipx inject sqlit-tui PyMySQL` | `python -m pip install PyMySQL` | | Oracle | `oracledb` | `pipx inject sqlit-tui oracledb` | `python -m pip install oracledb` | | DuckDB | `duckdb` | `pipx inject sqlit-tui duckdb` | `python -m pip install duckdb` | | ClickHouse | `clickhouse-connect` | `pipx inject sqlit-tui clickhouse-connect` | `python -m pip install clickhouse-connect` | diff --git a/aur/.SRCINFO b/aur/.SRCINFO index 23335685..5d988dbd 100644 --- a/aur/.SRCINFO +++ b/aur/.SRCINFO @@ -17,8 +17,7 @@ pkgbase = sqlit depends = python-docker optdepends = python-psycopg2: PostgreSQL, CockroachDB and Supabase support optdepends = python-pyodbc: SQL Server support - optdepends = python-pymysql: MySQL support - optdepends = python-mariadb-connector: MariaDB support + optdepends = python-pymysql: MySQL and MariaDB support optdepends = python-oracledb: Oracle support optdepends = python-duckdb: DuckDB support optdepends = python-clickhouse-connect: ClickHouse support diff --git a/aur/PKGBUILD b/aur/PKGBUILD index fd3e14bd..c2453fab 100644 --- a/aur/PKGBUILD +++ b/aur/PKGBUILD @@ -20,8 +20,7 @@ depends=( optdepends=( 'python-psycopg2: PostgreSQL, CockroachDB and Supabase support' 'python-pyodbc: SQL Server support' - 'python-pymysql: MySQL support' - 'python-mariadb-connector: MariaDB support' + 'python-pymysql: MySQL and MariaDB support' 'python-oracledb: Oracle support' 'python-duckdb: DuckDB support' 'python-clickhouse-connect: ClickHouse support' diff --git a/infra/docker/docker-compose.test.yml b/infra/docker/docker-compose.test.yml index fd648e74..b14db063 100644 --- a/infra/docker/docker-compose.test.yml +++ b/infra/docker/docker-compose.test.yml @@ -289,6 +289,48 @@ services: tmpfs: - /var/lib/mysql + # MariaDB with TIS-620 charset (Thai) - for charset testing + mariadb-tis620: + image: mariadb:11 + container_name: sqlit-test-mariadb-tis620 + command: --character-set-server=tis620 --collation-server=tis620_thai_ci + environment: + MARIADB_ROOT_PASSWORD: "TestPassword123!" + MARIADB_USER: "testuser" + MARIADB_PASSWORD: "TestPassword123!" + MARIADB_DATABASE: "test_sqlit" + ports: + - "${MARIADB_TIS620_PORT:-3310}:3306" + healthcheck: + test: ["CMD", "healthcheck.sh", "--connect", "--innodb_initialized"] + interval: 5s + timeout: 5s + retries: 10 + start_period: 30s + tmpfs: + - /var/lib/mysql + + # MariaDB with Latin1 charset - for charset testing + mariadb-latin1: + image: mariadb:11 + container_name: sqlit-test-mariadb-latin1 + command: --character-set-server=latin1 --collation-server=latin1_swedish_ci + environment: + MARIADB_ROOT_PASSWORD: "TestPassword123!" + MARIADB_USER: "testuser" + MARIADB_PASSWORD: "TestPassword123!" + MARIADB_DATABASE: "test_sqlit" + ports: + - "${MARIADB_LATIN1_PORT:-3311}:3306" + healthcheck: + test: ["CMD", "healthcheck.sh", "--connect", "--innodb_initialized"] + interval: 5s + timeout: 5s + retries: 10 + start_period: 30s + tmpfs: + - /var/lib/mysql + cockroachdb: image: cockroachdb/cockroach:latest container_name: sqlit-test-cockroachdb diff --git a/pyproject.toml b/pyproject.toml index 83e45d01..7940b1ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,6 @@ all = [ "psycopg2-binary>=2.9.0", "mssql-python>=1.1.0", "PyMySQL>=1.1.0", - "mariadb>=1.1.0", "oracledb>=2.0.0", "ibm_db>=3.2.0", "hdbcli>=2.20.0", @@ -68,7 +67,7 @@ postgres = ["psycopg2-binary>=2.9.0"] cockroachdb = ["psycopg2-binary>=2.9.0"] mssql = ["mssql-python>=1.1.0"] mysql = ["PyMySQL>=1.1.0"] -mariadb = ["mariadb>=1.1.0"] +mariadb = ["PyMySQL>=1.1.0"] oracle = ["oracledb>=2.0.0"] db2 = ["ibm_db>=3.2.0"] hana = ["hdbcli>=2.20.0"] @@ -225,7 +224,6 @@ namespace_packages = true [[tool.mypy.overrides]] module = [ "pymysql", - "mariadb", "oracledb", "duckdb", "clickhouse_connect", diff --git a/sqlit/domains/connections/app/install_strategy.py b/sqlit/domains/connections/app/install_strategy.py index 0eaf9ea8..1e9e9aa8 100644 --- a/sqlit/domains/connections/app/install_strategy.py +++ b/sqlit/domains/connections/app/install_strategy.py @@ -80,7 +80,6 @@ def _get_arch_package_name(package_name: str) -> str | None: "mssql-python": "python-mssql", "PyMySQL": "python-pymysql", "mysql-connector-python": "python-mysql-connector", - "mariadb": "python-mariadb-connector", "oracledb": "python-oracledb", "duckdb": "python-duckdb", "clickhouse-connect": "python-clickhouse-connect", diff --git a/sqlit/domains/connections/providers/mariadb/adapter.py b/sqlit/domains/connections/providers/mariadb/adapter.py index ba9459d2..60821645 100644 --- a/sqlit/domains/connections/providers/mariadb/adapter.py +++ b/sqlit/domains/connections/providers/mariadb/adapter.py @@ -1,17 +1,11 @@ -"""MariaDB adapter using mariadb connector.""" +"""MariaDB adapter using PyMySQL (pure Python).""" from __future__ import annotations import re from typing import TYPE_CHECKING, Any -from sqlit.domains.connections.providers.adapters.base import ( - ColumnInfo, - IndexInfo, - SequenceInfo, - TableInfo, - TriggerInfo, -) +from sqlit.domains.connections.providers.adapters.base import SequenceInfo from sqlit.domains.connections.providers.mysql.base import MySQLBaseAdapter from sqlit.domains.connections.providers.registry import get_default_port from sqlit.domains.connections.providers.tls import ( @@ -28,10 +22,12 @@ class MariaDBAdapter(MySQLBaseAdapter): - """Adapter for MariaDB using mariadb connector. + """Adapter for MariaDB using PyMySQL. - MariaDB uses ? placeholders instead of %s, so we override the - introspection methods that use parameterized queries. + PyMySQL speaks the MySQL/MariaDB wire protocol and is pure Python, so it + works on any platform without a system-level C library. It also handles + legacy charsets like TIS-620 and Latin1 that the MariaDB C connector + cannot read. """ @property @@ -44,11 +40,11 @@ def install_extra(self) -> str: @property def install_package(self) -> str: - return "mariadb" + return "PyMySQL" @property def driver_import_names(self) -> tuple[str, ...]: - return ("mariadb",) + return ("pymysql",) @property def supports_sequences(self) -> bool: @@ -65,8 +61,8 @@ def get_post_connect_warnings(self, config: ConnectionConfig) -> list[str]: def connect(self, config: ConnectionConfig) -> Any: """Connect to MariaDB database.""" - mariadb = self._import_driver_module( - "mariadb", + pymysql = self._import_driver_module( + "pymysql", driver_name=self.name, extra_name=self.install_extra, package_name=self.install_package, @@ -76,36 +72,57 @@ def connect(self, config: ConnectionConfig) -> Any: if endpoint is None: raise ValueError("MariaDB connections require a TCP-style endpoint.") port = int(endpoint.port or get_default_port("mariadb")) - mariadb_any: Any = mariadb + host = endpoint.host + if host and host.lower() == "localhost": + host = "127.0.0.1" connect_args: dict[str, Any] = { - "host": endpoint.host, + "host": host, "port": port, "database": endpoint.database or None, "user": endpoint.username, "password": endpoint.password, "connect_timeout": 10, + "autocommit": True, + "charset": "utf8mb4", } tls_mode = get_tls_mode(config) tls_ca, tls_cert, tls_key, _ = get_tls_files(config) has_tls_files = any([tls_ca, tls_cert, tls_key]) if tls_mode != TLS_MODE_DISABLE and (tls_mode != TLS_MODE_DEFAULT or has_tls_files): + import ssl + + ssl_params: dict[str, Any] = {} if tls_ca: - connect_args["ssl_ca"] = tls_ca + ssl_params["ca"] = tls_ca if tls_cert: - connect_args["ssl_cert"] = tls_cert + ssl_params["cert"] = tls_cert if tls_key: - connect_args["ssl_key"] = tls_key - if tls_mode != TLS_MODE_DEFAULT: - connect_args["ssl_verify_cert"] = tls_mode_verifies_cert(tls_mode) - connect_args["ssl_verify_identity"] = tls_mode_verifies_hostname(tls_mode) + ssl_params["key"] = tls_key + + if tls_mode_verifies_cert(tls_mode): + ssl_params["cert_reqs"] = ssl.CERT_REQUIRED + else: + ssl_params["cert_reqs"] = ssl.CERT_NONE + + ssl_params["check_hostname"] = tls_mode_verifies_hostname(tls_mode) + connect_args["ssl"] = ssl_params connect_args.update(config.extra_options) - conn = mariadb_any.connect(**connect_args) + conn = pymysql.connect(**connect_args) - # Note: The MariaDB Python connector only supports UTF-8 family charsets. - # Legacy charsets like TIS-620 or Latin1 are not supported. For databases - # using legacy charsets, use the MySQL provider with PyMySQL instead. + # Auto-sync charset with server to handle legacy encodings (e.g., TIS-620, Latin1). + try: + cursor = conn.cursor() + cursor.execute("SELECT @@character_set_database") + row = cursor.fetchone() + if row: + server_charset = row[0] + if server_charset and server_charset.lower() != "utf8mb4": + conn.set_charset(server_charset) + cursor.close() + except Exception: + pass self._supports_sequences = self._detect_sequences_support(conn) return conn @@ -123,7 +140,7 @@ def _detect_sequences_support(self, conn: Any) -> bool: return True self._server_version_str = row[0] - match = re.match(r"^(\\d+)\\.(\\d+)(?:\\.(\\d+))?", row[0]) + match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", row[0]) if not match: return True @@ -132,142 +149,15 @@ def _detect_sequences_support(self, conn: Any) -> bool: patch = int(match.group(3) or 0) return (major, minor, patch) >= (10, 3, 0) - # MariaDB connector uses ? placeholders instead of %s, so override methods with params - - def get_tables(self, conn: Any, database: str | None = None) -> list[TableInfo]: - """Get list of tables from MariaDB. Returns (schema, name) with empty schema.""" - cursor = conn.cursor() - if database: - cursor.execute( - "SELECT table_name FROM information_schema.tables " - "WHERE table_schema = ? AND table_type = 'BASE TABLE' " - "ORDER BY table_name", - (database,), - ) - else: - cursor.execute("SHOW TABLES") - return [("", row[0]) for row in cursor.fetchall()] - - def get_views(self, conn: Any, database: str | None = None) -> list[TableInfo]: - """Get list of views from MariaDB. Returns (schema, name) with empty schema.""" - cursor = conn.cursor() - if database: - cursor.execute( - "SELECT table_name FROM information_schema.views " "WHERE table_schema = ? ORDER BY table_name", - (database,), - ) - else: - cursor.execute( - "SELECT table_name FROM information_schema.views " "WHERE table_schema = DATABASE() ORDER BY table_name" - ) - return [("", row[0]) for row in cursor.fetchall()] - - def get_columns( - self, conn: Any, table: str, database: str | None = None, schema: str | None = None - ) -> list[ColumnInfo]: - """Get columns for a table from MariaDB. Schema parameter is ignored.""" - cursor = conn.cursor() - - # Get primary key columns - if database: - cursor.execute( - "SELECT column_name FROM information_schema.key_column_usage " - "WHERE table_schema = ? AND table_name = ? AND constraint_name = 'PRIMARY'", - (database, table), - ) - else: - cursor.execute( - "SELECT column_name FROM information_schema.key_column_usage " - "WHERE table_schema = DATABASE() AND table_name = ? AND constraint_name = 'PRIMARY'", - (table,), - ) - pk_columns = {row[0] for row in cursor.fetchall()} - - # Get all columns - if database: - cursor.execute( - "SELECT column_name, data_type FROM information_schema.columns " - "WHERE table_schema = ? AND table_name = ? " - "ORDER BY ordinal_position", - (database, table), - ) - else: - cursor.execute( - "SELECT column_name, data_type FROM information_schema.columns " - "WHERE table_schema = DATABASE() AND table_name = ? " - "ORDER BY ordinal_position", - (table,), - ) - return [ColumnInfo(name=row[0], data_type=row[1], is_primary_key=row[0] in pk_columns) for row in cursor.fetchall()] - - def get_procedures(self, conn: Any, database: str | None = None) -> list[str]: - """Get stored procedures from MariaDB.""" - cursor = conn.cursor() - if database: - cursor.execute( - "SELECT routine_name FROM information_schema.routines " - "WHERE routine_schema = ? AND routine_type = 'PROCEDURE' " - "ORDER BY routine_name", - (database,), - ) - else: - cursor.execute( - "SELECT routine_name FROM information_schema.routines " - "WHERE routine_schema = DATABASE() AND routine_type = 'PROCEDURE' " - "ORDER BY routine_name" - ) - return [row[0] for row in cursor.fetchall()] - - def get_indexes(self, conn: Any, database: str | None = None) -> list[IndexInfo]: - """Get indexes from MariaDB (uses ? placeholders).""" - cursor = conn.cursor() - if database: - cursor.execute( - "SELECT DISTINCT index_name, table_name, non_unique " - "FROM information_schema.statistics " - "WHERE table_schema = ? AND index_name != 'PRIMARY' " - "ORDER BY table_name, index_name", - (database,), - ) - else: - cursor.execute( - "SELECT DISTINCT index_name, table_name, non_unique " - "FROM information_schema.statistics " - "WHERE table_schema = DATABASE() AND index_name != 'PRIMARY' " - "ORDER BY table_name, index_name" - ) - return [ - IndexInfo(name=row[0], table_name=row[1], is_unique=row[2] == 0) - for row in cursor.fetchall() - ] - - def get_triggers(self, conn: Any, database: str | None = None) -> list[TriggerInfo]: - """Get triggers from MariaDB (uses ? placeholders).""" - cursor = conn.cursor() - if database: - cursor.execute( - "SELECT trigger_name, event_object_table " - "FROM information_schema.triggers " - "WHERE trigger_schema = ? " - "ORDER BY event_object_table, trigger_name", - (database,), - ) - else: - cursor.execute( - "SELECT trigger_name, event_object_table " - "FROM information_schema.triggers " - "WHERE trigger_schema = DATABASE() " - "ORDER BY event_object_table, trigger_name" - ) - return [TriggerInfo(name=row[0], table_name=row[1]) for row in cursor.fetchall()] - def get_sequences(self, conn: Any, database: str | None = None) -> list[SequenceInfo]: """Get sequences from MariaDB 10.3+.""" + if not self.supports_sequences: + return [] cursor = conn.cursor() if database: cursor.execute( "SELECT sequence_name FROM information_schema.sequences " - "WHERE sequence_schema = ? " + "WHERE sequence_schema = %s " "ORDER BY sequence_name", (database,), ) @@ -279,94 +169,23 @@ def get_sequences(self, conn: Any, database: str | None = None) -> list[Sequence ) return [SequenceInfo(name=row[0]) for row in cursor.fetchall()] - def get_index_definition( - self, conn: Any, index_name: str, table_name: str, database: str | None = None - ) -> dict[str, Any]: - """Get detailed information about a MariaDB index (uses ? placeholders).""" - cursor = conn.cursor() - if database: - cursor.execute( - "SELECT column_name, non_unique, index_type " - "FROM information_schema.statistics " - "WHERE table_schema = ? AND table_name = ? AND index_name = ? " - "ORDER BY seq_in_index", - (database, table_name, index_name), - ) - else: - cursor.execute( - "SELECT column_name, non_unique, index_type " - "FROM information_schema.statistics " - "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ? " - "ORDER BY seq_in_index", - (table_name, index_name), - ) - rows = cursor.fetchall() - columns = [row[0] for row in rows] - is_unique = rows[0][1] == 0 if rows else False - index_type = rows[0][2] if rows else "BTREE" - - return { - "name": index_name, - "table_name": table_name, - "columns": columns, - "is_unique": is_unique, - "type": index_type, - "definition": f"CREATE {'UNIQUE ' if is_unique else ''}INDEX {index_name} ON {table_name} ({', '.join(columns)})", - } - - def get_trigger_definition( - self, conn: Any, trigger_name: str, table_name: str, database: str | None = None - ) -> dict[str, Any]: - """Get detailed information about a MariaDB trigger (uses ? placeholders).""" - cursor = conn.cursor() - if database: - cursor.execute( - "SELECT action_timing, event_manipulation, action_statement " - "FROM information_schema.triggers " - "WHERE trigger_schema = ? AND trigger_name = ?", - (database, trigger_name), - ) - else: - cursor.execute( - "SELECT action_timing, event_manipulation, action_statement " - "FROM information_schema.triggers " - "WHERE trigger_schema = DATABASE() AND trigger_name = ?", - (trigger_name,), - ) - row = cursor.fetchone() - if row: - return { - "name": trigger_name, - "table_name": table_name, - "timing": row[0], - "event": row[1], - "definition": row[2], - } - return { - "name": trigger_name, - "table_name": table_name, - "timing": None, - "event": None, - "definition": None, - } - def get_sequence_definition( self, conn: Any, sequence_name: str, database: str | None = None ) -> dict[str, Any]: - """Get detailed information about a MariaDB sequence (uses ? placeholders).""" + """Get detailed information about a MariaDB sequence.""" cursor = conn.cursor() if database: cursor.execute( "SELECT start_value, increment, minimum_value, maximum_value, cycle_option " "FROM information_schema.sequences " - "WHERE sequence_schema = ? AND sequence_name = ?", + "WHERE sequence_schema = %s AND sequence_name = %s", (database, sequence_name), ) else: cursor.execute( "SELECT start_value, increment, minimum_value, maximum_value, cycle_option " "FROM information_schema.sequences " - "WHERE sequence_schema = DATABASE() AND sequence_name = ?", + "WHERE sequence_schema = DATABASE() AND sequence_name = %s", (sequence_name,), ) row = cursor.fetchone() diff --git a/tests/conftest.py b/tests/conftest.py index 583e2fb5..52016bc5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,7 @@ from tests.fixtures.flight import * from tests.fixtures.impala import * from tests.fixtures.mariadb import * +from tests.fixtures.mariadb_charset import * from tests.fixtures.mssql import * from tests.fixtures.mysql import * from tests.fixtures.mysql_charset import * diff --git a/tests/fixtures/mariadb.py b/tests/fixtures/mariadb.py index 7046abbd..c1dec699 100644 --- a/tests/fixtures/mariadb.py +++ b/tests/fixtures/mariadb.py @@ -39,12 +39,12 @@ def mariadb_db(mariadb_server_ready: bool) -> str: pytest.skip("MariaDB is not available") try: - import mariadb + import pymysql except ImportError: - pytest.skip("mariadb is not installed") + pytest.skip("PyMySQL is not installed") try: - conn = mariadb.connect( + conn = pymysql.connect( host=MARIADB_HOST, port=MARIADB_PORT, database=MARIADB_DATABASE, @@ -120,7 +120,7 @@ def mariadb_db(mariadb_server_ready: bool) -> str: yield MARIADB_DATABASE try: - conn = mariadb.connect( + conn = pymysql.connect( host=MARIADB_HOST, port=MARIADB_PORT, database=MARIADB_DATABASE, diff --git a/tests/fixtures/mariadb_charset.py b/tests/fixtures/mariadb_charset.py new file mode 100644 index 00000000..410ef8e2 --- /dev/null +++ b/tests/fixtures/mariadb_charset.py @@ -0,0 +1,249 @@ +"""MariaDB charset fixtures for testing legacy encodings (TIS-620, Latin1, etc.).""" + +from __future__ import annotations + +import os +import time + +import pytest + +from tests.fixtures.utils import cleanup_connection, is_port_open, run_cli + +# TIS-620 (Thai) MariaDB +MARIADB_TIS620_HOST = os.environ.get("MARIADB_TIS620_HOST", "127.0.0.1") +MARIADB_TIS620_PORT = int(os.environ.get("MARIADB_TIS620_PORT", "3310")) + +# Latin1 MariaDB +MARIADB_LATIN1_HOST = os.environ.get("MARIADB_LATIN1_HOST", "127.0.0.1") +MARIADB_LATIN1_PORT = int(os.environ.get("MARIADB_LATIN1_PORT", "3311")) + +# Common credentials (same as other MariaDB containers) +MARIADB_CHARSET_USER = os.environ.get("MARIADB_CHARSET_USER", "root") +MARIADB_CHARSET_PASSWORD = os.environ.get("MARIADB_CHARSET_PASSWORD", "TestPassword123!") +MARIADB_CHARSET_DATABASE = os.environ.get("MARIADB_CHARSET_DATABASE", "test_sqlit") + + +def mariadb_tis620_available() -> bool: + """Check if MariaDB TIS-620 is available.""" + return is_port_open(MARIADB_TIS620_HOST, MARIADB_TIS620_PORT) + + +def mariadb_latin1_available() -> bool: + """Check if MariaDB Latin1 is available.""" + return is_port_open(MARIADB_LATIN1_HOST, MARIADB_LATIN1_PORT) + + +@pytest.fixture(scope="session") +def mariadb_tis620_server_ready() -> bool: + """Check if MariaDB TIS-620 is ready and return True/False.""" + if not mariadb_tis620_available(): + return False + time.sleep(1) + return True + + +@pytest.fixture(scope="session") +def mariadb_latin1_server_ready() -> bool: + """Check if MariaDB Latin1 is ready and return True/False.""" + if not mariadb_latin1_available(): + return False + time.sleep(1) + return True + + +@pytest.fixture(scope="function") +def mariadb_tis620_db(mariadb_tis620_server_ready: bool) -> str: + """Set up MariaDB TIS-620 test database with Thai data.""" + if not mariadb_tis620_server_ready: + pytest.skip("MariaDB TIS-620 is not available") + + try: + import pymysql + except ImportError: + pytest.skip("PyMySQL is not installed") + + try: + conn = pymysql.connect( + host=MARIADB_TIS620_HOST, + port=MARIADB_TIS620_PORT, + database=MARIADB_CHARSET_DATABASE, + user=MARIADB_CHARSET_USER, + password=MARIADB_CHARSET_PASSWORD, + connect_timeout=10, + charset="tis620", + ) + cursor = conn.cursor() + + cursor.execute("DROP TABLE IF EXISTS charset_test") + cursor.execute( + "CREATE TABLE charset_test (id INT PRIMARY KEY, content TEXT) " + "CHARACTER SET tis620 COLLATE tis620_thai_ci" + ) + + cursor.execute("INSERT INTO charset_test VALUES (1, 'สวัสดีครับ')") + cursor.execute("INSERT INTO charset_test VALUES (2, 'ภาษาไทย')") + cursor.execute("INSERT INTO charset_test VALUES (3, 'กรุงเทพมหานคร')") + + conn.commit() + conn.close() + + except Exception as e: + pytest.skip(f"Failed to setup MariaDB TIS-620 database: {e}") + + yield MARIADB_CHARSET_DATABASE + + try: + conn = pymysql.connect( + host=MARIADB_TIS620_HOST, + port=MARIADB_TIS620_PORT, + database=MARIADB_CHARSET_DATABASE, + user=MARIADB_CHARSET_USER, + password=MARIADB_CHARSET_PASSWORD, + connect_timeout=10, + charset="tis620", + ) + cursor = conn.cursor() + cursor.execute("DROP TABLE IF EXISTS charset_test") + conn.commit() + conn.close() + except Exception: + pass + + +@pytest.fixture(scope="function") +def mariadb_latin1_db(mariadb_latin1_server_ready: bool) -> str: + """Set up MariaDB Latin1 test database with accented Latin characters.""" + if not mariadb_latin1_server_ready: + pytest.skip("MariaDB Latin1 is not available") + + try: + import pymysql + except ImportError: + pytest.skip("PyMySQL is not installed") + + try: + conn = pymysql.connect( + host=MARIADB_LATIN1_HOST, + port=MARIADB_LATIN1_PORT, + database=MARIADB_CHARSET_DATABASE, + user=MARIADB_CHARSET_USER, + password=MARIADB_CHARSET_PASSWORD, + connect_timeout=10, + charset="latin1", + ) + cursor = conn.cursor() + + cursor.execute("DROP TABLE IF EXISTS charset_test") + cursor.execute( + "CREATE TABLE charset_test (id INT PRIMARY KEY, content TEXT) " + "CHARACTER SET latin1 COLLATE latin1_swedish_ci" + ) + + cursor.execute("INSERT INTO charset_test VALUES (1, 'café')") + cursor.execute("INSERT INTO charset_test VALUES (2, 'naïve')") + cursor.execute("INSERT INTO charset_test VALUES (3, 'Müller')") + cursor.execute("INSERT INTO charset_test VALUES (4, 'señor')") + + conn.commit() + conn.close() + + except Exception as e: + pytest.skip(f"Failed to setup MariaDB Latin1 database: {e}") + + yield MARIADB_CHARSET_DATABASE + + try: + conn = pymysql.connect( + host=MARIADB_LATIN1_HOST, + port=MARIADB_LATIN1_PORT, + database=MARIADB_CHARSET_DATABASE, + user=MARIADB_CHARSET_USER, + password=MARIADB_CHARSET_PASSWORD, + connect_timeout=10, + charset="latin1", + ) + cursor = conn.cursor() + cursor.execute("DROP TABLE IF EXISTS charset_test") + conn.commit() + conn.close() + except Exception: + pass + + +@pytest.fixture(scope="function") +def mariadb_tis620_connection(mariadb_tis620_db: str) -> str: + """Create a sqlit CLI connection for MariaDB TIS-620.""" + connection_name = f"test_mariadb_tis620_{os.getpid()}" + + cleanup_connection(connection_name) + + run_cli( + "connections", + "add", + "mariadb", + "--name", + connection_name, + "--server", + MARIADB_TIS620_HOST, + "--port", + str(MARIADB_TIS620_PORT), + "--database", + mariadb_tis620_db, + "--username", + MARIADB_CHARSET_USER, + "--password", + MARIADB_CHARSET_PASSWORD, + ) + + yield connection_name + + cleanup_connection(connection_name) + + +@pytest.fixture(scope="function") +def mariadb_latin1_connection(mariadb_latin1_db: str) -> str: + """Create a sqlit CLI connection for MariaDB Latin1.""" + connection_name = f"test_mariadb_latin1_{os.getpid()}" + + cleanup_connection(connection_name) + + run_cli( + "connections", + "add", + "mariadb", + "--name", + connection_name, + "--server", + MARIADB_LATIN1_HOST, + "--port", + str(MARIADB_LATIN1_PORT), + "--database", + mariadb_latin1_db, + "--username", + MARIADB_CHARSET_USER, + "--password", + MARIADB_CHARSET_PASSWORD, + ) + + yield connection_name + + cleanup_connection(connection_name) + + +__all__ = [ + "MARIADB_CHARSET_DATABASE", + "MARIADB_CHARSET_PASSWORD", + "MARIADB_CHARSET_USER", + "MARIADB_LATIN1_HOST", + "MARIADB_LATIN1_PORT", + "MARIADB_TIS620_HOST", + "MARIADB_TIS620_PORT", + "mariadb_latin1_available", + "mariadb_latin1_connection", + "mariadb_latin1_db", + "mariadb_latin1_server_ready", + "mariadb_tis620_available", + "mariadb_tis620_connection", + "mariadb_tis620_db", + "mariadb_tis620_server_ready", +] diff --git a/tests/integration/test_stale_connection_reconnect.py b/tests/integration/test_stale_connection_reconnect.py index d6d189c5..46713e51 100644 --- a/tests/integration/test_stale_connection_reconnect.py +++ b/tests/integration/test_stale_connection_reconnect.py @@ -444,12 +444,12 @@ def work() -> None: if spec.key == "mariadb": try: - import mariadb + import pymysql except ImportError: - pytest.skip("mariadb is not installed") + pytest.skip("PyMySQL is not installed") def work() -> None: - conn = mariadb.connect( + conn = pymysql.connect( host=spec.host, port=int(spec.port), database=spec.database, diff --git a/tests/test_charset_support.py b/tests/test_charset_support.py index ebbcdfe2..3fd3fe01 100644 --- a/tests/test_charset_support.py +++ b/tests/test_charset_support.py @@ -8,8 +8,6 @@ import json -import pytest - from tests.fixtures.utils import run_cli @@ -315,3 +313,301 @@ def test_unicode_insert_and_read(self, mysql_connection: str) -> None: f"DROP TABLE IF EXISTS {table_name}", check=False, ) + + +class TestMariaDBTIS620Charset: + """Test Thai TIS-620 charset support in MariaDB. + + Now possible because MariaDB uses PyMySQL (pure Python wire protocol), + which handles legacy charsets unlike the old C `mariadb` connector. + """ + + def test_thai_characters_read_correctly(self, mariadb_tis620_connection: str) -> None: + """Test reading Thai data from TIS-620 MariaDB database.""" + result = run_cli( + "query", + "-c", + mariadb_tis620_connection, + "-q", + "SELECT content FROM charset_test WHERE id = 1", + "--format", + "json", + check=False, + ) + + assert result.returncode == 0, f"Query failed: {result.stderr}" + data = json.loads(result.stdout) + + expected = "สวัสดีครับ" + actual = data[0]["content"] + + assert actual == expected, ( + f"Thai charset mismatch!\n" + f"Expected: {expected}\n" + f"Got: {actual}\n" + f"This indicates charset auto-detection is not working." + ) + + def test_thai_language_name(self, mariadb_tis620_connection: str) -> None: + """Test reading 'Thai language' in Thai script.""" + result = run_cli( + "query", + "-c", + mariadb_tis620_connection, + "-q", + "SELECT content FROM charset_test WHERE id = 2", + "--format", + "json", + check=False, + ) + + assert result.returncode == 0, f"Query failed: {result.stderr}" + data = json.loads(result.stdout) + + expected = "ภาษาไทย" + actual = data[0]["content"] + + assert actual == expected, f"Expected: {expected}, Got: {actual}" + + def test_bangkok_city_name(self, mariadb_tis620_connection: str) -> None: + """Test reading Bangkok city name in Thai script.""" + result = run_cli( + "query", + "-c", + mariadb_tis620_connection, + "-q", + "SELECT content FROM charset_test WHERE id = 3", + "--format", + "json", + check=False, + ) + + assert result.returncode == 0, f"Query failed: {result.stderr}" + data = json.loads(result.stdout) + + expected = "กรุงเทพมหานคร" + actual = data[0]["content"] + + assert actual == expected, f"Expected: {expected}, Got: {actual}" + + def test_all_thai_rows(self, mariadb_tis620_connection: str) -> None: + """Test reading all Thai rows in a single query.""" + result = run_cli( + "query", + "-c", + mariadb_tis620_connection, + "-q", + "SELECT id, content FROM charset_test ORDER BY id", + "--format", + "json", + check=False, + ) + + assert result.returncode == 0, f"Query failed: {result.stderr}" + data = json.loads(result.stdout) + + expected_values = { + 1: "สวัสดีครับ", + 2: "ภาษาไทย", + 3: "กรุงเทพมหานคร", + } + + assert len(data) == 3, f"Expected 3 rows, got {len(data)}" + + for row in data: + row_id = row["id"] + expected = expected_values[row_id] + actual = row["content"] + assert actual == expected, ( + f"Row {row_id}: Expected '{expected}', Got '{actual}'" + ) + + +class TestMariaDBLatin1Charset: + """Test Latin1 charset support in MariaDB.""" + + def test_french_accents(self, mariadb_latin1_connection: str) -> None: + """Test reading French accented characters from Latin1 database.""" + result = run_cli( + "query", + "-c", + mariadb_latin1_connection, + "-q", + "SELECT content FROM charset_test WHERE id = 1", + "--format", + "json", + check=False, + ) + + assert result.returncode == 0, f"Query failed: {result.stderr}" + data = json.loads(result.stdout) + + expected = "café" + actual = data[0]["content"] + + assert actual == expected, f"Expected: {expected}, Got: {actual}" + + def test_french_diaeresis(self, mariadb_latin1_connection: str) -> None: + """Test reading French word with diaeresis from Latin1 database.""" + result = run_cli( + "query", + "-c", + mariadb_latin1_connection, + "-q", + "SELECT content FROM charset_test WHERE id = 2", + "--format", + "json", + check=False, + ) + + assert result.returncode == 0, f"Query failed: {result.stderr}" + data = json.loads(result.stdout) + + expected = "naïve" + actual = data[0]["content"] + + assert actual == expected, f"Expected: {expected}, Got: {actual}" + + def test_german_umlaut(self, mariadb_latin1_connection: str) -> None: + """Test reading German name with umlaut from Latin1 database.""" + result = run_cli( + "query", + "-c", + mariadb_latin1_connection, + "-q", + "SELECT content FROM charset_test WHERE id = 3", + "--format", + "json", + check=False, + ) + + assert result.returncode == 0, f"Query failed: {result.stderr}" + data = json.loads(result.stdout) + + expected = "Müller" + actual = data[0]["content"] + + assert actual == expected, f"Expected: {expected}, Got: {actual}" + + def test_spanish_tilde(self, mariadb_latin1_connection: str) -> None: + """Test reading Spanish word with tilde from Latin1 database.""" + result = run_cli( + "query", + "-c", + mariadb_latin1_connection, + "-q", + "SELECT content FROM charset_test WHERE id = 4", + "--format", + "json", + check=False, + ) + + assert result.returncode == 0, f"Query failed: {result.stderr}" + data = json.loads(result.stdout) + + expected = "señor" + actual = data[0]["content"] + + assert actual == expected, f"Expected: {expected}, Got: {actual}" + + def test_all_latin1_rows(self, mariadb_latin1_connection: str) -> None: + """Test reading all Latin1 rows in a single query.""" + result = run_cli( + "query", + "-c", + mariadb_latin1_connection, + "-q", + "SELECT id, content FROM charset_test ORDER BY id", + "--format", + "json", + check=False, + ) + + assert result.returncode == 0, f"Query failed: {result.stderr}" + data = json.loads(result.stdout) + + expected_values = { + 1: "café", + 2: "naïve", + 3: "Müller", + 4: "señor", + } + + assert len(data) == 4, f"Expected 4 rows, got {len(data)}" + + for row in data: + row_id = row["id"] + expected = expected_values[row_id] + actual = row["content"] + assert actual == expected, ( + f"Row {row_id}: Expected '{expected}', Got '{actual}'" + ) + + +class TestMariaDBUTF8Baseline: + """Baseline tests: UTF-8 MariaDB should always work correctly.""" + + def test_unicode_insert_and_read(self, mariadb_connection: str) -> None: + """Test that UTF-8 MariaDB handles Unicode correctly (baseline).""" + import uuid + + table_name = f"unicode_test_{uuid.uuid4().hex[:8]}" + + try: + create_result = run_cli( + "query", + "-c", + mariadb_connection, + "-q", + f"CREATE TABLE {table_name} (id INT, content TEXT)", + check=False, + ) + assert create_result.returncode == 0, f"Create failed: {create_result.stderr}" + + insert_result = run_cli( + "query", + "-c", + mariadb_connection, + "-q", + f"INSERT INTO {table_name} VALUES (1, 'Hello'), (2, 'café'), (3, 'สวัสดี'), (4, '你好')", + check=False, + ) + assert insert_result.returncode == 0, f"Insert failed: {insert_result.stderr}" + + result = run_cli( + "query", + "-c", + mariadb_connection, + "-q", + f"SELECT id, content FROM {table_name} ORDER BY id", + "--format", + "json", + check=False, + ) + + assert result.returncode == 0, f"Query failed: {result.stderr}" + data = json.loads(result.stdout) + + expected_values = { + 1: "Hello", + 2: "café", + 3: "สวัสดี", + 4: "你好", + } + + for row in data: + row_id = row["id"] + expected = expected_values[row_id] + actual = row["content"] + assert actual == expected, ( + f"UTF-8 baseline failed! Row {row_id}: Expected '{expected}', Got '{actual}'" + ) + finally: + run_cli( + "query", + "-c", + mariadb_connection, + "-q", + f"DROP TABLE IF EXISTS {table_name}", + check=False, + )