From 24ced2933834ced2fdc77df4b0febcb112bcc623 Mon Sep 17 00:00:00 2001 From: Alex Rubinsteyn Date: Thu, 18 Jun 2026 13:34:26 -0400 Subject: [PATCH] Fix #17: quote SQL table/column/index identifiers Wrap every table, column, and index name in double quotes when building SQL in database.py, via a new quote_identifier() helper that also escapes embedded double quotes. This means identifiers containing spaces or punctuation, or which collide with SQL keywords (e.g. "order", "group"), no longer have to be sanitized down to alphanumeric + underscore. Adds a regression test that creates and queries a table whose name and columns contain spaces and reserved words. Claude-Session: https://claude.ai/code/session_011bzfZPTzWnhAMVD7msyMg1 --- datacache/database.py | 40 +++++++++++++++++++++++++--------- tests/test_database_objects.py | 22 +++++++++++++++++++ 2 files changed, 52 insertions(+), 10 deletions(-) diff --git a/datacache/database.py b/datacache/database.py index 937c63c..3d59d0f 100644 --- a/datacache/database.py +++ b/datacache/database.py @@ -24,6 +24,19 @@ METADATA_TABLE_NAME = "_datacache_metadata" + +def quote_identifier(identifier): + """ + Wrap a SQL identifier (table, column, or index name) in double quotes so + that names which contain spaces or punctuation, or which collide with SQL + keywords, don't need to be sanitized down to a restricted character set. + Embedded double quotes are escaped by doubling them, per the SQL standard. + + See https://github.com/openvax/datacache/issues/17 + """ + return '"%s"' % str(identifier).replace('"', '""') + + class Database(object): """ Wrapper object for sqlite3 database which provides helpers for @@ -66,7 +79,7 @@ def has_table(self, table_name): def drop_all_tables(self): """Drop all tables in the database""" for table_name in self.table_names(): - self.execute_sql("DROP TABLE %s" % table_name) + self.execute_sql("DROP TABLE %s" % quote_identifier(table_name)) self.connection.commit() def execute_sql(self, sql, commit=False): @@ -91,7 +104,9 @@ def has_version(self): def version(self): """What's the version of this database? Found in metadata attached by datacache when creating this database.""" - query = "SELECT version FROM %s" % METADATA_TABLE_NAME + query = "SELECT %s FROM %s" % ( + quote_identifier("version"), + quote_identifier(METADATA_TABLE_NAME)) cursor = self.connection.execute(query) version = cursor.fetchone() if not version: @@ -110,10 +125,13 @@ def _finalize_database(self, version): """ require_integer(version, "version") create_metadata_sql = \ - "CREATE TABLE %s (version INT)" % METADATA_TABLE_NAME + "CREATE TABLE %s (%s INT)" % ( + quote_identifier(METADATA_TABLE_NAME), + quote_identifier("version")) self.execute_sql(create_metadata_sql) insert_version_sql = \ - "INSERT INTO %s VALUES (%s)" % (METADATA_TABLE_NAME, version) + "INSERT INTO %s VALUES (%s)" % ( + quote_identifier(METADATA_TABLE_NAME), version) self.execute_sql(insert_version_sql) def _create_table(self, table_name, column_types, primary=None, nullable=()): @@ -139,7 +157,7 @@ def _create_table(self, table_name, column_types, primary=None, nullable=()): column_decls = [] for column_name, column_type in column_types: - decl = "%s %s" % (column_name, column_type) + decl = "%s %s" % (quote_identifier(column_name), column_type) if column_name == primary: decl += " UNIQUE PRIMARY KEY" if column_name not in nullable: @@ -147,7 +165,8 @@ def _create_table(self, table_name, column_types, primary=None, nullable=()): column_decls.append(decl) column_decl_str = ", ".join(column_decls) create_table_sql = \ - "CREATE TABLE %s (%s)" % (table_name, column_decl_str) + "CREATE TABLE %s (%s)" % ( + quote_identifier(table_name), column_decl_str) self.execute_sql(create_table_sql) def _fill_table(self, table_name, rows): @@ -166,7 +185,8 @@ def _fill_table(self, table_name, rows): raise ValueError("Rows must all have %d values" % n_columns) blank_slots = ", ".join("?" for _ in range(n_columns)) logger.info("Inserting %d rows into table %s", len(rows), table_name) - sql = "INSERT INTO %s VALUES (%s)" % (table_name, blank_slots) + sql = "INSERT INTO %s VALUES (%s)" % ( + quote_identifier(table_name), blank_slots) self.connection.executemany(sql, rows) def create(self, tables, version): @@ -212,9 +232,9 @@ def _create_index(self, table_name, index_columns): "_".join(index_columns)) self.connection.execute( "CREATE INDEX IF NOT EXISTS %s ON %s (%s)" % ( - index_name, - table_name, - ", ".join(index_columns))) + quote_identifier(index_name), + quote_identifier(table_name), + ", ".join(quote_identifier(c) for c in index_columns))) def _create_indices(self, table_name, indices): """ diff --git a/tests/test_database_objects.py b/tests/test_database_objects.py index 75ac0d9..340a214 100644 --- a/tests/test_database_objects.py +++ b/tests/test_database_objects.py @@ -73,3 +73,25 @@ def query_in_thread(): thread.start() thread.join() eq_(results["value"], 2) + +def test_create_db_with_reserved_and_spaced_identifiers(): + # regression test for https://github.com/openvax/datacache/issues/17: + # table/column names that are SQL keywords or contain spaces must be + # quoted so they don't have to be sanitized to alphanumeric + underscore. + table_name = "weird table" + column_types = [("order", "INT"), ("group by", "STR")] + rows = [(1, "a"), (2, "b")] + table = datacache.database_table.DatabaseTable( + name=table_name, + column_types=column_types, + make_rows=lambda: rows, + indices=[["group by"]], + nullable={"group by"}, + primary_key="order") + with tempfile.NamedTemporaryFile(suffix="test.db") as f: + db = datacache.database.Database(f.name) + db.create(tables=[table], version=VERSION) + assert db.has_table(table_name) + cursor = db.connection.execute( + 'SELECT "order" FROM "weird table" WHERE "group by" = ?', ("b",)) + eq_(cursor.fetchone()[0], 2)