Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions datacache/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@

METADATA_TABLE_NAME = "_datacache_metadata"


def quote_identifier(identifier):
"""
Wrap a SQL identifier (table, column, or index name) in double quotes so
that names which contain spaces or punctuation, or which collide with SQL
keywords, don't need to be sanitized down to a restricted character set.
Embedded double quotes are escaped by doubling them, per the SQL standard.

See https://github.com/openvax/datacache/issues/17
"""
return '"%s"' % str(identifier).replace('"', '""')


class Database(object):
"""
Wrapper object for sqlite3 database which provides helpers for
Expand Down Expand Up @@ -66,7 +79,7 @@ def has_table(self, table_name):
def drop_all_tables(self):
"""Drop all tables in the database"""
for table_name in self.table_names():
self.execute_sql("DROP TABLE %s" % table_name)
self.execute_sql("DROP TABLE %s" % quote_identifier(table_name))
self.connection.commit()

def execute_sql(self, sql, commit=False):
Expand All @@ -91,7 +104,9 @@ def has_version(self):
def version(self):
"""What's the version of this database? Found in metadata attached
by datacache when creating this database."""
query = "SELECT version FROM %s" % METADATA_TABLE_NAME
query = "SELECT %s FROM %s" % (
quote_identifier("version"),
quote_identifier(METADATA_TABLE_NAME))
cursor = self.connection.execute(query)
version = cursor.fetchone()
if not version:
Expand All @@ -110,10 +125,13 @@ def _finalize_database(self, version):
"""
require_integer(version, "version")
create_metadata_sql = \
"CREATE TABLE %s (version INT)" % METADATA_TABLE_NAME
"CREATE TABLE %s (%s INT)" % (
quote_identifier(METADATA_TABLE_NAME),
quote_identifier("version"))
self.execute_sql(create_metadata_sql)
insert_version_sql = \
"INSERT INTO %s VALUES (%s)" % (METADATA_TABLE_NAME, version)
"INSERT INTO %s VALUES (%s)" % (
quote_identifier(METADATA_TABLE_NAME), version)
self.execute_sql(insert_version_sql)

def _create_table(self, table_name, column_types, primary=None, nullable=()):
Expand All @@ -139,15 +157,16 @@ def _create_table(self, table_name, column_types, primary=None, nullable=()):

column_decls = []
for column_name, column_type in column_types:
decl = "%s %s" % (column_name, column_type)
decl = "%s %s" % (quote_identifier(column_name), column_type)
if column_name == primary:
decl += " UNIQUE PRIMARY KEY"
if column_name not in nullable:
decl += " NOT NULL"
column_decls.append(decl)
column_decl_str = ", ".join(column_decls)
create_table_sql = \
"CREATE TABLE %s (%s)" % (table_name, column_decl_str)
"CREATE TABLE %s (%s)" % (
quote_identifier(table_name), column_decl_str)
self.execute_sql(create_table_sql)

def _fill_table(self, table_name, rows):
Expand All @@ -166,7 +185,8 @@ def _fill_table(self, table_name, rows):
raise ValueError("Rows must all have %d values" % n_columns)
blank_slots = ", ".join("?" for _ in range(n_columns))
logger.info("Inserting %d rows into table %s", len(rows), table_name)
sql = "INSERT INTO %s VALUES (%s)" % (table_name, blank_slots)
sql = "INSERT INTO %s VALUES (%s)" % (
quote_identifier(table_name), blank_slots)
self.connection.executemany(sql, rows)

def create(self, tables, version):
Expand Down Expand Up @@ -212,9 +232,9 @@ def _create_index(self, table_name, index_columns):
"_".join(index_columns))
self.connection.execute(
"CREATE INDEX IF NOT EXISTS %s ON %s (%s)" % (
index_name,
table_name,
", ".join(index_columns)))
quote_identifier(index_name),
quote_identifier(table_name),
", ".join(quote_identifier(c) for c in index_columns)))

def _create_indices(self, table_name, indices):
"""
Expand Down
22 changes: 22 additions & 0 deletions tests/test_database_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,25 @@ def query_in_thread():
thread.start()
thread.join()
eq_(results["value"], 2)

def test_create_db_with_reserved_and_spaced_identifiers():
# regression test for https://github.com/openvax/datacache/issues/17:
# table/column names that are SQL keywords or contain spaces must be
# quoted so they don't have to be sanitized to alphanumeric + underscore.
table_name = "weird table"
column_types = [("order", "INT"), ("group by", "STR")]
rows = [(1, "a"), (2, "b")]
table = datacache.database_table.DatabaseTable(
name=table_name,
column_types=column_types,
make_rows=lambda: rows,
indices=[["group by"]],
nullable={"group by"},
primary_key="order")
with tempfile.NamedTemporaryFile(suffix="test.db") as f:
db = datacache.database.Database(f.name)
db.create(tables=[table], version=VERSION)
assert db.has_table(table_name)
cursor = db.connection.execute(
'SELECT "order" FROM "weird table" WHERE "group by" = ?', ("b",))
eq_(cursor.fetchone()[0], 2)
Loading