diff --git a/datacache/database.py b/datacache/database.py index 937c63c..3d59d0f 100644 --- a/datacache/database.py +++ b/datacache/database.py @@ -24,6 +24,19 @@ METADATA_TABLE_NAME = "_datacache_metadata" + +def quote_identifier(identifier): + """ + Wrap a SQL identifier (table, column, or index name) in double quotes so + that names which contain spaces or punctuation, or which collide with SQL + keywords, don't need to be sanitized down to a restricted character set. + Embedded double quotes are escaped by doubling them, per the SQL standard. + + See https://github.com/openvax/datacache/issues/17 + """ + return '"%s"' % str(identifier).replace('"', '""') + + class Database(object): """ Wrapper object for sqlite3 database which provides helpers for @@ -66,7 +79,7 @@ def has_table(self, table_name): def drop_all_tables(self): """Drop all tables in the database""" for table_name in self.table_names(): - self.execute_sql("DROP TABLE %s" % table_name) + self.execute_sql("DROP TABLE %s" % quote_identifier(table_name)) self.connection.commit() def execute_sql(self, sql, commit=False): @@ -91,7 +104,9 @@ def has_version(self): def version(self): """What's the version of this database? Found in metadata attached by datacache when creating this database.""" - query = "SELECT version FROM %s" % METADATA_TABLE_NAME + query = "SELECT %s FROM %s" % ( + quote_identifier("version"), + quote_identifier(METADATA_TABLE_NAME)) cursor = self.connection.execute(query) version = cursor.fetchone() if not version: @@ -110,10 +125,13 @@ def _finalize_database(self, version): """ require_integer(version, "version") create_metadata_sql = \ - "CREATE TABLE %s (version INT)" % METADATA_TABLE_NAME + "CREATE TABLE %s (%s INT)" % ( + quote_identifier(METADATA_TABLE_NAME), + quote_identifier("version")) self.execute_sql(create_metadata_sql) insert_version_sql = \ - "INSERT INTO %s VALUES (%s)" % (METADATA_TABLE_NAME, version) + "INSERT INTO %s VALUES (%s)" % ( + quote_identifier(METADATA_TABLE_NAME), version) self.execute_sql(insert_version_sql) def _create_table(self, table_name, column_types, primary=None, nullable=()): @@ -139,7 +157,7 @@ def _create_table(self, table_name, column_types, primary=None, nullable=()): column_decls = [] for column_name, column_type in column_types: - decl = "%s %s" % (column_name, column_type) + decl = "%s %s" % (quote_identifier(column_name), column_type) if column_name == primary: decl += " UNIQUE PRIMARY KEY" if column_name not in nullable: @@ -147,7 +165,8 @@ def _create_table(self, table_name, column_types, primary=None, nullable=()): column_decls.append(decl) column_decl_str = ", ".join(column_decls) create_table_sql = \ - "CREATE TABLE %s (%s)" % (table_name, column_decl_str) + "CREATE TABLE %s (%s)" % ( + quote_identifier(table_name), column_decl_str) self.execute_sql(create_table_sql) def _fill_table(self, table_name, rows): @@ -166,7 +185,8 @@ def _fill_table(self, table_name, rows): raise ValueError("Rows must all have %d values" % n_columns) blank_slots = ", ".join("?" for _ in range(n_columns)) logger.info("Inserting %d rows into table %s", len(rows), table_name) - sql = "INSERT INTO %s VALUES (%s)" % (table_name, blank_slots) + sql = "INSERT INTO %s VALUES (%s)" % ( + quote_identifier(table_name), blank_slots) self.connection.executemany(sql, rows) def create(self, tables, version): @@ -212,9 +232,9 @@ def _create_index(self, table_name, index_columns): "_".join(index_columns)) self.connection.execute( "CREATE INDEX IF NOT EXISTS %s ON %s (%s)" % ( - index_name, - table_name, - ", ".join(index_columns))) + quote_identifier(index_name), + quote_identifier(table_name), + ", ".join(quote_identifier(c) for c in index_columns))) def _create_indices(self, table_name, indices): """ diff --git a/tests/test_database_objects.py b/tests/test_database_objects.py index 75ac0d9..340a214 100644 --- a/tests/test_database_objects.py +++ b/tests/test_database_objects.py @@ -73,3 +73,25 @@ def query_in_thread(): thread.start() thread.join() eq_(results["value"], 2) + +def test_create_db_with_reserved_and_spaced_identifiers(): + # regression test for https://github.com/openvax/datacache/issues/17: + # table/column names that are SQL keywords or contain spaces must be + # quoted so they don't have to be sanitized to alphanumeric + underscore. + table_name = "weird table" + column_types = [("order", "INT"), ("group by", "STR")] + rows = [(1, "a"), (2, "b")] + table = datacache.database_table.DatabaseTable( + name=table_name, + column_types=column_types, + make_rows=lambda: rows, + indices=[["group by"]], + nullable={"group by"}, + primary_key="order") + with tempfile.NamedTemporaryFile(suffix="test.db") as f: + db = datacache.database.Database(f.name) + db.create(tables=[table], version=VERSION) + assert db.has_table(table_name) + cursor = db.connection.execute( + 'SELECT "order" FROM "weird table" WHERE "group by" = ?', ("b",)) + eq_(cursor.fetchone()[0], 2)