diff --git a/commcare_export/env.py b/commcare_export/env.py index e52b76b6..63e6cf55 100644 --- a/commcare_export/env.py +++ b/commcare_export/env.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) -JSONPATH_CACHE = {} +JSONPATH_CACHE: dict[str, Any] = {} class CannotBind(Exception): diff --git a/commcare_export/excel_query.py b/commcare_export/excel_query.py index 8fbf112f..8d9528e6 100644 --- a/commcare_export/excel_query.py +++ b/commcare_export/excel_query.py @@ -696,7 +696,7 @@ def check_columns(parsed_sheets, columns): raise MissingColumnException(errors_by_sheet) -blacklisted_tables = [] +blacklisted_tables: list[str] = [] def blacklist(table_name): diff --git a/commcare_export/writers.py b/commcare_export/writers.py index 54efe9c8..c49cafbc 100644 --- a/commcare_export/writers.py +++ b/commcare_export/writers.py @@ -3,6 +3,7 @@ import logging from tempfile import NamedTemporaryFile import zipfile +import itertools from itertools import zip_longest from typing import Optional @@ -16,6 +17,8 @@ logger = logging.getLogger(__name__) MAX_COLUMN_SIZE = 2000 +SCHEMA_CHECK_ROWS = 10 +BATCH_SIZE = 1000 def ensure_text(v, convert_none=False): @@ -605,25 +608,111 @@ def upsert(self, table, row_dict): ) self.connection.execute(update) + def _commit(self): + # Explicit commit works for all DB types. Replace with explicit + # transactions when upgrading to SQLAlchemy 2.0 + self.connection.execute(sqlalchemy.text('COMMIT')) + + def bulk_upsert(self, table, batch): + if not batch: + return + # SQLAlchemy requires all dicts in `batch` to have the same keys + # for `insert(table).values(batch)`. We need to drop the columns + # whose values are always `None` to reproduce the behavior of + # `SqlTableWriter.insert()`. `batch_keys` are the columns where + # _any_ row has a value set. + batch_keys = set() + for row_dict in batch: + for key, value in row_dict.items(): + if value is not None: + batch_keys.add(key) + batch = [{k: row_dict[k] for k in batch_keys} for row_dict in batch] + if self.is_postgres: + from sqlalchemy.dialects.postgresql import insert + + stmt = insert(table).values(batch) + new_row = stmt.excluded + elif self.is_mysql: + from sqlalchemy.dialects.mysql import insert + + stmt = insert(table).values(batch) + new_row = stmt.inserted + else: + # MSSQL and others: fall back to row-by-row + for row_dict in batch: + self.upsert(table, row_dict) + return + + # Use COALESCE so that a None in the inserted row preserves the + # existing column value, matching the per-row upsert() which + # strips Nones before building the UPDATE. + # Only reference columns that already exist on the table. New + # columns in batch_keys would raise KeyError here; the INSERT + # itself will then fail and _flush_batch retries after fixing + # the schema. + update_cols = { + c.name: sqlalchemy.func.coalesce(new_row[c.name], c) + for c in table.columns + if c.name != 'id' and c.name in batch_keys + } + if self.is_postgres: + stmt = stmt.on_conflict_do_update( + index_elements=['id'], + set_=update_cols, + ) + else: + stmt = stmt.on_duplicate_key_update(**update_cols) + self.connection.execute(stmt) + + def _flush_batch(self, table, batch, data_type_dict): + try: + self.bulk_upsert(table, batch) + except ( + sqlalchemy.exc.CompileError, + sqlalchemy.exc.OperationalError, + sqlalchemy.exc.ProgrammingError, + sqlalchemy.exc.DataError, + ): + # Likely a schema mismatch; fix schema and retry once + for row_dict in batch: + table = self.make_table_compatible( + table, + row_dict, + data_type_dict, + ) + self.bulk_upsert(table, batch) + self._commit() + def write_table(self, table_spec: TableSpec) -> None: table_name = table_spec.name headings = table_spec.headings data_type_dict = dict(zip_longest(headings, table_spec.data_types)) - for i, row in enumerate(table_spec.rows): - row_dict = dict(zip(headings, row)) - if i == 0: - table = self.get_table(table_name) - if table is None: - table = self.create_table( - table_name, - row_dict, - data_type_dict, - ) - # Checks the data type for every cell in every row. Maybe we - # can use a future version of the data dictionary to avoid - # this? + + rows = (dict(zip(headings, row)) for row in table_spec.rows) + first_row = next(rows, None) + if first_row is None: + return + row_stream = itertools.chain([first_row], rows) + + table = self.get_table(table_name) + if table is None: + table = self.create_table(table_name, first_row, data_type_dict) + + for row_dict in itertools.islice(row_stream, SCHEMA_CHECK_ROWS): table = self.make_table_compatible(table, row_dict, data_type_dict) self.upsert(table, row_dict) + self._commit() + + logger.debug( + "Schema check complete for %s rows in table '%s'. " + 'Final columns: %s', + SCHEMA_CHECK_ROWS, + table_name, + [c.name for c in table.columns], + ) + + for batch in _batched(row_stream, BATCH_SIZE): + self._flush_batch(table, batch, data_type_dict) def _get_columns_for_data(self, row_dict, data_type_dict): return [self.get_id_column()] + [ @@ -638,3 +727,9 @@ def _get_columns_for_data(self, row_dict, data_type_dict): and column_name != 'id' ) ] + + +# Use itertools.batched when Python is always >= 3.12 +def _batched(iterable, n): + while batch := list(itertools.islice(iterable, n)): + yield batch diff --git a/tests/test_writers.py b/tests/test_writers.py index 9b144a2e..0c2dc204 100644 --- a/tests/test_writers.py +++ b/tests/test_writers.py @@ -3,6 +3,7 @@ import io import tempfile import zipfile +from itertools import zip_longest import openpyxl import sqlalchemy @@ -10,6 +11,7 @@ import pytest from commcare_export.specs import TableSpec from commcare_export.writers import ( + SCHEMA_CHECK_ROWS, CsvTableWriter, Excel2007TableWriter, JValueTableWriter, @@ -709,6 +711,130 @@ def test_mssql_nvarchar_length_downsize(self, writer): ) assert result['some_data'] == ('some_data', 'nvarchar', -1) + def test_bulk_upsert(self, writer): + # Create table via normal write_table path + with writer: + writer.write_table( + TableSpec( + name='foo_bulk_upsert', + headings=['id', 'a', 'b'], + rows=[ + ['row1', 'val1', 'x'], + ['row2', 'val2', 'y'], + ], + ) + ) + + # bulk_upsert: update row1, insert row3 + with writer: + table = writer.get_table('foo_bulk_upsert') + batch = [ + {'id': 'row1', 'a': 'updated1', 'b': 'ux'}, + {'id': 'row3', 'a': 'val3', 'b': 'z'}, + ] + writer.bulk_upsert(table, batch) + writer._commit() + + with writer: + result = { + row['id']: dict(row) + for row in writer.connection.execute( + 'SELECT id, a, b FROM foo_bulk_upsert' + ) + } + assert len(result) == 3 + assert result['row1'] == {'id': 'row1', 'a': 'updated1', 'b': 'ux'} + assert result['row2'] == {'id': 'row2', 'a': 'val2', 'b': 'y'} + assert result['row3'] == {'id': 'row3', 'a': 'val3', 'b': 'z'} + + def test_bulk_upsert_preserves_existing_values_for_none(self, writer): + with writer: + writer.write_table( + TableSpec( + name='foo_bulk_upsert_none', + headings=['id', 'a', 'b', 'c'], + rows=[ + ['row1', 'val1', 'x', 'keep1'], + ['row2', 'val2', 'y', 'keep2'], + ], + ) + ) + + # Update row1 with None for b (whole-batch None for column c). + # Both should be preserved at their existing values rather than + # clobbered to NULL. row3 is a new insert; its None values land + # as NULL since there is no prior row. + with writer: + table = writer.get_table('foo_bulk_upsert_none') + batch = [ + {'id': 'row1', 'a': 'updated1', 'b': None, 'c': None}, + {'id': 'row3', 'a': 'val3', 'b': 'z', 'c': None}, + ] + writer.bulk_upsert(table, batch) + writer._commit() + + with writer: + result = { + row['id']: dict(row) + for row in writer.connection.execute( + 'SELECT id, a, b, c FROM foo_bulk_upsert_none' + ) + } + assert len(result) == 3 + assert result['row1'] == { + 'id': 'row1', + 'a': 'updated1', + 'b': 'x', + 'c': 'keep1', + } + assert result['row2'] == { + 'id': 'row2', + 'a': 'val2', + 'b': 'y', + 'c': 'keep2', + } + assert result['row3'] == { + 'id': 'row3', + 'a': 'val3', + 'b': 'z', + 'c': None, + } + + def test_flush_batch_retry_on_new_column(self, writer): + # Create table with columns [id, a] + with writer: + writer.write_table( + TableSpec( + name='foo_flush_retry', + headings=['id', 'a'], + rows=[['row1', 'val1']], + ) + ) + + # _flush_batch with a batch containing new column 'b' + with writer: + table = writer.get_table('foo_flush_retry') + headings = ['id', 'a', 'b'] + data_type_dict = dict(zip_longest(headings, [])) + batch = [ + {'id': 'row2', 'a': 'val2', 'b': 'new_col_val'}, + ] + writer._flush_batch(table, batch, data_type_dict) + + with writer: + result = { + row['id']: dict(row) + for row in writer.connection.execute( + 'SELECT id, a, b FROM foo_flush_retry' + ) + } + assert len(result) == 2 + assert result['row2'] == { + 'id': 'row2', + 'a': 'val2', + 'b': 'new_col_val', + } + def test_emoji(self, writer): with writer: writer.write_table( @@ -723,3 +849,97 @@ def test_emoji(self, writer): } ) ) + + def test_batched_write(self, writer): + num_rows = SCHEMA_CHECK_ROWS + 15 + rows = [[f'id_{i}', f'a_{i}', i] for i in range(num_rows)] + with writer: + writer.write_table( + TableSpec( + name='foo_batched_write', + headings=['id', 'a', 'b'], + rows=rows, + ) + ) + + with writer: + result = list( + writer.connection.execute( + 'SELECT id, a, b FROM foo_batched_write' + ) + ) + assert len(result) == num_rows + result_dict = {row['id']: dict(row) for row in result} + for i in range(num_rows): + assert result_dict[f'id_{i}'] == { + 'id': f'id_{i}', + 'a': f'a_{i}', + 'b': i, + } + + def test_batched_upsert(self, writer): + num_rows = SCHEMA_CHECK_ROWS + 5 + rows = [[f'id_{i}', f'a_{i}', i] for i in range(num_rows)] + with writer: + writer.write_table( + TableSpec( + name='foo_batched_upsert', + headings=['id', 'a', 'b'], + rows=rows, + ) + ) + + # Second write: update all existing + add 5 new + rows2 = [ + [f'id_{i}', f'updated_{i}', i + 100] for i in range(num_rows + 5) + ] + with writer: + writer.write_table( + TableSpec( + name='foo_batched_upsert', + headings=['id', 'a', 'b'], + rows=rows2, + ) + ) + + with writer: + result = list( + writer.connection.execute( + 'SELECT id, a, b FROM foo_batched_upsert' + ) + ) + assert len(result) == num_rows + 5 + result_dict = {row['id']: dict(row) for row in result} + for i in range(num_rows + 5): + assert result_dict[f'id_{i}'] == { + 'id': f'id_{i}', + 'a': f'updated_{i}', + 'b': i + 100, + } + + def test_late_schema_change_via_write_table(self, writer): + rows = [] + for i in range(SCHEMA_CHECK_ROWS): + rows.append([f'id_{i}', f'a_{i}', None]) + for i in range(SCHEMA_CHECK_ROWS, SCHEMA_CHECK_ROWS + 5): + rows.append([f'id_{i}', f'a_{i}', f'b_{i}']) + + with writer: + writer.write_table( + TableSpec( + name='foo_late_schema', + headings=['id', 'a', 'b'], + rows=rows, + ) + ) + + with writer: + result = list( + writer.connection.execute( + 'SELECT id, a, b FROM foo_late_schema' + ) + ) + assert len(result) == SCHEMA_CHECK_ROWS + 5 + result_dict = {row['id']: dict(row) for row in result} + for i in range(SCHEMA_CHECK_ROWS, SCHEMA_CHECK_ROWS + 5): + assert result_dict[f'id_{i}']['b'] == f'b_{i}'