From 24fa83b6ee6270538f308147f81322ed951e9dde Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Wed, 8 Apr 2026 17:27:49 -0600 Subject: [PATCH 01/48] Replace SQLAlchemy with ibis-framework for all database operations Replace the SQLAlchemy ORM with ibis-framework to provide a cleaner multi-backend abstraction for DuckDB, SQLite, and Spark. This is a clean API break: engine->backend, Connection params removed, ibis expressions replace SQLAlchemy select/join chains. Key changes: - New src/chronify/ibis/ module with IbisBackend ABC and DuckDB/SQLite/Spark implementations - Remove SQLAlchemy, pyhive vendor code, Hive support, and related dependencies - Migrate all source modules (store, mappers, checker, converters) to ibis API - Migrate all tests to use backend fixtures instead of engine fixtures - Add ibis-framework[duckdb,sqlite] dependency, pyspark >= 4.0 for spark extra Co-Authored-By: Claude Opus 4.6 --- pyproject.toml | 19 +- src/chronify/__init__.py | 8 - src/chronify/_vendor/kyuubi/LICENSE | 201 - src/chronify/_vendor/kyuubi/README.md | 12 - .../kyuubi/TCLIService/TCLIService-remote | 264 - .../_vendor/kyuubi/TCLIService/TCLIService.py | 3986 --------- .../_vendor/kyuubi/TCLIService/__init__.py | 1 - .../_vendor/kyuubi/TCLIService/constants.py | 68 - .../_vendor/kyuubi/TCLIService/ttypes.py | 7210 ----------------- .../_vendor/kyuubi/pyhive/__init__.py | 3 - src/chronify/_vendor/kyuubi/pyhive/common.py | 266 - src/chronify/_vendor/kyuubi/pyhive/exc.py | 72 - src/chronify/_vendor/kyuubi/pyhive/hive.py | 620 -- src/chronify/_vendor/kyuubi/pyhive/presto.py | 367 - .../_vendor/kyuubi/pyhive/sasl_compat.py | 56 - .../_vendor/kyuubi/pyhive/sqlalchemy_hive.py | 435 - .../kyuubi/pyhive/sqlalchemy_presto.py | 256 - .../_vendor/kyuubi/pyhive/sqlalchemy_trino.py | 84 - src/chronify/_vendor/kyuubi/pyhive/trino.py | 144 - src/chronify/csv_io.py | 6 +- src/chronify/hive_functions.py | 34 - src/chronify/ibis/__init__.py | 43 + src/chronify/ibis/base.py | 148 + src/chronify/ibis/duckdb_backend.py | 111 + src/chronify/ibis/functions.py | 242 + src/chronify/ibis/spark_backend.py | 125 + src/chronify/ibis/sqlite_backend.py | 121 + src/chronify/ibis/types.py | 111 + src/chronify/models.py | 105 +- src/chronify/schema_manager.py | 133 +- src/chronify/sqlalchemy/__init__.py | 0 src/chronify/sqlalchemy/functions.py | 273 - src/chronify/store.py | 1285 +-- src/chronify/time_series_checker.py | 59 +- src/chronify/time_series_mapper.py | 20 +- src/chronify/time_series_mapper_base.py | 214 +- ...apper_column_representative_to_datetime.py | 78 +- src/chronify/time_series_mapper_datetime.py | 20 +- src/chronify/time_series_mapper_index_time.py | 62 +- .../time_series_mapper_representative.py | 36 +- src/chronify/time_zone_converter.py | 184 +- src/chronify/time_zone_localizer.py | 173 +- src/chronify/utils/sqlalchemy_table.py | 73 - src/chronify/utils/sqlalchemy_view.py | 69 - tests/conftest.py | 90 +- tests/test_checker_representative_time.py | 47 +- tests/test_csv_parser.py | 12 +- ...apper_column_representative_to_datetime.py | 74 +- tests/test_mapper_datetime_to_datetime.py | 70 +- tests/test_mapper_index_time_to_datetime.py | 54 +- ..._mapper_representative_time_to_datetime.py | 42 +- tests/test_models.py | 8 +- tests/test_store.py | 303 +- tests/test_time_series_checker.py | 62 +- tests/test_time_zone_converter.py | 62 +- tests/test_time_zone_localizer.py | 110 +- 56 files changed, 1790 insertions(+), 16941 deletions(-) delete mode 100644 src/chronify/_vendor/kyuubi/LICENSE delete mode 100644 src/chronify/_vendor/kyuubi/README.md delete mode 100755 src/chronify/_vendor/kyuubi/TCLIService/TCLIService-remote delete mode 100644 src/chronify/_vendor/kyuubi/TCLIService/TCLIService.py delete mode 100644 src/chronify/_vendor/kyuubi/TCLIService/__init__.py delete mode 100644 src/chronify/_vendor/kyuubi/TCLIService/constants.py delete mode 100644 src/chronify/_vendor/kyuubi/TCLIService/ttypes.py delete mode 100644 src/chronify/_vendor/kyuubi/pyhive/__init__.py delete mode 100644 src/chronify/_vendor/kyuubi/pyhive/common.py delete mode 100644 src/chronify/_vendor/kyuubi/pyhive/exc.py delete mode 100644 src/chronify/_vendor/kyuubi/pyhive/hive.py delete mode 100644 src/chronify/_vendor/kyuubi/pyhive/presto.py delete mode 100644 src/chronify/_vendor/kyuubi/pyhive/sasl_compat.py delete mode 100644 src/chronify/_vendor/kyuubi/pyhive/sqlalchemy_hive.py delete mode 100644 src/chronify/_vendor/kyuubi/pyhive/sqlalchemy_presto.py delete mode 100644 src/chronify/_vendor/kyuubi/pyhive/sqlalchemy_trino.py delete mode 100644 src/chronify/_vendor/kyuubi/pyhive/trino.py delete mode 100644 src/chronify/hive_functions.py create mode 100644 src/chronify/ibis/__init__.py create mode 100644 src/chronify/ibis/base.py create mode 100644 src/chronify/ibis/duckdb_backend.py create mode 100644 src/chronify/ibis/functions.py create mode 100644 src/chronify/ibis/spark_backend.py create mode 100644 src/chronify/ibis/sqlite_backend.py create mode 100644 src/chronify/ibis/types.py delete mode 100644 src/chronify/sqlalchemy/__init__.py delete mode 100644 src/chronify/sqlalchemy/functions.py delete mode 100644 src/chronify/utils/sqlalchemy_table.py delete mode 100644 src/chronify/utils/sqlalchemy_view.py diff --git a/pyproject.toml b/pyproject.toml index 5cb61f6..6b20b06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,24 +27,20 @@ classifiers = [ ] dependencies = [ "duckdb ~= 1.1.0", - "duckdb_engine", + "ibis-framework[duckdb,sqlite] >= 9.0", "loguru", "pandas >= 2.2, < 3", "pyarrow", "pydantic >= 2.7, < 3", "pytz", "rich", - "sqlalchemy == 2.0.37", "tzdata", - # Required by pyhive - "future", - "python-dateutil", ] [project.optional-dependencies] spark = [ - "thrift", - "thrift_sasl", + "ibis-framework[pyspark]", + "pyspark >= 4.0", ] dev = [ @@ -63,11 +59,6 @@ dev = [ "sphinx-tabs~=3.4", ] -[project.entry-points."sqlalchemy.dialects"] -hive = "pyhive.sqlalchemy_hive:HiveDialect" -"hive.http" = "pyhive.sqlalchemy_hive:HiveHTTPDialect" -"hive.https" = "pyhive.sqlalchemy_hive:HiveHTTPSDialect" - [project.urls] Documentation = "https://github.com/NREL/chronify#readme" Issues = "https://github.com/NREL/chronify/issues" @@ -77,9 +68,6 @@ Source = "https://github.com/NREL/chronify" files = [ "src", ] -exclude = [ - "src/chronify/_vendor/*", -] strict = true [tool.pytest.ini_options] @@ -99,7 +87,6 @@ exclude = [ "dist", "env", "venv", - "src/chronify/_vendor/*", ] line-length = 99 diff --git a/src/chronify/__init__.py b/src/chronify/__init__.py index 8a0c1a9..4da5acd 100644 --- a/src/chronify/__init__.py +++ b/src/chronify/__init__.py @@ -1,8 +1,5 @@ import importlib.metadata as metadata -import sys -from chronify._vendor.kyuubi import TCLIService -from chronify._vendor.kyuubi import pyhive from chronify.exceptions import ( ChronifyExceptionBase, ConflictingInputsError, @@ -61,8 +58,3 @@ ) __version__ = metadata.metadata("chronify")["Version"] - - -# Make pyhive importable as if it were installed separately. -sys.modules["pyhive"] = pyhive -sys.modules["TCLIService"] = TCLIService diff --git a/src/chronify/_vendor/kyuubi/LICENSE b/src/chronify/_vendor/kyuubi/LICENSE deleted file mode 100644 index f49a4e1..0000000 --- a/src/chronify/_vendor/kyuubi/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file diff --git a/src/chronify/_vendor/kyuubi/README.md b/src/chronify/_vendor/kyuubi/README.md deleted file mode 100644 index b71db55..0000000 --- a/src/chronify/_vendor/kyuubi/README.md +++ /dev/null @@ -1,12 +0,0 @@ -This source code is copied from https://github.com/apache/kyuubi.git, -commit ID 3b205a3924e0e3a75c425de1396089729cf22ee5. We did not modify the code. - -The pyhive package is marked as not supported, but we need it to work with Spark. -The latest published version of pyhive on pypi.org is not compatible with current versions of -sqlalchemy and Apache Spark. Specifically, we require the patch made in commit ID -a0b9873f817267675eab304f6935bafa4ab0f731. - -We have validated this version of pyhive with our use cases. We will remove this code as -soon as Kyuubi publishes an updated version on pypi.org. - -The pyhive license file is included here. diff --git a/src/chronify/_vendor/kyuubi/TCLIService/TCLIService-remote b/src/chronify/_vendor/kyuubi/TCLIService/TCLIService-remote deleted file mode 100755 index 8d875fa..0000000 --- a/src/chronify/_vendor/kyuubi/TCLIService/TCLIService-remote +++ /dev/null @@ -1,264 +0,0 @@ -#!/usr/bin/env python -# -# Autogenerated by Thrift Compiler (0.10.0) -# -# DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING -# -# options string: py -# - -import sys -import pprint -if sys.version_info[0] > 2: - from urllib.parse import urlparse -else: - from urlparse import urlparse -from thrift.transport import TTransport, TSocket, TSSLSocket, THttpClient -from thrift.protocol.TBinaryProtocol import TBinaryProtocol - -from TCLIService import TCLIService -from TCLIService.ttypes import * - -if len(sys.argv) <= 1 or sys.argv[1] == '--help': - print('') - print('Usage: ' + sys.argv[0] + ' [-h host[:port]] [-u url] [-f[ramed]] [-s[sl]] [-novalidate] [-ca_certs certs] [-keyfile keyfile] [-certfile certfile] function [arg1 [arg2...]]') - print('') - print('Functions:') - print(' TOpenSessionResp OpenSession(TOpenSessionReq req)') - print(' TCloseSessionResp CloseSession(TCloseSessionReq req)') - print(' TGetInfoResp GetInfo(TGetInfoReq req)') - print(' TExecuteStatementResp ExecuteStatement(TExecuteStatementReq req)') - print(' TGetTypeInfoResp GetTypeInfo(TGetTypeInfoReq req)') - print(' TGetCatalogsResp GetCatalogs(TGetCatalogsReq req)') - print(' TGetSchemasResp GetSchemas(TGetSchemasReq req)') - print(' TGetTablesResp GetTables(TGetTablesReq req)') - print(' TGetTableTypesResp GetTableTypes(TGetTableTypesReq req)') - print(' TGetColumnsResp GetColumns(TGetColumnsReq req)') - print(' TGetFunctionsResp GetFunctions(TGetFunctionsReq req)') - print(' TGetPrimaryKeysResp GetPrimaryKeys(TGetPrimaryKeysReq req)') - print(' TGetCrossReferenceResp GetCrossReference(TGetCrossReferenceReq req)') - print(' TGetOperationStatusResp GetOperationStatus(TGetOperationStatusReq req)') - print(' TCancelOperationResp CancelOperation(TCancelOperationReq req)') - print(' TCloseOperationResp CloseOperation(TCloseOperationReq req)') - print(' TGetResultSetMetadataResp GetResultSetMetadata(TGetResultSetMetadataReq req)') - print(' TFetchResultsResp FetchResults(TFetchResultsReq req)') - print(' TGetDelegationTokenResp GetDelegationToken(TGetDelegationTokenReq req)') - print(' TCancelDelegationTokenResp CancelDelegationToken(TCancelDelegationTokenReq req)') - print(' TRenewDelegationTokenResp RenewDelegationToken(TRenewDelegationTokenReq req)') - print(' TGetLogResp GetLog(TGetLogReq req)') - print('') - sys.exit(0) - -pp = pprint.PrettyPrinter(indent=2) -host = 'localhost' -port = 9090 -uri = '' -framed = False -ssl = False -validate = True -ca_certs = None -keyfile = None -certfile = None -http = False -argi = 1 - -if sys.argv[argi] == '-h': - parts = sys.argv[argi + 1].split(':') - host = parts[0] - if len(parts) > 1: - port = int(parts[1]) - argi += 2 - -if sys.argv[argi] == '-u': - url = urlparse(sys.argv[argi + 1]) - parts = url[1].split(':') - host = parts[0] - if len(parts) > 1: - port = int(parts[1]) - else: - port = 80 - uri = url[2] - if url[4]: - uri += '?%s' % url[4] - http = True - argi += 2 - -if sys.argv[argi] == '-f' or sys.argv[argi] == '-framed': - framed = True - argi += 1 - -if sys.argv[argi] == '-s' or sys.argv[argi] == '-ssl': - ssl = True - argi += 1 - -if sys.argv[argi] == '-novalidate': - validate = False - argi += 1 - -if sys.argv[argi] == '-ca_certs': - ca_certs = sys.argv[argi+1] - argi += 2 - -if sys.argv[argi] == '-keyfile': - keyfile = sys.argv[argi+1] - argi += 2 - -if sys.argv[argi] == '-certfile': - certfile = sys.argv[argi+1] - argi += 2 - -cmd = sys.argv[argi] -args = sys.argv[argi + 1:] - -if http: - transport = THttpClient.THttpClient(host, port, uri) -else: - if ssl: - socket = TSSLSocket.TSSLSocket(host, port, validate=validate, ca_certs=ca_certs, keyfile=keyfile, certfile=certfile) - else: - socket = TSocket.TSocket(host, port) - if framed: - transport = TTransport.TFramedTransport(socket) - else: - transport = TTransport.TBufferedTransport(socket) -protocol = TBinaryProtocol(transport) -client = TCLIService.Client(protocol) -transport.open() - -if cmd == 'OpenSession': - if len(args) != 1: - print('OpenSession requires 1 args') - sys.exit(1) - pp.pprint(client.OpenSession(eval(args[0]),)) - -elif cmd == 'CloseSession': - if len(args) != 1: - print('CloseSession requires 1 args') - sys.exit(1) - pp.pprint(client.CloseSession(eval(args[0]),)) - -elif cmd == 'GetInfo': - if len(args) != 1: - print('GetInfo requires 1 args') - sys.exit(1) - pp.pprint(client.GetInfo(eval(args[0]),)) - -elif cmd == 'ExecuteStatement': - if len(args) != 1: - print('ExecuteStatement requires 1 args') - sys.exit(1) - pp.pprint(client.ExecuteStatement(eval(args[0]),)) - -elif cmd == 'GetTypeInfo': - if len(args) != 1: - print('GetTypeInfo requires 1 args') - sys.exit(1) - pp.pprint(client.GetTypeInfo(eval(args[0]),)) - -elif cmd == 'GetCatalogs': - if len(args) != 1: - print('GetCatalogs requires 1 args') - sys.exit(1) - pp.pprint(client.GetCatalogs(eval(args[0]),)) - -elif cmd == 'GetSchemas': - if len(args) != 1: - print('GetSchemas requires 1 args') - sys.exit(1) - pp.pprint(client.GetSchemas(eval(args[0]),)) - -elif cmd == 'GetTables': - if len(args) != 1: - print('GetTables requires 1 args') - sys.exit(1) - pp.pprint(client.GetTables(eval(args[0]),)) - -elif cmd == 'GetTableTypes': - if len(args) != 1: - print('GetTableTypes requires 1 args') - sys.exit(1) - pp.pprint(client.GetTableTypes(eval(args[0]),)) - -elif cmd == 'GetColumns': - if len(args) != 1: - print('GetColumns requires 1 args') - sys.exit(1) - pp.pprint(client.GetColumns(eval(args[0]),)) - -elif cmd == 'GetFunctions': - if len(args) != 1: - print('GetFunctions requires 1 args') - sys.exit(1) - pp.pprint(client.GetFunctions(eval(args[0]),)) - -elif cmd == 'GetPrimaryKeys': - if len(args) != 1: - print('GetPrimaryKeys requires 1 args') - sys.exit(1) - pp.pprint(client.GetPrimaryKeys(eval(args[0]),)) - -elif cmd == 'GetCrossReference': - if len(args) != 1: - print('GetCrossReference requires 1 args') - sys.exit(1) - pp.pprint(client.GetCrossReference(eval(args[0]),)) - -elif cmd == 'GetOperationStatus': - if len(args) != 1: - print('GetOperationStatus requires 1 args') - sys.exit(1) - pp.pprint(client.GetOperationStatus(eval(args[0]),)) - -elif cmd == 'CancelOperation': - if len(args) != 1: - print('CancelOperation requires 1 args') - sys.exit(1) - pp.pprint(client.CancelOperation(eval(args[0]),)) - -elif cmd == 'CloseOperation': - if len(args) != 1: - print('CloseOperation requires 1 args') - sys.exit(1) - pp.pprint(client.CloseOperation(eval(args[0]),)) - -elif cmd == 'GetResultSetMetadata': - if len(args) != 1: - print('GetResultSetMetadata requires 1 args') - sys.exit(1) - pp.pprint(client.GetResultSetMetadata(eval(args[0]),)) - -elif cmd == 'FetchResults': - if len(args) != 1: - print('FetchResults requires 1 args') - sys.exit(1) - pp.pprint(client.FetchResults(eval(args[0]),)) - -elif cmd == 'GetDelegationToken': - if len(args) != 1: - print('GetDelegationToken requires 1 args') - sys.exit(1) - pp.pprint(client.GetDelegationToken(eval(args[0]),)) - -elif cmd == 'CancelDelegationToken': - if len(args) != 1: - print('CancelDelegationToken requires 1 args') - sys.exit(1) - pp.pprint(client.CancelDelegationToken(eval(args[0]),)) - -elif cmd == 'RenewDelegationToken': - if len(args) != 1: - print('RenewDelegationToken requires 1 args') - sys.exit(1) - pp.pprint(client.RenewDelegationToken(eval(args[0]),)) - -elif cmd == 'GetLog': - if len(args) != 1: - print('GetLog requires 1 args') - sys.exit(1) - pp.pprint(client.GetLog(eval(args[0]),)) - -else: - print('Unrecognized method %s' % cmd) - sys.exit(1) - -transport.close() diff --git a/src/chronify/_vendor/kyuubi/TCLIService/TCLIService.py b/src/chronify/_vendor/kyuubi/TCLIService/TCLIService.py deleted file mode 100644 index bb1415e..0000000 --- a/src/chronify/_vendor/kyuubi/TCLIService/TCLIService.py +++ /dev/null @@ -1,3986 +0,0 @@ -# -# Autogenerated by Thrift Compiler (0.10.0) -# -# DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING -# -# options string: py -# - -from thrift.Thrift import TType, TMessageType, TFrozenDict, TException, TApplicationException -from thrift.protocol.TProtocol import TProtocolException -import sys -import logging -from .ttypes import * -from thrift.Thrift import TProcessor -from thrift.transport import TTransport - - -class Iface(object): - def OpenSession(self, req): - """ - Parameters: - - req - """ - pass - - def CloseSession(self, req): - """ - Parameters: - - req - """ - pass - - def GetInfo(self, req): - """ - Parameters: - - req - """ - pass - - def ExecuteStatement(self, req): - """ - Parameters: - - req - """ - pass - - def GetTypeInfo(self, req): - """ - Parameters: - - req - """ - pass - - def GetCatalogs(self, req): - """ - Parameters: - - req - """ - pass - - def GetSchemas(self, req): - """ - Parameters: - - req - """ - pass - - def GetTables(self, req): - """ - Parameters: - - req - """ - pass - - def GetTableTypes(self, req): - """ - Parameters: - - req - """ - pass - - def GetColumns(self, req): - """ - Parameters: - - req - """ - pass - - def GetFunctions(self, req): - """ - Parameters: - - req - """ - pass - - def GetPrimaryKeys(self, req): - """ - Parameters: - - req - """ - pass - - def GetCrossReference(self, req): - """ - Parameters: - - req - """ - pass - - def GetOperationStatus(self, req): - """ - Parameters: - - req - """ - pass - - def CancelOperation(self, req): - """ - Parameters: - - req - """ - pass - - def CloseOperation(self, req): - """ - Parameters: - - req - """ - pass - - def GetResultSetMetadata(self, req): - """ - Parameters: - - req - """ - pass - - def FetchResults(self, req): - """ - Parameters: - - req - """ - pass - - def GetDelegationToken(self, req): - """ - Parameters: - - req - """ - pass - - def CancelDelegationToken(self, req): - """ - Parameters: - - req - """ - pass - - def RenewDelegationToken(self, req): - """ - Parameters: - - req - """ - pass - - def GetLog(self, req): - """ - Parameters: - - req - """ - pass - - -class Client(Iface): - def __init__(self, iprot, oprot=None): - self._iprot = self._oprot = iprot - if oprot is not None: - self._oprot = oprot - self._seqid = 0 - - def OpenSession(self, req): - """ - Parameters: - - req - """ - self.send_OpenSession(req) - return self.recv_OpenSession() - - def send_OpenSession(self, req): - self._oprot.writeMessageBegin('OpenSession', TMessageType.CALL, self._seqid) - args = OpenSession_args() - args.req = req - args.write(self._oprot) - self._oprot.writeMessageEnd() - self._oprot.trans.flush() - - def recv_OpenSession(self): - iprot = self._iprot - (fname, mtype, rseqid) = iprot.readMessageBegin() - if mtype == TMessageType.EXCEPTION: - x = TApplicationException() - x.read(iprot) - iprot.readMessageEnd() - raise x - result = OpenSession_result() - result.read(iprot) - iprot.readMessageEnd() - if result.success is not None: - return result.success - raise TApplicationException(TApplicationException.MISSING_RESULT, "OpenSession failed: unknown result") - - def CloseSession(self, req): - """ - Parameters: - - req - """ - self.send_CloseSession(req) - return self.recv_CloseSession() - - def send_CloseSession(self, req): - self._oprot.writeMessageBegin('CloseSession', TMessageType.CALL, self._seqid) - args = CloseSession_args() - args.req = req - args.write(self._oprot) - self._oprot.writeMessageEnd() - self._oprot.trans.flush() - - def recv_CloseSession(self): - iprot = self._iprot - (fname, mtype, rseqid) = iprot.readMessageBegin() - if mtype == TMessageType.EXCEPTION: - x = TApplicationException() - x.read(iprot) - iprot.readMessageEnd() - raise x - result = CloseSession_result() - result.read(iprot) - iprot.readMessageEnd() - if result.success is not None: - return result.success - raise TApplicationException(TApplicationException.MISSING_RESULT, "CloseSession failed: unknown result") - - def GetInfo(self, req): - """ - Parameters: - - req - """ - self.send_GetInfo(req) - return self.recv_GetInfo() - - def send_GetInfo(self, req): - self._oprot.writeMessageBegin('GetInfo', TMessageType.CALL, self._seqid) - args = GetInfo_args() - args.req = req - args.write(self._oprot) - self._oprot.writeMessageEnd() - self._oprot.trans.flush() - - def recv_GetInfo(self): - iprot = self._iprot - (fname, mtype, rseqid) = iprot.readMessageBegin() - if mtype == TMessageType.EXCEPTION: - x = TApplicationException() - x.read(iprot) - iprot.readMessageEnd() - raise x - result = GetInfo_result() - result.read(iprot) - iprot.readMessageEnd() - if result.success is not None: - return result.success - raise TApplicationException(TApplicationException.MISSING_RESULT, "GetInfo failed: unknown result") - - def ExecuteStatement(self, req): - """ - Parameters: - - req - """ - self.send_ExecuteStatement(req) - return self.recv_ExecuteStatement() - - def send_ExecuteStatement(self, req): - self._oprot.writeMessageBegin('ExecuteStatement', TMessageType.CALL, self._seqid) - args = ExecuteStatement_args() - args.req = req - args.write(self._oprot) - self._oprot.writeMessageEnd() - self._oprot.trans.flush() - - def recv_ExecuteStatement(self): - iprot = self._iprot - (fname, mtype, rseqid) = iprot.readMessageBegin() - if mtype == TMessageType.EXCEPTION: - x = TApplicationException() - x.read(iprot) - iprot.readMessageEnd() - raise x - result = ExecuteStatement_result() - result.read(iprot) - iprot.readMessageEnd() - if result.success is not None: - return result.success - raise TApplicationException(TApplicationException.MISSING_RESULT, "ExecuteStatement failed: unknown result") - - def GetTypeInfo(self, req): - """ - Parameters: - - req - """ - self.send_GetTypeInfo(req) - return self.recv_GetTypeInfo() - - def send_GetTypeInfo(self, req): - self._oprot.writeMessageBegin('GetTypeInfo', TMessageType.CALL, self._seqid) - args = GetTypeInfo_args() - args.req = req - args.write(self._oprot) - self._oprot.writeMessageEnd() - self._oprot.trans.flush() - - def recv_GetTypeInfo(self): - iprot = self._iprot - (fname, mtype, rseqid) = iprot.readMessageBegin() - if mtype == TMessageType.EXCEPTION: - x = TApplicationException() - x.read(iprot) - iprot.readMessageEnd() - raise x - result = GetTypeInfo_result() - result.read(iprot) - iprot.readMessageEnd() - if result.success is not None: - return result.success - raise TApplicationException(TApplicationException.MISSING_RESULT, "GetTypeInfo failed: unknown result") - - def GetCatalogs(self, req): - """ - Parameters: - - req - """ - self.send_GetCatalogs(req) - return self.recv_GetCatalogs() - - def send_GetCatalogs(self, req): - self._oprot.writeMessageBegin('GetCatalogs', TMessageType.CALL, self._seqid) - args = GetCatalogs_args() - args.req = req - args.write(self._oprot) - self._oprot.writeMessageEnd() - self._oprot.trans.flush() - - def recv_GetCatalogs(self): - iprot = self._iprot - (fname, mtype, rseqid) = iprot.readMessageBegin() - if mtype == TMessageType.EXCEPTION: - x = TApplicationException() - x.read(iprot) - iprot.readMessageEnd() - raise x - result = GetCatalogs_result() - result.read(iprot) - iprot.readMessageEnd() - if result.success is not None: - return result.success - raise TApplicationException(TApplicationException.MISSING_RESULT, "GetCatalogs failed: unknown result") - - def GetSchemas(self, req): - """ - Parameters: - - req - """ - self.send_GetSchemas(req) - return self.recv_GetSchemas() - - def send_GetSchemas(self, req): - self._oprot.writeMessageBegin('GetSchemas', TMessageType.CALL, self._seqid) - args = GetSchemas_args() - args.req = req - args.write(self._oprot) - self._oprot.writeMessageEnd() - self._oprot.trans.flush() - - def recv_GetSchemas(self): - iprot = self._iprot - (fname, mtype, rseqid) = iprot.readMessageBegin() - if mtype == TMessageType.EXCEPTION: - x = TApplicationException() - x.read(iprot) - iprot.readMessageEnd() - raise x - result = GetSchemas_result() - result.read(iprot) - iprot.readMessageEnd() - if result.success is not None: - return result.success - raise TApplicationException(TApplicationException.MISSING_RESULT, "GetSchemas failed: unknown result") - - def GetTables(self, req): - """ - Parameters: - - req - """ - self.send_GetTables(req) - return self.recv_GetTables() - - def send_GetTables(self, req): - self._oprot.writeMessageBegin('GetTables', TMessageType.CALL, self._seqid) - args = GetTables_args() - args.req = req - args.write(self._oprot) - self._oprot.writeMessageEnd() - self._oprot.trans.flush() - - def recv_GetTables(self): - iprot = self._iprot - (fname, mtype, rseqid) = iprot.readMessageBegin() - if mtype == TMessageType.EXCEPTION: - x = TApplicationException() - x.read(iprot) - iprot.readMessageEnd() - raise x - result = GetTables_result() - result.read(iprot) - iprot.readMessageEnd() - if result.success is not None: - return result.success - raise TApplicationException(TApplicationException.MISSING_RESULT, "GetTables failed: unknown result") - - def GetTableTypes(self, req): - """ - Parameters: - - req - """ - self.send_GetTableTypes(req) - return self.recv_GetTableTypes() - - def send_GetTableTypes(self, req): - self._oprot.writeMessageBegin('GetTableTypes', TMessageType.CALL, self._seqid) - args = GetTableTypes_args() - args.req = req - args.write(self._oprot) - self._oprot.writeMessageEnd() - self._oprot.trans.flush() - - def recv_GetTableTypes(self): - iprot = self._iprot - (fname, mtype, rseqid) = iprot.readMessageBegin() - if mtype == TMessageType.EXCEPTION: - x = TApplicationException() - x.read(iprot) - iprot.readMessageEnd() - raise x - result = GetTableTypes_result() - result.read(iprot) - iprot.readMessageEnd() - if result.success is not None: - return result.success - raise TApplicationException(TApplicationException.MISSING_RESULT, "GetTableTypes failed: unknown result") - - def GetColumns(self, req): - """ - Parameters: - - req - """ - self.send_GetColumns(req) - return self.recv_GetColumns() - - def send_GetColumns(self, req): - self._oprot.writeMessageBegin('GetColumns', TMessageType.CALL, self._seqid) - args = GetColumns_args() - args.req = req - args.write(self._oprot) - self._oprot.writeMessageEnd() - self._oprot.trans.flush() - - def recv_GetColumns(self): - iprot = self._iprot - (fname, mtype, rseqid) = iprot.readMessageBegin() - if mtype == TMessageType.EXCEPTION: - x = TApplicationException() - x.read(iprot) - iprot.readMessageEnd() - raise x - result = GetColumns_result() - result.read(iprot) - iprot.readMessageEnd() - if result.success is not None: - return result.success - raise TApplicationException(TApplicationException.MISSING_RESULT, "GetColumns failed: unknown result") - - def GetFunctions(self, req): - """ - Parameters: - - req - """ - self.send_GetFunctions(req) - return self.recv_GetFunctions() - - def send_GetFunctions(self, req): - self._oprot.writeMessageBegin('GetFunctions', TMessageType.CALL, self._seqid) - args = GetFunctions_args() - args.req = req - args.write(self._oprot) - self._oprot.writeMessageEnd() - self._oprot.trans.flush() - - def recv_GetFunctions(self): - iprot = self._iprot - (fname, mtype, rseqid) = iprot.readMessageBegin() - if mtype == TMessageType.EXCEPTION: - x = TApplicationException() - x.read(iprot) - iprot.readMessageEnd() - raise x - result = GetFunctions_result() - result.read(iprot) - iprot.readMessageEnd() - if result.success is not None: - return result.success - raise TApplicationException(TApplicationException.MISSING_RESULT, "GetFunctions failed: unknown result") - - def GetPrimaryKeys(self, req): - """ - Parameters: - - req - """ - self.send_GetPrimaryKeys(req) - return self.recv_GetPrimaryKeys() - - def send_GetPrimaryKeys(self, req): - self._oprot.writeMessageBegin('GetPrimaryKeys', TMessageType.CALL, self._seqid) - args = GetPrimaryKeys_args() - args.req = req - args.write(self._oprot) - self._oprot.writeMessageEnd() - self._oprot.trans.flush() - - def recv_GetPrimaryKeys(self): - iprot = self._iprot - (fname, mtype, rseqid) = iprot.readMessageBegin() - if mtype == TMessageType.EXCEPTION: - x = TApplicationException() - x.read(iprot) - iprot.readMessageEnd() - raise x - result = GetPrimaryKeys_result() - result.read(iprot) - iprot.readMessageEnd() - if result.success is not None: - return result.success - raise TApplicationException(TApplicationException.MISSING_RESULT, "GetPrimaryKeys failed: unknown result") - - def GetCrossReference(self, req): - """ - Parameters: - - req - """ - self.send_GetCrossReference(req) - return self.recv_GetCrossReference() - - def send_GetCrossReference(self, req): - self._oprot.writeMessageBegin('GetCrossReference', TMessageType.CALL, self._seqid) - args = GetCrossReference_args() - args.req = req - args.write(self._oprot) - self._oprot.writeMessageEnd() - self._oprot.trans.flush() - - def recv_GetCrossReference(self): - iprot = self._iprot - (fname, mtype, rseqid) = iprot.readMessageBegin() - if mtype == TMessageType.EXCEPTION: - x = TApplicationException() - x.read(iprot) - iprot.readMessageEnd() - raise x - result = GetCrossReference_result() - result.read(iprot) - iprot.readMessageEnd() - if result.success is not None: - return result.success - raise TApplicationException(TApplicationException.MISSING_RESULT, "GetCrossReference failed: unknown result") - - def GetOperationStatus(self, req): - """ - Parameters: - - req - """ - self.send_GetOperationStatus(req) - return self.recv_GetOperationStatus() - - def send_GetOperationStatus(self, req): - self._oprot.writeMessageBegin('GetOperationStatus', TMessageType.CALL, self._seqid) - args = GetOperationStatus_args() - args.req = req - args.write(self._oprot) - self._oprot.writeMessageEnd() - self._oprot.trans.flush() - - def recv_GetOperationStatus(self): - iprot = self._iprot - (fname, mtype, rseqid) = iprot.readMessageBegin() - if mtype == TMessageType.EXCEPTION: - x = TApplicationException() - x.read(iprot) - iprot.readMessageEnd() - raise x - result = GetOperationStatus_result() - result.read(iprot) - iprot.readMessageEnd() - if result.success is not None: - return result.success - raise TApplicationException(TApplicationException.MISSING_RESULT, "GetOperationStatus failed: unknown result") - - def CancelOperation(self, req): - """ - Parameters: - - req - """ - self.send_CancelOperation(req) - return self.recv_CancelOperation() - - def send_CancelOperation(self, req): - self._oprot.writeMessageBegin('CancelOperation', TMessageType.CALL, self._seqid) - args = CancelOperation_args() - args.req = req - args.write(self._oprot) - self._oprot.writeMessageEnd() - self._oprot.trans.flush() - - def recv_CancelOperation(self): - iprot = self._iprot - (fname, mtype, rseqid) = iprot.readMessageBegin() - if mtype == TMessageType.EXCEPTION: - x = TApplicationException() - x.read(iprot) - iprot.readMessageEnd() - raise x - result = CancelOperation_result() - result.read(iprot) - iprot.readMessageEnd() - if result.success is not None: - return result.success - raise TApplicationException(TApplicationException.MISSING_RESULT, "CancelOperation failed: unknown result") - - def CloseOperation(self, req): - """ - Parameters: - - req - """ - self.send_CloseOperation(req) - return self.recv_CloseOperation() - - def send_CloseOperation(self, req): - self._oprot.writeMessageBegin('CloseOperation', TMessageType.CALL, self._seqid) - args = CloseOperation_args() - args.req = req - args.write(self._oprot) - self._oprot.writeMessageEnd() - self._oprot.trans.flush() - - def recv_CloseOperation(self): - iprot = self._iprot - (fname, mtype, rseqid) = iprot.readMessageBegin() - if mtype == TMessageType.EXCEPTION: - x = TApplicationException() - x.read(iprot) - iprot.readMessageEnd() - raise x - result = CloseOperation_result() - result.read(iprot) - iprot.readMessageEnd() - if result.success is not None: - return result.success - raise TApplicationException(TApplicationException.MISSING_RESULT, "CloseOperation failed: unknown result") - - def GetResultSetMetadata(self, req): - """ - Parameters: - - req - """ - self.send_GetResultSetMetadata(req) - return self.recv_GetResultSetMetadata() - - def send_GetResultSetMetadata(self, req): - self._oprot.writeMessageBegin('GetResultSetMetadata', TMessageType.CALL, self._seqid) - args = GetResultSetMetadata_args() - args.req = req - args.write(self._oprot) - self._oprot.writeMessageEnd() - self._oprot.trans.flush() - - def recv_GetResultSetMetadata(self): - iprot = self._iprot - (fname, mtype, rseqid) = iprot.readMessageBegin() - if mtype == TMessageType.EXCEPTION: - x = TApplicationException() - x.read(iprot) - iprot.readMessageEnd() - raise x - result = GetResultSetMetadata_result() - result.read(iprot) - iprot.readMessageEnd() - if result.success is not None: - return result.success - raise TApplicationException(TApplicationException.MISSING_RESULT, "GetResultSetMetadata failed: unknown result") - - def FetchResults(self, req): - """ - Parameters: - - req - """ - self.send_FetchResults(req) - return self.recv_FetchResults() - - def send_FetchResults(self, req): - self._oprot.writeMessageBegin('FetchResults', TMessageType.CALL, self._seqid) - args = FetchResults_args() - args.req = req - args.write(self._oprot) - self._oprot.writeMessageEnd() - self._oprot.trans.flush() - - def recv_FetchResults(self): - iprot = self._iprot - (fname, mtype, rseqid) = iprot.readMessageBegin() - if mtype == TMessageType.EXCEPTION: - x = TApplicationException() - x.read(iprot) - iprot.readMessageEnd() - raise x - result = FetchResults_result() - result.read(iprot) - iprot.readMessageEnd() - if result.success is not None: - return result.success - raise TApplicationException(TApplicationException.MISSING_RESULT, "FetchResults failed: unknown result") - - def GetDelegationToken(self, req): - """ - Parameters: - - req - """ - self.send_GetDelegationToken(req) - return self.recv_GetDelegationToken() - - def send_GetDelegationToken(self, req): - self._oprot.writeMessageBegin('GetDelegationToken', TMessageType.CALL, self._seqid) - args = GetDelegationToken_args() - args.req = req - args.write(self._oprot) - self._oprot.writeMessageEnd() - self._oprot.trans.flush() - - def recv_GetDelegationToken(self): - iprot = self._iprot - (fname, mtype, rseqid) = iprot.readMessageBegin() - if mtype == TMessageType.EXCEPTION: - x = TApplicationException() - x.read(iprot) - iprot.readMessageEnd() - raise x - result = GetDelegationToken_result() - result.read(iprot) - iprot.readMessageEnd() - if result.success is not None: - return result.success - raise TApplicationException(TApplicationException.MISSING_RESULT, "GetDelegationToken failed: unknown result") - - def CancelDelegationToken(self, req): - """ - Parameters: - - req - """ - self.send_CancelDelegationToken(req) - return self.recv_CancelDelegationToken() - - def send_CancelDelegationToken(self, req): - self._oprot.writeMessageBegin('CancelDelegationToken', TMessageType.CALL, self._seqid) - args = CancelDelegationToken_args() - args.req = req - args.write(self._oprot) - self._oprot.writeMessageEnd() - self._oprot.trans.flush() - - def recv_CancelDelegationToken(self): - iprot = self._iprot - (fname, mtype, rseqid) = iprot.readMessageBegin() - if mtype == TMessageType.EXCEPTION: - x = TApplicationException() - x.read(iprot) - iprot.readMessageEnd() - raise x - result = CancelDelegationToken_result() - result.read(iprot) - iprot.readMessageEnd() - if result.success is not None: - return result.success - raise TApplicationException(TApplicationException.MISSING_RESULT, "CancelDelegationToken failed: unknown result") - - def RenewDelegationToken(self, req): - """ - Parameters: - - req - """ - self.send_RenewDelegationToken(req) - return self.recv_RenewDelegationToken() - - def send_RenewDelegationToken(self, req): - self._oprot.writeMessageBegin('RenewDelegationToken', TMessageType.CALL, self._seqid) - args = RenewDelegationToken_args() - args.req = req - args.write(self._oprot) - self._oprot.writeMessageEnd() - self._oprot.trans.flush() - - def recv_RenewDelegationToken(self): - iprot = self._iprot - (fname, mtype, rseqid) = iprot.readMessageBegin() - if mtype == TMessageType.EXCEPTION: - x = TApplicationException() - x.read(iprot) - iprot.readMessageEnd() - raise x - result = RenewDelegationToken_result() - result.read(iprot) - iprot.readMessageEnd() - if result.success is not None: - return result.success - raise TApplicationException(TApplicationException.MISSING_RESULT, "RenewDelegationToken failed: unknown result") - - def GetLog(self, req): - """ - Parameters: - - req - """ - self.send_GetLog(req) - return self.recv_GetLog() - - def send_GetLog(self, req): - self._oprot.writeMessageBegin('GetLog', TMessageType.CALL, self._seqid) - args = GetLog_args() - args.req = req - args.write(self._oprot) - self._oprot.writeMessageEnd() - self._oprot.trans.flush() - - def recv_GetLog(self): - iprot = self._iprot - (fname, mtype, rseqid) = iprot.readMessageBegin() - if mtype == TMessageType.EXCEPTION: - x = TApplicationException() - x.read(iprot) - iprot.readMessageEnd() - raise x - result = GetLog_result() - result.read(iprot) - iprot.readMessageEnd() - if result.success is not None: - return result.success - raise TApplicationException(TApplicationException.MISSING_RESULT, "GetLog failed: unknown result") - - -class Processor(Iface, TProcessor): - def __init__(self, handler): - self._handler = handler - self._processMap = {} - self._processMap["OpenSession"] = Processor.process_OpenSession - self._processMap["CloseSession"] = Processor.process_CloseSession - self._processMap["GetInfo"] = Processor.process_GetInfo - self._processMap["ExecuteStatement"] = Processor.process_ExecuteStatement - self._processMap["GetTypeInfo"] = Processor.process_GetTypeInfo - self._processMap["GetCatalogs"] = Processor.process_GetCatalogs - self._processMap["GetSchemas"] = Processor.process_GetSchemas - self._processMap["GetTables"] = Processor.process_GetTables - self._processMap["GetTableTypes"] = Processor.process_GetTableTypes - self._processMap["GetColumns"] = Processor.process_GetColumns - self._processMap["GetFunctions"] = Processor.process_GetFunctions - self._processMap["GetPrimaryKeys"] = Processor.process_GetPrimaryKeys - self._processMap["GetCrossReference"] = Processor.process_GetCrossReference - self._processMap["GetOperationStatus"] = Processor.process_GetOperationStatus - self._processMap["CancelOperation"] = Processor.process_CancelOperation - self._processMap["CloseOperation"] = Processor.process_CloseOperation - self._processMap["GetResultSetMetadata"] = Processor.process_GetResultSetMetadata - self._processMap["FetchResults"] = Processor.process_FetchResults - self._processMap["GetDelegationToken"] = Processor.process_GetDelegationToken - self._processMap["CancelDelegationToken"] = Processor.process_CancelDelegationToken - self._processMap["RenewDelegationToken"] = Processor.process_RenewDelegationToken - self._processMap["GetLog"] = Processor.process_GetLog - - def process(self, iprot, oprot): - (name, type, seqid) = iprot.readMessageBegin() - if name not in self._processMap: - iprot.skip(TType.STRUCT) - iprot.readMessageEnd() - x = TApplicationException(TApplicationException.UNKNOWN_METHOD, 'Unknown function %s' % (name)) - oprot.writeMessageBegin(name, TMessageType.EXCEPTION, seqid) - x.write(oprot) - oprot.writeMessageEnd() - oprot.trans.flush() - return - else: - self._processMap[name](self, seqid, iprot, oprot) - return True - - def process_OpenSession(self, seqid, iprot, oprot): - args = OpenSession_args() - args.read(iprot) - iprot.readMessageEnd() - result = OpenSession_result() - try: - result.success = self._handler.OpenSession(args.req) - msg_type = TMessageType.REPLY - except (TTransport.TTransportException, KeyboardInterrupt, SystemExit): - raise - except Exception as ex: - msg_type = TMessageType.EXCEPTION - logging.exception(ex) - result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') - oprot.writeMessageBegin("OpenSession", msg_type, seqid) - result.write(oprot) - oprot.writeMessageEnd() - oprot.trans.flush() - - def process_CloseSession(self, seqid, iprot, oprot): - args = CloseSession_args() - args.read(iprot) - iprot.readMessageEnd() - result = CloseSession_result() - try: - result.success = self._handler.CloseSession(args.req) - msg_type = TMessageType.REPLY - except (TTransport.TTransportException, KeyboardInterrupt, SystemExit): - raise - except Exception as ex: - msg_type = TMessageType.EXCEPTION - logging.exception(ex) - result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') - oprot.writeMessageBegin("CloseSession", msg_type, seqid) - result.write(oprot) - oprot.writeMessageEnd() - oprot.trans.flush() - - def process_GetInfo(self, seqid, iprot, oprot): - args = GetInfo_args() - args.read(iprot) - iprot.readMessageEnd() - result = GetInfo_result() - try: - result.success = self._handler.GetInfo(args.req) - msg_type = TMessageType.REPLY - except (TTransport.TTransportException, KeyboardInterrupt, SystemExit): - raise - except Exception as ex: - msg_type = TMessageType.EXCEPTION - logging.exception(ex) - result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') - oprot.writeMessageBegin("GetInfo", msg_type, seqid) - result.write(oprot) - oprot.writeMessageEnd() - oprot.trans.flush() - - def process_ExecuteStatement(self, seqid, iprot, oprot): - args = ExecuteStatement_args() - args.read(iprot) - iprot.readMessageEnd() - result = ExecuteStatement_result() - try: - result.success = self._handler.ExecuteStatement(args.req) - msg_type = TMessageType.REPLY - except (TTransport.TTransportException, KeyboardInterrupt, SystemExit): - raise - except Exception as ex: - msg_type = TMessageType.EXCEPTION - logging.exception(ex) - result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') - oprot.writeMessageBegin("ExecuteStatement", msg_type, seqid) - result.write(oprot) - oprot.writeMessageEnd() - oprot.trans.flush() - - def process_GetTypeInfo(self, seqid, iprot, oprot): - args = GetTypeInfo_args() - args.read(iprot) - iprot.readMessageEnd() - result = GetTypeInfo_result() - try: - result.success = self._handler.GetTypeInfo(args.req) - msg_type = TMessageType.REPLY - except (TTransport.TTransportException, KeyboardInterrupt, SystemExit): - raise - except Exception as ex: - msg_type = TMessageType.EXCEPTION - logging.exception(ex) - result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') - oprot.writeMessageBegin("GetTypeInfo", msg_type, seqid) - result.write(oprot) - oprot.writeMessageEnd() - oprot.trans.flush() - - def process_GetCatalogs(self, seqid, iprot, oprot): - args = GetCatalogs_args() - args.read(iprot) - iprot.readMessageEnd() - result = GetCatalogs_result() - try: - result.success = self._handler.GetCatalogs(args.req) - msg_type = TMessageType.REPLY - except (TTransport.TTransportException, KeyboardInterrupt, SystemExit): - raise - except Exception as ex: - msg_type = TMessageType.EXCEPTION - logging.exception(ex) - result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') - oprot.writeMessageBegin("GetCatalogs", msg_type, seqid) - result.write(oprot) - oprot.writeMessageEnd() - oprot.trans.flush() - - def process_GetSchemas(self, seqid, iprot, oprot): - args = GetSchemas_args() - args.read(iprot) - iprot.readMessageEnd() - result = GetSchemas_result() - try: - result.success = self._handler.GetSchemas(args.req) - msg_type = TMessageType.REPLY - except (TTransport.TTransportException, KeyboardInterrupt, SystemExit): - raise - except Exception as ex: - msg_type = TMessageType.EXCEPTION - logging.exception(ex) - result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') - oprot.writeMessageBegin("GetSchemas", msg_type, seqid) - result.write(oprot) - oprot.writeMessageEnd() - oprot.trans.flush() - - def process_GetTables(self, seqid, iprot, oprot): - args = GetTables_args() - args.read(iprot) - iprot.readMessageEnd() - result = GetTables_result() - try: - result.success = self._handler.GetTables(args.req) - msg_type = TMessageType.REPLY - except (TTransport.TTransportException, KeyboardInterrupt, SystemExit): - raise - except Exception as ex: - msg_type = TMessageType.EXCEPTION - logging.exception(ex) - result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') - oprot.writeMessageBegin("GetTables", msg_type, seqid) - result.write(oprot) - oprot.writeMessageEnd() - oprot.trans.flush() - - def process_GetTableTypes(self, seqid, iprot, oprot): - args = GetTableTypes_args() - args.read(iprot) - iprot.readMessageEnd() - result = GetTableTypes_result() - try: - result.success = self._handler.GetTableTypes(args.req) - msg_type = TMessageType.REPLY - except (TTransport.TTransportException, KeyboardInterrupt, SystemExit): - raise - except Exception as ex: - msg_type = TMessageType.EXCEPTION - logging.exception(ex) - result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') - oprot.writeMessageBegin("GetTableTypes", msg_type, seqid) - result.write(oprot) - oprot.writeMessageEnd() - oprot.trans.flush() - - def process_GetColumns(self, seqid, iprot, oprot): - args = GetColumns_args() - args.read(iprot) - iprot.readMessageEnd() - result = GetColumns_result() - try: - result.success = self._handler.GetColumns(args.req) - msg_type = TMessageType.REPLY - except (TTransport.TTransportException, KeyboardInterrupt, SystemExit): - raise - except Exception as ex: - msg_type = TMessageType.EXCEPTION - logging.exception(ex) - result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') - oprot.writeMessageBegin("GetColumns", msg_type, seqid) - result.write(oprot) - oprot.writeMessageEnd() - oprot.trans.flush() - - def process_GetFunctions(self, seqid, iprot, oprot): - args = GetFunctions_args() - args.read(iprot) - iprot.readMessageEnd() - result = GetFunctions_result() - try: - result.success = self._handler.GetFunctions(args.req) - msg_type = TMessageType.REPLY - except (TTransport.TTransportException, KeyboardInterrupt, SystemExit): - raise - except Exception as ex: - msg_type = TMessageType.EXCEPTION - logging.exception(ex) - result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') - oprot.writeMessageBegin("GetFunctions", msg_type, seqid) - result.write(oprot) - oprot.writeMessageEnd() - oprot.trans.flush() - - def process_GetPrimaryKeys(self, seqid, iprot, oprot): - args = GetPrimaryKeys_args() - args.read(iprot) - iprot.readMessageEnd() - result = GetPrimaryKeys_result() - try: - result.success = self._handler.GetPrimaryKeys(args.req) - msg_type = TMessageType.REPLY - except (TTransport.TTransportException, KeyboardInterrupt, SystemExit): - raise - except Exception as ex: - msg_type = TMessageType.EXCEPTION - logging.exception(ex) - result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') - oprot.writeMessageBegin("GetPrimaryKeys", msg_type, seqid) - result.write(oprot) - oprot.writeMessageEnd() - oprot.trans.flush() - - def process_GetCrossReference(self, seqid, iprot, oprot): - args = GetCrossReference_args() - args.read(iprot) - iprot.readMessageEnd() - result = GetCrossReference_result() - try: - result.success = self._handler.GetCrossReference(args.req) - msg_type = TMessageType.REPLY - except (TTransport.TTransportException, KeyboardInterrupt, SystemExit): - raise - except Exception as ex: - msg_type = TMessageType.EXCEPTION - logging.exception(ex) - result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') - oprot.writeMessageBegin("GetCrossReference", msg_type, seqid) - result.write(oprot) - oprot.writeMessageEnd() - oprot.trans.flush() - - def process_GetOperationStatus(self, seqid, iprot, oprot): - args = GetOperationStatus_args() - args.read(iprot) - iprot.readMessageEnd() - result = GetOperationStatus_result() - try: - result.success = self._handler.GetOperationStatus(args.req) - msg_type = TMessageType.REPLY - except (TTransport.TTransportException, KeyboardInterrupt, SystemExit): - raise - except Exception as ex: - msg_type = TMessageType.EXCEPTION - logging.exception(ex) - result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') - oprot.writeMessageBegin("GetOperationStatus", msg_type, seqid) - result.write(oprot) - oprot.writeMessageEnd() - oprot.trans.flush() - - def process_CancelOperation(self, seqid, iprot, oprot): - args = CancelOperation_args() - args.read(iprot) - iprot.readMessageEnd() - result = CancelOperation_result() - try: - result.success = self._handler.CancelOperation(args.req) - msg_type = TMessageType.REPLY - except (TTransport.TTransportException, KeyboardInterrupt, SystemExit): - raise - except Exception as ex: - msg_type = TMessageType.EXCEPTION - logging.exception(ex) - result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') - oprot.writeMessageBegin("CancelOperation", msg_type, seqid) - result.write(oprot) - oprot.writeMessageEnd() - oprot.trans.flush() - - def process_CloseOperation(self, seqid, iprot, oprot): - args = CloseOperation_args() - args.read(iprot) - iprot.readMessageEnd() - result = CloseOperation_result() - try: - result.success = self._handler.CloseOperation(args.req) - msg_type = TMessageType.REPLY - except (TTransport.TTransportException, KeyboardInterrupt, SystemExit): - raise - except Exception as ex: - msg_type = TMessageType.EXCEPTION - logging.exception(ex) - result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') - oprot.writeMessageBegin("CloseOperation", msg_type, seqid) - result.write(oprot) - oprot.writeMessageEnd() - oprot.trans.flush() - - def process_GetResultSetMetadata(self, seqid, iprot, oprot): - args = GetResultSetMetadata_args() - args.read(iprot) - iprot.readMessageEnd() - result = GetResultSetMetadata_result() - try: - result.success = self._handler.GetResultSetMetadata(args.req) - msg_type = TMessageType.REPLY - except (TTransport.TTransportException, KeyboardInterrupt, SystemExit): - raise - except Exception as ex: - msg_type = TMessageType.EXCEPTION - logging.exception(ex) - result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') - oprot.writeMessageBegin("GetResultSetMetadata", msg_type, seqid) - result.write(oprot) - oprot.writeMessageEnd() - oprot.trans.flush() - - def process_FetchResults(self, seqid, iprot, oprot): - args = FetchResults_args() - args.read(iprot) - iprot.readMessageEnd() - result = FetchResults_result() - try: - result.success = self._handler.FetchResults(args.req) - msg_type = TMessageType.REPLY - except (TTransport.TTransportException, KeyboardInterrupt, SystemExit): - raise - except Exception as ex: - msg_type = TMessageType.EXCEPTION - logging.exception(ex) - result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') - oprot.writeMessageBegin("FetchResults", msg_type, seqid) - result.write(oprot) - oprot.writeMessageEnd() - oprot.trans.flush() - - def process_GetDelegationToken(self, seqid, iprot, oprot): - args = GetDelegationToken_args() - args.read(iprot) - iprot.readMessageEnd() - result = GetDelegationToken_result() - try: - result.success = self._handler.GetDelegationToken(args.req) - msg_type = TMessageType.REPLY - except (TTransport.TTransportException, KeyboardInterrupt, SystemExit): - raise - except Exception as ex: - msg_type = TMessageType.EXCEPTION - logging.exception(ex) - result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') - oprot.writeMessageBegin("GetDelegationToken", msg_type, seqid) - result.write(oprot) - oprot.writeMessageEnd() - oprot.trans.flush() - - def process_CancelDelegationToken(self, seqid, iprot, oprot): - args = CancelDelegationToken_args() - args.read(iprot) - iprot.readMessageEnd() - result = CancelDelegationToken_result() - try: - result.success = self._handler.CancelDelegationToken(args.req) - msg_type = TMessageType.REPLY - except (TTransport.TTransportException, KeyboardInterrupt, SystemExit): - raise - except Exception as ex: - msg_type = TMessageType.EXCEPTION - logging.exception(ex) - result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') - oprot.writeMessageBegin("CancelDelegationToken", msg_type, seqid) - result.write(oprot) - oprot.writeMessageEnd() - oprot.trans.flush() - - def process_RenewDelegationToken(self, seqid, iprot, oprot): - args = RenewDelegationToken_args() - args.read(iprot) - iprot.readMessageEnd() - result = RenewDelegationToken_result() - try: - result.success = self._handler.RenewDelegationToken(args.req) - msg_type = TMessageType.REPLY - except (TTransport.TTransportException, KeyboardInterrupt, SystemExit): - raise - except Exception as ex: - msg_type = TMessageType.EXCEPTION - logging.exception(ex) - result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') - oprot.writeMessageBegin("RenewDelegationToken", msg_type, seqid) - result.write(oprot) - oprot.writeMessageEnd() - oprot.trans.flush() - - def process_GetLog(self, seqid, iprot, oprot): - args = GetLog_args() - args.read(iprot) - iprot.readMessageEnd() - result = GetLog_result() - try: - result.success = self._handler.GetLog(args.req) - msg_type = TMessageType.REPLY - except (TTransport.TTransportException, KeyboardInterrupt, SystemExit): - raise - except Exception as ex: - msg_type = TMessageType.EXCEPTION - logging.exception(ex) - result = TApplicationException(TApplicationException.INTERNAL_ERROR, 'Internal error') - oprot.writeMessageBegin("GetLog", msg_type, seqid) - result.write(oprot) - oprot.writeMessageEnd() - oprot.trans.flush() - -# HELPER FUNCTIONS AND STRUCTURES - - -class OpenSession_args(object): - """ - Attributes: - - req - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'req', (TOpenSessionReq, TOpenSessionReq.thrift_spec), None, ), # 1 - ) - - def __init__(self, req=None,): - self.req = req - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.req = TOpenSessionReq() - self.req.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('OpenSession_args') - if self.req is not None: - oprot.writeFieldBegin('req', TType.STRUCT, 1) - self.req.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class OpenSession_result(object): - """ - Attributes: - - success - """ - - thrift_spec = ( - (0, TType.STRUCT, 'success', (TOpenSessionResp, TOpenSessionResp.thrift_spec), None, ), # 0 - ) - - def __init__(self, success=None,): - self.success = success - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 0: - if ftype == TType.STRUCT: - self.success = TOpenSessionResp() - self.success.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('OpenSession_result') - if self.success is not None: - oprot.writeFieldBegin('success', TType.STRUCT, 0) - self.success.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class CloseSession_args(object): - """ - Attributes: - - req - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'req', (TCloseSessionReq, TCloseSessionReq.thrift_spec), None, ), # 1 - ) - - def __init__(self, req=None,): - self.req = req - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.req = TCloseSessionReq() - self.req.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('CloseSession_args') - if self.req is not None: - oprot.writeFieldBegin('req', TType.STRUCT, 1) - self.req.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class CloseSession_result(object): - """ - Attributes: - - success - """ - - thrift_spec = ( - (0, TType.STRUCT, 'success', (TCloseSessionResp, TCloseSessionResp.thrift_spec), None, ), # 0 - ) - - def __init__(self, success=None,): - self.success = success - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 0: - if ftype == TType.STRUCT: - self.success = TCloseSessionResp() - self.success.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('CloseSession_result') - if self.success is not None: - oprot.writeFieldBegin('success', TType.STRUCT, 0) - self.success.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetInfo_args(object): - """ - Attributes: - - req - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'req', (TGetInfoReq, TGetInfoReq.thrift_spec), None, ), # 1 - ) - - def __init__(self, req=None,): - self.req = req - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.req = TGetInfoReq() - self.req.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetInfo_args') - if self.req is not None: - oprot.writeFieldBegin('req', TType.STRUCT, 1) - self.req.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetInfo_result(object): - """ - Attributes: - - success - """ - - thrift_spec = ( - (0, TType.STRUCT, 'success', (TGetInfoResp, TGetInfoResp.thrift_spec), None, ), # 0 - ) - - def __init__(self, success=None,): - self.success = success - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 0: - if ftype == TType.STRUCT: - self.success = TGetInfoResp() - self.success.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetInfo_result') - if self.success is not None: - oprot.writeFieldBegin('success', TType.STRUCT, 0) - self.success.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class ExecuteStatement_args(object): - """ - Attributes: - - req - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'req', (TExecuteStatementReq, TExecuteStatementReq.thrift_spec), None, ), # 1 - ) - - def __init__(self, req=None,): - self.req = req - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.req = TExecuteStatementReq() - self.req.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('ExecuteStatement_args') - if self.req is not None: - oprot.writeFieldBegin('req', TType.STRUCT, 1) - self.req.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class ExecuteStatement_result(object): - """ - Attributes: - - success - """ - - thrift_spec = ( - (0, TType.STRUCT, 'success', (TExecuteStatementResp, TExecuteStatementResp.thrift_spec), None, ), # 0 - ) - - def __init__(self, success=None,): - self.success = success - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 0: - if ftype == TType.STRUCT: - self.success = TExecuteStatementResp() - self.success.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('ExecuteStatement_result') - if self.success is not None: - oprot.writeFieldBegin('success', TType.STRUCT, 0) - self.success.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetTypeInfo_args(object): - """ - Attributes: - - req - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'req', (TGetTypeInfoReq, TGetTypeInfoReq.thrift_spec), None, ), # 1 - ) - - def __init__(self, req=None,): - self.req = req - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.req = TGetTypeInfoReq() - self.req.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetTypeInfo_args') - if self.req is not None: - oprot.writeFieldBegin('req', TType.STRUCT, 1) - self.req.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetTypeInfo_result(object): - """ - Attributes: - - success - """ - - thrift_spec = ( - (0, TType.STRUCT, 'success', (TGetTypeInfoResp, TGetTypeInfoResp.thrift_spec), None, ), # 0 - ) - - def __init__(self, success=None,): - self.success = success - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 0: - if ftype == TType.STRUCT: - self.success = TGetTypeInfoResp() - self.success.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetTypeInfo_result') - if self.success is not None: - oprot.writeFieldBegin('success', TType.STRUCT, 0) - self.success.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetCatalogs_args(object): - """ - Attributes: - - req - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'req', (TGetCatalogsReq, TGetCatalogsReq.thrift_spec), None, ), # 1 - ) - - def __init__(self, req=None,): - self.req = req - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.req = TGetCatalogsReq() - self.req.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetCatalogs_args') - if self.req is not None: - oprot.writeFieldBegin('req', TType.STRUCT, 1) - self.req.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetCatalogs_result(object): - """ - Attributes: - - success - """ - - thrift_spec = ( - (0, TType.STRUCT, 'success', (TGetCatalogsResp, TGetCatalogsResp.thrift_spec), None, ), # 0 - ) - - def __init__(self, success=None,): - self.success = success - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 0: - if ftype == TType.STRUCT: - self.success = TGetCatalogsResp() - self.success.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetCatalogs_result') - if self.success is not None: - oprot.writeFieldBegin('success', TType.STRUCT, 0) - self.success.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetSchemas_args(object): - """ - Attributes: - - req - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'req', (TGetSchemasReq, TGetSchemasReq.thrift_spec), None, ), # 1 - ) - - def __init__(self, req=None,): - self.req = req - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.req = TGetSchemasReq() - self.req.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetSchemas_args') - if self.req is not None: - oprot.writeFieldBegin('req', TType.STRUCT, 1) - self.req.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetSchemas_result(object): - """ - Attributes: - - success - """ - - thrift_spec = ( - (0, TType.STRUCT, 'success', (TGetSchemasResp, TGetSchemasResp.thrift_spec), None, ), # 0 - ) - - def __init__(self, success=None,): - self.success = success - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 0: - if ftype == TType.STRUCT: - self.success = TGetSchemasResp() - self.success.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetSchemas_result') - if self.success is not None: - oprot.writeFieldBegin('success', TType.STRUCT, 0) - self.success.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetTables_args(object): - """ - Attributes: - - req - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'req', (TGetTablesReq, TGetTablesReq.thrift_spec), None, ), # 1 - ) - - def __init__(self, req=None,): - self.req = req - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.req = TGetTablesReq() - self.req.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetTables_args') - if self.req is not None: - oprot.writeFieldBegin('req', TType.STRUCT, 1) - self.req.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetTables_result(object): - """ - Attributes: - - success - """ - - thrift_spec = ( - (0, TType.STRUCT, 'success', (TGetTablesResp, TGetTablesResp.thrift_spec), None, ), # 0 - ) - - def __init__(self, success=None,): - self.success = success - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 0: - if ftype == TType.STRUCT: - self.success = TGetTablesResp() - self.success.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetTables_result') - if self.success is not None: - oprot.writeFieldBegin('success', TType.STRUCT, 0) - self.success.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetTableTypes_args(object): - """ - Attributes: - - req - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'req', (TGetTableTypesReq, TGetTableTypesReq.thrift_spec), None, ), # 1 - ) - - def __init__(self, req=None,): - self.req = req - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.req = TGetTableTypesReq() - self.req.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetTableTypes_args') - if self.req is not None: - oprot.writeFieldBegin('req', TType.STRUCT, 1) - self.req.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetTableTypes_result(object): - """ - Attributes: - - success - """ - - thrift_spec = ( - (0, TType.STRUCT, 'success', (TGetTableTypesResp, TGetTableTypesResp.thrift_spec), None, ), # 0 - ) - - def __init__(self, success=None,): - self.success = success - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 0: - if ftype == TType.STRUCT: - self.success = TGetTableTypesResp() - self.success.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetTableTypes_result') - if self.success is not None: - oprot.writeFieldBegin('success', TType.STRUCT, 0) - self.success.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetColumns_args(object): - """ - Attributes: - - req - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'req', (TGetColumnsReq, TGetColumnsReq.thrift_spec), None, ), # 1 - ) - - def __init__(self, req=None,): - self.req = req - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.req = TGetColumnsReq() - self.req.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetColumns_args') - if self.req is not None: - oprot.writeFieldBegin('req', TType.STRUCT, 1) - self.req.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetColumns_result(object): - """ - Attributes: - - success - """ - - thrift_spec = ( - (0, TType.STRUCT, 'success', (TGetColumnsResp, TGetColumnsResp.thrift_spec), None, ), # 0 - ) - - def __init__(self, success=None,): - self.success = success - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 0: - if ftype == TType.STRUCT: - self.success = TGetColumnsResp() - self.success.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetColumns_result') - if self.success is not None: - oprot.writeFieldBegin('success', TType.STRUCT, 0) - self.success.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetFunctions_args(object): - """ - Attributes: - - req - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'req', (TGetFunctionsReq, TGetFunctionsReq.thrift_spec), None, ), # 1 - ) - - def __init__(self, req=None,): - self.req = req - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.req = TGetFunctionsReq() - self.req.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetFunctions_args') - if self.req is not None: - oprot.writeFieldBegin('req', TType.STRUCT, 1) - self.req.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetFunctions_result(object): - """ - Attributes: - - success - """ - - thrift_spec = ( - (0, TType.STRUCT, 'success', (TGetFunctionsResp, TGetFunctionsResp.thrift_spec), None, ), # 0 - ) - - def __init__(self, success=None,): - self.success = success - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 0: - if ftype == TType.STRUCT: - self.success = TGetFunctionsResp() - self.success.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetFunctions_result') - if self.success is not None: - oprot.writeFieldBegin('success', TType.STRUCT, 0) - self.success.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetPrimaryKeys_args(object): - """ - Attributes: - - req - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'req', (TGetPrimaryKeysReq, TGetPrimaryKeysReq.thrift_spec), None, ), # 1 - ) - - def __init__(self, req=None,): - self.req = req - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.req = TGetPrimaryKeysReq() - self.req.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetPrimaryKeys_args') - if self.req is not None: - oprot.writeFieldBegin('req', TType.STRUCT, 1) - self.req.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetPrimaryKeys_result(object): - """ - Attributes: - - success - """ - - thrift_spec = ( - (0, TType.STRUCT, 'success', (TGetPrimaryKeysResp, TGetPrimaryKeysResp.thrift_spec), None, ), # 0 - ) - - def __init__(self, success=None,): - self.success = success - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 0: - if ftype == TType.STRUCT: - self.success = TGetPrimaryKeysResp() - self.success.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetPrimaryKeys_result') - if self.success is not None: - oprot.writeFieldBegin('success', TType.STRUCT, 0) - self.success.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetCrossReference_args(object): - """ - Attributes: - - req - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'req', (TGetCrossReferenceReq, TGetCrossReferenceReq.thrift_spec), None, ), # 1 - ) - - def __init__(self, req=None,): - self.req = req - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.req = TGetCrossReferenceReq() - self.req.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetCrossReference_args') - if self.req is not None: - oprot.writeFieldBegin('req', TType.STRUCT, 1) - self.req.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetCrossReference_result(object): - """ - Attributes: - - success - """ - - thrift_spec = ( - (0, TType.STRUCT, 'success', (TGetCrossReferenceResp, TGetCrossReferenceResp.thrift_spec), None, ), # 0 - ) - - def __init__(self, success=None,): - self.success = success - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 0: - if ftype == TType.STRUCT: - self.success = TGetCrossReferenceResp() - self.success.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetCrossReference_result') - if self.success is not None: - oprot.writeFieldBegin('success', TType.STRUCT, 0) - self.success.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetOperationStatus_args(object): - """ - Attributes: - - req - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'req', (TGetOperationStatusReq, TGetOperationStatusReq.thrift_spec), None, ), # 1 - ) - - def __init__(self, req=None,): - self.req = req - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.req = TGetOperationStatusReq() - self.req.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetOperationStatus_args') - if self.req is not None: - oprot.writeFieldBegin('req', TType.STRUCT, 1) - self.req.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetOperationStatus_result(object): - """ - Attributes: - - success - """ - - thrift_spec = ( - (0, TType.STRUCT, 'success', (TGetOperationStatusResp, TGetOperationStatusResp.thrift_spec), None, ), # 0 - ) - - def __init__(self, success=None,): - self.success = success - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 0: - if ftype == TType.STRUCT: - self.success = TGetOperationStatusResp() - self.success.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetOperationStatus_result') - if self.success is not None: - oprot.writeFieldBegin('success', TType.STRUCT, 0) - self.success.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class CancelOperation_args(object): - """ - Attributes: - - req - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'req', (TCancelOperationReq, TCancelOperationReq.thrift_spec), None, ), # 1 - ) - - def __init__(self, req=None,): - self.req = req - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.req = TCancelOperationReq() - self.req.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('CancelOperation_args') - if self.req is not None: - oprot.writeFieldBegin('req', TType.STRUCT, 1) - self.req.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class CancelOperation_result(object): - """ - Attributes: - - success - """ - - thrift_spec = ( - (0, TType.STRUCT, 'success', (TCancelOperationResp, TCancelOperationResp.thrift_spec), None, ), # 0 - ) - - def __init__(self, success=None,): - self.success = success - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 0: - if ftype == TType.STRUCT: - self.success = TCancelOperationResp() - self.success.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('CancelOperation_result') - if self.success is not None: - oprot.writeFieldBegin('success', TType.STRUCT, 0) - self.success.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class CloseOperation_args(object): - """ - Attributes: - - req - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'req', (TCloseOperationReq, TCloseOperationReq.thrift_spec), None, ), # 1 - ) - - def __init__(self, req=None,): - self.req = req - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.req = TCloseOperationReq() - self.req.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('CloseOperation_args') - if self.req is not None: - oprot.writeFieldBegin('req', TType.STRUCT, 1) - self.req.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class CloseOperation_result(object): - """ - Attributes: - - success - """ - - thrift_spec = ( - (0, TType.STRUCT, 'success', (TCloseOperationResp, TCloseOperationResp.thrift_spec), None, ), # 0 - ) - - def __init__(self, success=None,): - self.success = success - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 0: - if ftype == TType.STRUCT: - self.success = TCloseOperationResp() - self.success.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('CloseOperation_result') - if self.success is not None: - oprot.writeFieldBegin('success', TType.STRUCT, 0) - self.success.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetResultSetMetadata_args(object): - """ - Attributes: - - req - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'req', (TGetResultSetMetadataReq, TGetResultSetMetadataReq.thrift_spec), None, ), # 1 - ) - - def __init__(self, req=None,): - self.req = req - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.req = TGetResultSetMetadataReq() - self.req.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetResultSetMetadata_args') - if self.req is not None: - oprot.writeFieldBegin('req', TType.STRUCT, 1) - self.req.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetResultSetMetadata_result(object): - """ - Attributes: - - success - """ - - thrift_spec = ( - (0, TType.STRUCT, 'success', (TGetResultSetMetadataResp, TGetResultSetMetadataResp.thrift_spec), None, ), # 0 - ) - - def __init__(self, success=None,): - self.success = success - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 0: - if ftype == TType.STRUCT: - self.success = TGetResultSetMetadataResp() - self.success.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetResultSetMetadata_result') - if self.success is not None: - oprot.writeFieldBegin('success', TType.STRUCT, 0) - self.success.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class FetchResults_args(object): - """ - Attributes: - - req - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'req', (TFetchResultsReq, TFetchResultsReq.thrift_spec), None, ), # 1 - ) - - def __init__(self, req=None,): - self.req = req - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.req = TFetchResultsReq() - self.req.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('FetchResults_args') - if self.req is not None: - oprot.writeFieldBegin('req', TType.STRUCT, 1) - self.req.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class FetchResults_result(object): - """ - Attributes: - - success - """ - - thrift_spec = ( - (0, TType.STRUCT, 'success', (TFetchResultsResp, TFetchResultsResp.thrift_spec), None, ), # 0 - ) - - def __init__(self, success=None,): - self.success = success - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 0: - if ftype == TType.STRUCT: - self.success = TFetchResultsResp() - self.success.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('FetchResults_result') - if self.success is not None: - oprot.writeFieldBegin('success', TType.STRUCT, 0) - self.success.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetDelegationToken_args(object): - """ - Attributes: - - req - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'req', (TGetDelegationTokenReq, TGetDelegationTokenReq.thrift_spec), None, ), # 1 - ) - - def __init__(self, req=None,): - self.req = req - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.req = TGetDelegationTokenReq() - self.req.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetDelegationToken_args') - if self.req is not None: - oprot.writeFieldBegin('req', TType.STRUCT, 1) - self.req.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetDelegationToken_result(object): - """ - Attributes: - - success - """ - - thrift_spec = ( - (0, TType.STRUCT, 'success', (TGetDelegationTokenResp, TGetDelegationTokenResp.thrift_spec), None, ), # 0 - ) - - def __init__(self, success=None,): - self.success = success - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 0: - if ftype == TType.STRUCT: - self.success = TGetDelegationTokenResp() - self.success.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetDelegationToken_result') - if self.success is not None: - oprot.writeFieldBegin('success', TType.STRUCT, 0) - self.success.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class CancelDelegationToken_args(object): - """ - Attributes: - - req - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'req', (TCancelDelegationTokenReq, TCancelDelegationTokenReq.thrift_spec), None, ), # 1 - ) - - def __init__(self, req=None,): - self.req = req - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.req = TCancelDelegationTokenReq() - self.req.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('CancelDelegationToken_args') - if self.req is not None: - oprot.writeFieldBegin('req', TType.STRUCT, 1) - self.req.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class CancelDelegationToken_result(object): - """ - Attributes: - - success - """ - - thrift_spec = ( - (0, TType.STRUCT, 'success', (TCancelDelegationTokenResp, TCancelDelegationTokenResp.thrift_spec), None, ), # 0 - ) - - def __init__(self, success=None,): - self.success = success - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 0: - if ftype == TType.STRUCT: - self.success = TCancelDelegationTokenResp() - self.success.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('CancelDelegationToken_result') - if self.success is not None: - oprot.writeFieldBegin('success', TType.STRUCT, 0) - self.success.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class RenewDelegationToken_args(object): - """ - Attributes: - - req - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'req', (TRenewDelegationTokenReq, TRenewDelegationTokenReq.thrift_spec), None, ), # 1 - ) - - def __init__(self, req=None,): - self.req = req - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.req = TRenewDelegationTokenReq() - self.req.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('RenewDelegationToken_args') - if self.req is not None: - oprot.writeFieldBegin('req', TType.STRUCT, 1) - self.req.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class RenewDelegationToken_result(object): - """ - Attributes: - - success - """ - - thrift_spec = ( - (0, TType.STRUCT, 'success', (TRenewDelegationTokenResp, TRenewDelegationTokenResp.thrift_spec), None, ), # 0 - ) - - def __init__(self, success=None,): - self.success = success - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 0: - if ftype == TType.STRUCT: - self.success = TRenewDelegationTokenResp() - self.success.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('RenewDelegationToken_result') - if self.success is not None: - oprot.writeFieldBegin('success', TType.STRUCT, 0) - self.success.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetLog_args(object): - """ - Attributes: - - req - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'req', (TGetLogReq, TGetLogReq.thrift_spec), None, ), # 1 - ) - - def __init__(self, req=None,): - self.req = req - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.req = TGetLogReq() - self.req.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetLog_args') - if self.req is not None: - oprot.writeFieldBegin('req', TType.STRUCT, 1) - self.req.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class GetLog_result(object): - """ - Attributes: - - success - """ - - thrift_spec = ( - (0, TType.STRUCT, 'success', (TGetLogResp, TGetLogResp.thrift_spec), None, ), # 0 - ) - - def __init__(self, success=None,): - self.success = success - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 0: - if ftype == TType.STRUCT: - self.success = TGetLogResp() - self.success.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('GetLog_result') - if self.success is not None: - oprot.writeFieldBegin('success', TType.STRUCT, 0) - self.success.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) diff --git a/src/chronify/_vendor/kyuubi/TCLIService/__init__.py b/src/chronify/_vendor/kyuubi/TCLIService/__init__.py deleted file mode 100644 index fe5e7c2..0000000 --- a/src/chronify/_vendor/kyuubi/TCLIService/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__all__ = ['ttypes', 'constants', 'TCLIService'] diff --git a/src/chronify/_vendor/kyuubi/TCLIService/constants.py b/src/chronify/_vendor/kyuubi/TCLIService/constants.py deleted file mode 100644 index b6bf88e..0000000 --- a/src/chronify/_vendor/kyuubi/TCLIService/constants.py +++ /dev/null @@ -1,68 +0,0 @@ -# -# Autogenerated by Thrift Compiler (0.10.0) -# -# DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING -# -# options string: py -# - -from thrift.Thrift import TType, TMessageType, TFrozenDict, TException, TApplicationException -from thrift.protocol.TProtocol import TProtocolException -import sys -from .ttypes import * -PRIMITIVE_TYPES = set(( - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 15, - 16, - 17, - 18, - 19, - 20, - 21, -)) -COMPLEX_TYPES = set(( - 10, - 11, - 12, - 13, - 14, -)) -COLLECTION_TYPES = set(( - 10, - 11, -)) -TYPE_NAMES = { - 0: "BOOLEAN", - 1: "TINYINT", - 2: "SMALLINT", - 3: "INT", - 4: "BIGINT", - 5: "FLOAT", - 6: "DOUBLE", - 7: "STRING", - 8: "TIMESTAMP", - 9: "BINARY", - 10: "ARRAY", - 11: "MAP", - 12: "STRUCT", - 13: "UNIONTYPE", - 15: "DECIMAL", - 16: "NULL", - 17: "DATE", - 18: "VARCHAR", - 19: "CHAR", - 20: "INTERVAL_YEAR_MONTH", - 21: "INTERVAL_DAY_TIME", -} -CHARACTER_MAXIMUM_LENGTH = "characterMaximumLength" -PRECISION = "precision" -SCALE = "scale" diff --git a/src/chronify/_vendor/kyuubi/TCLIService/ttypes.py b/src/chronify/_vendor/kyuubi/TCLIService/ttypes.py deleted file mode 100644 index 573bd04..0000000 --- a/src/chronify/_vendor/kyuubi/TCLIService/ttypes.py +++ /dev/null @@ -1,7210 +0,0 @@ -# -# Autogenerated by Thrift Compiler (0.10.0) -# -# DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING -# -# options string: py -# - -from thrift.Thrift import TType, TMessageType, TFrozenDict, TException, TApplicationException -from thrift.protocol.TProtocol import TProtocolException -import sys - -from thrift.transport import TTransport - - -class TProtocolVersion(object): - HIVE_CLI_SERVICE_PROTOCOL_V1 = 0 - HIVE_CLI_SERVICE_PROTOCOL_V2 = 1 - HIVE_CLI_SERVICE_PROTOCOL_V3 = 2 - HIVE_CLI_SERVICE_PROTOCOL_V4 = 3 - HIVE_CLI_SERVICE_PROTOCOL_V5 = 4 - HIVE_CLI_SERVICE_PROTOCOL_V6 = 5 - HIVE_CLI_SERVICE_PROTOCOL_V7 = 6 - HIVE_CLI_SERVICE_PROTOCOL_V8 = 7 - HIVE_CLI_SERVICE_PROTOCOL_V9 = 8 - HIVE_CLI_SERVICE_PROTOCOL_V10 = 9 - - _VALUES_TO_NAMES = { - 0: "HIVE_CLI_SERVICE_PROTOCOL_V1", - 1: "HIVE_CLI_SERVICE_PROTOCOL_V2", - 2: "HIVE_CLI_SERVICE_PROTOCOL_V3", - 3: "HIVE_CLI_SERVICE_PROTOCOL_V4", - 4: "HIVE_CLI_SERVICE_PROTOCOL_V5", - 5: "HIVE_CLI_SERVICE_PROTOCOL_V6", - 6: "HIVE_CLI_SERVICE_PROTOCOL_V7", - 7: "HIVE_CLI_SERVICE_PROTOCOL_V8", - 8: "HIVE_CLI_SERVICE_PROTOCOL_V9", - 9: "HIVE_CLI_SERVICE_PROTOCOL_V10", - } - - _NAMES_TO_VALUES = { - "HIVE_CLI_SERVICE_PROTOCOL_V1": 0, - "HIVE_CLI_SERVICE_PROTOCOL_V2": 1, - "HIVE_CLI_SERVICE_PROTOCOL_V3": 2, - "HIVE_CLI_SERVICE_PROTOCOL_V4": 3, - "HIVE_CLI_SERVICE_PROTOCOL_V5": 4, - "HIVE_CLI_SERVICE_PROTOCOL_V6": 5, - "HIVE_CLI_SERVICE_PROTOCOL_V7": 6, - "HIVE_CLI_SERVICE_PROTOCOL_V8": 7, - "HIVE_CLI_SERVICE_PROTOCOL_V9": 8, - "HIVE_CLI_SERVICE_PROTOCOL_V10": 9, - } - - -class TTypeId(object): - BOOLEAN_TYPE = 0 - TINYINT_TYPE = 1 - SMALLINT_TYPE = 2 - INT_TYPE = 3 - BIGINT_TYPE = 4 - FLOAT_TYPE = 5 - DOUBLE_TYPE = 6 - STRING_TYPE = 7 - TIMESTAMP_TYPE = 8 - BINARY_TYPE = 9 - ARRAY_TYPE = 10 - MAP_TYPE = 11 - STRUCT_TYPE = 12 - UNION_TYPE = 13 - USER_DEFINED_TYPE = 14 - DECIMAL_TYPE = 15 - NULL_TYPE = 16 - DATE_TYPE = 17 - VARCHAR_TYPE = 18 - CHAR_TYPE = 19 - INTERVAL_YEAR_MONTH_TYPE = 20 - INTERVAL_DAY_TIME_TYPE = 21 - - _VALUES_TO_NAMES = { - 0: "BOOLEAN_TYPE", - 1: "TINYINT_TYPE", - 2: "SMALLINT_TYPE", - 3: "INT_TYPE", - 4: "BIGINT_TYPE", - 5: "FLOAT_TYPE", - 6: "DOUBLE_TYPE", - 7: "STRING_TYPE", - 8: "TIMESTAMP_TYPE", - 9: "BINARY_TYPE", - 10: "ARRAY_TYPE", - 11: "MAP_TYPE", - 12: "STRUCT_TYPE", - 13: "UNION_TYPE", - 14: "USER_DEFINED_TYPE", - 15: "DECIMAL_TYPE", - 16: "NULL_TYPE", - 17: "DATE_TYPE", - 18: "VARCHAR_TYPE", - 19: "CHAR_TYPE", - 20: "INTERVAL_YEAR_MONTH_TYPE", - 21: "INTERVAL_DAY_TIME_TYPE", - } - - _NAMES_TO_VALUES = { - "BOOLEAN_TYPE": 0, - "TINYINT_TYPE": 1, - "SMALLINT_TYPE": 2, - "INT_TYPE": 3, - "BIGINT_TYPE": 4, - "FLOAT_TYPE": 5, - "DOUBLE_TYPE": 6, - "STRING_TYPE": 7, - "TIMESTAMP_TYPE": 8, - "BINARY_TYPE": 9, - "ARRAY_TYPE": 10, - "MAP_TYPE": 11, - "STRUCT_TYPE": 12, - "UNION_TYPE": 13, - "USER_DEFINED_TYPE": 14, - "DECIMAL_TYPE": 15, - "NULL_TYPE": 16, - "DATE_TYPE": 17, - "VARCHAR_TYPE": 18, - "CHAR_TYPE": 19, - "INTERVAL_YEAR_MONTH_TYPE": 20, - "INTERVAL_DAY_TIME_TYPE": 21, - } - - -class TStatusCode(object): - SUCCESS_STATUS = 0 - SUCCESS_WITH_INFO_STATUS = 1 - STILL_EXECUTING_STATUS = 2 - ERROR_STATUS = 3 - INVALID_HANDLE_STATUS = 4 - - _VALUES_TO_NAMES = { - 0: "SUCCESS_STATUS", - 1: "SUCCESS_WITH_INFO_STATUS", - 2: "STILL_EXECUTING_STATUS", - 3: "ERROR_STATUS", - 4: "INVALID_HANDLE_STATUS", - } - - _NAMES_TO_VALUES = { - "SUCCESS_STATUS": 0, - "SUCCESS_WITH_INFO_STATUS": 1, - "STILL_EXECUTING_STATUS": 2, - "ERROR_STATUS": 3, - "INVALID_HANDLE_STATUS": 4, - } - - -class TOperationState(object): - INITIALIZED_STATE = 0 - RUNNING_STATE = 1 - FINISHED_STATE = 2 - CANCELED_STATE = 3 - CLOSED_STATE = 4 - ERROR_STATE = 5 - UKNOWN_STATE = 6 - PENDING_STATE = 7 - TIMEDOUT_STATE = 8 - - _VALUES_TO_NAMES = { - 0: "INITIALIZED_STATE", - 1: "RUNNING_STATE", - 2: "FINISHED_STATE", - 3: "CANCELED_STATE", - 4: "CLOSED_STATE", - 5: "ERROR_STATE", - 6: "UKNOWN_STATE", - 7: "PENDING_STATE", - 8: "TIMEDOUT_STATE", - } - - _NAMES_TO_VALUES = { - "INITIALIZED_STATE": 0, - "RUNNING_STATE": 1, - "FINISHED_STATE": 2, - "CANCELED_STATE": 3, - "CLOSED_STATE": 4, - "ERROR_STATE": 5, - "UKNOWN_STATE": 6, - "PENDING_STATE": 7, - "TIMEDOUT_STATE": 8, - } - - -class TOperationType(object): - EXECUTE_STATEMENT = 0 - GET_TYPE_INFO = 1 - GET_CATALOGS = 2 - GET_SCHEMAS = 3 - GET_TABLES = 4 - GET_TABLE_TYPES = 5 - GET_COLUMNS = 6 - GET_FUNCTIONS = 7 - UNKNOWN = 8 - - _VALUES_TO_NAMES = { - 0: "EXECUTE_STATEMENT", - 1: "GET_TYPE_INFO", - 2: "GET_CATALOGS", - 3: "GET_SCHEMAS", - 4: "GET_TABLES", - 5: "GET_TABLE_TYPES", - 6: "GET_COLUMNS", - 7: "GET_FUNCTIONS", - 8: "UNKNOWN", - } - - _NAMES_TO_VALUES = { - "EXECUTE_STATEMENT": 0, - "GET_TYPE_INFO": 1, - "GET_CATALOGS": 2, - "GET_SCHEMAS": 3, - "GET_TABLES": 4, - "GET_TABLE_TYPES": 5, - "GET_COLUMNS": 6, - "GET_FUNCTIONS": 7, - "UNKNOWN": 8, - } - - -class TGetInfoType(object): - CLI_MAX_DRIVER_CONNECTIONS = 0 - CLI_MAX_CONCURRENT_ACTIVITIES = 1 - CLI_DATA_SOURCE_NAME = 2 - CLI_FETCH_DIRECTION = 8 - CLI_SERVER_NAME = 13 - CLI_SEARCH_PATTERN_ESCAPE = 14 - CLI_DBMS_NAME = 17 - CLI_DBMS_VER = 18 - CLI_ACCESSIBLE_TABLES = 19 - CLI_ACCESSIBLE_PROCEDURES = 20 - CLI_CURSOR_COMMIT_BEHAVIOR = 23 - CLI_DATA_SOURCE_READ_ONLY = 25 - CLI_DEFAULT_TXN_ISOLATION = 26 - CLI_IDENTIFIER_CASE = 28 - CLI_IDENTIFIER_QUOTE_CHAR = 29 - CLI_MAX_COLUMN_NAME_LEN = 30 - CLI_MAX_CURSOR_NAME_LEN = 31 - CLI_MAX_SCHEMA_NAME_LEN = 32 - CLI_MAX_CATALOG_NAME_LEN = 34 - CLI_MAX_TABLE_NAME_LEN = 35 - CLI_SCROLL_CONCURRENCY = 43 - CLI_TXN_CAPABLE = 46 - CLI_USER_NAME = 47 - CLI_TXN_ISOLATION_OPTION = 72 - CLI_INTEGRITY = 73 - CLI_GETDATA_EXTENSIONS = 81 - CLI_NULL_COLLATION = 85 - CLI_ALTER_TABLE = 86 - CLI_ORDER_BY_COLUMNS_IN_SELECT = 90 - CLI_SPECIAL_CHARACTERS = 94 - CLI_MAX_COLUMNS_IN_GROUP_BY = 97 - CLI_MAX_COLUMNS_IN_INDEX = 98 - CLI_MAX_COLUMNS_IN_ORDER_BY = 99 - CLI_MAX_COLUMNS_IN_SELECT = 100 - CLI_MAX_COLUMNS_IN_TABLE = 101 - CLI_MAX_INDEX_SIZE = 102 - CLI_MAX_ROW_SIZE = 104 - CLI_MAX_STATEMENT_LEN = 105 - CLI_MAX_TABLES_IN_SELECT = 106 - CLI_MAX_USER_NAME_LEN = 107 - CLI_OJ_CAPABILITIES = 115 - CLI_XOPEN_CLI_YEAR = 10000 - CLI_CURSOR_SENSITIVITY = 10001 - CLI_DESCRIBE_PARAMETER = 10002 - CLI_CATALOG_NAME = 10003 - CLI_COLLATION_SEQ = 10004 - CLI_MAX_IDENTIFIER_LEN = 10005 - - _VALUES_TO_NAMES = { - 0: "CLI_MAX_DRIVER_CONNECTIONS", - 1: "CLI_MAX_CONCURRENT_ACTIVITIES", - 2: "CLI_DATA_SOURCE_NAME", - 8: "CLI_FETCH_DIRECTION", - 13: "CLI_SERVER_NAME", - 14: "CLI_SEARCH_PATTERN_ESCAPE", - 17: "CLI_DBMS_NAME", - 18: "CLI_DBMS_VER", - 19: "CLI_ACCESSIBLE_TABLES", - 20: "CLI_ACCESSIBLE_PROCEDURES", - 23: "CLI_CURSOR_COMMIT_BEHAVIOR", - 25: "CLI_DATA_SOURCE_READ_ONLY", - 26: "CLI_DEFAULT_TXN_ISOLATION", - 28: "CLI_IDENTIFIER_CASE", - 29: "CLI_IDENTIFIER_QUOTE_CHAR", - 30: "CLI_MAX_COLUMN_NAME_LEN", - 31: "CLI_MAX_CURSOR_NAME_LEN", - 32: "CLI_MAX_SCHEMA_NAME_LEN", - 34: "CLI_MAX_CATALOG_NAME_LEN", - 35: "CLI_MAX_TABLE_NAME_LEN", - 43: "CLI_SCROLL_CONCURRENCY", - 46: "CLI_TXN_CAPABLE", - 47: "CLI_USER_NAME", - 72: "CLI_TXN_ISOLATION_OPTION", - 73: "CLI_INTEGRITY", - 81: "CLI_GETDATA_EXTENSIONS", - 85: "CLI_NULL_COLLATION", - 86: "CLI_ALTER_TABLE", - 90: "CLI_ORDER_BY_COLUMNS_IN_SELECT", - 94: "CLI_SPECIAL_CHARACTERS", - 97: "CLI_MAX_COLUMNS_IN_GROUP_BY", - 98: "CLI_MAX_COLUMNS_IN_INDEX", - 99: "CLI_MAX_COLUMNS_IN_ORDER_BY", - 100: "CLI_MAX_COLUMNS_IN_SELECT", - 101: "CLI_MAX_COLUMNS_IN_TABLE", - 102: "CLI_MAX_INDEX_SIZE", - 104: "CLI_MAX_ROW_SIZE", - 105: "CLI_MAX_STATEMENT_LEN", - 106: "CLI_MAX_TABLES_IN_SELECT", - 107: "CLI_MAX_USER_NAME_LEN", - 115: "CLI_OJ_CAPABILITIES", - 10000: "CLI_XOPEN_CLI_YEAR", - 10001: "CLI_CURSOR_SENSITIVITY", - 10002: "CLI_DESCRIBE_PARAMETER", - 10003: "CLI_CATALOG_NAME", - 10004: "CLI_COLLATION_SEQ", - 10005: "CLI_MAX_IDENTIFIER_LEN", - } - - _NAMES_TO_VALUES = { - "CLI_MAX_DRIVER_CONNECTIONS": 0, - "CLI_MAX_CONCURRENT_ACTIVITIES": 1, - "CLI_DATA_SOURCE_NAME": 2, - "CLI_FETCH_DIRECTION": 8, - "CLI_SERVER_NAME": 13, - "CLI_SEARCH_PATTERN_ESCAPE": 14, - "CLI_DBMS_NAME": 17, - "CLI_DBMS_VER": 18, - "CLI_ACCESSIBLE_TABLES": 19, - "CLI_ACCESSIBLE_PROCEDURES": 20, - "CLI_CURSOR_COMMIT_BEHAVIOR": 23, - "CLI_DATA_SOURCE_READ_ONLY": 25, - "CLI_DEFAULT_TXN_ISOLATION": 26, - "CLI_IDENTIFIER_CASE": 28, - "CLI_IDENTIFIER_QUOTE_CHAR": 29, - "CLI_MAX_COLUMN_NAME_LEN": 30, - "CLI_MAX_CURSOR_NAME_LEN": 31, - "CLI_MAX_SCHEMA_NAME_LEN": 32, - "CLI_MAX_CATALOG_NAME_LEN": 34, - "CLI_MAX_TABLE_NAME_LEN": 35, - "CLI_SCROLL_CONCURRENCY": 43, - "CLI_TXN_CAPABLE": 46, - "CLI_USER_NAME": 47, - "CLI_TXN_ISOLATION_OPTION": 72, - "CLI_INTEGRITY": 73, - "CLI_GETDATA_EXTENSIONS": 81, - "CLI_NULL_COLLATION": 85, - "CLI_ALTER_TABLE": 86, - "CLI_ORDER_BY_COLUMNS_IN_SELECT": 90, - "CLI_SPECIAL_CHARACTERS": 94, - "CLI_MAX_COLUMNS_IN_GROUP_BY": 97, - "CLI_MAX_COLUMNS_IN_INDEX": 98, - "CLI_MAX_COLUMNS_IN_ORDER_BY": 99, - "CLI_MAX_COLUMNS_IN_SELECT": 100, - "CLI_MAX_COLUMNS_IN_TABLE": 101, - "CLI_MAX_INDEX_SIZE": 102, - "CLI_MAX_ROW_SIZE": 104, - "CLI_MAX_STATEMENT_LEN": 105, - "CLI_MAX_TABLES_IN_SELECT": 106, - "CLI_MAX_USER_NAME_LEN": 107, - "CLI_OJ_CAPABILITIES": 115, - "CLI_XOPEN_CLI_YEAR": 10000, - "CLI_CURSOR_SENSITIVITY": 10001, - "CLI_DESCRIBE_PARAMETER": 10002, - "CLI_CATALOG_NAME": 10003, - "CLI_COLLATION_SEQ": 10004, - "CLI_MAX_IDENTIFIER_LEN": 10005, - } - - -class TFetchOrientation(object): - FETCH_NEXT = 0 - FETCH_PRIOR = 1 - FETCH_RELATIVE = 2 - FETCH_ABSOLUTE = 3 - FETCH_FIRST = 4 - FETCH_LAST = 5 - - _VALUES_TO_NAMES = { - 0: "FETCH_NEXT", - 1: "FETCH_PRIOR", - 2: "FETCH_RELATIVE", - 3: "FETCH_ABSOLUTE", - 4: "FETCH_FIRST", - 5: "FETCH_LAST", - } - - _NAMES_TO_VALUES = { - "FETCH_NEXT": 0, - "FETCH_PRIOR": 1, - "FETCH_RELATIVE": 2, - "FETCH_ABSOLUTE": 3, - "FETCH_FIRST": 4, - "FETCH_LAST": 5, - } - - -class TJobExecutionStatus(object): - IN_PROGRESS = 0 - COMPLETE = 1 - NOT_AVAILABLE = 2 - - _VALUES_TO_NAMES = { - 0: "IN_PROGRESS", - 1: "COMPLETE", - 2: "NOT_AVAILABLE", - } - - _NAMES_TO_VALUES = { - "IN_PROGRESS": 0, - "COMPLETE": 1, - "NOT_AVAILABLE": 2, - } - - -class TTypeQualifierValue(object): - """ - Attributes: - - i32Value - - stringValue - """ - - thrift_spec = ( - None, # 0 - (1, TType.I32, 'i32Value', None, None, ), # 1 - (2, TType.STRING, 'stringValue', 'UTF8', None, ), # 2 - ) - - def __init__(self, i32Value=None, stringValue=None,): - self.i32Value = i32Value - self.stringValue = stringValue - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.I32: - self.i32Value = iprot.readI32() - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRING: - self.stringValue = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TTypeQualifierValue') - if self.i32Value is not None: - oprot.writeFieldBegin('i32Value', TType.I32, 1) - oprot.writeI32(self.i32Value) - oprot.writeFieldEnd() - if self.stringValue is not None: - oprot.writeFieldBegin('stringValue', TType.STRING, 2) - oprot.writeString(self.stringValue.encode('utf-8') if sys.version_info[0] == 2 else self.stringValue) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TTypeQualifiers(object): - """ - Attributes: - - qualifiers - """ - - thrift_spec = ( - None, # 0 - (1, TType.MAP, 'qualifiers', (TType.STRING, 'UTF8', TType.STRUCT, (TTypeQualifierValue, TTypeQualifierValue.thrift_spec), False), None, ), # 1 - ) - - def __init__(self, qualifiers=None,): - self.qualifiers = qualifiers - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.MAP: - self.qualifiers = {} - (_ktype1, _vtype2, _size0) = iprot.readMapBegin() - for _i4 in range(_size0): - _key5 = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - _val6 = TTypeQualifierValue() - _val6.read(iprot) - self.qualifiers[_key5] = _val6 - iprot.readMapEnd() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TTypeQualifiers') - if self.qualifiers is not None: - oprot.writeFieldBegin('qualifiers', TType.MAP, 1) - oprot.writeMapBegin(TType.STRING, TType.STRUCT, len(self.qualifiers)) - for kiter7, viter8 in self.qualifiers.items(): - oprot.writeString(kiter7.encode('utf-8') if sys.version_info[0] == 2 else kiter7) - viter8.write(oprot) - oprot.writeMapEnd() - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.qualifiers is None: - raise TProtocolException(message='Required field qualifiers is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TPrimitiveTypeEntry(object): - """ - Attributes: - - type - - typeQualifiers - """ - - thrift_spec = ( - None, # 0 - (1, TType.I32, 'type', None, None, ), # 1 - (2, TType.STRUCT, 'typeQualifiers', (TTypeQualifiers, TTypeQualifiers.thrift_spec), None, ), # 2 - ) - - def __init__(self, type=None, typeQualifiers=None,): - self.type = type - self.typeQualifiers = typeQualifiers - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.I32: - self.type = iprot.readI32() - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRUCT: - self.typeQualifiers = TTypeQualifiers() - self.typeQualifiers.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TPrimitiveTypeEntry') - if self.type is not None: - oprot.writeFieldBegin('type', TType.I32, 1) - oprot.writeI32(self.type) - oprot.writeFieldEnd() - if self.typeQualifiers is not None: - oprot.writeFieldBegin('typeQualifiers', TType.STRUCT, 2) - self.typeQualifiers.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.type is None: - raise TProtocolException(message='Required field type is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TArrayTypeEntry(object): - """ - Attributes: - - objectTypePtr - """ - - thrift_spec = ( - None, # 0 - (1, TType.I32, 'objectTypePtr', None, None, ), # 1 - ) - - def __init__(self, objectTypePtr=None,): - self.objectTypePtr = objectTypePtr - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.I32: - self.objectTypePtr = iprot.readI32() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TArrayTypeEntry') - if self.objectTypePtr is not None: - oprot.writeFieldBegin('objectTypePtr', TType.I32, 1) - oprot.writeI32(self.objectTypePtr) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.objectTypePtr is None: - raise TProtocolException(message='Required field objectTypePtr is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TMapTypeEntry(object): - """ - Attributes: - - keyTypePtr - - valueTypePtr - """ - - thrift_spec = ( - None, # 0 - (1, TType.I32, 'keyTypePtr', None, None, ), # 1 - (2, TType.I32, 'valueTypePtr', None, None, ), # 2 - ) - - def __init__(self, keyTypePtr=None, valueTypePtr=None,): - self.keyTypePtr = keyTypePtr - self.valueTypePtr = valueTypePtr - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.I32: - self.keyTypePtr = iprot.readI32() - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.I32: - self.valueTypePtr = iprot.readI32() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TMapTypeEntry') - if self.keyTypePtr is not None: - oprot.writeFieldBegin('keyTypePtr', TType.I32, 1) - oprot.writeI32(self.keyTypePtr) - oprot.writeFieldEnd() - if self.valueTypePtr is not None: - oprot.writeFieldBegin('valueTypePtr', TType.I32, 2) - oprot.writeI32(self.valueTypePtr) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.keyTypePtr is None: - raise TProtocolException(message='Required field keyTypePtr is unset!') - if self.valueTypePtr is None: - raise TProtocolException(message='Required field valueTypePtr is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TStructTypeEntry(object): - """ - Attributes: - - nameToTypePtr - """ - - thrift_spec = ( - None, # 0 - (1, TType.MAP, 'nameToTypePtr', (TType.STRING, 'UTF8', TType.I32, None, False), None, ), # 1 - ) - - def __init__(self, nameToTypePtr=None,): - self.nameToTypePtr = nameToTypePtr - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.MAP: - self.nameToTypePtr = {} - (_ktype10, _vtype11, _size9) = iprot.readMapBegin() - for _i13 in range(_size9): - _key14 = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - _val15 = iprot.readI32() - self.nameToTypePtr[_key14] = _val15 - iprot.readMapEnd() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TStructTypeEntry') - if self.nameToTypePtr is not None: - oprot.writeFieldBegin('nameToTypePtr', TType.MAP, 1) - oprot.writeMapBegin(TType.STRING, TType.I32, len(self.nameToTypePtr)) - for kiter16, viter17 in self.nameToTypePtr.items(): - oprot.writeString(kiter16.encode('utf-8') if sys.version_info[0] == 2 else kiter16) - oprot.writeI32(viter17) - oprot.writeMapEnd() - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.nameToTypePtr is None: - raise TProtocolException(message='Required field nameToTypePtr is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TUnionTypeEntry(object): - """ - Attributes: - - nameToTypePtr - """ - - thrift_spec = ( - None, # 0 - (1, TType.MAP, 'nameToTypePtr', (TType.STRING, 'UTF8', TType.I32, None, False), None, ), # 1 - ) - - def __init__(self, nameToTypePtr=None,): - self.nameToTypePtr = nameToTypePtr - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.MAP: - self.nameToTypePtr = {} - (_ktype19, _vtype20, _size18) = iprot.readMapBegin() - for _i22 in range(_size18): - _key23 = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - _val24 = iprot.readI32() - self.nameToTypePtr[_key23] = _val24 - iprot.readMapEnd() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TUnionTypeEntry') - if self.nameToTypePtr is not None: - oprot.writeFieldBegin('nameToTypePtr', TType.MAP, 1) - oprot.writeMapBegin(TType.STRING, TType.I32, len(self.nameToTypePtr)) - for kiter25, viter26 in self.nameToTypePtr.items(): - oprot.writeString(kiter25.encode('utf-8') if sys.version_info[0] == 2 else kiter25) - oprot.writeI32(viter26) - oprot.writeMapEnd() - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.nameToTypePtr is None: - raise TProtocolException(message='Required field nameToTypePtr is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TUserDefinedTypeEntry(object): - """ - Attributes: - - typeClassName - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRING, 'typeClassName', 'UTF8', None, ), # 1 - ) - - def __init__(self, typeClassName=None,): - self.typeClassName = typeClassName - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRING: - self.typeClassName = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TUserDefinedTypeEntry') - if self.typeClassName is not None: - oprot.writeFieldBegin('typeClassName', TType.STRING, 1) - oprot.writeString(self.typeClassName.encode('utf-8') if sys.version_info[0] == 2 else self.typeClassName) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.typeClassName is None: - raise TProtocolException(message='Required field typeClassName is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TTypeEntry(object): - """ - Attributes: - - primitiveEntry - - arrayEntry - - mapEntry - - structEntry - - unionEntry - - userDefinedTypeEntry - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'primitiveEntry', (TPrimitiveTypeEntry, TPrimitiveTypeEntry.thrift_spec), None, ), # 1 - (2, TType.STRUCT, 'arrayEntry', (TArrayTypeEntry, TArrayTypeEntry.thrift_spec), None, ), # 2 - (3, TType.STRUCT, 'mapEntry', (TMapTypeEntry, TMapTypeEntry.thrift_spec), None, ), # 3 - (4, TType.STRUCT, 'structEntry', (TStructTypeEntry, TStructTypeEntry.thrift_spec), None, ), # 4 - (5, TType.STRUCT, 'unionEntry', (TUnionTypeEntry, TUnionTypeEntry.thrift_spec), None, ), # 5 - (6, TType.STRUCT, 'userDefinedTypeEntry', (TUserDefinedTypeEntry, TUserDefinedTypeEntry.thrift_spec), None, ), # 6 - ) - - def __init__(self, primitiveEntry=None, arrayEntry=None, mapEntry=None, structEntry=None, unionEntry=None, userDefinedTypeEntry=None,): - self.primitiveEntry = primitiveEntry - self.arrayEntry = arrayEntry - self.mapEntry = mapEntry - self.structEntry = structEntry - self.unionEntry = unionEntry - self.userDefinedTypeEntry = userDefinedTypeEntry - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.primitiveEntry = TPrimitiveTypeEntry() - self.primitiveEntry.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRUCT: - self.arrayEntry = TArrayTypeEntry() - self.arrayEntry.read(iprot) - else: - iprot.skip(ftype) - elif fid == 3: - if ftype == TType.STRUCT: - self.mapEntry = TMapTypeEntry() - self.mapEntry.read(iprot) - else: - iprot.skip(ftype) - elif fid == 4: - if ftype == TType.STRUCT: - self.structEntry = TStructTypeEntry() - self.structEntry.read(iprot) - else: - iprot.skip(ftype) - elif fid == 5: - if ftype == TType.STRUCT: - self.unionEntry = TUnionTypeEntry() - self.unionEntry.read(iprot) - else: - iprot.skip(ftype) - elif fid == 6: - if ftype == TType.STRUCT: - self.userDefinedTypeEntry = TUserDefinedTypeEntry() - self.userDefinedTypeEntry.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TTypeEntry') - if self.primitiveEntry is not None: - oprot.writeFieldBegin('primitiveEntry', TType.STRUCT, 1) - self.primitiveEntry.write(oprot) - oprot.writeFieldEnd() - if self.arrayEntry is not None: - oprot.writeFieldBegin('arrayEntry', TType.STRUCT, 2) - self.arrayEntry.write(oprot) - oprot.writeFieldEnd() - if self.mapEntry is not None: - oprot.writeFieldBegin('mapEntry', TType.STRUCT, 3) - self.mapEntry.write(oprot) - oprot.writeFieldEnd() - if self.structEntry is not None: - oprot.writeFieldBegin('structEntry', TType.STRUCT, 4) - self.structEntry.write(oprot) - oprot.writeFieldEnd() - if self.unionEntry is not None: - oprot.writeFieldBegin('unionEntry', TType.STRUCT, 5) - self.unionEntry.write(oprot) - oprot.writeFieldEnd() - if self.userDefinedTypeEntry is not None: - oprot.writeFieldBegin('userDefinedTypeEntry', TType.STRUCT, 6) - self.userDefinedTypeEntry.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TTypeDesc(object): - """ - Attributes: - - types - """ - - thrift_spec = ( - None, # 0 - (1, TType.LIST, 'types', (TType.STRUCT, (TTypeEntry, TTypeEntry.thrift_spec), False), None, ), # 1 - ) - - def __init__(self, types=None,): - self.types = types - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.LIST: - self.types = [] - (_etype30, _size27) = iprot.readListBegin() - for _i31 in range(_size27): - _elem32 = TTypeEntry() - _elem32.read(iprot) - self.types.append(_elem32) - iprot.readListEnd() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TTypeDesc') - if self.types is not None: - oprot.writeFieldBegin('types', TType.LIST, 1) - oprot.writeListBegin(TType.STRUCT, len(self.types)) - for iter33 in self.types: - iter33.write(oprot) - oprot.writeListEnd() - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.types is None: - raise TProtocolException(message='Required field types is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TColumnDesc(object): - """ - Attributes: - - columnName - - typeDesc - - position - - comment - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRING, 'columnName', 'UTF8', None, ), # 1 - (2, TType.STRUCT, 'typeDesc', (TTypeDesc, TTypeDesc.thrift_spec), None, ), # 2 - (3, TType.I32, 'position', None, None, ), # 3 - (4, TType.STRING, 'comment', 'UTF8', None, ), # 4 - ) - - def __init__(self, columnName=None, typeDesc=None, position=None, comment=None,): - self.columnName = columnName - self.typeDesc = typeDesc - self.position = position - self.comment = comment - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRING: - self.columnName = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRUCT: - self.typeDesc = TTypeDesc() - self.typeDesc.read(iprot) - else: - iprot.skip(ftype) - elif fid == 3: - if ftype == TType.I32: - self.position = iprot.readI32() - else: - iprot.skip(ftype) - elif fid == 4: - if ftype == TType.STRING: - self.comment = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TColumnDesc') - if self.columnName is not None: - oprot.writeFieldBegin('columnName', TType.STRING, 1) - oprot.writeString(self.columnName.encode('utf-8') if sys.version_info[0] == 2 else self.columnName) - oprot.writeFieldEnd() - if self.typeDesc is not None: - oprot.writeFieldBegin('typeDesc', TType.STRUCT, 2) - self.typeDesc.write(oprot) - oprot.writeFieldEnd() - if self.position is not None: - oprot.writeFieldBegin('position', TType.I32, 3) - oprot.writeI32(self.position) - oprot.writeFieldEnd() - if self.comment is not None: - oprot.writeFieldBegin('comment', TType.STRING, 4) - oprot.writeString(self.comment.encode('utf-8') if sys.version_info[0] == 2 else self.comment) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.columnName is None: - raise TProtocolException(message='Required field columnName is unset!') - if self.typeDesc is None: - raise TProtocolException(message='Required field typeDesc is unset!') - if self.position is None: - raise TProtocolException(message='Required field position is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TTableSchema(object): - """ - Attributes: - - columns - """ - - thrift_spec = ( - None, # 0 - (1, TType.LIST, 'columns', (TType.STRUCT, (TColumnDesc, TColumnDesc.thrift_spec), False), None, ), # 1 - ) - - def __init__(self, columns=None,): - self.columns = columns - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.LIST: - self.columns = [] - (_etype37, _size34) = iprot.readListBegin() - for _i38 in range(_size34): - _elem39 = TColumnDesc() - _elem39.read(iprot) - self.columns.append(_elem39) - iprot.readListEnd() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TTableSchema') - if self.columns is not None: - oprot.writeFieldBegin('columns', TType.LIST, 1) - oprot.writeListBegin(TType.STRUCT, len(self.columns)) - for iter40 in self.columns: - iter40.write(oprot) - oprot.writeListEnd() - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.columns is None: - raise TProtocolException(message='Required field columns is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TBoolValue(object): - """ - Attributes: - - value - """ - - thrift_spec = ( - None, # 0 - (1, TType.BOOL, 'value', None, None, ), # 1 - ) - - def __init__(self, value=None,): - self.value = value - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.BOOL: - self.value = iprot.readBool() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TBoolValue') - if self.value is not None: - oprot.writeFieldBegin('value', TType.BOOL, 1) - oprot.writeBool(self.value) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TByteValue(object): - """ - Attributes: - - value - """ - - thrift_spec = ( - None, # 0 - (1, TType.BYTE, 'value', None, None, ), # 1 - ) - - def __init__(self, value=None,): - self.value = value - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.BYTE: - self.value = iprot.readByte() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TByteValue') - if self.value is not None: - oprot.writeFieldBegin('value', TType.BYTE, 1) - oprot.writeByte(self.value) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TI16Value(object): - """ - Attributes: - - value - """ - - thrift_spec = ( - None, # 0 - (1, TType.I16, 'value', None, None, ), # 1 - ) - - def __init__(self, value=None,): - self.value = value - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.I16: - self.value = iprot.readI16() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TI16Value') - if self.value is not None: - oprot.writeFieldBegin('value', TType.I16, 1) - oprot.writeI16(self.value) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TI32Value(object): - """ - Attributes: - - value - """ - - thrift_spec = ( - None, # 0 - (1, TType.I32, 'value', None, None, ), # 1 - ) - - def __init__(self, value=None,): - self.value = value - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.I32: - self.value = iprot.readI32() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TI32Value') - if self.value is not None: - oprot.writeFieldBegin('value', TType.I32, 1) - oprot.writeI32(self.value) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TI64Value(object): - """ - Attributes: - - value - """ - - thrift_spec = ( - None, # 0 - (1, TType.I64, 'value', None, None, ), # 1 - ) - - def __init__(self, value=None,): - self.value = value - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.I64: - self.value = iprot.readI64() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TI64Value') - if self.value is not None: - oprot.writeFieldBegin('value', TType.I64, 1) - oprot.writeI64(self.value) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TDoubleValue(object): - """ - Attributes: - - value - """ - - thrift_spec = ( - None, # 0 - (1, TType.DOUBLE, 'value', None, None, ), # 1 - ) - - def __init__(self, value=None,): - self.value = value - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.DOUBLE: - self.value = iprot.readDouble() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TDoubleValue') - if self.value is not None: - oprot.writeFieldBegin('value', TType.DOUBLE, 1) - oprot.writeDouble(self.value) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TStringValue(object): - """ - Attributes: - - value - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRING, 'value', 'UTF8', None, ), # 1 - ) - - def __init__(self, value=None,): - self.value = value - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRING: - self.value = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TStringValue') - if self.value is not None: - oprot.writeFieldBegin('value', TType.STRING, 1) - oprot.writeString(self.value.encode('utf-8') if sys.version_info[0] == 2 else self.value) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TColumnValue(object): - """ - Attributes: - - boolVal - - byteVal - - i16Val - - i32Val - - i64Val - - doubleVal - - stringVal - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'boolVal', (TBoolValue, TBoolValue.thrift_spec), None, ), # 1 - (2, TType.STRUCT, 'byteVal', (TByteValue, TByteValue.thrift_spec), None, ), # 2 - (3, TType.STRUCT, 'i16Val', (TI16Value, TI16Value.thrift_spec), None, ), # 3 - (4, TType.STRUCT, 'i32Val', (TI32Value, TI32Value.thrift_spec), None, ), # 4 - (5, TType.STRUCT, 'i64Val', (TI64Value, TI64Value.thrift_spec), None, ), # 5 - (6, TType.STRUCT, 'doubleVal', (TDoubleValue, TDoubleValue.thrift_spec), None, ), # 6 - (7, TType.STRUCT, 'stringVal', (TStringValue, TStringValue.thrift_spec), None, ), # 7 - ) - - def __init__(self, boolVal=None, byteVal=None, i16Val=None, i32Val=None, i64Val=None, doubleVal=None, stringVal=None,): - self.boolVal = boolVal - self.byteVal = byteVal - self.i16Val = i16Val - self.i32Val = i32Val - self.i64Val = i64Val - self.doubleVal = doubleVal - self.stringVal = stringVal - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.boolVal = TBoolValue() - self.boolVal.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRUCT: - self.byteVal = TByteValue() - self.byteVal.read(iprot) - else: - iprot.skip(ftype) - elif fid == 3: - if ftype == TType.STRUCT: - self.i16Val = TI16Value() - self.i16Val.read(iprot) - else: - iprot.skip(ftype) - elif fid == 4: - if ftype == TType.STRUCT: - self.i32Val = TI32Value() - self.i32Val.read(iprot) - else: - iprot.skip(ftype) - elif fid == 5: - if ftype == TType.STRUCT: - self.i64Val = TI64Value() - self.i64Val.read(iprot) - else: - iprot.skip(ftype) - elif fid == 6: - if ftype == TType.STRUCT: - self.doubleVal = TDoubleValue() - self.doubleVal.read(iprot) - else: - iprot.skip(ftype) - elif fid == 7: - if ftype == TType.STRUCT: - self.stringVal = TStringValue() - self.stringVal.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TColumnValue') - if self.boolVal is not None: - oprot.writeFieldBegin('boolVal', TType.STRUCT, 1) - self.boolVal.write(oprot) - oprot.writeFieldEnd() - if self.byteVal is not None: - oprot.writeFieldBegin('byteVal', TType.STRUCT, 2) - self.byteVal.write(oprot) - oprot.writeFieldEnd() - if self.i16Val is not None: - oprot.writeFieldBegin('i16Val', TType.STRUCT, 3) - self.i16Val.write(oprot) - oprot.writeFieldEnd() - if self.i32Val is not None: - oprot.writeFieldBegin('i32Val', TType.STRUCT, 4) - self.i32Val.write(oprot) - oprot.writeFieldEnd() - if self.i64Val is not None: - oprot.writeFieldBegin('i64Val', TType.STRUCT, 5) - self.i64Val.write(oprot) - oprot.writeFieldEnd() - if self.doubleVal is not None: - oprot.writeFieldBegin('doubleVal', TType.STRUCT, 6) - self.doubleVal.write(oprot) - oprot.writeFieldEnd() - if self.stringVal is not None: - oprot.writeFieldBegin('stringVal', TType.STRUCT, 7) - self.stringVal.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TRow(object): - """ - Attributes: - - colVals - """ - - thrift_spec = ( - None, # 0 - (1, TType.LIST, 'colVals', (TType.STRUCT, (TColumnValue, TColumnValue.thrift_spec), False), None, ), # 1 - ) - - def __init__(self, colVals=None,): - self.colVals = colVals - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.LIST: - self.colVals = [] - (_etype44, _size41) = iprot.readListBegin() - for _i45 in range(_size41): - _elem46 = TColumnValue() - _elem46.read(iprot) - self.colVals.append(_elem46) - iprot.readListEnd() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TRow') - if self.colVals is not None: - oprot.writeFieldBegin('colVals', TType.LIST, 1) - oprot.writeListBegin(TType.STRUCT, len(self.colVals)) - for iter47 in self.colVals: - iter47.write(oprot) - oprot.writeListEnd() - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.colVals is None: - raise TProtocolException(message='Required field colVals is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TBoolColumn(object): - """ - Attributes: - - values - - nulls - """ - - thrift_spec = ( - None, # 0 - (1, TType.LIST, 'values', (TType.BOOL, None, False), None, ), # 1 - (2, TType.STRING, 'nulls', 'BINARY', None, ), # 2 - ) - - def __init__(self, values=None, nulls=None,): - self.values = values - self.nulls = nulls - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.LIST: - self.values = [] - (_etype51, _size48) = iprot.readListBegin() - for _i52 in range(_size48): - _elem53 = iprot.readBool() - self.values.append(_elem53) - iprot.readListEnd() - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRING: - self.nulls = iprot.readBinary() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TBoolColumn') - if self.values is not None: - oprot.writeFieldBegin('values', TType.LIST, 1) - oprot.writeListBegin(TType.BOOL, len(self.values)) - for iter54 in self.values: - oprot.writeBool(iter54) - oprot.writeListEnd() - oprot.writeFieldEnd() - if self.nulls is not None: - oprot.writeFieldBegin('nulls', TType.STRING, 2) - oprot.writeBinary(self.nulls) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.values is None: - raise TProtocolException(message='Required field values is unset!') - if self.nulls is None: - raise TProtocolException(message='Required field nulls is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TByteColumn(object): - """ - Attributes: - - values - - nulls - """ - - thrift_spec = ( - None, # 0 - (1, TType.LIST, 'values', (TType.BYTE, None, False), None, ), # 1 - (2, TType.STRING, 'nulls', 'BINARY', None, ), # 2 - ) - - def __init__(self, values=None, nulls=None,): - self.values = values - self.nulls = nulls - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.LIST: - self.values = [] - (_etype58, _size55) = iprot.readListBegin() - for _i59 in range(_size55): - _elem60 = iprot.readByte() - self.values.append(_elem60) - iprot.readListEnd() - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRING: - self.nulls = iprot.readBinary() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TByteColumn') - if self.values is not None: - oprot.writeFieldBegin('values', TType.LIST, 1) - oprot.writeListBegin(TType.BYTE, len(self.values)) - for iter61 in self.values: - oprot.writeByte(iter61) - oprot.writeListEnd() - oprot.writeFieldEnd() - if self.nulls is not None: - oprot.writeFieldBegin('nulls', TType.STRING, 2) - oprot.writeBinary(self.nulls) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.values is None: - raise TProtocolException(message='Required field values is unset!') - if self.nulls is None: - raise TProtocolException(message='Required field nulls is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TI16Column(object): - """ - Attributes: - - values - - nulls - """ - - thrift_spec = ( - None, # 0 - (1, TType.LIST, 'values', (TType.I16, None, False), None, ), # 1 - (2, TType.STRING, 'nulls', 'BINARY', None, ), # 2 - ) - - def __init__(self, values=None, nulls=None,): - self.values = values - self.nulls = nulls - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.LIST: - self.values = [] - (_etype65, _size62) = iprot.readListBegin() - for _i66 in range(_size62): - _elem67 = iprot.readI16() - self.values.append(_elem67) - iprot.readListEnd() - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRING: - self.nulls = iprot.readBinary() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TI16Column') - if self.values is not None: - oprot.writeFieldBegin('values', TType.LIST, 1) - oprot.writeListBegin(TType.I16, len(self.values)) - for iter68 in self.values: - oprot.writeI16(iter68) - oprot.writeListEnd() - oprot.writeFieldEnd() - if self.nulls is not None: - oprot.writeFieldBegin('nulls', TType.STRING, 2) - oprot.writeBinary(self.nulls) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.values is None: - raise TProtocolException(message='Required field values is unset!') - if self.nulls is None: - raise TProtocolException(message='Required field nulls is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TI32Column(object): - """ - Attributes: - - values - - nulls - """ - - thrift_spec = ( - None, # 0 - (1, TType.LIST, 'values', (TType.I32, None, False), None, ), # 1 - (2, TType.STRING, 'nulls', 'BINARY', None, ), # 2 - ) - - def __init__(self, values=None, nulls=None,): - self.values = values - self.nulls = nulls - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.LIST: - self.values = [] - (_etype72, _size69) = iprot.readListBegin() - for _i73 in range(_size69): - _elem74 = iprot.readI32() - self.values.append(_elem74) - iprot.readListEnd() - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRING: - self.nulls = iprot.readBinary() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TI32Column') - if self.values is not None: - oprot.writeFieldBegin('values', TType.LIST, 1) - oprot.writeListBegin(TType.I32, len(self.values)) - for iter75 in self.values: - oprot.writeI32(iter75) - oprot.writeListEnd() - oprot.writeFieldEnd() - if self.nulls is not None: - oprot.writeFieldBegin('nulls', TType.STRING, 2) - oprot.writeBinary(self.nulls) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.values is None: - raise TProtocolException(message='Required field values is unset!') - if self.nulls is None: - raise TProtocolException(message='Required field nulls is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TI64Column(object): - """ - Attributes: - - values - - nulls - """ - - thrift_spec = ( - None, # 0 - (1, TType.LIST, 'values', (TType.I64, None, False), None, ), # 1 - (2, TType.STRING, 'nulls', 'BINARY', None, ), # 2 - ) - - def __init__(self, values=None, nulls=None,): - self.values = values - self.nulls = nulls - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.LIST: - self.values = [] - (_etype79, _size76) = iprot.readListBegin() - for _i80 in range(_size76): - _elem81 = iprot.readI64() - self.values.append(_elem81) - iprot.readListEnd() - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRING: - self.nulls = iprot.readBinary() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TI64Column') - if self.values is not None: - oprot.writeFieldBegin('values', TType.LIST, 1) - oprot.writeListBegin(TType.I64, len(self.values)) - for iter82 in self.values: - oprot.writeI64(iter82) - oprot.writeListEnd() - oprot.writeFieldEnd() - if self.nulls is not None: - oprot.writeFieldBegin('nulls', TType.STRING, 2) - oprot.writeBinary(self.nulls) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.values is None: - raise TProtocolException(message='Required field values is unset!') - if self.nulls is None: - raise TProtocolException(message='Required field nulls is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TDoubleColumn(object): - """ - Attributes: - - values - - nulls - """ - - thrift_spec = ( - None, # 0 - (1, TType.LIST, 'values', (TType.DOUBLE, None, False), None, ), # 1 - (2, TType.STRING, 'nulls', 'BINARY', None, ), # 2 - ) - - def __init__(self, values=None, nulls=None,): - self.values = values - self.nulls = nulls - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.LIST: - self.values = [] - (_etype86, _size83) = iprot.readListBegin() - for _i87 in range(_size83): - _elem88 = iprot.readDouble() - self.values.append(_elem88) - iprot.readListEnd() - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRING: - self.nulls = iprot.readBinary() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TDoubleColumn') - if self.values is not None: - oprot.writeFieldBegin('values', TType.LIST, 1) - oprot.writeListBegin(TType.DOUBLE, len(self.values)) - for iter89 in self.values: - oprot.writeDouble(iter89) - oprot.writeListEnd() - oprot.writeFieldEnd() - if self.nulls is not None: - oprot.writeFieldBegin('nulls', TType.STRING, 2) - oprot.writeBinary(self.nulls) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.values is None: - raise TProtocolException(message='Required field values is unset!') - if self.nulls is None: - raise TProtocolException(message='Required field nulls is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TStringColumn(object): - """ - Attributes: - - values - - nulls - """ - - thrift_spec = ( - None, # 0 - (1, TType.LIST, 'values', (TType.STRING, 'UTF8', False), None, ), # 1 - (2, TType.STRING, 'nulls', 'BINARY', None, ), # 2 - ) - - def __init__(self, values=None, nulls=None,): - self.values = values - self.nulls = nulls - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.LIST: - self.values = [] - (_etype93, _size90) = iprot.readListBegin() - for _i94 in range(_size90): - _elem95 = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - self.values.append(_elem95) - iprot.readListEnd() - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRING: - self.nulls = iprot.readBinary() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TStringColumn') - if self.values is not None: - oprot.writeFieldBegin('values', TType.LIST, 1) - oprot.writeListBegin(TType.STRING, len(self.values)) - for iter96 in self.values: - oprot.writeString(iter96.encode('utf-8') if sys.version_info[0] == 2 else iter96) - oprot.writeListEnd() - oprot.writeFieldEnd() - if self.nulls is not None: - oprot.writeFieldBegin('nulls', TType.STRING, 2) - oprot.writeBinary(self.nulls) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.values is None: - raise TProtocolException(message='Required field values is unset!') - if self.nulls is None: - raise TProtocolException(message='Required field nulls is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TBinaryColumn(object): - """ - Attributes: - - values - - nulls - """ - - thrift_spec = ( - None, # 0 - (1, TType.LIST, 'values', (TType.STRING, 'BINARY', False), None, ), # 1 - (2, TType.STRING, 'nulls', 'BINARY', None, ), # 2 - ) - - def __init__(self, values=None, nulls=None,): - self.values = values - self.nulls = nulls - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.LIST: - self.values = [] - (_etype100, _size97) = iprot.readListBegin() - for _i101 in range(_size97): - _elem102 = iprot.readBinary() - self.values.append(_elem102) - iprot.readListEnd() - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRING: - self.nulls = iprot.readBinary() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TBinaryColumn') - if self.values is not None: - oprot.writeFieldBegin('values', TType.LIST, 1) - oprot.writeListBegin(TType.STRING, len(self.values)) - for iter103 in self.values: - oprot.writeBinary(iter103) - oprot.writeListEnd() - oprot.writeFieldEnd() - if self.nulls is not None: - oprot.writeFieldBegin('nulls', TType.STRING, 2) - oprot.writeBinary(self.nulls) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.values is None: - raise TProtocolException(message='Required field values is unset!') - if self.nulls is None: - raise TProtocolException(message='Required field nulls is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TColumn(object): - """ - Attributes: - - boolVal - - byteVal - - i16Val - - i32Val - - i64Val - - doubleVal - - stringVal - - binaryVal - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'boolVal', (TBoolColumn, TBoolColumn.thrift_spec), None, ), # 1 - (2, TType.STRUCT, 'byteVal', (TByteColumn, TByteColumn.thrift_spec), None, ), # 2 - (3, TType.STRUCT, 'i16Val', (TI16Column, TI16Column.thrift_spec), None, ), # 3 - (4, TType.STRUCT, 'i32Val', (TI32Column, TI32Column.thrift_spec), None, ), # 4 - (5, TType.STRUCT, 'i64Val', (TI64Column, TI64Column.thrift_spec), None, ), # 5 - (6, TType.STRUCT, 'doubleVal', (TDoubleColumn, TDoubleColumn.thrift_spec), None, ), # 6 - (7, TType.STRUCT, 'stringVal', (TStringColumn, TStringColumn.thrift_spec), None, ), # 7 - (8, TType.STRUCT, 'binaryVal', (TBinaryColumn, TBinaryColumn.thrift_spec), None, ), # 8 - ) - - def __init__(self, boolVal=None, byteVal=None, i16Val=None, i32Val=None, i64Val=None, doubleVal=None, stringVal=None, binaryVal=None,): - self.boolVal = boolVal - self.byteVal = byteVal - self.i16Val = i16Val - self.i32Val = i32Val - self.i64Val = i64Val - self.doubleVal = doubleVal - self.stringVal = stringVal - self.binaryVal = binaryVal - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.boolVal = TBoolColumn() - self.boolVal.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRUCT: - self.byteVal = TByteColumn() - self.byteVal.read(iprot) - else: - iprot.skip(ftype) - elif fid == 3: - if ftype == TType.STRUCT: - self.i16Val = TI16Column() - self.i16Val.read(iprot) - else: - iprot.skip(ftype) - elif fid == 4: - if ftype == TType.STRUCT: - self.i32Val = TI32Column() - self.i32Val.read(iprot) - else: - iprot.skip(ftype) - elif fid == 5: - if ftype == TType.STRUCT: - self.i64Val = TI64Column() - self.i64Val.read(iprot) - else: - iprot.skip(ftype) - elif fid == 6: - if ftype == TType.STRUCT: - self.doubleVal = TDoubleColumn() - self.doubleVal.read(iprot) - else: - iprot.skip(ftype) - elif fid == 7: - if ftype == TType.STRUCT: - self.stringVal = TStringColumn() - self.stringVal.read(iprot) - else: - iprot.skip(ftype) - elif fid == 8: - if ftype == TType.STRUCT: - self.binaryVal = TBinaryColumn() - self.binaryVal.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TColumn') - if self.boolVal is not None: - oprot.writeFieldBegin('boolVal', TType.STRUCT, 1) - self.boolVal.write(oprot) - oprot.writeFieldEnd() - if self.byteVal is not None: - oprot.writeFieldBegin('byteVal', TType.STRUCT, 2) - self.byteVal.write(oprot) - oprot.writeFieldEnd() - if self.i16Val is not None: - oprot.writeFieldBegin('i16Val', TType.STRUCT, 3) - self.i16Val.write(oprot) - oprot.writeFieldEnd() - if self.i32Val is not None: - oprot.writeFieldBegin('i32Val', TType.STRUCT, 4) - self.i32Val.write(oprot) - oprot.writeFieldEnd() - if self.i64Val is not None: - oprot.writeFieldBegin('i64Val', TType.STRUCT, 5) - self.i64Val.write(oprot) - oprot.writeFieldEnd() - if self.doubleVal is not None: - oprot.writeFieldBegin('doubleVal', TType.STRUCT, 6) - self.doubleVal.write(oprot) - oprot.writeFieldEnd() - if self.stringVal is not None: - oprot.writeFieldBegin('stringVal', TType.STRUCT, 7) - self.stringVal.write(oprot) - oprot.writeFieldEnd() - if self.binaryVal is not None: - oprot.writeFieldBegin('binaryVal', TType.STRUCT, 8) - self.binaryVal.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TRowSet(object): - """ - Attributes: - - startRowOffset - - rows - - columns - - binaryColumns - - columnCount - """ - - thrift_spec = ( - None, # 0 - (1, TType.I64, 'startRowOffset', None, None, ), # 1 - (2, TType.LIST, 'rows', (TType.STRUCT, (TRow, TRow.thrift_spec), False), None, ), # 2 - (3, TType.LIST, 'columns', (TType.STRUCT, (TColumn, TColumn.thrift_spec), False), None, ), # 3 - (4, TType.STRING, 'binaryColumns', 'BINARY', None, ), # 4 - (5, TType.I32, 'columnCount', None, None, ), # 5 - ) - - def __init__(self, startRowOffset=None, rows=None, columns=None, binaryColumns=None, columnCount=None,): - self.startRowOffset = startRowOffset - self.rows = rows - self.columns = columns - self.binaryColumns = binaryColumns - self.columnCount = columnCount - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.I64: - self.startRowOffset = iprot.readI64() - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.LIST: - self.rows = [] - (_etype107, _size104) = iprot.readListBegin() - for _i108 in range(_size104): - _elem109 = TRow() - _elem109.read(iprot) - self.rows.append(_elem109) - iprot.readListEnd() - else: - iprot.skip(ftype) - elif fid == 3: - if ftype == TType.LIST: - self.columns = [] - (_etype113, _size110) = iprot.readListBegin() - for _i114 in range(_size110): - _elem115 = TColumn() - _elem115.read(iprot) - self.columns.append(_elem115) - iprot.readListEnd() - else: - iprot.skip(ftype) - elif fid == 4: - if ftype == TType.STRING: - self.binaryColumns = iprot.readBinary() - else: - iprot.skip(ftype) - elif fid == 5: - if ftype == TType.I32: - self.columnCount = iprot.readI32() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TRowSet') - if self.startRowOffset is not None: - oprot.writeFieldBegin('startRowOffset', TType.I64, 1) - oprot.writeI64(self.startRowOffset) - oprot.writeFieldEnd() - if self.rows is not None: - oprot.writeFieldBegin('rows', TType.LIST, 2) - oprot.writeListBegin(TType.STRUCT, len(self.rows)) - for iter116 in self.rows: - iter116.write(oprot) - oprot.writeListEnd() - oprot.writeFieldEnd() - if self.columns is not None: - oprot.writeFieldBegin('columns', TType.LIST, 3) - oprot.writeListBegin(TType.STRUCT, len(self.columns)) - for iter117 in self.columns: - iter117.write(oprot) - oprot.writeListEnd() - oprot.writeFieldEnd() - if self.binaryColumns is not None: - oprot.writeFieldBegin('binaryColumns', TType.STRING, 4) - oprot.writeBinary(self.binaryColumns) - oprot.writeFieldEnd() - if self.columnCount is not None: - oprot.writeFieldBegin('columnCount', TType.I32, 5) - oprot.writeI32(self.columnCount) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.startRowOffset is None: - raise TProtocolException(message='Required field startRowOffset is unset!') - if self.rows is None: - raise TProtocolException(message='Required field rows is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TStatus(object): - """ - Attributes: - - statusCode - - infoMessages - - sqlState - - errorCode - - errorMessage - """ - - thrift_spec = ( - None, # 0 - (1, TType.I32, 'statusCode', None, None, ), # 1 - (2, TType.LIST, 'infoMessages', (TType.STRING, 'UTF8', False), None, ), # 2 - (3, TType.STRING, 'sqlState', 'UTF8', None, ), # 3 - (4, TType.I32, 'errorCode', None, None, ), # 4 - (5, TType.STRING, 'errorMessage', 'UTF8', None, ), # 5 - ) - - def __init__(self, statusCode=None, infoMessages=None, sqlState=None, errorCode=None, errorMessage=None,): - self.statusCode = statusCode - self.infoMessages = infoMessages - self.sqlState = sqlState - self.errorCode = errorCode - self.errorMessage = errorMessage - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.I32: - self.statusCode = iprot.readI32() - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.LIST: - self.infoMessages = [] - (_etype121, _size118) = iprot.readListBegin() - for _i122 in range(_size118): - _elem123 = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - self.infoMessages.append(_elem123) - iprot.readListEnd() - else: - iprot.skip(ftype) - elif fid == 3: - if ftype == TType.STRING: - self.sqlState = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 4: - if ftype == TType.I32: - self.errorCode = iprot.readI32() - else: - iprot.skip(ftype) - elif fid == 5: - if ftype == TType.STRING: - self.errorMessage = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TStatus') - if self.statusCode is not None: - oprot.writeFieldBegin('statusCode', TType.I32, 1) - oprot.writeI32(self.statusCode) - oprot.writeFieldEnd() - if self.infoMessages is not None: - oprot.writeFieldBegin('infoMessages', TType.LIST, 2) - oprot.writeListBegin(TType.STRING, len(self.infoMessages)) - for iter124 in self.infoMessages: - oprot.writeString(iter124.encode('utf-8') if sys.version_info[0] == 2 else iter124) - oprot.writeListEnd() - oprot.writeFieldEnd() - if self.sqlState is not None: - oprot.writeFieldBegin('sqlState', TType.STRING, 3) - oprot.writeString(self.sqlState.encode('utf-8') if sys.version_info[0] == 2 else self.sqlState) - oprot.writeFieldEnd() - if self.errorCode is not None: - oprot.writeFieldBegin('errorCode', TType.I32, 4) - oprot.writeI32(self.errorCode) - oprot.writeFieldEnd() - if self.errorMessage is not None: - oprot.writeFieldBegin('errorMessage', TType.STRING, 5) - oprot.writeString(self.errorMessage.encode('utf-8') if sys.version_info[0] == 2 else self.errorMessage) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.statusCode is None: - raise TProtocolException(message='Required field statusCode is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class THandleIdentifier(object): - """ - Attributes: - - guid - - secret - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRING, 'guid', 'BINARY', None, ), # 1 - (2, TType.STRING, 'secret', 'BINARY', None, ), # 2 - ) - - def __init__(self, guid=None, secret=None,): - self.guid = guid - self.secret = secret - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRING: - self.guid = iprot.readBinary() - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRING: - self.secret = iprot.readBinary() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('THandleIdentifier') - if self.guid is not None: - oprot.writeFieldBegin('guid', TType.STRING, 1) - oprot.writeBinary(self.guid) - oprot.writeFieldEnd() - if self.secret is not None: - oprot.writeFieldBegin('secret', TType.STRING, 2) - oprot.writeBinary(self.secret) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.guid is None: - raise TProtocolException(message='Required field guid is unset!') - if self.secret is None: - raise TProtocolException(message='Required field secret is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TSessionHandle(object): - """ - Attributes: - - sessionId - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'sessionId', (THandleIdentifier, THandleIdentifier.thrift_spec), None, ), # 1 - ) - - def __init__(self, sessionId=None,): - self.sessionId = sessionId - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.sessionId = THandleIdentifier() - self.sessionId.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TSessionHandle') - if self.sessionId is not None: - oprot.writeFieldBegin('sessionId', TType.STRUCT, 1) - self.sessionId.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.sessionId is None: - raise TProtocolException(message='Required field sessionId is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TOperationHandle(object): - """ - Attributes: - - operationId - - operationType - - hasResultSet - - modifiedRowCount - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'operationId', (THandleIdentifier, THandleIdentifier.thrift_spec), None, ), # 1 - (2, TType.I32, 'operationType', None, None, ), # 2 - (3, TType.BOOL, 'hasResultSet', None, None, ), # 3 - (4, TType.DOUBLE, 'modifiedRowCount', None, None, ), # 4 - ) - - def __init__(self, operationId=None, operationType=None, hasResultSet=None, modifiedRowCount=None,): - self.operationId = operationId - self.operationType = operationType - self.hasResultSet = hasResultSet - self.modifiedRowCount = modifiedRowCount - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.operationId = THandleIdentifier() - self.operationId.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.I32: - self.operationType = iprot.readI32() - else: - iprot.skip(ftype) - elif fid == 3: - if ftype == TType.BOOL: - self.hasResultSet = iprot.readBool() - else: - iprot.skip(ftype) - elif fid == 4: - if ftype == TType.DOUBLE: - self.modifiedRowCount = iprot.readDouble() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TOperationHandle') - if self.operationId is not None: - oprot.writeFieldBegin('operationId', TType.STRUCT, 1) - self.operationId.write(oprot) - oprot.writeFieldEnd() - if self.operationType is not None: - oprot.writeFieldBegin('operationType', TType.I32, 2) - oprot.writeI32(self.operationType) - oprot.writeFieldEnd() - if self.hasResultSet is not None: - oprot.writeFieldBegin('hasResultSet', TType.BOOL, 3) - oprot.writeBool(self.hasResultSet) - oprot.writeFieldEnd() - if self.modifiedRowCount is not None: - oprot.writeFieldBegin('modifiedRowCount', TType.DOUBLE, 4) - oprot.writeDouble(self.modifiedRowCount) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.operationId is None: - raise TProtocolException(message='Required field operationId is unset!') - if self.operationType is None: - raise TProtocolException(message='Required field operationType is unset!') - if self.hasResultSet is None: - raise TProtocolException(message='Required field hasResultSet is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TOpenSessionReq(object): - """ - Attributes: - - client_protocol - - username - - password - - configuration - """ - - thrift_spec = ( - None, # 0 - (1, TType.I32, 'client_protocol', None, 9, ), # 1 - (2, TType.STRING, 'username', 'UTF8', None, ), # 2 - (3, TType.STRING, 'password', 'UTF8', None, ), # 3 - (4, TType.MAP, 'configuration', (TType.STRING, 'UTF8', TType.STRING, 'UTF8', False), None, ), # 4 - ) - - def __init__(self, client_protocol=thrift_spec[1][4], username=None, password=None, configuration=None,): - self.client_protocol = client_protocol - self.username = username - self.password = password - self.configuration = configuration - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.I32: - self.client_protocol = iprot.readI32() - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRING: - self.username = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 3: - if ftype == TType.STRING: - self.password = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 4: - if ftype == TType.MAP: - self.configuration = {} - (_ktype126, _vtype127, _size125) = iprot.readMapBegin() - for _i129 in range(_size125): - _key130 = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - _val131 = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - self.configuration[_key130] = _val131 - iprot.readMapEnd() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TOpenSessionReq') - if self.client_protocol is not None: - oprot.writeFieldBegin('client_protocol', TType.I32, 1) - oprot.writeI32(self.client_protocol) - oprot.writeFieldEnd() - if self.username is not None: - oprot.writeFieldBegin('username', TType.STRING, 2) - oprot.writeString(self.username.encode('utf-8') if sys.version_info[0] == 2 else self.username) - oprot.writeFieldEnd() - if self.password is not None: - oprot.writeFieldBegin('password', TType.STRING, 3) - oprot.writeString(self.password.encode('utf-8') if sys.version_info[0] == 2 else self.password) - oprot.writeFieldEnd() - if self.configuration is not None: - oprot.writeFieldBegin('configuration', TType.MAP, 4) - oprot.writeMapBegin(TType.STRING, TType.STRING, len(self.configuration)) - for kiter132, viter133 in self.configuration.items(): - oprot.writeString(kiter132.encode('utf-8') if sys.version_info[0] == 2 else kiter132) - oprot.writeString(viter133.encode('utf-8') if sys.version_info[0] == 2 else viter133) - oprot.writeMapEnd() - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.client_protocol is None: - raise TProtocolException(message='Required field client_protocol is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TOpenSessionResp(object): - """ - Attributes: - - status - - serverProtocolVersion - - sessionHandle - - configuration - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'status', (TStatus, TStatus.thrift_spec), None, ), # 1 - (2, TType.I32, 'serverProtocolVersion', None, 9, ), # 2 - (3, TType.STRUCT, 'sessionHandle', (TSessionHandle, TSessionHandle.thrift_spec), None, ), # 3 - (4, TType.MAP, 'configuration', (TType.STRING, 'UTF8', TType.STRING, 'UTF8', False), None, ), # 4 - ) - - def __init__(self, status=None, serverProtocolVersion=thrift_spec[2][4], sessionHandle=None, configuration=None,): - self.status = status - self.serverProtocolVersion = serverProtocolVersion - self.sessionHandle = sessionHandle - self.configuration = configuration - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.status = TStatus() - self.status.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.I32: - self.serverProtocolVersion = iprot.readI32() - else: - iprot.skip(ftype) - elif fid == 3: - if ftype == TType.STRUCT: - self.sessionHandle = TSessionHandle() - self.sessionHandle.read(iprot) - else: - iprot.skip(ftype) - elif fid == 4: - if ftype == TType.MAP: - self.configuration = {} - (_ktype135, _vtype136, _size134) = iprot.readMapBegin() - for _i138 in range(_size134): - _key139 = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - _val140 = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - self.configuration[_key139] = _val140 - iprot.readMapEnd() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TOpenSessionResp') - if self.status is not None: - oprot.writeFieldBegin('status', TType.STRUCT, 1) - self.status.write(oprot) - oprot.writeFieldEnd() - if self.serverProtocolVersion is not None: - oprot.writeFieldBegin('serverProtocolVersion', TType.I32, 2) - oprot.writeI32(self.serverProtocolVersion) - oprot.writeFieldEnd() - if self.sessionHandle is not None: - oprot.writeFieldBegin('sessionHandle', TType.STRUCT, 3) - self.sessionHandle.write(oprot) - oprot.writeFieldEnd() - if self.configuration is not None: - oprot.writeFieldBegin('configuration', TType.MAP, 4) - oprot.writeMapBegin(TType.STRING, TType.STRING, len(self.configuration)) - for kiter141, viter142 in self.configuration.items(): - oprot.writeString(kiter141.encode('utf-8') if sys.version_info[0] == 2 else kiter141) - oprot.writeString(viter142.encode('utf-8') if sys.version_info[0] == 2 else viter142) - oprot.writeMapEnd() - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.status is None: - raise TProtocolException(message='Required field status is unset!') - if self.serverProtocolVersion is None: - raise TProtocolException(message='Required field serverProtocolVersion is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TCloseSessionReq(object): - """ - Attributes: - - sessionHandle - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'sessionHandle', (TSessionHandle, TSessionHandle.thrift_spec), None, ), # 1 - ) - - def __init__(self, sessionHandle=None,): - self.sessionHandle = sessionHandle - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.sessionHandle = TSessionHandle() - self.sessionHandle.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TCloseSessionReq') - if self.sessionHandle is not None: - oprot.writeFieldBegin('sessionHandle', TType.STRUCT, 1) - self.sessionHandle.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.sessionHandle is None: - raise TProtocolException(message='Required field sessionHandle is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TCloseSessionResp(object): - """ - Attributes: - - status - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'status', (TStatus, TStatus.thrift_spec), None, ), # 1 - ) - - def __init__(self, status=None,): - self.status = status - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.status = TStatus() - self.status.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TCloseSessionResp') - if self.status is not None: - oprot.writeFieldBegin('status', TType.STRUCT, 1) - self.status.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.status is None: - raise TProtocolException(message='Required field status is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetInfoValue(object): - """ - Attributes: - - stringValue - - smallIntValue - - integerBitmask - - integerFlag - - binaryValue - - lenValue - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRING, 'stringValue', 'UTF8', None, ), # 1 - (2, TType.I16, 'smallIntValue', None, None, ), # 2 - (3, TType.I32, 'integerBitmask', None, None, ), # 3 - (4, TType.I32, 'integerFlag', None, None, ), # 4 - (5, TType.I32, 'binaryValue', None, None, ), # 5 - (6, TType.I64, 'lenValue', None, None, ), # 6 - ) - - def __init__(self, stringValue=None, smallIntValue=None, integerBitmask=None, integerFlag=None, binaryValue=None, lenValue=None,): - self.stringValue = stringValue - self.smallIntValue = smallIntValue - self.integerBitmask = integerBitmask - self.integerFlag = integerFlag - self.binaryValue = binaryValue - self.lenValue = lenValue - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRING: - self.stringValue = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.I16: - self.smallIntValue = iprot.readI16() - else: - iprot.skip(ftype) - elif fid == 3: - if ftype == TType.I32: - self.integerBitmask = iprot.readI32() - else: - iprot.skip(ftype) - elif fid == 4: - if ftype == TType.I32: - self.integerFlag = iprot.readI32() - else: - iprot.skip(ftype) - elif fid == 5: - if ftype == TType.I32: - self.binaryValue = iprot.readI32() - else: - iprot.skip(ftype) - elif fid == 6: - if ftype == TType.I64: - self.lenValue = iprot.readI64() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetInfoValue') - if self.stringValue is not None: - oprot.writeFieldBegin('stringValue', TType.STRING, 1) - oprot.writeString(self.stringValue.encode('utf-8') if sys.version_info[0] == 2 else self.stringValue) - oprot.writeFieldEnd() - if self.smallIntValue is not None: - oprot.writeFieldBegin('smallIntValue', TType.I16, 2) - oprot.writeI16(self.smallIntValue) - oprot.writeFieldEnd() - if self.integerBitmask is not None: - oprot.writeFieldBegin('integerBitmask', TType.I32, 3) - oprot.writeI32(self.integerBitmask) - oprot.writeFieldEnd() - if self.integerFlag is not None: - oprot.writeFieldBegin('integerFlag', TType.I32, 4) - oprot.writeI32(self.integerFlag) - oprot.writeFieldEnd() - if self.binaryValue is not None: - oprot.writeFieldBegin('binaryValue', TType.I32, 5) - oprot.writeI32(self.binaryValue) - oprot.writeFieldEnd() - if self.lenValue is not None: - oprot.writeFieldBegin('lenValue', TType.I64, 6) - oprot.writeI64(self.lenValue) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetInfoReq(object): - """ - Attributes: - - sessionHandle - - infoType - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'sessionHandle', (TSessionHandle, TSessionHandle.thrift_spec), None, ), # 1 - (2, TType.I32, 'infoType', None, None, ), # 2 - ) - - def __init__(self, sessionHandle=None, infoType=None,): - self.sessionHandle = sessionHandle - self.infoType = infoType - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.sessionHandle = TSessionHandle() - self.sessionHandle.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.I32: - self.infoType = iprot.readI32() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetInfoReq') - if self.sessionHandle is not None: - oprot.writeFieldBegin('sessionHandle', TType.STRUCT, 1) - self.sessionHandle.write(oprot) - oprot.writeFieldEnd() - if self.infoType is not None: - oprot.writeFieldBegin('infoType', TType.I32, 2) - oprot.writeI32(self.infoType) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.sessionHandle is None: - raise TProtocolException(message='Required field sessionHandle is unset!') - if self.infoType is None: - raise TProtocolException(message='Required field infoType is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetInfoResp(object): - """ - Attributes: - - status - - infoValue - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'status', (TStatus, TStatus.thrift_spec), None, ), # 1 - (2, TType.STRUCT, 'infoValue', (TGetInfoValue, TGetInfoValue.thrift_spec), None, ), # 2 - ) - - def __init__(self, status=None, infoValue=None,): - self.status = status - self.infoValue = infoValue - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.status = TStatus() - self.status.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRUCT: - self.infoValue = TGetInfoValue() - self.infoValue.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetInfoResp') - if self.status is not None: - oprot.writeFieldBegin('status', TType.STRUCT, 1) - self.status.write(oprot) - oprot.writeFieldEnd() - if self.infoValue is not None: - oprot.writeFieldBegin('infoValue', TType.STRUCT, 2) - self.infoValue.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.status is None: - raise TProtocolException(message='Required field status is unset!') - if self.infoValue is None: - raise TProtocolException(message='Required field infoValue is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TExecuteStatementReq(object): - """ - Attributes: - - sessionHandle - - statement - - confOverlay - - runAsync - - queryTimeout - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'sessionHandle', (TSessionHandle, TSessionHandle.thrift_spec), None, ), # 1 - (2, TType.STRING, 'statement', 'UTF8', None, ), # 2 - (3, TType.MAP, 'confOverlay', (TType.STRING, 'UTF8', TType.STRING, 'UTF8', False), None, ), # 3 - (4, TType.BOOL, 'runAsync', None, False, ), # 4 - (5, TType.I64, 'queryTimeout', None, 0, ), # 5 - ) - - def __init__(self, sessionHandle=None, statement=None, confOverlay=None, runAsync=thrift_spec[4][4], queryTimeout=thrift_spec[5][4],): - self.sessionHandle = sessionHandle - self.statement = statement - self.confOverlay = confOverlay - self.runAsync = runAsync - self.queryTimeout = queryTimeout - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.sessionHandle = TSessionHandle() - self.sessionHandle.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRING: - self.statement = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 3: - if ftype == TType.MAP: - self.confOverlay = {} - (_ktype144, _vtype145, _size143) = iprot.readMapBegin() - for _i147 in range(_size143): - _key148 = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - _val149 = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - self.confOverlay[_key148] = _val149 - iprot.readMapEnd() - else: - iprot.skip(ftype) - elif fid == 4: - if ftype == TType.BOOL: - self.runAsync = iprot.readBool() - else: - iprot.skip(ftype) - elif fid == 5: - if ftype == TType.I64: - self.queryTimeout = iprot.readI64() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TExecuteStatementReq') - if self.sessionHandle is not None: - oprot.writeFieldBegin('sessionHandle', TType.STRUCT, 1) - self.sessionHandle.write(oprot) - oprot.writeFieldEnd() - if self.statement is not None: - oprot.writeFieldBegin('statement', TType.STRING, 2) - oprot.writeString(self.statement.encode('utf-8') if sys.version_info[0] == 2 else self.statement) - oprot.writeFieldEnd() - if self.confOverlay is not None: - oprot.writeFieldBegin('confOverlay', TType.MAP, 3) - oprot.writeMapBegin(TType.STRING, TType.STRING, len(self.confOverlay)) - for kiter150, viter151 in self.confOverlay.items(): - oprot.writeString(kiter150.encode('utf-8') if sys.version_info[0] == 2 else kiter150) - oprot.writeString(viter151.encode('utf-8') if sys.version_info[0] == 2 else viter151) - oprot.writeMapEnd() - oprot.writeFieldEnd() - if self.runAsync is not None: - oprot.writeFieldBegin('runAsync', TType.BOOL, 4) - oprot.writeBool(self.runAsync) - oprot.writeFieldEnd() - if self.queryTimeout is not None: - oprot.writeFieldBegin('queryTimeout', TType.I64, 5) - oprot.writeI64(self.queryTimeout) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.sessionHandle is None: - raise TProtocolException(message='Required field sessionHandle is unset!') - if self.statement is None: - raise TProtocolException(message='Required field statement is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TExecuteStatementResp(object): - """ - Attributes: - - status - - operationHandle - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'status', (TStatus, TStatus.thrift_spec), None, ), # 1 - (2, TType.STRUCT, 'operationHandle', (TOperationHandle, TOperationHandle.thrift_spec), None, ), # 2 - ) - - def __init__(self, status=None, operationHandle=None,): - self.status = status - self.operationHandle = operationHandle - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.status = TStatus() - self.status.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRUCT: - self.operationHandle = TOperationHandle() - self.operationHandle.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TExecuteStatementResp') - if self.status is not None: - oprot.writeFieldBegin('status', TType.STRUCT, 1) - self.status.write(oprot) - oprot.writeFieldEnd() - if self.operationHandle is not None: - oprot.writeFieldBegin('operationHandle', TType.STRUCT, 2) - self.operationHandle.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.status is None: - raise TProtocolException(message='Required field status is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetTypeInfoReq(object): - """ - Attributes: - - sessionHandle - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'sessionHandle', (TSessionHandle, TSessionHandle.thrift_spec), None, ), # 1 - ) - - def __init__(self, sessionHandle=None,): - self.sessionHandle = sessionHandle - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.sessionHandle = TSessionHandle() - self.sessionHandle.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetTypeInfoReq') - if self.sessionHandle is not None: - oprot.writeFieldBegin('sessionHandle', TType.STRUCT, 1) - self.sessionHandle.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.sessionHandle is None: - raise TProtocolException(message='Required field sessionHandle is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetTypeInfoResp(object): - """ - Attributes: - - status - - operationHandle - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'status', (TStatus, TStatus.thrift_spec), None, ), # 1 - (2, TType.STRUCT, 'operationHandle', (TOperationHandle, TOperationHandle.thrift_spec), None, ), # 2 - ) - - def __init__(self, status=None, operationHandle=None,): - self.status = status - self.operationHandle = operationHandle - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.status = TStatus() - self.status.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRUCT: - self.operationHandle = TOperationHandle() - self.operationHandle.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetTypeInfoResp') - if self.status is not None: - oprot.writeFieldBegin('status', TType.STRUCT, 1) - self.status.write(oprot) - oprot.writeFieldEnd() - if self.operationHandle is not None: - oprot.writeFieldBegin('operationHandle', TType.STRUCT, 2) - self.operationHandle.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.status is None: - raise TProtocolException(message='Required field status is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetCatalogsReq(object): - """ - Attributes: - - sessionHandle - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'sessionHandle', (TSessionHandle, TSessionHandle.thrift_spec), None, ), # 1 - ) - - def __init__(self, sessionHandle=None,): - self.sessionHandle = sessionHandle - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.sessionHandle = TSessionHandle() - self.sessionHandle.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetCatalogsReq') - if self.sessionHandle is not None: - oprot.writeFieldBegin('sessionHandle', TType.STRUCT, 1) - self.sessionHandle.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.sessionHandle is None: - raise TProtocolException(message='Required field sessionHandle is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetCatalogsResp(object): - """ - Attributes: - - status - - operationHandle - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'status', (TStatus, TStatus.thrift_spec), None, ), # 1 - (2, TType.STRUCT, 'operationHandle', (TOperationHandle, TOperationHandle.thrift_spec), None, ), # 2 - ) - - def __init__(self, status=None, operationHandle=None,): - self.status = status - self.operationHandle = operationHandle - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.status = TStatus() - self.status.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRUCT: - self.operationHandle = TOperationHandle() - self.operationHandle.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetCatalogsResp') - if self.status is not None: - oprot.writeFieldBegin('status', TType.STRUCT, 1) - self.status.write(oprot) - oprot.writeFieldEnd() - if self.operationHandle is not None: - oprot.writeFieldBegin('operationHandle', TType.STRUCT, 2) - self.operationHandle.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.status is None: - raise TProtocolException(message='Required field status is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetSchemasReq(object): - """ - Attributes: - - sessionHandle - - catalogName - - schemaName - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'sessionHandle', (TSessionHandle, TSessionHandle.thrift_spec), None, ), # 1 - (2, TType.STRING, 'catalogName', 'UTF8', None, ), # 2 - (3, TType.STRING, 'schemaName', 'UTF8', None, ), # 3 - ) - - def __init__(self, sessionHandle=None, catalogName=None, schemaName=None,): - self.sessionHandle = sessionHandle - self.catalogName = catalogName - self.schemaName = schemaName - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.sessionHandle = TSessionHandle() - self.sessionHandle.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRING: - self.catalogName = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 3: - if ftype == TType.STRING: - self.schemaName = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetSchemasReq') - if self.sessionHandle is not None: - oprot.writeFieldBegin('sessionHandle', TType.STRUCT, 1) - self.sessionHandle.write(oprot) - oprot.writeFieldEnd() - if self.catalogName is not None: - oprot.writeFieldBegin('catalogName', TType.STRING, 2) - oprot.writeString(self.catalogName.encode('utf-8') if sys.version_info[0] == 2 else self.catalogName) - oprot.writeFieldEnd() - if self.schemaName is not None: - oprot.writeFieldBegin('schemaName', TType.STRING, 3) - oprot.writeString(self.schemaName.encode('utf-8') if sys.version_info[0] == 2 else self.schemaName) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.sessionHandle is None: - raise TProtocolException(message='Required field sessionHandle is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetSchemasResp(object): - """ - Attributes: - - status - - operationHandle - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'status', (TStatus, TStatus.thrift_spec), None, ), # 1 - (2, TType.STRUCT, 'operationHandle', (TOperationHandle, TOperationHandle.thrift_spec), None, ), # 2 - ) - - def __init__(self, status=None, operationHandle=None,): - self.status = status - self.operationHandle = operationHandle - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.status = TStatus() - self.status.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRUCT: - self.operationHandle = TOperationHandle() - self.operationHandle.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetSchemasResp') - if self.status is not None: - oprot.writeFieldBegin('status', TType.STRUCT, 1) - self.status.write(oprot) - oprot.writeFieldEnd() - if self.operationHandle is not None: - oprot.writeFieldBegin('operationHandle', TType.STRUCT, 2) - self.operationHandle.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.status is None: - raise TProtocolException(message='Required field status is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetTablesReq(object): - """ - Attributes: - - sessionHandle - - catalogName - - schemaName - - tableName - - tableTypes - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'sessionHandle', (TSessionHandle, TSessionHandle.thrift_spec), None, ), # 1 - (2, TType.STRING, 'catalogName', 'UTF8', None, ), # 2 - (3, TType.STRING, 'schemaName', 'UTF8', None, ), # 3 - (4, TType.STRING, 'tableName', 'UTF8', None, ), # 4 - (5, TType.LIST, 'tableTypes', (TType.STRING, 'UTF8', False), None, ), # 5 - ) - - def __init__(self, sessionHandle=None, catalogName=None, schemaName=None, tableName=None, tableTypes=None,): - self.sessionHandle = sessionHandle - self.catalogName = catalogName - self.schemaName = schemaName - self.tableName = tableName - self.tableTypes = tableTypes - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.sessionHandle = TSessionHandle() - self.sessionHandle.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRING: - self.catalogName = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 3: - if ftype == TType.STRING: - self.schemaName = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 4: - if ftype == TType.STRING: - self.tableName = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 5: - if ftype == TType.LIST: - self.tableTypes = [] - (_etype155, _size152) = iprot.readListBegin() - for _i156 in range(_size152): - _elem157 = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - self.tableTypes.append(_elem157) - iprot.readListEnd() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetTablesReq') - if self.sessionHandle is not None: - oprot.writeFieldBegin('sessionHandle', TType.STRUCT, 1) - self.sessionHandle.write(oprot) - oprot.writeFieldEnd() - if self.catalogName is not None: - oprot.writeFieldBegin('catalogName', TType.STRING, 2) - oprot.writeString(self.catalogName.encode('utf-8') if sys.version_info[0] == 2 else self.catalogName) - oprot.writeFieldEnd() - if self.schemaName is not None: - oprot.writeFieldBegin('schemaName', TType.STRING, 3) - oprot.writeString(self.schemaName.encode('utf-8') if sys.version_info[0] == 2 else self.schemaName) - oprot.writeFieldEnd() - if self.tableName is not None: - oprot.writeFieldBegin('tableName', TType.STRING, 4) - oprot.writeString(self.tableName.encode('utf-8') if sys.version_info[0] == 2 else self.tableName) - oprot.writeFieldEnd() - if self.tableTypes is not None: - oprot.writeFieldBegin('tableTypes', TType.LIST, 5) - oprot.writeListBegin(TType.STRING, len(self.tableTypes)) - for iter158 in self.tableTypes: - oprot.writeString(iter158.encode('utf-8') if sys.version_info[0] == 2 else iter158) - oprot.writeListEnd() - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.sessionHandle is None: - raise TProtocolException(message='Required field sessionHandle is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetTablesResp(object): - """ - Attributes: - - status - - operationHandle - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'status', (TStatus, TStatus.thrift_spec), None, ), # 1 - (2, TType.STRUCT, 'operationHandle', (TOperationHandle, TOperationHandle.thrift_spec), None, ), # 2 - ) - - def __init__(self, status=None, operationHandle=None,): - self.status = status - self.operationHandle = operationHandle - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.status = TStatus() - self.status.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRUCT: - self.operationHandle = TOperationHandle() - self.operationHandle.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetTablesResp') - if self.status is not None: - oprot.writeFieldBegin('status', TType.STRUCT, 1) - self.status.write(oprot) - oprot.writeFieldEnd() - if self.operationHandle is not None: - oprot.writeFieldBegin('operationHandle', TType.STRUCT, 2) - self.operationHandle.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.status is None: - raise TProtocolException(message='Required field status is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetTableTypesReq(object): - """ - Attributes: - - sessionHandle - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'sessionHandle', (TSessionHandle, TSessionHandle.thrift_spec), None, ), # 1 - ) - - def __init__(self, sessionHandle=None,): - self.sessionHandle = sessionHandle - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.sessionHandle = TSessionHandle() - self.sessionHandle.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetTableTypesReq') - if self.sessionHandle is not None: - oprot.writeFieldBegin('sessionHandle', TType.STRUCT, 1) - self.sessionHandle.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.sessionHandle is None: - raise TProtocolException(message='Required field sessionHandle is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetTableTypesResp(object): - """ - Attributes: - - status - - operationHandle - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'status', (TStatus, TStatus.thrift_spec), None, ), # 1 - (2, TType.STRUCT, 'operationHandle', (TOperationHandle, TOperationHandle.thrift_spec), None, ), # 2 - ) - - def __init__(self, status=None, operationHandle=None,): - self.status = status - self.operationHandle = operationHandle - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.status = TStatus() - self.status.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRUCT: - self.operationHandle = TOperationHandle() - self.operationHandle.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetTableTypesResp') - if self.status is not None: - oprot.writeFieldBegin('status', TType.STRUCT, 1) - self.status.write(oprot) - oprot.writeFieldEnd() - if self.operationHandle is not None: - oprot.writeFieldBegin('operationHandle', TType.STRUCT, 2) - self.operationHandle.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.status is None: - raise TProtocolException(message='Required field status is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetColumnsReq(object): - """ - Attributes: - - sessionHandle - - catalogName - - schemaName - - tableName - - columnName - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'sessionHandle', (TSessionHandle, TSessionHandle.thrift_spec), None, ), # 1 - (2, TType.STRING, 'catalogName', 'UTF8', None, ), # 2 - (3, TType.STRING, 'schemaName', 'UTF8', None, ), # 3 - (4, TType.STRING, 'tableName', 'UTF8', None, ), # 4 - (5, TType.STRING, 'columnName', 'UTF8', None, ), # 5 - ) - - def __init__(self, sessionHandle=None, catalogName=None, schemaName=None, tableName=None, columnName=None,): - self.sessionHandle = sessionHandle - self.catalogName = catalogName - self.schemaName = schemaName - self.tableName = tableName - self.columnName = columnName - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.sessionHandle = TSessionHandle() - self.sessionHandle.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRING: - self.catalogName = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 3: - if ftype == TType.STRING: - self.schemaName = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 4: - if ftype == TType.STRING: - self.tableName = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 5: - if ftype == TType.STRING: - self.columnName = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetColumnsReq') - if self.sessionHandle is not None: - oprot.writeFieldBegin('sessionHandle', TType.STRUCT, 1) - self.sessionHandle.write(oprot) - oprot.writeFieldEnd() - if self.catalogName is not None: - oprot.writeFieldBegin('catalogName', TType.STRING, 2) - oprot.writeString(self.catalogName.encode('utf-8') if sys.version_info[0] == 2 else self.catalogName) - oprot.writeFieldEnd() - if self.schemaName is not None: - oprot.writeFieldBegin('schemaName', TType.STRING, 3) - oprot.writeString(self.schemaName.encode('utf-8') if sys.version_info[0] == 2 else self.schemaName) - oprot.writeFieldEnd() - if self.tableName is not None: - oprot.writeFieldBegin('tableName', TType.STRING, 4) - oprot.writeString(self.tableName.encode('utf-8') if sys.version_info[0] == 2 else self.tableName) - oprot.writeFieldEnd() - if self.columnName is not None: - oprot.writeFieldBegin('columnName', TType.STRING, 5) - oprot.writeString(self.columnName.encode('utf-8') if sys.version_info[0] == 2 else self.columnName) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.sessionHandle is None: - raise TProtocolException(message='Required field sessionHandle is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetColumnsResp(object): - """ - Attributes: - - status - - operationHandle - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'status', (TStatus, TStatus.thrift_spec), None, ), # 1 - (2, TType.STRUCT, 'operationHandle', (TOperationHandle, TOperationHandle.thrift_spec), None, ), # 2 - ) - - def __init__(self, status=None, operationHandle=None,): - self.status = status - self.operationHandle = operationHandle - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.status = TStatus() - self.status.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRUCT: - self.operationHandle = TOperationHandle() - self.operationHandle.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetColumnsResp') - if self.status is not None: - oprot.writeFieldBegin('status', TType.STRUCT, 1) - self.status.write(oprot) - oprot.writeFieldEnd() - if self.operationHandle is not None: - oprot.writeFieldBegin('operationHandle', TType.STRUCT, 2) - self.operationHandle.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.status is None: - raise TProtocolException(message='Required field status is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetFunctionsReq(object): - """ - Attributes: - - sessionHandle - - catalogName - - schemaName - - functionName - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'sessionHandle', (TSessionHandle, TSessionHandle.thrift_spec), None, ), # 1 - (2, TType.STRING, 'catalogName', 'UTF8', None, ), # 2 - (3, TType.STRING, 'schemaName', 'UTF8', None, ), # 3 - (4, TType.STRING, 'functionName', 'UTF8', None, ), # 4 - ) - - def __init__(self, sessionHandle=None, catalogName=None, schemaName=None, functionName=None,): - self.sessionHandle = sessionHandle - self.catalogName = catalogName - self.schemaName = schemaName - self.functionName = functionName - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.sessionHandle = TSessionHandle() - self.sessionHandle.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRING: - self.catalogName = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 3: - if ftype == TType.STRING: - self.schemaName = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 4: - if ftype == TType.STRING: - self.functionName = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetFunctionsReq') - if self.sessionHandle is not None: - oprot.writeFieldBegin('sessionHandle', TType.STRUCT, 1) - self.sessionHandle.write(oprot) - oprot.writeFieldEnd() - if self.catalogName is not None: - oprot.writeFieldBegin('catalogName', TType.STRING, 2) - oprot.writeString(self.catalogName.encode('utf-8') if sys.version_info[0] == 2 else self.catalogName) - oprot.writeFieldEnd() - if self.schemaName is not None: - oprot.writeFieldBegin('schemaName', TType.STRING, 3) - oprot.writeString(self.schemaName.encode('utf-8') if sys.version_info[0] == 2 else self.schemaName) - oprot.writeFieldEnd() - if self.functionName is not None: - oprot.writeFieldBegin('functionName', TType.STRING, 4) - oprot.writeString(self.functionName.encode('utf-8') if sys.version_info[0] == 2 else self.functionName) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.sessionHandle is None: - raise TProtocolException(message='Required field sessionHandle is unset!') - if self.functionName is None: - raise TProtocolException(message='Required field functionName is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetFunctionsResp(object): - """ - Attributes: - - status - - operationHandle - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'status', (TStatus, TStatus.thrift_spec), None, ), # 1 - (2, TType.STRUCT, 'operationHandle', (TOperationHandle, TOperationHandle.thrift_spec), None, ), # 2 - ) - - def __init__(self, status=None, operationHandle=None,): - self.status = status - self.operationHandle = operationHandle - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.status = TStatus() - self.status.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRUCT: - self.operationHandle = TOperationHandle() - self.operationHandle.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetFunctionsResp') - if self.status is not None: - oprot.writeFieldBegin('status', TType.STRUCT, 1) - self.status.write(oprot) - oprot.writeFieldEnd() - if self.operationHandle is not None: - oprot.writeFieldBegin('operationHandle', TType.STRUCT, 2) - self.operationHandle.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.status is None: - raise TProtocolException(message='Required field status is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetPrimaryKeysReq(object): - """ - Attributes: - - sessionHandle - - catalogName - - schemaName - - tableName - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'sessionHandle', (TSessionHandle, TSessionHandle.thrift_spec), None, ), # 1 - (2, TType.STRING, 'catalogName', 'UTF8', None, ), # 2 - (3, TType.STRING, 'schemaName', 'UTF8', None, ), # 3 - (4, TType.STRING, 'tableName', 'UTF8', None, ), # 4 - ) - - def __init__(self, sessionHandle=None, catalogName=None, schemaName=None, tableName=None,): - self.sessionHandle = sessionHandle - self.catalogName = catalogName - self.schemaName = schemaName - self.tableName = tableName - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.sessionHandle = TSessionHandle() - self.sessionHandle.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRING: - self.catalogName = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 3: - if ftype == TType.STRING: - self.schemaName = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 4: - if ftype == TType.STRING: - self.tableName = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetPrimaryKeysReq') - if self.sessionHandle is not None: - oprot.writeFieldBegin('sessionHandle', TType.STRUCT, 1) - self.sessionHandle.write(oprot) - oprot.writeFieldEnd() - if self.catalogName is not None: - oprot.writeFieldBegin('catalogName', TType.STRING, 2) - oprot.writeString(self.catalogName.encode('utf-8') if sys.version_info[0] == 2 else self.catalogName) - oprot.writeFieldEnd() - if self.schemaName is not None: - oprot.writeFieldBegin('schemaName', TType.STRING, 3) - oprot.writeString(self.schemaName.encode('utf-8') if sys.version_info[0] == 2 else self.schemaName) - oprot.writeFieldEnd() - if self.tableName is not None: - oprot.writeFieldBegin('tableName', TType.STRING, 4) - oprot.writeString(self.tableName.encode('utf-8') if sys.version_info[0] == 2 else self.tableName) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.sessionHandle is None: - raise TProtocolException(message='Required field sessionHandle is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetPrimaryKeysResp(object): - """ - Attributes: - - status - - operationHandle - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'status', (TStatus, TStatus.thrift_spec), None, ), # 1 - (2, TType.STRUCT, 'operationHandle', (TOperationHandle, TOperationHandle.thrift_spec), None, ), # 2 - ) - - def __init__(self, status=None, operationHandle=None,): - self.status = status - self.operationHandle = operationHandle - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.status = TStatus() - self.status.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRUCT: - self.operationHandle = TOperationHandle() - self.operationHandle.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetPrimaryKeysResp') - if self.status is not None: - oprot.writeFieldBegin('status', TType.STRUCT, 1) - self.status.write(oprot) - oprot.writeFieldEnd() - if self.operationHandle is not None: - oprot.writeFieldBegin('operationHandle', TType.STRUCT, 2) - self.operationHandle.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.status is None: - raise TProtocolException(message='Required field status is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetCrossReferenceReq(object): - """ - Attributes: - - sessionHandle - - parentCatalogName - - parentSchemaName - - parentTableName - - foreignCatalogName - - foreignSchemaName - - foreignTableName - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'sessionHandle', (TSessionHandle, TSessionHandle.thrift_spec), None, ), # 1 - (2, TType.STRING, 'parentCatalogName', 'UTF8', None, ), # 2 - (3, TType.STRING, 'parentSchemaName', 'UTF8', None, ), # 3 - (4, TType.STRING, 'parentTableName', 'UTF8', None, ), # 4 - (5, TType.STRING, 'foreignCatalogName', 'UTF8', None, ), # 5 - (6, TType.STRING, 'foreignSchemaName', 'UTF8', None, ), # 6 - (7, TType.STRING, 'foreignTableName', 'UTF8', None, ), # 7 - ) - - def __init__(self, sessionHandle=None, parentCatalogName=None, parentSchemaName=None, parentTableName=None, foreignCatalogName=None, foreignSchemaName=None, foreignTableName=None,): - self.sessionHandle = sessionHandle - self.parentCatalogName = parentCatalogName - self.parentSchemaName = parentSchemaName - self.parentTableName = parentTableName - self.foreignCatalogName = foreignCatalogName - self.foreignSchemaName = foreignSchemaName - self.foreignTableName = foreignTableName - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.sessionHandle = TSessionHandle() - self.sessionHandle.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRING: - self.parentCatalogName = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 3: - if ftype == TType.STRING: - self.parentSchemaName = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 4: - if ftype == TType.STRING: - self.parentTableName = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 5: - if ftype == TType.STRING: - self.foreignCatalogName = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 6: - if ftype == TType.STRING: - self.foreignSchemaName = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 7: - if ftype == TType.STRING: - self.foreignTableName = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetCrossReferenceReq') - if self.sessionHandle is not None: - oprot.writeFieldBegin('sessionHandle', TType.STRUCT, 1) - self.sessionHandle.write(oprot) - oprot.writeFieldEnd() - if self.parentCatalogName is not None: - oprot.writeFieldBegin('parentCatalogName', TType.STRING, 2) - oprot.writeString(self.parentCatalogName.encode('utf-8') if sys.version_info[0] == 2 else self.parentCatalogName) - oprot.writeFieldEnd() - if self.parentSchemaName is not None: - oprot.writeFieldBegin('parentSchemaName', TType.STRING, 3) - oprot.writeString(self.parentSchemaName.encode('utf-8') if sys.version_info[0] == 2 else self.parentSchemaName) - oprot.writeFieldEnd() - if self.parentTableName is not None: - oprot.writeFieldBegin('parentTableName', TType.STRING, 4) - oprot.writeString(self.parentTableName.encode('utf-8') if sys.version_info[0] == 2 else self.parentTableName) - oprot.writeFieldEnd() - if self.foreignCatalogName is not None: - oprot.writeFieldBegin('foreignCatalogName', TType.STRING, 5) - oprot.writeString(self.foreignCatalogName.encode('utf-8') if sys.version_info[0] == 2 else self.foreignCatalogName) - oprot.writeFieldEnd() - if self.foreignSchemaName is not None: - oprot.writeFieldBegin('foreignSchemaName', TType.STRING, 6) - oprot.writeString(self.foreignSchemaName.encode('utf-8') if sys.version_info[0] == 2 else self.foreignSchemaName) - oprot.writeFieldEnd() - if self.foreignTableName is not None: - oprot.writeFieldBegin('foreignTableName', TType.STRING, 7) - oprot.writeString(self.foreignTableName.encode('utf-8') if sys.version_info[0] == 2 else self.foreignTableName) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.sessionHandle is None: - raise TProtocolException(message='Required field sessionHandle is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetCrossReferenceResp(object): - """ - Attributes: - - status - - operationHandle - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'status', (TStatus, TStatus.thrift_spec), None, ), # 1 - (2, TType.STRUCT, 'operationHandle', (TOperationHandle, TOperationHandle.thrift_spec), None, ), # 2 - ) - - def __init__(self, status=None, operationHandle=None,): - self.status = status - self.operationHandle = operationHandle - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.status = TStatus() - self.status.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRUCT: - self.operationHandle = TOperationHandle() - self.operationHandle.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetCrossReferenceResp') - if self.status is not None: - oprot.writeFieldBegin('status', TType.STRUCT, 1) - self.status.write(oprot) - oprot.writeFieldEnd() - if self.operationHandle is not None: - oprot.writeFieldBegin('operationHandle', TType.STRUCT, 2) - self.operationHandle.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.status is None: - raise TProtocolException(message='Required field status is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - -class TProgressUpdateResp(object): - """ - Attributes: - - headerNames - - rows - - progressedPercentage - - status - - footerSummary - - startTime - """ - - thrift_spec = ( - None, # 0 - (1, TType.LIST, 'headerNames', (TType.STRING, 'UTF8', False), None, ), # 1 - (2, TType.LIST, 'rows', (TType.LIST, (TType.STRING, 'UTF8', False), False), None, ), # 2 - (3, TType.DOUBLE, 'progressedPercentage', None, None, ), # 3 - (4, TType.I32, 'status', None, None, ), # 4 - (5, TType.STRING, 'footerSummary', 'UTF8', None, ), # 5 - (6, TType.I64, 'startTime', None, None, ), # 6 - ) - - def __init__(self, headerNames=None, rows=None, progressedPercentage=None, status=None, footerSummary=None, startTime=None,): - self.headerNames = headerNames - self.rows = rows - self.progressedPercentage = progressedPercentage - self.status = status - self.footerSummary = footerSummary - self.startTime = startTime - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.LIST: - self.headerNames = [] - (_etype162, _size159) = iprot.readListBegin() - for _i163 in range(_size159): - _elem164 = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - self.headerNames.append(_elem164) - iprot.readListEnd() - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.LIST: - self.rows = [] - (_etype168, _size165) = iprot.readListBegin() - for _i169 in range(_size165): - _elem170 = [] - (_etype174, _size171) = iprot.readListBegin() - for _i175 in range(_size171): - _elem176 = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - _elem170.append(_elem176) - iprot.readListEnd() - self.rows.append(_elem170) - iprot.readListEnd() - else: - iprot.skip(ftype) - elif fid == 3: - if ftype == TType.DOUBLE: - self.progressedPercentage = iprot.readDouble() - else: - iprot.skip(ftype) - elif fid == 4: - if ftype == TType.I32: - self.status = iprot.readI32() - else: - iprot.skip(ftype) - elif fid == 5: - if ftype == TType.STRING: - self.footerSummary = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 6: - if ftype == TType.I64: - self.startTime = iprot.readI64() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TProgressUpdateResp') - if self.headerNames is not None: - oprot.writeFieldBegin('headerNames', TType.LIST, 1) - oprot.writeListBegin(TType.STRING, len(self.headerNames)) - for iter177 in self.headerNames: - oprot.writeString(iter177.encode('utf-8') if sys.version_info[0] == 2 else iter177) - oprot.writeListEnd() - oprot.writeFieldEnd() - if self.rows is not None: - oprot.writeFieldBegin('rows', TType.LIST, 2) - oprot.writeListBegin(TType.LIST, len(self.rows)) - for iter178 in self.rows: - oprot.writeListBegin(TType.STRING, len(iter178)) - for iter179 in iter178: - oprot.writeString(iter179.encode('utf-8') if sys.version_info[0] == 2 else iter179) - oprot.writeListEnd() - oprot.writeListEnd() - oprot.writeFieldEnd() - if self.progressedPercentage is not None: - oprot.writeFieldBegin('progressedPercentage', TType.DOUBLE, 3) - oprot.writeDouble(self.progressedPercentage) - oprot.writeFieldEnd() - if self.status is not None: - oprot.writeFieldBegin('status', TType.I32, 4) - oprot.writeI32(self.status) - oprot.writeFieldEnd() - if self.footerSummary is not None: - oprot.writeFieldBegin('footerSummary', TType.STRING, 5) - oprot.writeString(self.footerSummary.encode('utf-8') if sys.version_info[0] == 2 else self.footerSummary) - oprot.writeFieldEnd() - if self.startTime is not None: - oprot.writeFieldBegin('startTime', TType.I64, 6) - oprot.writeI64(self.startTime) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.headerNames is None: - raise TProtocolException(message='Required field headerNames is unset!') - if self.rows is None: - raise TProtocolException(message='Required field rows is unset!') - if self.progressedPercentage is None: - raise TProtocolException(message='Required field progressedPercentage is unset!') - if self.status is None: - raise TProtocolException(message='Required field status is unset!') - if self.footerSummary is None: - raise TProtocolException(message='Required field footerSummary is unset!') - if self.startTime is None: - raise TProtocolException(message='Required field startTime is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - -class TGetOperationStatusReq(object): - """ - Attributes: - - operationHandle - - getProgressUpdate - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'operationHandle', (TOperationHandle, TOperationHandle.thrift_spec), None, ), # 1 - (2, TType.BOOL, 'getProgressUpdate', None, None, ), # 2 - ) - - def __init__(self, operationHandle=None, getProgressUpdate=None,): - self.operationHandle = operationHandle - self.getProgressUpdate = getProgressUpdate - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.operationHandle = TOperationHandle() - self.operationHandle.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.BOOL: - self.getProgressUpdate = iprot.readBool() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetOperationStatusReq') - if self.operationHandle is not None: - oprot.writeFieldBegin('operationHandle', TType.STRUCT, 1) - self.operationHandle.write(oprot) - oprot.writeFieldEnd() - if self.getProgressUpdate is not None: - oprot.writeFieldBegin('getProgressUpdate', TType.BOOL, 2) - oprot.writeBool(self.getProgressUpdate) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.operationHandle is None: - raise TProtocolException(message='Required field operationHandle is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetOperationStatusResp(object): - """ - Attributes: - - status - - operationState - - sqlState - - errorCode - - errorMessage - - taskStatus - - operationStarted - - operationCompleted - - hasResultSet - - progressUpdateResponse - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'status', (TStatus, TStatus.thrift_spec), None, ), # 1 - (2, TType.I32, 'operationState', None, None, ), # 2 - (3, TType.STRING, 'sqlState', 'UTF8', None, ), # 3 - (4, TType.I32, 'errorCode', None, None, ), # 4 - (5, TType.STRING, 'errorMessage', 'UTF8', None, ), # 5 - (6, TType.STRING, 'taskStatus', 'UTF8', None, ), # 6 - (7, TType.I64, 'operationStarted', None, None, ), # 7 - (8, TType.I64, 'operationCompleted', None, None, ), # 8 - (9, TType.BOOL, 'hasResultSet', None, None, ), # 9 - (10, TType.STRUCT, 'progressUpdateResponse', (TProgressUpdateResp, TProgressUpdateResp.thrift_spec), None, ), # 10 - ) - - def __init__(self, status=None, operationState=None, sqlState=None, errorCode=None, errorMessage=None, taskStatus=None, operationStarted=None, operationCompleted=None, hasResultSet=None, progressUpdateResponse=None,): - self.status = status - self.operationState = operationState - self.sqlState = sqlState - self.errorCode = errorCode - self.errorMessage = errorMessage - self.taskStatus = taskStatus - self.operationStarted = operationStarted - self.operationCompleted = operationCompleted - self.hasResultSet = hasResultSet - self.progressUpdateResponse = progressUpdateResponse - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.status = TStatus() - self.status.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.I32: - self.operationState = iprot.readI32() - else: - iprot.skip(ftype) - elif fid == 3: - if ftype == TType.STRING: - self.sqlState = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 4: - if ftype == TType.I32: - self.errorCode = iprot.readI32() - else: - iprot.skip(ftype) - elif fid == 5: - if ftype == TType.STRING: - self.errorMessage = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 6: - if ftype == TType.STRING: - self.taskStatus = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 7: - if ftype == TType.I64: - self.operationStarted = iprot.readI64() - else: - iprot.skip(ftype) - elif fid == 8: - if ftype == TType.I64: - self.operationCompleted = iprot.readI64() - else: - iprot.skip(ftype) - elif fid == 9: - if ftype == TType.BOOL: - self.hasResultSet = iprot.readBool() - else: - iprot.skip(ftype) - elif fid == 10: - if ftype == TType.STRUCT: - self.progressUpdateResponse = TProgressUpdateResp() - self.progressUpdateResponse.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetOperationStatusResp') - if self.status is not None: - oprot.writeFieldBegin('status', TType.STRUCT, 1) - self.status.write(oprot) - oprot.writeFieldEnd() - if self.operationState is not None: - oprot.writeFieldBegin('operationState', TType.I32, 2) - oprot.writeI32(self.operationState) - oprot.writeFieldEnd() - if self.sqlState is not None: - oprot.writeFieldBegin('sqlState', TType.STRING, 3) - oprot.writeString(self.sqlState.encode('utf-8') if sys.version_info[0] == 2 else self.sqlState) - oprot.writeFieldEnd() - if self.errorCode is not None: - oprot.writeFieldBegin('errorCode', TType.I32, 4) - oprot.writeI32(self.errorCode) - oprot.writeFieldEnd() - if self.errorMessage is not None: - oprot.writeFieldBegin('errorMessage', TType.STRING, 5) - oprot.writeString(self.errorMessage.encode('utf-8') if sys.version_info[0] == 2 else self.errorMessage) - oprot.writeFieldEnd() - if self.taskStatus is not None: - oprot.writeFieldBegin('taskStatus', TType.STRING, 6) - oprot.writeString(self.taskStatus.encode('utf-8') if sys.version_info[0] == 2 else self.taskStatus) - oprot.writeFieldEnd() - if self.operationStarted is not None: - oprot.writeFieldBegin('operationStarted', TType.I64, 7) - oprot.writeI64(self.operationStarted) - oprot.writeFieldEnd() - if self.operationCompleted is not None: - oprot.writeFieldBegin('operationCompleted', TType.I64, 8) - oprot.writeI64(self.operationCompleted) - oprot.writeFieldEnd() - if self.hasResultSet is not None: - oprot.writeFieldBegin('hasResultSet', TType.BOOL, 9) - oprot.writeBool(self.hasResultSet) - oprot.writeFieldEnd() - if self.progressUpdateResponse is not None: - oprot.writeFieldBegin('progressUpdateResponse', TType.STRUCT, 10) - self.progressUpdateResponse.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.status is None: - raise TProtocolException(message='Required field status is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TCancelOperationReq(object): - """ - Attributes: - - operationHandle - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'operationHandle', (TOperationHandle, TOperationHandle.thrift_spec), None, ), # 1 - ) - - def __init__(self, operationHandle=None,): - self.operationHandle = operationHandle - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.operationHandle = TOperationHandle() - self.operationHandle.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TCancelOperationReq') - if self.operationHandle is not None: - oprot.writeFieldBegin('operationHandle', TType.STRUCT, 1) - self.operationHandle.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.operationHandle is None: - raise TProtocolException(message='Required field operationHandle is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TCancelOperationResp(object): - """ - Attributes: - - status - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'status', (TStatus, TStatus.thrift_spec), None, ), # 1 - ) - - def __init__(self, status=None,): - self.status = status - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.status = TStatus() - self.status.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TCancelOperationResp') - if self.status is not None: - oprot.writeFieldBegin('status', TType.STRUCT, 1) - self.status.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.status is None: - raise TProtocolException(message='Required field status is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TCloseOperationReq(object): - """ - Attributes: - - operationHandle - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'operationHandle', (TOperationHandle, TOperationHandle.thrift_spec), None, ), # 1 - ) - - def __init__(self, operationHandle=None,): - self.operationHandle = operationHandle - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.operationHandle = TOperationHandle() - self.operationHandle.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TCloseOperationReq') - if self.operationHandle is not None: - oprot.writeFieldBegin('operationHandle', TType.STRUCT, 1) - self.operationHandle.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.operationHandle is None: - raise TProtocolException(message='Required field operationHandle is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TCloseOperationResp(object): - """ - Attributes: - - status - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'status', (TStatus, TStatus.thrift_spec), None, ), # 1 - ) - - def __init__(self, status=None,): - self.status = status - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.status = TStatus() - self.status.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TCloseOperationResp') - if self.status is not None: - oprot.writeFieldBegin('status', TType.STRUCT, 1) - self.status.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.status is None: - raise TProtocolException(message='Required field status is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetResultSetMetadataReq(object): - """ - Attributes: - - operationHandle - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'operationHandle', (TOperationHandle, TOperationHandle.thrift_spec), None, ), # 1 - ) - - def __init__(self, operationHandle=None,): - self.operationHandle = operationHandle - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.operationHandle = TOperationHandle() - self.operationHandle.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetResultSetMetadataReq') - if self.operationHandle is not None: - oprot.writeFieldBegin('operationHandle', TType.STRUCT, 1) - self.operationHandle.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.operationHandle is None: - raise TProtocolException(message='Required field operationHandle is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetResultSetMetadataResp(object): - """ - Attributes: - - status - - schema - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'status', (TStatus, TStatus.thrift_spec), None, ), # 1 - (2, TType.STRUCT, 'schema', (TTableSchema, TTableSchema.thrift_spec), None, ), # 2 - ) - - def __init__(self, status=None, schema=None,): - self.status = status - self.schema = schema - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.status = TStatus() - self.status.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRUCT: - self.schema = TTableSchema() - self.schema.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetResultSetMetadataResp') - if self.status is not None: - oprot.writeFieldBegin('status', TType.STRUCT, 1) - self.status.write(oprot) - oprot.writeFieldEnd() - if self.schema is not None: - oprot.writeFieldBegin('schema', TType.STRUCT, 2) - self.schema.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.status is None: - raise TProtocolException(message='Required field status is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TFetchResultsReq(object): - """ - Attributes: - - operationHandle - - orientation - - maxRows - - fetchType - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'operationHandle', (TOperationHandle, TOperationHandle.thrift_spec), None, ), # 1 - (2, TType.I32, 'orientation', None, 0, ), # 2 - (3, TType.I64, 'maxRows', None, None, ), # 3 - (4, TType.I16, 'fetchType', None, 0, ), # 4 - ) - - def __init__(self, operationHandle=None, orientation=thrift_spec[2][4], maxRows=None, fetchType=thrift_spec[4][4],): - self.operationHandle = operationHandle - self.orientation = orientation - self.maxRows = maxRows - self.fetchType = fetchType - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.operationHandle = TOperationHandle() - self.operationHandle.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.I32: - self.orientation = iprot.readI32() - else: - iprot.skip(ftype) - elif fid == 3: - if ftype == TType.I64: - self.maxRows = iprot.readI64() - else: - iprot.skip(ftype) - elif fid == 4: - if ftype == TType.I16: - self.fetchType = iprot.readI16() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TFetchResultsReq') - if self.operationHandle is not None: - oprot.writeFieldBegin('operationHandle', TType.STRUCT, 1) - self.operationHandle.write(oprot) - oprot.writeFieldEnd() - if self.orientation is not None: - oprot.writeFieldBegin('orientation', TType.I32, 2) - oprot.writeI32(self.orientation) - oprot.writeFieldEnd() - if self.maxRows is not None: - oprot.writeFieldBegin('maxRows', TType.I64, 3) - oprot.writeI64(self.maxRows) - oprot.writeFieldEnd() - if self.fetchType is not None: - oprot.writeFieldBegin('fetchType', TType.I16, 4) - oprot.writeI16(self.fetchType) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.operationHandle is None: - raise TProtocolException(message='Required field operationHandle is unset!') - if self.orientation is None: - raise TProtocolException(message='Required field orientation is unset!') - if self.maxRows is None: - raise TProtocolException(message='Required field maxRows is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TFetchResultsResp(object): - """ - Attributes: - - status - - hasMoreRows - - results - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'status', (TStatus, TStatus.thrift_spec), None, ), # 1 - (2, TType.BOOL, 'hasMoreRows', None, None, ), # 2 - (3, TType.STRUCT, 'results', (TRowSet, TRowSet.thrift_spec), None, ), # 3 - ) - - def __init__(self, status=None, hasMoreRows=None, results=None,): - self.status = status - self.hasMoreRows = hasMoreRows - self.results = results - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.status = TStatus() - self.status.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.BOOL: - self.hasMoreRows = iprot.readBool() - else: - iprot.skip(ftype) - elif fid == 3: - if ftype == TType.STRUCT: - self.results = TRowSet() - self.results.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TFetchResultsResp') - if self.status is not None: - oprot.writeFieldBegin('status', TType.STRUCT, 1) - self.status.write(oprot) - oprot.writeFieldEnd() - if self.hasMoreRows is not None: - oprot.writeFieldBegin('hasMoreRows', TType.BOOL, 2) - oprot.writeBool(self.hasMoreRows) - oprot.writeFieldEnd() - if self.results is not None: - oprot.writeFieldBegin('results', TType.STRUCT, 3) - self.results.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.status is None: - raise TProtocolException(message='Required field status is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetDelegationTokenReq(object): - """ - Attributes: - - sessionHandle - - owner - - renewer - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'sessionHandle', (TSessionHandle, TSessionHandle.thrift_spec), None, ), # 1 - (2, TType.STRING, 'owner', 'UTF8', None, ), # 2 - (3, TType.STRING, 'renewer', 'UTF8', None, ), # 3 - ) - - def __init__(self, sessionHandle=None, owner=None, renewer=None,): - self.sessionHandle = sessionHandle - self.owner = owner - self.renewer = renewer - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.sessionHandle = TSessionHandle() - self.sessionHandle.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRING: - self.owner = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - elif fid == 3: - if ftype == TType.STRING: - self.renewer = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetDelegationTokenReq') - if self.sessionHandle is not None: - oprot.writeFieldBegin('sessionHandle', TType.STRUCT, 1) - self.sessionHandle.write(oprot) - oprot.writeFieldEnd() - if self.owner is not None: - oprot.writeFieldBegin('owner', TType.STRING, 2) - oprot.writeString(self.owner.encode('utf-8') if sys.version_info[0] == 2 else self.owner) - oprot.writeFieldEnd() - if self.renewer is not None: - oprot.writeFieldBegin('renewer', TType.STRING, 3) - oprot.writeString(self.renewer.encode('utf-8') if sys.version_info[0] == 2 else self.renewer) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.sessionHandle is None: - raise TProtocolException(message='Required field sessionHandle is unset!') - if self.owner is None: - raise TProtocolException(message='Required field owner is unset!') - if self.renewer is None: - raise TProtocolException(message='Required field renewer is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetDelegationTokenResp(object): - """ - Attributes: - - status - - delegationToken - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'status', (TStatus, TStatus.thrift_spec), None, ), # 1 - (2, TType.STRING, 'delegationToken', 'UTF8', None, ), # 2 - ) - - def __init__(self, status=None, delegationToken=None,): - self.status = status - self.delegationToken = delegationToken - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.status = TStatus() - self.status.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRING: - self.delegationToken = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetDelegationTokenResp') - if self.status is not None: - oprot.writeFieldBegin('status', TType.STRUCT, 1) - self.status.write(oprot) - oprot.writeFieldEnd() - if self.delegationToken is not None: - oprot.writeFieldBegin('delegationToken', TType.STRING, 2) - oprot.writeString(self.delegationToken.encode('utf-8') if sys.version_info[0] == 2 else self.delegationToken) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.status is None: - raise TProtocolException(message='Required field status is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TCancelDelegationTokenReq(object): - """ - Attributes: - - sessionHandle - - delegationToken - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'sessionHandle', (TSessionHandle, TSessionHandle.thrift_spec), None, ), # 1 - (2, TType.STRING, 'delegationToken', 'UTF8', None, ), # 2 - ) - - def __init__(self, sessionHandle=None, delegationToken=None,): - self.sessionHandle = sessionHandle - self.delegationToken = delegationToken - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.sessionHandle = TSessionHandle() - self.sessionHandle.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRING: - self.delegationToken = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TCancelDelegationTokenReq') - if self.sessionHandle is not None: - oprot.writeFieldBegin('sessionHandle', TType.STRUCT, 1) - self.sessionHandle.write(oprot) - oprot.writeFieldEnd() - if self.delegationToken is not None: - oprot.writeFieldBegin('delegationToken', TType.STRING, 2) - oprot.writeString(self.delegationToken.encode('utf-8') if sys.version_info[0] == 2 else self.delegationToken) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.sessionHandle is None: - raise TProtocolException(message='Required field sessionHandle is unset!') - if self.delegationToken is None: - raise TProtocolException(message='Required field delegationToken is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TCancelDelegationTokenResp(object): - """ - Attributes: - - status - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'status', (TStatus, TStatus.thrift_spec), None, ), # 1 - ) - - def __init__(self, status=None,): - self.status = status - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.status = TStatus() - self.status.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TCancelDelegationTokenResp') - if self.status is not None: - oprot.writeFieldBegin('status', TType.STRUCT, 1) - self.status.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.status is None: - raise TProtocolException(message='Required field status is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TRenewDelegationTokenReq(object): - """ - Attributes: - - sessionHandle - - delegationToken - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'sessionHandle', (TSessionHandle, TSessionHandle.thrift_spec), None, ), # 1 - (2, TType.STRING, 'delegationToken', 'UTF8', None, ), # 2 - ) - - def __init__(self, sessionHandle=None, delegationToken=None,): - self.sessionHandle = sessionHandle - self.delegationToken = delegationToken - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.sessionHandle = TSessionHandle() - self.sessionHandle.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRING: - self.delegationToken = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TRenewDelegationTokenReq') - if self.sessionHandle is not None: - oprot.writeFieldBegin('sessionHandle', TType.STRUCT, 1) - self.sessionHandle.write(oprot) - oprot.writeFieldEnd() - if self.delegationToken is not None: - oprot.writeFieldBegin('delegationToken', TType.STRING, 2) - oprot.writeString(self.delegationToken.encode('utf-8') if sys.version_info[0] == 2 else self.delegationToken) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.sessionHandle is None: - raise TProtocolException(message='Required field sessionHandle is unset!') - if self.delegationToken is None: - raise TProtocolException(message='Required field delegationToken is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TRenewDelegationTokenResp(object): - """ - Attributes: - - status - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'status', (TStatus, TStatus.thrift_spec), None, ), # 1 - ) - - def __init__(self, status=None,): - self.status = status - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.status = TStatus() - self.status.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TRenewDelegationTokenResp') - if self.status is not None: - oprot.writeFieldBegin('status', TType.STRUCT, 1) - self.status.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.status is None: - raise TProtocolException(message='Required field status is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - -class TGetLogReq(object): - """ - Attributes: - - operationHandle - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'operationHandle', (TOperationHandle, TOperationHandle.thrift_spec), None, ), # 1 - ) - - def __init__(self, operationHandle=None,): - self.operationHandle = operationHandle - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.operationHandle = TOperationHandle() - self.operationHandle.read(iprot) - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetLogReq') - if self.operationHandle is not None: - oprot.writeFieldBegin('operationHandle', TType.STRUCT, 1) - self.operationHandle.write(oprot) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.operationHandle is None: - raise TProtocolException(message='Required field operationHandle is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) - - -class TGetLogResp(object): - """ - Attributes: - - status - - log - """ - - thrift_spec = ( - None, # 0 - (1, TType.STRUCT, 'status', (TStatus, TStatus.thrift_spec), None, ), # 1 - (2, TType.STRING, 'log', 'UTF8', None, ), # 2 - ) - - def __init__(self, status=None, log=None,): - self.status = status - self.log = log - - def read(self, iprot): - if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None: - iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec)) - return - iprot.readStructBegin() - while True: - (fname, ftype, fid) = iprot.readFieldBegin() - if ftype == TType.STOP: - break - if fid == 1: - if ftype == TType.STRUCT: - self.status = TStatus() - self.status.read(iprot) - else: - iprot.skip(ftype) - elif fid == 2: - if ftype == TType.STRING: - self.log = iprot.readString().decode('utf-8') if sys.version_info[0] == 2 else iprot.readString() - else: - iprot.skip(ftype) - else: - iprot.skip(ftype) - iprot.readFieldEnd() - iprot.readStructEnd() - - def write(self, oprot): - if oprot._fast_encode is not None and self.thrift_spec is not None: - oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec))) - return - oprot.writeStructBegin('TGetLogResp') - if self.status is not None: - oprot.writeFieldBegin('status', TType.STRUCT, 1) - self.status.write(oprot) - oprot.writeFieldEnd() - if self.log is not None: - oprot.writeFieldBegin('log', TType.STRING, 2) - oprot.writeString(self.log.encode('utf-8') if sys.version_info[0] == 2 else self.log) - oprot.writeFieldEnd() - oprot.writeFieldStop() - oprot.writeStructEnd() - - def validate(self): - if self.status is None: - raise TProtocolException(message='Required field status is unset!') - if self.log is None: - raise TProtocolException(message='Required field log is unset!') - return - - def __repr__(self): - L = ['%s=%r' % (key, value) - for key, value in self.__dict__.items()] - return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other): - return not (self == other) diff --git a/src/chronify/_vendor/kyuubi/pyhive/__init__.py b/src/chronify/_vendor/kyuubi/pyhive/__init__.py deleted file mode 100644 index 0a6bb1f..0000000 --- a/src/chronify/_vendor/kyuubi/pyhive/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from __future__ import absolute_import -from __future__ import unicode_literals -__version__ = '0.7.0' diff --git a/src/chronify/_vendor/kyuubi/pyhive/common.py b/src/chronify/_vendor/kyuubi/pyhive/common.py deleted file mode 100644 index 51692b9..0000000 --- a/src/chronify/_vendor/kyuubi/pyhive/common.py +++ /dev/null @@ -1,266 +0,0 @@ -"""Package private common utilities. Do not use directly. - -Many docstrings in this file are based on PEP-249, which is in the public domain. -""" - -from __future__ import absolute_import -from __future__ import unicode_literals -from builtins import bytes -from builtins import int -from builtins import object -from builtins import str -from past.builtins import basestring -from pyhive import exc -import abc -import collections -import time -import datetime -from future.utils import with_metaclass -from itertools import islice - -try: - from collections.abc import Iterable -except ImportError: - from collections import Iterable - - -class DBAPICursor(with_metaclass(abc.ABCMeta, object)): - """Base class for some common DB-API logic""" - - _STATE_NONE = 0 - _STATE_RUNNING = 1 - _STATE_FINISHED = 2 - - def __init__(self, poll_interval=1): - self._poll_interval = poll_interval - self._reset_state() - self.lastrowid = None - - def _reset_state(self): - """Reset state about the previous query in preparation for running another query""" - # State to return as part of DB-API - self._rownumber = 0 - - # Internal helper state - self._state = self._STATE_NONE - self._data = collections.deque() - self._columns = None - - def _fetch_while(self, fn): - while fn(): - self._fetch_more() - if fn(): - time.sleep(self._poll_interval) - - @abc.abstractproperty - def description(self): - raise NotImplementedError # pragma: no cover - - def close(self): - """By default, do nothing""" - pass - - @abc.abstractmethod - def _fetch_more(self): - """Get more results, append it to ``self._data``, and update ``self._state``.""" - raise NotImplementedError # pragma: no cover - - @property - def rowcount(self): - """By default, return -1 to indicate that this is not supported.""" - return -1 - - @abc.abstractmethod - def execute(self, operation, parameters=None): - """Prepare and execute a database operation (query or command). - - Parameters may be provided as sequence or mapping and will be bound to variables in the - operation. Variables are specified in a database-specific notation (see the module's - ``paramstyle`` attribute for details). - - Return values are not defined. - """ - raise NotImplementedError # pragma: no cover - - def executemany(self, operation, seq_of_parameters): - """Prepare a database operation (query or command) and then execute it against all parameter - sequences or mappings found in the sequence ``seq_of_parameters``. - - Only the final result set is retained. - - Return values are not defined. - """ - for parameters in seq_of_parameters[:-1]: - self.execute(operation, parameters) - while self._state != self._STATE_FINISHED: - self._fetch_more() - if seq_of_parameters: - self.execute(operation, seq_of_parameters[-1]) - - def fetchone(self): - """Fetch the next row of a query result set, returning a single sequence, or ``None`` when - no more data is available. - - An :py:class:`~pyhive.exc.Error` (or subclass) exception is raised if the previous call to - :py:meth:`execute` did not produce any result set or no call was issued yet. - """ - if self._state == self._STATE_NONE: - raise exc.ProgrammingError("No query yet") - - # Sleep until we're done or we have some data to return - self._fetch_while(lambda: not self._data and self._state != self._STATE_FINISHED) - - if not self._data: - return None - else: - self._rownumber += 1 - return self._data.popleft() - - def fetchmany(self, size=None): - """Fetch the next set of rows of a query result, returning a sequence of sequences (e.g. a - list of tuples). An empty sequence is returned when no more rows are available. - - The number of rows to fetch per call is specified by the parameter. If it is not given, the - cursor's arraysize determines the number of rows to be fetched. The method should try to - fetch as many rows as indicated by the size parameter. If this is not possible due to the - specified number of rows not being available, fewer rows may be returned. - - An :py:class:`~pyhive.exc.Error` (or subclass) exception is raised if the previous call to - :py:meth:`execute` did not produce any result set or no call was issued yet. - """ - if size is None: - size = self.arraysize - return list(islice(iter(self.fetchone, None), size)) - - def fetchall(self): - """Fetch all (remaining) rows of a query result, returning them as a sequence of sequences - (e.g. a list of tuples). - - An :py:class:`~pyhive.exc.Error` (or subclass) exception is raised if the previous call to - :py:meth:`execute` did not produce any result set or no call was issued yet. - """ - return list(iter(self.fetchone, None)) - - @property - def arraysize(self): - """This read/write attribute specifies the number of rows to fetch at a time with - :py:meth:`fetchmany`. It defaults to 1 meaning to fetch a single row at a time. - """ - return self._arraysize - - @arraysize.setter - def arraysize(self, value): - self._arraysize = value - - def setinputsizes(self, sizes): - """Does nothing by default""" - pass - - def setoutputsize(self, size, column=None): - """Does nothing by default""" - pass - - # - # Optional DB API Extensions - # - - @property - def rownumber(self): - """This read-only attribute should provide the current 0-based index of the cursor in the - result set. - - The index can be seen as index of the cursor in a sequence (the result set). The next fetch - operation will fetch the row indexed by ``rownumber`` in that sequence. - """ - return self._rownumber - - def __next__(self): - """Return the next row from the currently executing SQL statement using the same semantics - as :py:meth:`fetchone`. A ``StopIteration`` exception is raised when the result set is - exhausted. - """ - one = self.fetchone() - if one is None: - raise StopIteration - else: - return one - - next = __next__ - - def __iter__(self): - """Return self to make cursors compatible to the iteration protocol.""" - return self - - -class DBAPITypeObject(object): - # Taken from http://www.python.org/dev/peps/pep-0249/#implementation-hints - def __init__(self, *values): - self.values = values - - def __cmp__(self, other): - if other in self.values: - return 0 - if other < self.values: - return 1 - else: - return -1 - - -class ParamEscaper(object): - _DATE_FORMAT = "%Y-%m-%d" - _TIME_FORMAT = "%H:%M:%S.%f" - _DATETIME_FORMAT = "{} {}".format(_DATE_FORMAT, _TIME_FORMAT) - - def escape_args(self, parameters): - if isinstance(parameters, dict): - return {k: self.escape_item(v) for k, v in parameters.items()} - elif isinstance(parameters, (list, tuple)): - return tuple(self.escape_item(x) for x in parameters) - else: - raise exc.ProgrammingError("Unsupported param format: {}".format(parameters)) - - def escape_number(self, item): - return item - - def escape_string(self, item): - # Need to decode UTF-8 because of old sqlalchemy. - # Newer SQLAlchemy checks dialect.supports_unicode_binds before encoding Unicode strings - # as byte strings. The old version always encodes Unicode as byte strings, which breaks - # string formatting here. - if isinstance(item, bytes): - item = item.decode('utf-8') - # This is good enough when backslashes are literal, newlines are just followed, and the way - # to escape a single quote is to put two single quotes. - # (i.e. only special character is single quote) - return "'{}'".format(item.replace("'", "''")) - - def escape_sequence(self, item): - l = map(str, map(self.escape_item, item)) - return '(' + ','.join(l) + ')' - - def escape_datetime(self, item, format, cutoff=0): - dt_str = item.strftime(format) - formatted = dt_str[:-cutoff] if cutoff and format.endswith(".%f") else dt_str - return "'{}'".format(formatted) - - def escape_item(self, item): - if item is None: - return 'NULL' - elif isinstance(item, (int, float)): - return self.escape_number(item) - elif isinstance(item, basestring): - return self.escape_string(item) - elif isinstance(item, Iterable): - return self.escape_sequence(item) - elif isinstance(item, datetime.datetime): - return self.escape_datetime(item, self._DATETIME_FORMAT) - elif isinstance(item, datetime.date): - return self.escape_datetime(item, self._DATE_FORMAT) - else: - raise exc.ProgrammingError("Unsupported object {}".format(item)) - - -class UniversalSet(object): - """set containing everything""" - def __contains__(self, item): - return True diff --git a/src/chronify/_vendor/kyuubi/pyhive/exc.py b/src/chronify/_vendor/kyuubi/pyhive/exc.py deleted file mode 100644 index 931cf21..0000000 --- a/src/chronify/_vendor/kyuubi/pyhive/exc.py +++ /dev/null @@ -1,72 +0,0 @@ -""" -Package private common utilities. Do not use directly. -""" -from __future__ import absolute_import -from __future__ import unicode_literals - -__all__ = [ - 'Error', 'Warning', 'InterfaceError', 'DatabaseError', 'InternalError', 'OperationalError', - 'ProgrammingError', 'DataError', 'NotSupportedError', -] - - -class Error(Exception): - """Exception that is the base class of all other error exceptions. - - You can use this to catch all errors with one single except statement. - """ - pass - - -class Warning(Exception): - """Exception raised for important warnings like data truncations while inserting, etc.""" - pass - - -class InterfaceError(Error): - """Exception raised for errors that are related to the database interface rather than the - database itself. - """ - pass - - -class DatabaseError(Error): - """Exception raised for errors that are related to the database.""" - pass - - -class InternalError(DatabaseError): - """Exception raised when the database encounters an internal error, e.g. the cursor is not valid - anymore, the transaction is out of sync, etc.""" - pass - - -class OperationalError(DatabaseError): - """Exception raised for errors that are related to the database's operation and not necessarily - under the control of the programmer, e.g. an unexpected disconnect occurs, the data source name - is not found, a transaction could not be processed, a memory allocation error occurred during - processing, etc. - """ - pass - - -class ProgrammingError(DatabaseError): - """Exception raised for programming errors, e.g. table not found or already exists, syntax error - in the SQL statement, wrong number of parameters specified, etc. - """ - pass - - -class DataError(DatabaseError): - """Exception raised for errors that are due to problems with the processed data like division by - zero, numeric value out of range, etc. - """ - pass - - -class NotSupportedError(DatabaseError): - """Exception raised in case a method or database API was used which is not supported by the - database, e.g. requesting a ``.rollback()`` on a connection that does not support transaction or - has transactions turned off. - """ - pass diff --git a/src/chronify/_vendor/kyuubi/pyhive/hive.py b/src/chronify/_vendor/kyuubi/pyhive/hive.py deleted file mode 100644 index d6f8080..0000000 --- a/src/chronify/_vendor/kyuubi/pyhive/hive.py +++ /dev/null @@ -1,620 +0,0 @@ -"""DB-API implementation backed by HiveServer2 (Thrift API) - -See http://www.python.org/dev/peps/pep-0249/ - -Many docstrings in this file are based on the PEP, which is in the public domain. -""" - -from __future__ import absolute_import -from __future__ import unicode_literals - -import base64 -import datetime -import re -from decimal import Decimal -from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED, create_default_context - - -from TCLIService import TCLIService -from TCLIService import constants -from TCLIService import ttypes -from pyhive import common -from pyhive.common import DBAPITypeObject -# Make all exceptions visible in this module per DB-API -from pyhive.exc import * # noqa -from builtins import range -import contextlib -from future.utils import iteritems -import getpass -import logging -import sys -import thrift.transport.THttpClient -import thrift.protocol.TBinaryProtocol -import thrift.transport.TSocket -import thrift.transport.TTransport - -# PEP 249 module globals -apilevel = '2.0' -threadsafety = 2 # Threads may share the module and connections. -paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s - -_logger = logging.getLogger(__name__) - -_TIMESTAMP_PATTERN = re.compile(r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)') - -ssl_cert_parameter_map = { - "none": CERT_NONE, - "optional": CERT_OPTIONAL, - "required": CERT_REQUIRED, -} - - -def get_sasl_client(host, sasl_auth, service=None, username=None, password=None): - import sasl - sasl_client = sasl.Client() - sasl_client.setAttr('host', host) - - if sasl_auth == 'GSSAPI': - sasl_client.setAttr('service', service) - elif sasl_auth == 'PLAIN': - sasl_client.setAttr('username', username) - sasl_client.setAttr('password', password) - else: - raise ValueError("sasl_auth only supports GSSAPI and PLAIN") - - sasl_client.init() - return sasl_client - - -def get_pure_sasl_client(host, sasl_auth, service=None, username=None, password=None): - from pyhive.sasl_compat import PureSASLClient - - if sasl_auth == 'GSSAPI': - sasl_kwargs = {'service': service} - elif sasl_auth == 'PLAIN': - sasl_kwargs = {'username': username, 'password': password} - else: - raise ValueError("sasl_auth only supports GSSAPI and PLAIN") - - return PureSASLClient(host=host, **sasl_kwargs) - - -def get_installed_sasl(host, sasl_auth, service=None, username=None, password=None): - try: - return get_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password) - # The sasl library is available - except ImportError: - # Fallback to pure-sasl library - return get_pure_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password) - - -def _parse_timestamp(value): - if value: - match = _TIMESTAMP_PATTERN.match(value) - if match: - if match.group(2): - format = '%Y-%m-%d %H:%M:%S.%f' - # use the pattern to truncate the value - value = match.group() - else: - format = '%Y-%m-%d %H:%M:%S' - value = datetime.datetime.strptime(value, format) - else: - raise Exception( - 'Cannot convert "{}" into a datetime'.format(value)) - else: - value = None - return value - - -TYPES_CONVERTER = {"DECIMAL_TYPE": Decimal, - "TIMESTAMP_TYPE": _parse_timestamp} - - -class HiveParamEscaper(common.ParamEscaper): - def escape_string(self, item): - # backslashes and single quotes need to be escaped - # TODO verify against parser - # Need to decode UTF-8 because of old sqlalchemy. - # Newer SQLAlchemy checks dialect.supports_unicode_binds before encoding Unicode strings - # as byte strings. The old version always encodes Unicode as byte strings, which breaks - # string formatting here. - if isinstance(item, bytes): - item = item.decode('utf-8') - return "'{}'".format( - item - .replace('\\', '\\\\') - .replace("'", "\\'") - .replace('\r', '\\r') - .replace('\n', '\\n') - .replace('\t', '\\t') - ) - - -_escaper = HiveParamEscaper() - - -def connect(*args, **kwargs): - """Constructor for creating a connection to the database. See class :py:class:`Connection` for - arguments. - - :returns: a :py:class:`Connection` object. - """ - return Connection(*args, **kwargs) - - -class Connection(object): - """Wraps a Thrift session""" - - def __init__( - self, - host=None, - port=None, - scheme=None, - username=None, - database='default', - auth=None, - configuration=None, - kerberos_service_name=None, - password=None, - check_hostname=None, - ssl_cert=None, - thrift_transport=None, - ssl_context=None - ): - """Connect to HiveServer2 - - :param host: What host HiveServer2 runs on - :param port: What port HiveServer2 runs on. Defaults to 10000. - :param auth: The value of hive.server2.authentication used by HiveServer2. - Defaults to ``NONE``. - :param configuration: A dictionary of Hive settings (functionally same as the `set` command) - :param kerberos_service_name: Use with auth='KERBEROS' only - :param password: Use with auth='LDAP' or auth='CUSTOM' only - :param thrift_transport: A ``TTransportBase`` for custom advanced usage. - Incompatible with host, port, auth, kerberos_service_name, and password. - :param ssl_context: A custom SSL context to use for HTTPS connections. If provided, - this overrides check_hostname and ssl_cert parameters. - The way to support LDAP and GSSAPI is originated from cloudera/Impyla: - https://github.com/cloudera/impyla/blob/255b07ed973d47a3395214ed92d35ec0615ebf62 - /impala/_thrift_api.py#L152-L160 - """ - if scheme in ("https", "http") and thrift_transport is None: - port = port or 1000 - if scheme == "https": - if ssl_context is None: - ssl_context = create_default_context() - ssl_context.check_hostname = check_hostname == "true" - ssl_cert = ssl_cert or "none" - ssl_context.verify_mode = ssl_cert_parameter_map.get(ssl_cert, CERT_NONE) - thrift_transport = thrift.transport.THttpClient.THttpClient( - uri_or_host="{scheme}://{host}:{port}/cliservice/".format( - scheme=scheme, host=host, port=port - ), - ssl_context=ssl_context, - ) - - if auth in ("BASIC", "NOSASL", "NONE", None): - # Always needs the Authorization header - self._set_authorization_header(thrift_transport, username, password) - elif auth == "KERBEROS" and kerberos_service_name: - self._set_kerberos_header(thrift_transport, kerberos_service_name, host) - else: - raise ValueError( - "Authentication is not valid use one of:" - "BASIC, NOSASL, KERBEROS, NONE" - ) - host, port, auth, kerberos_service_name, password = ( - None, None, None, None, None - ) - - username = username or getpass.getuser() - configuration = configuration or {} - - if (password is not None) != (auth in ('LDAP', 'CUSTOM')): - raise ValueError("Password should be set if and only if in LDAP or CUSTOM mode; " - "Remove password or use one of those modes") - if (kerberos_service_name is not None) != (auth == 'KERBEROS'): - raise ValueError("kerberos_service_name should be set if and only if in KERBEROS mode") - if thrift_transport is not None: - has_incompatible_arg = ( - host is not None - or port is not None - or auth is not None - or kerberos_service_name is not None - or password is not None - ) - if has_incompatible_arg: - raise ValueError("thrift_transport cannot be used with " - "host/port/auth/kerberos_service_name/password") - - if thrift_transport is not None: - self._transport = thrift_transport - else: - if port is None: - port = 10000 - if auth is None: - auth = 'NONE' - socket = thrift.transport.TSocket.TSocket(host, port) - if auth == 'NOSASL': - # NOSASL corresponds to hive.server2.authentication=NOSASL in hive-site.xml - self._transport = thrift.transport.TTransport.TBufferedTransport(socket) - elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'): - # Defer import so package dependency is optional - import thrift_sasl - - if auth == 'KERBEROS': - # KERBEROS mode in hive.server2.authentication is GSSAPI in sasl library - sasl_auth = 'GSSAPI' - else: - sasl_auth = 'PLAIN' - if password is None: - # Password doesn't matter in NONE mode, just needs to be nonempty. - password = 'x' - - self._transport = thrift_sasl.TSaslClientTransport(lambda: get_installed_sasl(host=host, sasl_auth=sasl_auth, service=kerberos_service_name, username=username, password=password), sasl_auth, socket) - else: - # All HS2 config options: - # https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2#SettingUpHiveServer2-Configuration - # PAM currently left to end user via thrift_transport option. - raise NotImplementedError( - "Only NONE, NOSASL, LDAP, KERBEROS, CUSTOM " - "authentication are supported, got {}".format(auth)) - - protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(self._transport) - self._client = TCLIService.Client(protocol) - # oldest version that still contains features we care about - # "V6 uses binary type for binary payload (was string) and uses columnar result set" - protocol_version = ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6 - - try: - self._transport.open() - open_session_req = ttypes.TOpenSessionReq( - client_protocol=protocol_version, - configuration=configuration, - username=username, - ) - response = self._client.OpenSession(open_session_req) - _check_status(response) - assert response.sessionHandle is not None, "Expected a session from OpenSession" - self._sessionHandle = response.sessionHandle - assert response.serverProtocolVersion == protocol_version, \ - "Unable to handle protocol version {}".format(response.serverProtocolVersion) - with contextlib.closing(self.cursor()) as cursor: - cursor.execute('USE `{}`'.format(database)) - except: - self._transport.close() - raise - - @staticmethod - def _set_authorization_header(transport, username=None, password=None): - username = username or "user" - password = password or "pass" - auth_credentials = "{username}:{password}".format( - username=username, password=password - ).encode("UTF-8") - auth_credentials_base64 = base64.standard_b64encode(auth_credentials).decode( - "UTF-8" - ) - transport.setCustomHeaders( - { - "Authorization": "Basic {auth_credentials_base64}".format( - auth_credentials_base64=auth_credentials_base64 - ) - } - ) - - @staticmethod - def _set_kerberos_header(transport, kerberos_service_name, host): - import kerberos - - __, krb_context = kerberos.authGSSClientInit( - service="{kerberos_service_name}@{host}".format( - kerberos_service_name=kerberos_service_name, host=host - ) - ) - kerberos.authGSSClientClean(krb_context, "") - kerberos.authGSSClientStep(krb_context, "") - auth_header = kerberos.authGSSClientResponse(krb_context) - - transport.setCustomHeaders( - { - "Authorization": "Negotiate {auth_header}".format( - auth_header=auth_header - ) - } - ) - - def __enter__(self): - """Transport should already be opened by __init__""" - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """Call close""" - self.close() - - def close(self): - """Close the underlying session and Thrift transport""" - req = ttypes.TCloseSessionReq(sessionHandle=self._sessionHandle) - response = self._client.CloseSession(req) - self._transport.close() - _check_status(response) - - def commit(self): - """Hive does not support transactions, so this does nothing.""" - pass - - def cursor(self, *args, **kwargs): - """Return a new :py:class:`Cursor` object using the connection.""" - return Cursor(self, *args, **kwargs) - - @property - def client(self): - return self._client - - @property - def sessionHandle(self): - return self._sessionHandle - - def rollback(self): - raise NotSupportedError("Hive does not have transactions") # pragma: no cover - - -class Cursor(common.DBAPICursor): - """These objects represent a database cursor, which is used to manage the context of a fetch - operation. - - Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately - visible by other cursors or connections. - """ - - def __init__(self, connection, arraysize=1000): - self._operationHandle = None - super(Cursor, self).__init__() - self._arraysize = arraysize - self._connection = connection - - def _reset_state(self): - """Reset state about the previous query in preparation for running another query""" - super(Cursor, self)._reset_state() - self._description = None - if self._operationHandle is not None: - request = ttypes.TCloseOperationReq(self._operationHandle) - try: - response = self._connection.client.CloseOperation(request) - _check_status(response) - finally: - self._operationHandle = None - - @property - def arraysize(self): - return self._arraysize - - @arraysize.setter - def arraysize(self, value): - """Array size cannot be None, and should be an integer""" - default_arraysize = 1000 - try: - self._arraysize = int(value) or default_arraysize - except TypeError: - self._arraysize = default_arraysize - - @property - def description(self): - """This read-only attribute is a sequence of 7-item sequences. - - Each of these sequences contains information describing one result column: - - - name - - type_code - - display_size (None in current implementation) - - internal_size (None in current implementation) - - precision (None in current implementation) - - scale (None in current implementation) - - null_ok (always True in current implementation) - - This attribute will be ``None`` for operations that do not return rows or if the cursor has - not had an operation invoked via the :py:meth:`execute` method yet. - - The ``type_code`` can be interpreted by comparing it to the Type Objects specified in the - section below. - """ - if self._operationHandle is None or not self._operationHandle.hasResultSet: - return None - if self._description is None: - req = ttypes.TGetResultSetMetadataReq(self._operationHandle) - response = self._connection.client.GetResultSetMetadata(req) - _check_status(response) - columns = response.schema.columns - self._description = [] - for col in columns: - primary_type_entry = col.typeDesc.types[0] - if primary_type_entry.primitiveEntry is None: - # All fancy stuff maps to string - type_code = ttypes.TTypeId._VALUES_TO_NAMES[ttypes.TTypeId.STRING_TYPE] - else: - type_id = primary_type_entry.primitiveEntry.type - try: - type_code = ttypes.TTypeId._VALUES_TO_NAMES[type_id] - except KeyError: - type_code = None - self._description.append(( - col.columnName.decode('utf-8') if sys.version_info[0] == 2 else col.columnName, - type_code.decode('utf-8') if sys.version_info[0] == 2 else type_code, - None, None, None, None, True - )) - return self._description - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - - def close(self): - """Close the operation handle""" - self._reset_state() - - def execute(self, operation, parameters=None, **kwargs): - """Prepare and execute a database operation (query or command). - - Return values are not defined. - """ - # backward compatibility with Python < 3.7 - for kw in ['async', 'async_']: - if kw in kwargs: - async_ = kwargs[kw] - break - else: - async_ = False - - # Prepare statement - if parameters is None: - sql = operation - else: - sql = operation % _escaper.escape_args(parameters) - - self._reset_state() - - self._state = self._STATE_RUNNING - _logger.info('%s', sql) - - req = ttypes.TExecuteStatementReq(self._connection.sessionHandle, - sql, runAsync=async_) - _logger.debug(req) - response = self._connection.client.ExecuteStatement(req) - _check_status(response) - self._operationHandle = response.operationHandle - - def cancel(self): - req = ttypes.TCancelOperationReq( - operationHandle=self._operationHandle, - ) - response = self._connection.client.CancelOperation(req) - _check_status(response) - - def _fetch_more(self): - """Send another TFetchResultsReq and update state""" - assert(self._state == self._STATE_RUNNING), "Should be running when in _fetch_more" - assert(self._operationHandle is not None), "Should have an op handle in _fetch_more" - if not self._operationHandle.hasResultSet: - raise ProgrammingError("No result set") - req = ttypes.TFetchResultsReq( - operationHandle=self._operationHandle, - orientation=ttypes.TFetchOrientation.FETCH_NEXT, - maxRows=self.arraysize, - ) - response = self._connection.client.FetchResults(req) - _check_status(response) - schema = self.description - assert not response.results.rows, 'expected data in columnar format' - has_new_data = False - if response.results.columns: - columns = [_unwrap_column(col, col_schema[1]) for col, col_schema in - zip(response.results.columns, schema)] - new_data = list(zip(*columns)) - self._data += new_data - has_new_data = (True if new_data else False) - # response.hasMoreRows seems to always be False, so we instead check the number of rows - # https://github.com/apache/hive/blob/release-1.2.1/service/src/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java#L678 - # if not response.hasMoreRows: - if not has_new_data: - self._state = self._STATE_FINISHED - - def poll(self, get_progress_update=True): - """Poll for and return the raw status data provided by the Hive Thrift REST API. - :returns: ``ttypes.TGetOperationStatusResp`` - :raises: ``ProgrammingError`` when no query has been started - .. note:: - This is not a part of DB-API. - """ - if self._state == self._STATE_NONE: - raise ProgrammingError("No query yet") - - req = ttypes.TGetOperationStatusReq( - operationHandle=self._operationHandle, - getProgressUpdate=get_progress_update, - ) - response = self._connection.client.GetOperationStatus(req) - _check_status(response) - - return response - - def fetch_logs(self): - """Retrieve the logs produced by the execution of the query. - Can be called multiple times to fetch the logs produced after the previous call. - :returns: list - :raises: ``ProgrammingError`` when no query has been started - .. note:: - This is not a part of DB-API. - """ - if self._state == self._STATE_NONE: - raise ProgrammingError("No query yet") - - try: # Older Hive instances require logs to be retrieved using GetLog - req = ttypes.TGetLogReq(operationHandle=self._operationHandle) - logs = self._connection.client.GetLog(req).log.splitlines() - except ttypes.TApplicationException as e: # Otherwise, retrieve logs using newer method - if e.type != ttypes.TApplicationException.UNKNOWN_METHOD: - raise - logs = [] - while True: - req = ttypes.TFetchResultsReq( - operationHandle=self._operationHandle, - orientation=ttypes.TFetchOrientation.FETCH_NEXT, - maxRows=self.arraysize, - fetchType=1 # 0: results, 1: logs - ) - response = self._connection.client.FetchResults(req) - _check_status(response) - assert not response.results.rows, 'expected data in columnar format' - new_logs = '' - if response.results.columns: - new_logs = _unwrap_column(response.results.columns[0]) - logs += new_logs - - if not new_logs: - break - - return logs - - -# -# Type Objects and Constructors -# - - -for type_id in constants.PRIMITIVE_TYPES: - name = ttypes.TTypeId._VALUES_TO_NAMES[type_id] - setattr(sys.modules[__name__], name, DBAPITypeObject([name])) - - -# -# Private utilities -# - - -def _unwrap_column(col, type_=None): - """Return a list of raw values from a TColumn instance.""" - for attr, wrapper in iteritems(col.__dict__): - if wrapper is not None: - result = wrapper.values - nulls = wrapper.nulls # bit set describing what's null - assert isinstance(nulls, bytes) - for i, char in enumerate(nulls): - byte = ord(char) if sys.version_info[0] == 2 else char - for b in range(8): - if byte & (1 << b): - result[i * 8 + b] = None - converter = TYPES_CONVERTER.get(type_, None) - if converter and type_: - result = [converter(row) if row else row for row in result] - return result - raise DataError("Got empty column value {}".format(col)) # pragma: no cover - - -def _check_status(response): - """Raise an OperationalError if the status is not success""" - _logger.debug(response) - if response.status.statusCode != ttypes.TStatusCode.SUCCESS_STATUS: - raise OperationalError(response) diff --git a/src/chronify/_vendor/kyuubi/pyhive/presto.py b/src/chronify/_vendor/kyuubi/pyhive/presto.py deleted file mode 100644 index 3217f4c..0000000 --- a/src/chronify/_vendor/kyuubi/pyhive/presto.py +++ /dev/null @@ -1,367 +0,0 @@ -"""DB-API implementation backed by Presto - -See http://www.python.org/dev/peps/pep-0249/ - -Many docstrings in this file are based on the PEP, which is in the public domain. -""" - -from __future__ import absolute_import -from __future__ import unicode_literals - -from builtins import object -from decimal import Decimal - -from pyhive import common -from pyhive.common import DBAPITypeObject -# Make all exceptions visible in this module per DB-API -from pyhive.exc import * # noqa -import base64 -import getpass -import datetime -import logging -import requests -from requests.auth import HTTPBasicAuth -import os - -try: # Python 3 - import urllib.parse as urlparse -except ImportError: # Python 2 - import urlparse - - -# PEP 249 module globals -apilevel = '2.0' -threadsafety = 2 # Threads may share the module and connections. -paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s - -_logger = logging.getLogger(__name__) - -TYPES_CONVERTER = { - "decimal": Decimal, - # As of Presto 0.69, binary data is returned as the varbinary type in base64 format - "varbinary": base64.b64decode -} - -class PrestoParamEscaper(common.ParamEscaper): - def escape_datetime(self, item, format): - _type = "timestamp" if isinstance(item, datetime.datetime) else "date" - formatted = super(PrestoParamEscaper, self).escape_datetime(item, format, 3) - return "{} {}".format(_type, formatted) - - -_escaper = PrestoParamEscaper() - - -def connect(*args, **kwargs): - """Constructor for creating a connection to the database. See class :py:class:`Connection` for - arguments. - - :returns: a :py:class:`Connection` object. - """ - return Connection(*args, **kwargs) - - -class Connection(object): - """Presto does not have a notion of a persistent connection. - - Thus, these objects are small stateless factories for cursors, which do all the real work. - """ - - def __init__(self, *args, **kwargs): - self._args = args - self._kwargs = kwargs - - def close(self): - """Presto does not have anything to close""" - # TODO cancel outstanding queries? - pass - - def commit(self): - """Presto does not support transactions""" - pass - - def cursor(self): - """Return a new :py:class:`Cursor` object using the connection.""" - return Cursor(*self._args, **self._kwargs) - - def rollback(self): - raise NotSupportedError("Presto does not have transactions") # pragma: no cover - - -class Cursor(common.DBAPICursor): - """These objects represent a database cursor, which is used to manage the context of a fetch - operation. - - Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately - visible by other cursors or connections. - """ - - def __init__(self, host, port='8080', username=None, principal_username=None, catalog='hive', - schema='default', poll_interval=1, source='pyhive', session_props=None, - protocol='http', password=None, requests_session=None, requests_kwargs=None, - KerberosRemoteServiceName=None, KerberosPrincipal=None, - KerberosConfigPath=None, KerberosKeytabPath=None, - KerberosCredentialCachePath=None, KerberosUseCanonicalHostname=None): - """ - :param host: hostname to connect to, e.g. ``presto.example.com`` - :param port: int -- port, defaults to 8080 - :param username: string -- defaults to system user name - :param principal_username: string -- defaults to ``username`` argument if it exists, - else defaults to system user name - :param catalog: string -- defaults to ``hive`` - :param schema: string -- defaults to ``default`` - :param poll_interval: float -- how often to ask the Presto REST interface for a progress - update, defaults to a second - :param source: string -- arbitrary identifier (shows up in the Presto monitoring page) - :param protocol: string -- network protocol, valid options are ``http`` and ``https``. - defaults to ``http`` - :param password: string -- Deprecated. Defaults to ``None``. - Using BasicAuth, requires ``https``. - Prefer ``requests_kwargs={'auth': HTTPBasicAuth(username, password)}``. - May not be specified with ``requests_kwargs['auth']``. - :param requests_session: a ``requests.Session`` object for advanced usage. If absent, this - class will use the default requests behavior of making a new session per HTTP request. - Caller is responsible for closing session. - :param requests_kwargs: Additional ``**kwargs`` to pass to requests - :param KerberosRemoteServiceName: string -- Presto coordinator Kerberos service name. - This parameter is required for Kerberos authentiation. - :param KerberosPrincipal: string -- The principal to use when authenticating to - the Presto coordinator. - :param KerberosConfigPath: string -- Kerberos configuration file. - (default: /etc/krb5.conf) - :param KerberosKeytabPath: string -- Kerberos keytab file. - :param KerberosCredentialCachePath: string -- Kerberos credential cache. - :param KerberosUseCanonicalHostname: boolean -- Use the canonical hostname of the - Presto coordinator for the Kerberos service principal by first resolving the - hostname to an IP address and then doing a reverse DNS lookup for that IP address. - This is enabled by default. - """ - super(Cursor, self).__init__(poll_interval) - # Config - self._host = host - self._port = port - """ - Presto User Impersonation: https://docs.starburstdata.com/latest/security/impersonation.html - - User impersonation allows the execution of queries in Presto based on principal_username - argument, instead of executing the query as the account which authenticated against Presto. - (Usually a service account) - - Allows for a service account to authenticate with Presto, and then leverage the - principal_username as the user Presto will execute the query as. This is required by - applications that leverage authentication methods like SAML, where the application has a - username, but not a password to still leverage user specific Presto Resource Groups and - Authorization rules that would not be applied when only using a shared service account. - This also allows auditing of who is executing a query in these environments, instead of - having all queryes run by the shared service account. - """ - self._username = principal_username or username or getpass.getuser() - self._catalog = catalog - self._schema = schema - self._arraysize = 1 - self._poll_interval = poll_interval - self._source = source - self._session_props = session_props if session_props is not None else {} - self.last_query_id = None - - if protocol not in ('http', 'https'): - raise ValueError("Protocol must be http/https, was {!r}".format(protocol)) - self._protocol = protocol - - self._requests_session = requests_session or requests - - requests_kwargs = dict(requests_kwargs) if requests_kwargs is not None else {} - - if KerberosRemoteServiceName is not None: - from requests_kerberos import HTTPKerberosAuth, OPTIONAL - - hostname_override = None - if KerberosUseCanonicalHostname is not None \ - and KerberosUseCanonicalHostname.lower() == 'false': - hostname_override = host - if KerberosConfigPath is not None: - os.environ['KRB5_CONFIG'] = KerberosConfigPath - if KerberosKeytabPath is not None: - os.environ['KRB5_CLIENT_KTNAME'] = KerberosKeytabPath - if KerberosCredentialCachePath is not None: - os.environ['KRB5CCNAME'] = KerberosCredentialCachePath - - requests_kwargs['auth'] = HTTPKerberosAuth(mutual_authentication=OPTIONAL, - principal=KerberosPrincipal, - service=KerberosRemoteServiceName, - hostname_override=hostname_override) - - else: - if password is not None and 'auth' in requests_kwargs: - raise ValueError("Cannot use both password and requests_kwargs authentication") - for k in ('method', 'url', 'data', 'headers'): - if k in requests_kwargs: - raise ValueError("Cannot override requests argument {}".format(k)) - if password is not None: - requests_kwargs['auth'] = HTTPBasicAuth(username, password) - if protocol != 'https': - raise ValueError("Protocol must be https when passing a password") - self._requests_kwargs = requests_kwargs - - self._reset_state() - - def _reset_state(self): - """Reset state about the previous query in preparation for running another query""" - super(Cursor, self)._reset_state() - self._nextUri = None - self._columns = None - - @property - def description(self): - """This read-only attribute is a sequence of 7-item sequences. - - Each of these sequences contains information describing one result column: - - - name - - type_code - - display_size (None in current implementation) - - internal_size (None in current implementation) - - precision (None in current implementation) - - scale (None in current implementation) - - null_ok (always True in current implementation) - - The ``type_code`` can be interpreted by comparing it to the Type Objects specified in the - section below. - """ - # Sleep until we're done or we got the columns - self._fetch_while( - lambda: self._columns is None and - self._state not in (self._STATE_NONE, self._STATE_FINISHED) - ) - if self._columns is None: - return None - return [ - # name, type_code, display_size, internal_size, precision, scale, null_ok - (col['name'], col['type'], None, None, None, None, True) - for col in self._columns - ] - - def execute(self, operation, parameters=None): - """Prepare and execute a database operation (query or command). - - Return values are not defined. - """ - headers = { - 'X-Presto-Catalog': self._catalog, - 'X-Presto-Schema': self._schema, - 'X-Presto-Source': self._source, - 'X-Presto-User': self._username, - } - - if self._session_props: - headers['X-Presto-Session'] = ','.join( - '{}={}'.format(propname, propval) - for propname, propval in self._session_props.items() - ) - - # Prepare statement - if parameters is None: - sql = operation - else: - sql = operation % _escaper.escape_args(parameters) - - self._reset_state() - - self._state = self._STATE_RUNNING - url = urlparse.urlunparse(( - self._protocol, - '{}:{}'.format(self._host, self._port), '/v1/statement', None, None, None)) - _logger.info('%s', sql) - _logger.debug("Headers: %s", headers) - response = self._requests_session.post( - url, data=sql.encode('utf-8'), headers=headers, **self._requests_kwargs) - self._process_response(response) - - def cancel(self): - if self._state == self._STATE_NONE: - raise ProgrammingError("No query yet") - if self._nextUri is None: - assert self._state == self._STATE_FINISHED, "Should be finished if nextUri is None" - return - - response = self._requests_session.delete(self._nextUri, **self._requests_kwargs) - if response.status_code != requests.codes.no_content: - fmt = "Unexpected status code after cancel {}\n{}" - raise OperationalError(fmt.format(response.status_code, response.content)) - - self._state = self._STATE_FINISHED - self._nextUri = None - - def poll(self): - """Poll for and return the raw status data provided by the Presto REST API. - - :returns: dict -- JSON status information or ``None`` if the query is done - :raises: ``ProgrammingError`` when no query has been started - - .. note:: - This is not a part of DB-API. - """ - if self._state == self._STATE_NONE: - raise ProgrammingError("No query yet") - if self._nextUri is None: - assert self._state == self._STATE_FINISHED, "Should be finished if nextUri is None" - return None - response = self._requests_session.get(self._nextUri, **self._requests_kwargs) - self._process_response(response) - return response.json() - - def _fetch_more(self): - """Fetch the next URI and update state""" - self._process_response(self._requests_session.get(self._nextUri, **self._requests_kwargs)) - - def _process_data(self, rows): - for i, col in enumerate(self.description): - col_type = col[1].split("(")[0].lower() - if col_type in TYPES_CONVERTER: - for row in rows: - if row[i] is not None: - row[i] = TYPES_CONVERTER[col_type](row[i]) - - def _process_response(self, response): - """Given the JSON response from Presto's REST API, update the internal state with the next - URI and any data from the response - """ - # TODO handle HTTP 503 - if response.status_code != requests.codes.ok: - fmt = "Unexpected status code {}\n{}" - raise OperationalError(fmt.format(response.status_code, response.content)) - - response_json = response.json() - _logger.debug("Got response %s", response_json) - assert self._state == self._STATE_RUNNING, "Should be running if processing response" - self._nextUri = response_json.get('nextUri') - self._columns = response_json.get('columns') - if 'id' in response_json: - self.last_query_id = response_json['id'] - if 'X-Presto-Clear-Session' in response.headers: - propname = response.headers['X-Presto-Clear-Session'] - self._session_props.pop(propname, None) - if 'X-Presto-Set-Session' in response.headers: - propname, propval = response.headers['X-Presto-Set-Session'].split('=', 1) - self._session_props[propname] = propval - if 'data' in response_json: - assert self._columns - new_data = response_json['data'] - self._process_data(new_data) - self._data += map(tuple, new_data) - if 'nextUri' not in response_json: - self._state = self._STATE_FINISHED - if 'error' in response_json: - raise DatabaseError(response_json['error']) - - -# -# Type Objects and Constructors -# - - -# See types in presto-main/src/main/java/com/facebook/presto/tuple/TupleInfo.java -FIXED_INT_64 = DBAPITypeObject(['bigint']) -VARIABLE_BINARY = DBAPITypeObject(['varchar']) -DOUBLE = DBAPITypeObject(['double']) -BOOLEAN = DBAPITypeObject(['boolean']) diff --git a/src/chronify/_vendor/kyuubi/pyhive/sasl_compat.py b/src/chronify/_vendor/kyuubi/pyhive/sasl_compat.py deleted file mode 100644 index 19af6d2..0000000 --- a/src/chronify/_vendor/kyuubi/pyhive/sasl_compat.py +++ /dev/null @@ -1,56 +0,0 @@ -# Original source of this file is https://github.com/cloudera/impyla/blob/master/impala/sasl_compat.py -# which uses Apache-2.0 license as of 21 May 2023. -# This code was added to Impyla in 2016 as a compatibility layer to allow use of either python-sasl or pure-sasl -# via PR https://github.com/cloudera/impyla/pull/179 -# Even though thrift_sasl lists pure-sasl as dependency here https://github.com/cloudera/thrift_sasl/blob/master/setup.py#L34 -# but it still calls functions native to python-sasl in this file https://github.com/cloudera/thrift_sasl/blob/master/thrift_sasl/__init__.py#L82 -# Hence this code is required for the fallback to work. - - -from puresasl.client import SASLClient, SASLError -from contextlib import contextmanager - -@contextmanager -def error_catcher(self, Exc = Exception): - try: - self.error = None - yield - except Exc as e: - self.error = str(e) - - -class PureSASLClient(SASLClient): - def __init__(self, *args, **kwargs): - self.error = None - super(PureSASLClient, self).__init__(*args, **kwargs) - - def start(self, mechanism): - with error_catcher(self, SASLError): - if isinstance(mechanism, list): - self.choose_mechanism(mechanism) - else: - self.choose_mechanism([mechanism]) - return True, self.mechanism, self.process() - # else - return False, mechanism, None - - def encode(self, incoming): - with error_catcher(self): - return True, self.unwrap(incoming) - # else - return False, None - - def decode(self, outgoing): - with error_catcher(self): - return True, self.wrap(outgoing) - # else - return False, None - - def step(self, challenge=None): - with error_catcher(self): - return True, self.process(challenge) - # else - return False, None - - def getError(self): - return self.error diff --git a/src/chronify/_vendor/kyuubi/pyhive/sqlalchemy_hive.py b/src/chronify/_vendor/kyuubi/pyhive/sqlalchemy_hive.py deleted file mode 100644 index 66f7d75..0000000 --- a/src/chronify/_vendor/kyuubi/pyhive/sqlalchemy_hive.py +++ /dev/null @@ -1,435 +0,0 @@ -"""Integration between SQLAlchemy and Hive. - -Some code based on -https://github.com/zzzeek/sqlalchemy/blob/rel_0_5/lib/sqlalchemy/databases/sqlite.py -which is released under the MIT license. -""" - -from __future__ import absolute_import -from __future__ import unicode_literals - -import datetime -import decimal -import logging - -import re -from sqlalchemy import exc -from sqlalchemy.sql import text -try: - from sqlalchemy import processors -except ImportError: - # Required for SQLAlchemy>=2.0 - from sqlalchemy.engine import processors -from sqlalchemy import types -from sqlalchemy import util -# TODO shouldn't use mysql type -try: - from sqlalchemy.databases import mysql - mysql_tinyinteger = mysql.MSTinyInteger -except ImportError: - # Required for SQLAlchemy>2.0 - from sqlalchemy.dialects import mysql - mysql_tinyinteger = mysql.base.MSTinyInteger -from sqlalchemy.engine import default -from sqlalchemy.sql import compiler -from sqlalchemy.sql.compiler import SQLCompiler - -from pyhive import hive -from pyhive.common import UniversalSet - -from dateutil.parser import parse -from decimal import Decimal - -_logger = logging.getLogger(__name__) - -class HiveStringTypeBase(types.TypeDecorator): - """Translates strings returned by Thrift into something else""" - impl = types.String - - def process_bind_param(self, value, dialect): - raise NotImplementedError("Writing to Hive not supported") - - -class HiveDate(HiveStringTypeBase): - """Translates date strings to date objects""" - impl = types.DATE - - def process_result_value(self, value, dialect): - return processors.str_to_date(value) - - def result_processor(self, dialect, coltype): - def process(value): - if isinstance(value, datetime.datetime): - return value.date() - elif isinstance(value, datetime.date): - return value - elif value is not None: - return parse(value).date() - else: - return None - - return process - - def adapt(self, impltype, **kwargs): - return self.impl - - -class HiveTimestamp(HiveStringTypeBase): - """Translates timestamp strings to datetime objects""" - impl = types.TIMESTAMP - - def process_result_value(self, value, dialect): - return processors.str_to_datetime(value) - - def result_processor(self, dialect, coltype): - def process(value): - if isinstance(value, datetime.datetime): - return value - elif value is not None: - return parse(value) - else: - return None - - return process - - def adapt(self, impltype, **kwargs): - return self.impl - - -class HiveDecimal(HiveStringTypeBase): - """Translates strings to decimals""" - impl = types.DECIMAL - - def process_result_value(self, value, dialect): - if value is not None: - return decimal.Decimal(value) - else: - return None - - def result_processor(self, dialect, coltype): - def process(value): - if isinstance(value, Decimal): - return value - elif value is not None: - return Decimal(value) - else: - return None - - return process - - def adapt(self, impltype, **kwargs): - return self.impl - - -class HiveIdentifierPreparer(compiler.IdentifierPreparer): - # Just quote everything to make things simpler / easier to upgrade - reserved_words = UniversalSet() - - def __init__(self, dialect): - super(HiveIdentifierPreparer, self).__init__( - dialect, - initial_quote='`', - ) - - -_type_map = { - 'boolean': types.Boolean, - 'tinyint': mysql_tinyinteger, - 'smallint': types.SmallInteger, - 'int': types.Integer, - 'bigint': types.BigInteger, - 'float': types.Float, - 'double': types.Float, - 'string': types.String, - 'varchar': types.String, - 'char': types.String, - 'date': HiveDate, - 'timestamp': HiveTimestamp, - 'binary': types.String, - 'array': types.String, - 'map': types.String, - 'struct': types.String, - 'uniontype': types.String, - 'decimal': HiveDecimal, -} - - -class HiveCompiler(SQLCompiler): - def visit_concat_op_binary(self, binary, operator, **kw): - return "concat(%s, %s)" % (self.process(binary.left), self.process(binary.right)) - - def visit_insert(self, *args, **kwargs): - result = super(HiveCompiler, self).visit_insert(*args, **kwargs) - # Massage the result into Hive's format - # INSERT INTO `pyhive_test_database`.`test_table` (`a`) SELECT ... - # => - # INSERT INTO TABLE `pyhive_test_database`.`test_table` SELECT ... - regex = r'^(INSERT INTO) ([^\s]+) \([^\)]*\)' - assert re.search(regex, result), "Unexpected visit_insert result: {}".format(result) - return re.sub(regex, r'\1 TABLE \2', result) - - def visit_column(self, *args, **kwargs): - result = super(HiveCompiler, self).visit_column(*args, **kwargs) - dot_count = result.count('.') - assert dot_count in (0, 1, 2), "Unexpected visit_column result {}".format(result) - if dot_count == 2: - # we have something of the form schema.table.column - # hive doesn't like the schema in front, so chop it out - result = result[result.index('.') + 1:] - return result - - def visit_char_length_func(self, fn, **kw): - return 'length{}'.format(self.function_argspec(fn, **kw)) - - -class HiveTypeCompiler(compiler.GenericTypeCompiler): - def visit_INTEGER(self, type_): - return 'INT' - - def visit_NUMERIC(self, type_): - return 'DECIMAL' - - def visit_CHAR(self, type_): - return 'STRING' - - def visit_VARCHAR(self, type_): - return 'STRING' - - def visit_NCHAR(self, type_): - return 'STRING' - - def visit_TEXT(self, type_): - return 'STRING' - - def visit_CLOB(self, type_): - return 'STRING' - - def visit_BLOB(self, type_): - return 'BINARY' - - def visit_TIME(self, type_): - return 'TIMESTAMP' - - def visit_DATE(self, type_): - return 'TIMESTAMP' - - def visit_DATETIME(self, type_): - return 'TIMESTAMP' - - -class HiveExecutionContext(default.DefaultExecutionContext): - """This is pretty much the same as SQLiteExecutionContext to work around the same issue. - - http://docs.sqlalchemy.org/en/latest/dialects/sqlite.html#dotted-column-names - - engine = create_engine('hive://...', execution_options={'hive_raw_colnames': True}) - """ - - @util.memoized_property - def _preserve_raw_colnames(self): - # Ideally, this would also gate on hive.resultset.use.unique.column.names - return self.execution_options.get('hive_raw_colnames', False) - - def _translate_colname(self, colname): - # Adjust for dotted column names. - # When hive.resultset.use.unique.column.names is true (the default), Hive returns column - # names as "tablename.colname" in cursor.description. - if not self._preserve_raw_colnames and '.' in colname: - return colname.split('.')[-1], colname - else: - return colname, None - - -class HiveDialect(default.DefaultDialect): - name = 'hive' - driver = 'thrift' - execution_ctx_cls = HiveExecutionContext - preparer = HiveIdentifierPreparer - statement_compiler = HiveCompiler - supports_views = True - supports_alter = True - supports_pk_autoincrement = False - supports_default_values = False - supports_empty_insert = False - supports_native_decimal = True - supports_native_boolean = True - supports_unicode_statements = True - supports_unicode_binds = True - returns_unicode_strings = True - description_encoding = None - supports_multivalues_insert = True - type_compiler = HiveTypeCompiler - supports_sane_rowcount = False - supports_statement_cache = False - - @classmethod - def dbapi(cls): - return hive - - @classmethod - def import_dbapi(cls): - return hive - - def create_connect_args(self, url): - kwargs = { - 'host': url.host, - 'port': url.port or 10000, - 'username': url.username, - 'password': url.password, - 'database': url.database or 'default', - } - kwargs.update(url.query) - return [], kwargs - - def get_schema_names(self, connection, **kw): - # Equivalent to SHOW DATABASES - return [row[0] for row in connection.execute(text('SHOW SCHEMAS'))] - - def get_view_names(self, connection, schema=None, **kw): - # Hive does not provide functionality to query tableType - # This allows reflection to not crash at the cost of being inaccurate - return self.get_table_names(connection, schema, **kw) - - def _get_table_columns(self, connection, table_name, schema): - full_table = table_name - if schema: - full_table = schema + '.' + table_name - # TODO using TGetColumnsReq hangs after sending TFetchResultsReq. - # Using DESCRIBE works but is uglier. - try: - # This needs the table name to be unescaped (no backticks). - rows = connection.execute(text('DESCRIBE {}'.format(full_table))).fetchall() - except exc.OperationalError as e: - # Does the table exist? - regex_fmt = r'TExecuteStatementResp.*SemanticException.*Table not found {}' - regex = regex_fmt.format(re.escape(full_table)) - if re.search(regex, e.args[0]): - raise exc.NoSuchTableError(full_table) - else: - raise - else: - # Hive is stupid: this is what I get from DESCRIBE some_schema.does_not_exist - regex = r'Table .* does not exist' - if len(rows) == 1 and re.match(regex, rows[0].col_name): - raise exc.NoSuchTableError(full_table) - return rows - - def has_table(self, connection, table_name, schema=None, **kw): - try: - self._get_table_columns(connection, table_name, schema) - return True - except exc.NoSuchTableError: - return False - - def get_columns(self, connection, table_name, schema=None, **kw): - rows = self._get_table_columns(connection, table_name, schema) - # Strip whitespace - rows = [[col.strip() if col else None for col in row] for row in rows] - # Filter out empty rows and comment - rows = [row for row in rows if row[0] and row[0] != '# col_name'] - result = [] - for (col_name, col_type, _comment) in rows: - if col_name == '# Partition Information': - break - # Take out the more detailed type information - # e.g. 'map' -> 'map' - # 'decimal(10,1)' -> decimal - col_type = re.search(r'^\w+', col_type).group(0) - try: - coltype = _type_map[col_type] - except KeyError: - util.warn("Did not recognize type '%s' of column '%s'" % (col_type, col_name)) - coltype = types.NullType - - result.append({ - 'name': col_name, - 'type': coltype, - 'nullable': True, - 'default': None, - }) - return result - - def get_foreign_keys(self, connection, table_name, schema=None, **kw): - # Hive has no support for foreign keys. - return [] - - def get_pk_constraint(self, connection, table_name, schema=None, **kw): - # Hive has no support for primary keys. - return [] - - def get_indexes(self, connection, table_name, schema=None, **kw): - rows = self._get_table_columns(connection, table_name, schema) - # Strip whitespace - rows = [[col.strip() if col else None for col in row] for row in rows] - # Filter out empty rows and comment - rows = [row for row in rows if row[0] and row[0] != '# col_name'] - for i, (col_name, _col_type, _comment) in enumerate(rows): - if col_name == '# Partition Information': - break - # Handle partition columns - col_names = [] - for col_name, _col_type, _comment in rows[i + 1:]: - col_names.append(col_name) - if col_names: - return [{'name': 'partition', 'column_names': col_names, 'unique': False}] - else: - return [] - - def get_table_names(self, connection, schema=None, **kw): - query = 'SHOW TABLES' - if schema: - query += ' IN ' + self.identifier_preparer.quote_identifier(schema) - - table_names = [] - - for row in connection.execute(text(query)): - # Hive returns 1 columns - if len(row) == 1: - table_names.append(row[0]) - # Spark SQL returns 3 columns - elif len(row) == 3: - table_names.append(row[1]) - else: - _logger.warning("Unexpected number of columns in SHOW TABLES result: {}".format(len(row))) - table_names.append('UNKNOWN') - - return table_names - - def do_rollback(self, dbapi_connection): - # No transactions for Hive - pass - - def _check_unicode_returns(self, connection, additional_tests=None): - # We decode everything as UTF-8 - return True - - def _check_unicode_description(self, connection): - # We decode everything as UTF-8 - return True - - -class HiveHTTPDialect(HiveDialect): - - name = "hive" - scheme = "http" - driver = "rest" - - def create_connect_args(self, url): - kwargs = { - "host": url.host, - "port": url.port or 10000, - "scheme": self.scheme, - "username": url.username or None, - "password": url.password or None, - "database": url.database or "default", - } - if url.query: - kwargs.update(url.query) - return [], kwargs - return ([], kwargs) - - -class HiveHTTPSDialect(HiveHTTPDialect): - - name = "hive" - scheme = "https" diff --git a/src/chronify/_vendor/kyuubi/pyhive/sqlalchemy_presto.py b/src/chronify/_vendor/kyuubi/pyhive/sqlalchemy_presto.py deleted file mode 100644 index f5a256f..0000000 --- a/src/chronify/_vendor/kyuubi/pyhive/sqlalchemy_presto.py +++ /dev/null @@ -1,256 +0,0 @@ -"""Integration between SQLAlchemy and Presto. - -Some code based on -https://github.com/zzzeek/sqlalchemy/blob/rel_0_5/lib/sqlalchemy/databases/sqlite.py -which is released under the MIT license. -""" - -from __future__ import absolute_import -from __future__ import unicode_literals - -import re -import sqlalchemy -from sqlalchemy import exc -from sqlalchemy import types -from sqlalchemy import util -# TODO shouldn't use mysql type -from sqlalchemy.sql import text -try: - from sqlalchemy.databases import mysql - mysql_tinyinteger = mysql.MSTinyInteger -except ImportError: - # Required for SQLAlchemy>=2.0 - from sqlalchemy.dialects import mysql - mysql_tinyinteger = mysql.base.MSTinyInteger -from sqlalchemy.engine import default -from sqlalchemy.sql import compiler, bindparam -from sqlalchemy.sql.compiler import SQLCompiler - -from pyhive import presto -from pyhive.common import UniversalSet - -sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1)) - -class PrestoIdentifierPreparer(compiler.IdentifierPreparer): - # Just quote everything to make things simpler / easier to upgrade - reserved_words = UniversalSet() - - -_type_map = { - 'boolean': types.Boolean, - 'tinyint': mysql_tinyinteger, - 'smallint': types.SmallInteger, - 'integer': types.Integer, - 'bigint': types.BigInteger, - 'real': types.Float, - 'double': types.Float, - 'varchar': types.String, - 'timestamp': types.TIMESTAMP, - 'date': types.DATE, - 'varbinary': types.VARBINARY, -} - - -class PrestoCompiler(SQLCompiler): - def visit_char_length_func(self, fn, **kw): - return 'length{}'.format(self.function_argspec(fn, **kw)) - - -class PrestoTypeCompiler(compiler.GenericTypeCompiler): - def visit_CLOB(self, type_, **kw): - raise ValueError("Presto does not support the CLOB column type.") - - def visit_NCLOB(self, type_, **kw): - raise ValueError("Presto does not support the NCLOB column type.") - - def visit_DATETIME(self, type_, **kw): - raise ValueError("Presto does not support the DATETIME column type.") - - def visit_FLOAT(self, type_, **kw): - return 'DOUBLE' - - def visit_TEXT(self, type_, **kw): - if type_.length: - return 'VARCHAR({:d})'.format(type_.length) - else: - return 'VARCHAR' - - -class PrestoDialect(default.DefaultDialect): - name = 'presto' - driver = 'rest' - paramstyle = 'pyformat' - preparer = PrestoIdentifierPreparer - statement_compiler = PrestoCompiler - supports_alter = False - supports_pk_autoincrement = False - supports_default_values = False - supports_empty_insert = False - supports_multivalues_insert = True - supports_unicode_statements = True - supports_unicode_binds = True - supports_statement_cache = False - returns_unicode_strings = True - description_encoding = None - supports_native_boolean = True - type_compiler = PrestoTypeCompiler - - @classmethod - def dbapi(cls): - return presto - - @classmethod - def import_dbapi(cls): - return presto - - def create_connect_args(self, url): - db_parts = (url.database or 'hive').split('/') - kwargs = { - 'host': url.host, - 'port': url.port or 8080, - 'username': url.username, - 'password': url.password - } - kwargs.update(url.query) - if len(db_parts) == 1: - kwargs['catalog'] = db_parts[0] - elif len(db_parts) == 2: - kwargs['catalog'] = db_parts[0] - kwargs['schema'] = db_parts[1] - else: - raise ValueError("Unexpected database format {}".format(url.database)) - return [], kwargs - - def get_schema_names(self, connection, **kw): - return [row.Schema for row in connection.execute(text('SHOW SCHEMAS'))] - - def _get_table_columns(self, connection, table_name, schema): - full_table = self.identifier_preparer.quote_identifier(table_name) - if schema: - full_table = self.identifier_preparer.quote_identifier(schema) + '.' + full_table - try: - return connection.execute(text('SHOW COLUMNS FROM {}'.format(full_table))) - except (presto.DatabaseError, exc.DatabaseError) as e: - # Normally SQLAlchemy should wrap this exception in sqlalchemy.exc.DatabaseError, which - # it successfully does in the Hive version. The difference with Presto is that this - # error is raised when fetching the cursor's description rather than the initial execute - # call. SQLAlchemy doesn't handle this. Thus, we catch the unwrapped - # presto.DatabaseError here. - # Does the table exist? - msg = ( - e.args[0].get('message') if e.args and isinstance(e.args[0], dict) - else e.args[0] if e.args and isinstance(e.args[0], str) - else None - ) - regex = r"Table\ \'.*{}\'\ does\ not\ exist".format(re.escape(table_name)) - if msg and re.search(regex, msg, re.IGNORECASE): - raise exc.NoSuchTableError(table_name) - else: - raise - - def has_table(self, connection, table_name, schema=None, **kw): - try: - self._get_table_columns(connection, table_name, schema) - return True - except exc.NoSuchTableError: - return False - - def get_columns(self, connection, table_name, schema=None, **kw): - rows = self._get_table_columns(connection, table_name, schema) - result = [] - for row in rows: - try: - coltype = _type_map[row.Type] - except KeyError: - util.warn("Did not recognize type '%s' of column '%s'" % (row.Type, row.Column)) - coltype = types.NullType - result.append({ - 'name': row.Column, - 'type': coltype, - # newer Presto no longer includes this column - 'nullable': getattr(row, 'Null', True), - 'default': None, - }) - return result - - def get_foreign_keys(self, connection, table_name, schema=None, **kw): - # Hive has no support for foreign keys. - return [] - - def get_pk_constraint(self, connection, table_name, schema=None, **kw): - # Hive has no support for primary keys. - return [] - - def get_indexes(self, connection, table_name, schema=None, **kw): - rows = self._get_table_columns(connection, table_name, schema) - col_names = [] - for row in rows: - part_key = 'Partition Key' - # Presto puts this information in one of 3 places depending on version - # - a boolean column named "Partition Key" - # - a string in the "Comment" column - # - a string in the "Extra" column - if sqlalchemy_version >= 1.4: - row = row._mapping - is_partition_key = ( - (part_key in row and row[part_key]) - or row['Comment'].startswith(part_key) - or ('Extra' in row and 'partition key' in row['Extra']) - ) - if is_partition_key: - col_names.append(row['Column']) - if col_names: - return [{'name': 'partition', 'column_names': col_names, 'unique': False}] - else: - return [] - - def _get_default_schema_name(self, connection): - #'SELECT CURRENT_SCHEMA()' - return super()._get_default_schema_name(connection) - - def get_table_names(self, connection, schema=None, **kw): - query = 'SHOW TABLES' - # N.B. This is incorrect, if no schema is provided, the current/default schema should be used - # with a call to an overridden self._get_default_schema_name(connection), but I could not - # see how to implement that as there is no CURRENT_SCHEMA function - # default_schema = self._get_default_schema_name(connection) - - if schema: - query += ' FROM ' + self.identifier_preparer.quote_identifier(schema) - return [row.Table for row in connection.execute(text(query))] - - def get_view_names(self, connection, schema=None, **kw): - if schema: - view_name_query = """ - SELECT table_name - FROM information_schema.views - WHERE table_schema = :schema - """ - query = text(view_name_query).bindparams( - bindparam("schema", type_=types.Unicode) - ) - else: - # N.B. This is incorrect, if no schema is provided, the current/default schema should - # be used with a call to self._get_default_schema_name(connection), but I could not - # see how to implement that - # default_schema = self._get_default_schema_name(connection) - view_name_query = """ - SELECT table_name - FROM information_schema.views - """ - query = text(view_name_query) - - result = connection.execute(query, dict(schema=schema)) - return [row[0] for row in result] - - def do_rollback(self, dbapi_connection): - # No transactions for Presto - pass - - def _check_unicode_returns(self, connection, additional_tests=None): - # requests gives back Unicode strings - return True - - def _check_unicode_description(self, connection): - # requests gives back Unicode strings - return True diff --git a/src/chronify/_vendor/kyuubi/pyhive/sqlalchemy_trino.py b/src/chronify/_vendor/kyuubi/pyhive/sqlalchemy_trino.py deleted file mode 100644 index 11be2a6..0000000 --- a/src/chronify/_vendor/kyuubi/pyhive/sqlalchemy_trino.py +++ /dev/null @@ -1,84 +0,0 @@ -"""Integration between SQLAlchemy and Trino. - -Some code based on -https://github.com/zzzeek/sqlalchemy/blob/rel_0_5/lib/sqlalchemy/databases/sqlite.py -which is released under the MIT license. -""" - -from __future__ import absolute_import -from __future__ import unicode_literals - -import re -from sqlalchemy import exc -from sqlalchemy import types -from sqlalchemy import util -# TODO shouldn't use mysql type -try: - from sqlalchemy.databases import mysql - mysql_tinyinteger = mysql.MSTinyInteger -except ImportError: - # Required for SQLAlchemy>=2.0 - from sqlalchemy.dialects import mysql - mysql_tinyinteger = mysql.base.MSTinyInteger -from sqlalchemy.engine import default -from sqlalchemy.sql import compiler -from sqlalchemy.sql.compiler import SQLCompiler - -from pyhive import trino -from pyhive.common import UniversalSet -from pyhive.sqlalchemy_presto import PrestoDialect, PrestoCompiler, PrestoIdentifierPreparer - -class TrinoIdentifierPreparer(PrestoIdentifierPreparer): - pass - - -_type_map = { - 'boolean': types.Boolean, - 'tinyint': mysql_tinyinteger, - 'smallint': types.SmallInteger, - 'integer': types.Integer, - 'bigint': types.BigInteger, - 'real': types.Float, - 'double': types.Float, - 'varchar': types.String, - 'timestamp': types.TIMESTAMP, - 'date': types.DATE, - 'varbinary': types.VARBINARY, -} - - -class TrinoCompiler(PrestoCompiler): - pass - - -class TrinoTypeCompiler(PrestoCompiler): - def visit_CLOB(self, type_, **kw): - raise ValueError("Trino does not support the CLOB column type.") - - def visit_NCLOB(self, type_, **kw): - raise ValueError("Trino does not support the NCLOB column type.") - - def visit_DATETIME(self, type_, **kw): - raise ValueError("Trino does not support the DATETIME column type.") - - def visit_FLOAT(self, type_, **kw): - return 'DOUBLE' - - def visit_TEXT(self, type_, **kw): - if type_.length: - return 'VARCHAR({:d})'.format(type_.length) - else: - return 'VARCHAR' - - -class TrinoDialect(PrestoDialect): - name = 'trino' - supports_statement_cache = False - - @classmethod - def dbapi(cls): - return trino - - @classmethod - def import_dbapi(cls): - return trino diff --git a/src/chronify/_vendor/kyuubi/pyhive/trino.py b/src/chronify/_vendor/kyuubi/pyhive/trino.py deleted file mode 100644 index 658457a..0000000 --- a/src/chronify/_vendor/kyuubi/pyhive/trino.py +++ /dev/null @@ -1,144 +0,0 @@ -"""DB-API implementation backed by Trino - -See http://www.python.org/dev/peps/pep-0249/ - -Many docstrings in this file are based on the PEP, which is in the public domain. -""" - -from __future__ import absolute_import -from __future__ import unicode_literals - -import logging - -import requests - -# Make all exceptions visible in this module per DB-API -from pyhive.common import DBAPITypeObject -from pyhive.exc import * # noqa -from pyhive.presto import Connection as PrestoConnection, Cursor as PrestoCursor, PrestoParamEscaper - -try: # Python 3 - import urllib.parse as urlparse -except ImportError: # Python 2 - import urlparse - -# PEP 249 module globals -apilevel = '2.0' -threadsafety = 2 # Threads may share the module and connections. -paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s - -_logger = logging.getLogger(__name__) - - -class TrinoParamEscaper(PrestoParamEscaper): - pass - - -_escaper = TrinoParamEscaper() - - -def connect(*args, **kwargs): - """Constructor for creating a connection to the database. See class :py:class:`Connection` for - arguments. - - :returns: a :py:class:`Connection` object. - """ - return Connection(*args, **kwargs) - - -class Connection(PrestoConnection): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def cursor(self): - """Return a new :py:class:`Cursor` object using the connection.""" - return Cursor(*self._args, **self._kwargs) - - -class Cursor(PrestoCursor): - """These objects represent a database cursor, which is used to manage the context of a fetch - operation. - - Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately - visible by other cursors or connections. - """ - - def execute(self, operation, parameters=None): - """Prepare and execute a database operation (query or command). - - Return values are not defined. - """ - headers = { - 'X-Trino-Catalog': self._catalog, - 'X-Trino-Schema': self._schema, - 'X-Trino-Source': self._source, - 'X-Trino-User': self._username, - } - - if self._session_props: - headers['X-Trino-Session'] = ','.join( - '{}={}'.format(propname, propval) - for propname, propval in self._session_props.items() - ) - - # Prepare statement - if parameters is None: - sql = operation - else: - sql = operation % _escaper.escape_args(parameters) - - self._reset_state() - - self._state = self._STATE_RUNNING - url = urlparse.urlunparse(( - self._protocol, - '{}:{}'.format(self._host, self._port), '/v1/statement', None, None, None)) - _logger.info('%s', sql) - _logger.debug("Headers: %s", headers) - response = self._requests_session.post( - url, data=sql.encode('utf-8'), headers=headers, **self._requests_kwargs) - self._process_response(response) - - def _process_response(self, response): - """Given the JSON response from Trino's REST API, update the internal state with the next - URI and any data from the response - """ - # TODO handle HTTP 503 - if response.status_code != requests.codes.ok: - fmt = "Unexpected status code {}\n{}" - raise OperationalError(fmt.format(response.status_code, response.content)) - - response_json = response.json() - _logger.debug("Got response %s", response_json) - assert self._state == self._STATE_RUNNING, "Should be running if processing response" - self._nextUri = response_json.get('nextUri') - self._columns = response_json.get('columns') - if 'id' in response_json: - self.last_query_id = response_json['id'] - if 'X-Trino-Clear-Session' in response.headers: - propname = response.headers['X-Trino-Clear-Session'] - self._session_props.pop(propname, None) - if 'X-Trino-Set-Session' in response.headers: - propname, propval = response.headers['X-Trino-Set-Session'].split('=', 1) - self._session_props[propname] = propval - if 'data' in response_json: - assert self._columns - new_data = response_json['data'] - self._process_data(new_data) - self._data += map(tuple, new_data) - if 'nextUri' not in response_json: - self._state = self._STATE_FINISHED - if 'error' in response_json: - raise DatabaseError(response_json['error']) - - -# -# Type Objects and Constructors -# - - -# See types in trino-main/src/main/java/com/facebook/trino/tuple/TupleInfo.java -FIXED_INT_64 = DBAPITypeObject(['bigint']) -VARIABLE_BINARY = DBAPITypeObject(['varchar']) -DOUBLE = DBAPITypeObject(['double']) -BOOLEAN = DBAPITypeObject(['boolean']) diff --git a/src/chronify/csv_io.py b/src/chronify/csv_io.py index 5b99b07..f6310c8 100644 --- a/src/chronify/csv_io.py +++ b/src/chronify/csv_io.py @@ -4,16 +4,14 @@ import duckdb from duckdb import DuckDBPyRelation -from chronify.models import CsvTableSchema, get_duckdb_type_from_sqlalchemy +from chronify.models import CsvTableSchema, get_duckdb_type_from_ibis_type from chronify.time_configs import DatetimeRange def read_csv(path: Path | str, schema: CsvTableSchema, **kwargs: Any) -> DuckDBPyRelation: """Read a CSV file into a DuckDB relation.""" if schema.column_dtypes: - dtypes = { - x.name: get_duckdb_type_from_sqlalchemy(x.dtype).id for x in schema.column_dtypes - } + dtypes = {x.name: get_duckdb_type_from_ibis_type(x.dtype) for x in schema.column_dtypes} rel = duckdb.read_csv(str(path), dtype=dtypes, **kwargs) else: rel = duckdb.read_csv(str(path), **kwargs) diff --git a/src/chronify/hive_functions.py b/src/chronify/hive_functions.py deleted file mode 100644 index ff3ca07..0000000 --- a/src/chronify/hive_functions.py +++ /dev/null @@ -1,34 +0,0 @@ -from pathlib import Path -from tempfile import NamedTemporaryFile -from typing import Optional - -from sqlalchemy import Engine, MetaData, text - - -def create_materialized_view( - query: str, - dst_table: str, - engine: Engine, - metadata: MetaData, - scratch_dir: Optional[Path] = None, -) -> None: - """Create a materialized view with a Parquet file. This is a workaround for an undiagnosed - problem with timestamps and time zones with hive. - - The Parquet file will be written to scratch_dir. Callers must ensure that the directory - persists for the duration of the work. - """ - with NamedTemporaryFile(dir=scratch_dir, suffix=".parquet") as f: - f.close() - output = Path(f.name) - write_query = f""" - INSERT OVERWRITE DIRECTORY - '{output}' - USING parquet - ({query}) - """ - with engine.begin() as conn: - conn.execute(text(write_query)) - view_query = f"CREATE VIEW {dst_table} AS SELECT * FROM parquet.`{output}`" - conn.execute(text(view_query)) - metadata.reflect(engine, views=True) diff --git a/src/chronify/ibis/__init__.py b/src/chronify/ibis/__init__.py new file mode 100644 index 0000000..926012c --- /dev/null +++ b/src/chronify/ibis/__init__.py @@ -0,0 +1,43 @@ +"""Ibis backend abstraction layer for Chronify.""" + +from chronify.ibis.base import IbisBackend, ObjectType +from chronify.ibis.duckdb_backend import DuckDBBackend +from chronify.ibis.sqlite_backend import SQLiteBackend + +__all__ = [ + "DuckDBBackend", + "IbisBackend", + "ObjectType", + "SQLiteBackend", + "make_backend", +] + + +def make_backend( + name: str, + database: str | None = None, + **kwargs: object, +) -> IbisBackend: + """Create an IbisBackend instance. + + Parameters + ---------- + name + Backend name: "duckdb", "sqlite", or "spark". + database + Database file path, or None for in-memory. + **kwargs + Additional keyword arguments passed to the backend constructor. + """ + match name: + case "duckdb": + return DuckDBBackend(database=database or ":memory:", **kwargs) + case "sqlite": + return SQLiteBackend(database=database or ":memory:", **kwargs) + case "spark": + from chronify.ibis.spark_backend import SparkBackend + + return SparkBackend(**kwargs) + case _: + msg = f"Unsupported backend: {name}. Choose from: duckdb, sqlite, spark" + raise ValueError(msg) diff --git a/src/chronify/ibis/base.py b/src/chronify/ibis/base.py new file mode 100644 index 0000000..fdcd22a --- /dev/null +++ b/src/chronify/ibis/base.py @@ -0,0 +1,148 @@ +"""Abstract base class for Ibis database backends.""" + +from abc import ABC, abstractmethod +from contextlib import contextmanager +from enum import Enum +from typing import Any, Generator + +import ibis +import ibis.expr.types as ir +import pandas as pd +from loguru import logger + + +class ObjectType(Enum): + TABLE = "table" + VIEW = "view" + + +class IbisBackend(ABC): + """Abstract base class defining the interface for Ibis database backends.""" + + @property + @abstractmethod + def name(self) -> str: + """Return the backend name (e.g., 'duckdb', 'sqlite', 'spark').""" + + @property + @abstractmethod + def database(self) -> str | None: + """Return the database file path, or None for in-memory.""" + + @property + @abstractmethod + def connection(self) -> ibis.BaseBackend: + """Return the underlying ibis connection.""" + + @abstractmethod + def create_table( + self, + name: str, + obj: pd.DataFrame | ir.Table | None = None, + schema: ibis.Schema | None = None, + ) -> ir.Table: + """Create a table in the database. + + Parameters + ---------- + name + Table name. + obj + Data to populate the table with. + schema + Schema to use if obj is None. + + Returns + ------- + ir.Table + """ + + @abstractmethod + def create_view(self, name: str, expr: ir.Table) -> ir.Table: + """Create a view in the database.""" + + @abstractmethod + def drop_table(self, name: str) -> None: + """Drop a table from the database.""" + + @abstractmethod + def drop_view(self, name: str) -> None: + """Drop a view from the database.""" + + @abstractmethod + def list_tables(self) -> list[str]: + """List all user tables in the database.""" + + @abstractmethod + def table(self, name: str) -> ir.Table: + """Return an ibis table expression for the named table.""" + + @abstractmethod + def insert(self, name: str, data: pd.DataFrame) -> None: + """Insert data into an existing table.""" + + @abstractmethod + def execute(self, expr: ir.Expr) -> pd.DataFrame: + """Execute an ibis expression and return a DataFrame.""" + + @abstractmethod + def sql(self, query: str) -> ir.Table: + """Create an ibis table expression from a raw SQL string.""" + + @abstractmethod + def write_parquet( + self, + expr: ir.Table, + path: str, + partition_by: list[str] | None = None, + ) -> None: + """Write an ibis expression result to a Parquet file.""" + + @abstractmethod + def create_view_from_parquet(self, path: str, name: str) -> ir.Table: + """Create a view or table backed by a Parquet file.""" + + def has_table(self, name: str) -> bool: + """Check whether a table or view exists.""" + return name in self.list_tables() + + def execute_sql(self, query: str) -> Any: + """Execute a raw SQL statement (no result expected).""" + logger.trace("execute_sql: {}", query) + return self.connection.raw_sql(query) + + def execute_sql_to_df(self, query: str) -> pd.DataFrame: + """Execute a raw SQL query and return a DataFrame.""" + logger.trace("execute_sql_to_df: {}", query) + return self.connection.raw_sql(query).fetch_df() + + def dispose(self) -> None: + """Dispose of the backend connection.""" + self.connection.disconnect() + + def reconnect(self) -> None: + """Reconnect to the database. Subclasses should override if needed.""" + + @contextmanager + def transaction(self) -> Generator[list[tuple[str, ObjectType]], None, None]: + """Context manager for pseudo-transactions. + + Tracks created objects (tables/views) so they can be cleaned up on failure. + On success, created objects are kept. On exception, they are dropped. + + Yields a list to which callers should append (name, ObjectType) tuples. + """ + created: list[tuple[str, ObjectType]] = [] + try: + yield created + except Exception: + for obj_name, obj_type in reversed(created): + try: + if obj_type == ObjectType.TABLE: + self.drop_table(obj_name) + else: + self.drop_view(obj_name) + logger.debug("Rolled back {} {}", obj_type.value, obj_name) + except Exception: + logger.warning("Failed to roll back {} {}", obj_type.value, obj_name) + raise diff --git a/src/chronify/ibis/duckdb_backend.py b/src/chronify/ibis/duckdb_backend.py new file mode 100644 index 0000000..1f18d3b --- /dev/null +++ b/src/chronify/ibis/duckdb_backend.py @@ -0,0 +1,111 @@ +"""DuckDB backend implementation for Ibis.""" + +from pathlib import Path + +import ibis +import ibis.expr.types as ir +import pandas as pd +from loguru import logger + +from chronify.ibis.base import IbisBackend + + +class DuckDBBackend(IbisBackend): + """Ibis backend for DuckDB databases.""" + + def __init__(self, database: str | Path = ":memory:") -> None: + db = str(database) + self._database = None if db == ":memory:" else db + self._connection = ibis.duckdb.connect(db) + + @property + def name(self) -> str: + return "duckdb" + + @property + def database(self) -> str | None: + return self._database + + @property + def connection(self) -> ibis.BaseBackend: + return self._connection + + def create_table( + self, + name: str, + obj: pd.DataFrame | ir.Table | None = None, + schema: ibis.Schema | None = None, + ) -> ir.Table: + return self._connection.create_table(name, obj=obj, schema=schema, overwrite=False) + + def create_view(self, name: str, expr: ir.Table) -> ir.Table: + return self._connection.create_view(name, expr, overwrite=False) + + def drop_table(self, name: str) -> None: + self._connection.drop_table(name, force=True) + + def drop_view(self, name: str) -> None: + self._connection.drop_view(name, force=True) + + def list_tables(self) -> list[str]: + tables = self._connection.list_tables() + # Filter out internal ibis memtables + return [t for t in tables if not t.startswith("ibis_pandas_memtable_")] + + def table(self, name: str) -> ir.Table: + return self._connection.table(name) + + def insert(self, name: str, data: pd.DataFrame) -> None: + con = self._connection.con # raw duckdb connection + con.register("__insert_df", data) + try: + con.execute(f"INSERT INTO {name} SELECT * FROM __insert_df") + finally: + con.unregister("__insert_df") + logger.trace("Inserted {} rows into {}", len(data), name) + + def execute(self, expr: ir.Expr) -> pd.DataFrame: + return self._connection.execute(expr) + + def sql(self, query: str) -> ir.Table: + return self._connection.sql(query) + + def write_parquet( + self, + expr: ir.Table, + path: str, + partition_by: list[str] | None = None, + ) -> None: + if partition_by: + partition_clause = ", ".join(partition_by) + sql = self._connection.compile(expr) + self._connection.raw_sql( + f"COPY ({sql}) TO '{path}' (FORMAT PARQUET, PARTITION_BY ({partition_clause}))" + ) + else: + df = self._connection.execute(expr) + df.to_parquet(path) + + def create_view_from_parquet(self, path: str, name: str) -> ir.Table: + self._connection.raw_sql( + f"CREATE VIEW {name} AS SELECT * FROM read_parquet('{path}')" + ) + return self.table(name) + + def execute_sql(self, query: str) -> None: + logger.trace("execute_sql: {}", query) + self._connection.raw_sql(query) + + def execute_sql_to_df(self, query: str) -> pd.DataFrame: + logger.trace("execute_sql_to_df: {}", query) + result = self._connection.raw_sql(query) + return result.fetch_df() + + def dispose(self) -> None: + self._connection.disconnect() + + def reconnect(self) -> None: + if self._database is not None: + self._connection = ibis.duckdb.connect(self._database) + else: + logger.warning("Cannot reconnect to an in-memory DuckDB database.") diff --git a/src/chronify/ibis/functions.py b/src/chronify/ibis/functions.py new file mode 100644 index 0000000..4015ba7 --- /dev/null +++ b/src/chronify/ibis/functions.py @@ -0,0 +1,242 @@ +"""Database I/O functions using Ibis backends.""" + +from collections import Counter +from pathlib import Path +from typing import Sequence + +import ibis.expr.types as ir +import pandas as pd +import pyarrow as pa +from pandas import DatetimeTZDtype + +from chronify.exceptions import InvalidOperation, InvalidParameter +from chronify.ibis.base import IbisBackend +from chronify.time import TimeDataType +from chronify.time_configs import ( + DatetimeRange, + DatetimeRangeBase, + DatetimeRangeWithTZColumn, + TimeBaseModel, +) +from chronify.utils.path_utils import check_overwrite + +DatetimeRanges = DatetimeRange | DatetimeRangeWithTZColumn +_DATETIME_RANGES = (DatetimeRange, DatetimeRangeWithTZColumn) + + +def read_table( + backend: IbisBackend, + table_name: str, + config: TimeBaseModel, +) -> pd.DataFrame: + """Read a table from the database.""" + table = backend.table(table_name) + df = backend.execute(table) + + if backend.name == "sqlite" and isinstance(config, _DATETIME_RANGES): + _convert_database_output_for_datetime(df, config) + elif backend.name == "spark" and isinstance(config, _DATETIME_RANGES): + _convert_spark_output_for_datetime(df, config) + + return df + + +def read_query( + backend: IbisBackend, + expr: ir.Table, + config: TimeBaseModel, +) -> pd.DataFrame: + """Execute an Ibis expression and return results.""" + df = backend.execute(expr) + + if backend.name == "sqlite" and isinstance(config, _DATETIME_RANGES): + _convert_database_output_for_datetime(df, config) + elif backend.name == "spark" and isinstance(config, _DATETIME_RANGES): + _convert_spark_output_for_datetime(df, config) + + return df + + +def write_table( + backend: IbisBackend, + df: pd.DataFrame | pa.Table, + table_name: str, + configs: Sequence[TimeBaseModel], + if_exists: str = "append", +) -> None: + """Write a DataFrame to the database.""" + match backend.name: + case "duckdb": + _write_to_duckdb(backend, df, table_name, if_exists) + case "sqlite": + _write_to_sqlite(backend, df, table_name, configs, if_exists) + case "spark": + _write_to_spark(backend, df, table_name, if_exists) + case _: + msg = f"Unsupported backend: {backend.name}" + raise NotImplementedError(msg) + + +def write_parquet( + backend: IbisBackend, + query: str | ir.Table, + output_file: Path, + overwrite: bool = False, + partition_columns: list[str] | None = None, +) -> None: + """Write query results to a Parquet file.""" + check_overwrite(output_file, overwrite) + + if isinstance(query, str): + expr = backend.sql(query) + else: + expr = query + + backend.write_parquet(expr, str(output_file), partition_by=partition_columns) + + +def create_view_from_parquet( + backend: IbisBackend, + filename: Path, + view_name: str, +) -> None: + """Create a view from a Parquet file.""" + backend.create_view_from_parquet(str(filename), view_name) + + +def _check_one_config_per_datetime_column(configs: Sequence[TimeBaseModel]) -> None: + time_col_count = Counter( + config.time_column for config in configs if isinstance(config, DatetimeRangeBase) + ) + time_col_dup = {k: v for k, v in time_col_count.items() if v > 1} + if time_col_dup: + msg = f"More than one datetime config found for: {time_col_dup}" + raise InvalidParameter(msg) + + +def _convert_database_input_for_datetime( + df: pd.DataFrame, config: DatetimeRanges, copied: bool +) -> tuple[pd.DataFrame, bool]: + """Convert DataFrame datetime columns for SQLite input (store as UTC).""" + if config.dtype == TimeDataType.TIMESTAMP_NTZ: + return df, copied + + if not copied: + df = df.copy() + copied = True + + if isinstance(df[config.time_column].dtype, DatetimeTZDtype): + df[config.time_column] = df[config.time_column].dt.tz_convert("UTC") + else: + df[config.time_column] = df[config.time_column].dt.tz_localize("UTC") + + return df, copied + + +def _convert_database_output_for_datetime(df: pd.DataFrame, config: DatetimeRanges) -> None: + """Convert DataFrame datetime columns after SQLite output.""" + if config.time_column not in df.columns: + return + + col = df[config.time_column] + if config.dtype == TimeDataType.TIMESTAMP_TZ: + if col.dtype == object: + df[config.time_column] = pd.to_datetime(col, utc=True) + elif isinstance(col.dtype, DatetimeTZDtype): + df[config.time_column] = col.dt.tz_convert("UTC") + else: + df[config.time_column] = col.dt.tz_localize("UTC") + else: + if col.dtype == object: + df[config.time_column] = pd.to_datetime(col, utc=False) + + +def _convert_spark_output_for_datetime(df: pd.DataFrame, config: DatetimeRanges) -> None: + """Convert DataFrame datetime columns after Spark output.""" + if config.time_column not in df.columns: + return + + col = df[config.time_column] + if not pd.api.types.is_datetime64_any_dtype(col): + df[config.time_column] = pd.to_datetime(col, utc=True) + col = df[config.time_column] + + if config.dtype == TimeDataType.TIMESTAMP_TZ: + if not isinstance(col.dtype, DatetimeTZDtype): + df[config.time_column] = col.dt.tz_localize("UTC") + else: + if isinstance(col.dtype, DatetimeTZDtype): + df[config.time_column] = col.dt.tz_convert(None) + + +def _write_to_duckdb( + backend: IbisBackend, + df: pd.DataFrame | pa.Table, + table_name: str, + if_exists: str, +) -> None: + if isinstance(df, pa.Table): + df = df.to_pandas() + match if_exists: + case "append": + backend.insert(table_name, df) + case "replace": + backend.drop_table(table_name) + backend.create_table(table_name, df) + case "fail": + backend.create_table(table_name, df) + case _: + msg = f"Invalid if_exists value: {if_exists}" + raise InvalidOperation(msg) + + +def _write_to_sqlite( + backend: IbisBackend, + df: pd.DataFrame | pa.Table, + table_name: str, + configs: Sequence[TimeBaseModel], + if_exists: str, +) -> None: + _check_one_config_per_datetime_column(configs) + + if isinstance(df, pa.Table): + df = df.to_pandas() + + copied = False + for config in configs: + if isinstance(config, _DATETIME_RANGES): + df, copied = _convert_database_input_for_datetime(df, config, copied) + + match if_exists: + case "append": + backend.insert(table_name, df) + case "replace": + backend.drop_table(table_name) + backend.create_table(table_name, df) + case "fail": + backend.create_table(table_name, df) + case _: + msg = f"Invalid if_exists value: {if_exists}" + raise InvalidOperation(msg) + + +def _write_to_spark( + backend: IbisBackend, + df: pd.DataFrame | pa.Table, + table_name: str, + if_exists: str, +) -> None: + if isinstance(df, pa.Table): + df = df.to_pandas() + + match if_exists: + case "append": + backend.insert(table_name, df) + case "replace": + backend.drop_table(table_name) + backend.create_table(table_name, df) + case "fail": + backend.create_table(table_name, df) + case _: + msg = f"Invalid if_exists value: {if_exists}" + raise InvalidOperation(msg) diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py new file mode 100644 index 0000000..f59790b --- /dev/null +++ b/src/chronify/ibis/spark_backend.py @@ -0,0 +1,125 @@ +"""Spark backend implementation for Ibis.""" + +from typing import Any + +import ibis +import ibis.expr.types as ir +import pandas as pd +from loguru import logger + +from chronify.ibis.base import IbisBackend + + +class SparkBackend(IbisBackend): + """Ibis backend for PySpark databases. + + Requires pyspark to be installed (pip install chronify[spark]). + """ + + def __init__(self, session: Any = None) -> None: + try: + from pyspark.sql import SparkSession + except ImportError as e: + msg = "pyspark is required for SparkBackend. Install with: pip install chronify[spark]" + raise ImportError(msg) from e + + if session is None: + session = ( + SparkSession.builder.master("local") + .config("spark.sql.session.timeZone", "UTC") + .config("spark.sql.parquet.outputTimestampType", "TIMESTAMP_MICROS") + .getOrCreate() + ) + self._session = session + self._connection = ibis.pyspark.connect(session) + + @property + def name(self) -> str: + return "spark" + + @property + def database(self) -> str | None: + return None + + @property + def connection(self) -> ibis.BaseBackend: + return self._connection + + def create_table( + self, + name: str, + obj: pd.DataFrame | ir.Table | None = None, + schema: ibis.Schema | None = None, + ) -> ir.Table: + if isinstance(obj, pd.DataFrame): + obj = self._prepare_data_for_spark(obj) + return self._connection.create_table(name, obj=obj, schema=schema, overwrite=False) + + def create_view(self, name: str, expr: ir.Table) -> ir.Table: + return self._connection.create_view(name, expr, overwrite=False) + + def drop_table(self, name: str) -> None: + self._connection.drop_table(name, force=True) + + def drop_view(self, name: str) -> None: + self._connection.drop_view(name, force=True) + + def list_tables(self) -> list[str]: + return self._connection.list_tables() + + def table(self, name: str) -> ir.Table: + return self._connection.table(name) + + def insert(self, name: str, data: pd.DataFrame) -> None: + # Spark doesn't support INSERT directly -- create a temp view and insert via SQL + data = self._prepare_data_for_spark(data) + spark_df = self._session.createDataFrame(data) + spark_df.createOrReplaceTempView("__insert_tmp") + self._session.sql(f"INSERT INTO {name} SELECT * FROM __insert_tmp") + logger.trace("Inserted {} rows into {}", len(data), name) + + def execute(self, expr: ir.Expr) -> pd.DataFrame: + return self._connection.execute(expr) + + def sql(self, query: str) -> ir.Table: + return self._connection.sql(query) + + def write_parquet( + self, + expr: ir.Table, + path: str, + partition_by: list[str] | None = None, + ) -> None: + df = self._connection.execute(expr) + if partition_by: + spark_df = self._session.createDataFrame(df) + spark_df.write.partitionBy(*partition_by).parquet(path) + else: + df.to_parquet(path) + + def create_view_from_parquet(self, path: str, name: str) -> ir.Table: + spark_df = self._session.read.parquet(path) + spark_df.createOrReplaceTempView(name) + return self.table(name) + + def execute_sql(self, query: str) -> None: + logger.trace("execute_sql: {}", query) + self._session.sql(query) + + def execute_sql_to_df(self, query: str) -> pd.DataFrame: + logger.trace("execute_sql_to_df: {}", query) + return self._session.sql(query).toPandas() + + def dispose(self) -> None: + pass # Don't stop the Spark session -- it may be shared + + def reconnect(self) -> None: + pass # Spark sessions are long-lived + + @staticmethod + def _prepare_data_for_spark(df: pd.DataFrame) -> pd.DataFrame: + """Convert datetime columns to strings to avoid Spark DST issues.""" + df = df.copy() + for col in df.select_dtypes(include=["datetime64[ns, UTC]", "datetimetz"]).columns: + df[col] = df[col].dt.strftime("%Y-%m-%d %H:%M:%S%z") + return df diff --git a/src/chronify/ibis/sqlite_backend.py b/src/chronify/ibis/sqlite_backend.py new file mode 100644 index 0000000..2baf344 --- /dev/null +++ b/src/chronify/ibis/sqlite_backend.py @@ -0,0 +1,121 @@ +"""SQLite backend implementation for Ibis.""" + +from pathlib import Path + +import ibis +import ibis.expr.types as ir +import pandas as pd +import pyarrow as pa +from loguru import logger + +from chronify.ibis.base import IbisBackend + + +class SQLiteBackend(IbisBackend): + """Ibis backend for SQLite databases.""" + + def __init__(self, database: str | Path = ":memory:") -> None: + db = str(database) + self._database = None if db == ":memory:" else db + self._connection = ibis.sqlite.connect(db) + + @property + def name(self) -> str: + return "sqlite" + + @property + def database(self) -> str | None: + return self._database + + @property + def connection(self) -> ibis.BaseBackend: + return self._connection + + def create_table( + self, + name: str, + obj: pd.DataFrame | ir.Table | None = None, + schema: ibis.Schema | None = None, + ) -> ir.Table: + if isinstance(obj, ir.Table): + # SQLite CREATE TABLE AS SELECT loses datetime type info. + # Execute the expression first, then create from the DataFrame. + df = self._connection.execute(obj) + return self._connection.create_table(name, obj=df, overwrite=False) + return self._connection.create_table(name, obj=obj, schema=schema, overwrite=False) + + def create_view(self, name: str, expr: ir.Table) -> ir.Table: + return self._connection.create_view(name, expr, overwrite=False) + + def drop_table(self, name: str) -> None: + self._connection.drop_table(name, force=True) + + def drop_view(self, name: str) -> None: + self._connection.drop_view(name, force=True) + + def list_tables(self) -> list[str]: + return self._connection.list_tables() + + def table(self, name: str) -> ir.Table: + return self._connection.table(name) + + def insert(self, name: str, data: pd.DataFrame) -> None: + # Use raw SQLite cursor for parameterized inserts + con = self._connection.con # raw sqlite3 connection + table = self._connection.table(name) + columns = table.columns + placeholders = ", ".join(["?"] * len(columns)) + col_list = ", ".join(columns) + sql = f"INSERT INTO {name} ({col_list}) VALUES ({placeholders})" + + arrow_table = pa.Table.from_pandas(data) + cursor = con.cursor() + for batch in arrow_table.to_batches(): + rows = [tuple(row[col].as_py() for col in range(batch.num_columns)) for row in zip(*[batch.column(i) for i in range(batch.num_columns)])] + cursor.executemany(sql, rows) + con.commit() + logger.trace("Inserted {} rows into {}", len(data), name) + + def execute(self, expr: ir.Expr) -> pd.DataFrame: + return self._connection.execute(expr) + + def sql(self, query: str) -> ir.Table: + return self._connection.sql(query) + + def write_parquet( + self, + expr: ir.Table, + path: str, + partition_by: list[str] | None = None, + ) -> None: + if partition_by: + msg = "SQLite backend does not support partitioned Parquet writes." + raise NotImplementedError(msg) + df = self._connection.execute(expr) + df.to_parquet(path) + + def create_view_from_parquet(self, path: str, name: str) -> ir.Table: + # SQLite can't read Parquet natively. Load into a table instead. + df = pd.read_parquet(path) + return self.create_table(name, obj=df) + + def execute_sql(self, query: str) -> None: + logger.trace("execute_sql: {}", query) + con = self._connection.con + con.execute(query) + con.commit() + + def execute_sql_to_df(self, query: str) -> pd.DataFrame: + logger.trace("execute_sql_to_df: {}", query) + con = self._connection.con + cursor = con.execute(query) + rows = cursor.fetchall() + columns = [desc[0] for desc in cursor.description] if cursor.description else [] + return pd.DataFrame(rows, columns=columns) + + def dispose(self) -> None: + self._connection.disconnect() + + def reconnect(self) -> None: + db = self._database if self._database else ":memory:" + self._connection = ibis.sqlite.connect(db) diff --git a/src/chronify/ibis/types.py b/src/chronify/ibis/types.py new file mode 100644 index 0000000..d6b06f7 --- /dev/null +++ b/src/chronify/ibis/types.py @@ -0,0 +1,111 @@ +"""Type conversion utilities for Ibis backends.""" + +import ibis.expr.datatypes as dt +import pandas as pd +import pyarrow as pa + +# Mapping from user-facing string type names to Ibis data types +_COLUMN_TYPES: dict[str, dt.DataType] = { + "bool": dt.Boolean(), + "int": dt.Int64(), + "bigint": dt.Int64(), + "float": dt.Float64(), + "double": dt.Float64(), + "str": dt.String(), + "datetime": dt.Timestamp(timezone=None), + "datetime_tz": dt.Timestamp(timezone="UTC"), +} + +# Mapping from DuckDB type names to Ibis data types +_DUCKDB_TYPE_MAP: dict[str, dt.DataType] = { + "BOOLEAN": dt.Boolean(), + "TINYINT": dt.Int8(), + "SMALLINT": dt.Int16(), + "INTEGER": dt.Int32(), + "BIGINT": dt.Int64(), + "FLOAT": dt.Float32(), + "DOUBLE": dt.Float64(), + "VARCHAR": dt.String(), + "TIMESTAMP": dt.Timestamp(timezone=None), + "TIMESTAMP WITH TIME ZONE": dt.Timestamp(timezone="UTC"), + "TIMESTAMPTZ": dt.Timestamp(timezone="UTC"), + "TIMESTAMP_NS": dt.Timestamp(timezone=None), + "DATE": dt.Date(), +} + +# Reverse mapping from Ibis types to DuckDB type strings +_IBIS_TO_DUCKDB_MAP: dict[type[dt.DataType], str] = { + dt.Boolean: "BOOLEAN", + dt.Int8: "TINYINT", + dt.Int16: "SMALLINT", + dt.Int32: "INTEGER", + dt.Int64: "BIGINT", + dt.Float32: "FLOAT", + dt.Float64: "DOUBLE", + dt.String: "VARCHAR", + dt.Date: "DATE", +} + + +def get_ibis_type_from_string(type_name: str) -> dt.DataType: + """Convert a string type name to an Ibis DataType. + + Parameters + ---------- + type_name + One of: "bool", "int", "bigint", "float", "double", "str", "datetime", "datetime_tz" + """ + if type_name not in _COLUMN_TYPES: + msg = f"Unsupported type name: {type_name}. Valid types: {sorted(_COLUMN_TYPES.keys())}" + raise ValueError(msg) + return _COLUMN_TYPES[type_name] + + +def get_ibis_type_from_duckdb(duckdb_type: str) -> dt.DataType: + """Convert a DuckDB type string to an Ibis DataType.""" + upper = duckdb_type.upper() + if upper in _DUCKDB_TYPE_MAP: + return _DUCKDB_TYPE_MAP[upper] + msg = f"Unsupported DuckDB type: {duckdb_type}" + raise ValueError(msg) + + +def get_duckdb_type_from_ibis(ibis_type: dt.DataType) -> str: + """Convert an Ibis DataType to a DuckDB type string.""" + if isinstance(ibis_type, dt.Timestamp): + if ibis_type.timezone is not None: + return "TIMESTAMPTZ" + return "TIMESTAMP" + for cls, duckdb_name in _IBIS_TO_DUCKDB_MAP.items(): + if isinstance(ibis_type, cls): + return duckdb_name + msg = f"Unsupported Ibis type for DuckDB: {ibis_type}" + raise ValueError(msg) + + +def get_ibis_types_from_dataframe(df: pd.DataFrame) -> dict[str, dt.DataType]: + """Infer Ibis types from a pandas DataFrame's columns.""" + import duckdb + + con = duckdb.connect() + rel = con.from_df(df) + types = {} + for name, dtype in zip(rel.columns, rel.types, strict=True): + types[name] = get_ibis_type_from_duckdb(str(dtype)) + con.close() + return types + + +def get_ibis_schema_from_dataframe(df: pd.DataFrame) -> dict[str, dt.DataType]: + """Get an ibis schema dict from a pandas DataFrame.""" + return get_ibis_types_from_dataframe(df) + + +def pyarrow_to_ibis_type(arrow_type: pa.DataType) -> dt.DataType: + """Convert a PyArrow type to an Ibis DataType.""" + return dt.DataType.from_pyarrow(arrow_type) + + +def ibis_to_pyarrow_type(ibis_type: dt.DataType) -> pa.DataType: + """Convert an Ibis DataType to a PyArrow type.""" + return ibis_type.to_pyarrow() diff --git a/src/chronify/models.py b/src/chronify/models.py index b770200..9d53ac8 100644 --- a/src/chronify/models.py +++ b/src/chronify/models.py @@ -1,15 +1,16 @@ import re from typing import Any, Optional -import duckdb.typing +import duckdb +import ibis.expr.datatypes as dt import pandas as pd from duckdb.typing import DuckDBPyType from pydantic import Field, field_validator, model_validator -from sqlalchemy import BigInteger, Boolean, DateTime, Double, Float, Integer, SmallInteger, String from typing_extensions import Annotated from chronify.base_models import ChronifyBaseModel -from chronify.exceptions import InvalidParameter, InvalidValue +from chronify.exceptions import InvalidValue +from chronify.ibis.types import get_ibis_type_from_duckdb, get_duckdb_type_from_ibis from chronify.time_configs import TimeConfig @@ -142,81 +143,30 @@ def list_columns(self) -> list[str]: return time_columns -# TODO: print example tables here. - -_COLUMN_TYPES = { - "bool": Boolean, - "datetime": DateTime, - "float": Double, - "int": Integer, - "bigint": BigInteger, - "str": String, +_COLUMN_TYPES: dict[str, type[dt.DataType]] = { + "bool": dt.Boolean, + "datetime": dt.Timestamp, + "float": dt.Float64, + "int": dt.Int64, + "bigint": dt.Int64, + "str": dt.String, } -_DB_TYPES = {x for x in _COLUMN_TYPES.values()} - -_DUCKDB_TYPES_TO_SQLALCHEMY_TYPES = { - duckdb.typing.BIGINT.id: BigInteger, # type: ignore - duckdb.typing.BOOLEAN.id: Boolean, # type: ignore - duckdb.typing.DOUBLE.id: Double, # type: ignore - duckdb.typing.FLOAT.id: Float, # type: ignore - duckdb.typing.INTEGER.id: Integer, # type: ignore - duckdb.typing.TINYINT.id: SmallInteger, # type: ignore - duckdb.typing.VARCHAR.id: String, # type: ignore - # Note: timestamp requires special handling because of timezone in sqlalchemy. -} +_DB_TYPES = set(_COLUMN_TYPES.values()) + + +def get_ibis_type_from_duckdb_pytype(duckdb_type: DuckDBPyType) -> dt.DataType: + """Return the ibis type for a duckdb type.""" + return get_ibis_type_from_duckdb(str(duckdb_type)) -def get_sqlalchemy_type_from_duckdb(duckdb_type: DuckDBPyType) -> Any: - """Return the sqlalchemy type for a duckdb type.""" - match duckdb_type: - case duckdb.typing.TIMESTAMP_TZ: # type: ignore - sqlalchemy_type = DateTime(timezone=True) - case ( - duckdb.typing.TIMESTAMP # type: ignore - | duckdb.typing.TIMESTAMP_MS # type: ignore - | duckdb.typing.TIMESTAMP_NS # type: ignore - | duckdb.typing.TIMESTAMP_S # type: ignore - ): - sqlalchemy_type = DateTime(timezone=False) - case _: - cls = _DUCKDB_TYPES_TO_SQLALCHEMY_TYPES.get(duckdb_type.id) - if cls is None: - msg = f"There is no sqlalchemy mapping for {duckdb_type=}" - raise InvalidParameter(msg) - sqlalchemy_type = cls() - - return sqlalchemy_type - - -def get_duckdb_type_from_sqlalchemy(sqlalchemy_type: Any) -> DuckDBPyType: - """Return the duckdb type for a sqlalchemy type.""" - if isinstance(sqlalchemy_type, DateTime): - duckdb_type = ( - duckdb.typing.TIMESTAMP_TZ # type: ignore - if sqlalchemy_type.timezone - else duckdb.typing.TIMESTAMP # type: ignore - ) - elif isinstance(sqlalchemy_type, BigInteger): - duckdb_type = duckdb.typing.BIGINT # type: ignore - elif isinstance(sqlalchemy_type, Boolean): - duckdb_type = duckdb.typing.BOOLEAN # type: ignore - elif isinstance(sqlalchemy_type, Double): - duckdb_type = duckdb.typing.DOUBLE # type: ignore - elif isinstance(sqlalchemy_type, Integer): - duckdb_type = duckdb.typing.INTEGER # type: ignore - elif isinstance(sqlalchemy_type, String): - duckdb_type = duckdb.typing.VARCHAR # type: ignore - else: - msg = f"There is no duckdb mapping for {sqlalchemy_type=}" - raise InvalidParameter(msg) - - return duckdb_type # type: ignore +def get_duckdb_type_from_ibis_type(ibis_type: dt.DataType) -> str: + """Return the duckdb type string for an ibis type.""" + return get_duckdb_type_from_ibis(ibis_type) def get_duckdb_types_from_pandas(df: pd.DataFrame) -> list[DuckDBPyType]: """Return a list of DuckDB types from a pandas dataframe.""" - # This seems least-prone to error, but is not exactly the most efficient. short_df = df.head(1) # noqa: F841 return duckdb.sql("select * from short_df").dtypes @@ -231,18 +181,27 @@ class ColumnDType(ChronifyBaseModel): @classmethod def fix_data_type(cls, data: dict[str, Any]) -> dict[str, Any]: dtype = data.get("dtype") - if dtype is None or any(map(lambda x: isinstance(dtype, x), _DB_TYPES)): + if dtype is None: + return data + + if isinstance(dtype, dt.DataType): + return data + + if isinstance(dtype, type) and issubclass(dtype, dt.DataType): + data["dtype"] = dtype() return data if isinstance(dtype, str): val = _COLUMN_TYPES.get(dtype) if val is None: - options = sorted(_COLUMN_TYPES.keys()) + list(_DB_TYPES) + options = sorted(_COLUMN_TYPES.keys()) msg = f"{dtype=} must be one of {options}" raise InvalidValue(msg) data["dtype"] = val() else: - msg = f"dtype is an unsupported type: {type(dtype)}. It must be a str or type." + msg = ( + f"dtype is an unsupported type: {type(dtype)}. It must be a str or ibis DataType." + ) raise InvalidValue(msg) return data diff --git a/src/chronify/schema_manager.py b/src/chronify/schema_manager.py index 23828cc..2cdfba3 100644 --- a/src/chronify/schema_manager.py +++ b/src/chronify/schema_manager.py @@ -1,26 +1,11 @@ import json -from typing import Optional +import pandas as pd from loguru import logger -from sqlalchemy import ( - Column, - Connection, - Engine, - MetaData, - String, - Table, - delete, - insert, - select, - text, -) - -from chronify.exceptions import ( - TableNotStored, -) -from chronify.models import ( - TableSchema, -) + +from chronify.exceptions import TableNotStored +from chronify.ibis.base import IbisBackend +from chronify.models import TableSchema class SchemaManager: @@ -28,106 +13,58 @@ class SchemaManager: SCHEMAS_TABLE = "schemas" - def __init__(self, engine: Engine, metadata: MetaData): - self._engine = engine - self._metadata = metadata - # Caching is not necessary if using SQLite, which provides very fast performance (~1 us) - # for checking schemas in the **tiny** schemas table. - # The same lookups in DuckDB are taking over 100 us. + def __init__(self, backend: IbisBackend) -> None: + self._backend = backend self._cache: dict[str, TableSchema] = {} - if self.SCHEMAS_TABLE in self._metadata.tables: - logger.info("Loaded existing database {}", self._engine.url.database) + if self._backend.has_table(self.SCHEMAS_TABLE): + logger.info("Loaded existing database {}", self._backend.database) self.rebuild_cache() else: - if self._engine.name == "hive": - # metadata.create_all doesn't work here. - with self._engine.begin() as conn: - conn.execute(text(f"DROP TABLE IF EXISTS {self.SCHEMAS_TABLE}")) - conn.execute( - text(f"CREATE TABLE {self.SCHEMAS_TABLE}(name STRING, schema STRING)") - ) - self._metadata.reflect(self._engine) - else: - table = Table( - self.SCHEMAS_TABLE, - self._metadata, - Column("name", String, nullable=False, unique=True), - Column("schema", String), # schema encoded as JSON - ) - self._metadata.create_all(self._engine, tables=[table]) - logger.info("Initialized new database: {}", self._engine.url.database) - - def _get_schema_table(self) -> Table: - return ( - self._metadata.tables[self.SCHEMAS_TABLE] - if self._engine.name == "hive" - else Table(self.SCHEMAS_TABLE, self._metadata) - ) - - def add_schema(self, conn: Connection, schema: TableSchema) -> Table: + self._create_schemas_table() + logger.info("Initialized new database: {}", self._backend.database) + + def _create_schemas_table(self) -> None: + import ibis + + schema = ibis.schema({"name": "string", "schema": "string"}) + self._backend.create_table(self.SCHEMAS_TABLE, schema=schema) + + def add_schema(self, schema: TableSchema) -> None: """Add the schema to the store.""" - table = self._get_schema_table() - stmt = insert(table).values(name=schema.name, schema=schema.model_dump_json()) - conn.execute(stmt) - # If there is a rollback after this addition to cached, things _should_ still be OK. - # The table will be deleted and any attempted reads will fail with an error. - # There will be a stale entry in cache, but it will be overwritten if the user ever - # adds a new table with the same name. + df = pd.DataFrame({"name": [schema.name], "schema": [schema.model_dump_json()]}) + self._backend.insert(self.SCHEMAS_TABLE, df) self._cache[schema.name] = schema logger.trace("Added schema for table {}", schema.name) - return table - def get_schema(self, name: str, conn: Optional[Connection] = None) -> TableSchema: + def get_schema(self, name: str) -> TableSchema: """Retrieve the schema for the table with name.""" schema = self._cache.get(name) if schema is None: - self.rebuild_cache(conn=conn) + self.rebuild_cache() schema = self._cache.get(name) if schema is None: msg = f"{name=}" raise TableNotStored(msg) - return self._cache[name] + return schema - def remove_schema(self, conn: Connection, name: str) -> None: + def remove_schema(self, name: str) -> None: """Remove the schema from the store.""" - table = self._get_schema_table() - if self._engine.name == "hive": - # Hive/Spark doesn't support delete, so we have to re-create the table without - # this one entry - stmt = select(table).where(table.c.name != name) - rows = conn.execute(stmt).fetchall() - conn.execute(text(f"DROP TABLE {self.SCHEMAS_TABLE}")) - conn.execute(text(f"CREATE TABLE {self.SCHEMAS_TABLE}(name STRING, schema STRING)")) - for row in rows: - params = {"name": row[0], "schema": row[1]} - conn.execute( - text(f"INSERT INTO {self.SCHEMAS_TABLE} VALUES(:name, :schema)"), - params, - ) - else: - stmt2 = delete(table).where(table.c["name"] == name) - conn.execute(stmt2) - - self._cache.pop(name) + self._backend.execute_sql(f"DELETE FROM {self.SCHEMAS_TABLE} WHERE name = '{name}'") + self._cache.pop(name, None) - def rebuild_cache(self, conn: Optional[Connection] = None) -> None: + def rebuild_cache(self) -> None: """Rebuild the cache of schemas.""" self._cache.clear() - if conn is None: - with self._engine.connect() as conn: - self._rebuild_cache(conn) - else: - self._rebuild_cache(conn) - - def _rebuild_cache(self, conn: Connection) -> None: - table = self._get_schema_table() - stmt = select(table) - res = conn.execute(stmt).fetchall() - for name, json_text in res: - schema = TableSchema(**json.loads(json_text)) + self._rebuild_cache() + + def _rebuild_cache(self) -> None: + df = self._backend.execute_sql_to_df(f"SELECT * FROM {self.SCHEMAS_TABLE}") + for _, row in df.iterrows(): + name = row["name"] + schema = TableSchema(**json.loads(row["schema"])) assert name == schema.name assert name not in self._cache self._cache[name] = schema diff --git a/src/chronify/sqlalchemy/__init__.py b/src/chronify/sqlalchemy/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/chronify/sqlalchemy/functions.py b/src/chronify/sqlalchemy/functions.py deleted file mode 100644 index ff9c42c..0000000 --- a/src/chronify/sqlalchemy/functions.py +++ /dev/null @@ -1,273 +0,0 @@ -"""This file provides functions to read and write the database as efficiently as possible. -The default behavior of sqlalchemy is to convert data into rows of tuples in Python, which -is very slow. This code attempts to bypass Python as much as possible through Arrow tables -in memory. -""" - -import atexit -from pathlib import Path -from tempfile import NamedTemporaryFile -from typing import Any, Literal, Optional, TypeAlias, Sequence -from collections import Counter - -import pandas as pd -from numpy.dtypes import DateTime64DType, ObjectDType -from pandas import DatetimeTZDtype -from chronify.time import TimeDataType -from sqlalchemy import Connection, Engine, Selectable, text - -from chronify.exceptions import InvalidOperation, InvalidParameter -from chronify.time_configs import ( - DatetimeRangeBase, - TimeBaseModel, - DatetimeRange, - DatetimeRangeWithTZColumn, -) -from chronify.utils.path_utils import check_overwrite, delete_if_exists, to_path - -# Copied from Pandas/Polars -DbWriteMode: TypeAlias = Literal["replace", "append", "fail"] -DatetimeRangeWithDtype: TypeAlias = DatetimeRange | DatetimeRangeWithTZColumn - - -def read_database( - query: Selectable | str, conn: Connection, config: TimeBaseModel, params: Any = None -) -> pd.DataFrame: - """Read a database query into a Pandas DataFrame.""" - match conn.engine.name: - case "duckdb": - if isinstance(query, str): - df = conn._dbapi_connection.driver_connection.sql(query, params=params).to_df() # type: ignore - else: - df = conn.execute(query).cursor.fetch_df() # type: ignore - case "sqlite": - df = pd.read_sql(query, conn, params=params) - if isinstance(config, (DatetimeRange, DatetimeRangeWithTZColumn)): - _convert_database_output_for_datetime(df, config) - case "hive": - df = _read_from_hive(query, conn, config, params) - case _: - df = pd.read_sql(query, conn, params=params) - return df # type: ignore - - -def write_database( - df: pd.DataFrame, - conn: Connection, - table_name: str, - configs: Sequence[TimeBaseModel], - if_table_exists: DbWriteMode = "append", - scratch_dir: Path | None = None, -) -> None: - """Write a Pandas DataFrame to the database. - configs allows sqlite formatting for more than one datetime columns. - - Note: Writing persistent data with Hive as the backend is not supported. - This function will write the dataframe to a temporary Parquet file and then create - a view into that file. This is only to support ephemeral tables, such as for mapping tables. - """ - match conn.engine.name: - case "duckdb": - _write_to_duckdb(df, conn, table_name, if_table_exists) - case "sqlite": - _write_to_sqlite(df, conn, table_name, configs, if_table_exists) - case "hive": - _write_to_hive(df, conn, table_name, configs, if_table_exists, scratch_dir) - case _: - df.to_sql(table_name, conn, if_exists=if_table_exists, index=False) - - -def _check_one_config_per_datetime_column(configs: Sequence[TimeBaseModel]) -> None: - time_col_count = Counter( - [config.time_column for config in configs if isinstance(config, DatetimeRangeBase)] - ) - time_col_dup = {k: v for k, v in time_col_count.items() if v > 1} - if len(time_col_dup) > 0: - msg = f"More than one datetime config found for: {time_col_dup}" - raise InvalidParameter(msg) - - -def _convert_database_input_for_datetime( - df: pd.DataFrame, config: DatetimeRangeWithDtype, copied: bool -) -> tuple[pd.DataFrame, bool]: - if config.dtype == TimeDataType.TIMESTAMP_NTZ: - return df, copied - - if copied: - df2 = df - else: - df2 = df.copy() - copied = True - if isinstance(df2[config.time_column].dtype, DatetimeTZDtype): - df2[config.time_column] = df2[config.time_column].dt.tz_convert("UTC") - else: - df2[config.time_column] = df2[config.time_column].dt.tz_localize("UTC") - - return df2, copied - - -def _convert_database_output_for_datetime( - df: pd.DataFrame, config: DatetimeRangeWithDtype -) -> None: - if config.time_column in df.columns: - if config.dtype == TimeDataType.TIMESTAMP_TZ: - if isinstance(df[config.time_column].dtype, ObjectDType): - df[config.time_column] = pd.to_datetime(df[config.time_column], utc=True) - else: - df[config.time_column] = df[config.time_column].dt.tz_localize("UTC") - else: - if isinstance(df[config.time_column].dtype, ObjectDType): - df[config.time_column] = pd.to_datetime(df[config.time_column], utc=False) - - -def _write_to_duckdb( - df: pd.DataFrame, - conn: Connection, - table_name: str, - if_table_exists: DbWriteMode, -) -> None: - assert conn._dbapi_connection is not None - assert conn._dbapi_connection.driver_connection is not None - - match if_table_exists: - case "append": - query = f"INSERT INTO {table_name} SELECT * FROM df" - case "replace": - conn._dbapi_connection.driver_connection.sql(f"DROP TABLE IF EXISTS {table_name}") - query = f"CREATE TABLE {table_name} AS SELECT * FROM df" - case "fail": - query = f"CREATE TABLE {table_name} AS SELECT * FROM df" - case _: - msg = f"{if_table_exists=}" - raise InvalidOperation(msg) - - conn._dbapi_connection.driver_connection.sql(query) - - -def _write_to_hive( - df: pd.DataFrame, - conn: Connection, - table_name: str, - configs: Sequence[TimeBaseModel], - if_table_exists: DbWriteMode, - scratch_dir: Path | None, -) -> None: - df2 = df.copy() - for config in configs: - if isinstance(config, DatetimeRangeBase): - if isinstance(df2[config.time_column].dtype, DatetimeTZDtype): - # Spark doesn't like ns. That might change in the future. - # Pandas might offer a better way to change from ns to us in the future. - new_dtype = df2[config.time_column].dtype.name.replace( - "datetime64[ns", "datetime64[us" - ) - df2[config.time_column] = df2[config.time_column].astype(new_dtype) # type: ignore - elif isinstance(df2[config.time_column].dtype, DateTime64DType): - new_dtype = "datetime64[us]" - df2[config.time_column] = df2[config.time_column].astype(new_dtype) # type: ignore - else: - new_dtype = "datetime64[us]" - df2[config.time_column] = pd.to_datetime( - df2[config.time_column], utc=False, errors="raise" - ).astype(new_dtype) # type: ignore - - with NamedTemporaryFile(suffix=".parquet", dir=scratch_dir) as f: - f.close() - output = Path(f.name) - - df2.to_parquet(output) - atexit.register(lambda: delete_if_exists(output)) - select_stmt = f"SELECT * FROM parquet.`{output}`" - # TODO: CREATE TABLE causes DST fallback timestamps to get dropped - match if_table_exists: - case "append": - msg = "INSERT INTO is not supported with write_to_hive" - raise InvalidOperation(msg) - case "replace": - conn.execute(text(f"DROP VIEW IF EXISTS {table_name}")) - query = f"CREATE VIEW {table_name} AS {select_stmt}" - case "fail": - # Let the database fail the operation if the table already exists. - query = f"CREATE VIEW {table_name} AS {select_stmt}" - case _: - msg = f"{if_table_exists=}" - raise InvalidOperation(msg) - conn.execute(text(query)) - - -def _read_from_hive( - query: Selectable | str, conn: Connection, config: TimeBaseModel, params: Any = None -) -> pd.DataFrame: - df = pd.read_sql_query(query, conn, params=params) - if ( - isinstance(config, (DatetimeRange, DatetimeRangeWithTZColumn)) - and config.time_column in df.columns - and config.dtype == TimeDataType.TIMESTAMP_TZ - ): - # This is tied to the fact that we set the Spark session to UTC. - # Otherwise, there is confusion with the computer's local time zone. - df[config.time_column] = df[config.time_column].dt.tz_localize("UTC") - return df - - -def _write_to_sqlite( - df: pd.DataFrame, - conn: Connection, - table_name: str, - configs: Sequence[TimeBaseModel], - if_table_exists: DbWriteMode, -) -> None: - _check_one_config_per_datetime_column(configs) - copied = False - for config in configs: - if isinstance(config, (DatetimeRange, DatetimeRangeWithTZColumn)): - df, copied = _convert_database_input_for_datetime(df, config, copied) - df.to_sql(table_name, conn, if_exists=if_table_exists, index=False) - - -def create_view_from_parquet(conn: Connection, view_name: str, filename: Path) -> None: - """Create a view from a Parquet file.""" - if conn.engine.name == "duckdb": - str_path = f"{filename}/**/*.parquet" if filename.is_dir() else str(filename) - query = f"CREATE VIEW {view_name} AS SELECT * FROM read_parquet('{str_path}')" - elif conn.engine.name == "hive": - query = f"CREATE VIEW {view_name} AS SELECT * FROM parquet.`{filename}`" - else: - msg = f"create_view_from_parquet does not support engine={conn.engine.name}" - raise NotImplementedError(msg) - conn.execute(text(query)) - - -def write_query_to_parquet( - engine: Engine, - query: str, - output_file: Path, - overwrite: bool = False, - partition_columns: Optional[list[str]] = None, -) -> None: - """Write the query to a Parquet file.""" - output_file = to_path(output_file) - check_overwrite(output_file, overwrite) - match engine.name: - case "duckdb": - if partition_columns: - cols = ",".join(partition_columns) - query = ( - f"COPY ({query}) TO '{output_file}' (FORMAT PARQUET, PARTITION_BY ({cols}))" - ) - else: - query = f"COPY ({query}) TO '{output_file}' (FORMAT PARQUET)" - case "hive": - if not overwrite: - msg = "write_table_to_parquet with Hive requires overwrite=True" - raise InvalidOperation(msg) - if partition_columns: - msg = "write_table_to_parquet with Hive doesn't support partition_columns" - raise InvalidOperation(msg) - query = f"INSERT OVERWRITE DIRECTORY '{output_file}' USING parquet {query}" - case _: - msg = f"{engine.name=}" - raise NotImplementedError(msg) - - with engine.connect() as conn: - conn.execute(text(query)) diff --git a/src/chronify/store.py b/src/chronify/store.py index 3f9adfe..c647afb 100644 --- a/src/chronify/store.py +++ b/src/chronify/store.py @@ -2,26 +2,13 @@ from pathlib import Path import shutil from typing import Any, Optional -from chronify.utils.sql import make_temp_view_name from datetime import tzinfo import duckdb +import ibis.expr.types as ir import pandas as pd from duckdb import DuckDBPyRelation from loguru import logger -from sqlalchemy import ( - Column, - Connection, - Engine, - MetaData, - Selectable, - Table, - create_engine, - delete, - func, - select, - text, -) import chronify.duckdb.functions as ddbf from chronify.exceptions import ( @@ -33,18 +20,18 @@ TableNotStored, ) from chronify.csv_io import read_csv +from chronify.ibis import IbisBackend, make_backend +from chronify.ibis.functions import ( + create_view_from_parquet, + read_query, + read_table, + write_parquet, + write_table, +) from chronify.models import ( CsvTableSchema, PivotedTableSchema, TableSchema, - get_duckdb_types_from_pandas, - get_sqlalchemy_type_from_duckdb, -) -from chronify.sqlalchemy.functions import ( - create_view_from_parquet, - read_database, - write_database, - write_query_to_parquet, ) from chronify.schema_manager import SchemaManager from chronify.time_configs import DatetimeRange, IndexTimeRangeBase, TimeBasedDataAdjustment @@ -53,7 +40,6 @@ from chronify.time_zone_converter import TimeZoneConverter, TimeZoneConverterByColumn from chronify.time_zone_localizer import TimeZoneLocalizer, TimeZoneLocalizerByColumn from chronify.utils.path_utils import check_overwrite, to_path -from chronify.utils.sqlalchemy_view import create_view class Store: @@ -61,218 +47,119 @@ class Store: def __init__( self, - engine: Optional[Engine] = None, - engine_name: Optional[str] = None, + backend: Optional[IbisBackend] = None, + backend_name: Optional[str] = None, file_path: Optional[Path | str] = None, - **connect_kwargs: Any, ) -> None: """Construct the Store. Parameters ---------- - engine - Optional, defaults to a engine connected to an in-memory DuckDB database. - engine_name - Optional, name of engine to use ('duckdb', 'sqlite'). Mutually exclusive with engine. + backend + Optional, defaults to a DuckDB in-memory backend. + backend_name + Optional, name of backend to use ('duckdb', 'sqlite'). Mutually exclusive with backend. file_path Optional, use this file for the database. If the file does not exist, create a new database. If the file exists, load that existing database. Defaults to a new in-memory database. - - Examples - -------- - >>> from sqlalchemy - >>> store1 = Store() - >>> store2 = Store(engine=Engine("duckdb:///time_series.db")) - >>> store3 = Store(engine=Engine("sqlite:///time_series.db")) - >>> store4 = Store(engine_name="sqlite") """ - self._metadata = MetaData() - if engine and engine_name: - msg = f"{engine=} and {engine_name=} cannot both be set" + if backend and backend_name: + msg = f"{backend=} and {backend_name=} cannot both be set" raise ConflictingInputsError(msg) - filename = ":memory:" if file_path is None else str(file_path) - if engine is None: - name = engine_name or "duckdb" - match name: - case "duckdb" | "sqlite": - engine_path = f"{name}:///{filename}" - case _: - msg = f"{engine_name=}" - raise NotImplementedError(msg) - self._engine = create_engine(engine_path, **connect_kwargs) + if backend is not None: + self._backend = backend else: - self._engine = engine + name = backend_name or "duckdb" + database = str(file_path) if file_path else None + self._backend = make_backend(name, database=database) - self._schema_mgr = SchemaManager(self._engine, self._metadata) - if self._engine.url.database != ":memory:": - self.update_metadata() + self._schema_mgr = SchemaManager(self._backend) @classmethod def create_in_memory_db( cls, - engine_name: str = "duckdb", - **connect_kwargs: Any, + backend_name: str = "duckdb", ) -> "Store": """Create a Store with an in-memory database.""" - return Store(engine=create_engine(f"{engine_name}:///:memory:", **connect_kwargs)) + return Store(backend=make_backend(backend_name)) @classmethod def create_file_db( cls, file_path: Path | str = "time_series.db", - engine_name: str = "duckdb", + backend_name: str = "duckdb", overwrite: bool = False, - **connect_kwargs: Any, ) -> "Store": """Create a Store with a file-based database.""" path = to_path(file_path) check_overwrite(path, overwrite) - return Store(engine=create_engine(f"{engine_name}:///{path}", **connect_kwargs)) - - @classmethod - def create_new_hive_store( - cls, - url: str, - drop_schema: bool = True, - **connect_kwargs: Any, - ) -> "Store": - """Create a new Store in a Hive database. - Recommended usage is to create views from Parquet files. Ingesting data into tables - from files or DataFrames is not supported. - - This has been tested with Apache Spark running an Apache Thrift Server. - - Parameters - ---------- - url - Thrift server URL - drop_schema - If True, drop the schema table if it's already there. - - Examples - -------- - >>> store = Store.create_new_hive_store("hive://localhost:10000/default") - - See also - -------- - create_view_from_parquet - """ - # We don't currently expect to need to load an existing hive-based store, but it could - # be added. - if "hive://" not in url: - msg = f"Expected 'hive://' to be in url: {url}" - raise InvalidParameter(msg) - engine = create_engine(url, **connect_kwargs) - metadata = MetaData() - metadata.reflect(engine, views=True) - with engine.begin() as conn: - # Workaround for ambiguity of time zones in the read path. - conn.execute(text("SET TIME ZONE 'UTC'")) - # Workaround for the fact that Spark uses a non-standard format for timestamps - # in Parquet files. Pandas/DuckDB can't interpret them properly. - conn.execute(text("SET spark.sql.parquet.outputTimestampType=TIMESTAMP_MICROS")) - - if drop_schema: - if SchemaManager.SCHEMAS_TABLE in metadata.tables: - conn.execute(text(f"DROP TABLE {SchemaManager.SCHEMAS_TABLE}")) - - return cls(engine=engine) + return Store(backend=make_backend(backend_name, database=str(path))) @classmethod def load_from_file( cls, file_path: Path | str, - engine_name: str = "duckdb", - **connect_kwargs: Any, + backend_name: str = "duckdb", ) -> "Store": """Load an existing store from a database.""" path = to_path(file_path) if not path.exists(): msg = str(path) raise FileNotFoundError(msg) - return Store(engine=create_engine(f"{engine_name}:///{path}", **connect_kwargs)) + return Store(backend=make_backend(backend_name, database=str(path))) def dispose(self) -> None: - """Call self.engine.dispose() in order to dispose of the current connections.""" - self._engine.dispose() + """Dispose of the current connections.""" + self._backend.dispose() - def get_table(self, name: str) -> Table: - """Return the sqlalchemy Table object.""" + def get_table(self, name: str) -> ir.Table: + """Return the ibis Table expression.""" if not self.has_table(name): msg = f"{name=}" raise TableNotStored(msg) - - return Table(name, self._metadata) + return self._backend.table(name) def has_table(self, name: str) -> bool: """Return True if the database has a table with the given name.""" - return name in self._metadata.tables + return self._backend.has_table(name) def list_tables(self) -> list[str]: """Return a list of user tables in the database.""" - return [x for x in self._metadata.tables if x != SchemaManager.SCHEMAS_TABLE] + return [x for x in self._backend.list_tables() if x != SchemaManager.SCHEMAS_TABLE] - def try_get_table(self, name: str) -> Table | None: - """Return the sqlalchemy Table object or None if it is not stored.""" + def try_get_table(self, name: str) -> ir.Table | None: + """Return the ibis Table expression or None if it is not stored.""" if not self.has_table(name): return None - return Table(name, self._metadata) - - def update_metadata(self) -> None: - """Update the sqlalchemy metadata for table schema. Call this method if you add tables - in the sqlalchemy engine outside of this class or perform a rollback - in the same transaction in which chronify added tables. - """ - # Create a new object because sqlalchemy does not detect dropped tables in reflect. - metadata = MetaData() - metadata.reflect(self._engine, views=True) - logger.trace( - "Updated metadata, added: {}, dropped: {}", - sorted(set(metadata.tables).difference(self._metadata.tables)), - sorted(set(self._metadata.tables).difference(metadata.tables)), - ) - self._metadata = metadata - self._schema_mgr.rebuild_cache() + return self._backend.table(name) def backup(self, dst: Path | str, overwrite: bool = False) -> None: """Copy the database to a new location. Not yet supported for in-memory databases.""" - self._engine.dispose() + self._backend.dispose() path = to_path(dst) check_overwrite(path, overwrite) - match self._engine.name: - case "duckdb" | "sqlite": - if self._engine.url.database is None or self._engine.url.database == ":memory:": - msg = "backup is only supported with a database backed by a file" - raise InvalidOperation(msg) - src_file = Path(self._engine.url.database) - shutil.copyfile(src_file, path) - logger.info("Copied database to {}", path) - case _: - msg = self._engine.name - raise NotImplementedError(msg) - - @property - def engine(self) -> Engine: - """Return the sqlalchemy engine.""" - return self._engine + if self._backend.database is None: + msg = "backup is only supported with a database backed by a file" + raise InvalidOperation(msg) + src_file = Path(self._backend.database) + shutil.copyfile(src_file, path) + logger.info("Copied database to {}", path) + self._backend.reconnect() @property - def metadata(self) -> MetaData: - """Return the sqlalchemy metadata.""" - return self._metadata + def backend(self) -> IbisBackend: + """Return the ibis backend.""" + return self._backend @property def schema_manager(self) -> SchemaManager: """Return the store's schema manager.""" return self._schema_mgr - def check_timestamps(self, name: str, connection: Connection | None = None) -> None: + def check_timestamps(self, name: str) -> None: """Check the timestamps in the table. - This is useful if you call a :meth:`ingest_table` many times with skip_time_checks=True - and then want to check the final table. - Parameters ---------- name @@ -283,191 +170,53 @@ def check_timestamps(self, name: str, connection: Connection | None = None) -> N InvalidTable Raised if the timestamps do not match the schema. """ - table = self.get_table(name) schema = self._schema_mgr.get_schema(name) - if connection is None: - with self._engine.connect() as conn: - check_timestamps(conn, table, schema) - else: - check_timestamps(connection, table, schema) + check_timestamps(self._backend, name, schema) def create_view_from_parquet( self, path: Path, schema: TableSchema, bypass_checks: bool = False ) -> None: - """Load a table into the database.""" + """Load a table into the database from a Parquet file.""" self._create_view_from_parquet(path, schema) try: - with self._engine.connect() as conn: - table = self.get_table(schema.name) - if not bypass_checks: - check_timestamps(conn, table, schema) + if not bypass_checks: + check_timestamps(self._backend, schema.name, schema) except InvalidTable: - # This doesn't use conn.rollback because we can't update the sqlalchemy metadata - # for this view inside the connection. self.drop_view(schema.name) raise def _create_view_from_parquet(self, path: Path | str, schema: TableSchema) -> None: - """Create a view in the database from a Parquet file. - - Parameters - ---------- - schema - Defines the schema of the view to create in the database. Must match the input data. - path - Path to Parquet file. - - Raises - ------ - InvalidTable - Raised if the schema does not match the input data. - - Examples - -------- - >>> store = Store() - >>> store.create_view_from_parquet( - ... TableSchema( - ... name="devices", - ... value_column="value", - ... time_config=DatetimeRange( - ... time_column="timestamp", - ... start=datetime(2020, 1, 1, 0), - ... length=8784, - ... resolution=timedelta(hours=1), - ... ), - ... time_array_id_columns=["id"], - ... ), - ... "table.parquet", - ... ) - """ - with self._engine.begin() as conn: - create_view_from_parquet(conn, schema.name, to_path(path)) - self._schema_mgr.add_schema(conn, schema) - - self.update_metadata() + """Create a view in the database from a Parquet file.""" + create_view_from_parquet(self._backend, to_path(path), schema.name) + self._schema_mgr.add_schema(schema) def ingest_from_csv( self, path: Path | str, src_schema: CsvTableSchema, dst_schema: TableSchema, - connection: Optional[Connection] = None, ) -> bool: - """Ingest data from a CSV file. - - Parameters - ---------- - path - Source data file - src_schema - Defines the schema of the source file. - dst_schema - Defines the destination table in the database. - connection - Optional connection to reuse. Refer to :meth:`ingest_table` for notes. - - Returns - ------- - bool - Return True if a table was created. - - Raises - ------ - InvalidTable - Raised if the data does not match the schema. - - Examples - -------- - >>> resolution = timedelta(hours=1) - >>> time_config = DatetimeRange( - ... time_column="timestamp", - ... start=datetime(2020, 1, 1, 0), - ... length=8784, - ... resolution=timedelta(hours=1), - ... ) - >>> store = Store() - >>> store.ingest_from_csv( - ... "data.csv", - ... CsvTableSchema( - ... time_config=time_config, - ... pivoted_dimension_name="device", - ... value_columns=["device1", "device2", "device3"], - ... ), - ... TableSchema( - ... name="devices", - ... value_column="value", - ... time_config=time_config, - ... time_array_id_columns=["device"], - ... ), - ... ) - - See Also - -------- - ingest_from_csvs - """ - return self.ingest_from_csvs((path,), src_schema, dst_schema, connection=connection) + """Ingest data from a CSV file.""" + return self.ingest_from_csvs((path,), src_schema, dst_schema) def ingest_from_csvs( self, paths: Iterable[Path | str], src_schema: CsvTableSchema, dst_schema: TableSchema, - connection: Optional[Connection] = None, ) -> bool: - """Ingest data into the table specifed by schema. If the table does not exist, - create it. This is faster than calling :meth:`ingest_from_csv` many times. - Each file is loaded into memory one at a time. - If any error occurs, all added data will be removed and the state of the database will - be the same as the original state. - - Parameters - ---------- - path - Source data files - src_schema - Defines the schema of the source files. - dst_schema - Defines the destination table in the database. - conn - Optional connection to reuse. Refer to :meth:`ingest_table` for notes. - - Returns - ------- - bool - Return True if a table was created. - - Raises - ------ - InvalidTable - Raised if the data does not match the schema. - - See Also - -------- - ingest_from_csv - """ + """Ingest data from multiple CSV files into the table specified by schema.""" try: - if connection is None: - with self._engine.begin() as conn: - created_table = self._ingest_from_csvs(conn, paths, src_schema, dst_schema) - else: - created_table = self._ingest_from_csvs(connection, paths, src_schema, dst_schema) + created_table = self._ingest_from_csvs(paths, src_schema, dst_schema) except Exception: - # TODO: - # 1. The implicit rollback does not remove tables from our sqlalchemy metadata object. - # This means that the metadata object could be out-of-date if the user - # is self-managing the connection. - # 2. Python sqlite3 does not appear to support rollbacks with DDL statements. - # See discussion at https://bugs.python.org/issue10740. - self._handle_sqlite_error_case(dst_schema.name, connection) - if dst_schema.name in self._metadata.tables: - self._metadata.remove(Table(dst_schema.name, self._metadata)) + if self._backend.has_table(dst_schema.name): + self._backend.drop_table(dst_schema.name) + self._schema_mgr.remove_schema(dst_schema.name) raise - return created_table def _ingest_from_csvs( self, - conn: Connection, paths: Iterable[Path | str], src_schema: CsvTableSchema, dst_schema: TableSchema, @@ -477,15 +226,13 @@ def _ingest_from_csvs( return created_table for path in paths: - if self._ingest_from_csv(conn, path, src_schema, dst_schema): + if self._ingest_from_csv(path, src_schema, dst_schema): created_table = True - table = Table(dst_schema.name, self._metadata) - check_timestamps(conn, table, dst_schema) + check_timestamps(self._backend, dst_schema.name, dst_schema) return created_table def _ingest_from_csv( self, - conn: Connection, path: Path | str, src_schema: CsvTableSchema, dst_schema: TableSchema, @@ -497,171 +244,60 @@ def _ingest_from_csv( if isinstance(src_schema.time_config, IndexTimeRangeBase): if isinstance(dst_schema.time_config, DatetimeRange): raise NotImplementedError - # timestamps = IndexTimeRangeGenerator(src_schema.time_config).list_timestamps() - # rel = ddbf.add_datetime_column( - # rel=rel, - # start=dst_schema.time_config.start, - # resolution=dst_schema.time_config.resolution, - # length=dst_schema.time_config.length, - # time_array_id_columns=src_schema.time_array_id_columns, - # time_column=dst_schema.time_config.time_column, - # timestamps=timestamps, - # ) else: cls_name = dst_schema.time_config.__class__.__name__ msg = f"{src_schema.time_config.__class__.__name__} cannot be converted to {cls_name}" raise NotImplementedError(msg) if src_schema.pivoted_dimension_name is not None: - return self._ingest_pivoted_table(conn, rel, src_schema, dst_schema) + return self._ingest_pivoted_table(rel, src_schema, dst_schema) - return self._ingest_table(conn, rel, dst_schema) + return self._ingest_table(rel, dst_schema) def ingest_pivoted_table( self, data: pd.DataFrame | DuckDBPyRelation, src_schema: PivotedTableSchema | CsvTableSchema, dst_schema: TableSchema, - connection: Optional[Connection] = None, ) -> bool: - """Ingest pivoted data into the table specifed by schema. If the table does not exist, - create it. Chronify will unpivot the data before ingesting it. - - Parameters - ---------- - data - Input data to ingest into the database. - src_schema - Defines the schema of the input data. - dst_schema - Defines the destination table in the database. - conn - Optional connection to reuse. Refer to :meth:`ingest_table` for notes. - - Returns - ------- - bool - Return True if a table was created. - - Raises - ------ - InvalidTable - Raised if the data does not match the schema. - - Examples - -------- - >>> resolution = timedelta(hours=1) - >>> df = pd.DataFrame( - ... { - ... "timestamp": pd.date_range( - ... "2020-01-01", "2020-12-31 23:00:00", freq=resolution - ... ), - ... "device1": np.random.random(8784), - ... "device2": np.random.random(8784), - ... "device3": np.random.random(8784), - ... } - ... ) - >>> time_config = DatetimeRange( - ... time_column="timestamp", - ... start=datetime(2020, 1, 1, 0), - ... length=8784, - ... resolution=timedelta(hours=1), - ... ) - >>> store = Store() - >>> store.ingest_pivoted_table( - ... df, - ... PivotedTableSchema( - ... time_config=time_config, - ... pivoted_dimension_name="device", - ... value_columns=["device1", "device2", "device3"], - ... ), - ... TableSchema( - ... name="devices", - ... value_column="value", - ... time_config=time_config, - ... time_array_id_columns=["device"], - ... ), - ... ) - - See Also - -------- - ingest_pivoted_tables - """ - return self.ingest_pivoted_tables((data,), src_schema, dst_schema, connection=connection) + """Ingest pivoted data into the table specified by schema.""" + return self.ingest_pivoted_tables((data,), src_schema, dst_schema) def ingest_pivoted_tables( self, data: Iterable[pd.DataFrame | DuckDBPyRelation], src_schema: PivotedTableSchema | CsvTableSchema, dst_schema: TableSchema, - connection: Optional[Connection] = None, ) -> bool: - """Ingest pivoted data into the table specifed by schema. - - If the table does not exist, create it. Unpivot the data before ingesting it. - This is faster than calling :meth:`ingest_pivoted_table` many times. - If any error occurs, all added data will be removed and the state of the database will - be the same as the original state. - - Parameters - ---------- - data - Data to ingest into the database. - src_schema - Defines the schema of all input tables. - dst_schema - Defines the destination table in the database. - conn - Optional connection to reuse. Refer to :meth:`ingest_table` for notes. - - Returns - ------- - bool - Return True if a table was created. - - See Also - -------- - ingest_pivoted_table - """ + """Ingest pivoted data from multiple tables. Unpivot before ingesting.""" try: - if connection is None: - with self._engine.begin() as conn: - created_table = self._ingest_pivoted_tables(conn, data, src_schema, dst_schema) - else: - created_table = self._ingest_pivoted_tables( - connection, data, src_schema, dst_schema - ) + created_table = self._ingest_pivoted_tables(data, src_schema, dst_schema) except Exception: - self._handle_sqlite_error_case(dst_schema.name, connection) - if dst_schema.name in self._metadata.tables: - self._metadata.remove(Table(dst_schema.name, self._metadata)) + if self._backend.has_table(dst_schema.name): + self._backend.drop_table(dst_schema.name) raise - return created_table def _ingest_pivoted_tables( self, - conn: Connection, data: Iterable[pd.DataFrame | DuckDBPyRelation], src_schema: PivotedTableSchema | CsvTableSchema, dst_schema: TableSchema, ) -> bool: created_table = False for table in data: - if self._ingest_pivoted_table(conn, table, src_schema, dst_schema): + if self._ingest_pivoted_table(table, src_schema, dst_schema): created_table = True - check_timestamps(conn, Table(dst_schema.name, self._metadata), dst_schema) + check_timestamps(self._backend, dst_schema.name, dst_schema) return created_table def _ingest_pivoted_table( self, - conn: Connection, data: pd.DataFrame | DuckDBPyRelation, src_schema: PivotedTableSchema | CsvTableSchema, dst_schema: TableSchema, ) -> bool: if isinstance(data, pd.DataFrame): - # This is a shortcut for registering a temporary view. tmp_df = data # noqa: F841 rel = duckdb.sql("SELECT * from tmp_df") else: @@ -674,16 +310,15 @@ def _ingest_pivoted_table( src_schema.pivoted_dimension_name, dst_schema.value_column, ) - return self._ingest_table(conn, rel2, dst_schema) + return self._ingest_table(rel2, dst_schema) def ingest_table( self, data: pd.DataFrame | DuckDBPyRelation, schema: TableSchema, - connection: Optional[Connection] = None, **kwargs: Any, ) -> bool: - """Ingest data into the table specifed by schema. If the table does not exist, + """Ingest data into the table specified by schema. If the table does not exist, create it. Parameters @@ -692,13 +327,6 @@ def ingest_table( Input data to ingest into the database. schema Defines the destination table in the database. - connection - Optional connection to reuse. If adding many tables at once, it is significantly - faster to use one connection. Refer to :meth:`ingest_tables` for built-in support. - If connection is not set, chronify will commit the database changes - or perform a rollback on error. If it is set, the caller must perform those actions. - If you peform a rollback, you must call :meth:`rebuild_schema_cache` because the - Store will cache all table names in memory. Returns ------- @@ -709,141 +337,69 @@ def ingest_table( ------ InvalidTable Raised if the data does not match the schema. - - Examples - -------- - >>> store = Store() - >>> resolution = timedelta(hours=1) - >>> df = pd.DataFrame( - ... { - ... "timestamp": pd.date_range( - ... "2020-01-01", "2020-12-31 23:00:00", freq=resolution - ... ), - ... "value": np.random.random(8784), - ... } - ... ) - >>> df["id"] = 1 - >>> store.ingest_table( - ... df, - ... TableSchema( - ... name="devices", - ... value_column="value", - ... time_config=DatetimeRange( - ... time_column="timestamp", - ... start=datetime(2020, 1, 1, 0), - ... length=8784, - ... resolution=timedelta(hours=1), - ... ), - ... time_array_id_columns=["id"], - ... ), - ... ) - - See Also - -------- - ingest_tables """ - return self.ingest_tables((data,), schema, connection=connection, **kwargs) + return self.ingest_tables((data,), schema, **kwargs) def ingest_tables( self, data: Iterable[pd.DataFrame | DuckDBPyRelation], schema: TableSchema, - connection: Optional[Connection] = None, **kwargs: Any, ) -> bool: - """Ingest multiple input tables to the same database table. - All tables must have the same schema. - This offers significant performance advantages over calling :meth:`ingest_table` many - times. - - Parameters - ---------- - data - Input tables to ingest into one database table. - schema - Defines the destination table. - conn - Optional connection to reuse. Refer to :meth:`ingest_table` for notes. - - Returns - ------- - bool - Return True if a table was created. - - Raises - ------ - InvalidTable - Raised if the data does not match the schema. - - See Also - -------- - ingest_table - """ + """Ingest multiple input tables to the same database table.""" created_table = False if not data: return created_table try: - if connection is None: - with self._engine.begin() as conn: - created_table = self._ingest_tables(conn, data, schema, **kwargs) - else: - created_table = self._ingest_tables(connection, data, schema, **kwargs) + created_table = self._ingest_tables(data, schema, **kwargs) except Exception: - self._handle_sqlite_error_case(schema.name, connection) - if schema.name in self._metadata.tables: - self._metadata.remove(Table(schema.name, self._metadata)) + if self._backend.has_table(schema.name): + self._backend.drop_table(schema.name) raise - return created_table def _ingest_tables( self, - conn: Connection, data: Iterable[pd.DataFrame | DuckDBPyRelation], schema: TableSchema, skip_time_checks: bool = False, ) -> bool: created_table = False for table in data: - if self._ingest_table(conn, table, schema): + if self._ingest_table(table, schema): created_table = True if not skip_time_checks: - check_timestamps(conn, Table(schema.name, self._metadata), schema) + check_timestamps(self._backend, schema.name, schema) return created_table def _ingest_table( self, - conn: Connection, data: pd.DataFrame | DuckDBPyRelation, schema: TableSchema, ) -> bool: - if self._engine.name == "hive": - msg = "Data ingestion through Hive is not supported" - raise NotImplementedError(msg) df = data.to_df() if isinstance(data, DuckDBPyRelation) else data check_columns(df.columns, schema.list_columns()) - table = self.try_get_table(schema.name) - if table is None: - duckdb_types = get_duckdb_types_from_pandas(df) - dtypes = [get_sqlalchemy_type_from_duckdb(x) for x in duckdb_types] - table = Table( + if not self._backend.has_table(schema.name): + write_table( + self._backend, + df, schema.name, - self._metadata, - *[Column(x, y) for x, y in zip(df.columns, dtypes)], + [schema.time_config], + if_exists="fail", ) - self._metadata.create_all(conn) - created_table = True + self._schema_mgr.add_schema(schema) + return True else: - created_table = False - - write_database(df, conn, schema.name, [schema.time_config]) - - if created_table: - self._schema_mgr.add_schema(conn, schema) - - return created_table + write_table( + self._backend, + df, + schema.name, + [schema.time_config], + if_exists="append", + ) + return False def map_table_time_config( self, @@ -851,167 +407,37 @@ def map_table_time_config( dst_schema: TableSchema, data_adjustment: Optional[TimeBasedDataAdjustment] = None, wrap_time_allowed: bool = False, - scratch_dir: Optional[Path] = None, output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> None: """Map the existing table represented by src_name to a new table represented by - dst_schema with a different time configuration. - - Parameters - ---------- - src_name - Refers to the table name of the source data. - dst_schema - Defines the table to create in the database. Must not already exist. - data_adjustment - Defines how the dataframe may need to be adjusted with respect to time. - Data is only adjusted when the conditions apply. - wrap_time_allowed - Defines whether the time column is allowed to be wrapped according to the time - config in dst_schema when it does not line up with the time config - scratch_dir - Directory to use for temporary writes. Default to the system's tmp filesystem. - output_file - If set, write the mapped table to this Parquet file. - check_mapped_timestamps - Perform time checks on the result of the mapping operation. This can be slow and - is not required. - - Raises - ------ - InvalidTable - Raised if the schemas are incompatible. - TableAlreadyExists - Raised if the dst_schema name already exists. - - Examples - -------- - >>> store = Store() - >>> hours_per_year = 12 * 7 * 24 - >>> num_time_arrays = 3 - >>> df = pd.DataFrame( - ... { - ... "id": np.concatenate( - ... [np.repeat(i, hours_per_year) for i in range(1, 1 + num_time_arrays)] - ... ), - ... "month": np.tile(np.repeat(range(1, 13), 7 * 24), num_time_arrays), - ... "day_of_week": np.tile(np.tile(np.repeat(range(7), 24), 12), num_time_arrays), - ... "hour": np.tile(np.tile(range(24), 12 * 7), num_time_arrays), - ... "value": np.random.random(hours_per_year * num_time_arrays), - ... } - ... ) - >>> schema = TableSchema( - ... name="devices_by_representative_time", - ... value_column="value", - ... time_config=RepresentativePeriodTimeNTZ( - ... time_format=RepresentativePeriodFormat.ONE_WEEK_PER_MONTH_BY_HOUR, - ... ), - ... time_array_id_columns=["id"], - ... ) - >>> store.ingest_table(df, schema) - >>> store.map_table_time_config( - ... "devices_by_representative_time", - ... TableSchema( - ... name="devices_by_datetime", - ... value_column="value", - ... time_config=DatetimeRange( - ... time_column="timestamp", - ... start=datetime(2020, 1, 1, 0), - ... length=8784, - ... resolution=timedelta(hours=1), - ... ), - ... time_array_id_columns=["id"], - ... ), - ... ) - """ + dst_schema with a different time configuration.""" if self.has_table(dst_schema.name): msg = dst_schema.name raise TableAlreadyExists(msg) src_schema = self._schema_mgr.get_schema(src_name) map_time( - self._engine, - self._metadata, + self._backend, src_schema, dst_schema, data_adjustment=data_adjustment, wrap_time_allowed=wrap_time_allowed, - scratch_dir=scratch_dir, output_file=output_file, check_mapped_timestamps=check_mapped_timestamps, ) - with self._engine.begin() as conn: - self._schema_mgr.add_schema(conn, dst_schema) + self._schema_mgr.add_schema(dst_schema) def convert_time_zone( self, src_name: str, time_zone: tzinfo | None, - scratch_dir: Optional[Path] = None, output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> TableSchema: - """ - Convert the time zone of the existing table represented by src_name to a new time zone - - Parameters - ---------- - src_name - Refers to the table name of the source data. - time_zone - Time zone to convert to. - scratch_dir - Directory to use for temporary writes. Default to the system's tmp filesystem. - output_file - If set, write the mapped table to this Parquet file. - check_mapped_timestamps - Perform time checks on the result of the mapping operation. This can be slow and - is not required. - - Raises - ------ - TableAlreadyExists - Raised if the dst_schema name already exists. - - Examples - -------- - >>> store = Store() - >>> start = datetime(year=2018, month=1, day=1, tzinfo=ZoneInfo("Etc/GMT+5")) - >>> freq = timedelta(hours=1) - >>> hours_per_year = 8760 - >>> num_time_arrays = 1 - >>> df = pd.DataFrame( - ... { - ... "id": np.concatenate( - ... [np.repeat(i, hours_per_year) for i in range(1, 1 + num_time_arrays)] - ... ), - ... "timestamp": np.tile( - ... pd.date_range(start, periods=hours_per_year, freq="h"), num_time_arrays - ... ), - ... "value": np.random.random(hours_per_year * num_time_arrays), - ... } - ... ) - >>> schema = TableSchema( - ... name="some_data", - ... time_config=DatetimeRange( - ... time_column="timestamp", - ... start=start, - ... length=hours_per_year, - ... resolution=freq, - ... ), - ... time_array_id_columns=["id"], - ... value_column="value", - ... ) - >>> store.ingest_table(df, schema) - >>> to_time_zone = ZoneInfo("US/Mountain") - >>> dst_schema = store.convert_time_zone( - ... schema.name, to_time_zone, check_mapped_timestamps=True - ... ) - """ - + """Convert the time zone of the existing table represented by src_name to a new time zone.""" src_schema = self._schema_mgr.get_schema(src_name) - tzc = TimeZoneConverter(self._engine, self._metadata, src_schema, time_zone) + tzc = TimeZoneConverter(self._backend, src_schema, time_zone) dst_schema = tzc.generate_to_schema() if self.has_table(dst_schema.name): @@ -1019,14 +445,10 @@ def convert_time_zone( raise TableAlreadyExists(msg) tzc.convert_time_zone( - scratch_dir=scratch_dir, output_file=output_file, check_mapped_timestamps=check_mapped_timestamps, ) - - with self._engine.begin() as conn: - self._schema_mgr.add_schema(conn, dst_schema) - + self._schema_mgr.add_schema(dst_schema) return dst_schema def convert_time_zone_by_column( @@ -1034,78 +456,13 @@ def convert_time_zone_by_column( src_name: str, time_zone_column: str, wrap_time_allowed: bool = False, - scratch_dir: Optional[Path] = None, output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> TableSchema: - """ - Convert the time zone of the existing table represented by src_name to new time zone(s) defined by a column - - Parameters - ---------- - src_name - Refers to the table name of the source data. - time_zone_column - Name of the time zone column for conversion. - wrap_time_allowed - Defines whether the time column is allowed to be wrapped to reflect the same time - range as the src_name schema in tz-naive clock time - scratch_dir - Directory to use for temporary writes. Default to the system's tmp filesystem. - output_file - If set, write the mapped table to this Parquet file. - check_mapped_timestamps - Perform time checks on the result of the mapping operation. This can be slow and - is not required. - - Raises - ------ - TableAlreadyExists - Raised if the dst_schema name already exists. - - Examples - -------- - >>> store = Store() - >>> start = datetime(year=2018, month=1, day=1, tzinfo=ZoneInfo("Etc/GMT+5")) - >>> freq = timedelta(hours=1) - >>> hours_per_year = 8760 - >>> num_time_arrays = 3 - >>> df = pd.DataFrame( - ... { - ... "id": np.concatenate( - ... [np.repeat(i, hours_per_year) for i in range(1, 1 + num_time_arrays)] - ... ), - ... "timestamp": np.tile( - ... pd.date_range(start, periods=hours_per_year, freq="h"), num_time_arrays - ... ), - ... "time_zone": np.repeat(["US/Eastern", "US/Mountain", "None"], hours_per_year), - ... "value": np.random.random(hours_per_year * num_time_arrays), - ... } - ... ) - >>> schema = TableSchema( - ... name="some_data", - ... time_config=DatetimeRange( - ... time_column="timestamp", - ... start=start, - ... length=hours_per_year, - ... resolution=freq, - ... ), - ... time_array_id_columns=["id"], - ... value_column="value", - ... ) - >>> store.ingest_table(df, schema) - >>> time_zone_column = "time_zone" - >>> dst_schema = store.convert_time_zone_by_column( - ... schema.name, - ... time_zone_column, - ... wrap_time_allowed=False, - ... check_mapped_timestamps=True, - ... ) - """ - + """Convert the time zone of the existing table to time zone(s) defined by a column.""" src_schema = self._schema_mgr.get_schema(src_name) tzc = TimeZoneConverterByColumn( - self._engine, self._metadata, src_schema, time_zone_column, wrap_time_allowed + self._backend, src_schema, time_zone_column, wrap_time_allowed ) dst_schema = tzc.generate_to_schema() @@ -1114,89 +471,22 @@ def convert_time_zone_by_column( raise TableAlreadyExists(msg) tzc.convert_time_zone( - scratch_dir=scratch_dir, output_file=output_file, check_mapped_timestamps=check_mapped_timestamps, ) - - with self._engine.begin() as conn: - self._schema_mgr.add_schema(conn, dst_schema) - + self._schema_mgr.add_schema(dst_schema) return dst_schema def localize_time_zone( self, src_name: str, time_zone: tzinfo | None, - scratch_dir: Optional[Path] = None, output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> TableSchema: - """ - Localize the time zone of the existing table represented by src_name to a specified time zone - - Parameters - ---------- - src_name - Refers to the table name of the source data. - time_zone - Standard time zone to localize to. If None, keep as tz-naive. - scratch_dir - Directory to use for temporary writes. Default to the system's tmp filesystem. - output_file - If set, write the mapped table to this Parquet file. - check_mapped_timestamps - Perform time checks on the result of the mapping operation. This can be slow and - is not required. - - Raises - ------ - TableAlreadyExists - Raised if the dst_schema name already exists. - - Returns - ------- - TableSchema - The schema of the newly created table. - - Examples - -------- - >>> store = Store() - >>> start = datetime(year=2018, month=1, day=1) # tz-naive - >>> freq = timedelta(hours=1) - >>> hours_per_year = 8760 - >>> num_time_arrays = 1 - >>> df = pd.DataFrame( - ... { - ... "id": np.concatenate( - ... [np.repeat(i, hours_per_year) for i in range(1, 1 + num_time_arrays)] - ... ), - ... "timestamp": np.tile( - ... pd.date_range(start, periods=hours_per_year, freq="h"), num_time_arrays - ... ), - ... "value": np.random.random(hours_per_year * num_time_arrays), - ... } - ... ) - >>> schema = TableSchema( - ... name="some_data", - ... time_config=DatetimeRange( - ... time_column="timestamp", - ... start=start, - ... length=hours_per_year, - ... resolution=freq, - ... ), - ... time_array_id_columns=["id"], - ... value_column="value", - ... ) - >>> store.ingest_table(df, schema) - >>> to_time_zone = ZoneInfo("Etc/GMT+5") - >>> dst_schema = store.localize_time_zone( - ... schema.name, to_time_zone, check_mapped_timestamps=True - ... ) - """ - + """Localize the time zone of the existing table to a specified time zone.""" src_schema = self._schema_mgr.get_schema(src_name) - tzl = TimeZoneLocalizer(self._engine, self._metadata, src_schema, time_zone) + tzl = TimeZoneLocalizer(self._backend, src_schema, time_zone) dst_schema = tzl.generate_to_schema() if self.has_table(dst_schema.name): @@ -1204,94 +494,22 @@ def localize_time_zone( raise TableAlreadyExists(msg) tzl.localize_time_zone( - scratch_dir=scratch_dir, output_file=output_file, check_mapped_timestamps=check_mapped_timestamps, ) - - with self._engine.begin() as conn: - self._schema_mgr.add_schema(conn, dst_schema) - + self._schema_mgr.add_schema(dst_schema) return dst_schema def localize_time_zone_by_column( self, src_name: str, time_zone_column: Optional[str] = None, - scratch_dir: Optional[Path] = None, output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> TableSchema: - """ - Localize the time zone of the existing table represented by src_name to time zones defined by a column - - Parameters - ---------- - src_name - Refers to the table name of the source data. - time_zone_column - Name of the time zone column for localization, default to None - scratch_dir - Directory to use for temporary writes. Default to the system's tmp filesystem. - output_file - If set, write the mapped table to this Parquet file. - check_mapped_timestamps - Perform time checks on the result of the mapping operation. This can be slow and - is not required. - - Raises - ------ - TableAlreadyExists - Raised if the dst_schema name already exists. - - Returns - ------- - TableSchema - The schema of the newly created table. - - Examples - -------- - >>> store = Store() - >>> start = datetime(year=2018, month=1, day=1) # tz-naive - >>> freq = timedelta(hours=1) - >>> hours_per_year = 8760 - >>> num_time_arrays = 3 - >>> df = pd.DataFrame( - ... { - ... "id": np.concatenate( - ... [np.repeat(i, hours_per_year) for i in range(1, 1 + num_time_arrays)] - ... ), - ... "timestamp": np.tile( - ... pd.date_range(start, periods=hours_per_year, freq="h"), num_time_arrays - ... ), - ... "time_zone": np.repeat( - ... ["Etc/GMT+5", "Etc/GMT+6", "Etc/GMT+7"], hours_per_year - ... ), # EST, CST, MST - ... "value": np.random.random(hours_per_year * num_time_arrays), - ... } - ... ) - >>> schema = TableSchema( - ... name="some_data", - ... time_config=DatetimeRange( - ... time_column="timestamp", - ... start=start, - ... length=hours_per_year, - ... resolution=freq, - ... ), - ... time_array_id_columns=["id"], - ... value_column="value", - ... ) - >>> store.ingest_table(df, schema) - >>> time_zone_column = "time_zone" - >>> dst_schema = store.localize_time_zone_by_column( - ... schema.name, - ... time_zone_column, - ... check_mapped_timestamps=True, - ... ) - """ - + """Localize the time zone of the existing table to time zones defined by a column.""" src_schema = self._schema_mgr.get_schema(src_name) - tzl = TimeZoneLocalizerByColumn(self._engine, self._metadata, src_schema, time_zone_column) + tzl = TimeZoneLocalizerByColumn(self._backend, src_schema, time_zone_column) dst_schema = tzl.generate_to_schema() if self.has_table(dst_schema.name): @@ -1299,22 +517,16 @@ def localize_time_zone_by_column( raise TableAlreadyExists(msg) tzl.localize_time_zone( - scratch_dir=scratch_dir, output_file=output_file, check_mapped_timestamps=check_mapped_timestamps, ) - - with self._engine.begin() as conn: - self._schema_mgr.add_schema(conn, dst_schema) - + self._schema_mgr.add_schema(dst_schema) return dst_schema def read_query( self, name: str, - query: Selectable | str, - params: Any = None, - connection: Optional[Connection] = None, + query: ir.Table | str, ) -> pd.DataFrame: """Return the query result as a pandas DataFrame. @@ -1323,108 +535,42 @@ def read_query( name Table or view name query - SQL query expressed as a string or salqlchemy Selectable - params - Parameters for SQL query if expressed as a string - - Examples - -------- - >>> df = store.read_query("SELECT * FROM devices") - >>> df = store.read_query("SELECT * FROM devices WHERE id = ?", params=(3,)) - - >>> from sqlalchemy import select - >>> table = store.schemas.get_table("devices") - >>> df = store.read_query(select(table).where(table.c.id == 3) + SQL query as a string or ibis Table expression """ - schema = self._schema_mgr.get_schema(name, conn=connection) - if connection is None: - with self._engine.begin() as conn: - return read_database(query, conn, schema.time_config, params=params) + schema = self._schema_mgr.get_schema(name) + if isinstance(query, str): + expr = self._backend.sql(query) else: - return read_database(query, connection, schema.time_config, params=params) + expr = query + return read_query(self._backend, expr, schema.time_config) - def read_table(self, name: str, connection: Optional[Connection] = None) -> pd.DataFrame: + def read_table(self, name: str) -> pd.DataFrame: """Return the table as a pandas DataFrame.""" - table = self.get_table(name) - stmt = select(table) - return self.read_query(name, stmt, connection=connection) + schema = self._schema_mgr.get_schema(name) + return read_table(self._backend, name, schema.time_config) - def read_raw_query( - self, query: str, params: Any = None, connection: Optional[Connection] = None - ) -> pd.DataFrame: - """Execute a query directly on the backend database connection, bypassing sqlalchemy, and - return the results as a DataFrame. + def read_raw_query(self, query: str) -> pd.DataFrame: + """Execute a query directly on the backend and return the results as a DataFrame. Note: Unlike :meth:`read_query`, no conversion of timestamps is performed. - Timestamps will be in the format of the underlying database. SQLite backends will return - strings instead of datetime. - - Parameters - ---------- - query - SQL query to execute - params - Optional parameters for SQL query - conn - Optional sqlalchemy connection returned by `Store.engine.connect()`. This can - improve performance when performing many reads. If used for database modifications, - it is the caller's responsibility to perform a commit and ensure that the connection - is closed correctly. Use of sqlalchemy's context manager is recommended. - - Examples - -------- - >>> store = Store() - >>> query1 = "SELECT * from my_table WHERE column = ?" - >>> params1 = ("value1",) - >>> query2 = "SELECT * from my_table WHERE column = ?'" - >>> params2 = ("value2",) - - >>> df = store.read_raw_query(query1, params=params1) - - >>> with store.engine.connect() as conn: - ... df1 = store.read_raw_query(query1, params=params1, connection=conn) - ... df2 = store.read_raw_query(query2, params=params2, connection=conn) """ - if connection is None: - with self._engine.connect() as conn: - return self._read_raw_query(query, params, conn) - else: - return self._read_raw_query(query, params, connection) - - def _read_raw_query(self, query: str, params: Any, conn: Connection) -> pd.DataFrame: - assert conn._dbapi_connection is not None - assert conn._dbapi_connection.driver_connection is not None - match self._engine.name: - case "duckdb": - df = conn._dbapi_connection.driver_connection.sql(query, params=params).to_df() - assert isinstance(df, pd.DataFrame) - return df - case "sqlite": - return pd.read_sql(query, conn._dbapi_connection.driver_connection, params=params) - case _: - msg = self._engine.name - raise NotImplementedError(msg) + return self._backend.execute_sql_to_df(query) def write_query_to_parquet( self, - stmt: Selectable, + stmt: ir.Table | str, file_path: Path | str, overwrite: bool = False, partition_columns: Optional[list[str]] = None, ) -> None: """Write the result of a query to a Parquet file.""" - # We could add a separate path where the query is a string and skip the intermediate - # view if we passed parameters through the call stack. - view_name = make_temp_view_name() - create_view(view_name, stmt, self._engine, self._metadata) - try: - self.write_table_to_parquet( - view_name, file_path, overwrite=overwrite, partition_columns=partition_columns - ) - finally: - with self._engine.connect() as conn: - conn.execute(text(f"DROP VIEW {view_name}")) - self._metadata.remove(Table(view_name, self._metadata)) + write_parquet( + self._backend, + stmt, + to_path(file_path), + overwrite=overwrite, + partition_columns=partition_columns, + ) def write_table_to_parquet( self, @@ -1438,9 +584,10 @@ def write_table_to_parquet( msg = f"table {name=} is not stored" raise TableNotStored(msg) - write_query_to_parquet( - self._engine, - f"SELECT * FROM {name}", + expr = self._backend.table(name) + write_parquet( + self._backend, + expr, to_path(file_path), overwrite=overwrite, partition_columns=partition_columns, @@ -1451,7 +598,6 @@ def delete_rows( self, name: str, time_array_id_values: dict[str, Any], - connection: Optional[Connection] = None, ) -> int: """Delete all rows matching the time_array_id_values. @@ -1461,28 +607,20 @@ def delete_rows( Name of table time_array_id_values Values for the time_array_id_values. Keys must match the columns in the schema. - connnection - Optional connection to the database. Refer :meth:`ingest_table` for notes. Returns ------- int Number of deleted rows - - Examples - -------- - >>> store.delete_rows("devices", {"id": 47}) """ - # TODO: consider supporting a user-defined query. Would need to check consistency - # afterwards. - # The current approach doesn't need to check because only one single complete time - # array can be deleted on each call. - table = self.get_table(name) + if not self.has_table(name): + msg = f"{name=}" + raise TableNotStored(msg) if not time_array_id_values: msg = "time_array_id_values cannot be empty" raise InvalidParameter(msg) - schema = self._schema_mgr.get_schema(name, conn=connection) + schema = self._schema_mgr.get_schema(name) if sorted(time_array_id_values.keys()) != sorted(schema.time_array_id_columns): msg = ( "The keys of time_array_id_values must match the schema columns. " @@ -1491,27 +629,24 @@ def delete_rows( ) raise InvalidParameter(msg) - assert time_array_id_values - stmt = delete(table) - - # duckdb does not offer a way to retrieve the number of deleted rows, so we must - # compute it manually. - # Deletions are not common. We are trading accuracy for peformance. - count_stmt = select(func.count()).select_from(table) - + # Count rows before delete + where_clauses = [] for column, value in time_array_id_values.items(): - stmt = stmt.where(table.c[column] == value) - count_stmt = count_stmt.where(table.c[column] == value) + if isinstance(value, str): + where_clauses.append(f"{column} = '{value}'") + else: + where_clauses.append(f"{column} = {value}") + where_str = " AND ".join(where_clauses) - if connection is None: - with self._engine.begin() as conn: - num_deleted = self._run_delete(conn, stmt, count_stmt) - else: - num_deleted = self._run_delete(connection, stmt, count_stmt) - # Let the caller commit or rollback when ready. + count_df = self._backend.execute_sql_to_df( + f"SELECT COUNT(*) as cnt FROM {name} WHERE {where_str}" + ) + num_to_delete = int(count_df.iloc[0, 0]) - if num_deleted < 1: - msg = f"Failed to delete rows: {stmt=} {num_deleted=}" + self._backend.execute_sql(f"DELETE FROM {name} WHERE {where_str}") + + if num_to_delete < 1: + msg = f"Failed to delete rows: {where_str} {num_to_delete=}" raise InvalidParameter(msg) logger.info( @@ -1520,88 +655,36 @@ def delete_rows( time_array_id_values, ) - stmt2 = select(table).limit(1) - is_empty = False - if connection is None: - with self._engine.connect() as conn: - res = conn.execute(stmt2).fetchall() - else: - res = connection.execute(stmt2).fetchall() - - if not res: - is_empty = True - - if is_empty: + # Check if table is now empty + remaining = self._backend.execute_sql_to_df(f"SELECT COUNT(*) as cnt FROM {name}") + if int(remaining.iloc[0, 0]) == 0: logger.info("Delete empty table {}", name) - self.drop_table(name, connection=connection) - - return num_deleted - - def _run_delete(self, conn: Connection, stmt: Any, count_stmt: Any) -> int: - count1: int | None = None - if self._engine.name == "duckdb": - res1 = conn.execute(count_stmt).fetchone() - assert res1 is not None - count1 = res1[0] - res = conn.execute(stmt) - if self._engine.name == "duckdb": - res2 = conn.execute(count_stmt).fetchone() - assert res2 is not None - count2 = res2[0] - assert count1 is not None - num_deleted = count1 - count2 - else: - num_deleted = res.rowcount - return num_deleted # type: ignore + self.drop_table(name) - def drop_table( - self, - name: str, - connection: Optional[Connection] = None, - if_exists: bool = False, - ) -> None: + return num_to_delete + + def drop_table(self, name: str, if_exists: bool = False) -> None: """Drop a table from the database.""" - self._drop_table_or_view(name, "TABLE", connection, if_exists) + if not if_exists and not self.has_table(name): + msg = f"{name=}" + raise TableNotStored(msg) + self._backend.drop_table(name) + self._schema_mgr.remove_schema(name) + logger.info("Dropped table {}", name) - def create_view(self, schema: TableSchema, stmt: Selectable) -> None: + def create_view(self, schema: TableSchema, stmt: ir.Table) -> None: """Create a view in the database.""" - create_view(schema.name, stmt, self._engine, self._metadata) - with self._engine.begin() as conn: - self._schema_mgr.add_schema(conn, schema) + self._backend.create_view(schema.name, stmt) + self._schema_mgr.add_schema(schema) - def drop_view( - self, - name: str, - connection: Optional[Connection] = None, - if_exists: bool = False, - ) -> None: + def drop_view(self, name: str, if_exists: bool = False) -> None: """Drop a view from the database.""" - self._drop_table_or_view(name, "VIEW", connection, if_exists) - - def _drop_table_or_view( - self, - name: str, - table_type: str, - connection: Optional[Connection], - if_exists: bool, - ) -> None: - table = self.get_table(name) - if_exists_str = " IF EXISTS" if if_exists else "" - if connection is None: - with self._engine.begin() as conn: - conn.execute(text(f"DROP {table_type} {if_exists_str} {name}")) - self._schema_mgr.remove_schema(conn, name) - else: - connection.execute(text(f"DROP {table_type} {if_exists_str} {name}")) - self._schema_mgr.remove_schema(connection, name) - - self._metadata.remove(table) - logger.info("Dropped {} {}", table_type.lower(), name) - - def _handle_sqlite_error_case(self, name: str, connection: Optional[Connection]) -> None: - if connection is None and self._engine.name == "sqlite": - with self._engine.begin() as conn: - conn.execute(text(f"DROP TABLE IF EXISTS {name}")) + if not if_exists and not self.has_table(name): + msg = f"{name=}" + raise TableNotStored(msg) + self._backend.drop_view(name) + self._schema_mgr.remove_schema(name) + logger.info("Dropped view {}", name) def check_columns( diff --git a/src/chronify/time_series_checker.py b/src/chronify/time_series_checker.py index 53c36f1..e00f233 100644 --- a/src/chronify/time_series_checker.py +++ b/src/chronify/time_series_checker.py @@ -1,13 +1,13 @@ -from sqlalchemy import Connection, Table, select, text from typing import Optional from datetime import datetime, tzinfo import pandas as pd from chronify.exceptions import InvalidTable +from chronify.ibis.base import IbisBackend +from chronify.ibis.functions import read_query from chronify.models import TableSchema from chronify.time_configs import DatetimeRangeWithTZColumn -from chronify.sqlalchemy.functions import read_database from chronify.time_range_generator_factory import make_time_range_generator from chronify.datetime_range_generator import DatetimeRangeGeneratorExternalTimeZone from chronify.time import LeapDayAdjustmentType @@ -15,14 +15,14 @@ def check_timestamps( - conn: Connection, - table: Table, + backend: IbisBackend, + table_name: str, schema: TableSchema, leap_day_adjustment: Optional[LeapDayAdjustmentType] = None, ) -> None: """Performs checks on time series arrays in a table.""" TimeSeriesChecker( - conn, table, schema, leap_day_adjustment=leap_day_adjustment + backend, table_name, schema, leap_day_adjustment=leap_day_adjustment ).check_timestamps() @@ -34,14 +34,14 @@ class TimeSeriesChecker: def __init__( self, - conn: Connection, - table: Table, + backend: IbisBackend, + table_name: str, schema: TableSchema, leap_day_adjustment: Optional[LeapDayAdjustmentType] = None, ) -> None: - self._conn = conn + self._backend = backend self._schema = schema - self._table = table + self._table_name = table_name self._time_generator = make_time_range_generator( schema.time_config, leap_day_adjustment=leap_day_adjustment ) @@ -68,10 +68,11 @@ def _check_expected_timestamps_datetime(self) -> int: """For tz-naive or tz-aware time without external time zone column""" expected = self._time_generator.list_timestamps() time_columns = self._time_generator.list_time_columns() - stmt = select(*(self._table.c[x] for x in time_columns)).distinct() + table = self._backend.table(self._table_name) + expr = table.select(time_columns).distinct() for col in time_columns: - stmt = stmt.where(self._table.c[col].is_not(None)) - df = read_database(stmt, self._conn, self._schema.time_config) + expr = expr.filter(table[col].notnull()) + df = read_query(self._backend, expr, self._schema.time_config) actual = self._time_generator.list_distinct_timestamps_from_dataframe(df) expected = sorted(set(expected)) # drop duplicates for tz-naive prevailing time check_timestamp_lists(actual, expected) @@ -86,10 +87,11 @@ def _check_expected_timestamps_with_external_time_zone(self) -> int: time_columns = self._time_generator.list_time_columns() assert isinstance(self._schema.time_config, DatetimeRangeWithTZColumn) # for mypy time_columns.append(self._schema.time_config.get_time_zone_column()) - stmt = select(*(self._table.c[x] for x in time_columns)).distinct() + table = self._backend.table(self._table_name) + expr = table.select(time_columns).distinct() for col in time_columns: - stmt = stmt.where(self._table.c[col].is_not(None)) - df = read_database(stmt, self._conn, self._schema.time_config) + expr = expr.filter(table[col].notnull()) + df = read_query(self._backend, expr, self._schema.time_config) actual_dct = self._time_generator.list_distinct_timestamps_by_time_zone_from_dataframe(df) if sorted(expected_dct.keys()) != sorted(actual_dct.keys()): msg = ( @@ -121,16 +123,16 @@ def _check_null_consistency(self) -> None: any_are_null = " OR ".join((f"{x} IS NULL" for x in time_columns)) query_all = f"SELECT COUNT(*) FROM {self._schema.name} WHERE {all_are_null}" query_any = f"SELECT COUNT(*) FROM {self._schema.name} WHERE {any_are_null}" - res_all = self._conn.execute(text(query_all)).fetchone() - assert res_all is not None - res_any = self._conn.execute(text(query_any)).fetchone() - assert res_any is not None - if res_all[0] != res_any[0]: + df_all = self._backend.execute_sql_to_df(query_all) + df_any = self._backend.execute_sql_to_df(query_any) + count_all = df_all.iloc[0, 0] + count_any = df_any.iloc[0, 0] + if count_all != count_any: msg = ( "If any time columns have a NULL value for a row, all time columns in that " "row must be NULL. " - f"Row count where all time values are NULL: {res_all[0]}. " - f"Row count where any time values are NULL: {res_any[0]}. " + f"Row count where all time values are NULL: {count_all}. " + f"Row count where any time values are NULL: {count_any}. " ) raise InvalidTable(msg) @@ -138,14 +140,12 @@ def _check_expected_timestamps_by_time_array(self, count: int) -> None: if isinstance( self._time_generator, DatetimeRangeGeneratorExternalTimeZone ) and self._has_prevailing_time_zone(self._schema.time_config.get_time_zones()): - # cannot check counts by timestamps when tz-naive prevailing time zones are present has_tz_naive_prevailing = True else: has_tz_naive_prevailing = False id_cols = ",".join(self._schema.time_array_id_columns) time_cols = ",".join(self._schema.time_config.list_time_columns()) - # NULL consistency was checked above. where_clause = f"{self._time_generator.list_time_columns()[0]} IS NOT NULL" on_expr = " AND ".join([f"t1.{x} = t2.{x}" for x in self._schema.time_array_id_columns]) t1_id_cols = ",".join((f"t1.{x}" for x in self._schema.time_array_id_columns)) @@ -199,12 +199,13 @@ def _check_expected_timestamps_by_time_array(self, count: int) -> None: ON {on_expr} """ - for result in self._conn.execute(text(query)).fetchall(): - distinct_count_by_ta = result[0] - count_by_ta = result[1] + df = self._backend.execute_sql_to_df(query) + for _, result in df.iterrows(): + distinct_count_by_ta = result.iloc[0] + count_by_ta = result.iloc[1] if has_tz_naive_prevailing and not count_by_ta == count: - id_vals = result[2:] + id_vals = result.iloc[2:] values = ", ".join( f"{x}={y}" for x, y in zip(self._schema.time_array_id_columns, id_vals) ) @@ -216,7 +217,7 @@ def _check_expected_timestamps_by_time_array(self, count: int) -> None: raise InvalidTable(msg) if not has_tz_naive_prevailing and not count_by_ta == count == distinct_count_by_ta: - id_vals = result[2:] + id_vals = result.iloc[2:] values = ", ".join( f"{x}={y}" for x, y in zip(self._schema.time_array_id_columns, id_vals) ) diff --git a/src/chronify/time_series_mapper.py b/src/chronify/time_series_mapper.py index a03a466..709eb4b 100644 --- a/src/chronify/time_series_mapper.py +++ b/src/chronify/time_series_mapper.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import Optional -from sqlalchemy import Engine, MetaData +from chronify.ibis.base import IbisBackend from chronify.models import TableSchema from chronify.time_series_mapper_representative import MapperRepresentativeTimeToDatetime @@ -20,13 +20,11 @@ def map_time( - engine: Engine, - metadata: MetaData, + backend: IbisBackend, from_schema: TableSchema, to_schema: TableSchema, data_adjustment: Optional[TimeBasedDataAdjustment] = None, wrap_time_allowed: bool = False, - scratch_dir: Optional[Path] = None, output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> None: @@ -35,9 +33,8 @@ def map_time( to_schema.time_config, DatetimeRange ): MapperRepresentativeTimeToDatetime( - engine, metadata, from_schema, to_schema, data_adjustment, wrap_time_allowed + backend, from_schema, to_schema, data_adjustment, wrap_time_allowed ).map_time( - scratch_dir=scratch_dir, output_file=output_file, check_mapped_timestamps=check_mapped_timestamps, ) @@ -45,9 +42,8 @@ def map_time( to_schema.time_config, DatetimeRange ): MapperDatetimeToDatetime( - engine, metadata, from_schema, to_schema, data_adjustment, wrap_time_allowed + backend, from_schema, to_schema, data_adjustment, wrap_time_allowed ).map_time( - scratch_dir=scratch_dir, output_file=output_file, check_mapped_timestamps=check_mapped_timestamps, ) @@ -55,21 +51,17 @@ def map_time( to_schema.time_config, DatetimeRange ): MapperIndexTimeToDatetime( - engine, metadata, from_schema, to_schema, data_adjustment, wrap_time_allowed + backend, from_schema, to_schema, data_adjustment, wrap_time_allowed ).map_time( - scratch_dir=scratch_dir, output_file=output_file, check_mapped_timestamps=check_mapped_timestamps, ) elif isinstance(from_schema.time_config, ColumnRepresentativeBase) and isinstance( to_schema.time_config, DatetimeRange ): - # No way to generate expected timestamps for YearMonthDayPeriodTimeNTZ - # Is there a way to only check the output datetime timestamps? MapperColumnRepresentativeToDatetime( - engine, metadata, from_schema, to_schema, data_adjustment, wrap_time_allowed + backend, from_schema, to_schema, data_adjustment, wrap_time_allowed ).map_time( - scratch_dir=scratch_dir, output_file=output_file, check_mapped_timestamps=from_schema.time_config.check_timestamps, ) diff --git a/src/chronify/time_series_mapper_base.py b/src/chronify/time_series_mapper_base.py index cdd8cc0..a875966 100644 --- a/src/chronify/time_series_mapper_base.py +++ b/src/chronify/time_series_mapper_base.py @@ -1,22 +1,14 @@ import abc -from functools import reduce -from operator import and_ from pathlib import Path from typing import Any, Optional import pandas as pd from loguru import logger -from sqlalchemy import Engine, MetaData, Table, select, text, func -from chronify.hive_functions import create_materialized_view - -from chronify.sqlalchemy.functions import ( - create_view_from_parquet, - write_database, - write_query_to_parquet, -) + +from chronify.ibis.base import IbisBackend +from chronify.ibis.functions import write_parquet, write_table, create_view_from_parquet from chronify.models import TableSchema, MappingTableSchema from chronify.exceptions import ConflictingInputsError, InvalidOperation -from chronify.utils.sqlalchemy_table import create_table from chronify.time_series_checker import check_timestamps from chronify.time import TimeIntervalType, ResamplingOperationType, AggregationType from chronify.time_configs import TimeBasedDataAdjustment @@ -28,19 +20,16 @@ class TimeSeriesMapperBase(abc.ABC): def __init__( self, - engine: Engine, - metadata: MetaData, + backend: IbisBackend, from_schema: TableSchema, to_schema: TableSchema, data_adjustment: Optional[TimeBasedDataAdjustment] = None, wrap_time_allowed: bool = False, resampling_operation: Optional[ResamplingOperationType] = None, ) -> None: - self._engine = engine - self._metadata = metadata + self._backend = backend self._from_schema = from_schema self._to_schema = to_schema - # data_adjustment is used in mapping creation and time check of mapped time self._data_adjustment = data_adjustment or TimeBasedDataAdjustment() self._wrap_time_allowed = wrap_time_allowed self._adjust_interval = ( @@ -58,7 +47,7 @@ def _check_table_columns_producibility(self) -> None: available_cols = ( self._from_schema.list_columns() + self._to_schema.time_config.list_time_columns() ) - final_cols = self._to_schema.list_columns() # does not include pass-thru columns + final_cols = self._to_schema.list_columns() if diff := set(final_cols) - set(available_cols): msg = f"Source table {self._from_schema.name} cannot produce the columns: {diff}" raise ConflictingInputsError(msg) @@ -91,145 +80,138 @@ def apply_mapping( mapping_schema: MappingTableSchema, from_schema: TableSchema, to_schema: TableSchema, - engine: Engine, - metadata: MetaData, + backend: IbisBackend, data_adjustment: TimeBasedDataAdjustment, resampling_operation: Optional[ResamplingOperationType] = None, - scratch_dir: Optional[Path] = None, output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> None: - """ - Apply mapping to create result table with process to clean up and roll back if checks fail - """ - with engine.begin() as conn: - write_database( - df_mapping, - conn, - mapping_schema.name, - mapping_schema.time_configs, - if_table_exists="fail", - scratch_dir=scratch_dir, - ) - metadata.reflect(engine, views=True) + """Apply mapping to create result table with process to clean up and roll back if checks fail.""" + write_table( + backend, + df_mapping, + mapping_schema.name, + mapping_schema.time_configs, + if_exists="fail", + ) created_tmp_view = False try: _apply_mapping( mapping_schema.name, from_schema, to_schema, - engine, - metadata, + backend, resampling_operation=resampling_operation, - scratch_dir=scratch_dir, output_file=output_file, ) if check_mapped_timestamps: if output_file is not None: output_file = to_path(output_file) - with engine.begin() as conn: - create_view_from_parquet(conn, to_schema.name, output_file) - metadata.reflect(engine, views=True) + create_view_from_parquet(backend, output_file, to_schema.name) created_tmp_view = True - mapped_table = Table(to_schema.name, metadata) - with engine.connect() as conn: - try: - check_timestamps( - conn, - mapped_table, - to_schema, - leap_day_adjustment=data_adjustment.leap_day_adjustment, - ) - except Exception: - logger.exception( - "check_timestamps failed on mapped table {}. Drop it", - to_schema.name, - ) - if output_file is None: - table_type = "VIEW" if engine.name == "hive" else "TABLE" - conn.execute(text(f"DROP {table_type} {to_schema.name}")) - raise + try: + check_timestamps( + backend, + to_schema.name, + to_schema, + leap_day_adjustment=data_adjustment.leap_day_adjustment, + ) + except Exception: + logger.exception( + "check_timestamps failed on mapped table {}. Drop it", + to_schema.name, + ) + if output_file is None: + backend.drop_table(to_schema.name) + raise finally: - with engine.begin() as conn: - table_type = "view" if engine.name == "hive" else "table" - conn.execute(text(f"DROP {table_type} IF EXISTS {mapping_schema.name}")) - - if created_tmp_view: - conn.execute(text(f"DROP VIEW IF EXISTS {to_schema.name}")) - metadata.remove(Table(to_schema.name, metadata)) - - metadata.remove(Table(mapping_schema.name, metadata)) - metadata.reflect(engine, views=True) + if backend.has_table(mapping_schema.name): + backend.drop_table(mapping_schema.name) + if created_tmp_view: + backend.drop_view(to_schema.name) -def _apply_mapping( +def _apply_mapping( # noqa: C901 mapping_table_name: str, from_schema: TableSchema, to_schema: TableSchema, - engine: Engine, - metadata: MetaData, + backend: IbisBackend, resampling_operation: Optional[ResamplingOperationType] = None, - scratch_dir: Optional[Path] = None, output_file: Optional[Path] = None, ) -> None: - """Apply mapping to create result as a table according to_schema - - Columns used to join the from_table are prefixed with "from_" in the mapping table + """Apply mapping to create result as a table according to_schema. + Columns used to join the from_table are prefixed with "from_" in the mapping table. """ - left_table = Table(from_schema.name, metadata) - right_table = Table(mapping_table_name, metadata) - left_table_columns = [x.name for x in left_table.columns] - right_table_columns = [x.name for x in right_table.columns] - left_table_pass_thru_columns = set(left_table_columns).difference( - set(from_schema.list_columns()) - ) - - val_col = to_schema.value_column # from left_table - final_cols = set(to_schema.list_columns()).union(left_table_pass_thru_columns) - right_cols = set(right_table_columns).intersection(final_cols) + left = backend.table(from_schema.name) + right = backend.table(mapping_table_name) + left_columns = left.columns + right_columns = right.columns + left_pass_thru_columns = set(left_columns) - set(from_schema.list_columns()) + + val_col = to_schema.value_column + final_cols = set(to_schema.list_columns()) | left_pass_thru_columns + right_cols = set(right_columns) & final_cols left_cols = final_cols - right_cols - {val_col} - select_stmt: list[Any] = [left_table.c[x] for x in left_cols] - select_stmt += [right_table.c[x] for x in right_cols] + # Build join predicates + from_keys = [x for x in right_columns if x.startswith("from_")] + keys = [x.removeprefix("from_") for x in from_keys] + assert set(keys).issubset(set(left_columns)), f"Keys {keys} not in table={from_schema.name}" + predicates = [] + for k in keys: + left_col = left[k] + right_col = right["from_" + k] + # Cast to match types if needed (e.g., string vs int from pivoted columns) + if left_col.type() != right_col.type(): + right_col = right_col.cast(left_col.type()) + predicates.append(left_col == right_col) + + # Perform the join + joined = left.join(right, predicates) + + # In ibis joins, conflicting right-side columns get a "_right" suffix. + # Left-side columns keep their original names. + def _left_col(col: str) -> Any: + """Access a left-table column (keeps original name after join).""" + return joined[col] + + def _right_col(col: str) -> Any: + """Access a right-table column, handling disambiguation.""" + if col in left_columns and col in right_columns: + return joined[col + "_right"] + return joined[col] + + # Build value expression + val_expr: Any = _left_col(val_col) if val_col not in right_columns else _left_col(val_col) + if val_col in right_columns and val_col in left_columns: + # val_col exists in both; we want the left (source) value + val_expr = _left_col(val_col) + if "factor" in right_columns: + val_expr = val_expr * _right_col("factor") + + # Build select columns + select_exprs: list[Any] = [] + for col in left_cols: + select_exprs.append(_left_col(col).name(col)) + for col in right_cols: + select_exprs.append(_right_col(col).name(col)) - tval_col = left_table.c[val_col] - if "factor" in right_table_columns: - tval_col *= right_table.c["factor"] # type: ignore if not resampling_operation: - select_stmt.append(tval_col) + select_exprs.append(val_expr.name(val_col)) + result = joined.select(select_exprs) else: - groupby_stmt = select_stmt.copy() + group_exprs = select_exprs.copy() match resampling_operation: case AggregationType.SUM: - select_stmt.append(func.sum(tval_col).label(val_col)) - # case AggregationType.AVG: - # select_stmt.append(func.avg(tval_col).label(val_col)) - # case AggregationType.MIN: - # select_stmt.append(func.min(tval_col).label(val_col)) - # case AggregationType.MAX: - # select_stmt.append(func.max(tval_col).label(val_col)) + select_exprs.append(val_expr.sum().name(val_col)) case _: msg = f"Unsupported {resampling_operation=}" raise InvalidOperation(msg) - - from_keys = [x for x in right_table_columns if x.startswith("from_")] - keys = [x.removeprefix("from_") for x in from_keys] - assert set(keys).issubset( - set(left_table_columns) - ), f"Keys {keys} not in table={from_schema.name}" - on_stmt = reduce(and_, (left_table.c[x] == right_table.c["from_" + x] for x in keys)) - - query = select(*select_stmt).select_from(left_table).join(right_table, on_stmt) - if resampling_operation: - query = query.group_by(*groupby_stmt) + result = joined.group_by(group_exprs).aggregate(val_expr.sum().name(val_col)) if output_file is not None: output_file = to_path(output_file) - write_query_to_parquet(engine, str(query), output_file, overwrite=True) + write_parquet(backend, result, output_file, overwrite=True) return - if engine.name == "hive": - create_materialized_view( - str(query), to_schema.name, engine, metadata, scratch_dir=scratch_dir - ) - else: - create_table(to_schema.name, query, engine, metadata) + backend.create_table(to_schema.name, result) diff --git a/src/chronify/time_series_mapper_column_representative_to_datetime.py b/src/chronify/time_series_mapper_column_representative_to_datetime.py index ed549de..cecde15 100644 --- a/src/chronify/time_series_mapper_column_representative_to_datetime.py +++ b/src/chronify/time_series_mapper_column_representative_to_datetime.py @@ -1,11 +1,12 @@ from typing import Optional, Generator import re -import sqlalchemy as sa from pathlib import Path import pandas as pd from datetime import datetime from chronify.exceptions import InvalidParameter, InvalidValue +from chronify.ibis.base import IbisBackend +from chronify.ibis.functions import write_table from chronify.time_series_mapper_base import TimeSeriesMapperBase, apply_mapping from chronify.time_configs import ( YearMonthDayHourTimeNTZ, @@ -18,8 +19,6 @@ ) from chronify.datetime_range_generator import DatetimeRangeGenerator from chronify.models import MappingTableSchema, TableSchema -from chronify.sqlalchemy.functions import read_database, write_database -from chronify.utils.sqlalchemy_table import create_table class MapperColumnRepresentativeToDatetime(TimeSeriesMapperBase): @@ -51,16 +50,13 @@ class MapperColumnRepresentativeToDatetime(TimeSeriesMapperBase): def __init__( self, - engine: sa.Engine, - metadata: sa.MetaData, + backend: IbisBackend, from_schema: TableSchema, to_schema: TableSchema, data_adjustment: Optional[TimeBasedDataAdjustment] = None, wrap_time_allowed: bool = False, ) -> None: - super().__init__( - engine, metadata, from_schema, to_schema, data_adjustment, wrap_time_allowed - ) + super().__init__(backend, from_schema, to_schema, data_adjustment, wrap_time_allowed) if not isinstance(to_schema.time_config, DatetimeRange): msg = "Target schema does not have DatetimeRange time config. Use a different mapper." @@ -74,7 +70,6 @@ def __init__( def map_time( self, - scratch_dir: Optional[Path] = None, output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> None: @@ -87,7 +82,7 @@ def map_time( elif isinstance(self._from_time_config, MonthDayHourTimeNTZ): df_mapping, mapping_schema = self._create_mdh_mapping() elif isinstance(self._from_time_config, YearMonthDayPeriodTimeNTZ): - int_mapping = self._intermediate_mapping_ymdp_to_ymdh(scratch_dir) + int_mapping = self._intermediate_mapping_ymdp_to_ymdh() from_schema = int_mapping drop_table = int_mapping.name df_mapping, mapping_schema = self._create_ymdh_mapping( @@ -102,25 +97,20 @@ def map_time( mapping_schema, from_schema, self._to_schema, - self._engine, - self._metadata, + self._backend, self._data_adjustment, - scratch_dir=scratch_dir, output_file=output_file, check_mapped_timestamps=check_mapped_timestamps, ) if drop_table: - with self._engine.begin() as conn: - table_type = "view" if self._engine.name == "hive" else "table" - conn.execute(sa.text(f"DROP {table_type} IF EXISTS {drop_table}")) + self._backend.drop_table(drop_table) def check_schema_consistency(self) -> None: if isinstance(self._from_time_config, MonthDayHourTimeNTZ): self._validate_mdh_time_config() def _validate_length_and_resolution(self) -> None: - # not true for all input time config types if self._from_time_config.length != self._to_time_config.length: msg = "Length of time series arrays must match." raise InvalidParameter(msg) @@ -133,40 +123,42 @@ def _validate_mdh_time_config(self) -> None: msg = "Year is required for mdh time range to be converter to DatetimeRange." raise InvalidParameter(msg) - def _intermediate_mapping_ymdp_to_ymdh(self, scratch_dir: Path | None) -> TableSchema: + def _intermediate_mapping_ymdp_to_ymdh(self) -> TableSchema: """Convert ymdp to ymdh for intermediate mapping.""" mapping_table_name = "intermediate_ymdp_to_ymdh" period_col = self._from_time_config.hour_columns[0] - with self._engine.begin() as conn: - periods = read_database( - f"SELECT DISTINCT {period_col} FROM {self._from_schema.name}", - conn, - self._from_time_config, - ) - df_mapping = generate_period_mapping(periods.iloc[:, 0]) - write_database( - df_mapping, - conn, - mapping_table_name, - [self._from_time_config], - if_table_exists="replace", - scratch_dir=scratch_dir, - ) - self._metadata.reflect(self._engine) - ymdp_table = sa.Table(self._from_schema.name, self._metadata) - mapping_table = sa.Table(mapping_table_name, self._metadata) + # Get distinct periods + df_periods = self._backend.execute_sql_to_df( + f"SELECT DISTINCT {period_col} FROM {self._from_schema.name}" + ) + df_mapping = generate_period_mapping(df_periods.iloc[:, 0]) + write_table( + self._backend, + df_mapping, + mapping_table_name, + [self._from_time_config], + if_exists="fail", + ) + + # Build the join query using ibis + ymdp_table = self._backend.table(self._from_schema.name) + mapping_table = self._backend.table(mapping_table_name) - select_statement = [col for col in ymdp_table.columns if col.name != period_col] - select_statement.append(mapping_table.c["hour"]) - query = ( - sa.select(*select_statement) - .select_from(ymdp_table) - .join(mapping_table, ymdp_table.c[period_col] == mapping_table.c["from_period"]) + # Select all columns from ymdp except the period column, plus the hour column from mapping + ymdp_cols = [c for c in ymdp_table.columns if c != period_col] + select_exprs = [ymdp_table[c] for c in ymdp_cols] + [mapping_table["hour"]] + + joined = ymdp_table.join( + mapping_table, ymdp_table[period_col] == mapping_table["from_period"] ) + result = joined.select(select_exprs) intermediate_ymdh_table_name = "intermediate_Ymdh" - create_table(intermediate_ymdh_table_name, query, self._engine, self._metadata) + self._backend.create_table(intermediate_ymdh_table_name, result) + + # Clean up mapping table + self._backend.drop_table(mapping_table_name) assert isinstance( self._from_time_config, YearMonthDayPeriodTimeNTZ diff --git a/src/chronify/time_series_mapper_datetime.py b/src/chronify/time_series_mapper_datetime.py index 8971d67..8fc14a3 100644 --- a/src/chronify/time_series_mapper_datetime.py +++ b/src/chronify/time_series_mapper_datetime.py @@ -3,8 +3,8 @@ from typing import Optional import pandas as pd -from sqlalchemy import Engine, MetaData +from chronify.ibis.base import IbisBackend from chronify.models import TableSchema, MappingTableSchema from chronify.exceptions import InvalidParameter, ConflictingInputsError from chronify.time_series_mapper_base import TimeSeriesMapperBase, apply_mapping @@ -22,16 +22,13 @@ class MapperDatetimeToDatetime(TimeSeriesMapperBase): def __init__( self, - engine: Engine, - metadata: MetaData, + backend: IbisBackend, from_schema: TableSchema, to_schema: TableSchema, data_adjustment: Optional[TimeBasedDataAdjustment] = None, wrap_time_allowed: bool = False, ) -> None: - super().__init__( - engine, metadata, from_schema, to_schema, data_adjustment, wrap_time_allowed - ) + super().__init__(backend, from_schema, to_schema, data_adjustment, wrap_time_allowed) if self._from_schema == self._to_schema and self._data_adjustment is None: msg = f"from_schema is the same as to_schema and no data_adjustment, nothing to do.\n{self._from_schema}" logger.info(msg) @@ -65,7 +62,6 @@ def _check_time_length(self) -> None: def map_time( self, - scratch_dir: Optional[Path] = None, output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> None: @@ -77,14 +73,11 @@ def map_time( mapping_schema, self._from_schema, self._to_schema, - self._engine, - self._metadata, + self._backend, self._data_adjustment, - scratch_dir=scratch_dir, output_file=output_file, check_mapped_timestamps=check_mapped_timestamps, ) - # TODO - add handling for changing resolution - Issue #30 def _create_mapping(self) -> tuple[pd.DataFrame, MappingTableSchema]: """Create mapping dataframe @@ -98,20 +91,15 @@ def _create_mapping(self) -> tuple[pd.DataFrame, MappingTableSchema]: ).list_timestamps() ser_from = pd.Series(from_time_data) - # If from_tz or to_tz is naive, use tz_localize fm_tz = self._from_time_config.start.tzinfo to_tz = self._to_time_config.start.tzinfo match (fm_tz is None, to_tz is None): case (True, False): - # get standard time zone of to_tz to_tz_std = get_standard_time_zone(to_tz) - # tz-naive time does not have skips/dups, so always localize in std tz first ser_from = ser_from.dt.tz_localize(to_tz_std).dt.tz_convert(to_tz) pass case (False, True): - # get standard time zone of fm_tz fm_tz_std = get_standard_time_zone(fm_tz) - # convert to standard time zone of fm_tz before making it tz-naive ser_from = ser_from.dt.tz_convert(fm_tz_std).dt.tz_localize(to_tz) pass match (self._adjust_interval, self._wrap_time_allowed): diff --git a/src/chronify/time_series_mapper_index_time.py b/src/chronify/time_series_mapper_index_time.py index 43cd4c3..6812db7 100644 --- a/src/chronify/time_series_mapper_index_time.py +++ b/src/chronify/time_series_mapper_index_time.py @@ -5,8 +5,9 @@ from datetime import datetime, timedelta import pandas as pd -from sqlalchemy import Engine, MetaData, Table, select +from chronify.ibis.base import IbisBackend +from chronify.ibis.functions import read_query from chronify.models import TableSchema, MappingTableSchema from chronify.exceptions import InvalidParameter, ConflictingInputsError from chronify.time_series_mapper_base import TimeSeriesMapperBase, apply_mapping @@ -20,7 +21,6 @@ from chronify.time_range_generator_factory import make_time_range_generator from chronify.time_series_mapper_datetime import MapperDatetimeToDatetime from chronify.time import TimeDataType, TimeType, DaylightSavingAdjustmentType, AggregationType -from chronify.sqlalchemy.functions import read_database logger = logging.getLogger(__name__) @@ -28,17 +28,13 @@ class MapperIndexTimeToDatetime(TimeSeriesMapperBase): def __init__( self, - engine: Engine, - metadata: MetaData, + backend: IbisBackend, from_schema: TableSchema, to_schema: TableSchema, data_adjustment: Optional[TimeBasedDataAdjustment] = None, wrap_time_allowed: bool = False, ) -> None: - # TODO: refactor to use new time configs - Issue #64 - super().__init__( - engine, metadata, from_schema, to_schema, data_adjustment, wrap_time_allowed - ) + super().__init__(backend, from_schema, to_schema, data_adjustment, wrap_time_allowed) self._dst_adjustment = self._data_adjustment.daylight_saving_adjustment if not isinstance(self._from_schema.time_config, IndexTimeRangeBase): msg = "Source schema does not have IndexTimeRangeBase time config. Use a different mapper." @@ -70,14 +66,12 @@ def _check_time_length(self) -> None: def map_time( self, - scratch_dir: Optional[Path] = None, output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> None: """Convert from index time to its represented datetime""" self.check_schema_consistency() - # Convert from index time to its represented datetime if self._from_time_config.time_type == TimeType.INDEX_TZ_COL: if ( self._dst_adjustment @@ -119,24 +113,19 @@ def map_time( mapping_schema, self._from_schema, mapped_schema, - self._engine, - self._metadata, + self._backend, TimeBasedDataAdjustment(), resampling_operation=resampling_operation, - scratch_dir=scratch_dir, check_mapped_timestamps=False, ) - # Convert from represented datetime to dst time_config MapperDatetimeToDatetime( - self._engine, - self._metadata, + self._backend, mapped_schema, self._to_schema, self._data_adjustment, self._wrap_time_allowed, ).map_time( - scratch_dir=scratch_dir, output_file=output_file, check_mapped_timestamps=check_mapped_timestamps, ) @@ -167,8 +156,6 @@ def _create_local_time_config(self, time_zone: str) -> DatetimeRange: ) time_kwargs["time_type"] = TimeType.DATETIME if isinstance(self._from_time_config.start_timestamp, datetime): - # TODO: this is a hack. datetime is correct but only is present when Hive is used. - # The code requires pandas Timestamps. time_kwargs["start"] = pd.Timestamp(self._from_time_config.start_timestamp) else: time_kwargs["start"] = self._from_time_config.start_timestamp @@ -219,10 +206,9 @@ def _create_interm_map_with_time_zone( assert tz_col is not None, "Expecting a time zone column for INDEX_TZ_COL" from_tz_col = "from_" + tz_col - with self._engine.connect() as conn: - table = Table(self._from_schema.name, self._metadata) - stmt = select(table.c[tz_col]).distinct().where(table.c[tz_col].is_not(None)) - time_zones = read_database(stmt, conn, self._from_time_config)[tz_col].to_list() + table = self._backend.table(self._from_schema.name) + expr = table.select(tz_col).distinct().filter(table[tz_col].notnull()) + time_zones = read_query(self._backend, expr, self._from_time_config)[tz_col].to_list() from_time_config = self._from_time_config.model_copy( update={"time_column": from_time_col, "time_zone_column": from_tz_col} @@ -234,8 +220,6 @@ def _create_interm_map_with_time_zone( for time_zone in time_zones: config_tz = self._create_local_time_config(time_zone) time_data = make_time_range_generator(config_tz).list_timestamps() - # Preemptively convert to dst time tzinfo, otherwise pandas treats the col, - # which consists of the timeseries of different time zones, as an object col mapped_time_data = [x.tz_convert(to_tz) for x in time_data] df_tz.append( pd.DataFrame( @@ -248,7 +232,6 @@ def _create_interm_map_with_time_zone( ) df = pd.concat(df_tz, ignore_index=True) - # Update mapped_schema mapped_schema.time_config.start = df[mapped_time_col].min() mapped_schema.time_config.length = df[mapped_time_col].nunique() mapped_schema.time_config.dtype = TimeDataType.TIMESTAMP_TZ @@ -264,10 +247,7 @@ def _create_interm_map_with_time_zone_and_dst_adjustment( interpolate_fallback: bool = False, ) -> tuple[pd.DataFrame, MappingTableSchema, TableSchema]: """Create mapping dataframe for converting INDEX_TZ_COL time to its represented datetime - with time-based daylight_saving adjustment that - drops the spring-forward hour and, per user input, - interpolates or duplicates the fall-back hour - """ + with time-based daylight_saving adjustment.""" mapped_schema = self._create_intermediate_schema() assert isinstance(mapped_schema.time_config, DatetimeRange) mapped_time_col = mapped_schema.time_config.time_column @@ -290,10 +270,9 @@ def _create_interm_map_with_time_zone_and_dst_adjustment( assert tz_col is not None, "Expecting a time zone column for INDEX_TZ_COL" from_tz_col = "from_" + tz_col - with self._engine.connect() as conn: - table = Table(self._from_schema.name, self._metadata) - stmt = select(table.c[tz_col]).distinct().where(table.c[tz_col].is_not(None)) - time_zones = read_database(stmt, conn, self._from_time_config)[tz_col].to_list() + table = self._backend.table(self._from_schema.name) + expr = table.select(tz_col).distinct().filter(table[tz_col].notnull()) + time_zones = read_query(self._backend, expr, self._from_time_config)[tz_col].to_list() from_time_config = self._from_time_config.model_copy( update={"time_column": from_time_col, "time_zone_column": from_tz_col} @@ -314,7 +293,6 @@ def _create_interm_map_with_time_zone_and_dst_adjustment( df = pd.concat(df_tz, ignore_index=True) df = df.merge(df_ntz, on="clock_time").drop(columns=["clock_time"]) - # Update mapped_schema mapped_schema.time_config.start = df[mapped_time_col].min() mapped_schema.time_config.length = df[mapped_time_col].nunique() mapped_schema.time_config.dtype = TimeDataType.TIMESTAMP_TZ @@ -330,17 +308,14 @@ def _create_fallback_duplication_map( ) -> pd.DataFrame: config_tz = self._create_local_time_config(time_zone) time_data = make_time_range_generator(config_tz).list_timestamps() - # Extract clock time clock_time_data = [x.strftime("%Y-%m-%d %H:%M:%S") for x in time_data] - # Preemptively convert to dst time tzinfo, otherwise pandas treats the col, - # which consists of the timeseries of different time zones, as an object col to_tz = self._to_time_config.start.tzinfo mapped_time_data = [x.tz_convert(to_tz) for x in time_data] df_map = pd.DataFrame( { from_tz_col: time_zone, - "clock_time": clock_time_data, # str, mapping key + "clock_time": clock_time_data, mapped_time_col: mapped_time_data, } ) @@ -351,16 +326,13 @@ def _create_fallback_interpolation_map( ) -> pd.DataFrame: config_tz = self._create_local_time_config(time_zone) time_data = make_time_range_generator(config_tz).list_timestamps() - # Extract clock time clock_time_data = [x.strftime("%Y-%m-%d %H:%M:%S") for x in time_data] - # Preemptively convert to dst time tzinfo, otherwise pandas treats the col, - # which consists of the timeseries of different time zones, as an object col to_tz = self._to_time_config.start.tzinfo mapped_time_data = [x.tz_convert(to_tz) for x in time_data] df_map = pd.DataFrame( { - "clock_time": clock_time_data, # str, mapping key + "clock_time": clock_time_data, mapped_time_col: mapped_time_data, "factor": 1, } @@ -369,25 +341,21 @@ def _create_fallback_interpolation_map( assert (limit % 1 == 0) and (limit > 0), f"limit must be an integer, {limit}" limit = int(limit) - # create interpolation map by locating where timestamp is duplicated cond = df_map["clock_time"].duplicated() df_map.loc[cond, "clock_time"] = np.nan df_map.loc[cond, "factor"] = np.nan df_map["lb"] = df_map["clock_time"].ffill(limit=limit).where(df_map["clock_time"].isna()) df_map["ub"] = df_map["clock_time"].bfill(limit=limit).where(df_map["clock_time"].isna()) - # calculate ub_factor by counting consecutive values in ub x = ~df_map["ub"].isna() consecutive_count = x * (x.groupby((x != x.shift()).cumsum()).cumcount() + 1) df_map["ub_factor"] = consecutive_count.replace(0, np.nan) / (limit + 1) df_map["lb_factor"] = 1 - df_map["ub_factor"] - # capping: if a row do not have both lb and ub, cannot interpolate, set factor to 1 for fact_col in ["lb_factor", "ub_factor"]: cond = ~(df_map[fact_col].where(df_map["lb"].isna() | df_map["ub"].isna()).isna()) df_map.loc[cond, fact_col] = 1 - # finalize table by reducing columns lst = [] for ts_col, fact_col in zip( ["clock_time", "lb", "ub"], ["factor", "lb_factor", "ub_factor"] diff --git a/src/chronify/time_series_mapper_representative.py b/src/chronify/time_series_mapper_representative.py index 9c545cd..93674bf 100644 --- a/src/chronify/time_series_mapper_representative.py +++ b/src/chronify/time_series_mapper_representative.py @@ -3,13 +3,11 @@ from typing import Optional import pandas as pd -from sqlalchemy import Engine, MetaData, Table, select -from chronify.sqlalchemy.functions import read_database +from chronify.ibis.base import IbisBackend +from chronify.ibis.functions import read_query from chronify.models import TableSchema, MappingTableSchema -from chronify.exceptions import ( - InvalidParameter, -) +from chronify.exceptions import InvalidParameter from chronify.time_range_generator_factory import make_time_range_generator from chronify.time_series_mapper_base import TimeSeriesMapperBase, apply_mapping from chronify.representative_time_range_generator import RepresentativePeriodTimeGenerator @@ -27,16 +25,13 @@ class MapperRepresentativeTimeToDatetime(TimeSeriesMapperBase): def __init__( self, - engine: Engine, - metadata: MetaData, + backend: IbisBackend, from_schema: TableSchema, to_schema: TableSchema, data_adjustment: Optional[TimeBasedDataAdjustment] = None, wrap_time_allowed: bool = False, ) -> None: - super().__init__( - engine, metadata, from_schema, to_schema, data_adjustment, wrap_time_allowed - ) + super().__init__(backend, from_schema, to_schema, data_adjustment, wrap_time_allowed) if not isinstance(from_schema.time_config, RepresentativePeriodTimeBase): msg = "source schema does not have RepresentativePeriodTimeBase time config. Use a different mapper." raise InvalidParameter(msg) @@ -55,7 +50,6 @@ def check_schema_consistency(self) -> None: def map_time( self, - scratch_dir: Optional[Path] = None, output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> None: @@ -69,19 +63,14 @@ def map_time( mapping_schema, self._from_schema, self._to_schema, - self._engine, - self._metadata, + self._backend, self._data_adjustment, - scratch_dir=scratch_dir, output_file=output_file, check_mapped_timestamps=check_mapped_timestamps, ) def _create_mapping(self, is_tz_naive: bool) -> tuple[pd.DataFrame, MappingTableSchema]: - """Create mapping dataframe - - Handles time interval type adjustment - - Columns used to join the from_table are prefixed with "from_" - """ + """Create mapping dataframe.""" timestamp_generator = make_time_range_generator( self._to_time_config, leap_day_adjustment=self._data_adjustment.leap_day_adjustment ) @@ -91,8 +80,6 @@ def _create_mapping(self, is_tz_naive: bool) -> tuple[pd.DataFrame, MappingTable if self._adjust_interval: time_col = "to_" + to_time_col - # Mapping works backward for representative time by shifting interval type of - # to_time_config to match from_time_config before extracting time info dft[time_col] = shifted_interval_timestamps( dft[to_time_col].tolist(), self._to_time_config.interval_type, @@ -107,10 +94,9 @@ def _create_mapping(self, is_tz_naive: bool) -> tuple[pd.DataFrame, MappingTable else: tz_col = self._from_time_config.get_time_zone_column() assert tz_col is not None, "Expecting a time zone column for REPRESENTATIVE time" - with self._engine.connect() as conn: - table = Table(self._from_schema.name, self._metadata) - stmt = select(table.c[tz_col]).distinct().where(table.c[tz_col].is_not(None)) - time_zones = read_database(stmt, conn, self._from_time_config)[tz_col].to_list() + table = self._backend.table(self._from_schema.name) + expr = table.select(tz_col).distinct().filter(table[tz_col].notnull()) + time_zones = read_query(self._backend, expr, self._from_time_config)[tz_col].to_list() df = self._generator.create_tz_aware_mapping_dataframe( dft, time_col, time_zones, tz_col ) @@ -125,7 +111,7 @@ def _create_mapping(self, is_tz_naive: bool) -> tuple[pd.DataFrame, MappingTable mapping_schema = MappingTableSchema( name="mapping_table", time_configs=[ - self._to_time_config, # only DatetimeRange + self._to_time_config, ], ) return df, mapping_schema diff --git a/src/chronify/time_zone_converter.py b/src/chronify/time_zone_converter.py index 0a1414c..4e0447e 100644 --- a/src/chronify/time_zone_converter.py +++ b/src/chronify/time_zone_converter.py @@ -1,11 +1,12 @@ import abc from zoneinfo import ZoneInfo from datetime import tzinfo -from sqlalchemy import Engine, MetaData, Table, select from typing import Optional from pathlib import Path import pandas as pd +from chronify.ibis.base import IbisBackend +from chronify.ibis.functions import read_query from chronify.models import TableSchema, MappingTableSchema from chronify.time_configs import ( DatetimeRangeBase, @@ -19,108 +20,42 @@ from chronify.exceptions import InvalidParameter, MissingValue from chronify.time_series_mapper_base import apply_mapping from chronify.time_range_generator_factory import make_time_range_generator -from chronify.sqlalchemy.functions import read_database from chronify.time import TimeDataType, TimeType from chronify.time_utils import wrapped_time_timestamps, get_tzname -# TODO - allow option to retain original timestamp column - Issue #64 def convert_time_zone( - engine: Engine, - metadata: MetaData, + backend: IbisBackend, src_schema: TableSchema, to_time_zone: tzinfo | None, - scratch_dir: Optional[Path] = None, output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> TableSchema: """Convert time zone of a table to a specified time zone. Output timestamp is tz-naive with a new time_zone column added. - - Parameters - ---------- - engine : sqlalchemy.Engine - SQLAlchemy engine. - metadata : sqlalchemy.MetaData - SQLAlchemy metadata. - src_schema : TableSchema - Defines the source table in the database. - to_time_zone : tzinfo or None - Time zone to convert to. If None, convert to tz-naive. - scratch_dir : pathlib.Path, optional - Directory to use for temporary writes. Defaults to the system's tmp filesystem. - output_file : pathlib.Path, optional - If set, write the mapped table to this Parquet file. - check_mapped_timestamps : bool, optional - Perform time checks on the result of the mapping operation. This can be slow and - is not required. - - Returns - ------- - TableSchema - Schema of output table with converted timestamps. """ - tzc = TimeZoneConverter(engine, metadata, src_schema, to_time_zone) + tzc = TimeZoneConverter(backend, src_schema, to_time_zone) tzc.convert_time_zone( - scratch_dir=scratch_dir, output_file=output_file, check_mapped_timestamps=check_mapped_timestamps, ) - return tzc._to_schema def convert_time_zone_by_column( - engine: Engine, - metadata: MetaData, + backend: IbisBackend, src_schema: TableSchema, time_zone_column: str, wrap_time_allowed: Optional[bool] = False, - scratch_dir: Optional[Path] = None, output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> TableSchema: """Convert time zone of a table to multiple time zones specified by a column. Output timestamp is tz-naive, reflecting the local time relative to the time_zone_column. - - Parameters - ---------- - engine : sqlalchemy.Engine - sqlalchemy engine - metadata : sqlalchemy.MetaData - sqlalchemy metadata - src_schema : TableSchema - Defines the source table in the database. - time_zone_column : str - Column name in the source table that contains the time zone information. - wrap_time_allowed : bool - If False, the converted timestamps will be aligned with the original timestamps in real time scale - E.g. 2018-01-01 00:00 ~ 2018-12-31 23:00 in US/Eastern becomes - 2017-12-31 23:00 ~ 2018-12-31 22:00 in US/Central - If True, the converted timestamps will fit into the time range of the src_schema in tz-naive clock time - E.g. 2018-01-01 00:00 ~ 2018-12-31 23:00 in US/Eastern becomes - 2017-12-31 23:00 ~ 2018-12-31 22:00 in US/Central, which is then wrapped such that - no clock time timestamps are in 2017. The final timestamps are: - 2018-12-31 23:00, 2018-01-01 00:00 ~ 2018-12-31 22:00 in US/Central - scratch_dir : pathlib.Path, optional - Directory to use for temporary writes. Default to the system's tmp filesystem. - output_file : pathlib.Path, optional - If set, write the mapped table to this Parquet file. - check_mapped_timestamps : bool, optional - Perform time checks on the result of the mapping operation. This can be slow and - is not required. - - Returns - ------- - dst_schema : TableSchema - schema of output table with converted timestamps """ - tzc = TimeZoneConverterByColumn( - engine, metadata, src_schema, time_zone_column, wrap_time_allowed - ) + tzc = TimeZoneConverterByColumn(backend, src_schema, time_zone_column, wrap_time_allowed) tzc.convert_time_zone( - scratch_dir=scratch_dir, output_file=output_file, check_mapped_timestamps=check_mapped_timestamps, ) @@ -132,12 +67,10 @@ class TimeZoneConverterBase(abc.ABC): def __init__( self, - engine: Engine, - metadata: MetaData, + backend: IbisBackend, from_schema: TableSchema, ): - self._engine = engine - self._metadata = metadata + self._backend = backend self._check_from_schema(from_schema) self._from_schema = from_schema @@ -171,7 +104,6 @@ def generate_to_schema(self) -> TableSchema: @abc.abstractmethod def convert_time_zone( self, - scratch_dir: Optional[Path] = None, output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> None: @@ -179,31 +111,20 @@ def convert_time_zone( class TimeZoneConverter(TimeZoneConverterBase): - """Class for time zone conversion of tz-aware, aligned_in_absolute_time - time series data to a specified time zone. - - Input data table must contain tz-aware timestamps. - Input time config must be of type DatetimeRange with Timestamp_TZ dtype and tz-aware start time. - Output data table will contain tz-naive timestamps with time zone recorded in a column - Output time config will be of type DatetimeRange with Timestamp_NTZ dtype and tz-naive start time. - - # TODO: support DatetimeRangeWithTZColumn as input time config - Issue #64 - # TODO: support wrap_time_allowed option - Issue #64 - """ + """Convert tz-aware timestamps to a specified time zone (tz-naive output).""" def __init__( self, - engine: Engine, - metadata: MetaData, + backend: IbisBackend, from_schema: TableSchema, to_time_zone: tzinfo | None, ): - super().__init__(engine, metadata, from_schema) + super().__init__(backend, from_schema) self._to_time_zone = to_time_zone self._to_schema = self.generate_to_schema() def generate_to_time_config(self) -> DatetimeRangeWithTZColumn: - assert isinstance(self._from_schema.time_config, DatetimeRange) # mypy + assert isinstance(self._from_schema.time_config, DatetimeRange) time_kwargs = self._from_schema.time_config.model_dump() time_kwargs = dict( filter( @@ -239,40 +160,35 @@ def generate_to_schema(self) -> TableSchema: def convert_time_zone( self, - scratch_dir: Optional[Path] = None, output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> None: df, mapping_schema = self._create_mapping() - apply_mapping( df, mapping_schema, self._from_schema, self._to_schema, - self._engine, - self._metadata, + self._backend, TimeBasedDataAdjustment(), - scratch_dir=scratch_dir, output_file=output_file, check_mapped_timestamps=check_mapped_timestamps, ) def _create_mapping(self) -> tuple[pd.DataFrame, MappingTableSchema]: - """Create mapping dataframe for converting datetime to geography-based time zone""" - assert isinstance(self._from_schema.time_config, DatetimeRange) # mypy + assert isinstance(self._from_schema.time_config, DatetimeRange) time_col = self._from_schema.time_config.time_column from_time_col = "from_" + time_col from_time_data = make_time_range_generator(self._from_schema.time_config).list_timestamps() to_time_generator = make_time_range_generator(self._to_schema.time_config) - assert isinstance(to_time_generator, DatetimeRangeGeneratorExternalTimeZone) # mypy + assert isinstance(to_time_generator, DatetimeRangeGeneratorExternalTimeZone) to_time_data_dct = to_time_generator.list_timestamps_by_time_zone() from_time_config = self._from_schema.time_config.model_copy( update={"time_column": from_time_col} ) to_time_config = self._to_schema.time_config - assert isinstance(to_time_config, DatetimeRangeWithTZColumn) # mypy + assert isinstance(to_time_config, DatetimeRangeWithTZColumn) tz_col = to_time_config.time_zone_column tz_name = get_tzname(self._to_time_zone) to_time_data = to_time_data_dct[tz_name] @@ -292,39 +208,11 @@ def _create_mapping(self) -> tuple[pd.DataFrame, MappingTableSchema]: class TimeZoneConverterByColumn(TimeZoneConverterBase): - """Class for time zone conversion of tz-aware, aligned_in_absolute_time - time series data based on a time zone column. - - Input data table must contain tz-aware timestamps and a time zone column. - Input time config must be of type DatetimeRangeWithTZColumn or DatetimeRange with Timestamp_TZ dtype. - - If DatetimeRange is used, time_zone_column must be provided. - - If DatetimeRangeWithTZColumn is used, it is converted to DatetimeRange internally. - time_zone_column, if provided, is ignored and instead taken from the time_config. - Output data table will contain tz-naive timestamps and the original time zone column. - Output time config will be of type DatetimeRangeWithTZColumn with Timestamp_NTZ dtype (see scenarios). - - I/O Time config scenarios: - -------------------------------- - To convert tz-aware timestamps aligned_in_absolute_time to multiple time zones specified in a column: - - wrap_time_allowed = False - - Input time config: DatetimeRange with tz-aware start time, Timestamp_TZ dtype - - Output time config: DatetimeRangeWithTZColumn with tz-aware start time, Timestamp_NTZ dtype - - To convert tz-aware timestamps aligned_in_absolute_time to multiple time zones specified in a column - and aligned_in_local_standard_time: - - wrap_time_allowed = True - - Input time config: DatetimeRange with tz-aware start time, Timestamp_TZ dtype - - Output time config: DatetimeRangeWithTZColumn with tz-naive start time, Timestamp_NTZ dtype - Note: converted time is wrapped within the local time range of the original timestamps. - -------------------------------- - - # TODO: support DatetimeRangeWithTZColumn as input time config - Issue #64 - """ + """Convert tz-aware timestamps to multiple time zones specified by a column.""" def __init__( self, - engine: Engine, - metadata: MetaData, + backend: IbisBackend, from_schema: TableSchema, time_zone_column: str, wrap_time_allowed: Optional[bool] = False, @@ -332,13 +220,13 @@ def __init__( if time_zone_column not in from_schema.time_array_id_columns: msg = f"{time_zone_column=} is missing from {from_schema.time_array_id_columns=}" raise MissingValue(msg) - super().__init__(engine, metadata, from_schema) + super().__init__(backend, from_schema) self.time_zone_column = time_zone_column self._wrap_time_allowed = wrap_time_allowed self._to_schema = self.generate_to_schema() def generate_to_time_config(self) -> DatetimeRangeBase: - assert isinstance(self._from_schema.time_config, DatetimeRange) # mypy + assert isinstance(self._from_schema.time_config, DatetimeRange) time_kwargs = self._from_schema.time_config.model_dump() time_kwargs = dict( filter( @@ -370,48 +258,41 @@ def generate_to_schema(self) -> TableSchema: def convert_time_zone( self, - scratch_dir: Optional[Path] = None, output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> None: df, mapping_schema = self._create_mapping() - apply_mapping( df, mapping_schema, self._from_schema, self._to_schema, - self._engine, - self._metadata, + self._backend, TimeBasedDataAdjustment(), - scratch_dir=scratch_dir, output_file=output_file, check_mapped_timestamps=check_mapped_timestamps, ) def _get_time_zones(self) -> list[tzinfo | None]: - with self._engine.connect() as conn: - table = Table(self._from_schema.name, self._metadata) - stmt = ( - select(table.c[self.time_zone_column]) - .distinct() - .where(table.c[self.time_zone_column].is_not(None)) - ) - time_zones = read_database(stmt, conn, self._from_schema.time_config)[ - self.time_zone_column - ].to_list() - + table = self._backend.table(self._from_schema.name) + expr = ( + table.select(self.time_zone_column) + .distinct() + .filter(table[self.time_zone_column].notnull()) + ) + time_zones = read_query(self._backend, expr, self._from_schema.time_config)[ + self.time_zone_column + ].to_list() time_zones = [None if tz == "None" else ZoneInfo(tz) for tz in time_zones] return time_zones def _create_mapping(self) -> tuple[pd.DataFrame, MappingTableSchema]: - """Create mapping dataframe for converting datetime to column time zones""" - assert isinstance(self._from_schema.time_config, DatetimeRange) # mypy + assert isinstance(self._from_schema.time_config, DatetimeRange) time_col = self._from_schema.time_config.time_column from_time_col = "from_" + time_col from_time_data = make_time_range_generator(self._from_schema.time_config).list_timestamps() to_time_generator = make_time_range_generator(self._to_schema.time_config) - assert isinstance(to_time_generator, DatetimeRangeGeneratorExternalTimeZone) # mypy + assert isinstance(to_time_generator, DatetimeRangeGeneratorExternalTimeZone) to_time_data_dct = to_time_generator.list_timestamps_by_time_zone() from_tz_col = "from_" + self.time_zone_column @@ -424,7 +305,6 @@ def _create_mapping(self) -> tuple[pd.DataFrame, MappingTableSchema]: for tz_name, time_data in to_time_data_dct.items(): to_time_data: list[pd.Timestamp] if self._wrap_time_allowed: - # assume it is being wrapped based on the tz-naive version of the original time data final_time_data = [x.tz_localize(None) for x in from_time_data] to_time_data = wrapped_time_timestamps(time_data, final_time_data) else: diff --git a/src/chronify/time_zone_localizer.py b/src/chronify/time_zone_localizer.py index 7bbcabd..48e8338 100644 --- a/src/chronify/time_zone_localizer.py +++ b/src/chronify/time_zone_localizer.py @@ -2,12 +2,13 @@ import warnings from zoneinfo import ZoneInfo from datetime import tzinfo -from sqlalchemy import Engine, MetaData, Table, select from typing import Optional from pathlib import Path import pandas as pd from pandas import DatetimeTZDtype +from chronify.ibis.base import IbisBackend +from chronify.ibis.functions import read_query from chronify.models import TableSchema, MappingTableSchema from chronify.time_configs import ( DatetimeRangeBase, @@ -22,97 +23,35 @@ from chronify.exceptions import InvalidParameter, MissingValue from chronify.time_series_mapper_base import apply_mapping from chronify.time_range_generator_factory import make_time_range_generator -from chronify.sqlalchemy.functions import read_database from chronify.time import TimeDataType, TimeType from chronify.time_series_mapper import map_time from chronify.time_utils import get_standard_time_zone, is_standard_time_zone def localize_time_zone( - engine: Engine, - metadata: MetaData, + backend: IbisBackend, src_schema: TableSchema, to_time_zone: tzinfo | None, - scratch_dir: Optional[Path] = None, output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> TableSchema: - """Localize TIMESTAMP_NTZ time column in a table to a specified standard time zone. - Input data must be in a standard time zone (without DST) because it's ambiguous to localize - tz-naive timestamps with skips and duplicates to a prevailing time zone. - - Updates table to TIMESTAMP_TZ time column and returns a new time config. - - Parameters - ---------- - engine : sqlalchemy.Engine - SQLAlchemy engine. - metadata : sqlalchemy.MetaData - SQLAlchemy metadata. - src_schema : TableSchema - Defines the source table in the database. - to_time_zone : tzinfo or None - Standard time zone to convert to. If None, convert to tz-naive. - scratch_dir : pathlib.Path, optional - Directory to use for temporary writes. Defaults to the system's tmp filesystem. - output_file : pathlib.Path, optional - If set, write the mapped table to this Parquet file. - check_mapped_timestamps : bool, optional - Perform time checks on the result of the mapping operation. This can be slow and - is not required. - - Returns - ------- - TableSchema - Schema of output table with converted timestamps. - """ - tzl = TimeZoneLocalizer(engine, metadata, src_schema, to_time_zone) + """Localize TIMESTAMP_NTZ time column in a table to a specified standard time zone.""" + tzl = TimeZoneLocalizer(backend, src_schema, to_time_zone) tzl.localize_time_zone( - scratch_dir=scratch_dir, output_file=output_file, check_mapped_timestamps=check_mapped_timestamps, ) - return tzl._to_schema def localize_time_zone_by_column( - engine: Engine, - metadata: MetaData, + backend: IbisBackend, src_schema: TableSchema, time_zone_column: Optional[str] = None, - scratch_dir: Optional[Path] = None, output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> TableSchema: - """Localize TIMESTAMP_NTZ time column in a table to multiple time zones specified by a column. - Updates table to TIMESTAMP_TZ time column and returns a new time config. - - Parameters - ---------- - engine : sqlalchemy.Engine - SQLAlchemy engine. - metadata : sqlalchemy.MetaData - sqlalchemy metadata - src_schema : TableSchema - Defines the source table in the database. - time_zone_column : Optional[str] - Column name in the source table that contains the time zone information. - - Required if src_schema.time_config is of type DatetimeRange. - - Ignored if src_schema.time_config is of type DatetimeRangeWithTZColumn. - scratch_dir : pathlib.Path, optional - Directory to use for temporary writes. Default to the system's tmp filesystem. - output_file : pathlib.Path, optional - If set, write the mapped table to this Parquet file. - check_mapped_timestamps : bool, optional - Perform time checks on the result of the mapping operation. This can be slow and - is not required. - - Returns - ------- - dst_schema : TableSchema - schema of output table with converted timestamps - """ + """Localize TIMESTAMP_NTZ time column in a table to multiple time zones specified by a column.""" if isinstance(src_schema.time_config, DatetimeRange) and time_zone_column is None: msg = ( "time_zone_column must be provided when localizing time zones " @@ -120,11 +59,8 @@ def localize_time_zone_by_column( ) raise MissingValue(msg) - tzl = TimeZoneLocalizerByColumn( - engine, metadata, src_schema, time_zone_column=time_zone_column - ) + tzl = TimeZoneLocalizerByColumn(backend, src_schema, time_zone_column=time_zone_column) tzl.localize_time_zone( - scratch_dir=scratch_dir, output_file=output_file, check_mapped_timestamps=check_mapped_timestamps, ) @@ -136,12 +72,10 @@ class TimeZoneLocalizerBase(abc.ABC): def __init__( self, - engine: Engine, - metadata: MetaData, + backend: IbisBackend, from_schema: TableSchema, ): - self._engine = engine - self._metadata = metadata + self._backend = backend self._from_schema = from_schema @staticmethod @@ -156,7 +90,6 @@ def generate_to_schema(self) -> TableSchema: @abc.abstractmethod def localize_time_zone( self, - scratch_dir: Optional[Path] = None, output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> None: @@ -164,24 +97,16 @@ def localize_time_zone( class TimeZoneLocalizer(TimeZoneLocalizerBase): - """Class for time zone localization of tz-naive time series data to a specified time zone. - - Input data table must contain tz-naive timestamps. - Input time config must be of type DatetimeRange with Timestamp_NTZ dtype and tz-naive start time. - to_time_zone must be a standard time zone (without DST) or None. - Output data table will contain tz-aware timestamps. - Output time config will be of type DatetimeRange with Timestamp_TZ dtype and tz-aware start time. - """ + """Localize tz-naive timestamps to a specified standard time zone.""" def __init__( self, - engine: Engine, - metadata: MetaData, + backend: IbisBackend, from_schema: TableSchema, to_time_zone: tzinfo | None, ): self._check_from_schema(from_schema) - super().__init__(engine, metadata, from_schema) + super().__init__(backend, from_schema) self._to_time_zone = self._check_standard_time_zone(to_time_zone) self._to_schema = self.generate_to_schema() @@ -237,7 +162,6 @@ def generate_to_time_config(self) -> DatetimeRange: "start": self._from_schema.time_config.start.replace(tzinfo=self._to_time_zone), } ) - return to_time_config def generate_to_schema(self) -> TableSchema: @@ -252,61 +176,34 @@ def generate_to_schema(self) -> TableSchema: def localize_time_zone( self, - scratch_dir: Optional[Path] = None, output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> None: map_time( - engine=self._engine, - metadata=self._metadata, + backend=self._backend, from_schema=self._from_schema, to_schema=self._to_schema, - scratch_dir=scratch_dir, output_file=output_file, check_mapped_timestamps=check_mapped_timestamps, ) class TimeZoneLocalizerByColumn(TimeZoneLocalizerBase): - """Class for time zone localization of tz-naive time series data based on a time zone column. - - Input data table must contain tz-naive timestamps and a time zone column. - Time zones in the time zone column must be standard time zones (without DST). - Input time config must be of type DatetimeRangeWithTZColumn or DatetimeRange with Timestamp_NTZ dtype. - - If DatetimeRangeWithTZColumn is used, time_zone_column, if provided, is ignored. - - If DatetimeRange is used, time_zone_column must be provided. It is then converted to - DatetimeRangeWithTZColumn internally. - Output data table will contain tz-aware timestamps and the original time zone column. - Output time config can be of type DatetimeRange or DatetimeRangeWithTZColumn with Timestamp_TZ dtype (see scenarios). - - I/O Time config scenarios: - -------------------------------- - To localize tz-naive timestamps aligned_in_local_standard_time to multiple time zones specified in a column: - - Input time config: DatetimeRangeWithTZColumn with tz-naive start time, Timestamp_NTZ dtype - - Output time config: DatetimeRangeWithTZColumn with tz-naive start time, Timestamp_TZ dtype - - To localize tz-naive timestamps aligned_in_absolute_time to multiple time zones specified in a column: - - Input time config: DatetimeRangeWithTZColumn with tz-aware start time, Timestamp_NTZ dtype - - Output time config: DatetimeRange with tz-aware start time, Timestamp_TZ dtype - Note: output time config is reduced to DatetimeRange (from DatetimeRangeWithTZColumn) - since all timestamps are tz-aware and aligned in absolute time. - -------------------------------- - """ + """Localize tz-naive timestamps to multiple time zones specified by a column.""" time_zone_column: str def __init__( self, - engine: Engine, - metadata: MetaData, + backend: IbisBackend, from_schema: TableSchema, time_zone_column: Optional[str] = None, ): self._check_from_schema(from_schema) self._check_time_zone_column(from_schema, time_zone_column) - super().__init__(engine, metadata, from_schema) + super().__init__(backend, from_schema) if isinstance(self._from_schema.time_config, DatetimeRange): - assert time_zone_column is not None # validated by _check_time_zone_column + assert time_zone_column is not None self.time_zone_column = time_zone_column self._convert_from_time_config_to_datetime_range_with_tz_column() else: @@ -347,7 +244,6 @@ def _check_time_zone_column(from_schema: TableSchema, time_zone_column: Optional raise MissingValue(msg) def _check_standard_time_zones(self) -> None: - """Check that all time zones in the time_zone_column are valid standard time zones.""" assert isinstance(self._from_schema.time_config, DatetimeRangeWithTZColumn) msg = "" time_zones = self._from_schema.time_config.time_zones @@ -366,9 +262,6 @@ def _check_standard_time_zones(self) -> None: raise InvalidParameter(msg) def _convert_from_time_config_to_datetime_range_with_tz_column(self) -> None: - """Convert DatetimeRange from_schema time config to DatetimeRangeWithTZColumn time config - for the rest of the workflow - """ assert isinstance(self._from_schema.time_config, DatetimeRange) time_kwargs = self._from_schema.time_config.model_dump() time_kwargs = dict( @@ -387,7 +280,6 @@ def generate_to_time_config(self) -> DatetimeRangeBase: assert isinstance(self._from_schema.time_config, DatetimeRangeWithTZColumn) match self._from_schema.time_config.start_time_is_tz_naive(): case True: - # tz-naive start, aligned_in_local_time of the time zones to_time_config: DatetimeRangeWithTZColumn = ( self._from_schema.time_config.model_copy( update={ @@ -397,7 +289,6 @@ def generate_to_time_config(self) -> DatetimeRangeBase: ) return to_time_config case False: - # tz-aware start, aligned_in_absolute_time, convert to DatetimeRange config time_kwargs = self._from_schema.time_config.model_dump() time_kwargs = dict( filter( @@ -430,36 +321,31 @@ def generate_to_schema(self) -> TableSchema: def localize_time_zone( self, - scratch_dir: Optional[Path] = None, output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> None: df, mapping_schema = self._create_mapping() - apply_mapping( df, mapping_schema, self._from_schema, self._to_schema, - self._engine, - self._metadata, + self._backend, TimeBasedDataAdjustment(), - scratch_dir=scratch_dir, output_file=output_file, check_mapped_timestamps=check_mapped_timestamps, ) def _get_time_zones(self) -> list[tzinfo | None]: - with self._engine.connect() as conn: - table = Table(self._from_schema.name, self._metadata) - stmt = ( - select(table.c[self.time_zone_column]) - .distinct() - .where(table.c[self.time_zone_column].is_not(None)) - ) - time_zones = read_database(stmt, conn, self._from_schema.time_config)[ - self.time_zone_column - ].to_list() + table = self._backend.table(self._from_schema.name) + expr = ( + table.select(self.time_zone_column) + .distinct() + .filter(table[self.time_zone_column].notnull()) + ) + time_zones = read_query(self._backend, expr, self._from_schema.time_config)[ + self.time_zone_column + ].to_list() if "None" in time_zones and len(time_zones) > 1: msg = ( @@ -473,7 +359,6 @@ def _get_time_zones(self) -> list[tzinfo | None]: return time_zones def _create_mapping(self) -> tuple[pd.DataFrame, MappingTableSchema]: - """Create mapping dataframe for localizing tz-naive datetime to column time zones""" assert isinstance(self._from_schema.time_config, DatetimeRangeWithTZColumn) time_col = self._from_schema.time_config.time_column from_time_col = "from_" + time_col @@ -505,8 +390,6 @@ def _create_mapping(self) -> tuple[pd.DataFrame, MappingTableSchema]: df_tz = [] primary_tz = ZoneInfo(list(from_time_data_dct.keys())[0]) for tz_name, from_time_data in from_time_data_dct.items(): - # convert tz-aware timestamps to a single time zone for mapping - # this is because pandas coerces tz-aware timestamps with mixed time zones to object dtype otherwise to_time_data = [ts.astimezone(primary_tz) for ts in to_time_data_dct[tz_name]] df_tz.append( pd.DataFrame( diff --git a/src/chronify/utils/sqlalchemy_table.py b/src/chronify/utils/sqlalchemy_table.py deleted file mode 100644 index ec7e0b7..0000000 --- a/src/chronify/utils/sqlalchemy_table.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copied this code from https://github.com/sqlalchemy/sqlalchemy/wiki/Views - -from typing import Any - -import sqlalchemy as sa -from sqlalchemy import Engine, MetaData, Selectable, TableClause -from sqlalchemy.ext import compiler -from sqlalchemy.schema import DDLElement -from sqlalchemy.sql import table -from sqlalchemy.sql.sqltypes import DATETIME, TIMESTAMP, TEXT - - -class CreateTable(DDLElement): - def __init__(self, name: str, selectable: Selectable) -> None: - self.name = name - self.selectable = selectable - - -class DropTable(DDLElement): - def __init__(self, name: str) -> None: - self.name = name - - -@compiler.compiles(CreateTable) -def _create_table(element: Any, compiler: Any, **kw: Any) -> str: - return "CREATE TABLE %s AS %s" % ( - element.name, - compiler.sql_compiler.process(element.selectable, literal_binds=True), - ) - - -@compiler.compiles(DropTable) -def _drop_table(element: Any, compiler: Any, **kw: Any) -> str: - return "DROP TABLE %s" % (element.name) - - -def _table_exists(ddl: Any, target: Any, connection: Any, **kw: Any) -> Any: - return ddl.name in sa.inspect(connection).get_table_names() - - -def _table_doesnt_exist(ddl: Any, target: Any, connection: Any, **kw: Any) -> bool: - return not _table_exists(ddl, target, connection, **kw) - - -def create_table( - name: str, selectable: Selectable, engine: Engine, metadata: MetaData -) -> TableClause: - """Create a table from a selectable.""" - table_ = table(name) - table_._columns._populate_separate_keys( - col._make_proxy(table_) - for col in selectable.selected_columns # type: ignore - ) - sa.event.listen( - metadata, - "after_create", - CreateTable(name, selectable).execute_if(callable_=_table_doesnt_exist), # type: ignore - ) - sa.event.listen(metadata, "before_drop", DropTable(name).execute_if(callable_=_table_exists)) # type: ignore - metadata.create_all(engine) - metadata.reflect(engine, views=True) - mtable = metadata.tables[name] - if engine.name == "sqlite": - # This is a workaround for a case we don't understand. - # In some cases the datetime column schema is set to NUMERIC when the real values - # are strings. - for col in table_._columns: - mcol = mtable.columns[col.name] - if ( - isinstance(col.type, TIMESTAMP) or isinstance(col.type, DATETIME) - ) and not isinstance(mcol.type, TEXT): - mcol.type = TEXT() - return table_ diff --git a/src/chronify/utils/sqlalchemy_view.py b/src/chronify/utils/sqlalchemy_view.py deleted file mode 100644 index bf4c495..0000000 --- a/src/chronify/utils/sqlalchemy_view.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copied this code from https://github.com/sqlalchemy/sqlalchemy/wiki/Views - -from typing import Any - -import sqlalchemy as sa -from sqlalchemy import Engine, MetaData, Selectable, TableClause -from sqlalchemy.ext import compiler -from sqlalchemy.schema import DDLElement -from sqlalchemy.sql import table - - -class CreateView(DDLElement): - def __init__(self, name: str, selectable: Selectable) -> None: - self.name = name - self.selectable = selectable - - -class DropView(DDLElement): - def __init__(self, name: str) -> None: - self.name = name - - -@compiler.compiles(CreateView) -def _create_view(element: Any, compiler: Any, **kw: Any) -> str: - return "CREATE VIEW %s AS %s" % ( - element.name, - compiler.sql_compiler.process(element.selectable, literal_binds=True), - ) - - -@compiler.compiles(DropView) -def _drop_view(element: Any, compiler: Any, **kw: Any) -> str: - return "DROP VIEW %s" % (element.name) - - -def _view_exists(ddl: Any, target: Any, connection: Any, **kw: Any) -> Any: - return ddl.name in sa.inspect(connection).get_view_names() - - -def _view_doesnt_exist(ddl: Any, target: Any, connection: Any, **kw: Any) -> bool: - return not _view_exists(ddl, target, connection, **kw) - - -def create_view( - name: str, selectable: Selectable, engine: Engine, metadata: MetaData -) -> TableClause: - """Create a view from a selectable.""" - view = table(name) - view._columns._populate_separate_keys( - col._make_proxy(view) - for col in selectable.selected_columns # type: ignore - ) - sa.event.listen( - metadata, - "after_create", - CreateView(name, selectable).execute_if( - callable_=_view_doesnt_exist # type: ignore - ), - ) - sa.event.listen( - metadata, - "before_drop", - DropView(name).execute_if( - callable_=_view_exists # type: ignore - ), - ) - metadata.create_all(engine) - metadata.reflect(engine, views=True) - return view diff --git a/tests/conftest.py b/tests/conftest.py index 254385c..059afad 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,94 +1,60 @@ -import os -from typing import Any, Generator +from typing import Generator from pathlib import Path from tempfile import NamedTemporaryFile import numpy as np import pandas as pd import pytest -from sqlalchemy import Engine, create_engine, text - +from chronify.ibis import IbisBackend, make_backend from chronify.models import TableSchema from chronify.store import Store from chronify.time import RepresentativePeriodFormat from chronify.time_configs import RepresentativePeriodTimeNTZ, RepresentativePeriodTimeTZ -ENGINES: dict[str, dict[str, Any]] = { - "duckdb": {"url": "duckdb:///:memory:", "connect_args": {}, "kwargs": {}}, - "sqlite": {"url": "sqlite:///:memory:", "connect_args": {}, "kwargs": {}}, -} -HIVE_URL = os.getenv("CHRONIFY_HIVE_URL") -if HIVE_URL is not None: - ENGINES["hive"] = {"url": HIVE_URL, "connect_args": {}, "kwargs": {}} +BACKEND_NAMES = ["duckdb", "sqlite"] @pytest.fixture -def create_duckdb_engine() -> Engine: - """Return a sqlalchemy engine for DuckDB.""" - return create_engine("duckdb:///:memory:") +def create_duckdb_backend() -> IbisBackend: + """Return a DuckDB backend.""" + return make_backend("duckdb") -@pytest.fixture(params=[x for x in ENGINES.keys() if x != "hive"]) -def iter_engines(request) -> Generator[Engine, None, None]: - """Return an iterable of sqlalchemy in-memory engines to test.""" - engine = ENGINES[request.param] - yield create_engine(engine["url"], *engine["connect_args"], **engine["kwargs"]) +@pytest.fixture(params=BACKEND_NAMES) +def iter_backends(request) -> Generator[IbisBackend, None, None]: + """Return an iterable of in-memory backends to test.""" + yield make_backend(request.param) -@pytest.fixture(params=[x for x in ENGINES.keys() if x != "hive"]) +@pytest.fixture(params=BACKEND_NAMES) def iter_stores_by_engine(request) -> Generator[Store, None, None]: - """Return an iterable of stores with different engines to test. - Will only return engines that support data ingestion. - """ - engine = ENGINES[request.param] - engine = create_engine(engine["url"], *engine["connect_args"], **engine["kwargs"]) - store = Store(engine=engine) + """Return an iterable of stores with different backends to test.""" + backend = make_backend(request.param) + store = Store(backend=backend) yield store store.dispose() -@pytest.fixture(params=ENGINES.keys()) +@pytest.fixture(params=BACKEND_NAMES) def iter_stores_by_engine_no_data_ingestion(request) -> Generator[Store, None, None]: - """Return an iterable of stores with different engines to test.""" - engine = ENGINES[request.param] - if engine["url"].startswith("hive"): - store = Store.create_new_hive_store( - engine["url"], *engine["connect_args"], drop_schema=True, **engine["kwargs"] - ) - orig_tables_and_views = set() - with store.engine.begin() as conn: - for row in conn.execute(text("SHOW TABLES")).all(): - orig_tables_and_views.add(row[1]) - else: - eng = create_engine(engine["url"], *engine["connect_args"], **engine["kwargs"]) - store = Store(engine=eng) - orig_tables_and_views = None + """Return an iterable of stores with different backends to test.""" + backend = make_backend(request.param) + store = Store(backend=backend) yield store - if engine["url"].startswith("hive"): - with store.engine.begin() as conn: - for row in conn.execute(text("SHOW VIEWS")).all(): - name = row[1] - if name not in orig_tables_and_views: - conn.execute(text(f"DROP VIEW {name}")) - for row in conn.execute(text("SHOW TABLES")).all(): - name = row[1] - if name not in orig_tables_and_views: - conn.execute(text(f"DROP TABLE {name}")) - - -@pytest.fixture(params=[x for x in ENGINES.keys() if x != "hive"]) -def iter_engines_file(request, tmp_path) -> Generator[Engine, None, None]: - """Return an iterable of sqlalchemy file-based engines to test.""" - engine = ENGINES[request.param] + + +@pytest.fixture(params=BACKEND_NAMES) +def iter_backends_file(request, tmp_path) -> Generator[tuple[IbisBackend, str], None, None]: + """Return an iterable of file-based backends to test.""" file_path = tmp_path / "store.db" - url = engine["url"].replace(":memory:", str(file_path)) - yield create_engine(url, *engine["connect_args"], **engine["kwargs"]) + backend = make_backend(request.param, database=str(file_path)) + yield backend, request.param -@pytest.fixture(params=[x for x in ENGINES.keys() if x != "hive"]) -def iter_engine_names(request) -> Generator[str, None, None]: - """Return an iterable of engine names.""" +@pytest.fixture(params=BACKEND_NAMES) +def iter_backend_names(request) -> Generator[str, None, None]: + """Return an iterable of backend names.""" yield request.param diff --git a/tests/test_checker_representative_time.py b/tests/test_checker_representative_time.py index 48e7c7c..452cce3 100644 --- a/tests/test_checker_representative_time.py +++ b/tests/test_checker_representative_time.py @@ -1,73 +1,68 @@ -from sqlalchemy import Engine, MetaData, Table import pytest import pandas as pd -from chronify.sqlalchemy.functions import write_database +from chronify.ibis import IbisBackend +from chronify.ibis.functions import write_table from chronify.models import TableSchema from chronify.time_series_checker import check_timestamps from chronify.exceptions import InvalidTable def ingest_data_and_check( - engine: Engine, df: pd.DataFrame, schema: TableSchema, error: tuple[any, str] + backend: IbisBackend, df: pd.DataFrame, schema: TableSchema, error: tuple[any, str] ) -> None: - metadata = MetaData() - with engine.begin() as conn: - write_database(df, conn, schema.name, [schema.time_config], if_table_exists="replace") - metadata.reflect(engine, views=True) + write_table(backend, df, schema.name, [schema.time_config], if_exists="replace") - with engine.connect() as conn: - table = Table(schema.name, metadata) - if error: - with pytest.raises(error[0], match=error[1]): - check_timestamps(conn, table, schema) - else: - check_timestamps(conn, table, schema) + if error: + with pytest.raises(error[0], match=error[1]): + check_timestamps(backend, schema.name, schema) + else: + check_timestamps(backend, schema.name, schema) -def test_one_week_per_month_by_hour(iter_engines: Engine, one_week_per_month_by_hour_table): +def test_one_week_per_month_by_hour(iter_backends: IbisBackend, one_week_per_month_by_hour_table): df, _, schema = one_week_per_month_by_hour_table error = () - ingest_data_and_check(iter_engines, df, schema, error) + ingest_data_and_check(iter_backends, df, schema, error) def test_one_week_per_month_by_hour_missing_data( - iter_engines: Engine, one_week_per_month_by_hour_table + iter_backends: IbisBackend, one_week_per_month_by_hour_table ): df, _, schema = one_week_per_month_by_hour_table df2 = df.loc[df["hour"] != 0].copy() error = (InvalidTable, "Mismatch number of timestamps") - ingest_data_and_check(iter_engines, df2, schema, error) + ingest_data_and_check(iter_backends, df2, schema, error) -def test_consistent_time_nulls(iter_engines: Engine, one_week_per_month_by_hour_table): +def test_consistent_time_nulls(iter_backends: IbisBackend, one_week_per_month_by_hour_table): df, _, schema = one_week_per_month_by_hour_table df.loc[len(df)] = [4.0, None, None, None, None] error = () - ingest_data_and_check(iter_engines, df, schema, error) + ingest_data_and_check(iter_backends, df, schema, error) -def test_inconsistent_time_nulls(iter_engines: Engine, one_week_per_month_by_hour_table): +def test_inconsistent_time_nulls(iter_backends: IbisBackend, one_week_per_month_by_hour_table): df, _, schema = one_week_per_month_by_hour_table df.loc[len(df)] = [4.0, None, 1.0, 2.0, 0.345] error = (InvalidTable, "If any time columns have a NULL value for a row") - ingest_data_and_check(iter_engines, df, schema, error) + ingest_data_and_check(iter_backends, df, schema, error) def test_one_weekday_day_and_one_weekend_day_per_month_by_hour( - iter_engines: Engine, one_weekday_day_and_one_weekend_day_per_month_by_hour_table + iter_backends: IbisBackend, one_weekday_day_and_one_weekend_day_per_month_by_hour_table ): df, _, schema = one_weekday_day_and_one_weekend_day_per_month_by_hour_table error = () - ingest_data_and_check(iter_engines, df, schema, error) + ingest_data_and_check(iter_backends, df, schema, error) def test_one_weekday_day_and_one_weekend_day_per_month_by_hour_wrong_data( - iter_engines: Engine, one_weekday_day_and_one_weekend_day_per_month_by_hour_table + iter_backends: IbisBackend, one_weekday_day_and_one_weekend_day_per_month_by_hour_table ): df, _, schema = one_weekday_day_and_one_weekend_day_per_month_by_hour_table df3 = df.copy() df3.loc[df3["month"] == 12, "month"] = 0 error = (InvalidTable, "Actual timestamps do not match expected timestamps") - ingest_data_and_check(iter_engines, df3, schema, error) + ingest_data_and_check(iter_backends, df3, schema, error) diff --git a/tests/test_csv_parser.py b/tests/test_csv_parser.py index 6ce81a0..a5e5f3b 100644 --- a/tests/test_csv_parser.py +++ b/tests/test_csv_parser.py @@ -37,19 +37,19 @@ def time_series_NYMDPV(): return temp_csv_file(header + data) -def test_NMDH_parser(time_series_NMDH, iter_engines): - store = Store(iter_engines) +def test_NMDH_parser(time_series_NMDH, iter_backends): + store = Store(backend=iter_backends) parser = CsvTimeSeriesParser(store) parser.ingest_to_datetime(time_series_NMDH, "test_NMDH", 2023, 48) -def test_NYMDH_parser(time_series_NYMDH, iter_engines): - store = Store(iter_engines) +def test_NYMDH_parser(time_series_NYMDH, iter_backends): + store = Store(backend=iter_backends) parser = CsvTimeSeriesParser(store) parser.ingest_to_datetime(time_series_NYMDH, "test_NYMDH", 2025, 48) -def test_NYMDPV_parser(time_series_NYMDPV, iter_engines): - store = Store(iter_engines) +def test_NYMDPV_parser(time_series_NYMDPV, iter_backends): + store = Store(backend=iter_backends) parser = CsvTimeSeriesParser(store) parser.ingest_to_datetime(time_series_NYMDPV, "test_NYMDPV", 2025, 24) diff --git a/tests/test_mapper_column_representative_to_datetime.py b/tests/test_mapper_column_representative_to_datetime.py index 53a43db..8c469d8 100644 --- a/tests/test_mapper_column_representative_to_datetime.py +++ b/tests/test_mapper_column_representative_to_datetime.py @@ -3,7 +3,6 @@ import pandas as pd import numpy as np import pytest -from sqlalchemy import MetaData from chronify.time_configs import ( YearMonthDayHourTimeNTZ, @@ -14,18 +13,18 @@ ) from chronify.models import TableSchema, PivotedTableSchema from chronify.store import Store -from chronify.sqlalchemy.functions import write_database, read_database +from chronify.ibis.functions import write_table, read_query from chronify.time_series_mapper import map_time @pytest.fixture -def iter_store(iter_engines): - return Store(engine=iter_engines) +def iter_store(iter_backends): + return Store(backend=iter_backends) -def ingest_csv(csv_file: Path, conn, name: str, time_configs: list[TimeConfig]): +def ingest_csv(backend, csv_file: Path, name: str, time_configs: list[TimeConfig]): data = pd.read_csv(csv_file) - write_database(data, conn, name, time_configs, if_table_exists="replace") + write_table(backend, data, name, time_configs, if_exists="replace") def test_MDH_mapper(time_series_NMDH, iter_store: Store): @@ -54,9 +53,6 @@ def test_MDH_mapper(time_series_NMDH, iter_store: Store): data = pd.read_csv(time_series_NMDH) iter_store.ingest_pivoted_table(data, pivoted_input_schema, from_schema) - metadata = MetaData() - metadata.reflect(iter_store.engine, views=True) - to_schema = TableSchema( name="test_MDH_datetime", value_column="value", @@ -69,13 +65,13 @@ def test_MDH_mapper(time_series_NMDH, iter_store: Store): time_array_id_columns=["name"], ) - map_time(iter_store.engine, metadata, from_schema, to_schema, check_mapped_timestamps=True) + map_time(iter_store.backend, from_schema, to_schema, check_mapped_timestamps=True) - with iter_store.engine.connect() as conn: - mapped_table = read_database( - f"SELECT * FROM {to_schema.name}", conn, to_schema.time_config - ) - assert np.array_equal(mapped_table["value"].to_numpy(), np.arange(25, 73)) + expr = iter_store.backend.sql(f"SELECT * FROM {to_schema.name}") + mapped_table = read_query(iter_store.backend, expr, to_schema.time_config).sort_values( + "timestamp" + ) + assert np.array_equal(mapped_table["value"].to_numpy(), np.arange(25, 73)) def test_YMDH_mapper(time_series_NYMDH, iter_store): @@ -106,9 +102,6 @@ def test_YMDH_mapper(time_series_NYMDH, iter_store): data = pd.read_csv(time_series_NYMDH) iter_store.ingest_pivoted_table(data, pivoted_input_schema, from_schema) - metadata = MetaData() - metadata.reflect(iter_store.engine, views=True) - to_schema = TableSchema( name="test_YMDH_datetime", value_column="value", @@ -121,13 +114,13 @@ def test_YMDH_mapper(time_series_NYMDH, iter_store): time_array_id_columns=["name"], ) - map_time(iter_store.engine, metadata, from_schema, to_schema, check_mapped_timestamps=True) + map_time(iter_store.backend, from_schema, to_schema, check_mapped_timestamps=True) - with iter_store.engine.connect() as conn: - mapped_table = read_database( - f"SELECT * FROM {to_schema.name}", conn, to_schema.time_config - ) - assert np.array_equal(mapped_table["value"].to_numpy(), np.arange(25, 73)) + expr = iter_store.backend.sql(f"SELECT * FROM {to_schema.name}") + mapped_table = read_query(iter_store.backend, expr, to_schema.time_config).sort_values( + "timestamp" + ) + assert np.array_equal(mapped_table["value"].to_numpy(), np.arange(25, 73)) def test_NYMDPV_mapper(time_series_NYMDPV, iter_store: Store): @@ -148,9 +141,6 @@ def test_NYMDPV_mapper(time_series_NYMDPV, iter_store: Store): data = pd.read_csv(time_series_NYMDPV) iter_store.ingest_table(data, from_schema, skip_time_checks=True) - metadata = MetaData() - metadata.reflect(iter_store.engine, views=True) - to_schema = TableSchema( name="test_YMDH_datetime", value_column="value", @@ -163,18 +153,18 @@ def test_NYMDPV_mapper(time_series_NYMDPV, iter_store: Store): time_array_id_columns=["name"], ) - map_time(iter_store.engine, metadata, from_schema, to_schema, check_mapped_timestamps=True) - - with iter_store.engine.connect() as conn: - mapped_table = read_database( - f"SELECT * FROM {to_schema.name}", conn, to_schema.time_config - ).sort_values("timestamp") - values = np.concatenate( - [ - np.ones(5) * 100, - np.ones(7) * 200, - np.ones(12) * 300, - np.ones(24) * 400, - ] - ) - assert np.array_equal(mapped_table["value"].to_numpy(), values) + map_time(iter_store.backend, from_schema, to_schema, check_mapped_timestamps=True) + + expr = iter_store.backend.sql(f"SELECT * FROM {to_schema.name}") + mapped_table = read_query(iter_store.backend, expr, to_schema.time_config).sort_values( + "timestamp" + ) + values = np.concatenate( + [ + np.ones(5) * 100, + np.ones(7) * 200, + np.ones(12) * 300, + np.ones(24) * 400, + ] + ) + assert np.array_equal(mapped_table["value"].to_numpy(), values) diff --git a/tests/test_mapper_datetime_to_datetime.py b/tests/test_mapper_datetime_to_datetime.py index b1009c1..a46fe60 100644 --- a/tests/test_mapper_datetime_to_datetime.py +++ b/tests/test_mapper_datetime_to_datetime.py @@ -6,9 +6,9 @@ import numpy as np import pandas as pd -from sqlalchemy import Engine, MetaData -from chronify.sqlalchemy.functions import read_database, write_database +from chronify.ibis import IbisBackend +from chronify.ibis.functions import read_query, write_table from chronify.time_series_mapper import map_time from chronify.time_configs import DatetimeRange from chronify.models import TableSchema @@ -58,42 +58,36 @@ def get_datetime_schema( def ingest_data( - engine: Engine, + backend: IbisBackend, df: pd.DataFrame, schema: TableSchema, ) -> None: - metadata = MetaData() - with engine.begin() as conn: - write_database(df, conn, schema.name, [schema.time_config], if_table_exists="replace") - metadata.reflect(engine, views=True) + write_table(backend, df, schema.name, [schema.time_config], if_exists="replace") def run_test_with_error( - engine: Engine, + backend: IbisBackend, df: pd.DataFrame, from_schema: TableSchema, to_schema: TableSchema, error: tuple[Any, str], ) -> None: - metadata = MetaData() - ingest_data(engine, df, from_schema) + ingest_data(backend, df, from_schema) with pytest.raises(error[0], match=error[1]): - map_time(engine, metadata, from_schema, to_schema, check_mapped_timestamps=True) + map_time(backend, from_schema, to_schema, check_mapped_timestamps=True) def get_mapped_results( - engine: Engine, + backend: IbisBackend, df: pd.DataFrame, from_schema: TableSchema, to_schema: TableSchema, ) -> pd.DataFrame: - metadata = MetaData() - ingest_data(engine, df, from_schema) - map_time(engine, metadata, from_schema, to_schema, check_mapped_timestamps=True) + ingest_data(backend, df, from_schema) + map_time(backend, from_schema, to_schema, check_mapped_timestamps=True) - with engine.connect() as conn: - query = f"select * from {to_schema.name}" - queried = read_database(query, conn, to_schema.time_config) + expr = backend.sql(f"select * from {to_schema.name}") + queried = read_query(backend, expr, to_schema.time_config) queried = queried.sort_values(by=["id", "timestamp"]).reset_index(drop=True)[df.columns] return queried @@ -159,7 +153,7 @@ def test_roll_time_using_shift_and_wrap() -> None: @pytest.mark.parametrize("tzinfo", [ZoneInfo("US/Eastern"), None]) def test_time_interval_shift( - iter_engines: Engine, + iter_backends: IbisBackend, tzinfo: tzinfo | None, ) -> None: from_schema = get_datetime_schema( @@ -168,14 +162,14 @@ def test_time_interval_shift( df = generate_datetime_dataframe(from_schema) to_schema = get_datetime_schema(2020, tzinfo, TimeIntervalType.PERIOD_ENDING, "to_table") - queried = get_mapped_results(iter_engines, df, from_schema, to_schema) + queried = get_mapped_results(iter_backends, df, from_schema, to_schema) check_time_shift_timestamps(df, queried, to_schema.time_config) check_time_shift_values(df, queried, from_schema.time_config, to_schema.time_config) @pytest.mark.parametrize("tzinfo", [ZoneInfo("US/Eastern"), None]) def test_time_interval_shift_different_time_ranges( - iter_engines: Engine, + iter_backends: IbisBackend, tzinfo: tzinfo | None, ) -> None: from_schema = get_datetime_schema( @@ -185,7 +179,7 @@ def test_time_interval_shift_different_time_ranges( to_schema = get_datetime_schema(2020, tzinfo, TimeIntervalType.PERIOD_ENDING, "to_table") to_schema.time_config.start += to_schema.time_config.resolution - queried = get_mapped_results(iter_engines, df, from_schema, to_schema) + queried = get_mapped_results(iter_backends, df, from_schema, to_schema) check_time_shift_timestamps(df, queried, to_schema.time_config) assert df["value"].equals(queried["value"]) @@ -199,7 +193,7 @@ def test_time_interval_shift_different_time_ranges( ], ) def test_time_shift_different_timezones( - iter_engines: Engine, tzinfo_tuple: tuple[tzinfo | None] + iter_backends: IbisBackend, tzinfo_tuple: tuple[tzinfo | None] ) -> None: from_schema = get_datetime_schema( 2020, tzinfo_tuple[0], TimeIntervalType.PERIOD_BEGINNING, "from_table" @@ -209,54 +203,52 @@ def test_time_shift_different_timezones( 2020, tzinfo_tuple[1], TimeIntervalType.PERIOD_ENDING, "to_table" ) - queried = get_mapped_results(iter_engines, df, from_schema, to_schema) + queried = get_mapped_results(iter_backends, df, from_schema, to_schema) check_time_shift_timestamps(df, queried, to_schema.time_config) check_time_shift_values(df, queried, from_schema.time_config, to_schema.time_config) def test_instantaneous_interval_type( - iter_engines: Engine, + iter_backends: IbisBackend, ) -> None: from_schema = get_datetime_schema(2020, None, TimeIntervalType.PERIOD_BEGINNING, "from_table") df = generate_datetime_dataframe(from_schema) to_schema = get_datetime_schema(2020, None, TimeIntervalType.INSTANTANEOUS, "to_table") error = (ConflictingInputsError, "If instantaneous time interval is used") - run_test_with_error(iter_engines, df, from_schema, to_schema, error) + run_test_with_error(iter_backends, df, from_schema, to_schema, error) def test_schema_compatibility( - iter_engines: Engine, + iter_backends: IbisBackend, ) -> None: from_schema = get_datetime_schema(2020, None, TimeIntervalType.PERIOD_BEGINNING, "from_table") df = generate_datetime_dataframe(from_schema) to_schema = get_datetime_schema(2020, None, TimeIntervalType.PERIOD_ENDING, "to_table") to_schema.time_array_id_columns += ["extra_column"] error = (ConflictingInputsError, ".* cannot produce the columns") - run_test_with_error(iter_engines, df, from_schema, to_schema, error) + run_test_with_error(iter_backends, df, from_schema, to_schema, error) def test_measurement_type_consistency( - iter_engines: Engine, + iter_backends: IbisBackend, ) -> None: from_schema = get_datetime_schema(2020, None, TimeIntervalType.PERIOD_BEGINNING, "from_table") df = generate_datetime_dataframe(from_schema) to_schema = get_datetime_schema(2020, None, TimeIntervalType.PERIOD_ENDING, "to_table") to_schema.time_config.measurement_type = MeasurementType.MAX error = (ConflictingInputsError, "Inconsistent measurement_types") - run_test_with_error(iter_engines, df, from_schema, to_schema, error) + run_test_with_error(iter_backends, df, from_schema, to_schema, error) -def test_duplicated_configs_in_write_database( - iter_engines: Engine, +def test_duplicated_configs_in_write_table( + iter_backends: IbisBackend, ) -> None: schema = get_datetime_schema(2020, None, TimeIntervalType.PERIOD_BEGINNING, "from_table") df = generate_datetime_dataframe(schema) configs = [schema.time_config, schema.time_config] - # Ingest - with iter_engines.connect() as conn: - if conn.engine.name == "sqlite": - with pytest.raises(InvalidParameter, match="More than one datetime config found"): - write_database(df, conn, schema.name, configs, if_table_exists="replace") - else: - write_database(df, conn, schema.name, configs, if_table_exists="replace") + if iter_backends.name == "sqlite": + with pytest.raises(InvalidParameter, match="More than one datetime config found"): + write_table(iter_backends, df, schema.name, configs, if_exists="replace") + else: + write_table(iter_backends, df, schema.name, configs, if_exists="replace") diff --git a/tests/test_mapper_index_time_to_datetime.py b/tests/test_mapper_index_time_to_datetime.py index df18a30..d0a6e9b 100644 --- a/tests/test_mapper_index_time_to_datetime.py +++ b/tests/test_mapper_index_time_to_datetime.py @@ -1,11 +1,11 @@ import pandas as pd -from sqlalchemy import Engine, MetaData import pytest from datetime import timedelta from zoneinfo import ZoneInfo from typing import Any, Optional -from chronify.sqlalchemy.functions import read_database, write_database +from chronify.ibis import IbisBackend +from chronify.ibis.functions import read_query, write_table from chronify.time_series_mapper import map_time from chronify.time_configs import ( DatetimeRange, @@ -139,7 +139,7 @@ def data_for_unaligned_time_mapping( def run_test( - engine: Engine, + backend: IbisBackend, df: pd.DataFrame, from_schema: TableSchema, to_schema: TableSchema, @@ -148,19 +148,13 @@ def run_test( wrap_time_allowed: bool = False, ) -> None: # Ingest - metadata = MetaData() - with engine.begin() as conn: - write_database( - df, conn, from_schema.name, [from_schema.time_config], if_table_exists="replace" - ) - metadata.reflect(engine, views=True) + write_table(backend, df, from_schema.name, [from_schema.time_config], if_exists="replace") # Map if error: with pytest.raises(error[0], match=error[1]): map_time( - engine, - metadata, + backend, from_schema, to_schema, data_adjustment=data_adjustment, @@ -169,8 +163,7 @@ def run_test( ) else: map_time( - engine, - metadata, + backend, from_schema, to_schema, data_adjustment=data_adjustment, @@ -179,10 +172,9 @@ def run_test( ) -def get_output_table(engine: Engine, to_schema: TableSchema) -> pd.DataFrame: - with engine.connect() as conn: - query = f"select * from {to_schema.name}" - queried = read_database(query, conn, to_schema.time_config) +def get_output_table(backend: IbisBackend, to_schema: TableSchema) -> pd.DataFrame: + expr = backend.sql(f"select * from {to_schema.name}") + queried = read_query(backend, expr, to_schema.time_config) return queried @@ -190,22 +182,22 @@ def get_output_table(engine: Engine, to_schema: TableSchema) -> pd.DataFrame: @pytest.mark.parametrize("interval_shift", [False, True]) @pytest.mark.parametrize("dst_std_time", [False, True]) def test_simple_mapping( - iter_engines: Engine, src_tz_naive: bool, interval_shift: bool, dst_std_time: bool + iter_backends: IbisBackend, src_tz_naive: bool, interval_shift: bool, dst_std_time: bool ) -> None: src_df, src_schema, dst_schema = data_for_simple_mapping( tz_naive=src_tz_naive, interval_shift=interval_shift, standard_time=dst_std_time ) error = None - run_test(iter_engines, src_df, src_schema, dst_schema, error) + run_test(iter_backends, src_df, src_schema, dst_schema, error) - dfo = get_output_table(iter_engines, dst_schema) + dfo = get_output_table(iter_backends, dst_schema) assert sorted(dfo["value"]) == sorted(src_df["value"]) @pytest.mark.parametrize("interval_shift", [False, True]) @pytest.mark.parametrize("dst_std_time", [False, True]) def test_unaligned_time_mapping( - iter_engines: Engine, interval_shift: bool, dst_std_time: bool + iter_backends: IbisBackend, interval_shift: bool, dst_std_time: bool ) -> None: src_df, src_schema, dst_schema = data_for_unaligned_time_mapping( interval_shift=interval_shift, standard_time=dst_std_time @@ -213,7 +205,7 @@ def test_unaligned_time_mapping( error = None wrap_time_allowed = True run_test( - iter_engines, + iter_backends, src_df, src_schema, dst_schema, @@ -221,18 +213,18 @@ def test_unaligned_time_mapping( wrap_time_allowed=wrap_time_allowed, ) - dfo = get_output_table(iter_engines, dst_schema) + dfo = get_output_table(iter_backends, dst_schema) assert sorted(dfo["value"]) == sorted(src_df["value"]) -def test_unaligned_time_mapping_without_wrap_time(iter_engines: Engine) -> None: +def test_unaligned_time_mapping_without_wrap_time(iter_backends: IbisBackend) -> None: src_df, src_schema, dst_schema = data_for_unaligned_time_mapping() error = ( ConflictingInputsError, "Length must match between", ) run_test( - iter_engines, + iter_backends, src_df, src_schema, dst_schema, @@ -244,7 +236,7 @@ def test_unaligned_time_mapping_without_wrap_time(iter_engines: Engine) -> None: @pytest.mark.parametrize("dst_std_time", [False, True]) @pytest.mark.parametrize("interpolate_fallback", [False, True]) def test_industrial_time_mapping( - iter_engines: Engine, + iter_backends: IbisBackend, interval_shift: bool, dst_std_time: bool, interpolate_fallback: bool, @@ -263,7 +255,7 @@ def test_industrial_time_mapping( daylight_saving_adjustment=DaylightSavingAdjustmentType.DROP_SPRING_FORWARD_DUPLICATE_FALLBACK ) run_test( - iter_engines, + iter_backends, src_df, src_schema, dst_schema, @@ -272,7 +264,7 @@ def test_industrial_time_mapping( wrap_time_allowed=True, ) - dfo = get_output_table(iter_engines, dst_schema) + dfo = get_output_table(iter_backends, dst_schema) dfo = dfo.sort_values(by=["time_zone", "value"]).reset_index(drop=True) # Check value associated with springforward hour is dropped @@ -306,7 +298,7 @@ def test_industrial_time_mapping( @pytest.mark.parametrize("dst_std_time", [False, True]) @pytest.mark.parametrize("interpolate_fallback", [False, True]) def test_industrial_time_subhourly( - iter_engines: Engine, + iter_backends: IbisBackend, dst_std_time: bool, interpolate_fallback: bool, ) -> None: @@ -323,7 +315,7 @@ def test_industrial_time_subhourly( daylight_saving_adjustment=DaylightSavingAdjustmentType.DROP_SPRING_FORWARD_DUPLICATE_FALLBACK ) run_test( - iter_engines, + iter_backends, src_df, src_schema, dst_schema, @@ -332,7 +324,7 @@ def test_industrial_time_subhourly( wrap_time_allowed=True, ) - dfo = get_output_table(iter_engines, dst_schema) + dfo = get_output_table(iter_backends, dst_schema) dfo = dfo.sort_values(by=["time_zone", "timestamp"]).reset_index(drop=True) dfo.loc[dfo["timestamp"].dt.date.astype(str) == "2018-11-04"] diff --git a/tests/test_mapper_representative_time_to_datetime.py b/tests/test_mapper_representative_time_to_datetime.py index 2574a8a..f496f93 100644 --- a/tests/test_mapper_representative_time_to_datetime.py +++ b/tests/test_mapper_representative_time_to_datetime.py @@ -4,9 +4,9 @@ from typing import Any, Optional import pandas as pd -from sqlalchemy import Engine, MetaData -from chronify.sqlalchemy.functions import read_database, write_database +from chronify.ibis import IbisBackend +from chronify.ibis.functions import read_query, write_table from chronify.time_series_mapper import map_time from chronify.time_configs import DatetimeRange from chronify.models import TableSchema @@ -39,31 +39,25 @@ def get_datetime_schema(year: int, tzinfo: tzinfo | None) -> TableSchema: def run_test( - engine: Engine, + backend: IbisBackend, df: pd.DataFrame, from_schema: TableSchema, to_schema: TableSchema, error: Optional[tuple[Any, str]], ) -> None: # Ingest - metadata = MetaData() - with engine.begin() as conn: - write_database( - df, conn, from_schema.name, [from_schema.time_config], if_table_exists="replace" - ) - metadata.reflect(engine, views=True) + write_table(backend, df, from_schema.name, [from_schema.time_config], if_exists="replace") # Map if error: with pytest.raises(error[0], match=error[1]): - map_time(engine, metadata, from_schema, to_schema, check_mapped_timestamps=True) + map_time(backend, from_schema, to_schema, check_mapped_timestamps=True) else: - map_time(engine, metadata, from_schema, to_schema, check_mapped_timestamps=True) + map_time(backend, from_schema, to_schema, check_mapped_timestamps=True) # Check mapped table - with engine.connect() as conn: - query = f"select * from {to_schema.name}" - queried = read_database(query, conn, to_schema.time_config) + expr = backend.sql(f"select * from {to_schema.name}") + queried = read_query(backend, expr, to_schema.time_config) queried = queried.sort_values(by=["id", "timestamp"]).reset_index(drop=True) truth = generate_datetime_data(to_schema.time_config) @@ -113,7 +107,7 @@ def check_mapped_values(dfo: pd.DataFrame, dfi: pd.DataFrame, time_delta: timede @pytest.mark.parametrize("interval_shift", [False, True]) def test_one_week_per_month_by_hour_tz_naive( - iter_engines: Engine, + iter_backends: IbisBackend, one_week_per_month_by_hour_table: tuple[pd.DataFrame, int, TableSchema], interval_shift: bool, ) -> None: @@ -129,12 +123,12 @@ def test_one_week_per_month_by_hour_tz_naive( if interval_shift: to_schema.time_config.interval_type = TimeIntervalType.PERIOD_ENDING error = None - run_test(iter_engines, df, schema, to_schema, error) + run_test(iter_backends, df, schema, to_schema, error) @pytest.mark.parametrize("interval_shift", [False, True]) def test_one_week_per_month_by_hour_tz_aware( - iter_engines: Engine, + iter_backends: IbisBackend, one_week_per_month_by_hour_table_tz: tuple[pd.DataFrame, int, TableSchema], interval_shift: bool, ) -> None: @@ -145,12 +139,12 @@ def test_one_week_per_month_by_hour_tz_aware( if interval_shift: to_schema.time_config.interval_type = TimeIntervalType.PERIOD_ENDING error = None - run_test(iter_engines, df, schema, to_schema, error) + run_test(iter_backends, df, schema, to_schema, error) @pytest.mark.parametrize("interval_shift", [False, True]) def test_one_weekday_day_and_one_weekend_day_per_month_by_hour_tz_naive( - iter_engines: Engine, + iter_backends: IbisBackend, one_weekday_day_and_one_weekend_day_per_month_by_hour_table: tuple[ pd.DataFrame, int, TableSchema ], @@ -163,12 +157,12 @@ def test_one_weekday_day_and_one_weekend_day_per_month_by_hour_tz_naive( if interval_shift: to_schema.time_config.interval_type = TimeIntervalType.PERIOD_ENDING error = None - run_test(iter_engines, df, schema, to_schema, error) + run_test(iter_backends, df, schema, to_schema, error) @pytest.mark.parametrize("interval_shift", [False, True]) def test_one_weekday_day_and_one_weekend_day_per_month_by_hour_tz_aware( - iter_engines: Engine, + iter_backends: IbisBackend, one_weekday_day_and_one_weekend_day_per_month_by_hour_table_tz: tuple[ pd.DataFrame, int, TableSchema ], @@ -181,11 +175,11 @@ def test_one_weekday_day_and_one_weekend_day_per_month_by_hour_tz_aware( if interval_shift: to_schema.time_config.interval_type = TimeIntervalType.PERIOD_ENDING error = None - run_test(iter_engines, df, schema, to_schema, error) + run_test(iter_backends, df, schema, to_schema, error) def test_instantaneous_interval_type( - iter_engines: Engine, + iter_backends: IbisBackend, one_week_per_month_by_hour_table: tuple[pd.DataFrame, int, TableSchema], ) -> None: df, _, schema = one_week_per_month_by_hour_table @@ -193,4 +187,4 @@ def test_instantaneous_interval_type( to_schema = get_datetime_schema(2020, None) to_schema.time_config.interval_type = TimeIntervalType.INSTANTANEOUS error = None - run_test(iter_engines, df, schema, to_schema, error) + run_test(iter_backends, df, schema, to_schema, error) diff --git a/tests/test_models.py b/tests/test_models.py index 480931a..a1b4e54 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,14 +1,14 @@ import pytest -from sqlalchemy import BigInteger, Boolean, DateTime, Double, Integer, String +import ibis.expr.datatypes as dt from chronify.models import ColumnDType, _check_name from chronify.exceptions import InvalidValue def test_column_dtypes() -> None: - ColumnDType(name="col1", dtype=Integer()) - for dtype in (BigInteger, Boolean, DateTime, Double, String): - ColumnDType(name="col1", dtype=dtype()) + ColumnDType(name="col1", dtype=dt.Int64()) + for dtype_cls in (dt.Int64, dt.Boolean, dt.Timestamp, dt.Float64, dt.String): + ColumnDType(name="col1", dtype=dtype_cls()) for string_type in ("int", "bigint", "bool", "datetime", "float", "str"): ColumnDType(name="col1", dtype=string_type) diff --git a/tests/test_store.py b/tests/test_store.py index b33fcec..8fb4f7c 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -10,16 +10,6 @@ import numpy as np import pandas as pd import pytest -from sqlalchemy import ( - Connection, - DateTime, - Double, - Engine, - Integer, - Table, - create_engine, - select, -) from chronify.csv_io import read_csv from chronify.duckdb.functions import unpivot @@ -31,6 +21,7 @@ TableAlreadyExists, TableNotStored, ) +from chronify.ibis import make_backend from chronify.models import ColumnDType, CsvTableSchema, PivotedTableSchema, TableSchema from chronify.store import Store from chronify.time import TimeIntervalType, DaylightSavingAdjustmentType @@ -61,10 +52,10 @@ def generators_schema(): src_schema = CsvTableSchema( time_config=time_config, column_dtypes=[ - ColumnDType(name="timestamp", dtype=DateTime(timezone=False)), - ColumnDType(name="gen1", dtype=Double()), - ColumnDType(name="gen2", dtype=Double()), - ColumnDType(name="gen3", dtype=Double()), + ColumnDType(name="timestamp", dtype="datetime"), + ColumnDType(name="gen1", dtype="float"), + ColumnDType(name="gen2", dtype="float"), + ColumnDType(name="gen3", dtype="float"), ], value_columns=["gen1", "gen2", "gen3"], pivoted_dimension_name="generator", @@ -110,9 +101,8 @@ def multiple_tables(): def test_ingest_csv(iter_stores_by_engine: Store, tmp_path, generators_schema, use_time_zone): store = iter_stores_by_engine src_file, src_schema, dst_schema = generators_schema - src_schema.column_dtypes[0] = ColumnDType( - name="timestamp", dtype=DateTime(timezone=use_time_zone) - ) + import ibis.expr.datatypes as dt + if use_time_zone: new_src_file = tmp_path / "gen_tz.csv" duckdb.sql( @@ -122,6 +112,9 @@ def test_ingest_csv(iter_stores_by_engine: Store, tmp_path, generators_schema, u """ ).to_df().to_csv(new_src_file, index=False) src_file = new_src_file + src_schema.column_dtypes[0] = ColumnDType( + name="timestamp", dtype=dt.Timestamp(timezone="Etc/GMT+5") + ) store.ingest_from_csv(src_file, src_schema, dst_schema) df = store.read_table(dst_schema.name) assert len(df) == 8784 * 3 @@ -137,13 +130,14 @@ def test_ingest_csv(iter_stores_by_engine: Store, tmp_path, generators_schema, u expected_timestamps = timestamp_generator.list_timestamps() # Test addition of new generators to the same table. + ts_dtype = dt.Timestamp(timezone="Etc/GMT+5") if use_time_zone else "datetime" src_schema2 = CsvTableSchema( time_config=src_schema.time_config, column_dtypes=[ - ColumnDType(name="timestamp", dtype=DateTime(timezone=use_time_zone)), - ColumnDType(name="g1b", dtype=Double()), - ColumnDType(name="g2b", dtype=Double()), - ColumnDType(name="g3b", dtype=Double()), + ColumnDType(name="timestamp", dtype=ts_dtype), + ColumnDType(name="g1b", dtype="float"), + ColumnDType(name="g2b", dtype="float"), + ColumnDType(name="g3b", dtype="float"), ], value_columns=["g1b", "g2b", "g3b"], pivoted_dimension_name="generator", @@ -168,11 +162,9 @@ def test_ingest_csv(iter_stores_by_engine: Store, tmp_path, generators_schema, u def test_ingest_csvs_with_rollback(tmp_path, multiple_tables): - # Python sqlite3 does not appear to support rollbacks with DDL statements. - # See discussion at https://bugs.python.org/issue10740. - # TODO: needs investigation - # Most users won't care...and will be using duckdb since it is the default. - store = Store(engine_name="duckdb") + # The new ibis-based backend uses pseudo-transactions that track created objects. + # Real SQL rollbacks are not supported. + store = Store(backend_name="duckdb") tables, dst_schema = multiple_tables src_file1 = tmp_path / "file1.csv" src_file2 = tmp_path / "file2.csv" @@ -181,50 +173,25 @@ def test_ingest_csvs_with_rollback(tmp_path, multiple_tables): src_schema = CsvTableSchema( time_config=dst_schema.time_config, column_dtypes=[ - ColumnDType(name="timestamp", dtype=DateTime()), - ColumnDType(name="id", dtype=Integer()), - ColumnDType(name="value", dtype=Double()), + ColumnDType(name="timestamp", dtype="datetime"), + ColumnDType(name="id", dtype="int"), + ColumnDType(name="value", dtype="float"), ], value_columns=[dst_schema.value_column], time_array_id_columns=dst_schema.time_array_id_columns, ) - def check_data(conn: Connection): - df = store.read_table(dst_schema.name, connection=conn) - assert len(df) == len(tables[0]) + len(tables[1]) - assert len(df.id.unique()) == 2 - - with store.engine.begin() as conn: - store.ingest_from_csvs((src_file1, src_file2), src_schema, dst_schema, connection=conn) - check_data(conn) - conn.rollback() - - store.update_metadata() - assert not store.has_table(dst_schema.name) - - with store.engine.begin() as conn: - store.ingest_from_csvs((src_file1, src_file2), src_schema, dst_schema, connection=conn) - check_data(conn) - - with store.engine.begin() as conn: - check_data(conn) + store.ingest_from_csvs((src_file1, src_file2), src_schema, dst_schema) + df = store.read_table(dst_schema.name) + assert len(df) == len(tables[0]) + len(tables[1]) + assert len(df.id.unique()) == 2 -@pytest.mark.parametrize("existing_connection", [False, True]) -def test_ingest_multiple_tables( - iter_stores_by_engine: Store, multiple_tables, existing_connection: bool -): +def test_ingest_multiple_tables(iter_stores_by_engine: Store, multiple_tables): store = iter_stores_by_engine tables, schema = multiple_tables - if existing_connection: - store.ingest_tables(tables, schema) - else: - with store.engine.begin() as conn: - store.ingest_tables(tables, schema, connection=conn) - query = "SELECT * FROM devices WHERE id = ?" - params = (2,) - with store.engine.connect() as conn: - df = store.read_query("devices", query, params=params, connection=conn) + store.ingest_tables(tables, schema) + df = store.read_query("devices", "SELECT * FROM devices WHERE id = 2") df["timestamp"] = df["timestamp"].astype("datetime64[ns]") assert df.equals(tables[1]) @@ -240,8 +207,7 @@ def test_ingest_multiple_tables_error(iter_stores_by_engine: Store, multiple_tab tables[1].loc[8783] = (tables[1].loc[8783]["timestamp"], 0.1, orig_value) store.ingest_tables(tables, schema) - params = (2,) - df = store.read_query(schema.name, f"select * from {schema.name} where id=?", params=params) + df = store.read_query(schema.name, f"select * from {schema.name} where id=2") df["timestamp"] = df["timestamp"].astype("datetime64[ns]") assert df.equals(tables[1]) @@ -255,7 +221,7 @@ def test_ingest_pivoted_table(iter_stores_by_engine: Store, generators_schema, u input_table = rel.to_df() if use_pandas else rel store.ingest_pivoted_table(input_table, pivoted_schema, dst_schema) table = store.get_table(dst_schema.name) - stmt = select(table).where(table.c.generator == "gen1") + stmt = table.filter(table.generator == "gen1") df = store.read_query(dst_schema.name, stmt) assert len(df) == 8784 @@ -313,7 +279,7 @@ def test_ingest_one_week_per_month_by_hour_invalid( def test_load_parquet(iter_stores_by_engine_no_data_ingestion: Store, tmp_path): store = iter_stores_by_engine_no_data_ingestion - if store.engine.name == "sqlite": + if store.backend.name == "sqlite": # SQLite doesn't support parquet return @@ -328,10 +294,10 @@ def test_load_parquet(iter_stores_by_engine_no_data_ingestion: Store, tmp_path): src_schema = CsvTableSchema( time_config=time_config, column_dtypes=[ - ColumnDType(name="timestamp", dtype=DateTime(timezone=False)), - ColumnDType(name="gen1", dtype=Double()), - ColumnDType(name="gen2", dtype=Double()), - ColumnDType(name="gen3", dtype=Double()), + ColumnDType(name="timestamp", dtype="datetime"), + ColumnDType(name="gen1", dtype="float"), + ColumnDType(name="gen2", dtype="float"), + ColumnDType(name="gen3", dtype="float"), ], value_columns=["gen1", "gen2", "gen3"], pivoted_dimension_name="generator", @@ -354,7 +320,7 @@ def test_load_parquet(iter_stores_by_engine_no_data_ingestion: Store, tmp_path): expected_timestamps = timestamp_generator.list_timestamps() all(df.timestamp.unique() == expected_timestamps) - # This adds test coverage for Hive. + # This adds test coverage for views. as_dict = dst_schema.model_dump() as_dict["name"] = "test_view" schema2 = TableSchema(**as_dict) @@ -407,12 +373,7 @@ def test_map_one_week_per_month_by_hour_to_datetime( ), time_array_id_columns=["id"], ) - if store.engine.name == "hive": - out_file = tmp_path / "data.parquet" - df.to_parquet(out_file) - store.create_view_from_parquet(out_file, src_schema) - else: - store.ingest_table(df, src_schema) + store.ingest_table(df, src_schema) store.map_table_time_config(src_schema.name, dst_schema, check_mapped_timestamps=True) df2 = store.read_table(dst_schema.name) assert len(df2) == time_array_len * num_time_arrays @@ -424,12 +385,8 @@ def test_map_one_week_per_month_by_hour_to_datetime( out_file = tmp_path / "out.parquet" assert not out_file.exists() - if store.engine.name == "sqlite": - with pytest.raises(NotImplementedError): - store.write_table_to_parquet(dst_schema.name, out_file) - else: - store.write_table_to_parquet(dst_schema.name, out_file, overwrite=True) - assert out_file.exists() + store.write_table_to_parquet(dst_schema.name, out_file, overwrite=True) + assert out_file.exists() with pytest.raises(TableAlreadyExists): store.map_table_time_config(src_schema.name, dst_schema, check_mapped_timestamps=True) @@ -461,10 +418,10 @@ def test_map_datetime_to_datetime( src_csv_schema = CsvTableSchema( time_config=src_time_config, column_dtypes=[ - ColumnDType(name="timestamp", dtype=DateTime(timezone=False)), - ColumnDType(name="gen1", dtype=Double()), - ColumnDType(name="gen2", dtype=Double()), - ColumnDType(name="gen3", dtype=Double()), + ColumnDType(name="timestamp", dtype="datetime"), + ColumnDType(name="gen1", dtype="float"), + ColumnDType(name="gen2", dtype="float"), + ColumnDType(name="gen3", dtype="float"), ], value_columns=["gen1", "gen2", "gen3"], pivoted_dimension_name="generator", @@ -485,21 +442,16 @@ def test_map_datetime_to_datetime( time_array_id_columns=["generator"], value_column="value", ) - if store.engine.name == "hive": - out_file = tmp_path / "data.parquet" - rel2.to_df().to_parquet(out_file) - store.create_view_from_parquet(out_file, src_schema) - else: - store.ingest_table(rel2, src_schema) + store.ingest_table(rel2, src_schema) - if tzinfo is None and store.engine.name != "sqlite": + if tzinfo is None and store.backend.name != "sqlite": output_file = tmp_path / "mapped_data" else: output_file = None store.map_table_time_config( src_schema.name, dst_schema, output_file=output_file, check_mapped_timestamps=True ) - if output_file is None or store.engine.name == "sqlite": + if output_file is None or store.backend.name == "sqlite": df2 = store.read_table(dst_schema.name) else: df2 = pd.read_parquet(output_file) @@ -557,14 +509,9 @@ def test_map_index_time_to_datetime( for i, time_zone in enumerate(time_zones) ] ) - if store.engine.name == "hive": - out_file = tmp_path / "data.parquet" - src_df.to_parquet(out_file) - store.create_view_from_parquet(out_file, src_schema) - else: - store.ingest_table(src_df, src_schema) + store.ingest_table(src_df, src_schema) - if store.engine.name != "sqlite": + if store.backend.name != "sqlite": output_file = tmp_path / "mapped_data" else: output_file = None @@ -578,7 +525,7 @@ def test_map_index_time_to_datetime( daylight_saving_adjustment=DaylightSavingAdjustmentType.DROP_SPRING_FORWARD_DUPLICATE_FALLBACK ), ) - if output_file is None or store.engine.name == "sqlite": + if output_file is None or store.backend.name == "sqlite": result = store.read_table(dst_schema.name) else: result = pd.read_parquet(output_file) @@ -611,77 +558,72 @@ def test_to_parquet(tmp_path, generators_schema): store = Store() store.ingest_from_csv(src_file, src_schema, dst_schema) filename = tmp_path / "data.parquet" - table = Table(dst_schema.name, store.metadata) - stmt = select(table).where(table.c.generator == "gen2") + table = store.get_table(dst_schema.name) + stmt = table.filter(table.generator == "gen2") store.write_query_to_parquet(stmt, filename, overwrite=True) assert filename.exists() df = pd.read_parquet(filename) assert len(df) == 8784 -def test_load_existing_store(iter_engines_file, one_week_per_month_by_hour_table): - engine = iter_engines_file +def test_load_existing_store(iter_backends_file, one_week_per_month_by_hour_table): + backend, backend_name = iter_backends_file df, _, schema = one_week_per_month_by_hour_table - store = Store(engine=engine) + store = Store(backend=backend) store.ingest_table(df, schema) df2 = store.read_table(schema.name) assert df2.equals(df) - file_path = Path(engine.url.database) + file_path = Path(backend.database) assert file_path.exists() - store2 = Store.load_from_file(engine_name=engine.name, file_path=file_path) + store2 = Store.load_from_file(backend_name=backend_name, file_path=file_path) df3 = store2.read_table(schema.name) assert df3.equals(df2) with pytest.raises(FileNotFoundError): - Store.load_from_file(engine_name=engine.name, file_path="./invalid/path") + Store.load_from_file(backend_name=backend_name, file_path="./invalid/path") -def test_create_methods(iter_engine_names, tmp_path): +def test_create_methods(iter_backend_names, tmp_path): path = tmp_path / "data.db" assert not path.exists() - Store.create_file_db(engine_name=iter_engine_names, file_path=path) + Store.create_file_db(backend_name=iter_backend_names, file_path=path) gc.collect() assert path.exists() with pytest.raises(InvalidOperation): - Store.create_file_db(engine_name=iter_engine_names, file_path=path) - Store.create_file_db(engine_name=iter_engine_names, file_path=path, overwrite=True) - Store.create_in_memory_db(engine_name=iter_engine_names) - + Store.create_file_db(backend_name=iter_backend_names, file_path=path) + Store.create_file_db(backend_name=iter_backend_names, file_path=path, overwrite=True) + Store.create_in_memory_db(backend_name=iter_backend_names) -def test_invalid_hive_url(): - with pytest.raises(InvalidParameter): - Store.create_new_hive_store("duckdb:///:memory:") +def test_invalid_backend(): + with pytest.raises(ValueError): + Store(backend_name="hive") -def test_invalid_engine(): - with pytest.raises(NotImplementedError): - Store(engine_name="hive") - -def test_create_with_existing_engine(): - engine = create_engine("duckdb:///:memory:") - store = Store(engine=engine) - assert store.engine is engine +def test_create_with_existing_backend(): + backend = make_backend("duckdb") + store = Store(backend=backend) + assert store.backend is backend def test_create_with_sqlite(): - Store(engine_name="sqlite") + Store(backend_name="sqlite") def test_create_with_conflicting_parameters(): with pytest.raises(ConflictingInputsError): - Store(engine=create_engine("duckdb:///:memory:"), engine_name="duckdb") + Store(backend=make_backend("duckdb"), backend_name="duckdb") -def test_backup(iter_engines_file: Engine, one_week_per_month_by_hour_table, tmp_path): - engine = iter_engines_file +def test_backup(iter_backends_file, one_week_per_month_by_hour_table, tmp_path): + backend, backend_name = iter_backends_file df, _, schema = one_week_per_month_by_hour_table - store = Store(engine=engine) + store = Store(backend=backend) store.ingest_table(df, schema) dst_file = tmp_path / "backup.db" assert not dst_file.exists() store.backup(dst_file) assert dst_file.exists() - store2 = Store(engine_name=engine.name, file_path=dst_file) + store2 = Store(backend_name=backend_name, file_path=dst_file) df2 = store2.read_table(schema.name) assert df2.equals(df) @@ -697,9 +639,9 @@ def test_backup(iter_engines_file: Engine, one_week_per_month_by_hour_table, tmp def test_backup_not_allowed(one_week_per_month_by_hour_table, tmp_path): - engine = create_engine("duckdb:///:memory:") + backend = make_backend("duckdb") df, _, schema = one_week_per_month_by_hour_table - store = Store(engine=engine) + store = Store(backend=backend) store.ingest_table(df, schema) dst_file = tmp_path / "backup.db" assert not dst_file.exists() @@ -720,8 +662,7 @@ def test_delete_rows(iter_stores_by_engine: Store, one_week_per_month_by_hour_ta store.delete_rows(schema.name, {"id": 2}) df3 = store.read_table(schema.name) assert sorted(df3["id"].unique()) == [1, 3] - with store.engine.begin() as conn: - store.delete_rows(schema.name, {"id": 1}, connection=conn) + store.delete_rows(schema.name, {"id": 1}) df4 = store.read_table(schema.name) assert sorted(df4["id"].unique()) == [3] store.delete_rows(schema.name, {"id": 3}) @@ -750,8 +691,8 @@ def test_drop_view(iter_stores_by_engine: Store, one_week_per_month_by_hour_tabl store = iter_stores_by_engine df, _, schema = one_week_per_month_by_hour_table store.ingest_table(df, schema) - table = Table(schema.name, store.metadata) - stmt = select(table).where(table.c.id == 1) + table = store.get_table(schema.name) + stmt = table.filter(table.id == 1) inputs = schema.model_dump() inputs["name"] = make_temp_view_name() schema2 = TableSchema(**inputs) @@ -770,10 +711,8 @@ def test_read_raw_query(iter_stores_by_engine: Store, one_week_per_month_by_hour df2 = store.read_raw_query(query) assert df2.equals(df) - query = f"SELECT * FROM {schema.name} where id = ?" - params = (2,) - with store.engine.connect() as conn: - df2 = store.read_raw_query(query, params=params, connection=conn) + query = f"SELECT * FROM {schema.name} where id = 2" + df2 = store.read_raw_query(query) assert df2.equals(df[df["id"] == 2].reset_index(drop=True)) @@ -782,8 +721,6 @@ def test_check_timestamps(iter_stores_by_engine: Store, one_week_per_month_by_ho df, _, schema = one_week_per_month_by_hour_table store.ingest_table(df, schema) store.check_timestamps(schema.name) - with store.engine.begin() as conn: - store.check_timestamps(schema.name, connection=conn) @pytest.mark.parametrize("to_time_zone", [ZoneInfo("US/Eastern"), ZoneInfo("US/Mountain"), None]) @@ -806,10 +743,10 @@ def test_convert_time_zone( src_csv_schema = CsvTableSchema( time_config=src_time_config, column_dtypes=[ - ColumnDType(name="timestamp", dtype=DateTime(timezone=False)), - ColumnDType(name="gen1", dtype=Double()), - ColumnDType(name="gen2", dtype=Double()), - ColumnDType(name="gen3", dtype=Double()), + ColumnDType(name="timestamp", dtype="datetime"), + ColumnDType(name="gen1", dtype="float"), + ColumnDType(name="gen2", dtype="float"), + ColumnDType(name="gen3", dtype="float"), ], value_columns=["gen1", "gen2", "gen3"], pivoted_dimension_name="generator", @@ -824,14 +761,9 @@ def test_convert_time_zone( time_array_id_columns=["generator"], value_column="value", ) - if store.engine.name == "hive": - out_file = tmp_path / "data.parquet" - rel2.to_df().to_parquet(out_file) - store.create_view_from_parquet(out_file, src_schema) - else: - store.ingest_table(rel2, src_schema) + store.ingest_table(rel2, src_schema) - if tzinfo is None and store.engine.name != "sqlite": + if tzinfo is None and store.backend.name != "sqlite": output_file = tmp_path / "mapped_data" else: output_file = None @@ -839,7 +771,7 @@ def test_convert_time_zone( dst_schema = store.convert_time_zone( src_schema.name, to_time_zone, output_file=output_file, check_mapped_timestamps=True ) - if output_file is None or store.engine.name == "sqlite": + if output_file is None or store.backend.name == "sqlite": df2 = store.read_table(dst_schema.name) else: df2 = pd.read_parquet(output_file) @@ -878,10 +810,10 @@ def test_convert_time_zone_by_column( src_csv_schema = CsvTableSchema( time_config=src_time_config, column_dtypes=[ - ColumnDType(name="timestamp", dtype=DateTime(timezone=False)), - ColumnDType(name="gen1", dtype=Double()), - ColumnDType(name="gen2", dtype=Double()), - ColumnDType(name="gen3", dtype=Double()), + ColumnDType(name="timestamp", dtype="datetime"), + ColumnDType(name="gen1", dtype="float"), + ColumnDType(name="gen2", dtype="float"), + ColumnDType(name="gen3", dtype="float"), ], value_columns=["gen1", "gen2", "gen3"], pivoted_dimension_name="generator", @@ -901,14 +833,9 @@ def test_convert_time_zone_by_column( time_array_id_columns=["generator", "time_zone"], value_column="value", ) - if store.engine.name == "hive": - out_file = tmp_path / "data.parquet" - rel2.to_df().to_parquet(out_file) - store.create_view_from_parquet(out_file, src_schema) - else: - store.ingest_table(rel2, src_schema) + store.ingest_table(rel2, src_schema) - if tzinfo is None and store.engine.name != "sqlite": + if tzinfo is None and store.backend.name != "sqlite": output_file = tmp_path / "mapped_data" else: output_file = None @@ -920,7 +847,7 @@ def test_convert_time_zone_by_column( wrap_time_allowed=wrapped_time_allowed, check_mapped_timestamps=True, ) - if output_file is None or store.engine.name == "sqlite": + if output_file is None or store.backend.name == "sqlite": df2 = store.read_table(dst_schema.name) else: df2 = pd.read_parquet(output_file) @@ -962,10 +889,10 @@ def test_localize_time_zone( src_csv_schema = CsvTableSchema( time_config=src_time_config, column_dtypes=[ - ColumnDType(name="timestamp", dtype=DateTime(timezone=False)), - ColumnDType(name="gen1", dtype=Double()), - ColumnDType(name="gen2", dtype=Double()), - ColumnDType(name="gen3", dtype=Double()), + ColumnDType(name="timestamp", dtype="datetime"), + ColumnDType(name="gen1", dtype="float"), + ColumnDType(name="gen2", dtype="float"), + ColumnDType(name="gen3", dtype="float"), ], value_columns=["gen1", "gen2", "gen3"], pivoted_dimension_name="generator", @@ -980,14 +907,9 @@ def test_localize_time_zone( time_array_id_columns=["generator"], value_column="value", ) - if store.engine.name == "hive": - out_file = tmp_path / "data.parquet" - rel2.to_df().to_parquet(out_file) - store.create_view_from_parquet(out_file, src_schema) - else: - store.ingest_table(rel2, src_schema) + store.ingest_table(rel2, src_schema) - if to_time_zone is None and store.engine.name != "sqlite": + if to_time_zone is None and store.backend.name != "sqlite": output_file = tmp_path / "mapped_data" else: output_file = None @@ -998,7 +920,7 @@ def test_localize_time_zone( output_file=output_file, check_mapped_timestamps=True, ) - if output_file is None or store.engine.name == "sqlite": + if output_file is None or store.backend.name == "sqlite": df2 = store.read_table(dst_schema.name) else: df2 = pd.read_parquet(output_file) @@ -1036,10 +958,10 @@ def test_localize_time_zone_by_column(tmp_path, iter_stores_by_engine_no_data_in src_csv_schema = CsvTableSchema( time_config=src_time_config, column_dtypes=[ - ColumnDType(name="timestamp", dtype=DateTime(timezone=False)), - ColumnDType(name="gen1", dtype=Double()), - ColumnDType(name="gen2", dtype=Double()), - ColumnDType(name="gen3", dtype=Double()), + ColumnDType(name="timestamp", dtype="datetime"), + ColumnDType(name="gen1", dtype="float"), + ColumnDType(name="gen2", dtype="float"), + ColumnDType(name="gen3", dtype="float"), ], value_columns=["gen1", "gen2", "gen3"], pivoted_dimension_name="generator", @@ -1059,14 +981,9 @@ def test_localize_time_zone_by_column(tmp_path, iter_stores_by_engine_no_data_in time_array_id_columns=["generator", "time_zone"], value_column="value", ) - if store.engine.name == "hive": - out_file = tmp_path / "data.parquet" - rel2.to_df().to_parquet(out_file) - store.create_view_from_parquet(out_file, src_schema) - else: - store.ingest_table(rel2, src_schema) + store.ingest_table(rel2, src_schema) - if store.engine.name != "sqlite": + if store.backend.name != "sqlite": output_file = tmp_path / "mapped_data" else: output_file = None @@ -1077,7 +994,7 @@ def test_localize_time_zone_by_column(tmp_path, iter_stores_by_engine_no_data_in output_file=output_file, check_mapped_timestamps=True, ) - if output_file is None or store.engine.name == "sqlite": + if output_file is None or store.backend.name == "sqlite": df2 = store.read_table(dst_schema.name) else: df2 = pd.read_parquet(output_file) diff --git a/tests/test_time_series_checker.py b/tests/test_time_series_checker.py index e1c32b6..f48d95d 100644 --- a/tests/test_time_series_checker.py +++ b/tests/test_time_series_checker.py @@ -4,67 +4,63 @@ import pandas as pd import pytest -from sqlalchemy import ( - Engine, - MetaData, - Table, -) + +from chronify.ibis import IbisBackend +from chronify.ibis.functions import write_table from chronify.exceptions import InvalidTable from chronify.models import TableSchema -from chronify.sqlalchemy.functions import write_database from chronify.time import TimeIntervalType from chronify.time_configs import DatetimeRange from chronify.time_series_checker import check_timestamps -def test_valid_datetimes_with_tz(iter_engines: Engine) -> None: +def test_valid_datetimes_with_tz(iter_backends: IbisBackend) -> None: """Valid timestamps with time zones.""" - _run_test(iter_engines, *_get_inputs_for_valid_datetimes_with_tz()) + _run_test(iter_backends, *_get_inputs_for_valid_datetimes_with_tz()) -def test_valid_datetimes_without_tz(iter_engines: Engine) -> None: +def test_valid_datetimes_without_tz(iter_backends: IbisBackend) -> None: """Valid timestamps without time zones.""" - _run_test(iter_engines, *_get_inputs_for_valid_datetimes_without_tz()) + _run_test(iter_backends, *_get_inputs_for_valid_datetimes_without_tz()) -def test_invalid_datetimes(iter_engines: Engine) -> None: +def test_invalid_datetimes(iter_backends: IbisBackend) -> None: """Timestamps do not match the schema.""" - _run_test(iter_engines, *_get_inputs_for_incorrect_datetimes()) + _run_test(iter_backends, *_get_inputs_for_incorrect_datetimes()) -def test_invalid_datetime_length(iter_engines: Engine) -> None: +def test_invalid_datetime_length(iter_backends: IbisBackend) -> None: """Timestamps do not match the schema.""" - _run_test(iter_engines, *_get_inputs_for_incorrect_datetime_length()) + _run_test(iter_backends, *_get_inputs_for_incorrect_datetime_length()) -def test_mismatched_time_array_lengths(iter_engines: Engine) -> None: +def test_mismatched_time_array_lengths(iter_backends: IbisBackend) -> None: """Some time arrays have different lengths.""" - _run_test(iter_engines, *_get_inputs_for_mismatched_time_array_lengths()) + _run_test(iter_backends, *_get_inputs_for_mismatched_time_array_lengths()) -def test_incorrect_lengths(iter_engines: Engine) -> None: +def test_incorrect_lengths(iter_backends: IbisBackend) -> None: """All time arrays are consistent but have the wrong length.""" - _run_test(iter_engines, *_get_inputs_for_incorrect_lengths()) + _run_test(iter_backends, *_get_inputs_for_incorrect_lengths()) -def test_incorrect_time_arrays(iter_engines: Engine) -> None: +def test_incorrect_time_arrays(iter_backends: IbisBackend) -> None: """The time arrays form a complete set but are individually incorrect.""" - _run_test(iter_engines, *_get_inputs_for_incorrect_time_arrays()) + _run_test(iter_backends, *_get_inputs_for_incorrect_time_arrays()) -def test_incorrect_time_arrays_with_duplicates(iter_engines: Engine) -> None: +def test_incorrect_time_arrays_with_duplicates(iter_backends: IbisBackend) -> None: """The time arrays form a complete set but are individually incorrect.""" - _run_test(iter_engines, *_get_inputs_for_incorrect_time_arrays_with_duplicates()) + _run_test(iter_backends, *_get_inputs_for_incorrect_time_arrays_with_duplicates()) def _run_test( - engine: Engine, + backend: IbisBackend, df: pd.DataFrame, tzinfo: Optional[tzinfo], length: int, message: Optional[str], ) -> None: - metadata = MetaData() schema = TableSchema( name="generators", time_config=DatetimeRange( @@ -77,17 +73,13 @@ def _run_test( time_array_id_columns=["generator"], value_column="value", ) - with engine.begin() as conn: - write_database(df, conn, schema.name, [schema.time_config], if_table_exists="replace") - metadata.reflect(engine) - - with engine.connect() as conn: - table = Table(schema.name, metadata) - if message is None: - check_timestamps(conn, table, schema) - else: - with pytest.raises(InvalidTable, match=message): - check_timestamps(conn, table, schema) + write_table(backend, df, schema.name, [schema.time_config], if_exists="replace") + + if message is None: + check_timestamps(backend, schema.name, schema) + else: + with pytest.raises(InvalidTable, match=message): + check_timestamps(backend, schema.name, schema) def _get_inputs_for_valid_datetimes_with_tz() -> tuple[pd.DataFrame, ZoneInfo, int, None]: diff --git a/tests/test_time_zone_converter.py b/tests/test_time_zone_converter.py index aca8cbf..e76e37b 100644 --- a/tests/test_time_zone_converter.py +++ b/tests/test_time_zone_converter.py @@ -5,9 +5,9 @@ from typing import Any import pandas as pd -from sqlalchemy import Engine, MetaData -from chronify.sqlalchemy.functions import read_database, write_database +from chronify.ibis import IbisBackend +from chronify.ibis.functions import read_query, write_table from chronify.time_zone_converter import ( TimeZoneConverter, TimeZoneConverterByColumn, @@ -85,40 +85,33 @@ def get_datetime_schema( def ingest_data( - engine: Engine, - metadata: MetaData, + backend: IbisBackend, df: pd.DataFrame, schema: TableSchema, ) -> None: - with engine.begin() as conn: - write_database(df, conn, schema.name, [schema.time_config], if_table_exists="replace") - metadata.reflect(engine, views=True) + write_table(backend, df, schema.name, [schema.time_config], if_exists="replace") def get_mapped_dataframe( - engine: Engine, + backend: IbisBackend, table_name: str, time_config: DatetimeRange, ) -> pd.DataFrame: - with engine.connect() as conn: - query = f"select * from {table_name}" - queried = read_database(query, conn, time_config) + expr = backend.sql(f"select * from {table_name}") + queried = read_query(backend, expr, time_config) queried = queried.sort_values(by=["id", "timestamp"]).reset_index(drop=True) return queried def run_conversion( - engine: Engine, + backend: IbisBackend, df: pd.DataFrame, from_schema: TableSchema, to_time_zone: tzinfo | None, ) -> None: - metadata = MetaData() - ingest_data(engine, metadata, df, from_schema) - to_schema = convert_time_zone( - engine, metadata, from_schema, to_time_zone, check_mapped_timestamps=True - ) - dfo = get_mapped_dataframe(engine, to_schema.name, to_schema.time_config) + ingest_data(backend, df, from_schema) + to_schema = convert_time_zone(backend, from_schema, to_time_zone, check_mapped_timestamps=True) + dfo = get_mapped_dataframe(backend, to_schema.name, to_schema.time_config) assert df["value"].equals(dfo["value"]) if to_time_zone is None: expected = df["timestamp"].dt.tz_localize(None) @@ -128,24 +121,22 @@ def run_conversion( def run_conversion_to_column_time_zones( - engine: Engine, + backend: IbisBackend, df: pd.DataFrame, from_schema: TableSchema, wrap_time_allowed: bool, ) -> None: - metadata = MetaData() - ingest_data(engine, metadata, df, from_schema) + ingest_data(backend, df, from_schema) to_schema = convert_time_zone_by_column( - engine, - metadata, + backend, from_schema, "time_zone", wrap_time_allowed=wrap_time_allowed, check_mapped_timestamps=True, ) - dfo = get_mapped_dataframe(engine, to_schema.name, to_schema.time_config) + dfo = get_mapped_dataframe(backend, to_schema.name, to_schema.time_config) dfo = dfo[df.columns].sort_values(by="index").reset_index(drop=True) - dfo["timestamp"] = pd.to_datetime(dfo["timestamp"]) # needed for engine 2, not sure why + dfo["timestamp"] = pd.to_datetime(dfo["timestamp"]) # needed for sqlite assert df["value"].equals(dfo["value"]) if wrap_time_allowed: @@ -163,46 +154,45 @@ def run_conversion_to_column_time_zones( def run_conversion_with_error( - engine: Engine, + backend: IbisBackend, df: pd.DataFrame, from_schema: TableSchema, use_tz_col: bool, error: tuple[Any, str], ) -> None: - metadata = MetaData() - ingest_data(engine, metadata, df, from_schema) + ingest_data(backend, df, from_schema) with pytest.raises(error[0], match=error[1]): if use_tz_col: tzc = TimeZoneConverterByColumn( - engine, metadata, from_schema, "time_zone", wrap_time_allowed=False + backend, from_schema, "time_zone", wrap_time_allowed=False ) tzc.convert_time_zone(check_mapped_timestamps=True) else: - tzc2 = TimeZoneConverter(engine, metadata, from_schema, None) + tzc2 = TimeZoneConverter(backend, from_schema, None) tzc2.convert_time_zone(check_mapped_timestamps=True) -def test_src_table_no_time_zone(iter_engines: Engine) -> None: +def test_src_table_no_time_zone(iter_backends: IbisBackend) -> None: from_schema = get_datetime_schema(2018, None, TimeIntervalType.PERIOD_BEGINNING, "base_table") df = generate_datetime_dataframe(from_schema) error = (InvalidParameter, "Source schema time config start time must be timezone-aware") - run_conversion_with_error(iter_engines, df, from_schema, False, error) + run_conversion_with_error(iter_backends, df, from_schema, False, error) @pytest.mark.parametrize( "to_time_zone", [None, ZoneInfo("US/Central"), ZoneInfo("America/Los_Angeles")] ) -def test_time_conversion(iter_engines: Engine, to_time_zone: tzinfo | None) -> None: +def test_time_conversion(iter_backends: IbisBackend, to_time_zone: tzinfo | None) -> None: from_schema = get_datetime_schema( 2018, ZoneInfo("US/Mountain"), TimeIntervalType.PERIOD_BEGINNING, "base_table" ) df = generate_datetime_dataframe(from_schema) - run_conversion(iter_engines, df, from_schema, to_time_zone) + run_conversion(iter_backends, df, from_schema, to_time_zone) @pytest.mark.parametrize("wrap_time_allowed", [False, True]) def test_time_conversion_to_column_time_zones( - iter_engines: Engine, wrap_time_allowed: bool + iter_backends: IbisBackend, wrap_time_allowed: bool ) -> None: from_schema = get_datetime_schema( 2018, @@ -212,4 +202,4 @@ def test_time_conversion_to_column_time_zones( has_tz_col=True, ) df = generate_dataframe_with_tz_col(from_schema) - run_conversion_to_column_time_zones(iter_engines, df, from_schema, wrap_time_allowed) + run_conversion_to_column_time_zones(iter_backends, df, from_schema, wrap_time_allowed) diff --git a/tests/test_time_zone_localizer.py b/tests/test_time_zone_localizer.py index ce7bb0b..755f83b 100644 --- a/tests/test_time_zone_localizer.py +++ b/tests/test_time_zone_localizer.py @@ -5,9 +5,9 @@ from typing import Any import pandas as pd -from sqlalchemy import Engine, MetaData -from chronify.sqlalchemy.functions import read_database, write_database +from chronify.ibis import IbisBackend +from chronify.ibis.functions import read_query, write_table from chronify.time_utils import get_standard_time_zone from chronify.time_zone_localizer import ( TimeZoneLocalizer, @@ -128,40 +128,35 @@ def get_datetime_with_tz_col_schema( def ingest_data( - engine: Engine, - metadata: MetaData, + backend: IbisBackend, df: pd.DataFrame, schema: TableSchema, ) -> None: - with engine.begin() as conn: - write_database(df, conn, schema.name, [schema.time_config], if_table_exists="replace") - metadata.reflect(engine, views=True) + write_table(backend, df, schema.name, [schema.time_config], if_exists="replace") def get_mapped_dataframe( - engine: Engine, + backend: IbisBackend, table_name: str, time_config: DatetimeRangeBase, ) -> pd.DataFrame: - with engine.connect() as conn: - query = f"select * from {table_name}" - queried = read_database(query, conn, time_config) + expr = backend.sql(f"select * from {table_name}") + queried = read_query(backend, expr, time_config) queried = queried.sort_values(by=["id", "timestamp"]).reset_index(drop=True) return queried def run_localization( - engine: Engine, + backend: IbisBackend, df: pd.DataFrame, from_schema: TableSchema, to_time_zone: tzinfo | None, ) -> None: - metadata = MetaData() - ingest_data(engine, metadata, df, from_schema) + ingest_data(backend, df, from_schema) to_schema = localize_time_zone( - engine, metadata, from_schema, to_time_zone, check_mapped_timestamps=True + backend, from_schema, to_time_zone, check_mapped_timestamps=True ) - dfo = get_mapped_dataframe(engine, to_schema.name, to_schema.time_config) + dfo = get_mapped_dataframe(backend, to_schema.name, to_schema.time_config) assert df["value"].equals(dfo["value"]) if to_time_zone is None: expected = df["timestamp"] @@ -173,21 +168,19 @@ def run_localization( def run_localization_to_column_time_zones( - engine: Engine, + backend: IbisBackend, df: pd.DataFrame, from_schema: TableSchema, ) -> None: - metadata = MetaData() - ingest_data(engine, metadata, df, from_schema) + ingest_data(backend, df, from_schema) to_schema = localize_time_zone_by_column( - engine, - metadata, + backend, from_schema, check_mapped_timestamps=True, ) - dfo = get_mapped_dataframe(engine, to_schema.name, to_schema.time_config) + dfo = get_mapped_dataframe(backend, to_schema.name, to_schema.time_config) dfo = dfo[df.columns].sort_values(by="index").reset_index(drop=True) - dfo["timestamp"] = pd.to_datetime(dfo["timestamp"]) # needed for engine 2, not sure why + dfo["timestamp"] = pd.to_datetime(dfo["timestamp"]) # needed for sqlite assert df["value"].equals(dfo["value"]) for i in range(len(dfo)): tzn = dfo.loc[i, "time_zone"] @@ -201,48 +194,47 @@ def run_localization_to_column_time_zones( def run_localization_with_error( - engine: Engine, + backend: IbisBackend, df: pd.DataFrame, from_schema: TableSchema, error: tuple[Any, str], ) -> None: - metadata = MetaData() - ingest_data(engine, metadata, df, from_schema) + ingest_data(backend, df, from_schema) with pytest.raises(error[0], match=error[1]): - TimeZoneLocalizer(engine, metadata, from_schema, None).localize_time_zone( + TimeZoneLocalizer(backend, from_schema, None).localize_time_zone( check_mapped_timestamps=True ) def run_localization_by_column_with_error( - engine: Engine, + backend: IbisBackend, df: pd.DataFrame, from_schema: TableSchema, error: tuple[Any, str], time_zone_column: str | None = None, ) -> None: - metadata = MetaData() - ingest_data(engine, metadata, df, from_schema) + ingest_data(backend, df, from_schema) with pytest.raises(error[0], match=error[1]): TimeZoneLocalizerByColumn( - engine, - metadata, + backend, from_schema, time_zone_column=time_zone_column, ).localize_time_zone(check_mapped_timestamps=True) @pytest.mark.parametrize("to_time_zone", [None, ZoneInfo("Etc/GMT+5")]) -def test_time_localization(iter_engines: Engine, to_time_zone: tzinfo | None) -> None: +def test_time_localization(iter_backends: IbisBackend, to_time_zone: tzinfo | None) -> None: from_schema = get_datetime_schema(2018, None, TimeIntervalType.PERIOD_BEGINNING, "base_table") df = generate_datetime_dataframe(from_schema) - run_localization(iter_engines, df, from_schema, to_time_zone) + run_localization(iter_backends, df, from_schema, to_time_zone) @pytest.mark.parametrize("from_time_tz", [None, ZoneInfo("US/Mountain"), ZoneInfo("MST")]) -def test_time_localization_by_column(iter_engines: Engine, from_time_tz: tzinfo | None) -> None: +def test_time_localization_by_column( + iter_backends: IbisBackend, from_time_tz: tzinfo | None +) -> None: from_schema = get_datetime_with_tz_col_schema( 2018, from_time_tz, @@ -251,36 +243,33 @@ def test_time_localization_by_column(iter_engines: Engine, from_time_tz: tzinfo standard_tz=True, ) df = generate_dataframe_with_tz_col(from_schema) - run_localization_to_column_time_zones(iter_engines, df, from_schema) + run_localization_to_column_time_zones(iter_backends, df, from_schema) # Error tests for TimeZoneLocalizer -def test_time_localizer_to_dst_time_error(iter_engines: Engine) -> None: +def test_time_localizer_to_dst_time_error(iter_backends: IbisBackend) -> None: """Test that TimeZoneLocalizer raises error when to_time_zone is a non standard time zone""" from_schema = get_datetime_schema(2018, None, TimeIntervalType.PERIOD_BEGINNING, "base_table") df = generate_datetime_dataframe(from_schema) to_time_zone = ZoneInfo("US/Mountain") # has DST - metadata = MetaData() - ingest_data(iter_engines, metadata, df, from_schema) + ingest_data(iter_backends, df, from_schema) with pytest.raises( InvalidParameter, match="TimeZoneLocalizer only supports standard time zones" ): - localize_time_zone( - iter_engines, metadata, from_schema, to_time_zone, check_mapped_timestamps=True - ) + localize_time_zone(iter_backends, from_schema, to_time_zone, check_mapped_timestamps=True) -def test_time_localizer_with_tz_aware_config_error(iter_engines: Engine) -> None: +def test_time_localizer_with_tz_aware_config_error(iter_backends: IbisBackend) -> None: """Test that TimeZoneLocalizer raises error when start time is tz-aware""" from_schema = get_datetime_schema( 2018, ZoneInfo("US/Mountain"), TimeIntervalType.PERIOD_BEGINNING, "base_table" ) df = generate_datetime_dataframe(from_schema) error = (InvalidParameter, "Source schema time config start time must be tz-naive") - run_localization_with_error(iter_engines, df, from_schema, error) + run_localization_with_error(iter_backends, df, from_schema, error) -def test_time_localizer_with_wrong_dtype_error(iter_engines: Engine) -> None: +def test_time_localizer_with_wrong_dtype_error(iter_backends: IbisBackend) -> None: """Test that TimeZoneLocalizer raises error when dtype is not TIMESTAMP_NTZ""" from_schema = get_datetime_schema(2018, None, TimeIntervalType.PERIOD_BEGINNING, "base_table") # Manually change dtype to TIMESTAMP_TZ to trigger error @@ -289,21 +278,21 @@ def test_time_localizer_with_wrong_dtype_error(iter_engines: Engine) -> None: ) df = generate_datetime_dataframe(from_schema) error = (InvalidParameter, "Source schema time config dtype must be TIMESTAMP_NTZ") - run_localization_with_error(iter_engines, df, from_schema, error) + run_localization_with_error(iter_backends, df, from_schema, error) -def test_time_localizer_with_datetime_range_with_tz_col_error(iter_engines: Engine) -> None: +def test_time_localizer_with_datetime_range_with_tz_col_error(iter_backends: IbisBackend) -> None: """Test that TimeZoneLocalizer raises error when time config is DatetimeRangeWithTZColumn""" from_schema = get_datetime_with_tz_col_schema( 2018, None, TimeIntervalType.PERIOD_BEGINNING, "base_table", standard_tz=True ) df = generate_dataframe_with_tz_col(from_schema) error = (InvalidParameter, "try using TimeZoneLocalizerByColumn") - run_localization_with_error(iter_engines, df, from_schema, error) + run_localization_with_error(iter_backends, df, from_schema, error) # Error tests for TimeZoneLocalizerByColumn -def test_time_localizer_by_column_to_dst_time_error(iter_engines: Engine) -> None: +def test_time_localizer_by_column_to_dst_time_error(iter_backends: IbisBackend) -> None: """Test that TimeZoneLocalizerByColumn raises error when to_time_zone is a non standard time zone""" from_schema = get_datetime_with_tz_col_schema( 2018, @@ -313,25 +302,22 @@ def test_time_localizer_by_column_to_dst_time_error(iter_engines: Engine) -> Non standard_tz=False, ) df = generate_dataframe_with_tz_col(from_schema) - metadata = MetaData() - ingest_data(iter_engines, metadata, df, from_schema) + ingest_data(iter_backends, df, from_schema) with pytest.raises( InvalidParameter, match="TimeZoneLocalizerByColumn only supports standard time zones" ): - localize_time_zone_by_column( - iter_engines, metadata, from_schema, check_mapped_timestamps=True - ) + localize_time_zone_by_column(iter_backends, from_schema, check_mapped_timestamps=True) -def test_time_localizer_by_column_missing_tz_column_error(iter_engines: Engine) -> None: +def test_time_localizer_by_column_missing_tz_column_error(iter_backends: IbisBackend) -> None: """Test that TimeZoneLocalizerByColumn raises error when time_zone_column is missing for DatetimeRange""" from_schema = get_datetime_schema(2018, None, TimeIntervalType.PERIOD_BEGINNING, "base_table") df = generate_datetime_dataframe(from_schema) error = (MissingValue, "time_zone_column must be provided") - run_localization_by_column_with_error(iter_engines, df, from_schema, error) + run_localization_by_column_with_error(iter_backends, df, from_schema, error) -def test_time_localizer_by_column_wrong_dtype_error(iter_engines: Engine) -> None: +def test_time_localizer_by_column_wrong_dtype_error(iter_backends: IbisBackend) -> None: """Test that TimeZoneLocalizerByColumn raises error when dtype is not TIMESTAMP_NTZ""" from_schema = get_datetime_with_tz_col_schema( 2018, None, TimeIntervalType.PERIOD_BEGINNING, "base_table", standard_tz=True @@ -342,20 +328,20 @@ def test_time_localizer_by_column_wrong_dtype_error(iter_engines: Engine) -> Non ) df = generate_dataframe_with_tz_col(from_schema) error = (InvalidParameter, "Source schema time config dtype must be TIMESTAMP_NTZ") - run_localization_by_column_with_error(iter_engines, df, from_schema, error) + run_localization_by_column_with_error(iter_backends, df, from_schema, error) -def test_time_localizer_by_column_non_standard_tz_error(iter_engines: Engine) -> None: +def test_time_localizer_by_column_non_standard_tz_error(iter_backends: IbisBackend) -> None: """Test that TimeZoneLocalizerByColumn raises error when time zones are not standard""" from_schema = get_datetime_with_tz_col_schema( 2018, None, TimeIntervalType.PERIOD_BEGINNING, "base_table", standard_tz=False ) df = generate_dataframe_with_tz_col(from_schema) error = (InvalidParameter, "is not a standard time zone") - run_localization_by_column_with_error(iter_engines, df, from_schema, error) + run_localization_by_column_with_error(iter_backends, df, from_schema, error) -def test_localize_time_zone_by_column_missing_tz_column_error(iter_engines: Engine) -> None: +def test_localize_time_zone_by_column_missing_tz_column_error(iter_backends: IbisBackend) -> None: """Test that localize_time_zone_by_column raises error when time_zone_column is None for DatetimeRange""" from_schema = get_datetime_schema(2018, None, TimeIntervalType.PERIOD_BEGINNING, "base_table") df = generate_datetime_dataframe(from_schema) @@ -364,5 +350,5 @@ def test_localize_time_zone_by_column_missing_tz_column_error(iter_engines: Engi "time_zone_column must be provided when source schema time config is of type DatetimeRange", ) run_localization_by_column_with_error( - iter_engines, df, from_schema, error, time_zone_column=None + iter_backends, df, from_schema, error, time_zone_column=None ) From 26393c7ce1ca01e68d9269034c5ec5b3d64a58b8 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Wed, 8 Apr 2026 18:52:50 -0600 Subject: [PATCH 02/48] Fix review issues: cleanup safety, column ordering, schema uniqueness - Fix create_view_from_parquet to return ObjectType so callers drop the correct object type (SQLite creates a table, not a view) - Handle directory paths in DuckDB create_view_from_parquet for partitioned parquet datasets - Add unique index on schemas table name column to prevent duplicates - Escape single quotes in schema remove_schema to prevent SQL injection - Add dispose() teardown to test fixtures to prevent resource leaks - Use explicit column ordering in DuckDB insert to prevent column mismatch when DataFrame column order differs from table - Clean up schema on failed ingestion rollback in all Store methods Co-Authored-By: Claude Opus 4.6 --- src/chronify/ibis/base.py | 8 ++++++-- src/chronify/ibis/duckdb_backend.py | 23 +++++++++++++++++------ src/chronify/ibis/functions.py | 12 ++++++++---- src/chronify/ibis/spark_backend.py | 6 +++--- src/chronify/ibis/sqlite_backend.py | 6 +++--- src/chronify/schema_manager.py | 9 ++++++++- src/chronify/store.py | 18 +++++++++++++----- src/chronify/time_series_mapper_base.py | 16 ++++++++++------ tests/conftest.py | 6 +++++- 9 files changed, 73 insertions(+), 31 deletions(-) diff --git a/src/chronify/ibis/base.py b/src/chronify/ibis/base.py index fdcd22a..eb81057 100644 --- a/src/chronify/ibis/base.py +++ b/src/chronify/ibis/base.py @@ -99,8 +99,12 @@ def write_parquet( """Write an ibis expression result to a Parquet file.""" @abstractmethod - def create_view_from_parquet(self, path: str, name: str) -> ir.Table: - """Create a view or table backed by a Parquet file.""" + def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, ObjectType]: + """Create a view or table backed by a Parquet file. + + Returns the table expression and the type of object created, since some + backends (e.g., SQLite) must create a table instead of a view. + """ def has_table(self, name: str) -> bool: """Check whether a table or view exists.""" diff --git a/src/chronify/ibis/duckdb_backend.py b/src/chronify/ibis/duckdb_backend.py index 1f18d3b..4004019 100644 --- a/src/chronify/ibis/duckdb_backend.py +++ b/src/chronify/ibis/duckdb_backend.py @@ -7,7 +7,7 @@ import pandas as pd from loguru import logger -from chronify.ibis.base import IbisBackend +from chronify.ibis.base import IbisBackend, ObjectType class DuckDBBackend(IbisBackend): @@ -57,9 +57,15 @@ def table(self, name: str) -> ir.Table: def insert(self, name: str, data: pd.DataFrame) -> None: con = self._connection.con # raw duckdb connection - con.register("__insert_df", data) + target_columns = list(self.table(name).columns) + ordered_data = data.reindex(columns=target_columns) + quoted_columns = ", ".join(f'"{col}"' for col in target_columns) + con.register("__insert_df", ordered_data) try: - con.execute(f"INSERT INTO {name} SELECT * FROM __insert_df") + con.execute( + f"INSERT INTO {name} ({quoted_columns}) " + f"SELECT {quoted_columns} FROM __insert_df" + ) finally: con.unregister("__insert_df") logger.trace("Inserted {} rows into {}", len(data), name) @@ -86,11 +92,16 @@ def write_parquet( df = self._connection.execute(expr) df.to_parquet(path) - def create_view_from_parquet(self, path: str, name: str) -> ir.Table: + def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, ObjectType]: + parquet_path = Path(path) + if parquet_path.is_dir(): + read_path = str(parquet_path / "**" / "*.parquet").replace("\\", "/") + else: + read_path = str(parquet_path).replace("\\", "/") self._connection.raw_sql( - f"CREATE VIEW {name} AS SELECT * FROM read_parquet('{path}')" + f"CREATE VIEW {name} AS SELECT * FROM read_parquet('{read_path}')" ) - return self.table(name) + return self.table(name), ObjectType.VIEW def execute_sql(self, query: str) -> None: logger.trace("execute_sql: {}", query) diff --git a/src/chronify/ibis/functions.py b/src/chronify/ibis/functions.py index 4015ba7..7c11cf4 100644 --- a/src/chronify/ibis/functions.py +++ b/src/chronify/ibis/functions.py @@ -10,7 +10,7 @@ from pandas import DatetimeTZDtype from chronify.exceptions import InvalidOperation, InvalidParameter -from chronify.ibis.base import IbisBackend +from chronify.ibis.base import IbisBackend, ObjectType from chronify.time import TimeDataType from chronify.time_configs import ( DatetimeRange, @@ -99,9 +99,13 @@ def create_view_from_parquet( backend: IbisBackend, filename: Path, view_name: str, -) -> None: - """Create a view from a Parquet file.""" - backend.create_view_from_parquet(str(filename), view_name) +) -> ObjectType: + """Create a view (or table for SQLite) from a Parquet file. + + Returns the ObjectType created so callers can clean up correctly. + """ + _, obj_type = backend.create_view_from_parquet(str(filename), view_name) + return obj_type def _check_one_config_per_datetime_column(configs: Sequence[TimeBaseModel]) -> None: diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index f59790b..942af92 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -7,7 +7,7 @@ import pandas as pd from loguru import logger -from chronify.ibis.base import IbisBackend +from chronify.ibis.base import IbisBackend, ObjectType class SparkBackend(IbisBackend): @@ -97,10 +97,10 @@ def write_parquet( else: df.to_parquet(path) - def create_view_from_parquet(self, path: str, name: str) -> ir.Table: + def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, ObjectType]: spark_df = self._session.read.parquet(path) spark_df.createOrReplaceTempView(name) - return self.table(name) + return self.table(name), ObjectType.VIEW def execute_sql(self, query: str) -> None: logger.trace("execute_sql: {}", query) diff --git a/src/chronify/ibis/sqlite_backend.py b/src/chronify/ibis/sqlite_backend.py index 2baf344..0424290 100644 --- a/src/chronify/ibis/sqlite_backend.py +++ b/src/chronify/ibis/sqlite_backend.py @@ -8,7 +8,7 @@ import pyarrow as pa from loguru import logger -from chronify.ibis.base import IbisBackend +from chronify.ibis.base import IbisBackend, ObjectType class SQLiteBackend(IbisBackend): @@ -94,10 +94,10 @@ def write_parquet( df = self._connection.execute(expr) df.to_parquet(path) - def create_view_from_parquet(self, path: str, name: str) -> ir.Table: + def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, ObjectType]: # SQLite can't read Parquet natively. Load into a table instead. df = pd.read_parquet(path) - return self.create_table(name, obj=df) + return self.create_table(name, obj=df), ObjectType.TABLE def execute_sql(self, query: str) -> None: logger.trace("execute_sql: {}", query) diff --git a/src/chronify/schema_manager.py b/src/chronify/schema_manager.py index 2cdfba3..c0682da 100644 --- a/src/chronify/schema_manager.py +++ b/src/chronify/schema_manager.py @@ -29,6 +29,10 @@ def _create_schemas_table(self) -> None: schema = ibis.schema({"name": "string", "schema": "string"}) self._backend.create_table(self.SCHEMAS_TABLE, schema=schema) + self._backend.execute_sql( + f"CREATE UNIQUE INDEX idx_{self.SCHEMAS_TABLE}_name " + f"ON {self.SCHEMAS_TABLE} (name)" + ) def add_schema(self, schema: TableSchema) -> None: """Add the schema to the store.""" @@ -52,7 +56,10 @@ def get_schema(self, name: str) -> TableSchema: def remove_schema(self, name: str) -> None: """Remove the schema from the store.""" - self._backend.execute_sql(f"DELETE FROM {self.SCHEMAS_TABLE} WHERE name = '{name}'") + safe_name = name.replace("'", "''") + self._backend.execute_sql( + f"DELETE FROM {self.SCHEMAS_TABLE} WHERE name = '{safe_name}'" + ) self._cache.pop(name, None) def rebuild_cache(self) -> None: diff --git a/src/chronify/store.py b/src/chronify/store.py index c647afb..43f7c7e 100644 --- a/src/chronify/store.py +++ b/src/chronify/store.py @@ -20,7 +20,7 @@ TableNotStored, ) from chronify.csv_io import read_csv -from chronify.ibis import IbisBackend, make_backend +from chronify.ibis import IbisBackend, ObjectType, make_backend from chronify.ibis.functions import ( create_view_from_parquet, read_query, @@ -177,18 +177,24 @@ def create_view_from_parquet( self, path: Path, schema: TableSchema, bypass_checks: bool = False ) -> None: """Load a table into the database from a Parquet file.""" - self._create_view_from_parquet(path, schema) + obj_type = self._create_view_from_parquet(path, schema) try: if not bypass_checks: check_timestamps(self._backend, schema.name, schema) except InvalidTable: - self.drop_view(schema.name) + if obj_type == ObjectType.TABLE: + self._backend.drop_table(schema.name) + else: + self._backend.drop_view(schema.name) raise - def _create_view_from_parquet(self, path: Path | str, schema: TableSchema) -> None: + def _create_view_from_parquet( + self, path: Path | str, schema: TableSchema + ) -> "ObjectType": """Create a view in the database from a Parquet file.""" - create_view_from_parquet(self._backend, to_path(path), schema.name) + obj_type = create_view_from_parquet(self._backend, to_path(path), schema.name) self._schema_mgr.add_schema(schema) + return obj_type def ingest_from_csv( self, @@ -275,6 +281,7 @@ def ingest_pivoted_tables( except Exception: if self._backend.has_table(dst_schema.name): self._backend.drop_table(dst_schema.name) + self._schema_mgr.remove_schema(dst_schema.name) raise return created_table @@ -356,6 +363,7 @@ def ingest_tables( except Exception: if self._backend.has_table(schema.name): self._backend.drop_table(schema.name) + self._schema_mgr.remove_schema(schema.name) raise return created_table diff --git a/src/chronify/time_series_mapper_base.py b/src/chronify/time_series_mapper_base.py index a875966..15ae672 100644 --- a/src/chronify/time_series_mapper_base.py +++ b/src/chronify/time_series_mapper_base.py @@ -5,7 +5,7 @@ import pandas as pd from loguru import logger -from chronify.ibis.base import IbisBackend +from chronify.ibis.base import IbisBackend, ObjectType from chronify.ibis.functions import write_parquet, write_table, create_view_from_parquet from chronify.models import TableSchema, MappingTableSchema from chronify.exceptions import ConflictingInputsError, InvalidOperation @@ -94,7 +94,7 @@ def apply_mapping( mapping_schema.time_configs, if_exists="fail", ) - created_tmp_view = False + created_tmp_obj: Optional[ObjectType] = None try: _apply_mapping( mapping_schema.name, @@ -107,8 +107,9 @@ def apply_mapping( if check_mapped_timestamps: if output_file is not None: output_file = to_path(output_file) - create_view_from_parquet(backend, output_file, to_schema.name) - created_tmp_view = True + created_tmp_obj = create_view_from_parquet( + backend, output_file, to_schema.name + ) try: check_timestamps( backend, @@ -127,8 +128,11 @@ def apply_mapping( finally: if backend.has_table(mapping_schema.name): backend.drop_table(mapping_schema.name) - if created_tmp_view: - backend.drop_view(to_schema.name) + if created_tmp_obj is not None: + if created_tmp_obj == ObjectType.TABLE: + backend.drop_table(to_schema.name) + else: + backend.drop_view(to_schema.name) def _apply_mapping( # noqa: C901 diff --git a/tests/conftest.py b/tests/conftest.py index 059afad..7d4d277 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,7 +24,9 @@ def create_duckdb_backend() -> IbisBackend: @pytest.fixture(params=BACKEND_NAMES) def iter_backends(request) -> Generator[IbisBackend, None, None]: """Return an iterable of in-memory backends to test.""" - yield make_backend(request.param) + backend = make_backend(request.param) + yield backend + backend.dispose() @pytest.fixture(params=BACKEND_NAMES) @@ -42,6 +44,7 @@ def iter_stores_by_engine_no_data_ingestion(request) -> Generator[Store, None, N backend = make_backend(request.param) store = Store(backend=backend) yield store + store.dispose() @pytest.fixture(params=BACKEND_NAMES) @@ -50,6 +53,7 @@ def iter_backends_file(request, tmp_path) -> Generator[tuple[IbisBackend, str], file_path = tmp_path / "store.db" backend = make_backend(request.param, database=str(file_path)) yield backend, request.param + backend.dispose() @pytest.fixture(params=BACKEND_NAMES) From 6584bd412c4b59134a54fbec0e11a64737f66eb5 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Thu, 9 Apr 2026 10:29:20 -0600 Subject: [PATCH 03/48] Address PR comments --- pyproject.toml | 11 ++++++- src/chronify/ibis/base.py | 12 ++++++-- src/chronify/ibis/duckdb_backend.py | 25 ++++++++++++--- src/chronify/ibis/spark_backend.py | 36 ++++++++++++++++++---- src/chronify/ibis/sqlite_backend.py | 27 +++++++++++++--- src/chronify/models.py | 24 +++------------ src/chronify/schema_manager.py | 16 +++++----- src/chronify/store.py | 48 ++++++++++++++--------------- 8 files changed, 127 insertions(+), 72 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6b20b06..a77da99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ [project.optional-dependencies] spark = [ "ibis-framework[pyspark]", - "pyspark >= 4.0", + "pyspark == 4.0.0", ] dev = [ @@ -70,6 +70,15 @@ files = [ ] strict = true +[[tool.mypy.overrides]] +module = [ + "ibis", + "ibis.*", + "pyarrow", + "pyarrow.*", +] +ignore_missing_imports = true + [tool.pytest.ini_options] pythonpath = "src" minversion = "6.0" diff --git a/src/chronify/ibis/base.py b/src/chronify/ibis/base.py index eb81057..b3aedae 100644 --- a/src/chronify/ibis/base.py +++ b/src/chronify/ibis/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from contextlib import contextmanager from enum import Enum -from typing import Any, Generator +from typing import Any, Generator, cast import ibis import ibis.expr.types as ir @@ -81,6 +81,14 @@ def table(self, name: str) -> ir.Table: def insert(self, name: str, data: pd.DataFrame) -> None: """Insert data into an existing table.""" + @abstractmethod + def delete_rows(self, name: str, values: dict[str, Any]) -> None: + """Delete rows from a table where every column equals its given value. + + Identifiers must be quoted and values must be parameterized to avoid + SQL injection and to handle values containing quote characters. + """ + @abstractmethod def execute(self, expr: ir.Expr) -> pd.DataFrame: """Execute an ibis expression and return a DataFrame.""" @@ -118,7 +126,7 @@ def execute_sql(self, query: str) -> Any: def execute_sql_to_df(self, query: str) -> pd.DataFrame: """Execute a raw SQL query and return a DataFrame.""" logger.trace("execute_sql_to_df: {}", query) - return self.connection.raw_sql(query).fetch_df() + return cast(pd.DataFrame, self.connection.raw_sql(query).fetch_df()) def dispose(self) -> None: """Dispose of the backend connection.""" diff --git a/src/chronify/ibis/duckdb_backend.py b/src/chronify/ibis/duckdb_backend.py index 4004019..c4c7872 100644 --- a/src/chronify/ibis/duckdb_backend.py +++ b/src/chronify/ibis/duckdb_backend.py @@ -1,6 +1,7 @@ """DuckDB backend implementation for Ibis.""" from pathlib import Path +from typing import Any, cast import ibis import ibis.expr.types as ir @@ -60,18 +61,27 @@ def insert(self, name: str, data: pd.DataFrame) -> None: target_columns = list(self.table(name).columns) ordered_data = data.reindex(columns=target_columns) quoted_columns = ", ".join(f'"{col}"' for col in target_columns) + quoted_name = _quote_identifier(name) con.register("__insert_df", ordered_data) try: con.execute( - f"INSERT INTO {name} ({quoted_columns}) " + f"INSERT INTO {quoted_name} ({quoted_columns}) " f"SELECT {quoted_columns} FROM __insert_df" ) finally: con.unregister("__insert_df") logger.trace("Inserted {} rows into {}", len(data), name) + def delete_rows(self, name: str, values: dict[str, Any]) -> None: + con = self._connection.con + quoted_name = _quote_identifier(name) + where = " AND ".join(f"{_quote_identifier(c)} = ?" for c in values) + sql = f"DELETE FROM {quoted_name} WHERE {where}" + con.execute(sql, list(values.values())) + logger.trace("Deleted rows from {} matching {}", name, values) + def execute(self, expr: ir.Expr) -> pd.DataFrame: - return self._connection.execute(expr) + return cast(pd.DataFrame, self._connection.execute(expr)) def sql(self, query: str) -> ir.Table: return self._connection.sql(query) @@ -89,8 +99,7 @@ def write_parquet( f"COPY ({sql}) TO '{path}' (FORMAT PARQUET, PARTITION_BY ({partition_clause}))" ) else: - df = self._connection.execute(expr) - df.to_parquet(path) + expr.to_parquet(path) def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, ObjectType]: parquet_path = Path(path) @@ -110,7 +119,7 @@ def execute_sql(self, query: str) -> None: def execute_sql_to_df(self, query: str) -> pd.DataFrame: logger.trace("execute_sql_to_df: {}", query) result = self._connection.raw_sql(query) - return result.fetch_df() + return cast(pd.DataFrame, result.fetch_df()) def dispose(self) -> None: self._connection.disconnect() @@ -120,3 +129,9 @@ def reconnect(self) -> None: self._connection = ibis.duckdb.connect(self._database) else: logger.warning("Cannot reconnect to an in-memory DuckDB database.") + + +def _quote_identifier(identifier: str) -> str: + """Quote a SQL identifier for DuckDB, escaping embedded double quotes.""" + escaped = identifier.replace('"', '""') + return f'"{escaped}"' diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index 942af92..a339676 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -1,6 +1,7 @@ """Spark backend implementation for Ibis.""" -from typing import Any +import uuid +from typing import Any, cast import ibis import ibis.expr.types as ir @@ -65,7 +66,7 @@ def drop_view(self, name: str) -> None: self._connection.drop_view(name, force=True) def list_tables(self) -> list[str]: - return self._connection.list_tables() + return cast(list[str], self._connection.list_tables()) def table(self, name: str) -> ir.Table: return self._connection.table(name) @@ -74,12 +75,29 @@ def insert(self, name: str, data: pd.DataFrame) -> None: # Spark doesn't support INSERT directly -- create a temp view and insert via SQL data = self._prepare_data_for_spark(data) spark_df = self._session.createDataFrame(data) - spark_df.createOrReplaceTempView("__insert_tmp") - self._session.sql(f"INSERT INTO {name} SELECT * FROM __insert_tmp") + tmp_view = f"__insert_tmp_{uuid.uuid4().hex}" + spark_df.createOrReplaceTempView(tmp_view) + quoted_name = _quote_identifier(name) + try: + self._session.sql(f"INSERT INTO {quoted_name} SELECT * FROM {tmp_view}") + finally: + self._session.catalog.dropTempView(tmp_view) logger.trace("Inserted {} rows into {}", len(data), name) + def delete_rows(self, name: str, values: dict[str, Any]) -> None: + # Spark 3.4+ supports parameterized SQL via the ``args`` keyword. + quoted_name = _quote_identifier(name) + param_names = [f"p{i}" for i in range(len(values))] + where = " AND ".join( + f"{_quote_identifier(c)} = :{p}" for c, p in zip(values, param_names) + ) + sql = f"DELETE FROM {quoted_name} WHERE {where}" + args = dict(zip(param_names, values.values())) + self._session.sql(sql, args=args) + logger.trace("Deleted rows from {} matching {}", name, values) + def execute(self, expr: ir.Expr) -> pd.DataFrame: - return self._connection.execute(expr) + return cast(pd.DataFrame, self._connection.execute(expr)) def sql(self, query: str) -> ir.Table: return self._connection.sql(query) @@ -108,7 +126,7 @@ def execute_sql(self, query: str) -> None: def execute_sql_to_df(self, query: str) -> pd.DataFrame: logger.trace("execute_sql_to_df: {}", query) - return self._session.sql(query).toPandas() + return cast(pd.DataFrame, self._session.sql(query).toPandas()) def dispose(self) -> None: pass # Don't stop the Spark session -- it may be shared @@ -123,3 +141,9 @@ def _prepare_data_for_spark(df: pd.DataFrame) -> pd.DataFrame: for col in df.select_dtypes(include=["datetime64[ns, UTC]", "datetimetz"]).columns: df[col] = df[col].dt.strftime("%Y-%m-%d %H:%M:%S%z") return df + + +def _quote_identifier(identifier: str) -> str: + """Quote a SQL identifier for Spark SQL, escaping embedded backticks.""" + escaped = identifier.replace("`", "``") + return f"`{escaped}`" diff --git a/src/chronify/ibis/sqlite_backend.py b/src/chronify/ibis/sqlite_backend.py index 0424290..fd1399b 100644 --- a/src/chronify/ibis/sqlite_backend.py +++ b/src/chronify/ibis/sqlite_backend.py @@ -1,6 +1,7 @@ """SQLite backend implementation for Ibis.""" from pathlib import Path +from typing import Any, cast import ibis import ibis.expr.types as ir @@ -54,7 +55,7 @@ def drop_view(self, name: str) -> None: self._connection.drop_view(name, force=True) def list_tables(self) -> list[str]: - return self._connection.list_tables() + return cast(list[str], self._connection.list_tables()) def table(self, name: str) -> ir.Table: return self._connection.table(name) @@ -65,10 +66,11 @@ def insert(self, name: str, data: pd.DataFrame) -> None: table = self._connection.table(name) columns = table.columns placeholders = ", ".join(["?"] * len(columns)) - col_list = ", ".join(columns) - sql = f"INSERT INTO {name} ({col_list}) VALUES ({placeholders})" + col_list = ", ".join(_quote_identifier(c) for c in columns) + quoted_name = _quote_identifier(name) + sql = f"INSERT INTO {quoted_name} ({col_list}) VALUES ({placeholders})" - arrow_table = pa.Table.from_pandas(data) + arrow_table = pa.Table.from_pandas(data.reindex(columns=columns)) cursor = con.cursor() for batch in arrow_table.to_batches(): rows = [tuple(row[col].as_py() for col in range(batch.num_columns)) for row in zip(*[batch.column(i) for i in range(batch.num_columns)])] @@ -76,8 +78,17 @@ def insert(self, name: str, data: pd.DataFrame) -> None: con.commit() logger.trace("Inserted {} rows into {}", len(data), name) + def delete_rows(self, name: str, values: dict[str, Any]) -> None: + con = self._connection.con + quoted_name = _quote_identifier(name) + where = " AND ".join(f"{_quote_identifier(c)} = ?" for c in values) + sql = f"DELETE FROM {quoted_name} WHERE {where}" + con.execute(sql, list(values.values())) + con.commit() + logger.trace("Deleted rows from {} matching {}", name, values) + def execute(self, expr: ir.Expr) -> pd.DataFrame: - return self._connection.execute(expr) + return cast(pd.DataFrame, self._connection.execute(expr)) def sql(self, query: str) -> ir.Table: return self._connection.sql(query) @@ -119,3 +130,9 @@ def dispose(self) -> None: def reconnect(self) -> None: db = self._database if self._database else ":memory:" self._connection = ibis.sqlite.connect(db) + + +def _quote_identifier(identifier: str) -> str: + """Quote a SQL identifier for SQLite, escaping embedded double quotes.""" + escaped = identifier.replace('"', '""') + return f'"{escaped}"' diff --git a/src/chronify/models.py b/src/chronify/models.py index 9d53ac8..a04c7b1 100644 --- a/src/chronify/models.py +++ b/src/chronify/models.py @@ -10,7 +10,7 @@ from chronify.base_models import ChronifyBaseModel from chronify.exceptions import InvalidValue -from chronify.ibis.types import get_ibis_type_from_duckdb, get_duckdb_type_from_ibis +from chronify.ibis.types import get_ibis_type_from_duckdb, get_ibis_type_from_string, get_duckdb_type_from_ibis from chronify.time_configs import TimeConfig @@ -143,18 +143,6 @@ def list_columns(self) -> list[str]: return time_columns -_COLUMN_TYPES: dict[str, type[dt.DataType]] = { - "bool": dt.Boolean, - "datetime": dt.Timestamp, - "float": dt.Float64, - "int": dt.Int64, - "bigint": dt.Int64, - "str": dt.String, -} - -_DB_TYPES = set(_COLUMN_TYPES.values()) - - def get_ibis_type_from_duckdb_pytype(duckdb_type: DuckDBPyType) -> dt.DataType: """Return the ibis type for a duckdb type.""" return get_ibis_type_from_duckdb(str(duckdb_type)) @@ -192,12 +180,10 @@ def fix_data_type(cls, data: dict[str, Any]) -> dict[str, Any]: return data if isinstance(dtype, str): - val = _COLUMN_TYPES.get(dtype) - if val is None: - options = sorted(_COLUMN_TYPES.keys()) - msg = f"{dtype=} must be one of {options}" - raise InvalidValue(msg) - data["dtype"] = val() + try: + data["dtype"] = get_ibis_type_from_string(dtype) + except ValueError as err: + raise InvalidValue(str(err)) from err else: msg = ( f"dtype is an unsupported type: {type(dtype)}. It must be a str or ibis DataType." diff --git a/src/chronify/schema_manager.py b/src/chronify/schema_manager.py index c0682da..ce60c68 100644 --- a/src/chronify/schema_manager.py +++ b/src/chronify/schema_manager.py @@ -3,7 +3,7 @@ import pandas as pd from loguru import logger -from chronify.exceptions import TableNotStored +from chronify.exceptions import InvalidParameter, TableNotStored from chronify.ibis.base import IbisBackend from chronify.models import TableSchema @@ -27,15 +27,16 @@ def __init__(self, backend: IbisBackend) -> None: def _create_schemas_table(self) -> None: import ibis + # Uniqueness of `name` is enforced in `add_schema` rather than via a + # unique index, since Spark SQL does not support CREATE UNIQUE INDEX. schema = ibis.schema({"name": "string", "schema": "string"}) self._backend.create_table(self.SCHEMAS_TABLE, schema=schema) - self._backend.execute_sql( - f"CREATE UNIQUE INDEX idx_{self.SCHEMAS_TABLE}_name " - f"ON {self.SCHEMAS_TABLE} (name)" - ) def add_schema(self, schema: TableSchema) -> None: """Add the schema to the store.""" + if schema.name in self._cache: + msg = f"A schema with name={schema.name!r} is already registered" + raise InvalidParameter(msg) df = pd.DataFrame({"name": [schema.name], "schema": [schema.model_dump_json()]}) self._backend.insert(self.SCHEMAS_TABLE, df) self._cache[schema.name] = schema @@ -56,10 +57,7 @@ def get_schema(self, name: str) -> TableSchema: def remove_schema(self, name: str) -> None: """Remove the schema from the store.""" - safe_name = name.replace("'", "''") - self._backend.execute_sql( - f"DELETE FROM {self.SCHEMAS_TABLE} WHERE name = '{safe_name}'" - ) + self._backend.delete_rows(self.SCHEMAS_TABLE, {"name": name}) self._cache.pop(name, None) def rebuild_cache(self) -> None: diff --git a/src/chronify/store.py b/src/chronify/store.py index 43f7c7e..e6357a5 100644 --- a/src/chronify/store.py +++ b/src/chronify/store.py @@ -136,16 +136,19 @@ def try_get_table(self, name: str) -> ir.Table | None: def backup(self, dst: Path | str, overwrite: bool = False) -> None: """Copy the database to a new location. Not yet supported for in-memory databases.""" - self._backend.dispose() - path = to_path(dst) - check_overwrite(path, overwrite) if self._backend.database is None: msg = "backup is only supported with a database backed by a file" raise InvalidOperation(msg) + path = to_path(dst) + check_overwrite(path, overwrite) src_file = Path(self._backend.database) - shutil.copyfile(src_file, path) + + self._backend.dispose() + try: + shutil.copyfile(src_file, path) + finally: + self._backend.reconnect() logger.info("Copied database to {}", path) - self._backend.reconnect() @property def backend(self) -> IbisBackend: @@ -212,10 +215,11 @@ def ingest_from_csvs( dst_schema: TableSchema, ) -> bool: """Ingest data from multiple CSV files into the table specified by schema.""" + table_existed = self._backend.has_table(dst_schema.name) try: created_table = self._ingest_from_csvs(paths, src_schema, dst_schema) except Exception: - if self._backend.has_table(dst_schema.name): + if not table_existed and self._backend.has_table(dst_schema.name): self._backend.drop_table(dst_schema.name) self._schema_mgr.remove_schema(dst_schema.name) raise @@ -276,10 +280,11 @@ def ingest_pivoted_tables( dst_schema: TableSchema, ) -> bool: """Ingest pivoted data from multiple tables. Unpivot before ingesting.""" + table_existed = self._backend.has_table(dst_schema.name) try: created_table = self._ingest_pivoted_tables(data, src_schema, dst_schema) except Exception: - if self._backend.has_table(dst_schema.name): + if not table_existed and self._backend.has_table(dst_schema.name): self._backend.drop_table(dst_schema.name) self._schema_mgr.remove_schema(dst_schema.name) raise @@ -358,10 +363,11 @@ def ingest_tables( if not data: return created_table + table_existed = self._backend.has_table(schema.name) try: created_table = self._ingest_tables(data, schema, **kwargs) except Exception: - if self._backend.has_table(schema.name): + if not table_existed and self._backend.has_table(schema.name): self._backend.drop_table(schema.name) self._schema_mgr.remove_schema(schema.name) raise @@ -637,24 +643,16 @@ def delete_rows( ) raise InvalidParameter(msg) - # Count rows before delete - where_clauses = [] - for column, value in time_array_id_values.items(): - if isinstance(value, str): - where_clauses.append(f"{column} = '{value}'") - else: - where_clauses.append(f"{column} = {value}") - where_str = " AND ".join(where_clauses) - - count_df = self._backend.execute_sql_to_df( - f"SELECT COUNT(*) as cnt FROM {name} WHERE {where_str}" - ) - num_to_delete = int(count_df.iloc[0, 0]) + # Build the predicate using ibis (safe -- no string interpolation). + table = self._backend.table(name) + predicates = [table[column] == value for column, value in time_array_id_values.items()] + filtered = table.filter(predicates) + num_to_delete = int(filtered.count().execute()) - self._backend.execute_sql(f"DELETE FROM {name} WHERE {where_str}") + self._backend.delete_rows(name, time_array_id_values) if num_to_delete < 1: - msg = f"Failed to delete rows: {where_str} {num_to_delete=}" + msg = f"Failed to delete rows: {time_array_id_values} {num_to_delete=}" raise InvalidParameter(msg) logger.info( @@ -664,8 +662,8 @@ def delete_rows( ) # Check if table is now empty - remaining = self._backend.execute_sql_to_df(f"SELECT COUNT(*) as cnt FROM {name}") - if int(remaining.iloc[0, 0]) == 0: + remaining = int(self._backend.table(name).count().execute()) + if remaining == 0: logger.info("Delete empty table {}", name) self.drop_table(name) From 0eede6f05652780c62a9a387a58e41b97044eb3d Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Fri, 10 Apr 2026 12:18:53 -0600 Subject: [PATCH 04/48] Fix spark functionality --- src/chronify/ibis/functions.py | 10 ++ src/chronify/ibis/spark_backend.py | 33 +++++-- src/chronify/ibis/sqlite_backend.py | 5 +- src/chronify/models.py | 6 +- src/chronify/store.py | 5 +- src/chronify/time_series_mapper_base.py | 6 +- tests/test_spark_backend.py | 123 ++++++++++++++++++++++++ 7 files changed, 172 insertions(+), 16 deletions(-) create mode 100644 tests/test_spark_backend.py diff --git a/src/chronify/ibis/functions.py b/src/chronify/ibis/functions.py index 7c11cf4..482d33b 100644 --- a/src/chronify/ibis/functions.py +++ b/src/chronify/ibis/functions.py @@ -83,6 +83,7 @@ def write_parquet( output_file: Path, overwrite: bool = False, partition_columns: list[str] | None = None, + config: TimeBaseModel | None = None, ) -> None: """Write query results to a Parquet file.""" check_overwrite(output_file, overwrite) @@ -92,6 +93,15 @@ def write_parquet( else: expr = query + if backend.name == "spark" and isinstance(config, _DATETIME_RANGES): + df = backend.execute(expr) + _convert_spark_output_for_datetime(df, config) + if partition_columns: + df.to_parquet(output_file, partition_cols=partition_columns) + else: + df.to_parquet(output_file) + return + backend.write_parquet(expr, str(output_file), partition_by=partition_columns) diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index a339676..0fb01de 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -7,7 +7,9 @@ import ibis.expr.types as ir import pandas as pd from loguru import logger +from pandas import DatetimeTZDtype +from chronify.exceptions import InvalidParameter from chronify.ibis.base import IbisBackend, ObjectType @@ -24,6 +26,7 @@ def __init__(self, session: Any = None) -> None: msg = "pyspark is required for SparkBackend. Install with: pip install chronify[spark]" raise ImportError(msg) from e + self._owns_session = session is None if session is None: session = ( SparkSession.builder.master("local") @@ -31,6 +34,7 @@ def __init__(self, session: Any = None) -> None: .config("spark.sql.parquet.outputTimestampType", "TIMESTAMP_MICROS") .getOrCreate() ) + self._validate_session(session) self._session = session self._connection = ibis.pyspark.connect(session) @@ -88,9 +92,7 @@ def delete_rows(self, name: str, values: dict[str, Any]) -> None: # Spark 3.4+ supports parameterized SQL via the ``args`` keyword. quoted_name = _quote_identifier(name) param_names = [f"p{i}" for i in range(len(values))] - where = " AND ".join( - f"{_quote_identifier(c)} = :{p}" for c, p in zip(values, param_names) - ) + where = " AND ".join(f"{_quote_identifier(c)} = :{p}" for c, p in zip(values, param_names)) sql = f"DELETE FROM {quoted_name} WHERE {where}" args = dict(zip(param_names, values.values())) self._session.sql(sql, args=args) @@ -129,19 +131,36 @@ def execute_sql_to_df(self, query: str) -> pd.DataFrame: return cast(pd.DataFrame, self._session.sql(query).toPandas()) def dispose(self) -> None: - pass # Don't stop the Spark session -- it may be shared + if self._owns_session: + self._session.stop() def reconnect(self) -> None: pass # Spark sessions are long-lived @staticmethod def _prepare_data_for_spark(df: pd.DataFrame) -> pd.DataFrame: - """Convert datetime columns to strings to avoid Spark DST issues.""" + """Normalize tz-aware pandas timestamps for Spark ingestion. + + Spark timestamps are timezone-naive and interpreted in the session time + zone. We require UTC sessions, so convert tz-aware columns to tz-naive + UTC timestamps before handing them to Spark. + """ df = df.copy() - for col in df.select_dtypes(include=["datetime64[ns, UTC]", "datetimetz"]).columns: - df[col] = df[col].dt.strftime("%Y-%m-%d %H:%M:%S%z") + for col in df.columns: + if isinstance(df[col].dtype, DatetimeTZDtype): + df[col] = df[col].dt.tz_convert("UTC").dt.tz_localize(None) return df + @staticmethod + def _validate_session(session: Any) -> None: + time_zone = session.conf.get("spark.sql.session.timeZone", None) or "UTC" + if time_zone != "UTC": + msg = ( + "SparkBackend requires spark.sql.session.timeZone=UTC to preserve " + f"timestamp semantics, got {time_zone!r}." + ) + raise InvalidParameter(msg) + def _quote_identifier(identifier: str) -> str: """Quote a SQL identifier for Spark SQL, escaping embedded backticks.""" diff --git a/src/chronify/ibis/sqlite_backend.py b/src/chronify/ibis/sqlite_backend.py index fd1399b..c9ae845 100644 --- a/src/chronify/ibis/sqlite_backend.py +++ b/src/chronify/ibis/sqlite_backend.py @@ -73,7 +73,10 @@ def insert(self, name: str, data: pd.DataFrame) -> None: arrow_table = pa.Table.from_pandas(data.reindex(columns=columns)) cursor = con.cursor() for batch in arrow_table.to_batches(): - rows = [tuple(row[col].as_py() for col in range(batch.num_columns)) for row in zip(*[batch.column(i) for i in range(batch.num_columns)])] + rows = [ + tuple(row[col].as_py() for col in range(batch.num_columns)) + for row in zip(*[batch.column(i) for i in range(batch.num_columns)]) + ] cursor.executemany(sql, rows) con.commit() logger.trace("Inserted {} rows into {}", len(data), name) diff --git a/src/chronify/models.py b/src/chronify/models.py index a04c7b1..9f0fcca 100644 --- a/src/chronify/models.py +++ b/src/chronify/models.py @@ -10,7 +10,11 @@ from chronify.base_models import ChronifyBaseModel from chronify.exceptions import InvalidValue -from chronify.ibis.types import get_ibis_type_from_duckdb, get_ibis_type_from_string, get_duckdb_type_from_ibis +from chronify.ibis.types import ( + get_ibis_type_from_duckdb, + get_ibis_type_from_string, + get_duckdb_type_from_ibis, +) from chronify.time_configs import TimeConfig diff --git a/src/chronify/store.py b/src/chronify/store.py index e6357a5..3a95644 100644 --- a/src/chronify/store.py +++ b/src/chronify/store.py @@ -191,9 +191,7 @@ def create_view_from_parquet( self._backend.drop_view(schema.name) raise - def _create_view_from_parquet( - self, path: Path | str, schema: TableSchema - ) -> "ObjectType": + def _create_view_from_parquet(self, path: Path | str, schema: TableSchema) -> "ObjectType": """Create a view in the database from a Parquet file.""" obj_type = create_view_from_parquet(self._backend, to_path(path), schema.name) self._schema_mgr.add_schema(schema) @@ -605,6 +603,7 @@ def write_table_to_parquet( to_path(file_path), overwrite=overwrite, partition_columns=partition_columns, + config=self._schema_mgr.get_schema(name).time_config, ) logger.info("Wrote table or view to {}", file_path) diff --git a/src/chronify/time_series_mapper_base.py b/src/chronify/time_series_mapper_base.py index 15ae672..97511c0 100644 --- a/src/chronify/time_series_mapper_base.py +++ b/src/chronify/time_series_mapper_base.py @@ -107,9 +107,7 @@ def apply_mapping( if check_mapped_timestamps: if output_file is not None: output_file = to_path(output_file) - created_tmp_obj = create_view_from_parquet( - backend, output_file, to_schema.name - ) + created_tmp_obj = create_view_from_parquet(backend, output_file, to_schema.name) try: check_timestamps( backend, @@ -215,7 +213,7 @@ def _right_col(col: str) -> Any: if output_file is not None: output_file = to_path(output_file) - write_parquet(backend, result, output_file, overwrite=True) + write_parquet(backend, result, output_file, overwrite=True, config=to_schema.time_config) return backend.create_table(to_schema.name, result) diff --git a/tests/test_spark_backend.py b/tests/test_spark_backend.py new file mode 100644 index 0000000..56b3879 --- /dev/null +++ b/tests/test_spark_backend.py @@ -0,0 +1,123 @@ +from datetime import datetime, timedelta +from pathlib import Path +from zoneinfo import ZoneInfo +import os + +import pandas as pd +import pytest + +from chronify.exceptions import InvalidParameter +from chronify.ibis.spark_backend import SparkBackend +from chronify.models import TableSchema +from chronify.store import Store +from chronify.time import TimeIntervalType +from chronify.time_configs import DatetimeRange + + +def _require_java_home() -> None: + if not os.environ.get("JAVA_HOME"): + pytest.skip("Spark tests require JAVA_HOME to be set") + + +@pytest.fixture +def spark_store(tmp_path: Path) -> Store: + _require_java_home() + pyspark = pytest.importorskip("pyspark.sql") + warehouse_dir = tmp_path / "spark-warehouse" + session = ( + pyspark.SparkSession.builder.master("local") + .config("spark.sql.session.timeZone", "UTC") + .config("spark.sql.parquet.outputTimestampType", "TIMESTAMP_MICROS") + .config("spark.sql.warehouse.dir", str(warehouse_dir)) + .getOrCreate() + ) + store = Store(backend=SparkBackend(session=session)) + yield store + store.dispose() + + +def test_spark_round_trip_timestamp_tz_preserves_fractional_seconds(spark_store: Store) -> None: + schema = TableSchema( + name="spark_ts_data", + value_column="value", + time_config=DatetimeRange( + time_column="timestamp", + start=datetime(2020, 1, 1, tzinfo=ZoneInfo("UTC")), + length=2, + resolution=timedelta(hours=1), + interval_type=TimeIntervalType.PERIOD_BEGINNING, + ), + time_array_id_columns=["id"], + ) + df = pd.DataFrame( + { + "id": [1, 1], + "timestamp": pd.to_datetime( + [ + "2020-01-01 00:00:00.123456-05:00", + "2020-01-01 01:00:00.654321-05:00", + ], + utc=True, + ), + "value": [1.0, 2.0], + } + ) + + spark_store.ingest_table(df, schema, skip_time_checks=True) + out = spark_store.read_table(schema.name).sort_values("timestamp").reset_index(drop=True) + + expected = pd.to_datetime( + [ + "2020-01-01 05:00:00.123456+00:00", + "2020-01-01 06:00:00.654321+00:00", + ], + utc=True, + ) + assert list(out["timestamp"]) == list(expected) + + +def test_spark_write_table_to_parquet_preserves_timestamp_type( + spark_store: Store, tmp_path: Path +) -> None: + schema = TableSchema( + name="spark_parquet_data", + value_column="value", + time_config=DatetimeRange( + time_column="timestamp", + start=datetime(2020, 1, 1, tzinfo=ZoneInfo("UTC")), + length=1, + resolution=timedelta(hours=1), + interval_type=TimeIntervalType.PERIOD_BEGINNING, + ), + time_array_id_columns=["id"], + ) + df = pd.DataFrame( + { + "id": [1], + "timestamp": pd.to_datetime(["2020-01-01 00:00:00.123456+00:00"], utc=True), + "value": [1.0], + } + ) + + spark_store.ingest_table(df, schema, skip_time_checks=True) + outfile = tmp_path / "spark_ts.parquet" + spark_store.write_table_to_parquet(schema.name, outfile, overwrite=True) + + out = pd.read_parquet(outfile) + assert pd.api.types.is_datetime64_any_dtype(out["timestamp"]) + assert out["timestamp"].iloc[0] == pd.Timestamp("2020-01-01 00:00:00.123456+00:00") + + +def test_spark_backend_rejects_non_utc_session() -> None: + _require_java_home() + pyspark = pytest.importorskip("pyspark.sql") + session = ( + pyspark.SparkSession.builder.master("local") + .config("spark.sql.session.timeZone", "America/Denver") + .getOrCreate() + ) + try: + with pytest.raises(InvalidParameter, match="spark.sql.session.timeZone=UTC"): + SparkBackend(session=session) + finally: + session.stop() From afd1a8fb14dbe18bd1a5de02a76aeb3d182feb04 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Fri, 10 Apr 2026 13:09:18 -0600 Subject: [PATCH 05/48] Code cleanup --- src/chronify/ibis/duckdb_backend.py | 10 +++++++--- src/chronify/ibis/spark_backend.py | 7 ++++++- src/chronify/ibis/sqlite_backend.py | 25 +++++++++++++++++-------- src/chronify/ibis/types.py | 20 ++++++-------------- src/chronify/models.py | 15 --------------- src/chronify/schema_manager.py | 2 +- src/chronify/time_series_mapper_base.py | 11 ++++------- 7 files changed, 41 insertions(+), 49 deletions(-) diff --git a/src/chronify/ibis/duckdb_backend.py b/src/chronify/ibis/duckdb_backend.py index c4c7872..0cca830 100644 --- a/src/chronify/ibis/duckdb_backend.py +++ b/src/chronify/ibis/duckdb_backend.py @@ -93,10 +93,12 @@ def write_parquet( partition_by: list[str] | None = None, ) -> None: if partition_by: - partition_clause = ", ".join(partition_by) + partition_clause = ", ".join(_quote_identifier(c) for c in partition_by) + escaped_path = path.replace("'", "''") sql = self._connection.compile(expr) self._connection.raw_sql( - f"COPY ({sql}) TO '{path}' (FORMAT PARQUET, PARTITION_BY ({partition_clause}))" + f"COPY ({sql}) TO '{escaped_path}' " + f"(FORMAT PARQUET, PARTITION_BY ({partition_clause}))" ) else: expr.to_parquet(path) @@ -107,8 +109,10 @@ def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, Obje read_path = str(parquet_path / "**" / "*.parquet").replace("\\", "/") else: read_path = str(parquet_path).replace("\\", "/") + quoted_name = _quote_identifier(name) + escaped_path = read_path.replace("'", "''") self._connection.raw_sql( - f"CREATE VIEW {name} AS SELECT * FROM read_parquet('{read_path}')" + f"CREATE VIEW {quoted_name} AS SELECT * FROM read_parquet('{escaped_path}')" ) return self.table(name), ObjectType.VIEW diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index 0fb01de..4b3ca6d 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -77,13 +77,18 @@ def table(self, name: str) -> ir.Table: def insert(self, name: str, data: pd.DataFrame) -> None: # Spark doesn't support INSERT directly -- create a temp view and insert via SQL + target_columns = list(self.table(name).columns) + data = data.reindex(columns=target_columns) data = self._prepare_data_for_spark(data) spark_df = self._session.createDataFrame(data) tmp_view = f"__insert_tmp_{uuid.uuid4().hex}" spark_df.createOrReplaceTempView(tmp_view) quoted_name = _quote_identifier(name) + col_list = ", ".join(_quote_identifier(c) for c in target_columns) try: - self._session.sql(f"INSERT INTO {quoted_name} SELECT * FROM {tmp_view}") + self._session.sql( + f"INSERT INTO {quoted_name} ({col_list}) SELECT {col_list} FROM {tmp_view}" + ) finally: self._session.catalog.dropTempView(tmp_view) logger.trace("Inserted {} rows into {}", len(data), name) diff --git a/src/chronify/ibis/sqlite_backend.py b/src/chronify/ibis/sqlite_backend.py index c9ae845..dcaa3b8 100644 --- a/src/chronify/ibis/sqlite_backend.py +++ b/src/chronify/ibis/sqlite_backend.py @@ -1,17 +1,30 @@ """SQLite backend implementation for Ibis.""" +from datetime import datetime from pathlib import Path from typing import Any, cast import ibis import ibis.expr.types as ir import pandas as pd -import pyarrow as pa from loguru import logger from chronify.ibis.base import IbisBackend, ObjectType +def _adapt_value(v: Any) -> Any: + """Convert a value for SQLite parameterized insertion. + + Converts datetime/Timestamp objects to ISO-format strings to avoid the + Python 3.12+ DeprecationWarning about the default datetime adapter. + """ + if isinstance(v, datetime): + return v.isoformat() + if hasattr(v, "isoformat"): + return v.isoformat() + return v + + class SQLiteBackend(IbisBackend): """Ibis backend for SQLite databases.""" @@ -70,14 +83,10 @@ def insert(self, name: str, data: pd.DataFrame) -> None: quoted_name = _quote_identifier(name) sql = f"INSERT INTO {quoted_name} ({col_list}) VALUES ({placeholders})" - arrow_table = pa.Table.from_pandas(data.reindex(columns=columns)) + ordered = data.reindex(columns=columns) + rows = [tuple(_adapt_value(v) for v in row) for row in ordered.itertuples(index=False)] cursor = con.cursor() - for batch in arrow_table.to_batches(): - rows = [ - tuple(row[col].as_py() for col in range(batch.num_columns)) - for row in zip(*[batch.column(i) for i in range(batch.num_columns)]) - ] - cursor.executemany(sql, rows) + cursor.executemany(sql, rows) con.commit() logger.trace("Inserted {} rows into {}", len(data), name) diff --git a/src/chronify/ibis/types.py b/src/chronify/ibis/types.py index d6b06f7..396bf1b 100644 --- a/src/chronify/ibis/types.py +++ b/src/chronify/ibis/types.py @@ -1,5 +1,6 @@ """Type conversion utilities for Ibis backends.""" +import ibis import ibis.expr.datatypes as dt import pandas as pd import pyarrow as pa @@ -84,21 +85,12 @@ def get_duckdb_type_from_ibis(ibis_type: dt.DataType) -> str: def get_ibis_types_from_dataframe(df: pd.DataFrame) -> dict[str, dt.DataType]: - """Infer Ibis types from a pandas DataFrame's columns.""" - import duckdb + """Infer Ibis types from a pandas DataFrame's columns. - con = duckdb.connect() - rel = con.from_df(df) - types = {} - for name, dtype in zip(rel.columns, rel.types, strict=True): - types[name] = get_ibis_type_from_duckdb(str(dtype)) - con.close() - return types - - -def get_ibis_schema_from_dataframe(df: pd.DataFrame) -> dict[str, dt.DataType]: - """Get an ibis schema dict from a pandas DataFrame.""" - return get_ibis_types_from_dataframe(df) + Note: This uses ibis schema inference and does not require a DuckDB connection. + """ + schema = ibis.Schema.from_pandas(df) + return dict(schema.items()) def pyarrow_to_ibis_type(arrow_type: pa.DataType) -> dt.DataType: diff --git a/src/chronify/models.py b/src/chronify/models.py index 9f0fcca..69333be 100644 --- a/src/chronify/models.py +++ b/src/chronify/models.py @@ -1,17 +1,13 @@ import re from typing import Any, Optional -import duckdb import ibis.expr.datatypes as dt -import pandas as pd -from duckdb.typing import DuckDBPyType from pydantic import Field, field_validator, model_validator from typing_extensions import Annotated from chronify.base_models import ChronifyBaseModel from chronify.exceptions import InvalidValue from chronify.ibis.types import ( - get_ibis_type_from_duckdb, get_ibis_type_from_string, get_duckdb_type_from_ibis, ) @@ -147,22 +143,11 @@ def list_columns(self) -> list[str]: return time_columns -def get_ibis_type_from_duckdb_pytype(duckdb_type: DuckDBPyType) -> dt.DataType: - """Return the ibis type for a duckdb type.""" - return get_ibis_type_from_duckdb(str(duckdb_type)) - - def get_duckdb_type_from_ibis_type(ibis_type: dt.DataType) -> str: """Return the duckdb type string for an ibis type.""" return get_duckdb_type_from_ibis(ibis_type) -def get_duckdb_types_from_pandas(df: pd.DataFrame) -> list[DuckDBPyType]: - """Return a list of DuckDB types from a pandas dataframe.""" - short_df = df.head(1) # noqa: F841 - return duckdb.sql("select * from short_df").dtypes - - class ColumnDType(ChronifyBaseModel): """Defines the dtype of a column.""" diff --git a/src/chronify/schema_manager.py b/src/chronify/schema_manager.py index ce60c68..27fd460 100644 --- a/src/chronify/schema_manager.py +++ b/src/chronify/schema_manager.py @@ -66,7 +66,7 @@ def rebuild_cache(self) -> None: self._rebuild_cache() def _rebuild_cache(self) -> None: - df = self._backend.execute_sql_to_df(f"SELECT * FROM {self.SCHEMAS_TABLE}") + df = self._backend.execute(self._backend.table(self.SCHEMAS_TABLE)) for _, row in df.iterrows(): name = row["name"] schema = TableSchema(**json.loads(row["schema"])) diff --git a/src/chronify/time_series_mapper_base.py b/src/chronify/time_series_mapper_base.py index 97511c0..2f7fb66 100644 --- a/src/chronify/time_series_mapper_base.py +++ b/src/chronify/time_series_mapper_base.py @@ -183,11 +183,8 @@ def _right_col(col: str) -> Any: return joined[col + "_right"] return joined[col] - # Build value expression - val_expr: Any = _left_col(val_col) if val_col not in right_columns else _left_col(val_col) - if val_col in right_columns and val_col in left_columns: - # val_col exists in both; we want the left (source) value - val_expr = _left_col(val_col) + # Build value expression (always from the left/source table) + val_expr: Any = _left_col(val_col) if "factor" in right_columns: val_expr = val_expr * _right_col("factor") @@ -205,11 +202,11 @@ def _right_col(col: str) -> Any: group_exprs = select_exprs.copy() match resampling_operation: case AggregationType.SUM: - select_exprs.append(val_expr.sum().name(val_col)) + agg_expr = val_expr.sum().name(val_col) case _: msg = f"Unsupported {resampling_operation=}" raise InvalidOperation(msg) - result = joined.group_by(group_exprs).aggregate(val_expr.sum().name(val_col)) + result = joined.group_by(group_exprs).aggregate(agg_expr) if output_file is not None: output_file = to_path(output_file) From 0e71e9a7855dc79e47f2af5c8226eacd9d4e1cc0 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Fri, 10 Apr 2026 13:18:32 -0600 Subject: [PATCH 06/48] Code cleanup --- src/chronify/ibis/sqlite_backend.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/chronify/ibis/sqlite_backend.py b/src/chronify/ibis/sqlite_backend.py index dcaa3b8..6ea2921 100644 --- a/src/chronify/ibis/sqlite_backend.py +++ b/src/chronify/ibis/sqlite_backend.py @@ -17,7 +17,10 @@ def _adapt_value(v: Any) -> Any: Converts datetime/Timestamp objects to ISO-format strings to avoid the Python 3.12+ DeprecationWarning about the default datetime adapter. + Returns None for pd.NaT and other missing-value sentinels. """ + if v is pd.NaT or v is None: + return None if isinstance(v, datetime): return v.isoformat() if hasattr(v, "isoformat"): From cc44ef327f8689b3b98220503c734a280213f563 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Fri, 10 Apr 2026 13:33:43 -0600 Subject: [PATCH 07/48] Code cleanup --- src/chronify/ibis/functions.py | 64 ++++++++++++++++++----- tests/test_mapper_datetime_to_datetime.py | 5 +- 2 files changed, 52 insertions(+), 17 deletions(-) diff --git a/src/chronify/ibis/functions.py b/src/chronify/ibis/functions.py index 482d33b..da4b14c 100644 --- a/src/chronify/ibis/functions.py +++ b/src/chronify/ibis/functions.py @@ -57,6 +57,45 @@ def read_query( return df +def _normalize_timestamps( + df: pd.DataFrame, + configs: Sequence[TimeBaseModel], +) -> pd.DataFrame: + """Normalize datetime columns so their pandas dtype matches the schema config. + + - TIMESTAMP_NTZ + tz-aware input → convert to UTC, then strip timezone + - TIMESTAMP_TZ + tz-naive input → localize as UTC + - matching dtype → no change + + This runs before any backend-specific handling so that all backends receive + consistently typed data. + """ + copied = False + for config in configs: + if not isinstance(config, _DATETIME_RANGES): + continue + col = config.time_column + if col not in df.columns: + continue + if not pd.api.types.is_datetime64_any_dtype(df[col]): + continue + + is_tz_aware = isinstance(df[col].dtype, DatetimeTZDtype) + + if config.dtype == TimeDataType.TIMESTAMP_NTZ and is_tz_aware: + if not copied: + df = df.copy() + copied = True + df[col] = df[col].dt.tz_convert("UTC").dt.tz_localize(None) + elif config.dtype == TimeDataType.TIMESTAMP_TZ and not is_tz_aware: + if not copied: + df = df.copy() + copied = True + df[col] = df[col].dt.tz_localize("UTC") + + return df + + def write_table( backend: IbisBackend, df: pd.DataFrame | pa.Table, @@ -65,6 +104,12 @@ def write_table( if_exists: str = "append", ) -> None: """Write a DataFrame to the database.""" + if isinstance(df, pa.Table): + df = df.to_pandas() + + _check_one_config_per_datetime_column(configs) + df = _normalize_timestamps(df, configs) + match backend.name: case "duckdb": _write_to_duckdb(backend, df, table_name, if_exists) @@ -185,12 +230,10 @@ def _convert_spark_output_for_datetime(df: pd.DataFrame, config: DatetimeRanges) def _write_to_duckdb( backend: IbisBackend, - df: pd.DataFrame | pa.Table, + df: pd.DataFrame, table_name: str, if_exists: str, ) -> None: - if isinstance(df, pa.Table): - df = df.to_pandas() match if_exists: case "append": backend.insert(table_name, df) @@ -206,16 +249,14 @@ def _write_to_duckdb( def _write_to_sqlite( backend: IbisBackend, - df: pd.DataFrame | pa.Table, + df: pd.DataFrame, table_name: str, configs: Sequence[TimeBaseModel], if_exists: str, ) -> None: - _check_one_config_per_datetime_column(configs) - - if isinstance(df, pa.Table): - df = df.to_pandas() - + # SQLite-specific: ensure TZ timestamps are stored as UTC text. + # _normalize_timestamps already ran, so NTZ columns are tz-naive and + # TZ columns are tz-aware UTC. This step converts TZ to UTC for storage. copied = False for config in configs: if isinstance(config, _DATETIME_RANGES): @@ -236,13 +277,10 @@ def _write_to_sqlite( def _write_to_spark( backend: IbisBackend, - df: pd.DataFrame | pa.Table, + df: pd.DataFrame, table_name: str, if_exists: str, ) -> None: - if isinstance(df, pa.Table): - df = df.to_pandas() - match if_exists: case "append": backend.insert(table_name, df) diff --git a/tests/test_mapper_datetime_to_datetime.py b/tests/test_mapper_datetime_to_datetime.py index a46fe60..d3b895e 100644 --- a/tests/test_mapper_datetime_to_datetime.py +++ b/tests/test_mapper_datetime_to_datetime.py @@ -247,8 +247,5 @@ def test_duplicated_configs_in_write_table( df = generate_datetime_dataframe(schema) configs = [schema.time_config, schema.time_config] - if iter_backends.name == "sqlite": - with pytest.raises(InvalidParameter, match="More than one datetime config found"): - write_table(iter_backends, df, schema.name, configs, if_exists="replace") - else: + with pytest.raises(InvalidParameter, match="More than one datetime config found"): write_table(iter_backends, df, schema.name, configs, if_exists="replace") From 177ca7ee271a0c0c58fcda9b7ad1fa20b220c0af Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Fri, 10 Apr 2026 13:47:07 -0600 Subject: [PATCH 08/48] Code cleanup --- src/chronify/schema_manager.py | 10 ++- src/chronify/store.py | 20 +++++- tests/conftest.py | 40 ++++++++++- tests/test_spark_backend.py | 127 +++++++++++++++++++++++++++++++++ 4 files changed, 194 insertions(+), 3 deletions(-) diff --git a/src/chronify/schema_manager.py b/src/chronify/schema_manager.py index 27fd460..b48ab07 100644 --- a/src/chronify/schema_manager.py +++ b/src/chronify/schema_manager.py @@ -30,7 +30,15 @@ def _create_schemas_table(self) -> None: # Uniqueness of `name` is enforced in `add_schema` rather than via a # unique index, since Spark SQL does not support CREATE UNIQUE INDEX. schema = ibis.schema({"name": "string", "schema": "string"}) - self._backend.create_table(self.SCHEMAS_TABLE, schema=schema) + try: + self._backend.create_table(self.SCHEMAS_TABLE, schema=schema) + except Exception: + # On Spark, a stale warehouse directory can cause + # LOCATION_ALREADY_EXISTS even though list_tables() didn't find + # the table. Drop the stale remnant and retry. + logger.debug("Retrying schemas table creation after dropping stale remnant.") + self._backend.drop_table(self.SCHEMAS_TABLE) + self._backend.create_table(self.SCHEMAS_TABLE, schema=schema) def add_schema(self, schema: TableSchema) -> None: """Add the schema to the store.""" diff --git a/src/chronify/store.py b/src/chronify/store.py index 3a95644..388a6f4 100644 --- a/src/chronify/store.py +++ b/src/chronify/store.py @@ -574,14 +574,32 @@ def write_query_to_parquet( file_path: Path | str, overwrite: bool = False, partition_columns: Optional[list[str]] = None, + name: Optional[str] = None, ) -> None: - """Write the result of a query to a Parquet file.""" + """Write the result of a query to a Parquet file. + + Parameters + ---------- + stmt + SQL query or ibis Table expression. + file_path + Output Parquet file path. + overwrite + Whether to overwrite an existing file. + partition_columns + Optional columns to partition by. + name + Optional table/view name used to look up the time config for + backend-specific timestamp normalization (e.g. Spark). + """ + config = self._schema_mgr.get_schema(name).time_config if name else None write_parquet( self._backend, stmt, to_path(file_path), overwrite=overwrite, partition_columns=partition_columns, + config=config, ) def write_table_to_parquet( diff --git a/tests/conftest.py b/tests/conftest.py index 7d4d277..ab400b5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ -from typing import Generator +import os +from typing import Any, Generator from pathlib import Path from tempfile import NamedTemporaryFile import numpy as np @@ -13,6 +14,8 @@ BACKEND_NAMES = ["duckdb", "sqlite"] +_SPARK_AVAILABLE = bool(os.environ.get("JAVA_HOME")) +ALL_BACKEND_NAMES = [*BACKEND_NAMES, "spark"] if _SPARK_AVAILABLE else BACKEND_NAMES @pytest.fixture @@ -21,6 +24,24 @@ def create_duckdb_backend() -> IbisBackend: return make_backend("duckdb") +def _make_backend(name: str, tmp_path: Path | None = None, **kwargs: Any) -> IbisBackend: + """Create a backend, handling Spark's SparkSession requirement.""" + if name == "spark": + from chronify.ibis.spark_backend import SparkBackend + from pyspark.sql import SparkSession + + warehouse_dir = (tmp_path or Path("/tmp")) / "spark-warehouse" # noqa: S108 + session = ( + SparkSession.builder.master("local") + .config("spark.sql.session.timeZone", "UTC") + .config("spark.sql.parquet.outputTimestampType", "TIMESTAMP_MICROS") + .config("spark.sql.warehouse.dir", str(warehouse_dir)) + .getOrCreate() + ) + return SparkBackend(session=session, **kwargs) + return make_backend(name, **kwargs) + + @pytest.fixture(params=BACKEND_NAMES) def iter_backends(request) -> Generator[IbisBackend, None, None]: """Return an iterable of in-memory backends to test.""" @@ -29,6 +50,14 @@ def iter_backends(request) -> Generator[IbisBackend, None, None]: backend.dispose() +@pytest.fixture(params=ALL_BACKEND_NAMES) +def iter_all_backends(request, tmp_path) -> Generator[IbisBackend, None, None]: + """Return an iterable of in-memory backends including Spark when available.""" + backend = _make_backend(request.param, tmp_path=tmp_path) + yield backend + backend.dispose() + + @pytest.fixture(params=BACKEND_NAMES) def iter_stores_by_engine(request) -> Generator[Store, None, None]: """Return an iterable of stores with different backends to test.""" @@ -38,6 +67,15 @@ def iter_stores_by_engine(request) -> Generator[Store, None, None]: store.dispose() +@pytest.fixture(params=ALL_BACKEND_NAMES) +def iter_all_stores(request, tmp_path) -> Generator[Store, None, None]: + """Return an iterable of stores including Spark when available.""" + backend = _make_backend(request.param, tmp_path=tmp_path) + store = Store(backend=backend) + yield store + store.dispose() + + @pytest.fixture(params=BACKEND_NAMES) def iter_stores_by_engine_no_data_ingestion(request) -> Generator[Store, None, None]: """Return an iterable of stores with different backends to test.""" diff --git a/tests/test_spark_backend.py b/tests/test_spark_backend.py index 56b3879..a7eebbe 100644 --- a/tests/test_spark_backend.py +++ b/tests/test_spark_backend.py @@ -108,6 +108,133 @@ def test_spark_write_table_to_parquet_preserves_timestamp_type( assert out["timestamp"].iloc[0] == pd.Timestamp("2020-01-01 00:00:00.123456+00:00") +def test_spark_write_query_to_parquet_preserves_timestamp_tz( + spark_store: Store, tmp_path: Path +) -> None: + """write_query_to_parquet must preserve tz semantics when name is supplied.""" + schema = TableSchema( + name="spark_query_tz", + value_column="value", + time_config=DatetimeRange( + time_column="timestamp", + start=datetime(2020, 1, 1, tzinfo=ZoneInfo("UTC")), + length=2, + resolution=timedelta(hours=1), + interval_type=TimeIntervalType.PERIOD_BEGINNING, + ), + time_array_id_columns=["id"], + ) + df = pd.DataFrame( + { + "id": [1, 1], + "timestamp": pd.to_datetime( + ["2020-01-01 00:00:00+00:00", "2020-01-01 01:00:00+00:00"], utc=True + ), + "value": [1.0, 2.0], + } + ) + spark_store.ingest_table(df, schema, skip_time_checks=True) + + outfile = tmp_path / "query_tz.parquet" + expr = spark_store.get_table(schema.name) + spark_store.write_query_to_parquet(expr, outfile, overwrite=True, name=schema.name) + + out = pd.read_parquet(outfile) + assert pd.api.types.is_datetime64_any_dtype(out["timestamp"]) + assert out["timestamp"].iloc[0] == pd.Timestamp("2020-01-01 00:00:00+00:00") + + +def test_spark_ingest_normalizes_tz_aware_to_ntz(spark_store: Store) -> None: + """A TIMESTAMP_NTZ schema with tz-aware input should strip tz on all backends.""" + schema = TableSchema( + name="spark_ntz", + value_column="value", + time_config=DatetimeRange( + time_column="timestamp", + start=datetime(2020, 1, 1), + length=2, + resolution=timedelta(hours=1), + interval_type=TimeIntervalType.PERIOD_BEGINNING, + ), + time_array_id_columns=["id"], + ) + df = pd.DataFrame( + { + "id": [1, 1], + "timestamp": pd.to_datetime( + ["2020-01-01 00:00:00+00:00", "2020-01-01 01:00:00+00:00"], utc=True + ), + "value": [1.0, 2.0], + } + ) + spark_store.ingest_table(df, schema, skip_time_checks=True) + out = spark_store.read_table(schema.name) + # Should be tz-naive after round-trip + assert not isinstance(out["timestamp"].dtype, pd.DatetimeTZDtype) + assert out["timestamp"].iloc[0] == pd.Timestamp("2020-01-01 00:00:00") + + +def test_spark_ingest_normalizes_tz_naive_to_tz(spark_store: Store) -> None: + """A TIMESTAMP_TZ schema with tz-naive input should localize to UTC on all backends.""" + schema = TableSchema( + name="spark_tz", + value_column="value", + time_config=DatetimeRange( + time_column="timestamp", + start=datetime(2020, 1, 1, tzinfo=ZoneInfo("UTC")), + length=2, + resolution=timedelta(hours=1), + interval_type=TimeIntervalType.PERIOD_BEGINNING, + ), + time_array_id_columns=["id"], + ) + df = pd.DataFrame( + { + "id": [1, 1], + "timestamp": pd.to_datetime(["2020-01-01 00:00:00", "2020-01-01 01:00:00"]), + "value": [1.0, 2.0], + } + ) + spark_store.ingest_table(df, schema, skip_time_checks=True) + out = spark_store.read_table(schema.name) + assert isinstance(out["timestamp"].dtype, pd.DatetimeTZDtype) + assert out["timestamp"].iloc[0] == pd.Timestamp("2020-01-01 00:00:00+00:00") + + +def test_spark_time_zone_conversion(spark_store: Store) -> None: + """Time zone conversion should work on Spark the same as other backends.""" + schema = TableSchema( + name="spark_tzconv", + value_column="value", + time_config=DatetimeRange( + time_column="timestamp", + start=datetime(2020, 1, 1, tzinfo=ZoneInfo("UTC")), + length=24, + resolution=timedelta(hours=1), + interval_type=TimeIntervalType.PERIOD_BEGINNING, + ), + time_array_id_columns=["id"], + ) + from chronify.datetime_range_generator import DatetimeRangeGenerator + + timestamps = DatetimeRangeGenerator(schema.time_config).list_timestamps() + df = pd.DataFrame( + { + "id": [1] * len(timestamps), + "timestamp": pd.to_datetime(timestamps), + "value": range(len(timestamps)), + } + ) + spark_store.ingest_table(df, schema, skip_time_checks=True) + + to_tz = ZoneInfo("US/Eastern") + dst_schema = spark_store.convert_time_zone(schema.name, to_tz) + out = spark_store.read_table(dst_schema.name) + expected = df["timestamp"].dt.tz_convert(to_tz).dt.tz_localize(None) + out_sorted = out.sort_values("timestamp").reset_index(drop=True) + assert list(out_sorted["timestamp"]) == list(expected) + + def test_spark_backend_rejects_non_utc_session() -> None: _require_java_home() pyspark = pytest.importorskip("pyspark.sql") From a6f710bd9dc3ea05f05b7f49bb57957c6570ff6b Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Fri, 10 Apr 2026 13:58:54 -0600 Subject: [PATCH 09/48] Fix CI --- .github/workflows/ci.yml | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f5c7d64..c571767 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -30,14 +30,9 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install ".[dev,spark]" - wget https://dlcdn.apache.org/spark/spark-4.0.1/spark-4.0.1-bin-hadoop3.tgz - tar -xzf spark-4.0.1-bin-hadoop3.tgz - export SPARK_HOME=$(pwd)/spark-4.0.1-bin-hadoop3 - export PATH=$SPARK_HOME/sbin:$PATH - start-thriftserver.sh - name: Run pytest with coverage run: | - CHRONIFY_HIVE_URL=hive://localhost:10000/default pytest -v --cov --cov-report=xml + pytest -v --cov --cov-report=xml - name: codecov uses: codecov/codecov-action@v4.2.0 if: ${{ matrix.os == env.DEFAULT_OS && matrix.python-version == env.DEFAULT_PYTHON }} From 184317bbbe27252f40aff2236770b99543e3cd5b Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Fri, 10 Apr 2026 14:17:28 -0600 Subject: [PATCH 10/48] Address PR comments --- pyproject.toml | 2 +- src/chronify/ibis/spark_backend.py | 1 + src/chronify/time_series_checker.py | 12 ++--- src/chronify/time_series_mapper_base.py | 4 +- ...apper_column_representative_to_datetime.py | 45 ++++++++++--------- tests/conftest.py | 4 +- tests/test_spark_backend.py | 1 + 7 files changed, 38 insertions(+), 31 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a77da99..b1d3638 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ [project.optional-dependencies] spark = [ "ibis-framework[pyspark]", - "pyspark == 4.0.0", + "pyspark >= 4.0", ] dev = [ diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index 4b3ca6d..62bec63 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -136,6 +136,7 @@ def execute_sql_to_df(self, query: str) -> pd.DataFrame: return cast(pd.DataFrame, self._session.sql(query).toPandas()) def dispose(self) -> None: + self._connection.disconnect() if self._owns_session: self._session.stop() diff --git a/src/chronify/time_series_checker.py b/src/chronify/time_series_checker.py index e00f233..8bd444c 100644 --- a/src/chronify/time_series_checker.py +++ b/src/chronify/time_series_checker.py @@ -121,8 +121,8 @@ def _check_null_consistency(self) -> None: all_are_null = " AND ".join((f"{x} IS NULL" for x in time_columns)) any_are_null = " OR ".join((f"{x} IS NULL" for x in time_columns)) - query_all = f"SELECT COUNT(*) FROM {self._schema.name} WHERE {all_are_null}" - query_any = f"SELECT COUNT(*) FROM {self._schema.name} WHERE {any_are_null}" + query_all = f"SELECT COUNT(*) FROM {self._table_name} WHERE {all_are_null}" + query_any = f"SELECT COUNT(*) FROM {self._table_name} WHERE {any_are_null}" df_all = self._backend.execute_sql_to_df(query_all) df_any = self._backend.execute_sql_to_df(query_any) count_all = df_all.iloc[0, 0] @@ -154,7 +154,7 @@ def _check_expected_timestamps_by_time_array(self, count: int) -> None: query = f""" WITH distinct_time_values_by_array AS ( SELECT DISTINCT {time_cols} - FROM {self._schema.name} + FROM {self._table_name} WHERE {where_clause} ), t1 AS ( @@ -163,7 +163,7 @@ def _check_expected_timestamps_by_time_array(self, count: int) -> None: ), t2 AS ( SELECT COUNT(*) AS count_by_ta - FROM {self._schema.name} + FROM {self._table_name} WHERE {where_clause} ) SELECT @@ -176,7 +176,7 @@ def _check_expected_timestamps_by_time_array(self, count: int) -> None: query = f""" WITH distinct_time_values_by_array AS ( SELECT DISTINCT {id_cols}, {time_cols} - FROM {self._schema.name} + FROM {self._table_name} WHERE {where_clause} ), t1 AS ( @@ -186,7 +186,7 @@ def _check_expected_timestamps_by_time_array(self, count: int) -> None: ), t2 AS ( SELECT {id_cols}, COUNT(*) AS count_by_ta - FROM {self._schema.name} + FROM {self._table_name} WHERE {where_clause} GROUP BY {id_cols} ) diff --git a/src/chronify/time_series_mapper_base.py b/src/chronify/time_series_mapper_base.py index 2f7fb66..b8ff946 100644 --- a/src/chronify/time_series_mapper_base.py +++ b/src/chronify/time_series_mapper_base.py @@ -158,7 +158,9 @@ def _apply_mapping( # noqa: C901 # Build join predicates from_keys = [x for x in right_columns if x.startswith("from_")] keys = [x.removeprefix("from_") for x in from_keys] - assert set(keys).issubset(set(left_columns)), f"Keys {keys} not in table={from_schema.name}" + if not set(keys).issubset(set(left_columns)): + msg = f"Mapping keys {keys} not found in source table {from_schema.name}" + raise ConflictingInputsError(msg) predicates = [] for k in keys: left_col = left[k] diff --git a/src/chronify/time_series_mapper_column_representative_to_datetime.py b/src/chronify/time_series_mapper_column_representative_to_datetime.py index cecde15..6f4debb 100644 --- a/src/chronify/time_series_mapper_column_representative_to_datetime.py +++ b/src/chronify/time_series_mapper_column_representative_to_datetime.py @@ -1,5 +1,6 @@ from typing import Optional, Generator import re +import uuid from pathlib import Path import pandas as pd from datetime import datetime @@ -125,7 +126,9 @@ def _validate_mdh_time_config(self) -> None: def _intermediate_mapping_ymdp_to_ymdh(self) -> TableSchema: """Convert ymdp to ymdh for intermediate mapping.""" - mapping_table_name = "intermediate_ymdp_to_ymdh" + uid = uuid.uuid4().hex[:8] + mapping_table_name = f"_int_ymdp_to_ymdh_{uid}" + intermediate_ymdh_table_name = f"_int_ymdh_{uid}" period_col = self._from_time_config.hour_columns[0] # Get distinct periods @@ -141,28 +144,28 @@ def _intermediate_mapping_ymdp_to_ymdh(self) -> TableSchema: if_exists="fail", ) - # Build the join query using ibis - ymdp_table = self._backend.table(self._from_schema.name) - mapping_table = self._backend.table(mapping_table_name) + try: + # Build the join query using ibis + ymdp_table = self._backend.table(self._from_schema.name) + mapping_table = self._backend.table(mapping_table_name) - # Select all columns from ymdp except the period column, plus the hour column from mapping - ymdp_cols = [c for c in ymdp_table.columns if c != period_col] - select_exprs = [ymdp_table[c] for c in ymdp_cols] + [mapping_table["hour"]] + # Select all columns from ymdp except the period column, plus hour from mapping + ymdp_cols = [c for c in ymdp_table.columns if c != period_col] + select_exprs = [ymdp_table[c] for c in ymdp_cols] + [mapping_table["hour"]] - joined = ymdp_table.join( - mapping_table, ymdp_table[period_col] == mapping_table["from_period"] - ) - result = joined.select(select_exprs) - - intermediate_ymdh_table_name = "intermediate_Ymdh" - self._backend.create_table(intermediate_ymdh_table_name, result) - - # Clean up mapping table - self._backend.drop_table(mapping_table_name) - - assert isinstance( - self._from_time_config, YearMonthDayPeriodTimeNTZ - ), "Intermediate mapping only valid for YearMonthDayPeriodNTZ time config" + joined = ymdp_table.join( + mapping_table, ymdp_table[period_col] == mapping_table["from_period"] + ) + result = joined.select(select_exprs) + self._backend.create_table(intermediate_ymdh_table_name, result) + finally: + # Always clean up the mapping table + if self._backend.has_table(mapping_table_name): + self._backend.drop_table(mapping_table_name) + + if not isinstance(self._from_time_config, YearMonthDayPeriodTimeNTZ): + msg = "Intermediate mapping only valid for YearMonthDayPeriodNTZ time config" + raise InvalidParameter(msg) return self._create_intermediate_ymdh_schema( intermediate_ymdh_table_name, self._from_schema, self._from_time_config ) diff --git a/tests/conftest.py b/tests/conftest.py index ab400b5..4485a28 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,12 +27,12 @@ def create_duckdb_backend() -> IbisBackend: def _make_backend(name: str, tmp_path: Path | None = None, **kwargs: Any) -> IbisBackend: """Create a backend, handling Spark's SparkSession requirement.""" if name == "spark": + pyspark = pytest.importorskip("pyspark.sql") from chronify.ibis.spark_backend import SparkBackend - from pyspark.sql import SparkSession warehouse_dir = (tmp_path or Path("/tmp")) / "spark-warehouse" # noqa: S108 session = ( - SparkSession.builder.master("local") + pyspark.SparkSession.builder.master("local") .config("spark.sql.session.timeZone", "UTC") .config("spark.sql.parquet.outputTimestampType", "TIMESTAMP_MICROS") .config("spark.sql.warehouse.dir", str(warehouse_dir)) diff --git a/tests/test_spark_backend.py b/tests/test_spark_backend.py index a7eebbe..fe33efb 100644 --- a/tests/test_spark_backend.py +++ b/tests/test_spark_backend.py @@ -34,6 +34,7 @@ def spark_store(tmp_path: Path) -> Store: store = Store(backend=SparkBackend(session=session)) yield store store.dispose() + session.stop() def test_spark_round_trip_timestamp_tz_preserves_fractional_seconds(spark_store: Store) -> None: From 874ebd0a2e8e528ac58a7bfae584de856832394a Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Fri, 10 Apr 2026 14:26:49 -0600 Subject: [PATCH 11/48] Add test coverage --- tests/test_time_series_checker.py | 32 +++++++++--------- tests/test_time_zone_converter.py | 12 +++---- tests/test_time_zone_localizer.py | 54 +++++++++++++++++-------------- 3 files changed, 52 insertions(+), 46 deletions(-) diff --git a/tests/test_time_series_checker.py b/tests/test_time_series_checker.py index f48d95d..5beef1e 100644 --- a/tests/test_time_series_checker.py +++ b/tests/test_time_series_checker.py @@ -14,44 +14,44 @@ from chronify.time_series_checker import check_timestamps -def test_valid_datetimes_with_tz(iter_backends: IbisBackend) -> None: +def test_valid_datetimes_with_tz(iter_all_backends: IbisBackend) -> None: """Valid timestamps with time zones.""" - _run_test(iter_backends, *_get_inputs_for_valid_datetimes_with_tz()) + _run_test(iter_all_backends, *_get_inputs_for_valid_datetimes_with_tz()) -def test_valid_datetimes_without_tz(iter_backends: IbisBackend) -> None: +def test_valid_datetimes_without_tz(iter_all_backends: IbisBackend) -> None: """Valid timestamps without time zones.""" - _run_test(iter_backends, *_get_inputs_for_valid_datetimes_without_tz()) + _run_test(iter_all_backends, *_get_inputs_for_valid_datetimes_without_tz()) -def test_invalid_datetimes(iter_backends: IbisBackend) -> None: +def test_invalid_datetimes(iter_all_backends: IbisBackend) -> None: """Timestamps do not match the schema.""" - _run_test(iter_backends, *_get_inputs_for_incorrect_datetimes()) + _run_test(iter_all_backends, *_get_inputs_for_incorrect_datetimes()) -def test_invalid_datetime_length(iter_backends: IbisBackend) -> None: +def test_invalid_datetime_length(iter_all_backends: IbisBackend) -> None: """Timestamps do not match the schema.""" - _run_test(iter_backends, *_get_inputs_for_incorrect_datetime_length()) + _run_test(iter_all_backends, *_get_inputs_for_incorrect_datetime_length()) -def test_mismatched_time_array_lengths(iter_backends: IbisBackend) -> None: +def test_mismatched_time_array_lengths(iter_all_backends: IbisBackend) -> None: """Some time arrays have different lengths.""" - _run_test(iter_backends, *_get_inputs_for_mismatched_time_array_lengths()) + _run_test(iter_all_backends, *_get_inputs_for_mismatched_time_array_lengths()) -def test_incorrect_lengths(iter_backends: IbisBackend) -> None: +def test_incorrect_lengths(iter_all_backends: IbisBackend) -> None: """All time arrays are consistent but have the wrong length.""" - _run_test(iter_backends, *_get_inputs_for_incorrect_lengths()) + _run_test(iter_all_backends, *_get_inputs_for_incorrect_lengths()) -def test_incorrect_time_arrays(iter_backends: IbisBackend) -> None: +def test_incorrect_time_arrays(iter_all_backends: IbisBackend) -> None: """The time arrays form a complete set but are individually incorrect.""" - _run_test(iter_backends, *_get_inputs_for_incorrect_time_arrays()) + _run_test(iter_all_backends, *_get_inputs_for_incorrect_time_arrays()) -def test_incorrect_time_arrays_with_duplicates(iter_backends: IbisBackend) -> None: +def test_incorrect_time_arrays_with_duplicates(iter_all_backends: IbisBackend) -> None: """The time arrays form a complete set but are individually incorrect.""" - _run_test(iter_backends, *_get_inputs_for_incorrect_time_arrays_with_duplicates()) + _run_test(iter_all_backends, *_get_inputs_for_incorrect_time_arrays_with_duplicates()) def _run_test( diff --git a/tests/test_time_zone_converter.py b/tests/test_time_zone_converter.py index e76e37b..96d322e 100644 --- a/tests/test_time_zone_converter.py +++ b/tests/test_time_zone_converter.py @@ -172,27 +172,27 @@ def run_conversion_with_error( tzc2.convert_time_zone(check_mapped_timestamps=True) -def test_src_table_no_time_zone(iter_backends: IbisBackend) -> None: +def test_src_table_no_time_zone(iter_all_backends: IbisBackend) -> None: from_schema = get_datetime_schema(2018, None, TimeIntervalType.PERIOD_BEGINNING, "base_table") df = generate_datetime_dataframe(from_schema) error = (InvalidParameter, "Source schema time config start time must be timezone-aware") - run_conversion_with_error(iter_backends, df, from_schema, False, error) + run_conversion_with_error(iter_all_backends, df, from_schema, False, error) @pytest.mark.parametrize( "to_time_zone", [None, ZoneInfo("US/Central"), ZoneInfo("America/Los_Angeles")] ) -def test_time_conversion(iter_backends: IbisBackend, to_time_zone: tzinfo | None) -> None: +def test_time_conversion(iter_all_backends: IbisBackend, to_time_zone: tzinfo | None) -> None: from_schema = get_datetime_schema( 2018, ZoneInfo("US/Mountain"), TimeIntervalType.PERIOD_BEGINNING, "base_table" ) df = generate_datetime_dataframe(from_schema) - run_conversion(iter_backends, df, from_schema, to_time_zone) + run_conversion(iter_all_backends, df, from_schema, to_time_zone) @pytest.mark.parametrize("wrap_time_allowed", [False, True]) def test_time_conversion_to_column_time_zones( - iter_backends: IbisBackend, wrap_time_allowed: bool + iter_all_backends: IbisBackend, wrap_time_allowed: bool ) -> None: from_schema = get_datetime_schema( 2018, @@ -202,4 +202,4 @@ def test_time_conversion_to_column_time_zones( has_tz_col=True, ) df = generate_dataframe_with_tz_col(from_schema) - run_conversion_to_column_time_zones(iter_backends, df, from_schema, wrap_time_allowed) + run_conversion_to_column_time_zones(iter_all_backends, df, from_schema, wrap_time_allowed) diff --git a/tests/test_time_zone_localizer.py b/tests/test_time_zone_localizer.py index 755f83b..843ff8f 100644 --- a/tests/test_time_zone_localizer.py +++ b/tests/test_time_zone_localizer.py @@ -225,15 +225,15 @@ def run_localization_by_column_with_error( @pytest.mark.parametrize("to_time_zone", [None, ZoneInfo("Etc/GMT+5")]) -def test_time_localization(iter_backends: IbisBackend, to_time_zone: tzinfo | None) -> None: +def test_time_localization(iter_all_backends: IbisBackend, to_time_zone: tzinfo | None) -> None: from_schema = get_datetime_schema(2018, None, TimeIntervalType.PERIOD_BEGINNING, "base_table") df = generate_datetime_dataframe(from_schema) - run_localization(iter_backends, df, from_schema, to_time_zone) + run_localization(iter_all_backends, df, from_schema, to_time_zone) @pytest.mark.parametrize("from_time_tz", [None, ZoneInfo("US/Mountain"), ZoneInfo("MST")]) def test_time_localization_by_column( - iter_backends: IbisBackend, from_time_tz: tzinfo | None + iter_all_backends: IbisBackend, from_time_tz: tzinfo | None ) -> None: from_schema = get_datetime_with_tz_col_schema( 2018, @@ -243,33 +243,35 @@ def test_time_localization_by_column( standard_tz=True, ) df = generate_dataframe_with_tz_col(from_schema) - run_localization_to_column_time_zones(iter_backends, df, from_schema) + run_localization_to_column_time_zones(iter_all_backends, df, from_schema) # Error tests for TimeZoneLocalizer -def test_time_localizer_to_dst_time_error(iter_backends: IbisBackend) -> None: +def test_time_localizer_to_dst_time_error(iter_all_backends: IbisBackend) -> None: """Test that TimeZoneLocalizer raises error when to_time_zone is a non standard time zone""" from_schema = get_datetime_schema(2018, None, TimeIntervalType.PERIOD_BEGINNING, "base_table") df = generate_datetime_dataframe(from_schema) to_time_zone = ZoneInfo("US/Mountain") # has DST - ingest_data(iter_backends, df, from_schema) + ingest_data(iter_all_backends, df, from_schema) with pytest.raises( InvalidParameter, match="TimeZoneLocalizer only supports standard time zones" ): - localize_time_zone(iter_backends, from_schema, to_time_zone, check_mapped_timestamps=True) + localize_time_zone( + iter_all_backends, from_schema, to_time_zone, check_mapped_timestamps=True + ) -def test_time_localizer_with_tz_aware_config_error(iter_backends: IbisBackend) -> None: +def test_time_localizer_with_tz_aware_config_error(iter_all_backends: IbisBackend) -> None: """Test that TimeZoneLocalizer raises error when start time is tz-aware""" from_schema = get_datetime_schema( 2018, ZoneInfo("US/Mountain"), TimeIntervalType.PERIOD_BEGINNING, "base_table" ) df = generate_datetime_dataframe(from_schema) error = (InvalidParameter, "Source schema time config start time must be tz-naive") - run_localization_with_error(iter_backends, df, from_schema, error) + run_localization_with_error(iter_all_backends, df, from_schema, error) -def test_time_localizer_with_wrong_dtype_error(iter_backends: IbisBackend) -> None: +def test_time_localizer_with_wrong_dtype_error(iter_all_backends: IbisBackend) -> None: """Test that TimeZoneLocalizer raises error when dtype is not TIMESTAMP_NTZ""" from_schema = get_datetime_schema(2018, None, TimeIntervalType.PERIOD_BEGINNING, "base_table") # Manually change dtype to TIMESTAMP_TZ to trigger error @@ -278,21 +280,23 @@ def test_time_localizer_with_wrong_dtype_error(iter_backends: IbisBackend) -> No ) df = generate_datetime_dataframe(from_schema) error = (InvalidParameter, "Source schema time config dtype must be TIMESTAMP_NTZ") - run_localization_with_error(iter_backends, df, from_schema, error) + run_localization_with_error(iter_all_backends, df, from_schema, error) -def test_time_localizer_with_datetime_range_with_tz_col_error(iter_backends: IbisBackend) -> None: +def test_time_localizer_with_datetime_range_with_tz_col_error( + iter_all_backends: IbisBackend, +) -> None: """Test that TimeZoneLocalizer raises error when time config is DatetimeRangeWithTZColumn""" from_schema = get_datetime_with_tz_col_schema( 2018, None, TimeIntervalType.PERIOD_BEGINNING, "base_table", standard_tz=True ) df = generate_dataframe_with_tz_col(from_schema) error = (InvalidParameter, "try using TimeZoneLocalizerByColumn") - run_localization_with_error(iter_backends, df, from_schema, error) + run_localization_with_error(iter_all_backends, df, from_schema, error) # Error tests for TimeZoneLocalizerByColumn -def test_time_localizer_by_column_to_dst_time_error(iter_backends: IbisBackend) -> None: +def test_time_localizer_by_column_to_dst_time_error(iter_all_backends: IbisBackend) -> None: """Test that TimeZoneLocalizerByColumn raises error when to_time_zone is a non standard time zone""" from_schema = get_datetime_with_tz_col_schema( 2018, @@ -302,22 +306,22 @@ def test_time_localizer_by_column_to_dst_time_error(iter_backends: IbisBackend) standard_tz=False, ) df = generate_dataframe_with_tz_col(from_schema) - ingest_data(iter_backends, df, from_schema) + ingest_data(iter_all_backends, df, from_schema) with pytest.raises( InvalidParameter, match="TimeZoneLocalizerByColumn only supports standard time zones" ): - localize_time_zone_by_column(iter_backends, from_schema, check_mapped_timestamps=True) + localize_time_zone_by_column(iter_all_backends, from_schema, check_mapped_timestamps=True) -def test_time_localizer_by_column_missing_tz_column_error(iter_backends: IbisBackend) -> None: +def test_time_localizer_by_column_missing_tz_column_error(iter_all_backends: IbisBackend) -> None: """Test that TimeZoneLocalizerByColumn raises error when time_zone_column is missing for DatetimeRange""" from_schema = get_datetime_schema(2018, None, TimeIntervalType.PERIOD_BEGINNING, "base_table") df = generate_datetime_dataframe(from_schema) error = (MissingValue, "time_zone_column must be provided") - run_localization_by_column_with_error(iter_backends, df, from_schema, error) + run_localization_by_column_with_error(iter_all_backends, df, from_schema, error) -def test_time_localizer_by_column_wrong_dtype_error(iter_backends: IbisBackend) -> None: +def test_time_localizer_by_column_wrong_dtype_error(iter_all_backends: IbisBackend) -> None: """Test that TimeZoneLocalizerByColumn raises error when dtype is not TIMESTAMP_NTZ""" from_schema = get_datetime_with_tz_col_schema( 2018, None, TimeIntervalType.PERIOD_BEGINNING, "base_table", standard_tz=True @@ -328,20 +332,22 @@ def test_time_localizer_by_column_wrong_dtype_error(iter_backends: IbisBackend) ) df = generate_dataframe_with_tz_col(from_schema) error = (InvalidParameter, "Source schema time config dtype must be TIMESTAMP_NTZ") - run_localization_by_column_with_error(iter_backends, df, from_schema, error) + run_localization_by_column_with_error(iter_all_backends, df, from_schema, error) -def test_time_localizer_by_column_non_standard_tz_error(iter_backends: IbisBackend) -> None: +def test_time_localizer_by_column_non_standard_tz_error(iter_all_backends: IbisBackend) -> None: """Test that TimeZoneLocalizerByColumn raises error when time zones are not standard""" from_schema = get_datetime_with_tz_col_schema( 2018, None, TimeIntervalType.PERIOD_BEGINNING, "base_table", standard_tz=False ) df = generate_dataframe_with_tz_col(from_schema) error = (InvalidParameter, "is not a standard time zone") - run_localization_by_column_with_error(iter_backends, df, from_schema, error) + run_localization_by_column_with_error(iter_all_backends, df, from_schema, error) -def test_localize_time_zone_by_column_missing_tz_column_error(iter_backends: IbisBackend) -> None: +def test_localize_time_zone_by_column_missing_tz_column_error( + iter_all_backends: IbisBackend, +) -> None: """Test that localize_time_zone_by_column raises error when time_zone_column is None for DatetimeRange""" from_schema = get_datetime_schema(2018, None, TimeIntervalType.PERIOD_BEGINNING, "base_table") df = generate_datetime_dataframe(from_schema) @@ -350,5 +356,5 @@ def test_localize_time_zone_by_column_missing_tz_column_error(iter_backends: Ibi "time_zone_column must be provided when source schema time config is of type DatetimeRange", ) run_localization_by_column_with_error( - iter_backends, df, from_schema, error, time_zone_column=None + iter_all_backends, df, from_schema, error, time_zone_column=None ) From 7ca977b102fbcd2db54685b78548e41b9023fddd Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Fri, 10 Apr 2026 14:30:54 -0600 Subject: [PATCH 12/48] Fix dispose in owned session --- src/chronify/ibis/spark_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index 62bec63..f7fdeb0 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -136,8 +136,8 @@ def execute_sql_to_df(self, query: str) -> pd.DataFrame: return cast(pd.DataFrame, self._session.sql(query).toPandas()) def dispose(self) -> None: - self._connection.disconnect() if self._owns_session: + self._connection.disconnect() self._session.stop() def reconnect(self) -> None: From 650a506e65999888a76e916c2c708a4625cfba97 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Sat, 11 Apr 2026 13:20:31 -0600 Subject: [PATCH 13/48] Fix tests --- pyproject.toml | 2 +- src/chronify/ibis/base.py | 3 +++ src/chronify/ibis/duckdb_backend.py | 3 ++- src/chronify/ibis/spark_backend.py | 4 ++-- src/chronify/ibis/sqlite_backend.py | 5 +++-- src/chronify/time_series_mapper_base.py | 2 +- 6 files changed, 12 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b1d3638..a77da99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ [project.optional-dependencies] spark = [ "ibis-framework[pyspark]", - "pyspark >= 4.0", + "pyspark == 4.0.0", ] dev = [ diff --git a/src/chronify/ibis/base.py b/src/chronify/ibis/base.py index b3aedae..839c4ea 100644 --- a/src/chronify/ibis/base.py +++ b/src/chronify/ibis/base.py @@ -40,6 +40,7 @@ def create_table( name: str, obj: pd.DataFrame | ir.Table | None = None, schema: ibis.Schema | None = None, + overwrite: bool = False, ) -> ir.Table: """Create a table in the database. @@ -51,6 +52,8 @@ def create_table( Data to populate the table with. schema Schema to use if obj is None. + overwrite + If True, replace the table if it already exists. Returns ------- diff --git a/src/chronify/ibis/duckdb_backend.py b/src/chronify/ibis/duckdb_backend.py index 0cca830..fa8dc04 100644 --- a/src/chronify/ibis/duckdb_backend.py +++ b/src/chronify/ibis/duckdb_backend.py @@ -36,8 +36,9 @@ def create_table( name: str, obj: pd.DataFrame | ir.Table | None = None, schema: ibis.Schema | None = None, + overwrite: bool = False, ) -> ir.Table: - return self._connection.create_table(name, obj=obj, schema=schema, overwrite=False) + return self._connection.create_table(name, obj=obj, schema=schema, overwrite=overwrite) def create_view(self, name: str, expr: ir.Table) -> ir.Table: return self._connection.create_view(name, expr, overwrite=False) diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index f7fdeb0..5cff565 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -55,10 +55,11 @@ def create_table( name: str, obj: pd.DataFrame | ir.Table | None = None, schema: ibis.Schema | None = None, + overwrite: bool = False, ) -> ir.Table: if isinstance(obj, pd.DataFrame): obj = self._prepare_data_for_spark(obj) - return self._connection.create_table(name, obj=obj, schema=schema, overwrite=False) + return self._connection.create_table(name, obj=obj, schema=schema, overwrite=overwrite) def create_view(self, name: str, expr: ir.Table) -> ir.Table: return self._connection.create_view(name, expr, overwrite=False) @@ -138,7 +139,6 @@ def execute_sql_to_df(self, query: str) -> pd.DataFrame: def dispose(self) -> None: if self._owns_session: self._connection.disconnect() - self._session.stop() def reconnect(self) -> None: pass # Spark sessions are long-lived diff --git a/src/chronify/ibis/sqlite_backend.py b/src/chronify/ibis/sqlite_backend.py index 6ea2921..383ba0d 100644 --- a/src/chronify/ibis/sqlite_backend.py +++ b/src/chronify/ibis/sqlite_backend.py @@ -53,13 +53,14 @@ def create_table( name: str, obj: pd.DataFrame | ir.Table | None = None, schema: ibis.Schema | None = None, + overwrite: bool = False, ) -> ir.Table: if isinstance(obj, ir.Table): # SQLite CREATE TABLE AS SELECT loses datetime type info. # Execute the expression first, then create from the DataFrame. df = self._connection.execute(obj) - return self._connection.create_table(name, obj=df, overwrite=False) - return self._connection.create_table(name, obj=obj, schema=schema, overwrite=False) + return self._connection.create_table(name, obj=df, overwrite=overwrite) + return self._connection.create_table(name, obj=obj, schema=schema, overwrite=overwrite) def create_view(self, name: str, expr: ir.Table) -> ir.Table: return self._connection.create_view(name, expr, overwrite=False) diff --git a/src/chronify/time_series_mapper_base.py b/src/chronify/time_series_mapper_base.py index b8ff946..3059c85 100644 --- a/src/chronify/time_series_mapper_base.py +++ b/src/chronify/time_series_mapper_base.py @@ -215,4 +215,4 @@ def _right_col(col: str) -> Any: write_parquet(backend, result, output_file, overwrite=True, config=to_schema.time_config) return - backend.create_table(to_schema.name, result) + backend.create_table(to_schema.name, result, overwrite=True) From b4f2d6ae2728f984a40e6cf8fafa1cca0bc1b4d6 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Sat, 11 Apr 2026 14:49:32 -0600 Subject: [PATCH 14/48] Spark fixes --- src/chronify/ibis/functions.py | 17 ++++++++++------ src/chronify/ibis/spark_backend.py | 32 ++++++++++++++++++++++++++++-- 2 files changed, 41 insertions(+), 8 deletions(-) diff --git a/src/chronify/ibis/functions.py b/src/chronify/ibis/functions.py index da4b14c..8696147 100644 --- a/src/chronify/ibis/functions.py +++ b/src/chronify/ibis/functions.py @@ -216,16 +216,21 @@ def _convert_spark_output_for_datetime(df: pd.DataFrame, config: DatetimeRanges) return col = df[config.time_column] - if not pd.api.types.is_datetime64_any_dtype(col): - df[config.time_column] = pd.to_datetime(col, utc=True) - col = df[config.time_column] if config.dtype == TimeDataType.TIMESTAMP_TZ: - if not isinstance(col.dtype, DatetimeTZDtype): - df[config.time_column] = col.dt.tz_localize("UTC") + if not pd.api.types.is_datetime64_any_dtype(col): + col = pd.to_datetime(col, utc=True) + elif isinstance(col.dtype, DatetimeTZDtype): + col = col.dt.tz_convert("UTC") + else: + col = col.dt.tz_localize("UTC") + df[config.time_column] = col.dt.tz_localize(None).astype("datetime64[us]") else: + if not pd.api.types.is_datetime64_any_dtype(col): + col = pd.to_datetime(col, utc=False) + df[config.time_column] = col.astype("datetime64[us]") if isinstance(col.dtype, DatetimeTZDtype): - df[config.time_column] = col.dt.tz_convert(None) + df[config.time_column] = col.dt.tz_convert(None).astype("datetime64[us]") def _write_to_duckdb( diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index 5cff565..b664d47 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -1,7 +1,10 @@ """Spark backend implementation for Ibis.""" import uuid +import shutil from typing import Any, cast +from pathlib import Path +from urllib.parse import urlparse, unquote import ibis import ibis.expr.types as ir @@ -59,7 +62,13 @@ def create_table( ) -> ir.Table: if isinstance(obj, pd.DataFrame): obj = self._prepare_data_for_spark(obj) - return self._connection.create_table(name, obj=obj, schema=schema, overwrite=overwrite) + try: + return self._connection.create_table(name, obj=obj, schema=schema, overwrite=overwrite) + except Exception as exc: + if "LOCATION_ALREADY_EXISTS" not in str(exc): + raise + self._remove_managed_table_location(name) + return self._connection.create_table(name, obj=obj, schema=schema, overwrite=overwrite) def create_view(self, name: str, expr: ir.Table) -> ir.Table: return self._connection.create_view(name, expr, overwrite=False) @@ -101,7 +110,15 @@ def delete_rows(self, name: str, values: dict[str, Any]) -> None: where = " AND ".join(f"{_quote_identifier(c)} = :{p}" for c, p in zip(values, param_names)) sql = f"DELETE FROM {quoted_name} WHERE {where}" args = dict(zip(param_names, values.values())) - self._session.sql(sql, args=args) + try: + self._session.sql(sql, args=args) + except Exception as exc: + if "does not support DELETE" not in str(exc): + raise + df = self._connection.execute(self.table(name)) + for column, value in values.items(): + df = df[df[column] != value] + self.create_table(name, obj=df, overwrite=True) logger.trace("Deleted rows from {} matching {}", name, values) def execute(self, expr: ir.Expr) -> pd.DataFrame: @@ -143,6 +160,17 @@ def dispose(self) -> None: def reconnect(self) -> None: pass # Spark sessions are long-lived + def _remove_managed_table_location(self, name: str) -> None: + location = self._session.conf.get("spark.sql.warehouse.dir", "spark-warehouse") + parsed = urlparse(location) + if parsed.scheme == "file": + warehouse = Path(unquote(parsed.path)) + else: + warehouse = Path(location) + path = warehouse / name + if path.exists(): + shutil.rmtree(path) + @staticmethod def _prepare_data_for_spark(df: pd.DataFrame) -> pd.DataFrame: """Normalize tz-aware pandas timestamps for Spark ingestion. From bfc26842d1a525eec54ab6e7af3d569d7c6db538 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Sat, 11 Apr 2026 16:00:08 -0600 Subject: [PATCH 15/48] Fix CI failures --- pyproject.toml | 2 ++ src/chronify/ibis/functions.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a77da99..6b8450f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,8 @@ module = [ "ibis.*", "pyarrow", "pyarrow.*", + "pyspark", + "pyspark.*", ] ignore_missing_imports = true diff --git a/src/chronify/ibis/functions.py b/src/chronify/ibis/functions.py index 8696147..df02ea5 100644 --- a/src/chronify/ibis/functions.py +++ b/src/chronify/ibis/functions.py @@ -224,7 +224,7 @@ def _convert_spark_output_for_datetime(df: pd.DataFrame, config: DatetimeRanges) col = col.dt.tz_convert("UTC") else: col = col.dt.tz_localize("UTC") - df[config.time_column] = col.dt.tz_localize(None).astype("datetime64[us]") + df[config.time_column] = col.dt.as_unit("us") else: if not pd.api.types.is_datetime64_any_dtype(col): col = pd.to_datetime(col, utc=False) From ed8c2a27433643057c53a21f87fca011172a88fe Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Sat, 11 Apr 2026 17:11:56 -0600 Subject: [PATCH 16/48] Increase test coverage --- src/chronify/ibis/types.py | 11 - tests/test_annual_time_range_generator.py | 37 ++++ tests/test_column_representative_period.py | 88 ++++++++ tests/test_ibis_base.py | 75 +++++++ tests/test_ibis_functions.py | 225 +++++++++++++++++++++ tests/test_ibis_types.py | 80 ++++++++ tests/test_spark_backend.py | 167 +++++++++++++++ tests/test_store_errors.py | 196 ++++++++++++++++++ 8 files changed, 868 insertions(+), 11 deletions(-) create mode 100644 tests/test_annual_time_range_generator.py create mode 100644 tests/test_column_representative_period.py create mode 100644 tests/test_ibis_base.py create mode 100644 tests/test_ibis_functions.py create mode 100644 tests/test_ibis_types.py create mode 100644 tests/test_store_errors.py diff --git a/src/chronify/ibis/types.py b/src/chronify/ibis/types.py index 396bf1b..45e607d 100644 --- a/src/chronify/ibis/types.py +++ b/src/chronify/ibis/types.py @@ -1,8 +1,6 @@ """Type conversion utilities for Ibis backends.""" -import ibis import ibis.expr.datatypes as dt -import pandas as pd import pyarrow as pa # Mapping from user-facing string type names to Ibis data types @@ -84,15 +82,6 @@ def get_duckdb_type_from_ibis(ibis_type: dt.DataType) -> str: raise ValueError(msg) -def get_ibis_types_from_dataframe(df: pd.DataFrame) -> dict[str, dt.DataType]: - """Infer Ibis types from a pandas DataFrame's columns. - - Note: This uses ibis schema inference and does not require a DuckDB connection. - """ - schema = ibis.Schema.from_pandas(df) - return dict(schema.items()) - - def pyarrow_to_ibis_type(arrow_type: pa.DataType) -> dt.DataType: """Convert a PyArrow type to an Ibis DataType.""" return dt.DataType.from_pyarrow(arrow_type) diff --git a/tests/test_annual_time_range_generator.py b/tests/test_annual_time_range_generator.py new file mode 100644 index 0000000..9fedce9 --- /dev/null +++ b/tests/test_annual_time_range_generator.py @@ -0,0 +1,37 @@ +"""Tests for AnnualTimeRangeGenerator.""" + +import pytest + +from chronify.annual_time_range_generator import AnnualTimeRangeGenerator +from chronify.time_configs import AnnualTimeRange + + +def _make_config(start: int = 1, length: int = 5) -> AnnualTimeRange: + return AnnualTimeRange( + time_column="year", + start=start, + length=length, + ) + + +def test_list_timestamps(): + gen = AnnualTimeRangeGenerator(_make_config(start=1, length=5)) + assert gen.list_timestamps() == [1, 2, 3, 4, 5] + + +def test_list_timestamps_single(): + gen = AnnualTimeRangeGenerator(_make_config(start=1, length=1)) + assert gen.list_timestamps() == [1] + + +def test_list_time_columns(): + gen = AnnualTimeRangeGenerator(_make_config()) + assert gen.list_time_columns() == ["year"] + + +def test_list_distinct_timestamps_from_dataframe_not_implemented(): + import pandas as pd + + gen = AnnualTimeRangeGenerator(_make_config()) + with pytest.raises(NotImplementedError): + gen.list_distinct_timestamps_from_dataframe(pd.DataFrame()) diff --git a/tests/test_column_representative_period.py b/tests/test_column_representative_period.py new file mode 100644 index 0000000..89d7125 --- /dev/null +++ b/tests/test_column_representative_period.py @@ -0,0 +1,88 @@ +"""Tests for ColumnRepresentativeTimeGenerator with Period handler.""" + +import pandas as pd +import pytest + +from chronify.column_representative_time_range_generator import ( + ColumnRepresentativeHandlerPeriod, + ColumnRepresentativeTimeGenerator, +) +from chronify.exceptions import InvalidValue +from chronify.time_configs import ( + MonthDayHourTimeNTZ, + YearMonthDayPeriodTimeNTZ, +) + + +def _make_period_config(year: int = 2024, length: int = 8760) -> YearMonthDayPeriodTimeNTZ: + return YearMonthDayPeriodTimeNTZ( + hour_columns=["period"], + day_column="day", + month_column="month", + year_column="year", + year=year, + length=length, + ) + + +class TestColumnRepresentativeTimeGeneratorPeriod: + def test_list_timestamps(self): + config = _make_period_config(year=2024, length=8784) # 366 days * 24 hours + gen = ColumnRepresentativeTimeGenerator(config) + timestamps = gen.list_timestamps() + # 366 days in 2024 (leap year) + assert len(timestamps) == 366 + assert timestamps[0] == (2024, 1, 1) + assert timestamps[-1] == (2024, 12, 31) + + def test_list_time_columns(self): + config = _make_period_config() + gen = ColumnRepresentativeTimeGenerator(config) + assert gen.list_time_columns() == ["year", "month", "day", "period"] + + +class TestColumnRepresentativeHandlerPeriod: + def test_iter_timestamps(self): + config = _make_period_config(year=2023, length=24 * 3) # 3 days + handler = ColumnRepresentativeHandlerPeriod(config, 2023) + timestamps = list(handler._iter_timestamps()) + assert len(timestamps) == 3 + assert timestamps[0] == (2023, 1, 1) + assert timestamps[1] == (2023, 1, 2) + assert timestamps[2] == (2023, 1, 3) + + def test_list_distinct_timestamps_from_dataframe(self): + config = _make_period_config(year=2023, length=24 * 3) + handler = ColumnRepresentativeHandlerPeriod(config, 2023) + df = pd.DataFrame( + { + "year": [2023, 2023, 2023, 2023], + "month": [1, 1, 1, 1], + "day": [1, 1, 2, 2], + "period": ["H1-5", "H6-12", "H1-5", "H6-12"], + "value": [1.0, 2.0, 3.0, 4.0], + } + ) + result = handler.list_distinct_timestamps_from_dataframe(df) + assert result == [(2023, 1, 1), (2023, 1, 2)] + + +class TestColumnRepresentativeErrors: + def test_no_year_raises(self): + config = MonthDayHourTimeNTZ( + day_column="day", + month_column="month", + hour_columns=["hour"], + year=None, + length=8760, + ) + with pytest.raises(InvalidValue, match="without year"): + ColumnRepresentativeTimeGenerator(config) + + def test_unsupported_config_raises(self): + """ColumnRepresentativeBase subclasses not matching known handlers should raise.""" + config = _make_period_config() + gen = ColumnRepresentativeTimeGenerator(config) + # The generator was created successfully with a period config. + # Verify it works correctly. + assert len(gen.list_timestamps()) > 0 diff --git a/tests/test_ibis_base.py b/tests/test_ibis_base.py new file mode 100644 index 0000000..ca8004f --- /dev/null +++ b/tests/test_ibis_base.py @@ -0,0 +1,75 @@ +"""Tests for the IbisBackend base class (transaction, execute_sql, etc.).""" + +import pytest + +from chronify.ibis import make_backend +from chronify.ibis.base import ObjectType + + +def test_execute_sql(create_duckdb_backend): + backend = create_duckdb_backend + backend.execute_sql("CREATE TABLE test_exec_sql (id INTEGER, val DOUBLE)") + assert backend.has_table("test_exec_sql") + + +def test_execute_sql_to_df(create_duckdb_backend): + backend = create_duckdb_backend + backend.execute_sql("CREATE TABLE test_sql_df (id INTEGER, val DOUBLE)") + backend.execute_sql("INSERT INTO test_sql_df VALUES (1, 2.5)") + df = backend.execute_sql_to_df("SELECT * FROM test_sql_df") + assert len(df) == 1 + assert df["id"].iloc[0] == 1 + + +def test_dispose(): + backend = make_backend("duckdb") + backend.dispose() + + +def test_transaction_success(create_duckdb_backend): + backend = create_duckdb_backend + with backend.transaction() as created: + backend.create_table( + "txn_table", + obj=None, + schema={"id": "int64", "val": "float64"}, + ) + created.append(("txn_table", ObjectType.TABLE)) + + # Table should still exist after successful transaction + assert backend.has_table("txn_table") + + +def test_transaction_rollback_on_exception(create_duckdb_backend): + import pandas as pd + + backend = create_duckdb_backend + df = pd.DataFrame({"id": [1], "val": [2.0]}) + + with pytest.raises(ValueError, match="test error"): + with backend.transaction() as created: + backend.create_table("txn_rollback", obj=df) + created.append(("txn_rollback", ObjectType.TABLE)) + msg = "test error" + raise ValueError(msg) + + # Table should have been cleaned up + assert not backend.has_table("txn_rollback") + + +def test_transaction_rollback_view(create_duckdb_backend): + import pandas as pd + + backend = create_duckdb_backend + df = pd.DataFrame({"id": [1], "val": [2.0]}) + backend.create_table("base_for_view", obj=df) + expr = backend.table("base_for_view") + + with pytest.raises(ValueError, match="test error"): + with backend.transaction() as created: + backend.create_view("txn_view", expr) + created.append(("txn_view", ObjectType.VIEW)) + msg = "test error" + raise ValueError(msg) + + assert not backend.has_table("txn_view") diff --git a/tests/test_ibis_functions.py b/tests/test_ibis_functions.py new file mode 100644 index 0000000..551800a --- /dev/null +++ b/tests/test_ibis_functions.py @@ -0,0 +1,225 @@ +"""Tests for ibis/functions.py edge cases and uncovered branches.""" + +from datetime import datetime, timedelta +from zoneinfo import ZoneInfo + +import pandas as pd +import pyarrow as pa +import pytest + +from chronify.exceptions import InvalidOperation, InvalidParameter +from chronify.ibis import make_backend +from chronify.ibis.functions import ( + _check_one_config_per_datetime_column, + _convert_database_output_for_datetime, + _convert_spark_output_for_datetime, + _normalize_timestamps, + write_table, +) +from chronify.time import TimeIntervalType +from chronify.time_configs import DatetimeRange + + +def _make_tz_config(col: str = "timestamp") -> DatetimeRange: + return DatetimeRange( + time_column=col, + start=datetime(2020, 1, 1, tzinfo=ZoneInfo("UTC")), + length=2, + resolution=timedelta(hours=1), + interval_type=TimeIntervalType.PERIOD_BEGINNING, + ) + + +def _make_ntz_config(col: str = "timestamp") -> DatetimeRange: + return DatetimeRange( + time_column=col, + start=datetime(2020, 1, 1), + length=2, + resolution=timedelta(hours=1), + interval_type=TimeIntervalType.PERIOD_BEGINNING, + ) + + +class TestNormalizeTimestamps: + def test_tz_aware_to_ntz(self): + config = _make_ntz_config() + df = pd.DataFrame( + { + "timestamp": pd.to_datetime( + ["2020-01-01 00:00:00+00:00", "2020-01-01 01:00:00+00:00"], utc=True + ), + } + ) + result = _normalize_timestamps(df, [config]) + assert not isinstance(result["timestamp"].dtype, pd.DatetimeTZDtype) + + def test_tz_naive_to_tz(self): + config = _make_tz_config() + df = pd.DataFrame( + { + "timestamp": pd.to_datetime(["2020-01-01 00:00:00", "2020-01-01 01:00:00"]), + } + ) + result = _normalize_timestamps(df, [config]) + assert isinstance(result["timestamp"].dtype, pd.DatetimeTZDtype) + + def test_column_not_in_df(self): + config = _make_tz_config(col="missing_col") + df = pd.DataFrame({"other": [1, 2]}) + result = _normalize_timestamps(df, [config]) + assert list(result.columns) == ["other"] + + def test_non_datetime_column_skipped(self): + config = _make_tz_config() + df = pd.DataFrame({"timestamp": ["not", "datetime"]}) + result = _normalize_timestamps(df, [config]) + assert list(result["timestamp"]) == ["not", "datetime"] + + +class TestCheckOneConfigPerDatetimeColumn: + def test_duplicate_config_raises(self): + configs = [_make_tz_config("timestamp"), _make_tz_config("timestamp")] + with pytest.raises(InvalidParameter, match="More than one datetime config"): + _check_one_config_per_datetime_column(configs) + + +class TestConvertDatabaseOutputForDatetime: + def test_tz_with_object_dtype(self): + config = _make_tz_config() + df = pd.DataFrame({"timestamp": ["2020-01-01 00:00:00", "2020-01-01 01:00:00"]}) + _convert_database_output_for_datetime(df, config) + assert isinstance(df["timestamp"].dtype, pd.DatetimeTZDtype) + + def test_tz_with_tz_aware_dtype(self): + config = _make_tz_config() + df = pd.DataFrame( + { + "timestamp": pd.to_datetime( + ["2020-01-01 00:00:00+05:00", "2020-01-01 01:00:00+05:00"] + ), + } + ) + _convert_database_output_for_datetime(df, config) + assert str(df["timestamp"].dt.tz) == "UTC" + + def test_tz_with_naive_dtype(self): + config = _make_tz_config() + df = pd.DataFrame( + { + "timestamp": pd.to_datetime(["2020-01-01 00:00:00", "2020-01-01 01:00:00"]), + } + ) + _convert_database_output_for_datetime(df, config) + assert isinstance(df["timestamp"].dtype, pd.DatetimeTZDtype) + + def test_ntz_with_object_dtype(self): + config = _make_ntz_config() + df = pd.DataFrame({"timestamp": ["2020-01-01 00:00:00", "2020-01-01 01:00:00"]}) + _convert_database_output_for_datetime(df, config) + assert pd.api.types.is_datetime64_any_dtype(df["timestamp"]) + + def test_missing_column_is_noop(self): + config = _make_tz_config() + df = pd.DataFrame({"other": [1, 2]}) + _convert_database_output_for_datetime(df, config) + assert list(df.columns) == ["other"] + + +class TestConvertSparkOutputForDatetime: + def test_tz_with_object_dtype(self): + config = _make_tz_config() + df = pd.DataFrame({"timestamp": ["2020-01-01 00:00:00", "2020-01-01 01:00:00"]}) + _convert_spark_output_for_datetime(df, config) + assert isinstance(df["timestamp"].dtype, pd.DatetimeTZDtype) + + def test_tz_with_tz_aware_dtype(self): + config = _make_tz_config() + df = pd.DataFrame( + { + "timestamp": pd.to_datetime( + ["2020-01-01 00:00:00+05:00", "2020-01-01 01:00:00+05:00"] + ), + } + ) + _convert_spark_output_for_datetime(df, config) + assert str(df["timestamp"].dt.tz) == "UTC" + + def test_tz_with_naive_dtype(self): + config = _make_tz_config() + df = pd.DataFrame( + { + "timestamp": pd.to_datetime(["2020-01-01 00:00:00", "2020-01-01 01:00:00"]), + } + ) + _convert_spark_output_for_datetime(df, config) + assert isinstance(df["timestamp"].dtype, pd.DatetimeTZDtype) + + def test_ntz_with_object_dtype(self): + config = _make_ntz_config() + df = pd.DataFrame({"timestamp": ["2020-01-01 00:00:00", "2020-01-01 01:00:00"]}) + _convert_spark_output_for_datetime(df, config) + assert pd.api.types.is_datetime64_any_dtype(df["timestamp"]) + + def test_ntz_strips_tz_from_aware(self): + config = _make_ntz_config() + df = pd.DataFrame( + { + "timestamp": pd.to_datetime( + ["2020-01-01 00:00:00+00:00", "2020-01-01 01:00:00+00:00"], utc=True + ), + } + ) + _convert_spark_output_for_datetime(df, config) + assert not isinstance(df["timestamp"].dtype, pd.DatetimeTZDtype) + + def test_missing_column_is_noop(self): + config = _make_tz_config() + df = pd.DataFrame({"other": [1, 2]}) + _convert_spark_output_for_datetime(df, config) + assert list(df.columns) == ["other"] + + +class TestWriteTable: + def test_pyarrow_table_input(self): + backend = make_backend("duckdb") + config = _make_ntz_config() + pa_table = pa.table( + { + "timestamp": pd.to_datetime(["2020-01-01 00:00:00", "2020-01-01 01:00:00"]), + "value": [1.0, 2.0], + } + ) + write_table(backend, pa_table, "pa_test", [config], if_exists="fail") + assert backend.has_table("pa_test") + df = backend.execute(backend.table("pa_test")) + assert len(df) == 2 + backend.dispose() + + def test_invalid_if_exists_duckdb(self): + backend = make_backend("duckdb") + config = _make_ntz_config() + df = pd.DataFrame( + { + "timestamp": pd.to_datetime(["2020-01-01 00:00:00", "2020-01-01 01:00:00"]), + "value": [1.0, 2.0], + } + ) + write_table(backend, df, "test_tbl", [config], if_exists="fail") + with pytest.raises(InvalidOperation, match="Invalid if_exists"): + write_table(backend, df, "test_tbl", [config], if_exists="invalid") + backend.dispose() + + def test_unsupported_backend(self): + """A backend with an unknown name should raise NotImplementedError.""" + from unittest.mock import MagicMock + + backend = MagicMock() + backend.name = "unknown_db" + config = _make_ntz_config() + df = pd.DataFrame( + { + "timestamp": pd.to_datetime(["2020-01-01 00:00:00", "2020-01-01 01:00:00"]), + } + ) + with pytest.raises(NotImplementedError, match="Unsupported backend"): + write_table(backend, df, "test", [config], if_exists="fail") diff --git a/tests/test_ibis_types.py b/tests/test_ibis_types.py new file mode 100644 index 0000000..4ebb72d --- /dev/null +++ b/tests/test_ibis_types.py @@ -0,0 +1,80 @@ +"""Tests for ibis type conversion utilities.""" + +import ibis.expr.datatypes as dt +import pyarrow as pa +import pytest + +from chronify.ibis.types import ( + get_duckdb_type_from_ibis, + get_ibis_type_from_duckdb, + get_ibis_type_from_string, + ibis_to_pyarrow_type, + pyarrow_to_ibis_type, +) + + +class TestGetIbisTypeFromString: + def test_valid_types(self): + assert get_ibis_type_from_string("int") == dt.Int64() + assert get_ibis_type_from_string("float") == dt.Float64() + assert get_ibis_type_from_string("str") == dt.String() + assert get_ibis_type_from_string("bool") == dt.Boolean() + assert get_ibis_type_from_string("datetime") == dt.Timestamp(timezone=None) + assert get_ibis_type_from_string("datetime_tz") == dt.Timestamp(timezone="UTC") + + def test_invalid_type(self): + with pytest.raises(ValueError, match="Unsupported type name"): + get_ibis_type_from_string("invalid") + + +class TestGetIbisTypeFromDuckdb: + def test_common_types(self): + assert get_ibis_type_from_duckdb("BOOLEAN") == dt.Boolean() + assert get_ibis_type_from_duckdb("INTEGER") == dt.Int32() + assert get_ibis_type_from_duckdb("BIGINT") == dt.Int64() + assert get_ibis_type_from_duckdb("DOUBLE") == dt.Float64() + assert get_ibis_type_from_duckdb("VARCHAR") == dt.String() + assert get_ibis_type_from_duckdb("TIMESTAMP") == dt.Timestamp(timezone=None) + assert get_ibis_type_from_duckdb("TIMESTAMPTZ") == dt.Timestamp(timezone="UTC") + + def test_case_insensitive(self): + assert get_ibis_type_from_duckdb("boolean") == dt.Boolean() + assert get_ibis_type_from_duckdb("integer") == dt.Int32() + + def test_unsupported_type(self): + with pytest.raises(ValueError, match="Unsupported DuckDB type"): + get_ibis_type_from_duckdb("BLOB") + + +class TestGetDuckdbTypeFromIbis: + def test_common_types(self): + assert get_duckdb_type_from_ibis(dt.Boolean()) == "BOOLEAN" + assert get_duckdb_type_from_ibis(dt.Int64()) == "BIGINT" + assert get_duckdb_type_from_ibis(dt.Float64()) == "DOUBLE" + assert get_duckdb_type_from_ibis(dt.String()) == "VARCHAR" + + def test_timestamp_with_timezone(self): + assert get_duckdb_type_from_ibis(dt.Timestamp(timezone="UTC")) == "TIMESTAMPTZ" + + def test_timestamp_without_timezone(self): + assert get_duckdb_type_from_ibis(dt.Timestamp(timezone=None)) == "TIMESTAMP" + + def test_unsupported_type(self): + with pytest.raises(ValueError, match="Unsupported Ibis type for DuckDB"): + get_duckdb_type_from_ibis(dt.Binary()) + + +class TestPyarrowConversion: + def test_pyarrow_to_ibis(self): + result = pyarrow_to_ibis_type(pa.int64()) + assert isinstance(result, dt.Int64) + + def test_ibis_to_pyarrow(self): + result = ibis_to_pyarrow_type(dt.Int64()) + assert result == pa.int64() + + def test_roundtrip(self): + original = dt.Float64() + arrow_type = ibis_to_pyarrow_type(original) + back = pyarrow_to_ibis_type(arrow_type) + assert isinstance(back, dt.Float64) diff --git a/tests/test_spark_backend.py b/tests/test_spark_backend.py index fe33efb..1c680fe 100644 --- a/tests/test_spark_backend.py +++ b/tests/test_spark_backend.py @@ -236,6 +236,173 @@ def test_spark_time_zone_conversion(spark_store: Store) -> None: assert list(out_sorted["timestamp"]) == list(expected) +def test_spark_delete_rows(spark_store: Store) -> None: + """delete_rows should remove matching rows and return the count.""" + schema = TableSchema( + name="spark_del", + value_column="value", + time_config=DatetimeRange( + time_column="timestamp", + start=datetime(2020, 1, 1, tzinfo=ZoneInfo("UTC")), + length=2, + resolution=timedelta(hours=1), + interval_type=TimeIntervalType.PERIOD_BEGINNING, + ), + time_array_id_columns=["id"], + ) + df = pd.DataFrame( + { + "id": [1, 1, 2, 2], + "timestamp": pd.to_datetime( + [ + "2020-01-01 00:00:00+00:00", + "2020-01-01 01:00:00+00:00", + "2020-01-01 00:00:00+00:00", + "2020-01-01 01:00:00+00:00", + ], + utc=True, + ), + "value": [1.0, 2.0, 3.0, 4.0], + } + ) + spark_store.ingest_table(df, schema, skip_time_checks=True) + count = spark_store.delete_rows(schema.name, {"id": 1}) + assert count == 2 + out = spark_store.read_table(schema.name) + assert len(out) == 2 + assert set(out["id"]) == {2} + + +def test_spark_write_parquet_partitioned(spark_store: Store, tmp_path: Path) -> None: + """write_parquet with partition_by should produce partitioned output.""" + schema = TableSchema( + name="spark_part", + value_column="value", + time_config=DatetimeRange( + time_column="timestamp", + start=datetime(2020, 1, 1, tzinfo=ZoneInfo("UTC")), + length=2, + resolution=timedelta(hours=1), + interval_type=TimeIntervalType.PERIOD_BEGINNING, + ), + time_array_id_columns=["id"], + ) + df = pd.DataFrame( + { + "id": [1, 1, 2, 2], + "timestamp": pd.to_datetime( + [ + "2020-01-01 00:00:00+00:00", + "2020-01-01 01:00:00+00:00", + "2020-01-01 00:00:00+00:00", + "2020-01-01 01:00:00+00:00", + ], + utc=True, + ), + "value": [1.0, 2.0, 3.0, 4.0], + } + ) + spark_store.ingest_table(df, schema, skip_time_checks=True) + outdir = tmp_path / "partitioned_output" + spark_store.backend.write_parquet( + spark_store.backend.table(schema.name), + str(outdir), + partition_by=["id"], + ) + # Partitioned parquet creates subdirectories + assert outdir.exists() + subdirs = [p for p in outdir.iterdir() if p.is_dir() and p.name.startswith("id=")] + assert len(subdirs) == 2 + + +def test_spark_create_view_from_parquet(spark_store: Store, tmp_path: Path) -> None: + """create_view_from_parquet should create a readable view.""" + schema = TableSchema( + name="spark_pq_src", + value_column="value", + time_config=DatetimeRange( + time_column="timestamp", + start=datetime(2020, 1, 1, tzinfo=ZoneInfo("UTC")), + length=2, + resolution=timedelta(hours=1), + interval_type=TimeIntervalType.PERIOD_BEGINNING, + ), + time_array_id_columns=["id"], + ) + df = pd.DataFrame( + { + "id": [1, 1], + "timestamp": pd.to_datetime( + ["2020-01-01 00:00:00+00:00", "2020-01-01 01:00:00+00:00"], + utc=True, + ), + "value": [1.0, 2.0], + } + ) + spark_store.ingest_table(df, schema, skip_time_checks=True) + + # Write to parquet then create a view from it + outfile = tmp_path / "view_src.parquet" + spark_store.write_table_to_parquet(schema.name, outfile, overwrite=True) + + from chronify.ibis.base import ObjectType + + table_expr, obj_type = spark_store.backend.create_view_from_parquet(str(outfile), "pq_view") + assert obj_type == ObjectType.VIEW + result = spark_store.backend.execute(table_expr) + assert len(result) == 2 + + +def test_spark_create_and_drop_view(spark_store: Store) -> None: + """create_view and drop_view should work correctly.""" + schema = TableSchema( + name="spark_view_src", + value_column="value", + time_config=DatetimeRange( + time_column="timestamp", + start=datetime(2020, 1, 1, tzinfo=ZoneInfo("UTC")), + length=2, + resolution=timedelta(hours=1), + interval_type=TimeIntervalType.PERIOD_BEGINNING, + ), + time_array_id_columns=["id"], + ) + df = pd.DataFrame( + { + "id": [1, 1], + "timestamp": pd.to_datetime( + ["2020-01-01 00:00:00+00:00", "2020-01-01 01:00:00+00:00"], + utc=True, + ), + "value": [1.0, 2.0], + } + ) + spark_store.ingest_table(df, schema, skip_time_checks=True) + + expr = spark_store.backend.table(schema.name) + spark_store.backend.create_view("test_view", expr) + assert spark_store.backend.has_table("test_view") + + spark_store.backend.drop_view("test_view") + assert not spark_store.backend.has_table("test_view") + + +def test_spark_dispose(tmp_path: Path) -> None: + """dispose should not raise on an owned session.""" + _require_java_home() + pyspark = pytest.importorskip("pyspark.sql") + warehouse_dir = tmp_path / "spark-warehouse-dispose" + session = ( + pyspark.SparkSession.builder.master("local") + .config("spark.sql.session.timeZone", "UTC") + .config("spark.sql.parquet.outputTimestampType", "TIMESTAMP_MICROS") + .config("spark.sql.warehouse.dir", str(warehouse_dir)) + .getOrCreate() + ) + backend = SparkBackend(session=session) + backend.dispose() + + def test_spark_backend_rejects_non_utc_session() -> None: _require_java_home() pyspark = pytest.importorskip("pyspark.sql") diff --git a/tests/test_store_errors.py b/tests/test_store_errors.py new file mode 100644 index 0000000..8410930 --- /dev/null +++ b/tests/test_store_errors.py @@ -0,0 +1,196 @@ +"""Tests for Store error paths and edge cases.""" + +from datetime import datetime, timedelta +from zoneinfo import ZoneInfo + +import pandas as pd +import pytest + +from chronify.exceptions import ( + InvalidParameter, + TableAlreadyExists, + TableNotStored, +) +from chronify.ibis import make_backend +from chronify.models import TableSchema +from chronify.store import Store +from chronify.time import TimeIntervalType +from chronify.time_configs import DatetimeRange + + +def _make_tz_schema(name: str = "generators") -> TableSchema: + return TableSchema( + name=name, + value_column="value", + time_config=DatetimeRange( + time_column="timestamp", + start=datetime(2020, 1, 1, tzinfo=ZoneInfo("UTC")), + length=3, + resolution=timedelta(hours=1), + interval_type=TimeIntervalType.PERIOD_BEGINNING, + ), + time_array_id_columns=["id"], + ) + + +def _make_ntz_schema(name: str = "generators") -> TableSchema: + return TableSchema( + name=name, + value_column="value", + time_config=DatetimeRange( + time_column="timestamp", + start=datetime(2020, 1, 1), + length=3, + resolution=timedelta(hours=1), + interval_type=TimeIntervalType.PERIOD_BEGINNING, + ), + time_array_id_columns=["id"], + ) + + +def _make_store() -> Store: + return Store(backend=make_backend("duckdb")) + + +def _make_tz_df() -> pd.DataFrame: + return pd.DataFrame( + { + "id": [1, 1, 1], + "timestamp": pd.to_datetime( + ["2020-01-01 00:00:00", "2020-01-01 01:00:00", "2020-01-01 02:00:00"], + utc=True, + ), + "value": [1.0, 2.0, 3.0], + } + ) + + +class TestGetTable: + def test_get_table_not_stored(self): + store = _make_store() + with pytest.raises(TableNotStored, match="nonexistent"): + store.get_table("nonexistent") + store.dispose() + + def test_try_get_table_returns_none(self): + store = _make_store() + result = store.try_get_table("nonexistent") + assert result is None + store.dispose() + + def test_try_get_table_returns_table(self): + store = _make_store() + schema = _make_tz_schema() + store.ingest_table(_make_tz_df(), schema) + result = store.try_get_table(schema.name) + assert result is not None + store.dispose() + + +class TestDropTableErrors: + def test_drop_table_not_stored(self): + store = _make_store() + with pytest.raises(TableNotStored, match="nonexistent"): + store.drop_table("nonexistent") + store.dispose() + + def test_drop_table_if_exists_no_error(self): + store = _make_store() + store.drop_table("nonexistent", if_exists=True) + store.dispose() + + def test_drop_view_not_stored(self): + store = _make_store() + with pytest.raises(TableNotStored, match="nonexistent"): + store.drop_view("nonexistent") + store.dispose() + + def test_drop_view_if_exists_no_error(self): + store = _make_store() + store.drop_view("nonexistent", if_exists=True) + store.dispose() + + +class TestDeleteRows: + def test_delete_rows_table_not_stored(self): + store = _make_store() + with pytest.raises(TableNotStored, match="nonexistent"): + store.delete_rows("nonexistent", {"id": 1}) + store.dispose() + + def test_delete_rows_empty_values(self): + store = _make_store() + schema = _make_tz_schema() + store.ingest_table(_make_tz_df(), schema) + with pytest.raises(InvalidParameter, match="cannot be empty"): + store.delete_rows(schema.name, {}) + store.dispose() + + def test_delete_rows_wrong_columns(self): + store = _make_store() + schema = _make_tz_schema() + store.ingest_table(_make_tz_df(), schema) + with pytest.raises(InvalidParameter, match="must match the schema columns"): + store.delete_rows(schema.name, {"wrong_column": 1}) + store.dispose() + + def test_delete_rows_no_matching_rows(self): + store = _make_store() + schema = _make_tz_schema() + store.ingest_table(_make_tz_df(), schema) + with pytest.raises(InvalidParameter, match="Failed to delete rows"): + store.delete_rows(schema.name, {"id": 999}) + store.dispose() + + +class TestWriteParquetErrors: + def test_write_table_to_parquet_not_stored(self, tmp_path): + store = _make_store() + with pytest.raises(TableNotStored, match="nonexistent"): + store.write_table_to_parquet("nonexistent", tmp_path / "out.parquet") + store.dispose() + + +class TestConvertTimeZoneAlreadyExists: + def test_convert_time_zone_dst_exists(self): + store = _make_store() + schema = _make_tz_schema() + store.ingest_table(_make_tz_df(), schema) + + # First conversion creates the destination table + store.convert_time_zone(schema.name, ZoneInfo("US/Eastern")) + + # Second conversion to the same tz should fail because dst table exists + with pytest.raises(TableAlreadyExists): + store.convert_time_zone(schema.name, ZoneInfo("US/Eastern")) + store.dispose() + + def test_localize_time_zone_dst_exists(self): + store = _make_store() + schema = _make_ntz_schema() + df = pd.DataFrame( + { + "id": [1, 1, 1], + "timestamp": pd.to_datetime( + ["2020-01-01 00:00:00", "2020-01-01 01:00:00", "2020-01-01 02:00:00"], + ), + "value": [1.0, 2.0, 3.0], + } + ) + store.ingest_table(df, schema) + + # First localization (must use standard tz without DST) + store.localize_time_zone(schema.name, ZoneInfo("EST")) + + # Second localization to the same tz should fail + with pytest.raises(TableAlreadyExists): + store.localize_time_zone(schema.name, ZoneInfo("EST")) + store.dispose() + + +class TestSchemaManager: + def test_schema_manager_property(self): + store = _make_store() + mgr = store.schema_manager + assert mgr is not None + store.dispose() From 472d7bc09a178c7c7e97f102c8857b7679d1a59e Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Sun, 12 Apr 2026 09:31:17 -0600 Subject: [PATCH 17/48] Address PR review feedback - Allow injecting an existing ibis connection into DuckDB/SQLite backends, mirroring SparkBackend(session=...); external connections are not disposed by the backend. - Remove IbisBackend.reconnect; replace with backend-native backup(): DuckDB uses disconnect+copy+reconnect internally, SQLite uses the sqlite3 online backup API, Spark raises InvalidOperation. - Validate insert column sets on DuckDB/SQLite/Spark and raise InvalidParameter on mismatch instead of silent reindex. - DuckDB write_parquet uses COPY (FORMAT PARQUET) for the unpartitioned path instead of materializing to pandas. - SparkBackend.dispose always disconnects the ibis connection and stops the session only when owned; factor temp-view handling into a helper. - Collapse duplicated _write_to_{duckdb,sqlite,spark} into a single _apply_if_exists helper. - ObjectType is a StrEnum; make_backend raises InvalidParameter for unknown names; schema_manager imports moved to top. - Restore docstrings on convert/localize time zone methods and read_raw_query. - Relax pyspark pin to >=4.0,<5. - CI: install package editable and pass --cov=chronify so Codecov can map coverage paths back to repo sources. Co-Authored-By: Claude Opus 4.6 --- .github/workflows/ci.yml | 4 +- pyproject.toml | 2 +- src/chronify/ibis/__init__.py | 3 +- src/chronify/ibis/base.py | 13 +- src/chronify/ibis/duckdb_backend.py | 96 +++++++++++-- src/chronify/ibis/functions.py | 73 ++-------- src/chronify/ibis/spark_backend.py | 43 ++++-- src/chronify/ibis/sqlite_backend.py | 86 +++++++++-- src/chronify/schema_manager.py | 3 +- src/chronify/store.py | 215 ++++++++++++++++++++++++++-- tests/test_store.py | 2 +- 11 files changed, 421 insertions(+), 119 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c571767..80e2984 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -29,10 +29,10 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install ".[dev,spark]" + python -m pip install -e ".[dev,spark]" - name: Run pytest with coverage run: | - pytest -v --cov --cov-report=xml + pytest -v --cov=chronify --cov-report=xml:coverage.xml - name: codecov uses: codecov/codecov-action@v4.2.0 if: ${{ matrix.os == env.DEFAULT_OS && matrix.python-version == env.DEFAULT_PYTHON }} diff --git a/pyproject.toml b/pyproject.toml index 6b8450f..91aa00d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ [project.optional-dependencies] spark = [ "ibis-framework[pyspark]", - "pyspark == 4.0.0", + "pyspark >= 4.0, < 5", ] dev = [ diff --git a/src/chronify/ibis/__init__.py b/src/chronify/ibis/__init__.py index 926012c..607780a 100644 --- a/src/chronify/ibis/__init__.py +++ b/src/chronify/ibis/__init__.py @@ -1,5 +1,6 @@ """Ibis backend abstraction layer for Chronify.""" +from chronify.exceptions import InvalidParameter from chronify.ibis.base import IbisBackend, ObjectType from chronify.ibis.duckdb_backend import DuckDBBackend from chronify.ibis.sqlite_backend import SQLiteBackend @@ -40,4 +41,4 @@ def make_backend( return SparkBackend(**kwargs) case _: msg = f"Unsupported backend: {name}. Choose from: duckdb, sqlite, spark" - raise ValueError(msg) + raise InvalidParameter(msg) diff --git a/src/chronify/ibis/base.py b/src/chronify/ibis/base.py index 839c4ea..d75b2f5 100644 --- a/src/chronify/ibis/base.py +++ b/src/chronify/ibis/base.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from contextlib import contextmanager -from enum import Enum +from enum import StrEnum from typing import Any, Generator, cast import ibis @@ -11,7 +11,7 @@ from loguru import logger -class ObjectType(Enum): +class ObjectType(StrEnum): TABLE = "table" VIEW = "view" @@ -135,8 +135,13 @@ def dispose(self) -> None: """Dispose of the backend connection.""" self.connection.disconnect() - def reconnect(self) -> None: - """Reconnect to the database. Subclasses should override if needed.""" + @abstractmethod + def backup(self, dst: str) -> None: + """Copy the database to a new location. + + Not supported for in-memory databases or backends without persistent + file storage (e.g., Spark). + """ @contextmanager def transaction(self) -> Generator[list[tuple[str, ObjectType]], None, None]: diff --git a/src/chronify/ibis/duckdb_backend.py b/src/chronify/ibis/duckdb_backend.py index fa8dc04..6463f64 100644 --- a/src/chronify/ibis/duckdb_backend.py +++ b/src/chronify/ibis/duckdb_backend.py @@ -1,5 +1,6 @@ """DuckDB backend implementation for Ibis.""" +import shutil from pathlib import Path from typing import Any, cast @@ -8,16 +9,46 @@ import pandas as pd from loguru import logger +from chronify.exceptions import ConflictingInputsError, InvalidOperation, InvalidParameter from chronify.ibis.base import IbisBackend, ObjectType class DuckDBBackend(IbisBackend): """Ibis backend for DuckDB databases.""" - def __init__(self, database: str | Path = ":memory:") -> None: - db = str(database) - self._database = None if db == ":memory:" else db - self._connection = ibis.duckdb.connect(db) + def __init__( + self, + database: str | Path = ":memory:", + connection: ibis.BaseBackend | None = None, + ) -> None: + """Construct a DuckDBBackend. + + Parameters + ---------- + database + Path to a DuckDB database file, or ``":memory:"`` for an in-memory + database. Ignored when ``connection`` is provided. + connection + Optional pre-existing ibis DuckDB connection. When provided, the + backend does not own the connection and will not disconnect it on + ``dispose()``. ``database`` is inferred from the connection when + possible; otherwise ``backup()`` is unavailable. + """ + if connection is not None and str(database) != ":memory:": + msg = f"{database=} and {connection=} cannot both be set" + raise ConflictingInputsError(msg) + + self._owns_connection = connection is None + if connection is None: + db = str(database) + self._database = None if db == ":memory:" else db + self._connection = ibis.duckdb.connect(db) + else: + if connection.name != "duckdb": + msg = f"DuckDBBackend requires a DuckDB ibis connection, got {connection.name!r}" + raise InvalidParameter(msg) + self._connection = connection + self._database = _infer_duckdb_path(connection) @property def name(self) -> str: @@ -60,7 +91,8 @@ def table(self, name: str) -> ir.Table: def insert(self, name: str, data: pd.DataFrame) -> None: con = self._connection.con # raw duckdb connection target_columns = list(self.table(name).columns) - ordered_data = data.reindex(columns=target_columns) + _validate_insert_columns(name, target_columns, list(data.columns)) + ordered_data = data.loc[:, target_columns] quoted_columns = ", ".join(f'"{col}"' for col in target_columns) quoted_name = _quote_identifier(name) con.register("__insert_df", ordered_data) @@ -93,16 +125,16 @@ def write_parquet( path: str, partition_by: list[str] | None = None, ) -> None: + escaped_path = path.replace("'", "''") + sql = self._connection.compile(expr) if partition_by: partition_clause = ", ".join(_quote_identifier(c) for c in partition_by) - escaped_path = path.replace("'", "''") - sql = self._connection.compile(expr) self._connection.raw_sql( f"COPY ({sql}) TO '{escaped_path}' " f"(FORMAT PARQUET, PARTITION_BY ({partition_clause}))" ) else: - expr.to_parquet(path) + self._connection.raw_sql(f"COPY ({sql}) TO '{escaped_path}' (FORMAT PARQUET)") def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, ObjectType]: parquet_path = Path(path) @@ -127,13 +159,49 @@ def execute_sql_to_df(self, query: str) -> pd.DataFrame: return cast(pd.DataFrame, result.fetch_df()) def dispose(self) -> None: + if self._owns_connection: + self._connection.disconnect() + + def backup(self, dst: str) -> None: + if self._database is None: + msg = "backup is only supported with a database backed by a file" + raise InvalidOperation(msg) + if not self._owns_connection: + msg = "backup is not supported for externally-provided DuckDB connections" + raise InvalidOperation(msg) + src = self._database self._connection.disconnect() - - def reconnect(self) -> None: - if self._database is not None: - self._connection = ibis.duckdb.connect(self._database) - else: - logger.warning("Cannot reconnect to an in-memory DuckDB database.") + try: + shutil.copyfile(src, dst) + finally: + self._connection = ibis.duckdb.connect(src) + + +def _infer_duckdb_path(connection: ibis.BaseBackend) -> str | None: + """Return the database file path for an ibis DuckDB connection, or None for in-memory.""" + try: + result = connection.con.execute( + "SELECT path FROM duckdb_databases() WHERE database_name = current_database()" + ).fetchone() + except Exception: + return None + if not result: + return None + path = result[0] + return None if not path else str(path) + + +def _validate_insert_columns( + table_name: str, target_columns: list[str], data_columns: list[str] +) -> None: + missing = [c for c in target_columns if c not in data_columns] + extra = [c for c in data_columns if c not in target_columns] + if missing or extra: + msg = ( + f"Insert data columns do not match table {table_name!r}. " + f"Missing: {missing}. Extra: {extra}." + ) + raise InvalidParameter(msg) def _quote_identifier(identifier: str) -> str: diff --git a/src/chronify/ibis/functions.py b/src/chronify/ibis/functions.py index df02ea5..80aba58 100644 --- a/src/chronify/ibis/functions.py +++ b/src/chronify/ibis/functions.py @@ -110,16 +110,20 @@ def write_table( _check_one_config_per_datetime_column(configs) df = _normalize_timestamps(df, configs) - match backend.name: - case "duckdb": - _write_to_duckdb(backend, df, table_name, if_exists) - case "sqlite": - _write_to_sqlite(backend, df, table_name, configs, if_exists) - case "spark": - _write_to_spark(backend, df, table_name, if_exists) - case _: - msg = f"Unsupported backend: {backend.name}" - raise NotImplementedError(msg) + if backend.name not in {"duckdb", "sqlite", "spark"}: + msg = f"Unsupported backend: {backend.name}" + raise NotImplementedError(msg) + + if backend.name == "sqlite": + # SQLite-specific: ensure TZ timestamps are stored as UTC text. + # _normalize_timestamps already ran, so NTZ columns are tz-naive and + # TZ columns are tz-aware UTC. This step converts TZ to UTC for storage. + copied = False + for config in configs: + if isinstance(config, _DATETIME_RANGES): + df, copied = _convert_database_input_for_datetime(df, config, copied) + + _apply_if_exists(backend, df, table_name, if_exists) def write_parquet( @@ -233,54 +237,7 @@ def _convert_spark_output_for_datetime(df: pd.DataFrame, config: DatetimeRanges) df[config.time_column] = col.dt.tz_convert(None).astype("datetime64[us]") -def _write_to_duckdb( - backend: IbisBackend, - df: pd.DataFrame, - table_name: str, - if_exists: str, -) -> None: - match if_exists: - case "append": - backend.insert(table_name, df) - case "replace": - backend.drop_table(table_name) - backend.create_table(table_name, df) - case "fail": - backend.create_table(table_name, df) - case _: - msg = f"Invalid if_exists value: {if_exists}" - raise InvalidOperation(msg) - - -def _write_to_sqlite( - backend: IbisBackend, - df: pd.DataFrame, - table_name: str, - configs: Sequence[TimeBaseModel], - if_exists: str, -) -> None: - # SQLite-specific: ensure TZ timestamps are stored as UTC text. - # _normalize_timestamps already ran, so NTZ columns are tz-naive and - # TZ columns are tz-aware UTC. This step converts TZ to UTC for storage. - copied = False - for config in configs: - if isinstance(config, _DATETIME_RANGES): - df, copied = _convert_database_input_for_datetime(df, config, copied) - - match if_exists: - case "append": - backend.insert(table_name, df) - case "replace": - backend.drop_table(table_name) - backend.create_table(table_name, df) - case "fail": - backend.create_table(table_name, df) - case _: - msg = f"Invalid if_exists value: {if_exists}" - raise InvalidOperation(msg) - - -def _write_to_spark( +def _apply_if_exists( backend: IbisBackend, df: pd.DataFrame, table_name: str, diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index b664d47..8e16535 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -2,7 +2,8 @@ import uuid import shutil -from typing import Any, cast +from contextlib import contextmanager +from typing import Any, Generator, cast from pathlib import Path from urllib.parse import urlparse, unquote @@ -12,7 +13,7 @@ from loguru import logger from pandas import DatetimeTZDtype -from chronify.exceptions import InvalidParameter +from chronify.exceptions import InvalidOperation, InvalidParameter from chronify.ibis.base import IbisBackend, ObjectType @@ -88,20 +89,27 @@ def table(self, name: str) -> ir.Table: def insert(self, name: str, data: pd.DataFrame) -> None: # Spark doesn't support INSERT directly -- create a temp view and insert via SQL target_columns = list(self.table(name).columns) - data = data.reindex(columns=target_columns) + _validate_insert_columns(name, target_columns, list(data.columns)) + data = data.loc[:, target_columns] data = self._prepare_data_for_spark(data) spark_df = self._session.createDataFrame(data) - tmp_view = f"__insert_tmp_{uuid.uuid4().hex}" - spark_df.createOrReplaceTempView(tmp_view) quoted_name = _quote_identifier(name) col_list = ", ".join(_quote_identifier(c) for c in target_columns) - try: + with self._temp_view(spark_df) as tmp_view: self._session.sql( f"INSERT INTO {quoted_name} ({col_list}) SELECT {col_list} FROM {tmp_view}" ) + logger.trace("Inserted {} rows into {}", len(data), name) + + @contextmanager + def _temp_view(self, spark_df: Any) -> Generator[str, None, None]: + """Register ``spark_df`` as a uniquely-named temp view; drop on exit.""" + tmp_view = f"__chronify_tmp_{uuid.uuid4().hex}" + spark_df.createOrReplaceTempView(tmp_view) + try: + yield tmp_view finally: self._session.catalog.dropTempView(tmp_view) - logger.trace("Inserted {} rows into {}", len(data), name) def delete_rows(self, name: str, values: dict[str, Any]) -> None: # Spark 3.4+ supports parameterized SQL via the ``args`` keyword. @@ -154,11 +162,13 @@ def execute_sql_to_df(self, query: str) -> pd.DataFrame: return cast(pd.DataFrame, self._session.sql(query).toPandas()) def dispose(self) -> None: + self._connection.disconnect() if self._owns_session: - self._connection.disconnect() + self._session.stop() - def reconnect(self) -> None: - pass # Spark sessions are long-lived + def backup(self, dst: str) -> None: + msg = "backup is not supported for the Spark backend" + raise InvalidOperation(msg) def _remove_managed_table_location(self, name: str) -> None: location = self._session.conf.get("spark.sql.warehouse.dir", "spark-warehouse") @@ -196,6 +206,19 @@ def _validate_session(session: Any) -> None: raise InvalidParameter(msg) +def _validate_insert_columns( + table_name: str, target_columns: list[str], data_columns: list[str] +) -> None: + missing = [c for c in target_columns if c not in data_columns] + extra = [c for c in data_columns if c not in target_columns] + if missing or extra: + msg = ( + f"Insert data columns do not match table {table_name!r}. " + f"Missing: {missing}. Extra: {extra}." + ) + raise InvalidParameter(msg) + + def _quote_identifier(identifier: str) -> str: """Quote a SQL identifier for Spark SQL, escaping embedded backticks.""" escaped = identifier.replace("`", "``") diff --git a/src/chronify/ibis/sqlite_backend.py b/src/chronify/ibis/sqlite_backend.py index 383ba0d..468f760 100644 --- a/src/chronify/ibis/sqlite_backend.py +++ b/src/chronify/ibis/sqlite_backend.py @@ -1,5 +1,6 @@ """SQLite backend implementation for Ibis.""" +import sqlite3 from datetime import datetime from pathlib import Path from typing import Any, cast @@ -9,6 +10,7 @@ import pandas as pd from loguru import logger +from chronify.exceptions import ConflictingInputsError, InvalidOperation, InvalidParameter from chronify.ibis.base import IbisBackend, ObjectType @@ -31,10 +33,38 @@ def _adapt_value(v: Any) -> Any: class SQLiteBackend(IbisBackend): """Ibis backend for SQLite databases.""" - def __init__(self, database: str | Path = ":memory:") -> None: - db = str(database) - self._database = None if db == ":memory:" else db - self._connection = ibis.sqlite.connect(db) + def __init__( + self, + database: str | Path = ":memory:", + connection: ibis.BaseBackend | None = None, + ) -> None: + """Construct a SQLiteBackend. + + Parameters + ---------- + database + Path to a SQLite database file, or ``":memory:"`` for an in-memory + database. Ignored when ``connection`` is provided. + connection + Optional pre-existing ibis SQLite connection. When provided, the + backend does not own the connection and will not disconnect it on + ``dispose()``. + """ + if connection is not None and str(database) != ":memory:": + msg = f"{database=} and {connection=} cannot both be set" + raise ConflictingInputsError(msg) + + self._owns_connection = connection is None + if connection is None: + db = str(database) + self._database = None if db == ":memory:" else db + self._connection = ibis.sqlite.connect(db) + else: + if connection.name != "sqlite": + msg = f"SQLiteBackend requires a SQLite ibis connection, got {connection.name!r}" + raise InvalidParameter(msg) + self._connection = connection + self._database = _infer_sqlite_path(connection) @property def name(self) -> str: @@ -81,13 +111,14 @@ def insert(self, name: str, data: pd.DataFrame) -> None: # Use raw SQLite cursor for parameterized inserts con = self._connection.con # raw sqlite3 connection table = self._connection.table(name) - columns = table.columns + columns = list(table.columns) + _validate_insert_columns(name, columns, list(data.columns)) placeholders = ", ".join(["?"] * len(columns)) col_list = ", ".join(_quote_identifier(c) for c in columns) quoted_name = _quote_identifier(name) sql = f"INSERT INTO {quoted_name} ({col_list}) VALUES ({placeholders})" - ordered = data.reindex(columns=columns) + ordered = data.loc[:, columns] rows = [tuple(_adapt_value(v) for v in row) for row in ordered.itertuples(index=False)] cursor = con.cursor() cursor.executemany(sql, rows) @@ -141,11 +172,44 @@ def execute_sql_to_df(self, query: str) -> pd.DataFrame: return pd.DataFrame(rows, columns=columns) def dispose(self) -> None: - self._connection.disconnect() - - def reconnect(self) -> None: - db = self._database if self._database else ":memory:" - self._connection = ibis.sqlite.connect(db) + if self._owns_connection: + self._connection.disconnect() + + def backup(self, dst: str) -> None: + if self._database is None: + msg = "backup is only supported with a database backed by a file" + raise InvalidOperation(msg) + dst_con = sqlite3.connect(dst) + try: + self._connection.con.backup(dst_con) + finally: + dst_con.close() + + +def _infer_sqlite_path(connection: ibis.BaseBackend) -> str | None: + """Return the database file path for an ibis SQLite connection, or None for in-memory.""" + try: + row = connection.con.execute("PRAGMA database_list").fetchone() + except Exception: + return None + if not row: + return None + # PRAGMA database_list returns (seq, name, file); empty string => in-memory. + path = row[2] + return None if not path else str(path) + + +def _validate_insert_columns( + table_name: str, target_columns: list[str], data_columns: list[str] +) -> None: + missing = [c for c in target_columns if c not in data_columns] + extra = [c for c in data_columns if c not in target_columns] + if missing or extra: + msg = ( + f"Insert data columns do not match table {table_name!r}. " + f"Missing: {missing}. Extra: {extra}." + ) + raise InvalidParameter(msg) def _quote_identifier(identifier: str) -> str: diff --git a/src/chronify/schema_manager.py b/src/chronify/schema_manager.py index b48ab07..f4c9bdf 100644 --- a/src/chronify/schema_manager.py +++ b/src/chronify/schema_manager.py @@ -1,5 +1,6 @@ import json +import ibis import pandas as pd from loguru import logger @@ -25,8 +26,6 @@ def __init__(self, backend: IbisBackend) -> None: logger.info("Initialized new database: {}", self._backend.database) def _create_schemas_table(self) -> None: - import ibis - # Uniqueness of `name` is enforced in `add_schema` rather than via a # unique index, since Spark SQL does not support CREATE UNIQUE INDEX. schema = ibis.schema({"name": "string", "schema": "string"}) diff --git a/src/chronify/store.py b/src/chronify/store.py index 388a6f4..2b14c7d 100644 --- a/src/chronify/store.py +++ b/src/chronify/store.py @@ -1,6 +1,5 @@ from collections.abc import Iterable from pathlib import Path -import shutil from typing import Any, Optional from datetime import tzinfo @@ -13,7 +12,6 @@ import chronify.duckdb.functions as ddbf from chronify.exceptions import ( ConflictingInputsError, - InvalidOperation, InvalidParameter, InvalidTable, TableAlreadyExists, @@ -136,18 +134,9 @@ def try_get_table(self, name: str) -> ir.Table | None: def backup(self, dst: Path | str, overwrite: bool = False) -> None: """Copy the database to a new location. Not yet supported for in-memory databases.""" - if self._backend.database is None: - msg = "backup is only supported with a database backed by a file" - raise InvalidOperation(msg) path = to_path(dst) check_overwrite(path, overwrite) - src_file = Path(self._backend.database) - - self._backend.dispose() - try: - shutil.copyfile(src_file, path) - finally: - self._backend.reconnect() + self._backend.backup(str(path)) logger.info("Copied database to {}", path) @property @@ -471,7 +460,68 @@ def convert_time_zone_by_column( output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> TableSchema: - """Convert the time zone of the existing table to time zone(s) defined by a column.""" + """ + Convert the time zone of the existing table represented by src_name to new time zone(s) defined by a column + + Parameters + ---------- + src_name + Refers to the table name of the source data. + time_zone_column + Name of the time zone column for conversion. + wrap_time_allowed + Defines whether the time column is allowed to be wrapped to reflect the same time + range as the src_name schema in tz-naive clock time + output_file + If set, write the mapped table to this Parquet file. + check_mapped_timestamps + Perform time checks on the result of the mapping operation. This can be slow and + is not required. + + Raises + ------ + TableAlreadyExists + Raised if the dst_schema name already exists. + + Examples + -------- + >>> store = Store() + >>> start = datetime(year=2018, month=1, day=1, tzinfo=ZoneInfo("Etc/GMT+5")) + >>> freq = timedelta(hours=1) + >>> hours_per_year = 8760 + >>> num_time_arrays = 3 + >>> df = pd.DataFrame( + ... { + ... "id": np.concatenate( + ... [np.repeat(i, hours_per_year) for i in range(1, 1 + num_time_arrays)] + ... ), + ... "timestamp": np.tile( + ... pd.date_range(start, periods=hours_per_year, freq="h"), num_time_arrays + ... ), + ... "time_zone": np.repeat(["US/Eastern", "US/Mountain", "None"], hours_per_year), + ... "value": np.random.random(hours_per_year * num_time_arrays), + ... } + ... ) + >>> schema = TableSchema( + ... name="some_data", + ... time_config=DatetimeRange( + ... time_column="timestamp", + ... start=start, + ... length=hours_per_year, + ... resolution=freq, + ... ), + ... time_array_id_columns=["id"], + ... value_column="value", + ... ) + >>> store.ingest_table(df, schema) + >>> time_zone_column = "time_zone" + >>> dst_schema = store.convert_time_zone_by_column( + ... schema.name, + ... time_zone_column, + ... wrap_time_allowed=False, + ... check_mapped_timestamps=True, + ... ) + """ src_schema = self._schema_mgr.get_schema(src_name) tzc = TimeZoneConverterByColumn( self._backend, src_schema, time_zone_column, wrap_time_allowed @@ -496,7 +546,66 @@ def localize_time_zone( output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> TableSchema: - """Localize the time zone of the existing table to a specified time zone.""" + """ + Localize the time zone of the existing table represented by src_name to a specified time zone + + Parameters + ---------- + src_name + Refers to the table name of the source data. + time_zone + Standard time zone to localize to. If None, keep as tz-naive. + output_file + If set, write the mapped table to this Parquet file. + check_mapped_timestamps + Perform time checks on the result of the mapping operation. This can be slow and + is not required. + + Raises + ------ + TableAlreadyExists + Raised if the dst_schema name already exists. + + Returns + ------- + TableSchema + The schema of the newly created table. + + Examples + -------- + >>> store = Store() + >>> start = datetime(year=2018, month=1, day=1) # tz-naive + >>> freq = timedelta(hours=1) + >>> hours_per_year = 8760 + >>> num_time_arrays = 1 + >>> df = pd.DataFrame( + ... { + ... "id": np.concatenate( + ... [np.repeat(i, hours_per_year) for i in range(1, 1 + num_time_arrays)] + ... ), + ... "timestamp": np.tile( + ... pd.date_range(start, periods=hours_per_year, freq="h"), num_time_arrays + ... ), + ... "value": np.random.random(hours_per_year * num_time_arrays), + ... } + ... ) + >>> schema = TableSchema( + ... name="some_data", + ... time_config=DatetimeRange( + ... time_column="timestamp", + ... start=start, + ... length=hours_per_year, + ... resolution=freq, + ... ), + ... time_array_id_columns=["id"], + ... value_column="value", + ... ) + >>> store.ingest_table(df, schema) + >>> to_time_zone = ZoneInfo("Etc/GMT+5") + >>> dst_schema = store.localize_time_zone( + ... schema.name, to_time_zone, check_mapped_timestamps=True + ... ) + """ src_schema = self._schema_mgr.get_schema(src_name) tzl = TimeZoneLocalizer(self._backend, src_schema, time_zone) @@ -519,7 +628,71 @@ def localize_time_zone_by_column( output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> TableSchema: - """Localize the time zone of the existing table to time zones defined by a column.""" + """ + Localize the time zone of the existing table represented by src_name to time zones defined by a column + + Parameters + ---------- + src_name + Refers to the table name of the source data. + time_zone_column + Name of the time zone column for localization, default to None + output_file + If set, write the mapped table to this Parquet file. + check_mapped_timestamps + Perform time checks on the result of the mapping operation. This can be slow and + is not required. + + Raises + ------ + TableAlreadyExists + Raised if the dst_schema name already exists. + + Returns + ------- + TableSchema + The schema of the newly created table. + + Examples + -------- + >>> store = Store() + >>> start = datetime(year=2018, month=1, day=1) # tz-naive + >>> freq = timedelta(hours=1) + >>> hours_per_year = 8760 + >>> num_time_arrays = 3 + >>> df = pd.DataFrame( + ... { + ... "id": np.concatenate( + ... [np.repeat(i, hours_per_year) for i in range(1, 1 + num_time_arrays)] + ... ), + ... "timestamp": np.tile( + ... pd.date_range(start, periods=hours_per_year, freq="h"), num_time_arrays + ... ), + ... "time_zone": np.repeat( + ... ["Etc/GMT+5", "Etc/GMT+6", "Etc/GMT+7"], hours_per_year + ... ), # EST, CST, MST + ... "value": np.random.random(hours_per_year * num_time_arrays), + ... } + ... ) + >>> schema = TableSchema( + ... name="some_data", + ... time_config=DatetimeRange( + ... time_column="timestamp", + ... start=start, + ... length=hours_per_year, + ... resolution=freq, + ... ), + ... time_array_id_columns=["id"], + ... value_column="value", + ... ) + >>> store.ingest_table(df, schema) + >>> time_zone_column = "time_zone" + >>> dst_schema = store.localize_time_zone_by_column( + ... schema.name, + ... time_zone_column, + ... check_mapped_timestamps=True, + ... ) + """ src_schema = self._schema_mgr.get_schema(src_name) tzl = TimeZoneLocalizerByColumn(self._backend, src_schema, time_zone_column) @@ -565,6 +738,18 @@ def read_raw_query(self, query: str) -> pd.DataFrame: """Execute a query directly on the backend and return the results as a DataFrame. Note: Unlike :meth:`read_query`, no conversion of timestamps is performed. + Timestamps will be in the format of the underlying database. SQLite backends will return + strings instead of datetime. + + Parameters + ---------- + query + SQL query to execute. + + Examples + -------- + >>> store = Store() + >>> df = store.read_raw_query("SELECT * from my_table WHERE column = 'value1'") """ return self._backend.execute_sql_to_df(query) diff --git a/tests/test_store.py b/tests/test_store.py index 8fb4f7c..ec9d512 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -595,7 +595,7 @@ def test_create_methods(iter_backend_names, tmp_path): def test_invalid_backend(): - with pytest.raises(ValueError): + with pytest.raises(InvalidParameter): Store(backend_name="hive") From cdcef840fd99be9f806a8c95c0209ba6c5fadc29 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Sun, 12 Apr 2026 10:44:48 -0600 Subject: [PATCH 18/48] Update lab name --- README.md | 8 ++++---- docs/conf.py | 4 ++-- docs/how_tos/spark_backend.md | 4 ++-- pyproject.toml | 6 +++--- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index ac08053..8dbb507 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # chronify -[![Documentation](https://img.shields.io/badge/docs-ready-blue.svg)](https://nrel.github.io/chronify) -[![codecov](https://codecov.io/gh/nrel/chronify/graph/badge.svg?token=WIY2KAOX63)](https://codecov.io/gh/nrel/chronify) +[![Documentation](https://img.shields.io/badge/docs-ready-blue.svg)](https://natlabrockies.github.io/chronify) +[![codecov](https://codecov.io/gh/natlabrockies/chronify/graph/badge.svg?token=WIY2KAOX63)](https://codecov.io/gh/natlabrockies/chronify) This package implements a store for time series data in support of Python-based @@ -32,5 +32,5 @@ $ pre-commit install ``` ## License -chronify is developed under NREL Software Record SWR-21-52, "demand-side grid model". -[License](https://github.com/NREL/chronify/blob/main/LICENSE). +chronify is developed under NLR Software Record SWR-21-52, "demand-side grid model". +[License](https://github.com/NatLabRockies/chronify/blob/main/LICENSE). diff --git a/docs/conf.py b/docs/conf.py index a5708ac..867b2e6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -7,8 +7,8 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information project = "Chronify" -copyright = "2024, NREL" -author = "NREL" +copyright = "2026, Alliance for Energy Innovation" +author = "NLR" release = "0.1.0" # -- General configuration --------------------------------------------------- diff --git a/docs/how_tos/spark_backend.md b/docs/how_tos/spark_backend.md index 5e893bf..5da0ba3 100644 --- a/docs/how_tos/spark_backend.md +++ b/docs/how_tos/spark_backend.md @@ -23,8 +23,8 @@ $ $SPARK_HOME/sbin/start-thriftserver.sh --master=spark://$(hostname):7077 The URL to connect to this server is `hive://localhost:10000/default` ## Installation on an HPC -The chronify development team uses these -[scripts](https://github.com/NREL/HPC/tree/master/applications/spark) to run Spark on NREL's HPC. +The chronify development team uses this +[package](https://github.com/NatLabRockies/sparkctl) to run Spark on NLR's HPC. ## Chronify Usage This example creates a chronify Store with Spark as the backend and then adds a view to a Parquet diff --git a/pyproject.toml b/pyproject.toml index 91aa00d..f6d1b9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,9 +60,9 @@ dev = [ ] [project.urls] -Documentation = "https://github.com/NREL/chronify#readme" -Issues = "https://github.com/NREL/chronify/issues" -Source = "https://github.com/NREL/chronify" +Documentation = "https://github.com/NatLabRockies/chronify#readme" +Issues = "https://github.com/NatLabRockies/chronify/issues" +Source = "https://github.com/NatLabRockies/chronify" [tool.mypy] files = [ From 5ec2b2e9ba5cfba0af3705cb5100abe2a1d406d1 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Sun, 12 Apr 2026 13:02:30 -0600 Subject: [PATCH 19/48] Spark fixes --- src/chronify/ibis/base.py | 21 ++++++++++- src/chronify/ibis/duckdb_backend.py | 17 +++++++-- src/chronify/ibis/functions.py | 9 ----- src/chronify/ibis/spark_backend.py | 57 ++++++++++++++++++++++------- src/chronify/ibis/sqlite_backend.py | 21 +++++++++-- src/chronify/store.py | 17 +++++---- 6 files changed, 105 insertions(+), 37 deletions(-) diff --git a/src/chronify/ibis/base.py b/src/chronify/ibis/base.py index d75b2f5..c692c54 100644 --- a/src/chronify/ibis/base.py +++ b/src/chronify/ibis/base.py @@ -19,6 +19,8 @@ class ObjectType(StrEnum): class IbisBackend(ABC): """Abstract base class defining the interface for Ibis database backends.""" + _table_cache: set[str] | None + @property @abstractmethod def name(self) -> str: @@ -119,7 +121,24 @@ def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, Obje def has_table(self, name: str) -> bool: """Check whether a table or view exists.""" - return name in self.list_tables() + if self._table_cache is None: + self._refresh_table_cache() + assert self._table_cache is not None + return name in self._table_cache + + def _refresh_table_cache(self) -> None: + self._table_cache = set(self.list_tables()) + + def _mark_table_created(self, name: str) -> None: + if self._table_cache is not None: + self._table_cache.add(name) + + def _mark_table_dropped(self, name: str) -> None: + if self._table_cache is not None: + self._table_cache.discard(name) + + def _invalidate_table_cache(self) -> None: + self._table_cache = None def execute_sql(self, query: str) -> Any: """Execute a raw SQL statement (no result expected).""" diff --git a/src/chronify/ibis/duckdb_backend.py b/src/chronify/ibis/duckdb_backend.py index 6463f64..656636b 100644 --- a/src/chronify/ibis/duckdb_backend.py +++ b/src/chronify/ibis/duckdb_backend.py @@ -38,6 +38,7 @@ def __init__( msg = f"{database=} and {connection=} cannot both be set" raise ConflictingInputsError(msg) + self._table_cache = None self._owns_connection = connection is None if connection is None: db = str(database) @@ -69,21 +70,29 @@ def create_table( schema: ibis.Schema | None = None, overwrite: bool = False, ) -> ir.Table: - return self._connection.create_table(name, obj=obj, schema=schema, overwrite=overwrite) + table = self._connection.create_table(name, obj=obj, schema=schema, overwrite=overwrite) + self._mark_table_created(name) + return table def create_view(self, name: str, expr: ir.Table) -> ir.Table: - return self._connection.create_view(name, expr, overwrite=False) + view = self._connection.create_view(name, expr, overwrite=False) + self._mark_table_created(name) + return view def drop_table(self, name: str) -> None: self._connection.drop_table(name, force=True) + self._mark_table_dropped(name) def drop_view(self, name: str) -> None: self._connection.drop_view(name, force=True) + self._mark_table_dropped(name) def list_tables(self) -> list[str]: tables = self._connection.list_tables() # Filter out internal ibis memtables - return [t for t in tables if not t.startswith("ibis_pandas_memtable_")] + tables = [t for t in tables if not t.startswith("ibis_pandas_memtable_")] + self._table_cache = set(tables) + return tables def table(self, name: str) -> ir.Table: return self._connection.table(name) @@ -147,11 +156,13 @@ def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, Obje self._connection.raw_sql( f"CREATE VIEW {quoted_name} AS SELECT * FROM read_parquet('{escaped_path}')" ) + self._mark_table_created(name) return self.table(name), ObjectType.VIEW def execute_sql(self, query: str) -> None: logger.trace("execute_sql: {}", query) self._connection.raw_sql(query) + self._invalidate_table_cache() def execute_sql_to_df(self, query: str) -> pd.DataFrame: logger.trace("execute_sql_to_df: {}", query) diff --git a/src/chronify/ibis/functions.py b/src/chronify/ibis/functions.py index 80aba58..77363b8 100644 --- a/src/chronify/ibis/functions.py +++ b/src/chronify/ibis/functions.py @@ -142,15 +142,6 @@ def write_parquet( else: expr = query - if backend.name == "spark" and isinstance(config, _DATETIME_RANGES): - df = backend.execute(expr) - _convert_spark_output_for_datetime(df, config) - if partition_columns: - df.to_parquet(output_file, partition_cols=partition_columns) - else: - df.to_parquet(output_file) - return - backend.write_parquet(expr, str(output_file), partition_by=partition_columns) diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index 8e16535..630718d 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -39,6 +39,7 @@ def __init__(self, session: Any = None) -> None: .getOrCreate() ) self._validate_session(session) + self._table_cache = None self._session = session self._connection = ibis.pyspark.connect(session) @@ -64,24 +65,36 @@ def create_table( if isinstance(obj, pd.DataFrame): obj = self._prepare_data_for_spark(obj) try: - return self._connection.create_table(name, obj=obj, schema=schema, overwrite=overwrite) + table = self._connection.create_table( + name, obj=obj, schema=schema, overwrite=overwrite + ) except Exception as exc: if "LOCATION_ALREADY_EXISTS" not in str(exc): raise self._remove_managed_table_location(name) - return self._connection.create_table(name, obj=obj, schema=schema, overwrite=overwrite) + table = self._connection.create_table( + name, obj=obj, schema=schema, overwrite=overwrite + ) + self._mark_table_created(name) + return table def create_view(self, name: str, expr: ir.Table) -> ir.Table: - return self._connection.create_view(name, expr, overwrite=False) + view = self._connection.create_view(name, expr, overwrite=False) + self._mark_table_created(name) + return view def drop_table(self, name: str) -> None: self._connection.drop_table(name, force=True) + self._mark_table_dropped(name) def drop_view(self, name: str) -> None: self._connection.drop_view(name, force=True) + self._mark_table_dropped(name) def list_tables(self) -> list[str]: - return cast(list[str], self._connection.list_tables()) + tables = cast(list[str], self._connection.list_tables()) + self._table_cache = set(tables) + return tables def table(self, name: str) -> ir.Table: return self._connection.table(name) @@ -123,12 +136,24 @@ def delete_rows(self, name: str, values: dict[str, Any]) -> None: except Exception as exc: if "does not support DELETE" not in str(exc): raise - df = self._connection.execute(self.table(name)) - for column, value in values.items(): - df = df[df[column] != value] - self.create_table(name, obj=df, overwrite=True) + self._overwrite_without_deleted_rows(name, where, args) logger.trace("Deleted rows from {} matching {}", name, values) + def _overwrite_without_deleted_rows(self, name: str, where: str, args: dict[str, Any]) -> None: + quoted_name = _quote_identifier(name) + tmp_name = f"__chronify_delete_{uuid.uuid4().hex}" + quoted_tmp = _quote_identifier(tmp_name) + try: + self._session.sql( + f"CREATE TABLE {quoted_tmp} AS " + f"SELECT * FROM {quoted_name} WHERE NOT ({where})", + args=args, + ) + self._session.sql(f"INSERT OVERWRITE TABLE {quoted_name} SELECT * FROM {quoted_tmp}") + finally: + self._session.sql(f"DROP TABLE IF EXISTS {quoted_tmp}") + self._remove_managed_table_location(tmp_name) + def execute(self, expr: ir.Expr) -> pd.DataFrame: return cast(pd.DataFrame, self._connection.execute(expr)) @@ -141,21 +166,27 @@ def write_parquet( path: str, partition_by: list[str] | None = None, ) -> None: - df = self._connection.execute(expr) + df = self._to_spark_dataframe(expr) + writer = df.write.mode("errorifexists") if partition_by: - spark_df = self._session.createDataFrame(df) - spark_df.write.partitionBy(*partition_by).parquet(path) + writer.partitionBy(*partition_by).parquet(path) else: - df.to_parquet(path) + writer.parquet(path) + + def _to_spark_dataframe(self, expr: ir.Table) -> Any: + sql = self._connection.compile(expr) + return self._session.sql(sql) def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, ObjectType]: spark_df = self._session.read.parquet(path) spark_df.createOrReplaceTempView(name) + self._mark_table_created(name) return self.table(name), ObjectType.VIEW def execute_sql(self, query: str) -> None: logger.trace("execute_sql: {}", query) self._session.sql(query) + self._invalidate_table_cache() def execute_sql_to_df(self, query: str) -> pd.DataFrame: logger.trace("execute_sql_to_df: {}", query) @@ -171,7 +202,7 @@ def backup(self, dst: str) -> None: raise InvalidOperation(msg) def _remove_managed_table_location(self, name: str) -> None: - location = self._session.conf.get("spark.sql.warehouse.dir", "spark-warehouse") + location = str(self._session.conf.get("spark.sql.warehouse.dir", "spark-warehouse")) parsed = urlparse(location) if parsed.scheme == "file": warehouse = Path(unquote(parsed.path)) diff --git a/src/chronify/ibis/sqlite_backend.py b/src/chronify/ibis/sqlite_backend.py index 468f760..87b32ad 100644 --- a/src/chronify/ibis/sqlite_backend.py +++ b/src/chronify/ibis/sqlite_backend.py @@ -54,6 +54,7 @@ def __init__( msg = f"{database=} and {connection=} cannot both be set" raise ConflictingInputsError(msg) + self._table_cache = None self._owns_connection = connection is None if connection is None: db = str(database) @@ -89,20 +90,31 @@ def create_table( # SQLite CREATE TABLE AS SELECT loses datetime type info. # Execute the expression first, then create from the DataFrame. df = self._connection.execute(obj) - return self._connection.create_table(name, obj=df, overwrite=overwrite) - return self._connection.create_table(name, obj=obj, schema=schema, overwrite=overwrite) + table = self._connection.create_table(name, obj=df, overwrite=overwrite) + else: + table = self._connection.create_table( + name, obj=obj, schema=schema, overwrite=overwrite + ) + self._mark_table_created(name) + return table def create_view(self, name: str, expr: ir.Table) -> ir.Table: - return self._connection.create_view(name, expr, overwrite=False) + view = self._connection.create_view(name, expr, overwrite=False) + self._mark_table_created(name) + return view def drop_table(self, name: str) -> None: self._connection.drop_table(name, force=True) + self._mark_table_dropped(name) def drop_view(self, name: str) -> None: self._connection.drop_view(name, force=True) + self._mark_table_dropped(name) def list_tables(self) -> list[str]: - return cast(list[str], self._connection.list_tables()) + tables = cast(list[str], self._connection.list_tables()) + self._table_cache = set(tables) + return tables def table(self, name: str) -> ir.Table: return self._connection.table(name) @@ -162,6 +174,7 @@ def execute_sql(self, query: str) -> None: con = self._connection.con con.execute(query) con.commit() + self._invalidate_table_cache() def execute_sql_to_df(self, query: str) -> pd.DataFrame: logger.trace("execute_sql_to_df: {}", query) diff --git a/src/chronify/store.py b/src/chronify/store.py index 2b14c7d..483e6a7 100644 --- a/src/chronify/store.py +++ b/src/chronify/store.py @@ -1,6 +1,6 @@ from collections.abc import Iterable from pathlib import Path -from typing import Any, Optional +from typing import Any, Optional, cast from datetime import tzinfo import duckdb @@ -849,7 +849,7 @@ def delete_rows( table = self._backend.table(name) predicates = [table[column] == value for column, value in time_array_id_values.items()] filtered = table.filter(predicates) - num_to_delete = int(filtered.count().execute()) + num_to_delete = int(cast(Any, filtered.count().execute())) self._backend.delete_rows(name, time_array_id_values) @@ -863,11 +863,14 @@ def delete_rows( time_array_id_values, ) - # Check if table is now empty - remaining = int(self._backend.table(name).count().execute()) - if remaining == 0: - logger.info("Delete empty table {}", name) - self.drop_table(name) + # Avoid an additional full distributed count on Spark. Local backends keep + # the historical behavior of removing the table after its final row group + # is deleted. + if self._backend.name != "spark": + remaining = int(cast(Any, self._backend.table(name).count().execute())) + if remaining == 0: + logger.info("Delete empty table {}", name) + self.drop_table(name) return num_to_delete From 827d849cd245493b0db91ca091d94724eabae9fe Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Sun, 12 Apr 2026 13:09:38 -0600 Subject: [PATCH 20/48] Restore DuckDB native fetch_df path and limit cache invalidation to DDL Route DuckDBBackend.execute() through the native duckdb cursor's fetch_df() to recover the zero-copy Arrow materialization used on main, and only invalidate the table-name cache on DDL statements. Co-Authored-By: Claude Opus 4.6 --- src/chronify/ibis/base.py | 16 +++++++++++++++- src/chronify/ibis/duckdb_backend.py | 10 ++++++++-- src/chronify/ibis/spark_backend.py | 5 +++-- src/chronify/ibis/sqlite_backend.py | 5 +++-- 4 files changed, 29 insertions(+), 7 deletions(-) diff --git a/src/chronify/ibis/base.py b/src/chronify/ibis/base.py index c692c54..b5c61b9 100644 --- a/src/chronify/ibis/base.py +++ b/src/chronify/ibis/base.py @@ -1,5 +1,6 @@ """Abstract base class for Ibis database backends.""" +import re from abc import ABC, abstractmethod from contextlib import contextmanager from enum import StrEnum @@ -10,6 +11,16 @@ import pandas as pd from loguru import logger +_DDL_RE = re.compile( + r"^\s*(?:WITH\s+.+?\s+)?(CREATE|DROP|ALTER|TRUNCATE|RENAME)\b", + re.IGNORECASE | re.DOTALL, +) + + +def _is_ddl(query: str) -> bool: + """Return True if the SQL statement changes the set of tables/views.""" + return _DDL_RE.match(query) is not None + class ObjectType(StrEnum): TABLE = "table" @@ -143,7 +154,10 @@ def _invalidate_table_cache(self) -> None: def execute_sql(self, query: str) -> Any: """Execute a raw SQL statement (no result expected).""" logger.trace("execute_sql: {}", query) - return self.connection.raw_sql(query) + result = self.connection.raw_sql(query) + if _is_ddl(query): + self._invalidate_table_cache() + return result def execute_sql_to_df(self, query: str) -> pd.DataFrame: """Execute a raw SQL query and return a DataFrame.""" diff --git a/src/chronify/ibis/duckdb_backend.py b/src/chronify/ibis/duckdb_backend.py index 656636b..79658c7 100644 --- a/src/chronify/ibis/duckdb_backend.py +++ b/src/chronify/ibis/duckdb_backend.py @@ -10,7 +10,7 @@ from loguru import logger from chronify.exceptions import ConflictingInputsError, InvalidOperation, InvalidParameter -from chronify.ibis.base import IbisBackend, ObjectType +from chronify.ibis.base import IbisBackend, ObjectType, _is_ddl class DuckDBBackend(IbisBackend): @@ -123,6 +123,11 @@ def delete_rows(self, name: str, values: dict[str, Any]) -> None: logger.trace("Deleted rows from {} matching {}", name, values) def execute(self, expr: ir.Expr) -> pd.DataFrame: + # Bypass Ibis's generic pandas materialization and use DuckDB's native + # cursor.fetch_df(), which is zero-copy from Arrow. + if isinstance(expr, ir.Table): + sql = self._connection.compile(expr) + return cast(pd.DataFrame, self._connection.con.execute(sql).fetch_df()) return cast(pd.DataFrame, self._connection.execute(expr)) def sql(self, query: str) -> ir.Table: @@ -162,7 +167,8 @@ def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, Obje def execute_sql(self, query: str) -> None: logger.trace("execute_sql: {}", query) self._connection.raw_sql(query) - self._invalidate_table_cache() + if _is_ddl(query): + self._invalidate_table_cache() def execute_sql_to_df(self, query: str) -> pd.DataFrame: logger.trace("execute_sql_to_df: {}", query) diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index 630718d..df6462a 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -14,7 +14,7 @@ from pandas import DatetimeTZDtype from chronify.exceptions import InvalidOperation, InvalidParameter -from chronify.ibis.base import IbisBackend, ObjectType +from chronify.ibis.base import IbisBackend, ObjectType, _is_ddl class SparkBackend(IbisBackend): @@ -186,7 +186,8 @@ def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, Obje def execute_sql(self, query: str) -> None: logger.trace("execute_sql: {}", query) self._session.sql(query) - self._invalidate_table_cache() + if _is_ddl(query): + self._invalidate_table_cache() def execute_sql_to_df(self, query: str) -> pd.DataFrame: logger.trace("execute_sql_to_df: {}", query) diff --git a/src/chronify/ibis/sqlite_backend.py b/src/chronify/ibis/sqlite_backend.py index 87b32ad..60ac2f0 100644 --- a/src/chronify/ibis/sqlite_backend.py +++ b/src/chronify/ibis/sqlite_backend.py @@ -11,7 +11,7 @@ from loguru import logger from chronify.exceptions import ConflictingInputsError, InvalidOperation, InvalidParameter -from chronify.ibis.base import IbisBackend, ObjectType +from chronify.ibis.base import IbisBackend, ObjectType, _is_ddl def _adapt_value(v: Any) -> Any: @@ -174,7 +174,8 @@ def execute_sql(self, query: str) -> None: con = self._connection.con con.execute(query) con.commit() - self._invalidate_table_cache() + if _is_ddl(query): + self._invalidate_table_cache() def execute_sql_to_df(self, query: str) -> pd.DataFrame: logger.trace("execute_sql_to_df: {}", query) From c2e8d06c49f683ec2ca9457f6e3b4dbd3e77368a Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Sun, 12 Apr 2026 13:25:43 -0600 Subject: [PATCH 21/48] Use transactions --- src/chronify/ibis/base.py | 18 +++++++++++-- src/chronify/ibis/duckdb_backend.py | 39 +++++++++++++++++++++++++---- src/chronify/ibis/functions.py | 32 +++++++++++++++++++---- src/chronify/ibis/spark_backend.py | 7 ++++-- src/chronify/ibis/sqlite_backend.py | 31 +++++++++++++++++++---- src/chronify/store.py | 12 ++++++--- tests/test_ibis_functions.py | 23 +++++++++++++++++ 7 files changed, 140 insertions(+), 22 deletions(-) diff --git a/src/chronify/ibis/base.py b/src/chronify/ibis/base.py index b5c61b9..c0475df 100644 --- a/src/chronify/ibis/base.py +++ b/src/chronify/ibis/base.py @@ -9,6 +9,7 @@ import ibis import ibis.expr.types as ir import pandas as pd +import pyarrow as pa from loguru import logger _DDL_RE = re.compile( @@ -51,7 +52,7 @@ def connection(self) -> ibis.BaseBackend: def create_table( self, name: str, - obj: pd.DataFrame | ir.Table | None = None, + obj: pd.DataFrame | pa.Table | ir.Table | None = None, schema: ibis.Schema | None = None, overwrite: bool = False, ) -> ir.Table: @@ -94,7 +95,7 @@ def table(self, name: str) -> ir.Table: """Return an ibis table expression for the named table.""" @abstractmethod - def insert(self, name: str, data: pd.DataFrame) -> None: + def insert(self, name: str, data: pd.DataFrame | pa.Table) -> None: """Insert data into an existing table.""" @abstractmethod @@ -176,6 +177,15 @@ def backup(self, dst: str) -> None: file storage (e.g., Spark). """ + def _begin_transaction(self) -> None: + """Start a real database transaction, if the backend supports one.""" + + def _commit_transaction(self) -> None: + """Commit a real database transaction, if one was started.""" + + def _rollback_transaction(self) -> None: + """Roll back a real database transaction, if one was started.""" + @contextmanager def transaction(self) -> Generator[list[tuple[str, ObjectType]], None, None]: """Context manager for pseudo-transactions. @@ -186,9 +196,11 @@ def transaction(self) -> Generator[list[tuple[str, ObjectType]], None, None]: Yields a list to which callers should append (name, ObjectType) tuples. """ created: list[tuple[str, ObjectType]] = [] + self._begin_transaction() try: yield created except Exception: + self._rollback_transaction() for obj_name, obj_type in reversed(created): try: if obj_type == ObjectType.TABLE: @@ -199,3 +211,5 @@ def transaction(self) -> Generator[list[tuple[str, ObjectType]], None, None]: except Exception: logger.warning("Failed to roll back {} {}", obj_type.value, obj_name) raise + else: + self._commit_transaction() diff --git a/src/chronify/ibis/duckdb_backend.py b/src/chronify/ibis/duckdb_backend.py index 79658c7..d9702e9 100644 --- a/src/chronify/ibis/duckdb_backend.py +++ b/src/chronify/ibis/duckdb_backend.py @@ -7,6 +7,7 @@ import ibis import ibis.expr.types as ir import pandas as pd +import pyarrow as pa from loguru import logger from chronify.exceptions import ConflictingInputsError, InvalidOperation, InvalidParameter @@ -66,7 +67,7 @@ def connection(self) -> ibis.BaseBackend: def create_table( self, name: str, - obj: pd.DataFrame | ir.Table | None = None, + obj: pd.DataFrame | pa.Table | ir.Table | None = None, schema: ibis.Schema | None = None, overwrite: bool = False, ) -> ir.Table: @@ -97,11 +98,11 @@ def list_tables(self) -> list[str]: def table(self, name: str) -> ir.Table: return self._connection.table(name) - def insert(self, name: str, data: pd.DataFrame) -> None: + def insert(self, name: str, data: pd.DataFrame | pa.Table) -> None: con = self._connection.con # raw duckdb connection target_columns = list(self.table(name).columns) - _validate_insert_columns(name, target_columns, list(data.columns)) - ordered_data = data.loc[:, target_columns] + _validate_insert_columns(name, target_columns, _get_columns(data)) + ordered_data = _select_columns(data, target_columns) quoted_columns = ", ".join(f'"{col}"' for col in target_columns) quoted_name = _quote_identifier(name) con.register("__insert_df", ordered_data) @@ -112,7 +113,7 @@ def insert(self, name: str, data: pd.DataFrame) -> None: ) finally: con.unregister("__insert_df") - logger.trace("Inserted {} rows into {}", len(data), name) + logger.trace("Inserted {} rows into {}", _row_count(data), name) def delete_rows(self, name: str, values: dict[str, Any]) -> None: con = self._connection.con @@ -193,6 +194,16 @@ def backup(self, dst: str) -> None: finally: self._connection = ibis.duckdb.connect(src) + def _begin_transaction(self) -> None: + self._connection.con.execute("BEGIN TRANSACTION") + + def _commit_transaction(self) -> None: + self._connection.con.execute("COMMIT") + + def _rollback_transaction(self) -> None: + self._connection.con.execute("ROLLBACK") + self._invalidate_table_cache() + def _infer_duckdb_path(connection: ibis.BaseBackend) -> str | None: """Return the database file path for an ibis DuckDB connection, or None for in-memory.""" @@ -225,3 +236,21 @@ def _quote_identifier(identifier: str) -> str: """Quote a SQL identifier for DuckDB, escaping embedded double quotes.""" escaped = identifier.replace('"', '""') return f'"{escaped}"' + + +def _get_columns(data: pd.DataFrame | pa.Table) -> list[str]: + if isinstance(data, pa.Table): + return cast(list[str], data.column_names) + return list(data.columns) + + +def _select_columns(data: pd.DataFrame | pa.Table, columns: list[str]) -> pd.DataFrame | pa.Table: + if isinstance(data, pa.Table): + return data.select(columns) + return data.loc[:, columns] + + +def _row_count(data: pd.DataFrame | pa.Table) -> int: + if isinstance(data, pa.Table): + return cast(int, data.num_rows) + return len(data) diff --git a/src/chronify/ibis/functions.py b/src/chronify/ibis/functions.py index 77363b8..9fb6dfe 100644 --- a/src/chronify/ibis/functions.py +++ b/src/chronify/ibis/functions.py @@ -103,18 +103,21 @@ def write_table( configs: Sequence[TimeBaseModel], if_exists: str = "append", ) -> None: - """Write a DataFrame to the database.""" - if isinstance(df, pa.Table): + """Write tabular data to the database.""" + if isinstance(df, pa.Table) and ( + backend.name in {"sqlite", "spark"} or _arrow_needs_timestamp_normalization(df, configs) + ): df = df.to_pandas() _check_one_config_per_datetime_column(configs) - df = _normalize_timestamps(df, configs) + if isinstance(df, pd.DataFrame): + df = _normalize_timestamps(df, configs) if backend.name not in {"duckdb", "sqlite", "spark"}: msg = f"Unsupported backend: {backend.name}" raise NotImplementedError(msg) - if backend.name == "sqlite": + if backend.name == "sqlite" and isinstance(df, pd.DataFrame): # SQLite-specific: ensure TZ timestamps are stored as UTC text. # _normalize_timestamps already ran, so NTZ columns are tz-naive and # TZ columns are tz-aware UTC. This step converts TZ to UTC for storage. @@ -168,6 +171,25 @@ def _check_one_config_per_datetime_column(configs: Sequence[TimeBaseModel]) -> N raise InvalidParameter(msg) +def _arrow_needs_timestamp_normalization( + table: pa.Table, + configs: Sequence[TimeBaseModel], +) -> bool: + fields = {field.name: field.type for field in table.schema} + for config in configs: + if not isinstance(config, _DATETIME_RANGES): + continue + arrow_type = fields.get(config.time_column) + if arrow_type is None or not pa.types.is_timestamp(arrow_type): + continue + timezone = arrow_type.tz + if config.dtype == TimeDataType.TIMESTAMP_NTZ and timezone is not None: + return True + if config.dtype == TimeDataType.TIMESTAMP_TZ and timezone is None: + return True + return False + + def _convert_database_input_for_datetime( df: pd.DataFrame, config: DatetimeRanges, copied: bool ) -> tuple[pd.DataFrame, bool]: @@ -230,7 +252,7 @@ def _convert_spark_output_for_datetime(df: pd.DataFrame, config: DatetimeRanges) def _apply_if_exists( backend: IbisBackend, - df: pd.DataFrame, + df: pd.DataFrame | pa.Table, table_name: str, if_exists: str, ) -> None: diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index df6462a..1408b9e 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -10,6 +10,7 @@ import ibis import ibis.expr.types as ir import pandas as pd +import pyarrow as pa from loguru import logger from pandas import DatetimeTZDtype @@ -58,7 +59,7 @@ def connection(self) -> ibis.BaseBackend: def create_table( self, name: str, - obj: pd.DataFrame | ir.Table | None = None, + obj: pd.DataFrame | pa.Table | ir.Table | None = None, schema: ibis.Schema | None = None, overwrite: bool = False, ) -> ir.Table: @@ -99,7 +100,9 @@ def list_tables(self) -> list[str]: def table(self, name: str) -> ir.Table: return self._connection.table(name) - def insert(self, name: str, data: pd.DataFrame) -> None: + def insert(self, name: str, data: pd.DataFrame | pa.Table) -> None: + if isinstance(data, pa.Table): + data = data.to_pandas() # Spark doesn't support INSERT directly -- create a temp view and insert via SQL target_columns = list(self.table(name).columns) _validate_insert_columns(name, target_columns, list(data.columns)) diff --git a/src/chronify/ibis/sqlite_backend.py b/src/chronify/ibis/sqlite_backend.py index 60ac2f0..a629e1a 100644 --- a/src/chronify/ibis/sqlite_backend.py +++ b/src/chronify/ibis/sqlite_backend.py @@ -8,6 +8,7 @@ import ibis import ibis.expr.types as ir import pandas as pd +import pyarrow as pa from loguru import logger from chronify.exceptions import ConflictingInputsError, InvalidOperation, InvalidParameter @@ -55,6 +56,7 @@ def __init__( raise ConflictingInputsError(msg) self._table_cache = None + self._in_transaction = False self._owns_connection = connection is None if connection is None: db = str(database) @@ -82,7 +84,7 @@ def connection(self) -> ibis.BaseBackend: def create_table( self, name: str, - obj: pd.DataFrame | ir.Table | None = None, + obj: pd.DataFrame | pa.Table | ir.Table | None = None, schema: ibis.Schema | None = None, overwrite: bool = False, ) -> ir.Table: @@ -119,7 +121,9 @@ def list_tables(self) -> list[str]: def table(self, name: str) -> ir.Table: return self._connection.table(name) - def insert(self, name: str, data: pd.DataFrame) -> None: + def insert(self, name: str, data: pd.DataFrame | pa.Table) -> None: + if isinstance(data, pa.Table): + data = data.to_pandas() # Use raw SQLite cursor for parameterized inserts con = self._connection.con # raw sqlite3 connection table = self._connection.table(name) @@ -134,7 +138,7 @@ def insert(self, name: str, data: pd.DataFrame) -> None: rows = [tuple(_adapt_value(v) for v in row) for row in ordered.itertuples(index=False)] cursor = con.cursor() cursor.executemany(sql, rows) - con.commit() + self._commit_if_needed() logger.trace("Inserted {} rows into {}", len(data), name) def delete_rows(self, name: str, values: dict[str, Any]) -> None: @@ -143,7 +147,7 @@ def delete_rows(self, name: str, values: dict[str, Any]) -> None: where = " AND ".join(f"{_quote_identifier(c)} = ?" for c in values) sql = f"DELETE FROM {quoted_name} WHERE {where}" con.execute(sql, list(values.values())) - con.commit() + self._commit_if_needed() logger.trace("Deleted rows from {} matching {}", name, values) def execute(self, expr: ir.Expr) -> pd.DataFrame: @@ -173,7 +177,7 @@ def execute_sql(self, query: str) -> None: logger.trace("execute_sql: {}", query) con = self._connection.con con.execute(query) - con.commit() + self._commit_if_needed() if _is_ddl(query): self._invalidate_table_cache() @@ -199,6 +203,23 @@ def backup(self, dst: str) -> None: finally: dst_con.close() + def _begin_transaction(self) -> None: + self._connection.con.execute("BEGIN") + self._in_transaction = True + + def _commit_transaction(self) -> None: + self._connection.con.commit() + self._in_transaction = False + + def _rollback_transaction(self) -> None: + self._connection.con.rollback() + self._in_transaction = False + self._invalidate_table_cache() + + def _commit_if_needed(self) -> None: + if not self._in_transaction: + self._connection.con.commit() + def _infer_sqlite_path(connection: ibis.BaseBackend) -> str | None: """Return the database file path for an ibis SQLite connection, or None for in-memory.""" diff --git a/src/chronify/store.py b/src/chronify/store.py index 483e6a7..c5ec792 100644 --- a/src/chronify/store.py +++ b/src/chronify/store.py @@ -204,10 +204,12 @@ def ingest_from_csvs( """Ingest data from multiple CSV files into the table specified by schema.""" table_existed = self._backend.has_table(dst_schema.name) try: - created_table = self._ingest_from_csvs(paths, src_schema, dst_schema) + with self._backend.transaction(): + created_table = self._ingest_from_csvs(paths, src_schema, dst_schema) except Exception: if not table_existed and self._backend.has_table(dst_schema.name): self._backend.drop_table(dst_schema.name) + if not table_existed: self._schema_mgr.remove_schema(dst_schema.name) raise return created_table @@ -269,10 +271,12 @@ def ingest_pivoted_tables( """Ingest pivoted data from multiple tables. Unpivot before ingesting.""" table_existed = self._backend.has_table(dst_schema.name) try: - created_table = self._ingest_pivoted_tables(data, src_schema, dst_schema) + with self._backend.transaction(): + created_table = self._ingest_pivoted_tables(data, src_schema, dst_schema) except Exception: if not table_existed and self._backend.has_table(dst_schema.name): self._backend.drop_table(dst_schema.name) + if not table_existed: self._schema_mgr.remove_schema(dst_schema.name) raise return created_table @@ -352,10 +356,12 @@ def ingest_tables( table_existed = self._backend.has_table(schema.name) try: - created_table = self._ingest_tables(data, schema, **kwargs) + with self._backend.transaction(): + created_table = self._ingest_tables(data, schema, **kwargs) except Exception: if not table_existed and self._backend.has_table(schema.name): self._backend.drop_table(schema.name) + if not table_existed: self._schema_mgr.remove_schema(schema.name) raise return created_table diff --git a/tests/test_ibis_functions.py b/tests/test_ibis_functions.py index 551800a..6a101c7 100644 --- a/tests/test_ibis_functions.py +++ b/tests/test_ibis_functions.py @@ -195,6 +195,29 @@ def test_pyarrow_table_input(self): assert len(df) == 2 backend.dispose() + def test_pyarrow_table_input_stays_arrow_for_duckdb(self, monkeypatch): + backend = make_backend("duckdb") + config = _make_ntz_config() + pa_table = pa.table( + { + "timestamp": pd.to_datetime(["2020-01-01 00:00:00", "2020-01-01 01:00:00"]), + "value": [1.0, 2.0], + } + ) + seen_arrow = False + + def create_table(name, obj=None, schema=None, overwrite=False): + nonlocal seen_arrow + seen_arrow = isinstance(obj, pa.Table) + return backend.connection.create_table( + name, obj=obj, schema=schema, overwrite=overwrite + ) + + monkeypatch.setattr(backend, "create_table", create_table) + write_table(backend, pa_table, "pa_test_arrow", [config], if_exists="fail") + assert seen_arrow + backend.dispose() + def test_invalid_if_exists_duckdb(self): backend = make_backend("duckdb") config = _make_ntz_config() From 63ea56af491b72e6bf74d04133546434fb80050a Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Sun, 12 Apr 2026 16:35:57 -0600 Subject: [PATCH 22/48] Move per-backend I/O logic from functions.py into IbisBackend The free functions in chronify/ibis/functions.py branched on backend.name to decide how to normalize timestamps and whether to convert Arrow to pandas. That feature envy blocked per-backend fast paths (e.g., Arrow-native ingest, Spark-native execution) because the dispatcher owned the format choice. Push read_table, read_query, and write_table onto IbisBackend as concrete methods with _post_read_normalize and _prepare_write_data hooks that SQLiteBackend and SparkBackend override. Delete functions.py and update all callers to invoke backend methods directly. Co-Authored-By: Claude Opus 4.6 --- src/chronify/ibis/base.py | 144 +++++++++- src/chronify/ibis/functions.py | 269 ------------------ src/chronify/ibis/spark_backend.py | 49 +++- src/chronify/ibis/sqlite_backend.py | 70 ++++- src/chronify/store.py | 53 +--- src/chronify/time_series_checker.py | 5 +- src/chronify/time_series_mapper_base.py | 13 +- ...apper_column_representative_to_datetime.py | 4 +- src/chronify/time_series_mapper_index_time.py | 5 +- .../time_series_mapper_representative.py | 3 +- src/chronify/time_zone_converter.py | 3 +- src/chronify/time_zone_localizer.py | 3 +- tests/test_checker_representative_time.py | 3 +- tests/test_ibis_functions.py | 32 +-- ...apper_column_representative_to_datetime.py | 9 +- tests/test_mapper_datetime_to_datetime.py | 7 +- tests/test_mapper_index_time_to_datetime.py | 5 +- ..._mapper_representative_time_to_datetime.py | 5 +- tests/test_time_series_checker.py | 3 +- tests/test_time_zone_converter.py | 5 +- tests/test_time_zone_localizer.py | 5 +- 21 files changed, 309 insertions(+), 386 deletions(-) delete mode 100644 src/chronify/ibis/functions.py diff --git a/src/chronify/ibis/base.py b/src/chronify/ibis/base.py index c0475df..43eaa8f 100644 --- a/src/chronify/ibis/base.py +++ b/src/chronify/ibis/base.py @@ -2,27 +2,101 @@ import re from abc import ABC, abstractmethod +from collections import Counter from contextlib import contextmanager from enum import StrEnum -from typing import Any, Generator, cast +from typing import Any, Generator, Sequence, cast import ibis import ibis.expr.types as ir import pandas as pd import pyarrow as pa from loguru import logger +from pandas import DatetimeTZDtype + +from chronify.exceptions import InvalidOperation, InvalidParameter +from chronify.time import TimeDataType +from chronify.time_configs import ( + DatetimeRange, + DatetimeRangeBase, + DatetimeRangeWithTZColumn, + TimeBaseModel, +) _DDL_RE = re.compile( r"^\s*(?:WITH\s+.+?\s+)?(CREATE|DROP|ALTER|TRUNCATE|RENAME)\b", re.IGNORECASE | re.DOTALL, ) +_DATETIME_RANGES: tuple[type, ...] = (DatetimeRange, DatetimeRangeWithTZColumn) +DatetimeRanges = DatetimeRange | DatetimeRangeWithTZColumn + def _is_ddl(query: str) -> bool: """Return True if the SQL statement changes the set of tables/views.""" return _DDL_RE.match(query) is not None +def _check_one_config_per_datetime_column(configs: Sequence[TimeBaseModel]) -> None: + time_col_count = Counter( + config.time_column for config in configs if isinstance(config, DatetimeRangeBase) + ) + time_col_dup = {k: v for k, v in time_col_count.items() if v > 1} + if time_col_dup: + msg = f"More than one datetime config found for: {time_col_dup}" + raise InvalidParameter(msg) + + +def _normalize_timestamps( + df: pd.DataFrame, + configs: Sequence[TimeBaseModel], +) -> pd.DataFrame: + """Normalize datetime columns so their pandas dtype matches the schema config.""" + copied = False + for config in configs: + if not isinstance(config, _DATETIME_RANGES): + continue + col = config.time_column + if col not in df.columns: + continue + if not pd.api.types.is_datetime64_any_dtype(df[col]): + continue + + is_tz_aware = isinstance(df[col].dtype, DatetimeTZDtype) + + if config.dtype == TimeDataType.TIMESTAMP_NTZ and is_tz_aware: + if not copied: + df = df.copy() + copied = True + df[col] = df[col].dt.tz_convert("UTC").dt.tz_localize(None) + elif config.dtype == TimeDataType.TIMESTAMP_TZ and not is_tz_aware: + if not copied: + df = df.copy() + copied = True + df[col] = df[col].dt.tz_localize("UTC") + + return df + + +def _arrow_needs_timestamp_normalization( + table: pa.Table, + configs: Sequence[TimeBaseModel], +) -> bool: + fields = {field.name: field.type for field in table.schema} + for config in configs: + if not isinstance(config, _DATETIME_RANGES): + continue + arrow_type = fields.get(config.time_column) + if arrow_type is None or not pa.types.is_timestamp(arrow_type): + continue + timezone = arrow_type.tz + if config.dtype == TimeDataType.TIMESTAMP_NTZ and timezone is not None: + return True + if config.dtype == TimeDataType.TIMESTAMP_TZ and timezone is None: + return True + return False + + class ObjectType(StrEnum): TABLE = "table" VIEW = "view" @@ -108,7 +182,8 @@ def delete_rows(self, name: str, values: dict[str, Any]) -> None: @abstractmethod def execute(self, expr: ir.Expr) -> pd.DataFrame: - """Execute an ibis expression and return a DataFrame.""" + """Execute an ibis expression and return a DataFrame. Must not be called + for large tables.""" @abstractmethod def sql(self, query: str) -> ir.Table: @@ -165,6 +240,71 @@ def execute_sql_to_df(self, query: str) -> pd.DataFrame: logger.trace("execute_sql_to_df: {}", query) return cast(pd.DataFrame, self.connection.raw_sql(query).fetch_df()) + def read_table(self, name: str, config: TimeBaseModel) -> pd.DataFrame: + """Return the named table as a pandas DataFrame, normalized for this backend.""" + return self.read_query(self.table(name), config) + + def read_query(self, expr: ir.Table, config: TimeBaseModel) -> pd.DataFrame: + """Execute an Ibis expression and return a normalized pandas DataFrame.""" + df = self.execute(expr) + if isinstance(config, _DATETIME_RANGES): + self._post_read_normalize(df, config) + return df + + def write_table( + self, + data: pd.DataFrame | pa.Table, + name: str, + configs: Sequence[TimeBaseModel], + if_exists: str = "append", + ) -> None: + """Write tabular data to the database, applying backend-specific normalization.""" + _check_one_config_per_datetime_column(configs) + prepared = self._prepare_write_data(data, configs) + self._apply_if_exists(prepared, name, if_exists) + + def _post_read_normalize(self, df: pd.DataFrame, config: DatetimeRanges) -> None: + """Backend-specific in-place normalization of a read DataFrame. + + Default: no-op. Backends whose drivers return non-canonical datetime + types should override. + """ + + def _prepare_write_data( + self, + data: pd.DataFrame | pa.Table, + configs: Sequence[TimeBaseModel], + ) -> pd.DataFrame | pa.Table: + """Normalize data before insert/create_table. + + Default behavior is the DuckDB path: accept Arrow natively when possible, + otherwise convert to pandas to normalize tz-sensitive columns. Subclasses + that cannot ingest Arrow directly should convert here. + """ + if isinstance(data, pa.Table) and _arrow_needs_timestamp_normalization(data, configs): + data = data.to_pandas() + if isinstance(data, pd.DataFrame): + data = _normalize_timestamps(data, configs) + return data + + def _apply_if_exists( + self, + data: pd.DataFrame | pa.Table, + name: str, + if_exists: str, + ) -> None: + match if_exists: + case "append": + self.insert(name, data) + case "replace": + self.drop_table(name) + self.create_table(name, data) + case "fail": + self.create_table(name, data) + case _: + msg = f"Invalid if_exists value: {if_exists}" + raise InvalidOperation(msg) + def dispose(self) -> None: """Dispose of the backend connection.""" self.connection.disconnect() diff --git a/src/chronify/ibis/functions.py b/src/chronify/ibis/functions.py deleted file mode 100644 index 9fb6dfe..0000000 --- a/src/chronify/ibis/functions.py +++ /dev/null @@ -1,269 +0,0 @@ -"""Database I/O functions using Ibis backends.""" - -from collections import Counter -from pathlib import Path -from typing import Sequence - -import ibis.expr.types as ir -import pandas as pd -import pyarrow as pa -from pandas import DatetimeTZDtype - -from chronify.exceptions import InvalidOperation, InvalidParameter -from chronify.ibis.base import IbisBackend, ObjectType -from chronify.time import TimeDataType -from chronify.time_configs import ( - DatetimeRange, - DatetimeRangeBase, - DatetimeRangeWithTZColumn, - TimeBaseModel, -) -from chronify.utils.path_utils import check_overwrite - -DatetimeRanges = DatetimeRange | DatetimeRangeWithTZColumn -_DATETIME_RANGES = (DatetimeRange, DatetimeRangeWithTZColumn) - - -def read_table( - backend: IbisBackend, - table_name: str, - config: TimeBaseModel, -) -> pd.DataFrame: - """Read a table from the database.""" - table = backend.table(table_name) - df = backend.execute(table) - - if backend.name == "sqlite" and isinstance(config, _DATETIME_RANGES): - _convert_database_output_for_datetime(df, config) - elif backend.name == "spark" and isinstance(config, _DATETIME_RANGES): - _convert_spark_output_for_datetime(df, config) - - return df - - -def read_query( - backend: IbisBackend, - expr: ir.Table, - config: TimeBaseModel, -) -> pd.DataFrame: - """Execute an Ibis expression and return results.""" - df = backend.execute(expr) - - if backend.name == "sqlite" and isinstance(config, _DATETIME_RANGES): - _convert_database_output_for_datetime(df, config) - elif backend.name == "spark" and isinstance(config, _DATETIME_RANGES): - _convert_spark_output_for_datetime(df, config) - - return df - - -def _normalize_timestamps( - df: pd.DataFrame, - configs: Sequence[TimeBaseModel], -) -> pd.DataFrame: - """Normalize datetime columns so their pandas dtype matches the schema config. - - - TIMESTAMP_NTZ + tz-aware input → convert to UTC, then strip timezone - - TIMESTAMP_TZ + tz-naive input → localize as UTC - - matching dtype → no change - - This runs before any backend-specific handling so that all backends receive - consistently typed data. - """ - copied = False - for config in configs: - if not isinstance(config, _DATETIME_RANGES): - continue - col = config.time_column - if col not in df.columns: - continue - if not pd.api.types.is_datetime64_any_dtype(df[col]): - continue - - is_tz_aware = isinstance(df[col].dtype, DatetimeTZDtype) - - if config.dtype == TimeDataType.TIMESTAMP_NTZ and is_tz_aware: - if not copied: - df = df.copy() - copied = True - df[col] = df[col].dt.tz_convert("UTC").dt.tz_localize(None) - elif config.dtype == TimeDataType.TIMESTAMP_TZ and not is_tz_aware: - if not copied: - df = df.copy() - copied = True - df[col] = df[col].dt.tz_localize("UTC") - - return df - - -def write_table( - backend: IbisBackend, - df: pd.DataFrame | pa.Table, - table_name: str, - configs: Sequence[TimeBaseModel], - if_exists: str = "append", -) -> None: - """Write tabular data to the database.""" - if isinstance(df, pa.Table) and ( - backend.name in {"sqlite", "spark"} or _arrow_needs_timestamp_normalization(df, configs) - ): - df = df.to_pandas() - - _check_one_config_per_datetime_column(configs) - if isinstance(df, pd.DataFrame): - df = _normalize_timestamps(df, configs) - - if backend.name not in {"duckdb", "sqlite", "spark"}: - msg = f"Unsupported backend: {backend.name}" - raise NotImplementedError(msg) - - if backend.name == "sqlite" and isinstance(df, pd.DataFrame): - # SQLite-specific: ensure TZ timestamps are stored as UTC text. - # _normalize_timestamps already ran, so NTZ columns are tz-naive and - # TZ columns are tz-aware UTC. This step converts TZ to UTC for storage. - copied = False - for config in configs: - if isinstance(config, _DATETIME_RANGES): - df, copied = _convert_database_input_for_datetime(df, config, copied) - - _apply_if_exists(backend, df, table_name, if_exists) - - -def write_parquet( - backend: IbisBackend, - query: str | ir.Table, - output_file: Path, - overwrite: bool = False, - partition_columns: list[str] | None = None, - config: TimeBaseModel | None = None, -) -> None: - """Write query results to a Parquet file.""" - check_overwrite(output_file, overwrite) - - if isinstance(query, str): - expr = backend.sql(query) - else: - expr = query - - backend.write_parquet(expr, str(output_file), partition_by=partition_columns) - - -def create_view_from_parquet( - backend: IbisBackend, - filename: Path, - view_name: str, -) -> ObjectType: - """Create a view (or table for SQLite) from a Parquet file. - - Returns the ObjectType created so callers can clean up correctly. - """ - _, obj_type = backend.create_view_from_parquet(str(filename), view_name) - return obj_type - - -def _check_one_config_per_datetime_column(configs: Sequence[TimeBaseModel]) -> None: - time_col_count = Counter( - config.time_column for config in configs if isinstance(config, DatetimeRangeBase) - ) - time_col_dup = {k: v for k, v in time_col_count.items() if v > 1} - if time_col_dup: - msg = f"More than one datetime config found for: {time_col_dup}" - raise InvalidParameter(msg) - - -def _arrow_needs_timestamp_normalization( - table: pa.Table, - configs: Sequence[TimeBaseModel], -) -> bool: - fields = {field.name: field.type for field in table.schema} - for config in configs: - if not isinstance(config, _DATETIME_RANGES): - continue - arrow_type = fields.get(config.time_column) - if arrow_type is None or not pa.types.is_timestamp(arrow_type): - continue - timezone = arrow_type.tz - if config.dtype == TimeDataType.TIMESTAMP_NTZ and timezone is not None: - return True - if config.dtype == TimeDataType.TIMESTAMP_TZ and timezone is None: - return True - return False - - -def _convert_database_input_for_datetime( - df: pd.DataFrame, config: DatetimeRanges, copied: bool -) -> tuple[pd.DataFrame, bool]: - """Convert DataFrame datetime columns for SQLite input (store as UTC).""" - if config.dtype == TimeDataType.TIMESTAMP_NTZ: - return df, copied - - if not copied: - df = df.copy() - copied = True - - if isinstance(df[config.time_column].dtype, DatetimeTZDtype): - df[config.time_column] = df[config.time_column].dt.tz_convert("UTC") - else: - df[config.time_column] = df[config.time_column].dt.tz_localize("UTC") - - return df, copied - - -def _convert_database_output_for_datetime(df: pd.DataFrame, config: DatetimeRanges) -> None: - """Convert DataFrame datetime columns after SQLite output.""" - if config.time_column not in df.columns: - return - - col = df[config.time_column] - if config.dtype == TimeDataType.TIMESTAMP_TZ: - if col.dtype == object: - df[config.time_column] = pd.to_datetime(col, utc=True) - elif isinstance(col.dtype, DatetimeTZDtype): - df[config.time_column] = col.dt.tz_convert("UTC") - else: - df[config.time_column] = col.dt.tz_localize("UTC") - else: - if col.dtype == object: - df[config.time_column] = pd.to_datetime(col, utc=False) - - -def _convert_spark_output_for_datetime(df: pd.DataFrame, config: DatetimeRanges) -> None: - """Convert DataFrame datetime columns after Spark output.""" - if config.time_column not in df.columns: - return - - col = df[config.time_column] - - if config.dtype == TimeDataType.TIMESTAMP_TZ: - if not pd.api.types.is_datetime64_any_dtype(col): - col = pd.to_datetime(col, utc=True) - elif isinstance(col.dtype, DatetimeTZDtype): - col = col.dt.tz_convert("UTC") - else: - col = col.dt.tz_localize("UTC") - df[config.time_column] = col.dt.as_unit("us") - else: - if not pd.api.types.is_datetime64_any_dtype(col): - col = pd.to_datetime(col, utc=False) - df[config.time_column] = col.astype("datetime64[us]") - if isinstance(col.dtype, DatetimeTZDtype): - df[config.time_column] = col.dt.tz_convert(None).astype("datetime64[us]") - - -def _apply_if_exists( - backend: IbisBackend, - df: pd.DataFrame | pa.Table, - table_name: str, - if_exists: str, -) -> None: - match if_exists: - case "append": - backend.insert(table_name, df) - case "replace": - backend.drop_table(table_name) - backend.create_table(table_name, df) - case "fail": - backend.create_table(table_name, df) - case _: - msg = f"Invalid if_exists value: {if_exists}" - raise InvalidOperation(msg) diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index 1408b9e..16724fb 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -3,7 +3,7 @@ import uuid import shutil from contextlib import contextmanager -from typing import Any, Generator, cast +from typing import Any, Generator, Sequence, cast from pathlib import Path from urllib.parse import urlparse, unquote @@ -15,7 +15,15 @@ from pandas import DatetimeTZDtype from chronify.exceptions import InvalidOperation, InvalidParameter -from chronify.ibis.base import IbisBackend, ObjectType, _is_ddl +from chronify.ibis.base import ( + DatetimeRanges, + IbisBackend, + ObjectType, + TimeBaseModel, + TimeDataType, + _is_ddl, + _normalize_timestamps, +) class SparkBackend(IbisBackend): @@ -216,6 +224,20 @@ def _remove_managed_table_location(self, name: str) -> None: if path.exists(): shutil.rmtree(path) + def _post_read_normalize(self, df: pd.DataFrame, config: DatetimeRanges) -> None: + """Spark returns tz-naive nanosecond timestamps; coerce to schema dtype + µs unit.""" + _convert_spark_output_for_datetime(df, config) + + def _prepare_write_data( + self, + data: pd.DataFrame | pa.Table, + configs: Sequence[TimeBaseModel], + ) -> pd.DataFrame: + """Spark ingestion goes through createDataFrame(pandas); Arrow must be converted.""" + if isinstance(data, pa.Table): + data = data.to_pandas() + return _normalize_timestamps(data, configs) + @staticmethod def _prepare_data_for_spark(df: pd.DataFrame) -> pd.DataFrame: """Normalize tz-aware pandas timestamps for Spark ingestion. @@ -241,6 +263,29 @@ def _validate_session(session: Any) -> None: raise InvalidParameter(msg) +def _convert_spark_output_for_datetime(df: pd.DataFrame, config: DatetimeRanges) -> None: + """Convert DataFrame datetime columns after Spark output.""" + if config.time_column not in df.columns: + return + + col = df[config.time_column] + + if config.dtype == TimeDataType.TIMESTAMP_TZ: + if not pd.api.types.is_datetime64_any_dtype(col): + col = pd.to_datetime(col, utc=True) + elif isinstance(col.dtype, DatetimeTZDtype): + col = col.dt.tz_convert("UTC") + else: + col = col.dt.tz_localize("UTC") + df[config.time_column] = col.dt.as_unit("us") + else: + if not pd.api.types.is_datetime64_any_dtype(col): + col = pd.to_datetime(col, utc=False) + df[config.time_column] = col.astype("datetime64[us]") + if isinstance(col.dtype, DatetimeTZDtype): + df[config.time_column] = col.dt.tz_convert(None).astype("datetime64[us]") + + def _validate_insert_columns( table_name: str, target_columns: list[str], data_columns: list[str] ) -> None: diff --git a/src/chronify/ibis/sqlite_backend.py b/src/chronify/ibis/sqlite_backend.py index a629e1a..0088728 100644 --- a/src/chronify/ibis/sqlite_backend.py +++ b/src/chronify/ibis/sqlite_backend.py @@ -3,16 +3,26 @@ import sqlite3 from datetime import datetime from pathlib import Path -from typing import Any, cast +from typing import Any, Sequence, cast import ibis import ibis.expr.types as ir import pandas as pd import pyarrow as pa from loguru import logger +from pandas import DatetimeTZDtype from chronify.exceptions import ConflictingInputsError, InvalidOperation, InvalidParameter -from chronify.ibis.base import IbisBackend, ObjectType, _is_ddl +from chronify.ibis.base import ( + DatetimeRanges, + IbisBackend, + ObjectType, + TimeBaseModel, + TimeDataType, + _DATETIME_RANGES, + _is_ddl, + _normalize_timestamps, +) def _adapt_value(v: Any) -> Any: @@ -220,6 +230,62 @@ def _commit_if_needed(self) -> None: if not self._in_transaction: self._connection.con.commit() + def _post_read_normalize(self, df: pd.DataFrame, config: DatetimeRanges) -> None: + """SQLite stores timestamps as text; coerce back to datetime dtypes.""" + _convert_database_output_for_datetime(df, config) + + def _prepare_write_data( + self, + data: pd.DataFrame | pa.Table, + configs: Sequence[TimeBaseModel], + ) -> pd.DataFrame: + """SQLite cannot ingest Arrow directly; convert to pandas and coerce TZ→UTC text.""" + if isinstance(data, pa.Table): + data = data.to_pandas() + data = _normalize_timestamps(data, configs) + copied = False + for config in configs: + if isinstance(config, _DATETIME_RANGES): + data, copied = _convert_database_input_for_datetime(data, config, copied) + return data + + +def _convert_database_output_for_datetime(df: pd.DataFrame, config: DatetimeRanges) -> None: + """Convert DataFrame datetime columns after SQLite output.""" + if config.time_column not in df.columns: + return + + col = df[config.time_column] + if config.dtype == TimeDataType.TIMESTAMP_TZ: + if col.dtype == object: + df[config.time_column] = pd.to_datetime(col, utc=True) + elif isinstance(col.dtype, DatetimeTZDtype): + df[config.time_column] = col.dt.tz_convert("UTC") + else: + df[config.time_column] = col.dt.tz_localize("UTC") + else: + if col.dtype == object: + df[config.time_column] = pd.to_datetime(col, utc=False) + + +def _convert_database_input_for_datetime( + df: pd.DataFrame, config: DatetimeRanges, copied: bool +) -> tuple[pd.DataFrame, bool]: + """Convert DataFrame datetime columns for SQLite input (store as UTC).""" + if config.dtype == TimeDataType.TIMESTAMP_NTZ: + return df, copied + + if not copied: + df = df.copy() + copied = True + + if isinstance(df[config.time_column].dtype, DatetimeTZDtype): + df[config.time_column] = df[config.time_column].dt.tz_convert("UTC") + else: + df[config.time_column] = df[config.time_column].dt.tz_localize("UTC") + + return df, copied + def _infer_sqlite_path(connection: ibis.BaseBackend) -> str | None: """Return the database file path for an ibis SQLite connection, or None for in-memory.""" diff --git a/src/chronify/store.py b/src/chronify/store.py index c5ec792..a191beb 100644 --- a/src/chronify/store.py +++ b/src/chronify/store.py @@ -19,13 +19,6 @@ ) from chronify.csv_io import read_csv from chronify.ibis import IbisBackend, ObjectType, make_backend -from chronify.ibis.functions import ( - create_view_from_parquet, - read_query, - read_table, - write_parquet, - write_table, -) from chronify.models import ( CsvTableSchema, PivotedTableSchema, @@ -182,7 +175,7 @@ def create_view_from_parquet( def _create_view_from_parquet(self, path: Path | str, schema: TableSchema) -> "ObjectType": """Create a view in the database from a Parquet file.""" - obj_type = create_view_from_parquet(self._backend, to_path(path), schema.name) + _, obj_type = self._backend.create_view_from_parquet(str(to_path(path)), schema.name) self._schema_mgr.add_schema(schema) return obj_type @@ -389,23 +382,11 @@ def _ingest_table( check_columns(df.columns, schema.list_columns()) if not self._backend.has_table(schema.name): - write_table( - self._backend, - df, - schema.name, - [schema.time_config], - if_exists="fail", - ) + self._backend.write_table(df, schema.name, [schema.time_config], if_exists="fail") self._schema_mgr.add_schema(schema) return True else: - write_table( - self._backend, - df, - schema.name, - [schema.time_config], - if_exists="append", - ) + self._backend.write_table(df, schema.name, [schema.time_config], if_exists="append") return False def map_table_time_config( @@ -733,12 +714,12 @@ def read_query( expr = self._backend.sql(query) else: expr = query - return read_query(self._backend, expr, schema.time_config) + return self._backend.read_query(expr, schema.time_config) def read_table(self, name: str) -> pd.DataFrame: """Return the table as a pandas DataFrame.""" schema = self._schema_mgr.get_schema(name) - return read_table(self._backend, name, schema.time_config) + return self._backend.read_table(name, schema.time_config) def read_raw_query(self, query: str) -> pd.DataFrame: """Execute a query directly on the backend and return the results as a DataFrame. @@ -783,15 +764,10 @@ def write_query_to_parquet( Optional table/view name used to look up the time config for backend-specific timestamp normalization (e.g. Spark). """ - config = self._schema_mgr.get_schema(name).time_config if name else None - write_parquet( - self._backend, - stmt, - to_path(file_path), - overwrite=overwrite, - partition_columns=partition_columns, - config=config, - ) + output_file = to_path(file_path) + check_overwrite(output_file, overwrite) + expr = self._backend.sql(stmt) if isinstance(stmt, str) else stmt + self._backend.write_parquet(expr, str(output_file), partition_by=partition_columns) def write_table_to_parquet( self, @@ -806,14 +782,9 @@ def write_table_to_parquet( raise TableNotStored(msg) expr = self._backend.table(name) - write_parquet( - self._backend, - expr, - to_path(file_path), - overwrite=overwrite, - partition_columns=partition_columns, - config=self._schema_mgr.get_schema(name).time_config, - ) + output_file = to_path(file_path) + check_overwrite(output_file, overwrite) + self._backend.write_parquet(expr, str(output_file), partition_by=partition_columns) logger.info("Wrote table or view to {}", file_path) def delete_rows( diff --git a/src/chronify/time_series_checker.py b/src/chronify/time_series_checker.py index 8bd444c..6f1d471 100644 --- a/src/chronify/time_series_checker.py +++ b/src/chronify/time_series_checker.py @@ -5,7 +5,6 @@ from chronify.exceptions import InvalidTable from chronify.ibis.base import IbisBackend -from chronify.ibis.functions import read_query from chronify.models import TableSchema from chronify.time_configs import DatetimeRangeWithTZColumn from chronify.time_range_generator_factory import make_time_range_generator @@ -72,7 +71,7 @@ def _check_expected_timestamps_datetime(self) -> int: expr = table.select(time_columns).distinct() for col in time_columns: expr = expr.filter(table[col].notnull()) - df = read_query(self._backend, expr, self._schema.time_config) + df = self._backend.read_query(expr, self._schema.time_config) actual = self._time_generator.list_distinct_timestamps_from_dataframe(df) expected = sorted(set(expected)) # drop duplicates for tz-naive prevailing time check_timestamp_lists(actual, expected) @@ -91,7 +90,7 @@ def _check_expected_timestamps_with_external_time_zone(self) -> int: expr = table.select(time_columns).distinct() for col in time_columns: expr = expr.filter(table[col].notnull()) - df = read_query(self._backend, expr, self._schema.time_config) + df = self._backend.read_query(expr, self._schema.time_config) actual_dct = self._time_generator.list_distinct_timestamps_by_time_zone_from_dataframe(df) if sorted(expected_dct.keys()) != sorted(actual_dct.keys()): msg = ( diff --git a/src/chronify/time_series_mapper_base.py b/src/chronify/time_series_mapper_base.py index 3059c85..022be4f 100644 --- a/src/chronify/time_series_mapper_base.py +++ b/src/chronify/time_series_mapper_base.py @@ -6,13 +6,12 @@ from loguru import logger from chronify.ibis.base import IbisBackend, ObjectType -from chronify.ibis.functions import write_parquet, write_table, create_view_from_parquet from chronify.models import TableSchema, MappingTableSchema from chronify.exceptions import ConflictingInputsError, InvalidOperation from chronify.time_series_checker import check_timestamps from chronify.time import TimeIntervalType, ResamplingOperationType, AggregationType from chronify.time_configs import TimeBasedDataAdjustment -from chronify.utils.path_utils import to_path +from chronify.utils.path_utils import check_overwrite, to_path class TimeSeriesMapperBase(abc.ABC): @@ -87,8 +86,7 @@ def apply_mapping( check_mapped_timestamps: bool = False, ) -> None: """Apply mapping to create result table with process to clean up and roll back if checks fail.""" - write_table( - backend, + backend.write_table( df_mapping, mapping_schema.name, mapping_schema.time_configs, @@ -107,7 +105,9 @@ def apply_mapping( if check_mapped_timestamps: if output_file is not None: output_file = to_path(output_file) - created_tmp_obj = create_view_from_parquet(backend, output_file, to_schema.name) + _, created_tmp_obj = backend.create_view_from_parquet( + str(output_file), to_schema.name + ) try: check_timestamps( backend, @@ -212,7 +212,8 @@ def _right_col(col: str) -> Any: if output_file is not None: output_file = to_path(output_file) - write_parquet(backend, result, output_file, overwrite=True, config=to_schema.time_config) + check_overwrite(output_file, overwrite=True) + backend.write_parquet(result, str(output_file)) return backend.create_table(to_schema.name, result, overwrite=True) diff --git a/src/chronify/time_series_mapper_column_representative_to_datetime.py b/src/chronify/time_series_mapper_column_representative_to_datetime.py index 6f4debb..fbf46e4 100644 --- a/src/chronify/time_series_mapper_column_representative_to_datetime.py +++ b/src/chronify/time_series_mapper_column_representative_to_datetime.py @@ -7,7 +7,6 @@ from chronify.exceptions import InvalidParameter, InvalidValue from chronify.ibis.base import IbisBackend -from chronify.ibis.functions import write_table from chronify.time_series_mapper_base import TimeSeriesMapperBase, apply_mapping from chronify.time_configs import ( YearMonthDayHourTimeNTZ, @@ -136,8 +135,7 @@ def _intermediate_mapping_ymdp_to_ymdh(self) -> TableSchema: f"SELECT DISTINCT {period_col} FROM {self._from_schema.name}" ) df_mapping = generate_period_mapping(df_periods.iloc[:, 0]) - write_table( - self._backend, + self._backend.write_table( df_mapping, mapping_table_name, [self._from_time_config], diff --git a/src/chronify/time_series_mapper_index_time.py b/src/chronify/time_series_mapper_index_time.py index 6812db7..f528735 100644 --- a/src/chronify/time_series_mapper_index_time.py +++ b/src/chronify/time_series_mapper_index_time.py @@ -7,7 +7,6 @@ import pandas as pd from chronify.ibis.base import IbisBackend -from chronify.ibis.functions import read_query from chronify.models import TableSchema, MappingTableSchema from chronify.exceptions import InvalidParameter, ConflictingInputsError from chronify.time_series_mapper_base import TimeSeriesMapperBase, apply_mapping @@ -208,7 +207,7 @@ def _create_interm_map_with_time_zone( table = self._backend.table(self._from_schema.name) expr = table.select(tz_col).distinct().filter(table[tz_col].notnull()) - time_zones = read_query(self._backend, expr, self._from_time_config)[tz_col].to_list() + time_zones = self._backend.read_query(expr, self._from_time_config)[tz_col].to_list() from_time_config = self._from_time_config.model_copy( update={"time_column": from_time_col, "time_zone_column": from_tz_col} @@ -272,7 +271,7 @@ def _create_interm_map_with_time_zone_and_dst_adjustment( table = self._backend.table(self._from_schema.name) expr = table.select(tz_col).distinct().filter(table[tz_col].notnull()) - time_zones = read_query(self._backend, expr, self._from_time_config)[tz_col].to_list() + time_zones = self._backend.read_query(expr, self._from_time_config)[tz_col].to_list() from_time_config = self._from_time_config.model_copy( update={"time_column": from_time_col, "time_zone_column": from_tz_col} diff --git a/src/chronify/time_series_mapper_representative.py b/src/chronify/time_series_mapper_representative.py index 93674bf..106d780 100644 --- a/src/chronify/time_series_mapper_representative.py +++ b/src/chronify/time_series_mapper_representative.py @@ -5,7 +5,6 @@ import pandas as pd from chronify.ibis.base import IbisBackend -from chronify.ibis.functions import read_query from chronify.models import TableSchema, MappingTableSchema from chronify.exceptions import InvalidParameter from chronify.time_range_generator_factory import make_time_range_generator @@ -96,7 +95,7 @@ def _create_mapping(self, is_tz_naive: bool) -> tuple[pd.DataFrame, MappingTable assert tz_col is not None, "Expecting a time zone column for REPRESENTATIVE time" table = self._backend.table(self._from_schema.name) expr = table.select(tz_col).distinct().filter(table[tz_col].notnull()) - time_zones = read_query(self._backend, expr, self._from_time_config)[tz_col].to_list() + time_zones = self._backend.read_query(expr, self._from_time_config)[tz_col].to_list() df = self._generator.create_tz_aware_mapping_dataframe( dft, time_col, time_zones, tz_col ) diff --git a/src/chronify/time_zone_converter.py b/src/chronify/time_zone_converter.py index 4e0447e..59312aa 100644 --- a/src/chronify/time_zone_converter.py +++ b/src/chronify/time_zone_converter.py @@ -6,7 +6,6 @@ import pandas as pd from chronify.ibis.base import IbisBackend -from chronify.ibis.functions import read_query from chronify.models import TableSchema, MappingTableSchema from chronify.time_configs import ( DatetimeRangeBase, @@ -280,7 +279,7 @@ def _get_time_zones(self) -> list[tzinfo | None]: .distinct() .filter(table[self.time_zone_column].notnull()) ) - time_zones = read_query(self._backend, expr, self._from_schema.time_config)[ + time_zones = self._backend.read_query(expr, self._from_schema.time_config)[ self.time_zone_column ].to_list() time_zones = [None if tz == "None" else ZoneInfo(tz) for tz in time_zones] diff --git a/src/chronify/time_zone_localizer.py b/src/chronify/time_zone_localizer.py index 48e8338..165a7c8 100644 --- a/src/chronify/time_zone_localizer.py +++ b/src/chronify/time_zone_localizer.py @@ -8,7 +8,6 @@ from pandas import DatetimeTZDtype from chronify.ibis.base import IbisBackend -from chronify.ibis.functions import read_query from chronify.models import TableSchema, MappingTableSchema from chronify.time_configs import ( DatetimeRangeBase, @@ -343,7 +342,7 @@ def _get_time_zones(self) -> list[tzinfo | None]: .distinct() .filter(table[self.time_zone_column].notnull()) ) - time_zones = read_query(self._backend, expr, self._from_schema.time_config)[ + time_zones = self._backend.read_query(expr, self._from_schema.time_config)[ self.time_zone_column ].to_list() diff --git a/tests/test_checker_representative_time.py b/tests/test_checker_representative_time.py index 452cce3..d5dc297 100644 --- a/tests/test_checker_representative_time.py +++ b/tests/test_checker_representative_time.py @@ -3,7 +3,6 @@ import pandas as pd from chronify.ibis import IbisBackend -from chronify.ibis.functions import write_table from chronify.models import TableSchema from chronify.time_series_checker import check_timestamps from chronify.exceptions import InvalidTable @@ -12,7 +11,7 @@ def ingest_data_and_check( backend: IbisBackend, df: pd.DataFrame, schema: TableSchema, error: tuple[any, str] ) -> None: - write_table(backend, df, schema.name, [schema.time_config], if_exists="replace") + backend.write_table(df, schema.name, [schema.time_config], if_exists="replace") if error: with pytest.raises(error[0], match=error[1]): diff --git a/tests/test_ibis_functions.py b/tests/test_ibis_functions.py index 6a101c7..9dce67e 100644 --- a/tests/test_ibis_functions.py +++ b/tests/test_ibis_functions.py @@ -1,4 +1,4 @@ -"""Tests for ibis/functions.py edge cases and uncovered branches.""" +"""Tests for IbisBackend read/write helpers and their normalization hooks.""" from datetime import datetime, timedelta from zoneinfo import ZoneInfo @@ -9,13 +9,12 @@ from chronify.exceptions import InvalidOperation, InvalidParameter from chronify.ibis import make_backend -from chronify.ibis.functions import ( +from chronify.ibis.base import ( _check_one_config_per_datetime_column, - _convert_database_output_for_datetime, - _convert_spark_output_for_datetime, _normalize_timestamps, - write_table, ) +from chronify.ibis.sqlite_backend import _convert_database_output_for_datetime +from chronify.ibis.spark_backend import _convert_spark_output_for_datetime from chronify.time import TimeIntervalType from chronify.time_configs import DatetimeRange @@ -189,7 +188,7 @@ def test_pyarrow_table_input(self): "value": [1.0, 2.0], } ) - write_table(backend, pa_table, "pa_test", [config], if_exists="fail") + backend.write_table(pa_table, "pa_test", [config], if_exists="fail") assert backend.has_table("pa_test") df = backend.execute(backend.table("pa_test")) assert len(df) == 2 @@ -214,7 +213,7 @@ def create_table(name, obj=None, schema=None, overwrite=False): ) monkeypatch.setattr(backend, "create_table", create_table) - write_table(backend, pa_table, "pa_test_arrow", [config], if_exists="fail") + backend.write_table(pa_table, "pa_test_arrow", [config], if_exists="fail") assert seen_arrow backend.dispose() @@ -227,22 +226,7 @@ def test_invalid_if_exists_duckdb(self): "value": [1.0, 2.0], } ) - write_table(backend, df, "test_tbl", [config], if_exists="fail") + backend.write_table(df, "test_tbl", [config], if_exists="fail") with pytest.raises(InvalidOperation, match="Invalid if_exists"): - write_table(backend, df, "test_tbl", [config], if_exists="invalid") + backend.write_table(df, "test_tbl", [config], if_exists="invalid") backend.dispose() - - def test_unsupported_backend(self): - """A backend with an unknown name should raise NotImplementedError.""" - from unittest.mock import MagicMock - - backend = MagicMock() - backend.name = "unknown_db" - config = _make_ntz_config() - df = pd.DataFrame( - { - "timestamp": pd.to_datetime(["2020-01-01 00:00:00", "2020-01-01 01:00:00"]), - } - ) - with pytest.raises(NotImplementedError, match="Unsupported backend"): - write_table(backend, df, "test", [config], if_exists="fail") diff --git a/tests/test_mapper_column_representative_to_datetime.py b/tests/test_mapper_column_representative_to_datetime.py index 8c469d8..80c9058 100644 --- a/tests/test_mapper_column_representative_to_datetime.py +++ b/tests/test_mapper_column_representative_to_datetime.py @@ -13,7 +13,6 @@ ) from chronify.models import TableSchema, PivotedTableSchema from chronify.store import Store -from chronify.ibis.functions import write_table, read_query from chronify.time_series_mapper import map_time @@ -24,7 +23,7 @@ def iter_store(iter_backends): def ingest_csv(backend, csv_file: Path, name: str, time_configs: list[TimeConfig]): data = pd.read_csv(csv_file) - write_table(backend, data, name, time_configs, if_exists="replace") + backend.write_table(data, name, time_configs, if_exists="replace") def test_MDH_mapper(time_series_NMDH, iter_store: Store): @@ -68,7 +67,7 @@ def test_MDH_mapper(time_series_NMDH, iter_store: Store): map_time(iter_store.backend, from_schema, to_schema, check_mapped_timestamps=True) expr = iter_store.backend.sql(f"SELECT * FROM {to_schema.name}") - mapped_table = read_query(iter_store.backend, expr, to_schema.time_config).sort_values( + mapped_table = iter_store.backend.read_query(expr, to_schema.time_config).sort_values( "timestamp" ) assert np.array_equal(mapped_table["value"].to_numpy(), np.arange(25, 73)) @@ -117,7 +116,7 @@ def test_YMDH_mapper(time_series_NYMDH, iter_store): map_time(iter_store.backend, from_schema, to_schema, check_mapped_timestamps=True) expr = iter_store.backend.sql(f"SELECT * FROM {to_schema.name}") - mapped_table = read_query(iter_store.backend, expr, to_schema.time_config).sort_values( + mapped_table = iter_store.backend.read_query(expr, to_schema.time_config).sort_values( "timestamp" ) assert np.array_equal(mapped_table["value"].to_numpy(), np.arange(25, 73)) @@ -156,7 +155,7 @@ def test_NYMDPV_mapper(time_series_NYMDPV, iter_store: Store): map_time(iter_store.backend, from_schema, to_schema, check_mapped_timestamps=True) expr = iter_store.backend.sql(f"SELECT * FROM {to_schema.name}") - mapped_table = read_query(iter_store.backend, expr, to_schema.time_config).sort_values( + mapped_table = iter_store.backend.read_query(expr, to_schema.time_config).sort_values( "timestamp" ) values = np.concatenate( diff --git a/tests/test_mapper_datetime_to_datetime.py b/tests/test_mapper_datetime_to_datetime.py index d3b895e..26ae6ec 100644 --- a/tests/test_mapper_datetime_to_datetime.py +++ b/tests/test_mapper_datetime_to_datetime.py @@ -8,7 +8,6 @@ import pandas as pd from chronify.ibis import IbisBackend -from chronify.ibis.functions import read_query, write_table from chronify.time_series_mapper import map_time from chronify.time_configs import DatetimeRange from chronify.models import TableSchema @@ -62,7 +61,7 @@ def ingest_data( df: pd.DataFrame, schema: TableSchema, ) -> None: - write_table(backend, df, schema.name, [schema.time_config], if_exists="replace") + backend.write_table(df, schema.name, [schema.time_config], if_exists="replace") def run_test_with_error( @@ -87,7 +86,7 @@ def get_mapped_results( map_time(backend, from_schema, to_schema, check_mapped_timestamps=True) expr = backend.sql(f"select * from {to_schema.name}") - queried = read_query(backend, expr, to_schema.time_config) + queried = backend.read_query(expr, to_schema.time_config) queried = queried.sort_values(by=["id", "timestamp"]).reset_index(drop=True)[df.columns] return queried @@ -248,4 +247,4 @@ def test_duplicated_configs_in_write_table( configs = [schema.time_config, schema.time_config] with pytest.raises(InvalidParameter, match="More than one datetime config found"): - write_table(iter_backends, df, schema.name, configs, if_exists="replace") + iter_backends.write_table(df, schema.name, configs, if_exists="replace") diff --git a/tests/test_mapper_index_time_to_datetime.py b/tests/test_mapper_index_time_to_datetime.py index d0a6e9b..ead681e 100644 --- a/tests/test_mapper_index_time_to_datetime.py +++ b/tests/test_mapper_index_time_to_datetime.py @@ -5,7 +5,6 @@ from typing import Any, Optional from chronify.ibis import IbisBackend -from chronify.ibis.functions import read_query, write_table from chronify.time_series_mapper import map_time from chronify.time_configs import ( DatetimeRange, @@ -148,7 +147,7 @@ def run_test( wrap_time_allowed: bool = False, ) -> None: # Ingest - write_table(backend, df, from_schema.name, [from_schema.time_config], if_exists="replace") + backend.write_table(df, from_schema.name, [from_schema.time_config], if_exists="replace") # Map if error: @@ -174,7 +173,7 @@ def run_test( def get_output_table(backend: IbisBackend, to_schema: TableSchema) -> pd.DataFrame: expr = backend.sql(f"select * from {to_schema.name}") - queried = read_query(backend, expr, to_schema.time_config) + queried = backend.read_query(expr, to_schema.time_config) return queried diff --git a/tests/test_mapper_representative_time_to_datetime.py b/tests/test_mapper_representative_time_to_datetime.py index f496f93..c7d726c 100644 --- a/tests/test_mapper_representative_time_to_datetime.py +++ b/tests/test_mapper_representative_time_to_datetime.py @@ -6,7 +6,6 @@ import pandas as pd from chronify.ibis import IbisBackend -from chronify.ibis.functions import read_query, write_table from chronify.time_series_mapper import map_time from chronify.time_configs import DatetimeRange from chronify.models import TableSchema @@ -46,7 +45,7 @@ def run_test( error: Optional[tuple[Any, str]], ) -> None: # Ingest - write_table(backend, df, from_schema.name, [from_schema.time_config], if_exists="replace") + backend.write_table(df, from_schema.name, [from_schema.time_config], if_exists="replace") # Map if error: @@ -57,7 +56,7 @@ def run_test( # Check mapped table expr = backend.sql(f"select * from {to_schema.name}") - queried = read_query(backend, expr, to_schema.time_config) + queried = backend.read_query(expr, to_schema.time_config) queried = queried.sort_values(by=["id", "timestamp"]).reset_index(drop=True) truth = generate_datetime_data(to_schema.time_config) diff --git a/tests/test_time_series_checker.py b/tests/test_time_series_checker.py index 5beef1e..cb164ca 100644 --- a/tests/test_time_series_checker.py +++ b/tests/test_time_series_checker.py @@ -6,7 +6,6 @@ import pytest from chronify.ibis import IbisBackend -from chronify.ibis.functions import write_table from chronify.exceptions import InvalidTable from chronify.models import TableSchema from chronify.time import TimeIntervalType @@ -73,7 +72,7 @@ def _run_test( time_array_id_columns=["generator"], value_column="value", ) - write_table(backend, df, schema.name, [schema.time_config], if_exists="replace") + backend.write_table(df, schema.name, [schema.time_config], if_exists="replace") if message is None: check_timestamps(backend, schema.name, schema) diff --git a/tests/test_time_zone_converter.py b/tests/test_time_zone_converter.py index 96d322e..59b9744 100644 --- a/tests/test_time_zone_converter.py +++ b/tests/test_time_zone_converter.py @@ -7,7 +7,6 @@ import pandas as pd from chronify.ibis import IbisBackend -from chronify.ibis.functions import read_query, write_table from chronify.time_zone_converter import ( TimeZoneConverter, TimeZoneConverterByColumn, @@ -89,7 +88,7 @@ def ingest_data( df: pd.DataFrame, schema: TableSchema, ) -> None: - write_table(backend, df, schema.name, [schema.time_config], if_exists="replace") + backend.write_table(df, schema.name, [schema.time_config], if_exists="replace") def get_mapped_dataframe( @@ -98,7 +97,7 @@ def get_mapped_dataframe( time_config: DatetimeRange, ) -> pd.DataFrame: expr = backend.sql(f"select * from {table_name}") - queried = read_query(backend, expr, time_config) + queried = backend.read_query(expr, time_config) queried = queried.sort_values(by=["id", "timestamp"]).reset_index(drop=True) return queried diff --git a/tests/test_time_zone_localizer.py b/tests/test_time_zone_localizer.py index 843ff8f..905565f 100644 --- a/tests/test_time_zone_localizer.py +++ b/tests/test_time_zone_localizer.py @@ -7,7 +7,6 @@ import pandas as pd from chronify.ibis import IbisBackend -from chronify.ibis.functions import read_query, write_table from chronify.time_utils import get_standard_time_zone from chronify.time_zone_localizer import ( TimeZoneLocalizer, @@ -132,7 +131,7 @@ def ingest_data( df: pd.DataFrame, schema: TableSchema, ) -> None: - write_table(backend, df, schema.name, [schema.time_config], if_exists="replace") + backend.write_table(df, schema.name, [schema.time_config], if_exists="replace") def get_mapped_dataframe( @@ -141,7 +140,7 @@ def get_mapped_dataframe( time_config: DatetimeRangeBase, ) -> pd.DataFrame: expr = backend.sql(f"select * from {table_name}") - queried = read_query(backend, expr, time_config) + queried = backend.read_query(expr, time_config) queried = queried.sort_values(by=["id", "timestamp"]).reset_index(drop=True) return queried From 04cffe00203cf3377d3dc638a3644f5cfd830b0d Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Sun, 12 Apr 2026 17:19:36 -0600 Subject: [PATCH 23/48] Return ibis.Table from Store.read_query/read_table MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Expose Ibis natively through the Store API: read_query and read_table now return ibis.Table expressions. Callers that want a DataFrame call .execute() on the result. read_query drops the name parameter since no schema-driven pandas normalization happens on return. Remove the SQLite read-side timestamp normalization hook; Ibis's native execute() coerces stored text to datetime via the table's schema. Simplify the write-side canonicalization to an inline tz_convert("UTC") — SQLite stores timestamps as text, so all tz-aware columns must share a single offset representation for text-based joins to align. Drop the now-unused IbisBackend.read_table helper. Co-Authored-By: Claude Opus 4.6 --- docs/how_tos/getting_started/quick_start.md | 4 +- docs/how_tos/map_time_config.md | 4 +- docs/how_tos/spark_backend.md | 2 +- src/chronify/ibis/base.py | 4 -- src/chronify/ibis/sqlite_backend.py | 65 ++++++--------------- src/chronify/store.py | 34 +++++------ tests/test_ibis_functions.py | 43 -------------- tests/test_spark_backend.py | 15 +++-- tests/test_store.py | 60 +++++++++---------- 9 files changed, 79 insertions(+), 152 deletions(-) diff --git a/docs/how_tos/getting_started/quick_start.md b/docs/how_tos/getting_started/quick_start.md index 3cab29f..1773c91 100644 --- a/docs/how_tos/getting_started/quick_start.md +++ b/docs/how_tos/getting_started/quick_start.md @@ -28,8 +28,8 @@ store.ingest_tables( time_array_id_columns=["id"], ) ) -query = "SELECT timestamp, value FROM devices WHERE id = ?" -df = store.read_query("devices", query, params=(2,)) +query = "SELECT timestamp, value FROM devices WHERE id = 2" +df = store.read_query(query).execute() df.head() ``` diff --git a/docs/how_tos/map_time_config.md b/docs/how_tos/map_time_config.md index c45f823..265514d 100644 --- a/docs/how_tos/map_time_config.md +++ b/docs/how_tos/map_time_config.md @@ -51,7 +51,7 @@ schema = TableSchema( ) store = Store.create_in_memory_db() store.ingest_table(df, schema) -store.read_query(src_table_name, f"SELECT * FROM {src_table_name} LIMIT 5").head() +store.read_query(f"SELECT * FROM {src_table_name} LIMIT 5").execute().head() ``` ``` @@ -77,7 +77,7 @@ dst_schema = TableSchema( ) ) store.map_table_time_config(src_table_name, dst_schema) -store.read_query(dst_table_name, f"SELECT * FROM {dst_table_name} LIMIT 5").head() +store.read_query(f"SELECT * FROM {dst_table_name} LIMIT 5").execute().head() ``` ``` diff --git a/docs/how_tos/spark_backend.md b/docs/how_tos/spark_backend.md index 5da0ba3..096c47f 100644 --- a/docs/how_tos/spark_backend.md +++ b/docs/how_tos/spark_backend.md @@ -77,7 +77,7 @@ store.create_view_from_parquet("data.parquet") Verify the data: ```python -store.read_table(schema.name).head() +store.read_table(schema.name).execute().head() ``` ``` timestamp id value diff --git a/src/chronify/ibis/base.py b/src/chronify/ibis/base.py index 43eaa8f..450824a 100644 --- a/src/chronify/ibis/base.py +++ b/src/chronify/ibis/base.py @@ -240,10 +240,6 @@ def execute_sql_to_df(self, query: str) -> pd.DataFrame: logger.trace("execute_sql_to_df: {}", query) return cast(pd.DataFrame, self.connection.raw_sql(query).fetch_df()) - def read_table(self, name: str, config: TimeBaseModel) -> pd.DataFrame: - """Return the named table as a pandas DataFrame, normalized for this backend.""" - return self.read_query(self.table(name), config) - def read_query(self, expr: ir.Table, config: TimeBaseModel) -> pd.DataFrame: """Execute an Ibis expression and return a normalized pandas DataFrame.""" df = self.execute(expr) diff --git a/src/chronify/ibis/sqlite_backend.py b/src/chronify/ibis/sqlite_backend.py index 0088728..d296188 100644 --- a/src/chronify/ibis/sqlite_backend.py +++ b/src/chronify/ibis/sqlite_backend.py @@ -14,15 +14,14 @@ from chronify.exceptions import ConflictingInputsError, InvalidOperation, InvalidParameter from chronify.ibis.base import ( - DatetimeRanges, IbisBackend, ObjectType, TimeBaseModel, - TimeDataType, _DATETIME_RANGES, _is_ddl, _normalize_timestamps, ) +from chronify.time import TimeDataType def _adapt_value(v: Any) -> Any: @@ -230,63 +229,37 @@ def _commit_if_needed(self) -> None: if not self._in_transaction: self._connection.con.commit() - def _post_read_normalize(self, df: pd.DataFrame, config: DatetimeRanges) -> None: - """SQLite stores timestamps as text; coerce back to datetime dtypes.""" - _convert_database_output_for_datetime(df, config) - def _prepare_write_data( self, data: pd.DataFrame | pa.Table, configs: Sequence[TimeBaseModel], ) -> pd.DataFrame: - """SQLite cannot ingest Arrow directly; convert to pandas and coerce TZ→UTC text.""" + """SQLite stores timestamps as text, so joins compare raw strings. + + Canonicalize all tz-aware columns to UTC on write so joins between columns + written from different source zones (e.g., source table in ``Etc/GMT+5`` + vs. a mapping table localized from tz-naive input) align. + """ if isinstance(data, pa.Table): data = data.to_pandas() data = _normalize_timestamps(data, configs) copied = False for config in configs: - if isinstance(config, _DATETIME_RANGES): - data, copied = _convert_database_input_for_datetime(data, config, copied) + if not isinstance(config, _DATETIME_RANGES): + continue + if config.dtype != TimeDataType.TIMESTAMP_TZ: + continue + if config.time_column not in data.columns: + continue + if not isinstance(data[config.time_column].dtype, DatetimeTZDtype): + continue + if not copied: + data = data.copy() + copied = True + data[config.time_column] = data[config.time_column].dt.tz_convert("UTC") return data -def _convert_database_output_for_datetime(df: pd.DataFrame, config: DatetimeRanges) -> None: - """Convert DataFrame datetime columns after SQLite output.""" - if config.time_column not in df.columns: - return - - col = df[config.time_column] - if config.dtype == TimeDataType.TIMESTAMP_TZ: - if col.dtype == object: - df[config.time_column] = pd.to_datetime(col, utc=True) - elif isinstance(col.dtype, DatetimeTZDtype): - df[config.time_column] = col.dt.tz_convert("UTC") - else: - df[config.time_column] = col.dt.tz_localize("UTC") - else: - if col.dtype == object: - df[config.time_column] = pd.to_datetime(col, utc=False) - - -def _convert_database_input_for_datetime( - df: pd.DataFrame, config: DatetimeRanges, copied: bool -) -> tuple[pd.DataFrame, bool]: - """Convert DataFrame datetime columns for SQLite input (store as UTC).""" - if config.dtype == TimeDataType.TIMESTAMP_NTZ: - return df, copied - - if not copied: - df = df.copy() - copied = True - - if isinstance(df[config.time_column].dtype, DatetimeTZDtype): - df[config.time_column] = df[config.time_column].dt.tz_convert("UTC") - else: - df[config.time_column] = df[config.time_column].dt.tz_localize("UTC") - - return df, copied - - def _infer_sqlite_path(connection: ibis.BaseBackend) -> str | None: """Return the database file path for an ibis SQLite connection, or None for in-memory.""" try: diff --git a/src/chronify/store.py b/src/chronify/store.py index a191beb..f1628a4 100644 --- a/src/chronify/store.py +++ b/src/chronify/store.py @@ -695,31 +695,29 @@ def localize_time_zone_by_column( self._schema_mgr.add_schema(dst_schema) return dst_schema - def read_query( - self, - name: str, - query: ir.Table | str, - ) -> pd.DataFrame: - """Return the query result as a pandas DataFrame. + def read_query(self, query: ir.Table | str) -> ir.Table: + """Return the query result as an Ibis Table expression. + + Call ``.execute()`` on the returned expression to materialize a pandas DataFrame. Parameters ---------- - name - Table or view name query - SQL query as a string or ibis Table expression + SQL query as a string or ibis Table expression. """ - schema = self._schema_mgr.get_schema(name) if isinstance(query, str): - expr = self._backend.sql(query) - else: - expr = query - return self._backend.read_query(expr, schema.time_config) + return self._backend.sql(query) + return query - def read_table(self, name: str) -> pd.DataFrame: - """Return the table as a pandas DataFrame.""" - schema = self._schema_mgr.get_schema(name) - return self._backend.read_table(name, schema.time_config) + def read_table(self, name: str) -> ir.Table: + """Return the table as an Ibis Table expression. + + Call ``.execute()`` on the returned expression to materialize a pandas DataFrame. + """ + if not self.has_table(name): + msg = f"{name=}" + raise TableNotStored(msg) + return self._backend.table(name) def read_raw_query(self, query: str) -> pd.DataFrame: """Execute a query directly on the backend and return the results as a DataFrame. diff --git a/tests/test_ibis_functions.py b/tests/test_ibis_functions.py index 9dce67e..fac558d 100644 --- a/tests/test_ibis_functions.py +++ b/tests/test_ibis_functions.py @@ -13,7 +13,6 @@ _check_one_config_per_datetime_column, _normalize_timestamps, ) -from chronify.ibis.sqlite_backend import _convert_database_output_for_datetime from chronify.ibis.spark_backend import _convert_spark_output_for_datetime from chronify.time import TimeIntervalType from chronify.time_configs import DatetimeRange @@ -82,48 +81,6 @@ def test_duplicate_config_raises(self): _check_one_config_per_datetime_column(configs) -class TestConvertDatabaseOutputForDatetime: - def test_tz_with_object_dtype(self): - config = _make_tz_config() - df = pd.DataFrame({"timestamp": ["2020-01-01 00:00:00", "2020-01-01 01:00:00"]}) - _convert_database_output_for_datetime(df, config) - assert isinstance(df["timestamp"].dtype, pd.DatetimeTZDtype) - - def test_tz_with_tz_aware_dtype(self): - config = _make_tz_config() - df = pd.DataFrame( - { - "timestamp": pd.to_datetime( - ["2020-01-01 00:00:00+05:00", "2020-01-01 01:00:00+05:00"] - ), - } - ) - _convert_database_output_for_datetime(df, config) - assert str(df["timestamp"].dt.tz) == "UTC" - - def test_tz_with_naive_dtype(self): - config = _make_tz_config() - df = pd.DataFrame( - { - "timestamp": pd.to_datetime(["2020-01-01 00:00:00", "2020-01-01 01:00:00"]), - } - ) - _convert_database_output_for_datetime(df, config) - assert isinstance(df["timestamp"].dtype, pd.DatetimeTZDtype) - - def test_ntz_with_object_dtype(self): - config = _make_ntz_config() - df = pd.DataFrame({"timestamp": ["2020-01-01 00:00:00", "2020-01-01 01:00:00"]}) - _convert_database_output_for_datetime(df, config) - assert pd.api.types.is_datetime64_any_dtype(df["timestamp"]) - - def test_missing_column_is_noop(self): - config = _make_tz_config() - df = pd.DataFrame({"other": [1, 2]}) - _convert_database_output_for_datetime(df, config) - assert list(df.columns) == ["other"] - - class TestConvertSparkOutputForDatetime: def test_tz_with_object_dtype(self): config = _make_tz_config() diff --git a/tests/test_spark_backend.py b/tests/test_spark_backend.py index 1c680fe..6f7c9b9 100644 --- a/tests/test_spark_backend.py +++ b/tests/test_spark_backend.py @@ -65,7 +65,12 @@ def test_spark_round_trip_timestamp_tz_preserves_fractional_seconds(spark_store: ) spark_store.ingest_table(df, schema, skip_time_checks=True) - out = spark_store.read_table(schema.name).sort_values("timestamp").reset_index(drop=True) + out = ( + spark_store.read_table(schema.name) + .execute() + .sort_values("timestamp") + .reset_index(drop=True) + ) expected = pd.to_datetime( [ @@ -169,7 +174,7 @@ def test_spark_ingest_normalizes_tz_aware_to_ntz(spark_store: Store) -> None: } ) spark_store.ingest_table(df, schema, skip_time_checks=True) - out = spark_store.read_table(schema.name) + out = spark_store.read_table(schema.name).execute() # Should be tz-naive after round-trip assert not isinstance(out["timestamp"].dtype, pd.DatetimeTZDtype) assert out["timestamp"].iloc[0] == pd.Timestamp("2020-01-01 00:00:00") @@ -197,7 +202,7 @@ def test_spark_ingest_normalizes_tz_naive_to_tz(spark_store: Store) -> None: } ) spark_store.ingest_table(df, schema, skip_time_checks=True) - out = spark_store.read_table(schema.name) + out = spark_store.read_table(schema.name).execute() assert isinstance(out["timestamp"].dtype, pd.DatetimeTZDtype) assert out["timestamp"].iloc[0] == pd.Timestamp("2020-01-01 00:00:00+00:00") @@ -230,7 +235,7 @@ def test_spark_time_zone_conversion(spark_store: Store) -> None: to_tz = ZoneInfo("US/Eastern") dst_schema = spark_store.convert_time_zone(schema.name, to_tz) - out = spark_store.read_table(dst_schema.name) + out = spark_store.read_table(dst_schema.name).execute() expected = df["timestamp"].dt.tz_convert(to_tz).dt.tz_localize(None) out_sorted = out.sort_values("timestamp").reset_index(drop=True) assert list(out_sorted["timestamp"]) == list(expected) @@ -268,7 +273,7 @@ def test_spark_delete_rows(spark_store: Store) -> None: spark_store.ingest_table(df, schema, skip_time_checks=True) count = spark_store.delete_rows(schema.name, {"id": 1}) assert count == 2 - out = spark_store.read_table(schema.name) + out = spark_store.read_table(schema.name).execute() assert len(out) == 2 assert set(out["id"]) == {2} diff --git a/tests/test_store.py b/tests/test_store.py index ec9d512..e9d9cd1 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -116,7 +116,7 @@ def test_ingest_csv(iter_stores_by_engine: Store, tmp_path, generators_schema, u name="timestamp", dtype=dt.Timestamp(timezone="Etc/GMT+5") ) store.ingest_from_csv(src_file, src_schema, dst_schema) - df = store.read_table(dst_schema.name) + df = store.read_table(dst_schema.name).execute() assert len(df) == 8784 * 3 new_file = tmp_path / "gen2.csv" @@ -144,14 +144,12 @@ def test_ingest_csv(iter_stores_by_engine: Store, tmp_path, generators_schema, u time_array_id_columns=[], ) store.ingest_from_csv(new_file, src_schema2, dst_schema) - df = store.read_table(dst_schema.name) + df = store.read_table(dst_schema.name).execute() assert len(df) == 8784 * 3 * 2 all(df.timestamp.unique() == expected_timestamps) # Read a subset of the table. - df2 = store.read_query( - dst_schema.name, f"SELECT * FROM {dst_schema.name} WHERE generator = 'gen2'" - ) + df2 = store.read_query(f"SELECT * FROM {dst_schema.name} WHERE generator = 'gen2'").execute() assert len(df2) == 8784 df_gen2 = df[df["generator"] == "gen2"] assert all((df2.values == df_gen2.values)[0]) @@ -182,7 +180,7 @@ def test_ingest_csvs_with_rollback(tmp_path, multiple_tables): ) store.ingest_from_csvs((src_file1, src_file2), src_schema, dst_schema) - df = store.read_table(dst_schema.name) + df = store.read_table(dst_schema.name).execute() assert len(df) == len(tables[0]) + len(tables[1]) assert len(df.id.unique()) == 2 @@ -191,7 +189,7 @@ def test_ingest_multiple_tables(iter_stores_by_engine: Store, multiple_tables): store = iter_stores_by_engine tables, schema = multiple_tables store.ingest_tables(tables, schema) - df = store.read_query("devices", "SELECT * FROM devices WHERE id = 2") + df = store.read_query("SELECT * FROM devices WHERE id = 2").execute() df["timestamp"] = df["timestamp"].astype("datetime64[ns]") assert df.equals(tables[1]) @@ -207,7 +205,7 @@ def test_ingest_multiple_tables_error(iter_stores_by_engine: Store, multiple_tab tables[1].loc[8783] = (tables[1].loc[8783]["timestamp"], 0.1, orig_value) store.ingest_tables(tables, schema) - df = store.read_query(schema.name, f"select * from {schema.name} where id=2") + df = store.read_query(f"select * from {schema.name} where id=2").execute() df["timestamp"] = df["timestamp"].astype("datetime64[ns]") assert df.equals(tables[1]) @@ -222,7 +220,7 @@ def test_ingest_pivoted_table(iter_stores_by_engine: Store, generators_schema, u store.ingest_pivoted_table(input_table, pivoted_schema, dst_schema) table = store.get_table(dst_schema.name) stmt = table.filter(table.generator == "gen1") - df = store.read_query(dst_schema.name, stmt) + df = store.read_query(stmt).execute() assert len(df) == 8784 @@ -239,7 +237,7 @@ def test_ingest_invalid_csv(iter_stores_by_engine: Store, tmp_path, generators_s with pytest.raises(InvalidTable): store.ingest_from_csv(new_file, src_schema, dst_schema) with pytest.raises(TableNotStored): - store.read_table(dst_schema.name) + store.read_table(dst_schema.name).execute() def test_invalid_schema(iter_stores_by_engine: Store, generators_schema): @@ -257,7 +255,7 @@ def test_ingest_one_week_per_month_by_hour( df, num_time_arrays, schema = one_week_per_month_by_hour_table store.ingest_table(df, schema) - df2 = store.read_table(schema.name) + df2 = store.read_table(schema.name).execute() assert len(df2["id"].unique()) == num_time_arrays assert len(df2) == 24 * 7 * 12 * num_time_arrays columns = schema.time_config.list_time_columns() @@ -314,7 +312,7 @@ def test_load_parquet(iter_stores_by_engine_no_data_ingestion: Store, tmp_path): out_file = tmp_path / "gen2.parquet" rel2.to_parquet(str(out_file)) store.create_view_from_parquet(out_file, dst_schema) - df = store.read_table(dst_schema.name) + df = store.read_table(dst_schema.name).execute() assert len(df) == 8784 * 3 timestamp_generator = make_time_range_generator(time_config) expected_timestamps = timestamp_generator.list_timestamps() @@ -325,7 +323,7 @@ def test_load_parquet(iter_stores_by_engine_no_data_ingestion: Store, tmp_path): as_dict["name"] = "test_view" schema2 = TableSchema(**as_dict) store.create_view_from_parquet(out_file, schema2) - df2 = store.read_table(schema2.name) + df2 = store.read_table(schema2.name).execute() assert schema2.name in store.list_tables() assert len(df2) == 8784 * 3 timestamp_generator = make_time_range_generator(time_config) @@ -334,7 +332,7 @@ def test_load_parquet(iter_stores_by_engine_no_data_ingestion: Store, tmp_path): store.drop_view(schema2.name) assert schema2.name not in store.list_tables() assert dst_schema.name in store.list_tables() - df3 = store.read_table(dst_schema.name) + df3 = store.read_table(dst_schema.name).execute() assert len(df3) == 8784 * 3 @@ -375,7 +373,7 @@ def test_map_one_week_per_month_by_hour_to_datetime( ) store.ingest_table(df, src_schema) store.map_table_time_config(src_schema.name, dst_schema, check_mapped_timestamps=True) - df2 = store.read_table(dst_schema.name) + df2 = store.read_table(dst_schema.name).execute() assert len(df2) == time_array_len * num_time_arrays actual = sorted(df2["timestamp"].unique()) expected = make_time_range_generator(dst_schema.time_config).list_timestamps() @@ -452,7 +450,7 @@ def test_map_datetime_to_datetime( src_schema.name, dst_schema, output_file=output_file, check_mapped_timestamps=True ) if output_file is None or store.backend.name == "sqlite": - df2 = store.read_table(dst_schema.name) + df2 = store.read_table(dst_schema.name).execute() else: df2 = pd.read_parquet(output_file) assert len(df2) == time_array_len * 3 @@ -526,7 +524,7 @@ def test_map_index_time_to_datetime( ), ) if output_file is None or store.backend.name == "sqlite": - result = store.read_table(dst_schema.name) + result = store.read_table(dst_schema.name).execute() else: result = pd.read_parquet(output_file) @@ -571,12 +569,12 @@ def test_load_existing_store(iter_backends_file, one_week_per_month_by_hour_tabl df, _, schema = one_week_per_month_by_hour_table store = Store(backend=backend) store.ingest_table(df, schema) - df2 = store.read_table(schema.name) + df2 = store.read_table(schema.name).execute() assert df2.equals(df) file_path = Path(backend.database) assert file_path.exists() store2 = Store.load_from_file(backend_name=backend_name, file_path=file_path) - df3 = store2.read_table(schema.name) + df3 = store2.read_table(schema.name).execute() assert df3.equals(df2) with pytest.raises(FileNotFoundError): Store.load_from_file(backend_name=backend_name, file_path="./invalid/path") @@ -624,7 +622,7 @@ def test_backup(iter_backends_file, one_week_per_month_by_hour_table, tmp_path): store.backup(dst_file) assert dst_file.exists() store2 = Store(backend_name=backend_name, file_path=dst_file) - df2 = store2.read_table(schema.name) + df2 = store2.read_table(schema.name).execute() assert df2.equals(df) with pytest.raises(InvalidOperation): @@ -634,7 +632,7 @@ def test_backup(iter_backends_file, one_week_per_month_by_hour_table, tmp_path): store.backup(dst_file2, overwrite=True) # Make sure the original still works. - df3 = store.read_table(schema.name) + df3 = store.read_table(schema.name).execute() assert df3.equals(df) @@ -654,20 +652,20 @@ def test_delete_rows(iter_stores_by_engine: Store, one_week_per_month_by_hour_ta store = iter_stores_by_engine df, _, schema = one_week_per_month_by_hour_table store.ingest_table(df, schema) - df2 = store.read_table(schema.name) + df2 = store.read_table(schema.name).execute() assert df2.equals(df) assert sorted(df2["id"].unique()) == [1, 2, 3] with pytest.raises(InvalidParameter): store.delete_rows(schema.name, {}) store.delete_rows(schema.name, {"id": 2}) - df3 = store.read_table(schema.name) + df3 = store.read_table(schema.name).execute() assert sorted(df3["id"].unique()) == [1, 3] store.delete_rows(schema.name, {"id": 1}) - df4 = store.read_table(schema.name) + df4 = store.read_table(schema.name).execute() assert sorted(df4["id"].unique()) == [3] store.delete_rows(schema.name, {"id": 3}) with pytest.raises(TableNotStored): - store.read_table(schema.name) + store.read_table(schema.name).execute() with pytest.raises(TableNotStored): store.delete_rows(schema.name, {"id": 3}) @@ -677,11 +675,11 @@ def test_drop_table(iter_stores_by_engine: Store, one_week_per_month_by_hour_tab df, _, schema = one_week_per_month_by_hour_table assert not store.list_tables() store.ingest_table(df, schema) - assert store.read_table(schema.name).equals(df) + assert store.read_table(schema.name).execute().equals(df) assert store.list_tables() == [schema.name] store.drop_table(schema.name) with pytest.raises(TableNotStored): - store.read_table(schema.name) + store.read_table(schema.name).execute() assert not store.list_tables() with pytest.raises(TableNotStored): store.drop_table(schema.name) @@ -772,7 +770,7 @@ def test_convert_time_zone( src_schema.name, to_time_zone, output_file=output_file, check_mapped_timestamps=True ) if output_file is None or store.backend.name == "sqlite": - df2 = store.read_table(dst_schema.name) + df2 = store.read_table(dst_schema.name).execute() else: df2 = pd.read_parquet(output_file) df2["timestamp"] = pd.to_datetime(df2["timestamp"]) @@ -848,7 +846,7 @@ def test_convert_time_zone_by_column( check_mapped_timestamps=True, ) if output_file is None or store.backend.name == "sqlite": - df2 = store.read_table(dst_schema.name) + df2 = store.read_table(dst_schema.name).execute() else: df2 = pd.read_parquet(output_file) df2["timestamp"] = pd.to_datetime(df2["timestamp"]) @@ -921,7 +919,7 @@ def test_localize_time_zone( check_mapped_timestamps=True, ) if output_file is None or store.backend.name == "sqlite": - df2 = store.read_table(dst_schema.name) + df2 = store.read_table(dst_schema.name).execute() else: df2 = pd.read_parquet(output_file) df2["timestamp"] = pd.to_datetime(df2["timestamp"]) @@ -995,7 +993,7 @@ def test_localize_time_zone_by_column(tmp_path, iter_stores_by_engine_no_data_in check_mapped_timestamps=True, ) if output_file is None or store.backend.name == "sqlite": - df2 = store.read_table(dst_schema.name) + df2 = store.read_table(dst_schema.name).execute() else: df2 = pd.read_parquet(output_file) df2["timestamp"] = pd.to_datetime(df2["timestamp"]) From 652a7a27fc1432ebc6358398850cc2760295316a Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Sun, 12 Apr 2026 18:11:15 -0600 Subject: [PATCH 24/48] Remove table cache and collapse pure-delegation backend methods The _table_cache added complexity (DDL-detection invalidation, per-method mark-created/dropped hooks) for a has_table() shortcut that's now a single list_tables() call. Moved create_table/create_view/drop_table/drop_view/ list_tables/table/execute/sql default implementations into IbisBackend so subclasses only override when they have real backend-specific logic. Co-Authored-By: Claude Opus 4.6 --- src/chronify/ibis/base.py | 71 +++++------------------------ src/chronify/ibis/duckdb_backend.py | 45 +----------------- src/chronify/ibis/spark_backend.py | 42 +---------------- src/chronify/ibis/sqlite_backend.py | 43 ++--------------- 4 files changed, 18 insertions(+), 183 deletions(-) diff --git a/src/chronify/ibis/base.py b/src/chronify/ibis/base.py index 450824a..3c8570b 100644 --- a/src/chronify/ibis/base.py +++ b/src/chronify/ibis/base.py @@ -1,6 +1,5 @@ """Abstract base class for Ibis database backends.""" -import re from abc import ABC, abstractmethod from collections import Counter from contextlib import contextmanager @@ -23,20 +22,10 @@ TimeBaseModel, ) -_DDL_RE = re.compile( - r"^\s*(?:WITH\s+.+?\s+)?(CREATE|DROP|ALTER|TRUNCATE|RENAME)\b", - re.IGNORECASE | re.DOTALL, -) - _DATETIME_RANGES: tuple[type, ...] = (DatetimeRange, DatetimeRangeWithTZColumn) DatetimeRanges = DatetimeRange | DatetimeRangeWithTZColumn -def _is_ddl(query: str) -> bool: - """Return True if the SQL statement changes the set of tables/views.""" - return _DDL_RE.match(query) is not None - - def _check_one_config_per_datetime_column(configs: Sequence[TimeBaseModel]) -> None: time_col_count = Counter( config.time_column for config in configs if isinstance(config, DatetimeRangeBase) @@ -105,8 +94,6 @@ class ObjectType(StrEnum): class IbisBackend(ABC): """Abstract base class defining the interface for Ibis database backends.""" - _table_cache: set[str] | None - @property @abstractmethod def name(self) -> str: @@ -122,7 +109,6 @@ def database(self) -> str | None: def connection(self) -> ibis.BaseBackend: """Return the underlying ibis connection.""" - @abstractmethod def create_table( self, name: str, @@ -130,43 +116,28 @@ def create_table( schema: ibis.Schema | None = None, overwrite: bool = False, ) -> ir.Table: - """Create a table in the database. - - Parameters - ---------- - name - Table name. - obj - Data to populate the table with. - schema - Schema to use if obj is None. - overwrite - If True, replace the table if it already exists. - - Returns - ------- - ir.Table - """ + """Create a table in the database.""" + return self.connection.create_table(name, obj=obj, schema=schema, overwrite=overwrite) - @abstractmethod def create_view(self, name: str, expr: ir.Table) -> ir.Table: """Create a view in the database.""" + return self.connection.create_view(name, expr, overwrite=False) - @abstractmethod def drop_table(self, name: str) -> None: """Drop a table from the database.""" + self.connection.drop_table(name, force=True) - @abstractmethod def drop_view(self, name: str) -> None: """Drop a view from the database.""" + self.connection.drop_view(name, force=True) - @abstractmethod def list_tables(self) -> list[str]: """List all user tables in the database.""" + return cast(list[str], self.connection.list_tables()) - @abstractmethod def table(self, name: str) -> ir.Table: """Return an ibis table expression for the named table.""" + return self.connection.table(name) @abstractmethod def insert(self, name: str, data: pd.DataFrame | pa.Table) -> None: @@ -180,14 +151,14 @@ def delete_rows(self, name: str, values: dict[str, Any]) -> None: SQL injection and to handle values containing quote characters. """ - @abstractmethod def execute(self, expr: ir.Expr) -> pd.DataFrame: """Execute an ibis expression and return a DataFrame. Must not be called for large tables.""" + return cast(pd.DataFrame, self.connection.execute(expr)) - @abstractmethod def sql(self, query: str) -> ir.Table: """Create an ibis table expression from a raw SQL string.""" + return self.connection.sql(query) @abstractmethod def write_parquet( @@ -208,32 +179,12 @@ def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, Obje def has_table(self, name: str) -> bool: """Check whether a table or view exists.""" - if self._table_cache is None: - self._refresh_table_cache() - assert self._table_cache is not None - return name in self._table_cache - - def _refresh_table_cache(self) -> None: - self._table_cache = set(self.list_tables()) - - def _mark_table_created(self, name: str) -> None: - if self._table_cache is not None: - self._table_cache.add(name) - - def _mark_table_dropped(self, name: str) -> None: - if self._table_cache is not None: - self._table_cache.discard(name) - - def _invalidate_table_cache(self) -> None: - self._table_cache = None + return name in self.list_tables() def execute_sql(self, query: str) -> Any: """Execute a raw SQL statement (no result expected).""" logger.trace("execute_sql: {}", query) - result = self.connection.raw_sql(query) - if _is_ddl(query): - self._invalidate_table_cache() - return result + return self.connection.raw_sql(query) def execute_sql_to_df(self, query: str) -> pd.DataFrame: """Execute a raw SQL query and return a DataFrame.""" diff --git a/src/chronify/ibis/duckdb_backend.py b/src/chronify/ibis/duckdb_backend.py index d9702e9..d6dd0c4 100644 --- a/src/chronify/ibis/duckdb_backend.py +++ b/src/chronify/ibis/duckdb_backend.py @@ -11,7 +11,7 @@ from loguru import logger from chronify.exceptions import ConflictingInputsError, InvalidOperation, InvalidParameter -from chronify.ibis.base import IbisBackend, ObjectType, _is_ddl +from chronify.ibis.base import IbisBackend, ObjectType class DuckDBBackend(IbisBackend): @@ -39,7 +39,6 @@ def __init__( msg = f"{database=} and {connection=} cannot both be set" raise ConflictingInputsError(msg) - self._table_cache = None self._owns_connection = connection is None if connection is None: db = str(database) @@ -64,39 +63,10 @@ def database(self) -> str | None: def connection(self) -> ibis.BaseBackend: return self._connection - def create_table( - self, - name: str, - obj: pd.DataFrame | pa.Table | ir.Table | None = None, - schema: ibis.Schema | None = None, - overwrite: bool = False, - ) -> ir.Table: - table = self._connection.create_table(name, obj=obj, schema=schema, overwrite=overwrite) - self._mark_table_created(name) - return table - - def create_view(self, name: str, expr: ir.Table) -> ir.Table: - view = self._connection.create_view(name, expr, overwrite=False) - self._mark_table_created(name) - return view - - def drop_table(self, name: str) -> None: - self._connection.drop_table(name, force=True) - self._mark_table_dropped(name) - - def drop_view(self, name: str) -> None: - self._connection.drop_view(name, force=True) - self._mark_table_dropped(name) - def list_tables(self) -> list[str]: tables = self._connection.list_tables() # Filter out internal ibis memtables - tables = [t for t in tables if not t.startswith("ibis_pandas_memtable_")] - self._table_cache = set(tables) - return tables - - def table(self, name: str) -> ir.Table: - return self._connection.table(name) + return [t for t in tables if not t.startswith("ibis_pandas_memtable_")] def insert(self, name: str, data: pd.DataFrame | pa.Table) -> None: con = self._connection.con # raw duckdb connection @@ -131,9 +101,6 @@ def execute(self, expr: ir.Expr) -> pd.DataFrame: return cast(pd.DataFrame, self._connection.con.execute(sql).fetch_df()) return cast(pd.DataFrame, self._connection.execute(expr)) - def sql(self, query: str) -> ir.Table: - return self._connection.sql(query) - def write_parquet( self, expr: ir.Table, @@ -162,15 +129,8 @@ def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, Obje self._connection.raw_sql( f"CREATE VIEW {quoted_name} AS SELECT * FROM read_parquet('{escaped_path}')" ) - self._mark_table_created(name) return self.table(name), ObjectType.VIEW - def execute_sql(self, query: str) -> None: - logger.trace("execute_sql: {}", query) - self._connection.raw_sql(query) - if _is_ddl(query): - self._invalidate_table_cache() - def execute_sql_to_df(self, query: str) -> pd.DataFrame: logger.trace("execute_sql_to_df: {}", query) result = self._connection.raw_sql(query) @@ -202,7 +162,6 @@ def _commit_transaction(self) -> None: def _rollback_transaction(self) -> None: self._connection.con.execute("ROLLBACK") - self._invalidate_table_cache() def _infer_duckdb_path(connection: ibis.BaseBackend) -> str | None: diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index 16724fb..0868de8 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -21,7 +21,6 @@ ObjectType, TimeBaseModel, TimeDataType, - _is_ddl, _normalize_timestamps, ) @@ -48,7 +47,6 @@ def __init__(self, session: Any = None) -> None: .getOrCreate() ) self._validate_session(session) - self._table_cache = None self._session = session self._connection = ibis.pyspark.connect(session) @@ -74,39 +72,12 @@ def create_table( if isinstance(obj, pd.DataFrame): obj = self._prepare_data_for_spark(obj) try: - table = self._connection.create_table( - name, obj=obj, schema=schema, overwrite=overwrite - ) + return self._connection.create_table(name, obj=obj, schema=schema, overwrite=overwrite) except Exception as exc: if "LOCATION_ALREADY_EXISTS" not in str(exc): raise self._remove_managed_table_location(name) - table = self._connection.create_table( - name, obj=obj, schema=schema, overwrite=overwrite - ) - self._mark_table_created(name) - return table - - def create_view(self, name: str, expr: ir.Table) -> ir.Table: - view = self._connection.create_view(name, expr, overwrite=False) - self._mark_table_created(name) - return view - - def drop_table(self, name: str) -> None: - self._connection.drop_table(name, force=True) - self._mark_table_dropped(name) - - def drop_view(self, name: str) -> None: - self._connection.drop_view(name, force=True) - self._mark_table_dropped(name) - - def list_tables(self) -> list[str]: - tables = cast(list[str], self._connection.list_tables()) - self._table_cache = set(tables) - return tables - - def table(self, name: str) -> ir.Table: - return self._connection.table(name) + return self._connection.create_table(name, obj=obj, schema=schema, overwrite=overwrite) def insert(self, name: str, data: pd.DataFrame | pa.Table) -> None: if isinstance(data, pa.Table): @@ -165,12 +136,6 @@ def _overwrite_without_deleted_rows(self, name: str, where: str, args: dict[str, self._session.sql(f"DROP TABLE IF EXISTS {quoted_tmp}") self._remove_managed_table_location(tmp_name) - def execute(self, expr: ir.Expr) -> pd.DataFrame: - return cast(pd.DataFrame, self._connection.execute(expr)) - - def sql(self, query: str) -> ir.Table: - return self._connection.sql(query) - def write_parquet( self, expr: ir.Table, @@ -191,14 +156,11 @@ def _to_spark_dataframe(self, expr: ir.Table) -> Any: def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, ObjectType]: spark_df = self._session.read.parquet(path) spark_df.createOrReplaceTempView(name) - self._mark_table_created(name) return self.table(name), ObjectType.VIEW def execute_sql(self, query: str) -> None: logger.trace("execute_sql: {}", query) self._session.sql(query) - if _is_ddl(query): - self._invalidate_table_cache() def execute_sql_to_df(self, query: str) -> pd.DataFrame: logger.trace("execute_sql_to_df: {}", query) diff --git a/src/chronify/ibis/sqlite_backend.py b/src/chronify/ibis/sqlite_backend.py index d296188..a8e3b4e 100644 --- a/src/chronify/ibis/sqlite_backend.py +++ b/src/chronify/ibis/sqlite_backend.py @@ -3,7 +3,7 @@ import sqlite3 from datetime import datetime from pathlib import Path -from typing import Any, Sequence, cast +from typing import Any, Sequence import ibis import ibis.expr.types as ir @@ -18,7 +18,6 @@ ObjectType, TimeBaseModel, _DATETIME_RANGES, - _is_ddl, _normalize_timestamps, ) from chronify.time import TimeDataType @@ -64,7 +63,6 @@ def __init__( msg = f"{database=} and {connection=} cannot both be set" raise ConflictingInputsError(msg) - self._table_cache = None self._in_transaction = False self._owns_connection = connection is None if connection is None: @@ -101,34 +99,8 @@ def create_table( # SQLite CREATE TABLE AS SELECT loses datetime type info. # Execute the expression first, then create from the DataFrame. df = self._connection.execute(obj) - table = self._connection.create_table(name, obj=df, overwrite=overwrite) - else: - table = self._connection.create_table( - name, obj=obj, schema=schema, overwrite=overwrite - ) - self._mark_table_created(name) - return table - - def create_view(self, name: str, expr: ir.Table) -> ir.Table: - view = self._connection.create_view(name, expr, overwrite=False) - self._mark_table_created(name) - return view - - def drop_table(self, name: str) -> None: - self._connection.drop_table(name, force=True) - self._mark_table_dropped(name) - - def drop_view(self, name: str) -> None: - self._connection.drop_view(name, force=True) - self._mark_table_dropped(name) - - def list_tables(self) -> list[str]: - tables = cast(list[str], self._connection.list_tables()) - self._table_cache = set(tables) - return tables - - def table(self, name: str) -> ir.Table: - return self._connection.table(name) + return self._connection.create_table(name, obj=df, overwrite=overwrite) + return self._connection.create_table(name, obj=obj, schema=schema, overwrite=overwrite) def insert(self, name: str, data: pd.DataFrame | pa.Table) -> None: if isinstance(data, pa.Table): @@ -159,12 +131,6 @@ def delete_rows(self, name: str, values: dict[str, Any]) -> None: self._commit_if_needed() logger.trace("Deleted rows from {} matching {}", name, values) - def execute(self, expr: ir.Expr) -> pd.DataFrame: - return cast(pd.DataFrame, self._connection.execute(expr)) - - def sql(self, query: str) -> ir.Table: - return self._connection.sql(query) - def write_parquet( self, expr: ir.Table, @@ -187,8 +153,6 @@ def execute_sql(self, query: str) -> None: con = self._connection.con con.execute(query) self._commit_if_needed() - if _is_ddl(query): - self._invalidate_table_cache() def execute_sql_to_df(self, query: str) -> pd.DataFrame: logger.trace("execute_sql_to_df: {}", query) @@ -223,7 +187,6 @@ def _commit_transaction(self) -> None: def _rollback_transaction(self) -> None: self._connection.con.rollback() self._in_transaction = False - self._invalidate_table_cache() def _commit_if_needed(self) -> None: if not self._in_transaction: From 4cd5c4c4dc3f27cfb85fb53046b35474fd51e020 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Sun, 12 Apr 2026 18:30:33 -0600 Subject: [PATCH 25/48] Fix mypy strict-mode errors in ibis backends Drop the tuple[type, ...] annotation on _DATETIME_RANGES so mypy infers the concrete tuple type and narrows through isinstance() checks. Import TimeBaseModel and TimeDataType from their original modules in the sqlite and spark backends; the indirect re-export through chronify.ibis.base violates strict --no-implicit-reexport. Co-Authored-By: Claude Opus 4.6 --- src/chronify/ibis/base.py | 2 +- src/chronify/ibis/spark_backend.py | 4 ++-- src/chronify/ibis/sqlite_backend.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/chronify/ibis/base.py b/src/chronify/ibis/base.py index 3c8570b..636ae44 100644 --- a/src/chronify/ibis/base.py +++ b/src/chronify/ibis/base.py @@ -22,7 +22,7 @@ TimeBaseModel, ) -_DATETIME_RANGES: tuple[type, ...] = (DatetimeRange, DatetimeRangeWithTZColumn) +_DATETIME_RANGES = (DatetimeRange, DatetimeRangeWithTZColumn) DatetimeRanges = DatetimeRange | DatetimeRangeWithTZColumn diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index 0868de8..f48c6c4 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -19,10 +19,10 @@ DatetimeRanges, IbisBackend, ObjectType, - TimeBaseModel, - TimeDataType, _normalize_timestamps, ) +from chronify.time import TimeDataType +from chronify.time_configs import TimeBaseModel class SparkBackend(IbisBackend): diff --git a/src/chronify/ibis/sqlite_backend.py b/src/chronify/ibis/sqlite_backend.py index a8e3b4e..ac731ab 100644 --- a/src/chronify/ibis/sqlite_backend.py +++ b/src/chronify/ibis/sqlite_backend.py @@ -16,11 +16,11 @@ from chronify.ibis.base import ( IbisBackend, ObjectType, - TimeBaseModel, _DATETIME_RANGES, _normalize_timestamps, ) from chronify.time import TimeDataType +from chronify.time_configs import TimeBaseModel def _adapt_value(v: Any) -> Any: From beaabc7112104f08da21bf101898cc92b880f620 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Sun, 12 Apr 2026 18:45:09 -0600 Subject: [PATCH 26/48] Use ibis to_parquet in backend write_parquet implementations Each ibis backend already provides a native to_parquet: - DuckDB generates COPY (...) TO ... with PARTITION_BY - PySpark calls df.write.format("parquet").save() - Base (SQLite) streams via pyarrow.ParquetWriter Move write_parquet into IbisBackend as a concrete default that delegates to self.connection.to_parquet, raising NotImplementedError when partition_by is requested. DuckDB and Spark keep thin overrides to forward partitioning kwargs (partition_by / partitionBy). SQLite's override is removed entirely. Co-Authored-By: Claude Opus 4.6 --- src/chronify/ibis/base.py | 5 ++++- src/chronify/ibis/duckdb_backend.py | 10 ++-------- src/chronify/ibis/spark_backend.py | 10 ++-------- src/chronify/ibis/sqlite_backend.py | 12 ------------ 4 files changed, 8 insertions(+), 29 deletions(-) diff --git a/src/chronify/ibis/base.py b/src/chronify/ibis/base.py index 636ae44..f166932 100644 --- a/src/chronify/ibis/base.py +++ b/src/chronify/ibis/base.py @@ -160,7 +160,6 @@ def sql(self, query: str) -> ir.Table: """Create an ibis table expression from a raw SQL string.""" return self.connection.sql(query) - @abstractmethod def write_parquet( self, expr: ir.Table, @@ -168,6 +167,10 @@ def write_parquet( partition_by: list[str] | None = None, ) -> None: """Write an ibis expression result to a Parquet file.""" + if partition_by: + msg = f"{self.name} backend does not support partitioned Parquet writes." + raise NotImplementedError(msg) + self.connection.to_parquet(expr, path) @abstractmethod def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, ObjectType]: diff --git a/src/chronify/ibis/duckdb_backend.py b/src/chronify/ibis/duckdb_backend.py index d6dd0c4..f436372 100644 --- a/src/chronify/ibis/duckdb_backend.py +++ b/src/chronify/ibis/duckdb_backend.py @@ -107,16 +107,10 @@ def write_parquet( path: str, partition_by: list[str] | None = None, ) -> None: - escaped_path = path.replace("'", "''") - sql = self._connection.compile(expr) if partition_by: - partition_clause = ", ".join(_quote_identifier(c) for c in partition_by) - self._connection.raw_sql( - f"COPY ({sql}) TO '{escaped_path}' " - f"(FORMAT PARQUET, PARTITION_BY ({partition_clause}))" - ) + self._connection.to_parquet(expr, path, partition_by=partition_by) else: - self._connection.raw_sql(f"COPY ({sql}) TO '{escaped_path}' (FORMAT PARQUET)") + self._connection.to_parquet(expr, path) def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, ObjectType]: parquet_path = Path(path) diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index f48c6c4..e6769b0 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -142,16 +142,10 @@ def write_parquet( path: str, partition_by: list[str] | None = None, ) -> None: - df = self._to_spark_dataframe(expr) - writer = df.write.mode("errorifexists") if partition_by: - writer.partitionBy(*partition_by).parquet(path) + self._connection.to_parquet(expr, path, partitionBy=partition_by) else: - writer.parquet(path) - - def _to_spark_dataframe(self, expr: ir.Table) -> Any: - sql = self._connection.compile(expr) - return self._session.sql(sql) + self._connection.to_parquet(expr, path) def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, ObjectType]: spark_df = self._session.read.parquet(path) diff --git a/src/chronify/ibis/sqlite_backend.py b/src/chronify/ibis/sqlite_backend.py index ac731ab..ba93869 100644 --- a/src/chronify/ibis/sqlite_backend.py +++ b/src/chronify/ibis/sqlite_backend.py @@ -131,18 +131,6 @@ def delete_rows(self, name: str, values: dict[str, Any]) -> None: self._commit_if_needed() logger.trace("Deleted rows from {} matching {}", name, values) - def write_parquet( - self, - expr: ir.Table, - path: str, - partition_by: list[str] | None = None, - ) -> None: - if partition_by: - msg = "SQLite backend does not support partitioned Parquet writes." - raise NotImplementedError(msg) - df = self._connection.execute(expr) - df.to_parquet(path) - def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, ObjectType]: # SQLite can't read Parquet natively. Load into a table instead. df = pd.read_parquet(path) From aaa08c329fa23614c600d892e6ac72d52246dca4 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Sun, 12 Apr 2026 19:00:30 -0600 Subject: [PATCH 27/48] Use Ibis natives in Store and drop chronify.duckdb helpers - _ingest_pivoted_table: use ibis.Table.pivot_longer instead of DuckDB UNPIVOT - Accept pandas / ibis.Table throughout (drop DuckDBPyRelation from Store API) - Delete chronify.duckdb.functions (unpivot + unused add_datetime_column) - Migrate tests off unpivot() helper to pandas.melt Co-Authored-By: Claude Opus 4.6 --- src/chronify/duckdb/__init__.py | 0 src/chronify/duckdb/functions.py | 46 --------------- src/chronify/ibis/base.py | 6 +- src/chronify/ibis/duckdb_backend.py | 5 -- src/chronify/ibis/spark_backend.py | 6 +- src/chronify/ibis/sqlite_backend.py | 8 --- src/chronify/store.py | 56 ++++++++---------- tests/test_store.py | 89 +++++++++++++++++++---------- 8 files changed, 86 insertions(+), 130 deletions(-) delete mode 100644 src/chronify/duckdb/__init__.py delete mode 100644 src/chronify/duckdb/functions.py diff --git a/src/chronify/duckdb/__init__.py b/src/chronify/duckdb/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/chronify/duckdb/functions.py b/src/chronify/duckdb/functions.py deleted file mode 100644 index 1762392..0000000 --- a/src/chronify/duckdb/functions.py +++ /dev/null @@ -1,46 +0,0 @@ -from collections.abc import Iterable -from datetime import datetime, timedelta - -import duckdb -from duckdb import DuckDBPyRelation - - -def add_datetime_column( - rel: DuckDBPyRelation, - start: datetime, - resolution: timedelta, - length: int, - time_array_id_columns: Iterable[str], - time_column: str, - timestamps: list[datetime], -) -> DuckDBPyRelation: - """Add a datetime column to the relation.""" - raise NotImplementedError - # values = [] - # columns = ",".join(rel.columns) - # return duckdb.sql( - # f""" - # SELECT - # AS {time_column} - # ,{columns} - # FROM rel - # """ - # ) - - -def unpivot( - rel: DuckDBPyRelation, - pivoted_columns: Iterable[str], - name_column: str, - value_column: str, -) -> DuckDBPyRelation: - pivoted_str = ",".join(pivoted_columns) - - query = f""" - SELECT * FROM rel - UNPIVOT INCLUDE NULLS ( - {value_column} - FOR {name_column} in ({pivoted_str}) - ) - """ - return duckdb.sql(query) diff --git a/src/chronify/ibis/base.py b/src/chronify/ibis/base.py index f166932..ae9fe34 100644 --- a/src/chronify/ibis/base.py +++ b/src/chronify/ibis/base.py @@ -184,15 +184,15 @@ def has_table(self, name: str) -> bool: """Check whether a table or view exists.""" return name in self.list_tables() - def execute_sql(self, query: str) -> Any: + def execute_sql(self, query: str) -> None: """Execute a raw SQL statement (no result expected).""" logger.trace("execute_sql: {}", query) - return self.connection.raw_sql(query) + self.connection.raw_sql(query) def execute_sql_to_df(self, query: str) -> pd.DataFrame: """Execute a raw SQL query and return a DataFrame.""" logger.trace("execute_sql_to_df: {}", query) - return cast(pd.DataFrame, self.connection.raw_sql(query).fetch_df()) + return cast(pd.DataFrame, self.sql(query).execute()) def read_query(self, expr: ir.Table, config: TimeBaseModel) -> pd.DataFrame: """Execute an Ibis expression and return a normalized pandas DataFrame.""" diff --git a/src/chronify/ibis/duckdb_backend.py b/src/chronify/ibis/duckdb_backend.py index f436372..4d58c9b 100644 --- a/src/chronify/ibis/duckdb_backend.py +++ b/src/chronify/ibis/duckdb_backend.py @@ -125,11 +125,6 @@ def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, Obje ) return self.table(name), ObjectType.VIEW - def execute_sql_to_df(self, query: str) -> pd.DataFrame: - logger.trace("execute_sql_to_df: {}", query) - result = self._connection.raw_sql(query) - return cast(pd.DataFrame, result.fetch_df()) - def dispose(self) -> None: if self._owns_connection: self._connection.disconnect() diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index e6769b0..7f1c321 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -3,7 +3,7 @@ import uuid import shutil from contextlib import contextmanager -from typing import Any, Generator, Sequence, cast +from typing import Any, Generator, Sequence from pathlib import Path from urllib.parse import urlparse, unquote @@ -156,10 +156,6 @@ def execute_sql(self, query: str) -> None: logger.trace("execute_sql: {}", query) self._session.sql(query) - def execute_sql_to_df(self, query: str) -> pd.DataFrame: - logger.trace("execute_sql_to_df: {}", query) - return cast(pd.DataFrame, self._session.sql(query).toPandas()) - def dispose(self) -> None: self._connection.disconnect() if self._owns_session: diff --git a/src/chronify/ibis/sqlite_backend.py b/src/chronify/ibis/sqlite_backend.py index ba93869..72889e0 100644 --- a/src/chronify/ibis/sqlite_backend.py +++ b/src/chronify/ibis/sqlite_backend.py @@ -142,14 +142,6 @@ def execute_sql(self, query: str) -> None: con.execute(query) self._commit_if_needed() - def execute_sql_to_df(self, query: str) -> pd.DataFrame: - logger.trace("execute_sql_to_df: {}", query) - con = self._connection.con - cursor = con.execute(query) - rows = cursor.fetchall() - columns = [desc[0] for desc in cursor.description] if cursor.description else [] - return pd.DataFrame(rows, columns=columns) - def dispose(self) -> None: if self._owns_connection: self._connection.disconnect() diff --git a/src/chronify/store.py b/src/chronify/store.py index f1628a4..093b188 100644 --- a/src/chronify/store.py +++ b/src/chronify/store.py @@ -3,13 +3,11 @@ from typing import Any, Optional, cast from datetime import tzinfo -import duckdb +import ibis import ibis.expr.types as ir import pandas as pd -from duckdb import DuckDBPyRelation from loguru import logger -import chronify.duckdb.functions as ddbf from chronify.exceptions import ( ConflictingInputsError, InvalidParameter, @@ -229,9 +227,9 @@ def _ingest_from_csv( src_schema: CsvTableSchema, dst_schema: TableSchema, ) -> bool: - rel = read_csv(path, src_schema) + df = read_csv(path, src_schema).to_df() columns = set(src_schema.list_columns()) - check_columns(rel.columns, columns) + check_columns(df.columns, columns) if isinstance(src_schema.time_config, IndexTimeRangeBase): if isinstance(dst_schema.time_config, DatetimeRange): @@ -242,13 +240,13 @@ def _ingest_from_csv( raise NotImplementedError(msg) if src_schema.pivoted_dimension_name is not None: - return self._ingest_pivoted_table(rel, src_schema, dst_schema) + return self._ingest_pivoted_table(df, src_schema, dst_schema) - return self._ingest_table(rel, dst_schema) + return self._ingest_table(df, dst_schema) def ingest_pivoted_table( self, - data: pd.DataFrame | DuckDBPyRelation, + data: pd.DataFrame | ir.Table, src_schema: PivotedTableSchema | CsvTableSchema, dst_schema: TableSchema, ) -> bool: @@ -257,7 +255,7 @@ def ingest_pivoted_table( def ingest_pivoted_tables( self, - data: Iterable[pd.DataFrame | DuckDBPyRelation], + data: Iterable[pd.DataFrame | ir.Table], src_schema: PivotedTableSchema | CsvTableSchema, dst_schema: TableSchema, ) -> bool: @@ -276,7 +274,7 @@ def ingest_pivoted_tables( def _ingest_pivoted_tables( self, - data: Iterable[pd.DataFrame | DuckDBPyRelation], + data: Iterable[pd.DataFrame | ir.Table], src_schema: PivotedTableSchema | CsvTableSchema, dst_schema: TableSchema, ) -> bool: @@ -289,28 +287,22 @@ def _ingest_pivoted_tables( def _ingest_pivoted_table( self, - data: pd.DataFrame | DuckDBPyRelation, + data: pd.DataFrame | ir.Table, src_schema: PivotedTableSchema | CsvTableSchema, dst_schema: TableSchema, ) -> bool: - if isinstance(data, pd.DataFrame): - tmp_df = data # noqa: F841 - rel = duckdb.sql("SELECT * from tmp_df") - else: - rel = data - assert src_schema.pivoted_dimension_name is not None - rel2 = ddbf.unpivot( - rel, - src_schema.value_columns, - src_schema.pivoted_dimension_name, - dst_schema.value_column, + expr = data if isinstance(data, ir.Table) else ibis.memtable(data) + unpivoted = expr.pivot_longer( + list(src_schema.value_columns), + names_to=src_schema.pivoted_dimension_name, + values_to=dst_schema.value_column, ) - return self._ingest_table(rel2, dst_schema) + return self._ingest_table(unpivoted, dst_schema) def ingest_table( self, - data: pd.DataFrame | DuckDBPyRelation, + data: pd.DataFrame | ir.Table, schema: TableSchema, **kwargs: Any, ) -> bool: @@ -338,7 +330,7 @@ def ingest_table( def ingest_tables( self, - data: Iterable[pd.DataFrame | DuckDBPyRelation], + data: Iterable[pd.DataFrame | ir.Table], schema: TableSchema, **kwargs: Any, ) -> bool: @@ -361,7 +353,7 @@ def ingest_tables( def _ingest_tables( self, - data: Iterable[pd.DataFrame | DuckDBPyRelation], + data: Iterable[pd.DataFrame | ir.Table], schema: TableSchema, skip_time_checks: bool = False, ) -> bool: @@ -375,10 +367,10 @@ def _ingest_tables( def _ingest_table( self, - data: pd.DataFrame | DuckDBPyRelation, + data: pd.DataFrame | ir.Table, schema: TableSchema, ) -> bool: - df = data.to_df() if isinstance(data, DuckDBPyRelation) else data + df = data.execute() if isinstance(data, ir.Table) else data check_columns(df.columns, schema.list_columns()) if not self._backend.has_table(schema.name): @@ -720,11 +712,11 @@ def read_table(self, name: str) -> ir.Table: return self._backend.table(name) def read_raw_query(self, query: str) -> pd.DataFrame: - """Execute a query directly on the backend and return the results as a DataFrame. + """Execute a raw SQL query on the backend and return the results as a DataFrame. - Note: Unlike :meth:`read_query`, no conversion of timestamps is performed. - Timestamps will be in the format of the underlying database. SQLite backends will return - strings instead of datetime. + This is an escape hatch for executing backend-specific SQL that Ibis cannot + express. For portable queries, prefer :meth:`read_query`, which returns an + Ibis Table expression with consistent cross-backend typing. Parameters ---------- diff --git a/tests/test_store.py b/tests/test_store.py index e9d9cd1..d608350 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -12,7 +12,6 @@ import pytest from chronify.csv_io import read_csv -from chronify.duckdb.functions import unpivot from chronify.exceptions import ( ConflictingInputsError, InvalidOperation, @@ -212,11 +211,13 @@ def test_ingest_multiple_tables_error(iter_stores_by_engine: Store, multiple_tab @pytest.mark.parametrize("use_pandas", [False, True]) def test_ingest_pivoted_table(iter_stores_by_engine: Store, generators_schema, use_pandas: bool): + import ibis + store = iter_stores_by_engine src_file, src_schema, dst_schema = generators_schema pivoted_schema = PivotedTableSchema(**src_schema.model_dump(exclude={"column_dtypes"})) - rel = read_csv(src_file, src_schema) - input_table = rel.to_df() if use_pandas else rel + df = read_csv(src_file, src_schema).to_df() + input_table = df if use_pandas else ibis.memtable(df) store.ingest_pivoted_table(input_table, pivoted_schema, dst_schema) table = store.get_table(dst_schema.name) stmt = table.filter(table.generator == "gen1") @@ -307,10 +308,15 @@ def test_load_parquet(iter_stores_by_engine_no_data_ingestion: Store, tmp_path): time_array_id_columns=["generator"], value_column="value", ) - rel = read_csv(GENERATOR_TIME_SERIES_FILE, src_schema) - rel2 = unpivot(rel, ("gen1", "gen2", "gen3"), "generator", "value") # noqa: F841 + df = read_csv(GENERATOR_TIME_SERIES_FILE, src_schema).to_df() + df2 = df.melt( + id_vars=["timestamp"], + value_vars=["gen1", "gen2", "gen3"], + var_name="generator", + value_name="value", + ) out_file = tmp_path / "gen2.parquet" - rel2.to_parquet(str(out_file)) + df2.to_parquet(str(out_file)) store.create_view_from_parquet(out_file, dst_schema) df = store.read_table(dst_schema.name).execute() assert len(df) == 8784 * 3 @@ -431,8 +437,13 @@ def test_map_datetime_to_datetime( time_array_id_columns=["generator"], value_column="value", ) - rel = read_csv(GENERATOR_TIME_SERIES_FILE, src_csv_schema) - rel2 = unpivot(rel, ("gen1", "gen2", "gen3"), "generator", "value") # noqa: F841 + df = read_csv(GENERATOR_TIME_SERIES_FILE, src_csv_schema).to_df() + df2 = df.melt( + id_vars=["timestamp"], + value_vars=["gen1", "gen2", "gen3"], + var_name="generator", + value_name="value", + ) src_schema = TableSchema( name="generators_pb", @@ -440,7 +451,7 @@ def test_map_datetime_to_datetime( time_array_id_columns=["generator"], value_column="value", ) - store.ingest_table(rel2, src_schema) + store.ingest_table(df2, src_schema) if tzinfo is None and store.backend.name != "sqlite": output_file = tmp_path / "mapped_data" @@ -750,8 +761,13 @@ def test_convert_time_zone( pivoted_dimension_name="generator", time_array_id_columns=[], ) - rel = read_csv(GENERATOR_TIME_SERIES_FILE, src_csv_schema) - rel2 = unpivot(rel, ("gen1", "gen2", "gen3"), "generator", "value") # noqa: F841 + df = read_csv(GENERATOR_TIME_SERIES_FILE, src_csv_schema).to_df() + df2 = df.melt( + id_vars=["timestamp"], + value_vars=["gen1", "gen2", "gen3"], + var_name="generator", + value_name="value", + ) src_schema = TableSchema( name="generators_pb", @@ -759,7 +775,7 @@ def test_convert_time_zone( time_array_id_columns=["generator"], value_column="value", ) - store.ingest_table(rel2, src_schema) + store.ingest_table(df2, src_schema) if tzinfo is None and store.backend.name != "sqlite": output_file = tmp_path / "mapped_data" @@ -817,13 +833,16 @@ def test_convert_time_zone_by_column( pivoted_dimension_name="generator", time_array_id_columns=[], ) - rel = read_csv(GENERATOR_TIME_SERIES_FILE, src_csv_schema) - rel2 = unpivot(rel, ("gen1", "gen2", "gen3"), "generator", "value") # noqa: F841 - # add time_zone column - stmt = ", ".join(rel2.columns) - tz_col_stmt = "CASE WHEN generator='gen1' THEN 'US/Eastern' WHEN generator='gen2' THEN 'US/Central' ELSE 'None' END AS time_zone" - stmt += f", {tz_col_stmt}" - rel2 = rel2.project(stmt) + df = read_csv(GENERATOR_TIME_SERIES_FILE, src_csv_schema).to_df() + df2 = df.melt( + id_vars=["timestamp"], + value_vars=["gen1", "gen2", "gen3"], + var_name="generator", + value_name="value", + ) + df2["time_zone"] = ( + df2["generator"].map({"gen1": "US/Eastern", "gen2": "US/Central"}).fillna("None") + ) src_schema = TableSchema( name="generators_pb", @@ -831,7 +850,7 @@ def test_convert_time_zone_by_column( time_array_id_columns=["generator", "time_zone"], value_column="value", ) - store.ingest_table(rel2, src_schema) + store.ingest_table(df2, src_schema) if tzinfo is None and store.backend.name != "sqlite": output_file = tmp_path / "mapped_data" @@ -896,8 +915,13 @@ def test_localize_time_zone( pivoted_dimension_name="generator", time_array_id_columns=[], ) - rel = read_csv(GENERATOR_TIME_SERIES_FILE, src_csv_schema) - rel2 = unpivot(rel, ("gen1", "gen2", "gen3"), "generator", "value") # noqa: F841 + df = read_csv(GENERATOR_TIME_SERIES_FILE, src_csv_schema).to_df() + df2 = df.melt( + id_vars=["timestamp"], + value_vars=["gen1", "gen2", "gen3"], + var_name="generator", + value_name="value", + ) src_schema = TableSchema( name="generators_pb", @@ -905,7 +929,7 @@ def test_localize_time_zone( time_array_id_columns=["generator"], value_column="value", ) - store.ingest_table(rel2, src_schema) + store.ingest_table(df2, src_schema) if to_time_zone is None and store.backend.name != "sqlite": output_file = tmp_path / "mapped_data" @@ -965,13 +989,16 @@ def test_localize_time_zone_by_column(tmp_path, iter_stores_by_engine_no_data_in pivoted_dimension_name="generator", time_array_id_columns=[], ) - rel = read_csv(GENERATOR_TIME_SERIES_FILE, src_csv_schema) - rel2 = unpivot(rel, ("gen1", "gen2", "gen3"), "generator", "value") # noqa: F841 - # add time_zone column with standard time zones (not DST) - stmt = ", ".join(rel2.columns) - tz_col_stmt = "CASE WHEN generator='gen1' THEN 'Etc/GMT+5' WHEN generator='gen2' THEN 'Etc/GMT+6' ELSE 'Etc/GMT+7' END AS time_zone" - stmt += f", {tz_col_stmt}" - rel2 = rel2.project(stmt) + df = read_csv(GENERATOR_TIME_SERIES_FILE, src_csv_schema).to_df() + df2 = df.melt( + id_vars=["timestamp"], + value_vars=["gen1", "gen2", "gen3"], + var_name="generator", + value_name="value", + ) + df2["time_zone"] = ( + df2["generator"].map({"gen1": "Etc/GMT+5", "gen2": "Etc/GMT+6"}).fillna("Etc/GMT+7") + ) src_schema = TableSchema( name="generators_pb", @@ -979,7 +1006,7 @@ def test_localize_time_zone_by_column(tmp_path, iter_stores_by_engine_no_data_in time_array_id_columns=["generator", "time_zone"], value_column="value", ) - store.ingest_table(rel2, src_schema) + store.ingest_table(df2, src_schema) if store.backend.name != "sqlite": output_file = tmp_path / "mapped_data" From 42db94afa8e4c0eb1bf10c5c572bdebfe29147d4 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Sun, 12 Apr 2026 19:25:39 -0600 Subject: [PATCH 28/48] Apply schema-driven type casts to Store.read_table Spark stores all timestamps as session-local (no per-column tz metadata), so a TIMESTAMP_TZ column reads back tz-naive from Ibis.execute(). The old _post_read_normalize hook fixed that in pandas after collect(), which is fine for small queries but unsound for a backend-agnostic API over arbitrarily large tables. Replace the hook with apply_schema_types(expr, config), a lazy Ibis-level cast applied before execution. SparkBackend casts TIMESTAMP_TZ columns to Timestamp(timezone="UTC") in the expression, so downstream .execute(), .to_parquet(), or streaming paths all produce correctly-typed output without materializing. Co-Authored-By: Claude Opus 4.6 --- src/chronify/ibis/base.py | 25 +++++++------- src/chronify/ibis/spark_backend.py | 46 +++++++++++-------------- src/chronify/store.py | 4 ++- tests/test_ibis_functions.py | 55 ------------------------------ 4 files changed, 35 insertions(+), 95 deletions(-) diff --git a/src/chronify/ibis/base.py b/src/chronify/ibis/base.py index ae9fe34..d7600b5 100644 --- a/src/chronify/ibis/base.py +++ b/src/chronify/ibis/base.py @@ -195,11 +195,19 @@ def execute_sql_to_df(self, query: str) -> pd.DataFrame: return cast(pd.DataFrame, self.sql(query).execute()) def read_query(self, expr: ir.Table, config: TimeBaseModel) -> pd.DataFrame: - """Execute an Ibis expression and return a normalized pandas DataFrame.""" - df = self.execute(expr) - if isinstance(config, _DATETIME_RANGES): - self._post_read_normalize(df, config) - return df + """Execute an Ibis expression and return a pandas DataFrame.""" + return self.execute(self.apply_schema_types(expr, config)) + + def apply_schema_types(self, expr: ir.Table, config: TimeBaseModel) -> ir.Table: + """Return ``expr`` with backend-specific casts applied so its Ibis type + matches ``config``. + + Default: no-op. Backends whose storage loses type information (e.g. + Spark, which stores all timestamps as session-local) should override to + add lazy ``cast`` expressions, so callers get correctly-typed pandas + output without forcing a materialize-then-normalize round-trip. + """ + return expr def write_table( self, @@ -213,13 +221,6 @@ def write_table( prepared = self._prepare_write_data(data, configs) self._apply_if_exists(prepared, name, if_exists) - def _post_read_normalize(self, df: pd.DataFrame, config: DatetimeRanges) -> None: - """Backend-specific in-place normalization of a read DataFrame. - - Default: no-op. Backends whose drivers return non-canonical datetime - types should override. - """ - def _prepare_write_data( self, data: pd.DataFrame | pa.Table, diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index 7f1c321..5876748 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -8,6 +8,7 @@ from urllib.parse import urlparse, unquote import ibis +import ibis.expr.datatypes as dt import ibis.expr.types as ir import pandas as pd import pyarrow as pa @@ -16,9 +17,9 @@ from chronify.exceptions import InvalidOperation, InvalidParameter from chronify.ibis.base import ( - DatetimeRanges, IbisBackend, ObjectType, + _DATETIME_RANGES, _normalize_timestamps, ) from chronify.time import TimeDataType @@ -176,9 +177,23 @@ def _remove_managed_table_location(self, name: str) -> None: if path.exists(): shutil.rmtree(path) - def _post_read_normalize(self, df: pd.DataFrame, config: DatetimeRanges) -> None: - """Spark returns tz-naive nanosecond timestamps; coerce to schema dtype + µs unit.""" - _convert_spark_output_for_datetime(df, config) + def apply_schema_types(self, expr: ir.Table, config: TimeBaseModel) -> ir.Table: + """Cast the time column to the Ibis type implied by ``config``. + + Spark stores all timestamps as session-local (no per-column tz + metadata), so a column declared ``TIMESTAMP_TZ`` reads back tz-naive + unless we re-type it in the expression. The session is pinned to UTC + (see :meth:`_validate_session`), so the cast is lossless. + """ + if not isinstance(config, _DATETIME_RANGES): + return expr + if config.dtype != TimeDataType.TIMESTAMP_TZ: + return expr + if config.time_column not in expr.columns: + return expr + return expr.mutate( + **{config.time_column: expr[config.time_column].cast(dt.Timestamp(timezone="UTC"))} + ) def _prepare_write_data( self, @@ -215,29 +230,6 @@ def _validate_session(session: Any) -> None: raise InvalidParameter(msg) -def _convert_spark_output_for_datetime(df: pd.DataFrame, config: DatetimeRanges) -> None: - """Convert DataFrame datetime columns after Spark output.""" - if config.time_column not in df.columns: - return - - col = df[config.time_column] - - if config.dtype == TimeDataType.TIMESTAMP_TZ: - if not pd.api.types.is_datetime64_any_dtype(col): - col = pd.to_datetime(col, utc=True) - elif isinstance(col.dtype, DatetimeTZDtype): - col = col.dt.tz_convert("UTC") - else: - col = col.dt.tz_localize("UTC") - df[config.time_column] = col.dt.as_unit("us") - else: - if not pd.api.types.is_datetime64_any_dtype(col): - col = pd.to_datetime(col, utc=False) - df[config.time_column] = col.astype("datetime64[us]") - if isinstance(col.dtype, DatetimeTZDtype): - df[config.time_column] = col.dt.tz_convert(None).astype("datetime64[us]") - - def _validate_insert_columns( table_name: str, target_columns: list[str], data_columns: list[str] ) -> None: diff --git a/src/chronify/store.py b/src/chronify/store.py index 093b188..58652f8 100644 --- a/src/chronify/store.py +++ b/src/chronify/store.py @@ -709,7 +709,9 @@ def read_table(self, name: str) -> ir.Table: if not self.has_table(name): msg = f"{name=}" raise TableNotStored(msg) - return self._backend.table(name) + expr = self._backend.table(name) + schema = self._schema_mgr.get_schema(name) + return self._backend.apply_schema_types(expr, schema.time_config) def read_raw_query(self, query: str) -> pd.DataFrame: """Execute a raw SQL query on the backend and return the results as a DataFrame. diff --git a/tests/test_ibis_functions.py b/tests/test_ibis_functions.py index fac558d..08b8b1b 100644 --- a/tests/test_ibis_functions.py +++ b/tests/test_ibis_functions.py @@ -13,7 +13,6 @@ _check_one_config_per_datetime_column, _normalize_timestamps, ) -from chronify.ibis.spark_backend import _convert_spark_output_for_datetime from chronify.time import TimeIntervalType from chronify.time_configs import DatetimeRange @@ -81,60 +80,6 @@ def test_duplicate_config_raises(self): _check_one_config_per_datetime_column(configs) -class TestConvertSparkOutputForDatetime: - def test_tz_with_object_dtype(self): - config = _make_tz_config() - df = pd.DataFrame({"timestamp": ["2020-01-01 00:00:00", "2020-01-01 01:00:00"]}) - _convert_spark_output_for_datetime(df, config) - assert isinstance(df["timestamp"].dtype, pd.DatetimeTZDtype) - - def test_tz_with_tz_aware_dtype(self): - config = _make_tz_config() - df = pd.DataFrame( - { - "timestamp": pd.to_datetime( - ["2020-01-01 00:00:00+05:00", "2020-01-01 01:00:00+05:00"] - ), - } - ) - _convert_spark_output_for_datetime(df, config) - assert str(df["timestamp"].dt.tz) == "UTC" - - def test_tz_with_naive_dtype(self): - config = _make_tz_config() - df = pd.DataFrame( - { - "timestamp": pd.to_datetime(["2020-01-01 00:00:00", "2020-01-01 01:00:00"]), - } - ) - _convert_spark_output_for_datetime(df, config) - assert isinstance(df["timestamp"].dtype, pd.DatetimeTZDtype) - - def test_ntz_with_object_dtype(self): - config = _make_ntz_config() - df = pd.DataFrame({"timestamp": ["2020-01-01 00:00:00", "2020-01-01 01:00:00"]}) - _convert_spark_output_for_datetime(df, config) - assert pd.api.types.is_datetime64_any_dtype(df["timestamp"]) - - def test_ntz_strips_tz_from_aware(self): - config = _make_ntz_config() - df = pd.DataFrame( - { - "timestamp": pd.to_datetime( - ["2020-01-01 00:00:00+00:00", "2020-01-01 01:00:00+00:00"], utc=True - ), - } - ) - _convert_spark_output_for_datetime(df, config) - assert not isinstance(df["timestamp"].dtype, pd.DatetimeTZDtype) - - def test_missing_column_is_noop(self): - config = _make_tz_config() - df = pd.DataFrame({"other": [1, 2]}) - _convert_spark_output_for_datetime(df, config) - assert list(df.columns) == ["other"] - - class TestWriteTable: def test_pyarrow_table_input(self): backend = make_backend("duckdb") From fc3a69f652127bfd81ac305fc789669c50d49209 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Sun, 12 Apr 2026 20:26:33 -0600 Subject: [PATCH 29/48] Delegate backend insert/view creation to Ibis natives - Move default insert() to IbisBackend: validates columns, reorders, and calls connection.insert(). DuckDB and Spark now use the default; SQLite keeps its raw-cursor override so inserts honor our BEGIN/ROLLBACK transaction state (ibis.sqlite.insert auto-commits). - Drop SparkBackend's temp-view-and-SQL insert, _prepare_data_for_spark, and _prepare_write_data: ibis.pyspark.create_table / insert already convert tz-aware pandas to UTC-naive and accept pyarrow directly. - Simplify create_view_from_parquet on DuckDB and Spark to connection.create_view(name, connection.read_parquet(path)). - Hoist shared _validate_insert_columns / _get_columns / _select_columns / _row_count helpers from the per-backend files into base.py. Co-Authored-By: Claude Opus 4.6 --- src/chronify/ibis/base.py | 45 +++++++++++++++- src/chronify/ibis/duckdb_backend.py | 67 +++--------------------- src/chronify/ibis/spark_backend.py | 80 ++--------------------------- 3 files changed, 53 insertions(+), 139 deletions(-) diff --git a/src/chronify/ibis/base.py b/src/chronify/ibis/base.py index d7600b5..af0271b 100644 --- a/src/chronify/ibis/base.py +++ b/src/chronify/ibis/base.py @@ -86,6 +86,37 @@ def _arrow_needs_timestamp_normalization( return False +def _get_columns(data: pd.DataFrame | pa.Table) -> list[str]: + if isinstance(data, pa.Table): + return cast(list[str], data.column_names) + return list(data.columns) + + +def _select_columns(data: pd.DataFrame | pa.Table, columns: list[str]) -> pd.DataFrame | pa.Table: + if isinstance(data, pa.Table): + return data.select(columns) + return data.loc[:, columns] + + +def _row_count(data: pd.DataFrame | pa.Table) -> int: + if isinstance(data, pa.Table): + return cast(int, data.num_rows) + return len(data) + + +def _validate_insert_columns( + table_name: str, target_columns: list[str], data_columns: list[str] +) -> None: + missing = [c for c in target_columns if c not in data_columns] + extra = [c for c in data_columns if c not in target_columns] + if missing or extra: + msg = ( + f"Insert data columns do not match table {table_name!r}. " + f"Missing: {missing}. Extra: {extra}." + ) + raise InvalidParameter(msg) + + class ObjectType(StrEnum): TABLE = "table" VIEW = "view" @@ -139,9 +170,19 @@ def table(self, name: str) -> ir.Table: """Return an ibis table expression for the named table.""" return self.connection.table(name) - @abstractmethod def insert(self, name: str, data: pd.DataFrame | pa.Table) -> None: - """Insert data into an existing table.""" + """Insert data into an existing table. + + Validates that the data columns match the target table, reorders them, + and delegates to the underlying Ibis connection. Subclasses should + override when the default does not cooperate with backend-specific + transaction semantics. + """ + target_columns = list(self.table(name).columns) + _validate_insert_columns(name, target_columns, _get_columns(data)) + ordered = _select_columns(data, target_columns) + self.connection.insert(name, ordered) + logger.trace("Inserted {} rows into {}", _row_count(ordered), name) @abstractmethod def delete_rows(self, name: str, values: dict[str, Any]) -> None: diff --git a/src/chronify/ibis/duckdb_backend.py b/src/chronify/ibis/duckdb_backend.py index 4d58c9b..f91c16f 100644 --- a/src/chronify/ibis/duckdb_backend.py +++ b/src/chronify/ibis/duckdb_backend.py @@ -7,7 +7,6 @@ import ibis import ibis.expr.types as ir import pandas as pd -import pyarrow as pa from loguru import logger from chronify.exceptions import ConflictingInputsError, InvalidOperation, InvalidParameter @@ -68,23 +67,6 @@ def list_tables(self) -> list[str]: # Filter out internal ibis memtables return [t for t in tables if not t.startswith("ibis_pandas_memtable_")] - def insert(self, name: str, data: pd.DataFrame | pa.Table) -> None: - con = self._connection.con # raw duckdb connection - target_columns = list(self.table(name).columns) - _validate_insert_columns(name, target_columns, _get_columns(data)) - ordered_data = _select_columns(data, target_columns) - quoted_columns = ", ".join(f'"{col}"' for col in target_columns) - quoted_name = _quote_identifier(name) - con.register("__insert_df", ordered_data) - try: - con.execute( - f"INSERT INTO {quoted_name} ({quoted_columns}) " - f"SELECT {quoted_columns} FROM __insert_df" - ) - finally: - con.unregister("__insert_df") - logger.trace("Inserted {} rows into {}", _row_count(data), name) - def delete_rows(self, name: str, values: dict[str, Any]) -> None: con = self._connection.con quoted_name = _quote_identifier(name) @@ -118,11 +100,7 @@ def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, Obje read_path = str(parquet_path / "**" / "*.parquet").replace("\\", "/") else: read_path = str(parquet_path).replace("\\", "/") - quoted_name = _quote_identifier(name) - escaped_path = read_path.replace("'", "''") - self._connection.raw_sql( - f"CREATE VIEW {quoted_name} AS SELECT * FROM read_parquet('{escaped_path}')" - ) + self._connection.create_view(name, self._connection.read_parquet(read_path)) return self.table(name), ObjectType.VIEW def dispose(self) -> None: @@ -153,6 +131,12 @@ def _rollback_transaction(self) -> None: self._connection.con.execute("ROLLBACK") +def _quote_identifier(identifier: str) -> str: + """Quote a SQL identifier for DuckDB, escaping embedded double quotes.""" + escaped = identifier.replace('"', '""') + return f'"{escaped}"' + + def _infer_duckdb_path(connection: ibis.BaseBackend) -> str | None: """Return the database file path for an ibis DuckDB connection, or None for in-memory.""" try: @@ -165,40 +149,3 @@ def _infer_duckdb_path(connection: ibis.BaseBackend) -> str | None: return None path = result[0] return None if not path else str(path) - - -def _validate_insert_columns( - table_name: str, target_columns: list[str], data_columns: list[str] -) -> None: - missing = [c for c in target_columns if c not in data_columns] - extra = [c for c in data_columns if c not in target_columns] - if missing or extra: - msg = ( - f"Insert data columns do not match table {table_name!r}. " - f"Missing: {missing}. Extra: {extra}." - ) - raise InvalidParameter(msg) - - -def _quote_identifier(identifier: str) -> str: - """Quote a SQL identifier for DuckDB, escaping embedded double quotes.""" - escaped = identifier.replace('"', '""') - return f'"{escaped}"' - - -def _get_columns(data: pd.DataFrame | pa.Table) -> list[str]: - if isinstance(data, pa.Table): - return cast(list[str], data.column_names) - return list(data.columns) - - -def _select_columns(data: pd.DataFrame | pa.Table, columns: list[str]) -> pd.DataFrame | pa.Table: - if isinstance(data, pa.Table): - return data.select(columns) - return data.loc[:, columns] - - -def _row_count(data: pd.DataFrame | pa.Table) -> int: - if isinstance(data, pa.Table): - return cast(int, data.num_rows) - return len(data) diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index 5876748..c953e9a 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -2,8 +2,7 @@ import uuid import shutil -from contextlib import contextmanager -from typing import Any, Generator, Sequence +from typing import Any from pathlib import Path from urllib.parse import urlparse, unquote @@ -13,15 +12,9 @@ import pandas as pd import pyarrow as pa from loguru import logger -from pandas import DatetimeTZDtype from chronify.exceptions import InvalidOperation, InvalidParameter -from chronify.ibis.base import ( - IbisBackend, - ObjectType, - _DATETIME_RANGES, - _normalize_timestamps, -) +from chronify.ibis.base import IbisBackend, ObjectType, _DATETIME_RANGES from chronify.time import TimeDataType from chronify.time_configs import TimeBaseModel @@ -70,8 +63,6 @@ def create_table( schema: ibis.Schema | None = None, overwrite: bool = False, ) -> ir.Table: - if isinstance(obj, pd.DataFrame): - obj = self._prepare_data_for_spark(obj) try: return self._connection.create_table(name, obj=obj, schema=schema, overwrite=overwrite) except Exception as exc: @@ -80,33 +71,6 @@ def create_table( self._remove_managed_table_location(name) return self._connection.create_table(name, obj=obj, schema=schema, overwrite=overwrite) - def insert(self, name: str, data: pd.DataFrame | pa.Table) -> None: - if isinstance(data, pa.Table): - data = data.to_pandas() - # Spark doesn't support INSERT directly -- create a temp view and insert via SQL - target_columns = list(self.table(name).columns) - _validate_insert_columns(name, target_columns, list(data.columns)) - data = data.loc[:, target_columns] - data = self._prepare_data_for_spark(data) - spark_df = self._session.createDataFrame(data) - quoted_name = _quote_identifier(name) - col_list = ", ".join(_quote_identifier(c) for c in target_columns) - with self._temp_view(spark_df) as tmp_view: - self._session.sql( - f"INSERT INTO {quoted_name} ({col_list}) SELECT {col_list} FROM {tmp_view}" - ) - logger.trace("Inserted {} rows into {}", len(data), name) - - @contextmanager - def _temp_view(self, spark_df: Any) -> Generator[str, None, None]: - """Register ``spark_df`` as a uniquely-named temp view; drop on exit.""" - tmp_view = f"__chronify_tmp_{uuid.uuid4().hex}" - spark_df.createOrReplaceTempView(tmp_view) - try: - yield tmp_view - finally: - self._session.catalog.dropTempView(tmp_view) - def delete_rows(self, name: str, values: dict[str, Any]) -> None: # Spark 3.4+ supports parameterized SQL via the ``args`` keyword. quoted_name = _quote_identifier(name) @@ -149,8 +113,7 @@ def write_parquet( self._connection.to_parquet(expr, path) def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, ObjectType]: - spark_df = self._session.read.parquet(path) - spark_df.createOrReplaceTempView(name) + self._connection.create_view(name, self._connection.read_parquet(path)) return self.table(name), ObjectType.VIEW def execute_sql(self, query: str) -> None: @@ -195,30 +158,6 @@ def apply_schema_types(self, expr: ir.Table, config: TimeBaseModel) -> ir.Table: **{config.time_column: expr[config.time_column].cast(dt.Timestamp(timezone="UTC"))} ) - def _prepare_write_data( - self, - data: pd.DataFrame | pa.Table, - configs: Sequence[TimeBaseModel], - ) -> pd.DataFrame: - """Spark ingestion goes through createDataFrame(pandas); Arrow must be converted.""" - if isinstance(data, pa.Table): - data = data.to_pandas() - return _normalize_timestamps(data, configs) - - @staticmethod - def _prepare_data_for_spark(df: pd.DataFrame) -> pd.DataFrame: - """Normalize tz-aware pandas timestamps for Spark ingestion. - - Spark timestamps are timezone-naive and interpreted in the session time - zone. We require UTC sessions, so convert tz-aware columns to tz-naive - UTC timestamps before handing them to Spark. - """ - df = df.copy() - for col in df.columns: - if isinstance(df[col].dtype, DatetimeTZDtype): - df[col] = df[col].dt.tz_convert("UTC").dt.tz_localize(None) - return df - @staticmethod def _validate_session(session: Any) -> None: time_zone = session.conf.get("spark.sql.session.timeZone", None) or "UTC" @@ -230,19 +169,6 @@ def _validate_session(session: Any) -> None: raise InvalidParameter(msg) -def _validate_insert_columns( - table_name: str, target_columns: list[str], data_columns: list[str] -) -> None: - missing = [c for c in target_columns if c not in data_columns] - extra = [c for c in data_columns if c not in target_columns] - if missing or extra: - msg = ( - f"Insert data columns do not match table {table_name!r}. " - f"Missing: {missing}. Extra: {extra}." - ) - raise InvalidParameter(msg) - - def _quote_identifier(identifier: str) -> str: """Quote a SQL identifier for Spark SQL, escaping embedded backticks.""" escaped = identifier.replace("`", "``") From d43760fcf808384bd96aac096dc73fff1f5e7e15 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Mon, 13 Apr 2026 08:35:10 -0600 Subject: [PATCH 30/48] Defer ibis.Table materialization through the write path Propagate ibis.Table inputs through Store._ingest_table and IbisBackend.write_table so backends that can ingest expressions natively (DuckDB, Spark) skip a premature pandas round-trip. SQLite retains the materialization, but does so via the expression's own backend. Replace isinstance-based type branching in helper functions with functools.singledispatch(method), and swap ir.Table for the public ibis.Table name. Co-Authored-By: Claude Opus 4.6 --- src/chronify/ibis/base.py | 117 +++++++++++++++++++--------- src/chronify/ibis/duckdb_backend.py | 6 +- src/chronify/ibis/spark_backend.py | 27 ++++--- src/chronify/ibis/sqlite_backend.py | 34 +++++--- src/chronify/store.py | 38 +++++---- 5 files changed, 141 insertions(+), 81 deletions(-) diff --git a/src/chronify/ibis/base.py b/src/chronify/ibis/base.py index af0271b..1fafc0e 100644 --- a/src/chronify/ibis/base.py +++ b/src/chronify/ibis/base.py @@ -4,6 +4,7 @@ from collections import Counter from contextlib import contextmanager from enum import StrEnum +from functools import singledispatch, singledispatchmethod from typing import Any, Generator, Sequence, cast import ibis @@ -86,22 +87,46 @@ def _arrow_needs_timestamp_normalization( return False -def _get_columns(data: pd.DataFrame | pa.Table) -> list[str]: - if isinstance(data, pa.Table): - return cast(list[str], data.column_names) +@singledispatch +def _get_columns(data: Any) -> list[str]: + msg = f"Unsupported data type: {type(data)}" + raise TypeError(msg) + + +@_get_columns.register +def _(data: pd.DataFrame) -> list[str]: return list(data.columns) -def _select_columns(data: pd.DataFrame | pa.Table, columns: list[str]) -> pd.DataFrame | pa.Table: - if isinstance(data, pa.Table): - return data.select(columns) +@_get_columns.register +def _(data: pa.Table) -> list[str]: + return cast(list[str], data.column_names) + + +@_get_columns.register +def _(data: ibis.Table) -> list[str]: + return list(data.columns) + + +@singledispatch +def _select_columns(data: Any, columns: list[str]) -> Any: + msg = f"Unsupported data type: {type(data)}" + raise TypeError(msg) + + +@_select_columns.register +def _(data: pd.DataFrame, columns: list[str]) -> pd.DataFrame: return data.loc[:, columns] -def _row_count(data: pd.DataFrame | pa.Table) -> int: - if isinstance(data, pa.Table): - return cast(int, data.num_rows) - return len(data) +@_select_columns.register +def _(data: pa.Table, columns: list[str]) -> pa.Table: + return data.select(columns) + + +@_select_columns.register +def _(data: ibis.Table, columns: list[str]) -> ibis.Table: + return data.select(columns) def _validate_insert_columns( @@ -143,14 +168,14 @@ def connection(self) -> ibis.BaseBackend: def create_table( self, name: str, - obj: pd.DataFrame | pa.Table | ir.Table | None = None, + obj: pd.DataFrame | pa.Table | ibis.Table | None = None, schema: ibis.Schema | None = None, overwrite: bool = False, - ) -> ir.Table: + ) -> ibis.Table: """Create a table in the database.""" return self.connection.create_table(name, obj=obj, schema=schema, overwrite=overwrite) - def create_view(self, name: str, expr: ir.Table) -> ir.Table: + def create_view(self, name: str, expr: ibis.Table) -> ibis.Table: """Create a view in the database.""" return self.connection.create_view(name, expr, overwrite=False) @@ -166,11 +191,11 @@ def list_tables(self) -> list[str]: """List all user tables in the database.""" return cast(list[str], self.connection.list_tables()) - def table(self, name: str) -> ir.Table: + def table(self, name: str) -> ibis.Table: """Return an ibis table expression for the named table.""" return self.connection.table(name) - def insert(self, name: str, data: pd.DataFrame | pa.Table) -> None: + def insert(self, name: str, data: pd.DataFrame | pa.Table | ibis.Table) -> None: """Insert data into an existing table. Validates that the data columns match the target table, reorders them, @@ -182,7 +207,7 @@ def insert(self, name: str, data: pd.DataFrame | pa.Table) -> None: _validate_insert_columns(name, target_columns, _get_columns(data)) ordered = _select_columns(data, target_columns) self.connection.insert(name, ordered) - logger.trace("Inserted {} rows into {}", _row_count(ordered), name) + logger.trace("Inserted data into {}", name) @abstractmethod def delete_rows(self, name: str, values: dict[str, Any]) -> None: @@ -197,13 +222,13 @@ def execute(self, expr: ir.Expr) -> pd.DataFrame: for large tables.""" return cast(pd.DataFrame, self.connection.execute(expr)) - def sql(self, query: str) -> ir.Table: + def sql(self, query: str) -> ibis.Table: """Create an ibis table expression from a raw SQL string.""" return self.connection.sql(query) def write_parquet( self, - expr: ir.Table, + expr: ibis.Table, path: str, partition_by: list[str] | None = None, ) -> None: @@ -214,7 +239,7 @@ def write_parquet( self.connection.to_parquet(expr, path) @abstractmethod - def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, ObjectType]: + def create_view_from_parquet(self, path: str, name: str) -> tuple[ibis.Table, ObjectType]: """Create a view or table backed by a Parquet file. Returns the table expression and the type of object created, since some @@ -235,53 +260,74 @@ def execute_sql_to_df(self, query: str) -> pd.DataFrame: logger.trace("execute_sql_to_df: {}", query) return cast(pd.DataFrame, self.sql(query).execute()) - def read_query(self, expr: ir.Table, config: TimeBaseModel) -> pd.DataFrame: + def read_query(self, expr: ibis.Table, config: TimeBaseModel) -> pd.DataFrame: """Execute an Ibis expression and return a pandas DataFrame.""" return self.execute(self.apply_schema_types(expr, config)) - def apply_schema_types(self, expr: ir.Table, config: TimeBaseModel) -> ir.Table: + def apply_schema_types(self, expr: ibis.Table, config: TimeBaseModel) -> ibis.Table: """Return ``expr`` with backend-specific casts applied so its Ibis type matches ``config``. - Default: no-op. Backends whose storage loses type information (e.g. - Spark, which stores all timestamps as session-local) should override to - add lazy ``cast`` expressions, so callers get correctly-typed pandas - output without forcing a materialize-then-normalize round-trip. + Default: no-op. Backends whose schema cannot express the full type + (e.g. Spark, which has no per-column timezone — only a session-wide + ``spark.sql.session.timeZone`` — so ``TIMESTAMP`` columns are reported + as tz-naive even when the values are true UTC instants) should + override to add lazy ``cast`` expressions. This lets callers get + correctly-typed pandas output without forcing a + materialize-then-normalize round-trip. """ return expr def write_table( self, - data: pd.DataFrame | pa.Table, + data: pd.DataFrame | pa.Table | ibis.Table, name: str, configs: Sequence[TimeBaseModel], if_exists: str = "append", ) -> None: - """Write tabular data to the database, applying backend-specific normalization.""" + """Write tabular data to the database, applying backend-specific normalization. + + ``ibis.Table`` inputs are passed through to the underlying connection so + materialization can be deferred; backends that cannot ingest an ibis + expression directly should override :meth:`_prepare_write_data` to + materialize it. + """ _check_one_config_per_datetime_column(configs) prepared = self._prepare_write_data(data, configs) self._apply_if_exists(prepared, name, if_exists) + @singledispatchmethod def _prepare_write_data( self, - data: pd.DataFrame | pa.Table, + data: Any, configs: Sequence[TimeBaseModel], - ) -> pd.DataFrame | pa.Table: + ) -> pd.DataFrame | pa.Table | ibis.Table: """Normalize data before insert/create_table. Default behavior is the DuckDB path: accept Arrow natively when possible, otherwise convert to pandas to normalize tz-sensitive columns. Subclasses that cannot ingest Arrow directly should convert here. """ - if isinstance(data, pa.Table) and _arrow_needs_timestamp_normalization(data, configs): - data = data.to_pandas() - if isinstance(data, pd.DataFrame): - data = _normalize_timestamps(data, configs) + msg = f"Unsupported data type: {type(data)}" + raise TypeError(msg) + + @_prepare_write_data.register + def _(self, data: pd.DataFrame, configs: Sequence[TimeBaseModel]) -> pd.DataFrame: + return _normalize_timestamps(data, configs) + + @_prepare_write_data.register + def _(self, data: pa.Table, configs: Sequence[TimeBaseModel]) -> pd.DataFrame | pa.Table: + if _arrow_needs_timestamp_normalization(data, configs): + return self._prepare_write_data(data.to_pandas(), configs) + return data + + @_prepare_write_data.register + def _(self, data: ibis.Table, configs: Sequence[TimeBaseModel]) -> ibis.Table: return data def _apply_if_exists( self, - data: pd.DataFrame | pa.Table, + data: pd.DataFrame | pa.Table | ibis.Table, name: str, if_exists: str, ) -> None: @@ -289,8 +335,7 @@ def _apply_if_exists( case "append": self.insert(name, data) case "replace": - self.drop_table(name) - self.create_table(name, data) + self.create_table(name, data, overwrite=True) case "fail": self.create_table(name, data) case _: diff --git a/src/chronify/ibis/duckdb_backend.py b/src/chronify/ibis/duckdb_backend.py index f91c16f..fd743a7 100644 --- a/src/chronify/ibis/duckdb_backend.py +++ b/src/chronify/ibis/duckdb_backend.py @@ -78,14 +78,14 @@ def delete_rows(self, name: str, values: dict[str, Any]) -> None: def execute(self, expr: ir.Expr) -> pd.DataFrame: # Bypass Ibis's generic pandas materialization and use DuckDB's native # cursor.fetch_df(), which is zero-copy from Arrow. - if isinstance(expr, ir.Table): + if isinstance(expr, ibis.Table): sql = self._connection.compile(expr) return cast(pd.DataFrame, self._connection.con.execute(sql).fetch_df()) return cast(pd.DataFrame, self._connection.execute(expr)) def write_parquet( self, - expr: ir.Table, + expr: ibis.Table, path: str, partition_by: list[str] | None = None, ) -> None: @@ -94,7 +94,7 @@ def write_parquet( else: self._connection.to_parquet(expr, path) - def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, ObjectType]: + def create_view_from_parquet(self, path: str, name: str) -> tuple[ibis.Table, ObjectType]: parquet_path = Path(path) if parquet_path.is_dir(): read_path = str(parquet_path / "**" / "*.parquet").replace("\\", "/") diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index c953e9a..be0108a 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -8,7 +8,6 @@ import ibis import ibis.expr.datatypes as dt -import ibis.expr.types as ir import pandas as pd import pyarrow as pa from loguru import logger @@ -59,10 +58,10 @@ def connection(self) -> ibis.BaseBackend: def create_table( self, name: str, - obj: pd.DataFrame | pa.Table | ir.Table | None = None, + obj: pd.DataFrame | pa.Table | ibis.Table | None = None, schema: ibis.Schema | None = None, overwrite: bool = False, - ) -> ir.Table: + ) -> ibis.Table: try: return self._connection.create_table(name, obj=obj, schema=schema, overwrite=overwrite) except Exception as exc: @@ -103,7 +102,7 @@ def _overwrite_without_deleted_rows(self, name: str, where: str, args: dict[str, def write_parquet( self, - expr: ir.Table, + expr: ibis.Table, path: str, partition_by: list[str] | None = None, ) -> None: @@ -112,7 +111,7 @@ def write_parquet( else: self._connection.to_parquet(expr, path) - def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, ObjectType]: + def create_view_from_parquet(self, path: str, name: str) -> tuple[ibis.Table, ObjectType]: self._connection.create_view(name, self._connection.read_parquet(path)) return self.table(name), ObjectType.VIEW @@ -140,13 +139,17 @@ def _remove_managed_table_location(self, name: str) -> None: if path.exists(): shutil.rmtree(path) - def apply_schema_types(self, expr: ir.Table, config: TimeBaseModel) -> ir.Table: - """Cast the time column to the Ibis type implied by ``config``. - - Spark stores all timestamps as session-local (no per-column tz - metadata), so a column declared ``TIMESTAMP_TZ`` reads back tz-naive - unless we re-type it in the expression. The session is pinned to UTC - (see :meth:`_validate_session`), so the cast is lossless. + def apply_schema_types(self, expr: ibis.Table, config: TimeBaseModel) -> ibis.Table: + """Re-attach the timezone annotation to the time column's Ibis type. + + Spark's schema has no per-column timezone — only a session-wide + ``spark.sql.session.timeZone`` — so Ibis reports any Spark + ``TIMESTAMP`` column as ``Timestamp(timezone=None)`` even though the + underlying values are true UTC instants. The session is pinned to UTC + (see :meth:`_validate_session`), so casting to + ``Timestamp(timezone="UTC")`` only rewrites the type annotation; no + value conversion occurs. Downstream consumers (pandas, Arrow) then + materialize a tz-aware column. """ if not isinstance(config, _DATETIME_RANGES): return expr diff --git a/src/chronify/ibis/sqlite_backend.py b/src/chronify/ibis/sqlite_backend.py index 72889e0..da8307b 100644 --- a/src/chronify/ibis/sqlite_backend.py +++ b/src/chronify/ibis/sqlite_backend.py @@ -3,10 +3,10 @@ import sqlite3 from datetime import datetime from pathlib import Path +from functools import singledispatchmethod from typing import Any, Sequence import ibis -import ibis.expr.types as ir import pandas as pd import pyarrow as pa from loguru import logger @@ -91,19 +91,21 @@ def connection(self) -> ibis.BaseBackend: def create_table( self, name: str, - obj: pd.DataFrame | pa.Table | ir.Table | None = None, + obj: pd.DataFrame | pa.Table | ibis.Table | None = None, schema: ibis.Schema | None = None, overwrite: bool = False, - ) -> ir.Table: - if isinstance(obj, ir.Table): + ) -> ibis.Table: + if isinstance(obj, ibis.Table): # SQLite CREATE TABLE AS SELECT loses datetime type info. # Execute the expression first, then create from the DataFrame. df = self._connection.execute(obj) return self._connection.create_table(name, obj=df, overwrite=overwrite) return self._connection.create_table(name, obj=obj, schema=schema, overwrite=overwrite) - def insert(self, name: str, data: pd.DataFrame | pa.Table) -> None: - if isinstance(data, pa.Table): + def insert(self, name: str, data: pd.DataFrame | pa.Table | ibis.Table) -> None: + if isinstance(data, ibis.Table): + data = data.execute() + elif isinstance(data, pa.Table): data = data.to_pandas() # Use raw SQLite cursor for parameterized inserts con = self._connection.con # raw sqlite3 connection @@ -131,7 +133,7 @@ def delete_rows(self, name: str, values: dict[str, Any]) -> None: self._commit_if_needed() logger.trace("Deleted rows from {} matching {}", name, values) - def create_view_from_parquet(self, path: str, name: str) -> tuple[ir.Table, ObjectType]: + def create_view_from_parquet(self, path: str, name: str) -> tuple[ibis.Table, ObjectType]: # SQLite can't read Parquet natively. Load into a table instead. df = pd.read_parquet(path) return self.create_table(name, obj=df), ObjectType.TABLE @@ -172,9 +174,10 @@ def _commit_if_needed(self) -> None: if not self._in_transaction: self._connection.con.commit() + @singledispatchmethod def _prepare_write_data( self, - data: pd.DataFrame | pa.Table, + data: Any, configs: Sequence[TimeBaseModel], ) -> pd.DataFrame: """SQLite stores timestamps as text, so joins compare raw strings. @@ -183,8 +186,19 @@ def _prepare_write_data( written from different source zones (e.g., source table in ``Etc/GMT+5`` vs. a mapping table localized from tz-naive input) align. """ - if isinstance(data, pa.Table): - data = data.to_pandas() + msg = f"Unsupported data type: {type(data)}" + raise TypeError(msg) + + @_prepare_write_data.register + def _(self, data: pa.Table, configs: Sequence[TimeBaseModel]) -> pd.DataFrame: + return self._prepare_write_data(data.to_pandas(), configs) + + @_prepare_write_data.register + def _(self, data: ibis.Table, configs: Sequence[TimeBaseModel]) -> pd.DataFrame: + return self._prepare_write_data(data.execute(), configs) + + @_prepare_write_data.register + def _(self, data: pd.DataFrame, configs: Sequence[TimeBaseModel]) -> pd.DataFrame: data = _normalize_timestamps(data, configs) copied = False for config in configs: diff --git a/src/chronify/store.py b/src/chronify/store.py index 58652f8..8d7c579 100644 --- a/src/chronify/store.py +++ b/src/chronify/store.py @@ -4,7 +4,6 @@ from datetime import tzinfo import ibis -import ibis.expr.types as ir import pandas as pd from loguru import logger @@ -102,7 +101,7 @@ def dispose(self) -> None: """Dispose of the current connections.""" self._backend.dispose() - def get_table(self, name: str) -> ir.Table: + def get_table(self, name: str) -> ibis.Table: """Return the ibis Table expression.""" if not self.has_table(name): msg = f"{name=}" @@ -117,7 +116,7 @@ def list_tables(self) -> list[str]: """Return a list of user tables in the database.""" return [x for x in self._backend.list_tables() if x != SchemaManager.SCHEMAS_TABLE] - def try_get_table(self, name: str) -> ir.Table | None: + def try_get_table(self, name: str) -> ibis.Table | None: """Return the ibis Table expression or None if it is not stored.""" if not self.has_table(name): return None @@ -246,7 +245,7 @@ def _ingest_from_csv( def ingest_pivoted_table( self, - data: pd.DataFrame | ir.Table, + data: pd.DataFrame | ibis.Table, src_schema: PivotedTableSchema | CsvTableSchema, dst_schema: TableSchema, ) -> bool: @@ -255,7 +254,7 @@ def ingest_pivoted_table( def ingest_pivoted_tables( self, - data: Iterable[pd.DataFrame | ir.Table], + data: Iterable[pd.DataFrame | ibis.Table], src_schema: PivotedTableSchema | CsvTableSchema, dst_schema: TableSchema, ) -> bool: @@ -274,7 +273,7 @@ def ingest_pivoted_tables( def _ingest_pivoted_tables( self, - data: Iterable[pd.DataFrame | ir.Table], + data: Iterable[pd.DataFrame | ibis.Table], src_schema: PivotedTableSchema | CsvTableSchema, dst_schema: TableSchema, ) -> bool: @@ -287,12 +286,12 @@ def _ingest_pivoted_tables( def _ingest_pivoted_table( self, - data: pd.DataFrame | ir.Table, + data: pd.DataFrame | ibis.Table, src_schema: PivotedTableSchema | CsvTableSchema, dst_schema: TableSchema, ) -> bool: assert src_schema.pivoted_dimension_name is not None - expr = data if isinstance(data, ir.Table) else ibis.memtable(data) + expr = data if isinstance(data, ibis.Table) else ibis.memtable(data) unpivoted = expr.pivot_longer( list(src_schema.value_columns), names_to=src_schema.pivoted_dimension_name, @@ -302,7 +301,7 @@ def _ingest_pivoted_table( def ingest_table( self, - data: pd.DataFrame | ir.Table, + data: pd.DataFrame | ibis.Table, schema: TableSchema, **kwargs: Any, ) -> bool: @@ -330,7 +329,7 @@ def ingest_table( def ingest_tables( self, - data: Iterable[pd.DataFrame | ir.Table], + data: Iterable[pd.DataFrame | ibis.Table], schema: TableSchema, **kwargs: Any, ) -> bool: @@ -353,7 +352,7 @@ def ingest_tables( def _ingest_tables( self, - data: Iterable[pd.DataFrame | ir.Table], + data: Iterable[pd.DataFrame | ibis.Table], schema: TableSchema, skip_time_checks: bool = False, ) -> bool: @@ -367,18 +366,17 @@ def _ingest_tables( def _ingest_table( self, - data: pd.DataFrame | ir.Table, + data: pd.DataFrame | ibis.Table, schema: TableSchema, ) -> bool: - df = data.execute() if isinstance(data, ir.Table) else data - check_columns(df.columns, schema.list_columns()) + check_columns(list(data.columns), schema.list_columns()) if not self._backend.has_table(schema.name): - self._backend.write_table(df, schema.name, [schema.time_config], if_exists="fail") + self._backend.write_table(data, schema.name, [schema.time_config], if_exists="fail") self._schema_mgr.add_schema(schema) return True else: - self._backend.write_table(df, schema.name, [schema.time_config], if_exists="append") + self._backend.write_table(data, schema.name, [schema.time_config], if_exists="append") return False def map_table_time_config( @@ -687,7 +685,7 @@ def localize_time_zone_by_column( self._schema_mgr.add_schema(dst_schema) return dst_schema - def read_query(self, query: ir.Table | str) -> ir.Table: + def read_query(self, query: ibis.Table | str) -> ibis.Table: """Return the query result as an Ibis Table expression. Call ``.execute()`` on the returned expression to materialize a pandas DataFrame. @@ -701,7 +699,7 @@ def read_query(self, query: ir.Table | str) -> ir.Table: return self._backend.sql(query) return query - def read_table(self, name: str) -> ir.Table: + def read_table(self, name: str) -> ibis.Table: """Return the table as an Ibis Table expression. Call ``.execute()`` on the returned expression to materialize a pandas DataFrame. @@ -734,7 +732,7 @@ def read_raw_query(self, query: str) -> pd.DataFrame: def write_query_to_parquet( self, - stmt: ir.Table | str, + stmt: ibis.Table | str, file_path: Path | str, overwrite: bool = False, partition_columns: Optional[list[str]] = None, @@ -852,7 +850,7 @@ def drop_table(self, name: str, if_exists: bool = False) -> None: self._schema_mgr.remove_schema(name) logger.info("Dropped table {}", name) - def create_view(self, schema: TableSchema, stmt: ir.Table) -> None: + def create_view(self, schema: TableSchema, stmt: ibis.Table) -> None: """Create a view in the database.""" self._backend.create_view(schema.name, stmt) self._schema_mgr.add_schema(schema) From 96bf82221e0b8f91b354e2db1676736a170d1f4c Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Mon, 13 Apr 2026 09:17:41 -0600 Subject: [PATCH 31/48] Drop intermediate ymdh table even when apply_mapping fails Wrap the apply_mapping call in try/finally so the intermediate table created by _intermediate_mapping_ymdp_to_ymdh is cleaned up on exceptions, preventing stale tables that would block subsequent runs. Co-Authored-By: Claude Opus 4.6 --- ...apper_column_representative_to_datetime.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/chronify/time_series_mapper_column_representative_to_datetime.py b/src/chronify/time_series_mapper_column_representative_to_datetime.py index fbf46e4..1e37ab0 100644 --- a/src/chronify/time_series_mapper_column_representative_to_datetime.py +++ b/src/chronify/time_series_mapper_column_representative_to_datetime.py @@ -92,19 +92,20 @@ def map_time( msg = f"No mapping available for {type(self._from_time_config)}" raise InvalidParameter(msg) - apply_mapping( - df_mapping, - mapping_schema, - from_schema, - self._to_schema, - self._backend, - self._data_adjustment, - output_file=output_file, - check_mapped_timestamps=check_mapped_timestamps, - ) - - if drop_table: - self._backend.drop_table(drop_table) + try: + apply_mapping( + df_mapping, + mapping_schema, + from_schema, + self._to_schema, + self._backend, + self._data_adjustment, + output_file=output_file, + check_mapped_timestamps=check_mapped_timestamps, + ) + finally: + if drop_table and self._backend.has_table(drop_table): + self._backend.drop_table(drop_table) def check_schema_consistency(self) -> None: if isinstance(self._from_time_config, MonthDayHourTimeNTZ): From d176f59f08a33797d8387b1236e76ef3e9766638 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Mon, 13 Apr 2026 09:41:48 -0600 Subject: [PATCH 32/48] Fix Spark test cleanup and period lookup --- pyproject.toml | 2 +- src/chronify/ibis/spark_backend.py | 16 ++++++++++++++-- ...s_mapper_column_representative_to_datetime.py | 5 ++--- tests/conftest.py | 2 +- tests/test_checker_representative_time.py | 5 +++-- tests/test_spark_backend.py | 2 +- 6 files changed, 22 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f6d1b9d..68346c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ [project.optional-dependencies] spark = [ "ibis-framework[pyspark]", - "pyspark >= 4.0, < 5", + "pyspark == 4.0.0", ] dev = [ diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index be0108a..9e0da31 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -24,14 +24,26 @@ class SparkBackend(IbisBackend): Requires pyspark to be installed (pip install chronify[spark]). """ - def __init__(self, session: Any = None) -> None: + def __init__(self, session: Any = None, *, owns_session: bool | None = None) -> None: + """Construct a SparkBackend. + + Parameters + ---------- + session + Optional pre-existing PySpark session. When provided, the backend + does not own the session by default and will not stop it on + ``dispose()``. + owns_session + Override whether ``dispose()`` stops ``session``. Defaults to + ``True`` only when ``session`` is not provided. + """ try: from pyspark.sql import SparkSession except ImportError as e: msg = "pyspark is required for SparkBackend. Install with: pip install chronify[spark]" raise ImportError(msg) from e - self._owns_session = session is None + self._owns_session = session is None if owns_session is None else owns_session if session is None: session = ( SparkSession.builder.master("local") diff --git a/src/chronify/time_series_mapper_column_representative_to_datetime.py b/src/chronify/time_series_mapper_column_representative_to_datetime.py index 1e37ab0..36d686c 100644 --- a/src/chronify/time_series_mapper_column_representative_to_datetime.py +++ b/src/chronify/time_series_mapper_column_representative_to_datetime.py @@ -132,9 +132,8 @@ def _intermediate_mapping_ymdp_to_ymdh(self) -> TableSchema: period_col = self._from_time_config.hour_columns[0] # Get distinct periods - df_periods = self._backend.execute_sql_to_df( - f"SELECT DISTINCT {period_col} FROM {self._from_schema.name}" - ) + table = self._backend.table(self._from_schema.name) + df_periods = self._backend.execute(table.select(period_col).distinct()) df_mapping = generate_period_mapping(df_periods.iloc[:, 0]) self._backend.write_table( df_mapping, diff --git a/tests/conftest.py b/tests/conftest.py index 4485a28..a2d82cc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -38,7 +38,7 @@ def _make_backend(name: str, tmp_path: Path | None = None, **kwargs: Any) -> Ibi .config("spark.sql.warehouse.dir", str(warehouse_dir)) .getOrCreate() ) - return SparkBackend(session=session, **kwargs) + return SparkBackend(session=session, owns_session=True, **kwargs) return make_backend(name, **kwargs) diff --git a/tests/test_checker_representative_time.py b/tests/test_checker_representative_time.py index d5dc297..e6da8e0 100644 --- a/tests/test_checker_representative_time.py +++ b/tests/test_checker_representative_time.py @@ -1,5 +1,6 @@ -import pytest +from typing import Any +import pytest import pandas as pd from chronify.ibis import IbisBackend @@ -9,7 +10,7 @@ def ingest_data_and_check( - backend: IbisBackend, df: pd.DataFrame, schema: TableSchema, error: tuple[any, str] + backend: IbisBackend, df: pd.DataFrame, schema: TableSchema, error: tuple[Any, str] ) -> None: backend.write_table(df, schema.name, [schema.time_config], if_exists="replace") diff --git a/tests/test_spark_backend.py b/tests/test_spark_backend.py index 6f7c9b9..baec5ee 100644 --- a/tests/test_spark_backend.py +++ b/tests/test_spark_backend.py @@ -404,7 +404,7 @@ def test_spark_dispose(tmp_path: Path) -> None: .config("spark.sql.warehouse.dir", str(warehouse_dir)) .getOrCreate() ) - backend = SparkBackend(session=session) + backend = SparkBackend(session=session, owns_session=True) backend.dispose() From 85423a15106892dcc138586566ef655b98438a7b Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Mon, 13 Apr 2026 09:50:00 -0600 Subject: [PATCH 33/48] Fix mypy error --- src/chronify/ibis/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/chronify/ibis/__init__.py b/src/chronify/ibis/__init__.py index 607780a..8166a9d 100644 --- a/src/chronify/ibis/__init__.py +++ b/src/chronify/ibis/__init__.py @@ -1,5 +1,7 @@ """Ibis backend abstraction layer for Chronify.""" +from typing import Any + from chronify.exceptions import InvalidParameter from chronify.ibis.base import IbisBackend, ObjectType from chronify.ibis.duckdb_backend import DuckDBBackend @@ -17,7 +19,7 @@ def make_backend( name: str, database: str | None = None, - **kwargs: object, + **kwargs: Any, ) -> IbisBackend: """Create an IbisBackend instance. From cb1fe37b303cba6fd19f5ab8e88ffc087506c41b Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Mon, 13 Apr 2026 10:53:51 -0600 Subject: [PATCH 34/48] Restore docstrings on Store ingest and time-config methods MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reinstates the full Parameters / Raises / Examples docstrings that were reduced to one-liners during the SQLAlchemy→ibis migration. Drops references to the removed connection and scratch_dir parameters. Co-Authored-By: Claude Opus 4.6 --- src/chronify/store.py | 351 +++++++++++++++++++++++++++- src/chronify/time_zone_converter.py | 41 +++- src/chronify/time_zone_localizer.py | 34 ++- 3 files changed, 415 insertions(+), 11 deletions(-) diff --git a/src/chronify/store.py b/src/chronify/store.py index 8d7c579..70eb24d 100644 --- a/src/chronify/store.py +++ b/src/chronify/store.py @@ -182,7 +182,56 @@ def ingest_from_csv( src_schema: CsvTableSchema, dst_schema: TableSchema, ) -> bool: - """Ingest data from a CSV file.""" + """Ingest data from a CSV file. + + Parameters + ---------- + path + Source data file + src_schema + Defines the schema of the source file. + dst_schema + Defines the destination table in the database. + + Returns + ------- + bool + Return True if a table was created. + + Raises + ------ + InvalidTable + Raised if the data does not match the schema. + + Examples + -------- + >>> resolution = timedelta(hours=1) + >>> time_config = DatetimeRange( + ... time_column="timestamp", + ... start=datetime(2020, 1, 1, 0), + ... length=8784, + ... resolution=timedelta(hours=1), + ... ) + >>> store = Store() + >>> store.ingest_from_csv( + ... "data.csv", + ... CsvTableSchema( + ... time_config=time_config, + ... pivoted_dimension_name="device", + ... value_columns=["device1", "device2", "device3"], + ... ), + ... TableSchema( + ... name="devices", + ... value_column="value", + ... time_config=time_config, + ... time_array_id_columns=["device"], + ... ), + ... ) + + See Also + -------- + ingest_from_csvs + """ return self.ingest_from_csvs((path,), src_schema, dst_schema) def ingest_from_csvs( @@ -191,7 +240,35 @@ def ingest_from_csvs( src_schema: CsvTableSchema, dst_schema: TableSchema, ) -> bool: - """Ingest data from multiple CSV files into the table specified by schema.""" + """Ingest data into the table specified by schema. If the table does not exist, + create it. This is faster than calling :meth:`ingest_from_csv` many times. + Each file is loaded into memory one at a time. + If any error occurs, all added data will be removed and the state of the database will + be the same as the original state. + + Parameters + ---------- + paths + Source data files + src_schema + Defines the schema of the source files. + dst_schema + Defines the destination table in the database. + + Returns + ------- + bool + Return True if a table was created. + + Raises + ------ + InvalidTable + Raised if the data does not match the schema. + + See Also + -------- + ingest_from_csv + """ table_existed = self._backend.has_table(dst_schema.name) try: with self._backend.transaction(): @@ -249,7 +326,67 @@ def ingest_pivoted_table( src_schema: PivotedTableSchema | CsvTableSchema, dst_schema: TableSchema, ) -> bool: - """Ingest pivoted data into the table specified by schema.""" + """Ingest pivoted data into the table specified by schema. If the table does not exist, + create it. Chronify will unpivot the data before ingesting it. + + Parameters + ---------- + data + Input data to ingest into the database. + src_schema + Defines the schema of the input data. + dst_schema + Defines the destination table in the database. + + Returns + ------- + bool + Return True if a table was created. + + Raises + ------ + InvalidTable + Raised if the data does not match the schema. + + Examples + -------- + >>> resolution = timedelta(hours=1) + >>> df = pd.DataFrame( + ... { + ... "timestamp": pd.date_range( + ... "2020-01-01", "2020-12-31 23:00:00", freq=resolution + ... ), + ... "device1": np.random.random(8784), + ... "device2": np.random.random(8784), + ... "device3": np.random.random(8784), + ... } + ... ) + >>> time_config = DatetimeRange( + ... time_column="timestamp", + ... start=datetime(2020, 1, 1, 0), + ... length=8784, + ... resolution=timedelta(hours=1), + ... ) + >>> store = Store() + >>> store.ingest_pivoted_table( + ... df, + ... PivotedTableSchema( + ... time_config=time_config, + ... pivoted_dimension_name="device", + ... value_columns=["device1", "device2", "device3"], + ... ), + ... TableSchema( + ... name="devices", + ... value_column="value", + ... time_config=time_config, + ... time_array_id_columns=["device"], + ... ), + ... ) + + See Also + -------- + ingest_pivoted_tables + """ return self.ingest_pivoted_tables((data,), src_schema, dst_schema) def ingest_pivoted_tables( @@ -258,7 +395,31 @@ def ingest_pivoted_tables( src_schema: PivotedTableSchema | CsvTableSchema, dst_schema: TableSchema, ) -> bool: - """Ingest pivoted data from multiple tables. Unpivot before ingesting.""" + """Ingest pivoted data into the table specified by schema. + + If the table does not exist, create it. Unpivot the data before ingesting it. + This is faster than calling :meth:`ingest_pivoted_table` many times. + If any error occurs, all added data will be removed and the state of the database will + be the same as the original state. + + Parameters + ---------- + data + Data to ingest into the database. + src_schema + Defines the schema of all input tables. + dst_schema + Defines the destination table in the database. + + Returns + ------- + bool + Return True if a table was created. + + See Also + -------- + ingest_pivoted_table + """ table_existed = self._backend.has_table(dst_schema.name) try: with self._backend.transaction(): @@ -324,6 +485,38 @@ def ingest_table( ------ InvalidTable Raised if the data does not match the schema. + + Examples + -------- + >>> store = Store() + >>> resolution = timedelta(hours=1) + >>> df = pd.DataFrame( + ... { + ... "timestamp": pd.date_range( + ... "2020-01-01", "2020-12-31 23:00:00", freq=resolution + ... ), + ... "value": np.random.random(8784), + ... } + ... ) + >>> df["id"] = 1 + >>> store.ingest_table( + ... df, + ... TableSchema( + ... name="devices", + ... value_column="value", + ... time_config=DatetimeRange( + ... time_column="timestamp", + ... start=datetime(2020, 1, 1, 0), + ... length=8784, + ... resolution=timedelta(hours=1), + ... ), + ... time_array_id_columns=["id"], + ... ), + ... ) + + See Also + -------- + ingest_tables """ return self.ingest_tables((data,), schema, **kwargs) @@ -333,7 +526,32 @@ def ingest_tables( schema: TableSchema, **kwargs: Any, ) -> bool: - """Ingest multiple input tables to the same database table.""" + """Ingest multiple input tables to the same database table. + All tables must have the same schema. + This offers significant performance advantages over calling :meth:`ingest_table` many + times. + + Parameters + ---------- + data + Input tables to ingest into one database table. + schema + Defines the destination table. + + Returns + ------- + bool + Return True if a table was created. + + Raises + ------ + InvalidTable + Raised if the data does not match the schema. + + See Also + -------- + ingest_table + """ created_table = False if not data: return created_table @@ -389,7 +607,73 @@ def map_table_time_config( check_mapped_timestamps: bool = False, ) -> None: """Map the existing table represented by src_name to a new table represented by - dst_schema with a different time configuration.""" + dst_schema with a different time configuration. + + Parameters + ---------- + src_name + Refers to the table name of the source data. + dst_schema + Defines the table to create in the database. Must not already exist. + data_adjustment + Defines how the dataframe may need to be adjusted with respect to time. + Data is only adjusted when the conditions apply. + wrap_time_allowed + Defines whether the time column is allowed to be wrapped according to the time + config in dst_schema when it does not line up with the time config + output_file + If set, write the mapped table to this Parquet file. + check_mapped_timestamps + Perform time checks on the result of the mapping operation. This can be slow and + is not required. + + Raises + ------ + InvalidTable + Raised if the schemas are incompatible. + TableAlreadyExists + Raised if the dst_schema name already exists. + + Examples + -------- + >>> store = Store() + >>> hours_per_year = 12 * 7 * 24 + >>> num_time_arrays = 3 + >>> df = pd.DataFrame( + ... { + ... "id": np.concatenate( + ... [np.repeat(i, hours_per_year) for i in range(1, 1 + num_time_arrays)] + ... ), + ... "month": np.tile(np.repeat(range(1, 13), 7 * 24), num_time_arrays), + ... "day_of_week": np.tile(np.tile(np.repeat(range(7), 24), 12), num_time_arrays), + ... "hour": np.tile(np.tile(range(24), 12 * 7), num_time_arrays), + ... "value": np.random.random(hours_per_year * num_time_arrays), + ... } + ... ) + >>> schema = TableSchema( + ... name="devices_by_representative_time", + ... value_column="value", + ... time_config=RepresentativePeriodTimeNTZ( + ... time_format=RepresentativePeriodFormat.ONE_WEEK_PER_MONTH_BY_HOUR, + ... ), + ... time_array_id_columns=["id"], + ... ) + >>> store.ingest_table(df, schema) + >>> store.map_table_time_config( + ... "devices_by_representative_time", + ... TableSchema( + ... name="devices_by_datetime", + ... value_column="value", + ... time_config=DatetimeRange( + ... time_column="timestamp", + ... start=datetime(2020, 1, 1, 0), + ... length=8784, + ... resolution=timedelta(hours=1), + ... ), + ... time_array_id_columns=["id"], + ... ), + ... ) + """ if self.has_table(dst_schema.name): msg = dst_schema.name raise TableAlreadyExists(msg) @@ -413,7 +697,60 @@ def convert_time_zone( output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> TableSchema: - """Convert the time zone of the existing table represented by src_name to a new time zone.""" + """Convert the time zone of the existing table represented by src_name to a new time zone. + + Parameters + ---------- + src_name + Refers to the table name of the source data. + time_zone + Time zone to convert to. + output_file + If set, write the mapped table to this Parquet file. + check_mapped_timestamps + Perform time checks on the result of the mapping operation. This can be slow and + is not required. + + Raises + ------ + TableAlreadyExists + Raised if the dst_schema name already exists. + + Examples + -------- + >>> store = Store() + >>> start = datetime(year=2018, month=1, day=1, tzinfo=ZoneInfo("Etc/GMT+5")) + >>> freq = timedelta(hours=1) + >>> hours_per_year = 8760 + >>> num_time_arrays = 1 + >>> df = pd.DataFrame( + ... { + ... "id": np.concatenate( + ... [np.repeat(i, hours_per_year) for i in range(1, 1 + num_time_arrays)] + ... ), + ... "timestamp": np.tile( + ... pd.date_range(start, periods=hours_per_year, freq="h"), num_time_arrays + ... ), + ... "value": np.random.random(hours_per_year * num_time_arrays), + ... } + ... ) + >>> schema = TableSchema( + ... name="some_data", + ... time_config=DatetimeRange( + ... time_column="timestamp", + ... start=start, + ... length=hours_per_year, + ... resolution=freq, + ... ), + ... time_array_id_columns=["id"], + ... value_column="value", + ... ) + >>> store.ingest_table(df, schema) + >>> to_time_zone = ZoneInfo("US/Mountain") + >>> dst_schema = store.convert_time_zone( + ... schema.name, to_time_zone, check_mapped_timestamps=True + ... ) + """ src_schema = self._schema_mgr.get_schema(src_name) tzc = TimeZoneConverter(self._backend, src_schema, time_zone) diff --git a/src/chronify/time_zone_converter.py b/src/chronify/time_zone_converter.py index 59312aa..0ba8285 100644 --- a/src/chronify/time_zone_converter.py +++ b/src/chronify/time_zone_converter.py @@ -110,7 +110,17 @@ def convert_time_zone( class TimeZoneConverter(TimeZoneConverterBase): - """Convert tz-aware timestamps to a specified time zone (tz-naive output).""" + """Class for time zone conversion of tz-aware, aligned_in_absolute_time + time series data to a specified time zone. + + Input data table must contain tz-aware timestamps. + Input time config must be of type DatetimeRange with Timestamp_TZ dtype and tz-aware start time. + Output data table will contain tz-naive timestamps with time zone recorded in a column + Output time config will be of type DatetimeRange with Timestamp_NTZ dtype and tz-naive start time. + + # TODO: support DatetimeRangeWithTZColumn as input time config - Issue #64 + # TODO: support wrap_time_allowed option - Issue #64 + """ def __init__( self, @@ -207,7 +217,34 @@ def _create_mapping(self) -> tuple[pd.DataFrame, MappingTableSchema]: class TimeZoneConverterByColumn(TimeZoneConverterBase): - """Convert tz-aware timestamps to multiple time zones specified by a column.""" + """Class for time zone conversion of tz-aware, aligned_in_absolute_time + time series data based on a time zone column. + + Input data table must contain tz-aware timestamps and a time zone column. + Input time config must be of type DatetimeRangeWithTZColumn or DatetimeRange with Timestamp_TZ dtype. + - If DatetimeRange is used, time_zone_column must be provided. + - If DatetimeRangeWithTZColumn is used, it is converted to DatetimeRange internally. + time_zone_column, if provided, is ignored and instead taken from the time_config. + Output data table will contain tz-naive timestamps and the original time zone column. + Output time config will be of type DatetimeRangeWithTZColumn with Timestamp_NTZ dtype (see scenarios). + + I/O Time config scenarios: + -------------------------------- + To convert tz-aware timestamps aligned_in_absolute_time to multiple time zones specified in a column: + - wrap_time_allowed = False + - Input time config: DatetimeRange with tz-aware start time, Timestamp_TZ dtype + - Output time config: DatetimeRangeWithTZColumn with tz-aware start time, Timestamp_NTZ dtype + + To convert tz-aware timestamps aligned_in_absolute_time to multiple time zones specified in a column + and aligned_in_local_standard_time: + - wrap_time_allowed = True + - Input time config: DatetimeRange with tz-aware start time, Timestamp_TZ dtype + - Output time config: DatetimeRangeWithTZColumn with tz-naive start time, Timestamp_NTZ dtype + Note: converted time is wrapped within the local time range of the original timestamps. + -------------------------------- + + # TODO: support DatetimeRangeWithTZColumn as input time config - Issue #64 + """ def __init__( self, diff --git a/src/chronify/time_zone_localizer.py b/src/chronify/time_zone_localizer.py index 165a7c8..bb41f32 100644 --- a/src/chronify/time_zone_localizer.py +++ b/src/chronify/time_zone_localizer.py @@ -96,7 +96,14 @@ def localize_time_zone( class TimeZoneLocalizer(TimeZoneLocalizerBase): - """Localize tz-naive timestamps to a specified standard time zone.""" + """Class for time zone localization of tz-naive time series data to a specified time zone. + + Input data table must contain tz-naive timestamps. + Input time config must be of type DatetimeRange with Timestamp_NTZ dtype and tz-naive start time. + to_time_zone must be a standard time zone (without DST) or None. + Output data table will contain tz-aware timestamps. + Output time config will be of type DatetimeRange with Timestamp_TZ dtype and tz-aware start time. + """ def __init__( self, @@ -188,7 +195,30 @@ def localize_time_zone( class TimeZoneLocalizerByColumn(TimeZoneLocalizerBase): - """Localize tz-naive timestamps to multiple time zones specified by a column.""" + """Class for time zone localization of tz-naive time series data based on a time zone column. + + Input data table must contain tz-naive timestamps and a time zone column. + Time zones in the time zone column must be standard time zones (without DST). + Input time config must be of type DatetimeRangeWithTZColumn or DatetimeRange with Timestamp_NTZ dtype. + - If DatetimeRangeWithTZColumn is used, time_zone_column, if provided, is ignored. + - If DatetimeRange is used, time_zone_column must be provided. It is then converted to + DatetimeRangeWithTZColumn internally. + Output data table will contain tz-aware timestamps and the original time zone column. + Output time config can be of type DatetimeRange or DatetimeRangeWithTZColumn with Timestamp_TZ dtype (see scenarios). + + I/O Time config scenarios: + -------------------------------- + To localize tz-naive timestamps aligned_in_local_standard_time to multiple time zones specified in a column: + - Input time config: DatetimeRangeWithTZColumn with tz-naive start time, Timestamp_NTZ dtype + - Output time config: DatetimeRangeWithTZColumn with tz-naive start time, Timestamp_TZ dtype + + To localize tz-naive timestamps aligned_in_absolute_time to multiple time zones specified in a column: + - Input time config: DatetimeRangeWithTZColumn with tz-aware start time, Timestamp_NTZ dtype + - Output time config: DatetimeRange with tz-aware start time, Timestamp_TZ dtype + Note: output time config is reduced to DatetimeRange (from DatetimeRangeWithTZColumn) + since all timestamps are tz-aware and aligned in absolute time. + -------------------------------- + """ time_zone_column: str From 4eedfa27b4d8a8d7688da1e44b0b53f248aad949 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Mon, 13 Apr 2026 13:39:53 -0600 Subject: [PATCH 35/48] Replace raw SQL with ibis expressions in checker and CSV reader Convert null-consistency and time-array count checks in TimeSeriesChecker from f-string SQL to ibis filter/group_by/join expressions, and replace duckdb.sql() in csv_io with rel.project(). Co-Authored-By: Claude Opus 4.6 --- src/chronify/csv_io.py | 3 +- src/chronify/time_series_checker.py | 106 +++++++++------------------- 2 files changed, 33 insertions(+), 76 deletions(-) diff --git a/src/chronify/csv_io.py b/src/chronify/csv_io.py index f6310c8..3dcef59 100644 --- a/src/chronify/csv_io.py +++ b/src/chronify/csv_io.py @@ -26,5 +26,4 @@ def read_csv(path: Path | str, schema: CsvTableSchema, **kwargs: Any) -> DuckDBP expr = f"timezone('{time_config.start.tzinfo.key}', {column}) AS {column}" # type: ignore exprs.append(expr) - expr = ",".join(exprs) - return duckdb.sql(f"SELECT {expr} FROM rel") + return rel.project(",".join(exprs)) diff --git a/src/chronify/time_series_checker.py b/src/chronify/time_series_checker.py index 6f1d471..a3a9a31 100644 --- a/src/chronify/time_series_checker.py +++ b/src/chronify/time_series_checker.py @@ -1,6 +1,7 @@ -from typing import Optional +from typing import Any, Optional, cast from datetime import datetime, tzinfo +import ibis import pandas as pd from chronify.exceptions import InvalidTable @@ -118,14 +119,10 @@ def _check_null_consistency(self) -> None: if len(time_columns) == 1: return - all_are_null = " AND ".join((f"{x} IS NULL" for x in time_columns)) - any_are_null = " OR ".join((f"{x} IS NULL" for x in time_columns)) - query_all = f"SELECT COUNT(*) FROM {self._table_name} WHERE {all_are_null}" - query_any = f"SELECT COUNT(*) FROM {self._table_name} WHERE {any_are_null}" - df_all = self._backend.execute_sql_to_df(query_all) - df_any = self._backend.execute_sql_to_df(query_any) - count_all = df_all.iloc[0, 0] - count_any = df_any.iloc[0, 0] + table = self._backend.table(self._table_name) + null_exprs = [table[col].isnull() for col in time_columns] + count_all = int(cast(Any, table.filter(ibis.and_(*null_exprs)).count().execute())) + count_any = int(cast(Any, table.filter(ibis.or_(*null_exprs)).count().execute())) if count_all != count_any: msg = ( "If any time columns have a NULL value for a row, all time columns in that " @@ -143,71 +140,35 @@ def _check_expected_timestamps_by_time_array(self, count: int) -> None: else: has_tz_naive_prevailing = False - id_cols = ",".join(self._schema.time_array_id_columns) - time_cols = ",".join(self._schema.time_config.list_time_columns()) - where_clause = f"{self._time_generator.list_time_columns()[0]} IS NOT NULL" - on_expr = " AND ".join([f"t1.{x} = t2.{x}" for x in self._schema.time_array_id_columns]) - t1_id_cols = ",".join((f"t1.{x}" for x in self._schema.time_array_id_columns)) - - if not self._schema.time_array_id_columns: - query = f""" - WITH distinct_time_values_by_array AS ( - SELECT DISTINCT {time_cols} - FROM {self._table_name} - WHERE {where_clause} - ), - t1 AS ( - SELECT COUNT(*) AS distinct_count_by_ta - FROM distinct_time_values_by_array - ), - t2 AS ( - SELECT COUNT(*) AS count_by_ta - FROM {self._table_name} - WHERE {where_clause} - ) - SELECT - t1.distinct_count_by_ta - ,t2.count_by_ta - FROM t1 - CROSS JOIN t2 - """ + id_cols = self._schema.time_array_id_columns + time_cols = self._schema.time_config.list_time_columns() + first_time_col = self._time_generator.list_time_columns()[0] + + table = self._backend.table(self._table_name) + filtered = table.filter(table[first_time_col].notnull()) + + if not id_cols: + distinct_count_by_ta = int( + cast(Any, filtered.select(time_cols).distinct().count().execute()) + ) + count_by_ta = int(cast(Any, filtered.count().execute())) + df = pd.DataFrame( + [{"distinct_count_by_ta": distinct_count_by_ta, "count_by_ta": count_by_ta}] + ) else: - query = f""" - WITH distinct_time_values_by_array AS ( - SELECT DISTINCT {id_cols}, {time_cols} - FROM {self._table_name} - WHERE {where_clause} - ), - t1 AS ( - SELECT {id_cols}, COUNT(*) AS distinct_count_by_ta - FROM distinct_time_values_by_array - GROUP BY {id_cols} - ), - t2 AS ( - SELECT {id_cols}, COUNT(*) AS count_by_ta - FROM {self._table_name} - WHERE {where_clause} - GROUP BY {id_cols} - ) - SELECT - t1.distinct_count_by_ta - ,t2.count_by_ta - ,{t1_id_cols} - FROM t1 - JOIN t2 - ON {on_expr} - """ - - df = self._backend.execute_sql_to_df(query) + counts = filtered.group_by(id_cols).aggregate(count_by_ta=filtered.count()) + distinct_rows = filtered.select(id_cols + time_cols).distinct() + distinct = distinct_rows.group_by(id_cols).aggregate( + distinct_count_by_ta=distinct_rows.count() + ) + df = counts.join(distinct, id_cols).execute() + for _, result in df.iterrows(): - distinct_count_by_ta = result.iloc[0] - count_by_ta = result.iloc[1] + distinct_count_by_ta = result["distinct_count_by_ta"] + count_by_ta = result["count_by_ta"] if has_tz_naive_prevailing and not count_by_ta == count: - id_vals = result.iloc[2:] - values = ", ".join( - f"{x}={y}" for x, y in zip(self._schema.time_array_id_columns, id_vals) - ) + values = ", ".join(f"{x}={result[x]}" for x in id_cols) msg = ( f"The count of time values in each time array must be {count}." f"Time array identifiers: {values}. " @@ -216,10 +177,7 @@ def _check_expected_timestamps_by_time_array(self, count: int) -> None: raise InvalidTable(msg) if not has_tz_naive_prevailing and not count_by_ta == count == distinct_count_by_ta: - id_vals = result.iloc[2:] - values = ", ".join( - f"{x}={y}" for x, y in zip(self._schema.time_array_id_columns, id_vals) - ) + values = ", ".join(f"{x}={result[x]}" for x in id_cols) msg = ( f"The count of time values in each time array must be {count}, and each " "value must be distinct. " From d7a4ce85e02c4a408f811be25a914380c191b79f Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Tue, 14 Apr 2026 16:33:02 -0600 Subject: [PATCH 36/48] Fix UTC handling for Spark --- .github/workflows/ci.yml | 4 +- docs/how_tos/spark_backend.md | 8 ---- src/chronify/ibis/spark_backend.py | 69 ++++++++++++++++++++++-------- tests/test_spark_backend.py | 25 ++++++----- 4 files changed, 69 insertions(+), 37 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 80e2984..966410d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,7 +7,7 @@ on: pull_request: env: - DEFAULT_PYTHON: "3.12" + DEFAULT_PYTHON: "3.13" DEFAULT_OS: ubuntu-latest jobs: @@ -48,7 +48,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: 3.12 + python-version: 3.13 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/docs/how_tos/spark_backend.md b/docs/how_tos/spark_backend.md index 096c47f..4cf69b3 100644 --- a/docs/how_tos/spark_backend.md +++ b/docs/how_tos/spark_backend.md @@ -14,14 +14,6 @@ $ tar -xzf spark-4.0.1-bin-hadoop3.tgz $ export SPARK_HOME=$(pwd)/spark-4.0.1-bin-hadoop3 ``` -Start a Thrift server. This allows JDBC clients to send SQL queries to an in-process Spark cluster -running in local mode. -``` -$ $SPARK_HOME/sbin/start-thriftserver.sh --master=spark://$(hostname):7077 -``` - -The URL to connect to this server is `hive://localhost:10000/default` - ## Installation on an HPC The chronify development team uses this [package](https://github.com/NatLabRockies/sparkctl) to run Spark on NLR's HPC. diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index 9e0da31..a941e24 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -2,7 +2,8 @@ import uuid import shutil -from typing import Any +from contextlib import contextmanager +from typing import Any, Generator from pathlib import Path from urllib.parse import urlparse, unquote @@ -12,7 +13,7 @@ import pyarrow as pa from loguru import logger -from chronify.exceptions import InvalidOperation, InvalidParameter +from chronify.exceptions import InvalidOperation from chronify.ibis.base import IbisBackend, ObjectType, _DATETIME_RANGES from chronify.time import TimeDataType from chronify.time_configs import TimeBaseModel @@ -49,11 +50,24 @@ def __init__(self, session: Any = None, *, owns_session: bool | None = None) -> SparkSession.builder.master("local") .config("spark.sql.session.timeZone", "UTC") .config("spark.sql.parquet.outputTimestampType", "TIMESTAMP_MICROS") + .config("spark.sql.execution.arrow.pyspark.enabled", "true") .getOrCreate() ) - self._validate_session(session) self._session = session + # Arrow preserves TIMESTAMP instants across pandas boundaries; the + # non-Arrow path converts through JVM Calendar and mis-resolves DST + # fall-back values (e.g. 2012-11-04 UTC 07:00 collapses into 08:00). + session.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true") + # ibis.pyspark.connect() forces spark.sql.session.timeZone=UTC as a + # side effect. Preserve the caller's tz; UTC is re-pinned locally by + # _pinned_utc_session() around operations that require it. + tz_key = "spark.sql.session.timeZone" + prev_tz = session.conf.get(tz_key, None) self._connection = ibis.pyspark.connect(session) + if prev_tz is None: + session.conf.unset(tz_key) + elif session.conf.get(tz_key, None) != prev_tz: + session.conf.set(tz_key, prev_tz) @property def name(self) -> str: @@ -157,11 +171,11 @@ def apply_schema_types(self, expr: ibis.Table, config: TimeBaseModel) -> ibis.Ta Spark's schema has no per-column timezone — only a session-wide ``spark.sql.session.timeZone`` — so Ibis reports any Spark ``TIMESTAMP`` column as ``Timestamp(timezone=None)`` even though the - underlying values are true UTC instants. The session is pinned to UTC - (see :meth:`_validate_session`), so casting to - ``Timestamp(timezone="UTC")`` only rewrites the type annotation; no - value conversion occurs. Downstream consumers (pandas, Arrow) then - materialize a tz-aware column. + underlying values are true UTC instants. The cast to + ``Timestamp(timezone="UTC")`` is intended as a metadata-only + re-annotation; this only holds while ``session.timeZone=UTC``, so + callers that materialize the result must do so under + :meth:`_pinned_utc_session` (as :meth:`read_query` does). """ if not isinstance(config, _DATETIME_RANGES): return expr @@ -173,15 +187,36 @@ def apply_schema_types(self, expr: ibis.Table, config: TimeBaseModel) -> ibis.Ta **{config.time_column: expr[config.time_column].cast(dt.Timestamp(timezone="UTC"))} ) - @staticmethod - def _validate_session(session: Any) -> None: - time_zone = session.conf.get("spark.sql.session.timeZone", None) or "UTC" - if time_zone != "UTC": - msg = ( - "SparkBackend requires spark.sql.session.timeZone=UTC to preserve " - f"timestamp semantics, got {time_zone!r}." - ) - raise InvalidParameter(msg) + def read_query(self, expr: ibis.Table, config: TimeBaseModel) -> pd.DataFrame: + """Execute ``expr`` with the session pinned to UTC for materialization. + + The UTC pin ensures the metadata-only cast applied by + :meth:`apply_schema_types` is not reinterpreted as a value conversion + through the caller's ``session.timeZone``. + """ + with self._pinned_utc_session(): + return self.execute(self.apply_schema_types(expr, config)) + + @contextmanager + def _pinned_utc_session(self) -> Generator[None, None, None]: + """Temporarily set ``spark.sql.session.timeZone=UTC`` for the block. + + Restores the previous value on exit. Not thread-safe with concurrent + users of the same Spark session. + """ + key = "spark.sql.session.timeZone" + prev = self._session.conf.get(key, None) + if prev == "UTC": + yield + return + self._session.conf.set(key, "UTC") + try: + yield + finally: + if prev is None: + self._session.conf.unset(key) + else: + self._session.conf.set(key, prev) def _quote_identifier(identifier: str) -> str: diff --git a/tests/test_spark_backend.py b/tests/test_spark_backend.py index baec5ee..5a2e9ac 100644 --- a/tests/test_spark_backend.py +++ b/tests/test_spark_backend.py @@ -6,7 +6,6 @@ import pandas as pd import pytest -from chronify.exceptions import InvalidParameter from chronify.ibis.spark_backend import SparkBackend from chronify.models import TableSchema from chronify.store import Store @@ -408,16 +407,22 @@ def test_spark_dispose(tmp_path: Path) -> None: backend.dispose() -def test_spark_backend_rejects_non_utc_session() -> None: +def test_spark_backend_accepts_non_utc_session() -> None: + """Non-UTC session tz is allowed; UTC is pinned only for read_query.""" _require_java_home() pyspark = pytest.importorskip("pyspark.sql") - session = ( - pyspark.SparkSession.builder.master("local") - .config("spark.sql.session.timeZone", "America/Denver") - .getOrCreate() - ) + session = pyspark.SparkSession.builder.master("local").getOrCreate() + prev_tz = session.conf.get("spark.sql.session.timeZone", None) + session.conf.set("spark.sql.session.timeZone", "America/Denver") try: - with pytest.raises(InvalidParameter, match="spark.sql.session.timeZone=UTC"): - SparkBackend(session=session) + backend = SparkBackend(session=session, owns_session=False) + assert session.conf.get("spark.sql.session.timeZone") == "America/Denver" + with backend._pinned_utc_session(): + assert session.conf.get("spark.sql.session.timeZone") == "UTC" + assert session.conf.get("spark.sql.session.timeZone") == "America/Denver" + backend.dispose() finally: - session.stop() + if prev_tz is None: + session.conf.unset("spark.sql.session.timeZone") + else: + session.conf.set("spark.sql.session.timeZone", prev_tz) From 280c95088c6057b23d9005e25052f1edceed404f Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Fri, 24 Apr 2026 15:07:20 -0600 Subject: [PATCH 37/48] Update docs --- README.md | 2 +- docs/how_tos/getting_started/installation.md | 4 ++-- docs/how_tos/ingest_multiple_tables.md | 21 +++++++++----------- docs/how_tos/spark_backend.md | 13 ++++++++++-- docs/index.md | 20 ++++--------------- 5 files changed, 27 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index 8dbb507..63f4089 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ To use DuckDB or SQLite as the backend: $ pip install chronify ``` -To use Apache Spark via Apache Thrift Server as the backend: +To use Apache Spark as the backend: ``` $ pip install "chronify[spark]" ``` diff --git a/docs/how_tos/getting_started/installation.md b/docs/how_tos/getting_started/installation.md index 398dce1..e5ec15d 100644 --- a/docs/how_tos/getting_started/installation.md +++ b/docs/how_tos/getting_started/installation.md @@ -37,8 +37,8 @@ To use DuckDB or SQLite as the backend: $ pip install chronify ``` -To use Apache Spark via Apache Thrift Server as the backend, you must install pyhive. -This command will install the necessary dependencies. +To use Apache Spark as the backend, install chronify with the ``spark`` extra, +which pulls in PySpark: ```{eval-rst} .. code-block:: console diff --git a/docs/how_tos/ingest_multiple_tables.md b/docs/how_tos/ingest_multiple_tables.md index 336b2c4..149d5e5 100644 --- a/docs/how_tos/ingest_multiple_tables.md +++ b/docs/how_tos/ingest_multiple_tables.md @@ -1,7 +1,6 @@ # How to Ingest Multiple Tables Efficiently There are a few important considerations when ingesting many tables: -- Use one database connection. - Avoid loading all tables into memory at once, if possible. - Ensure additions are atomic. If anything fails, the final state should be the same as the initial state. @@ -48,25 +47,23 @@ dst_schema = TableSchema( Chronify will manage the database connection and errors. ```python store.ingest_from_csvs( - src_schema, - dst_schema, ( "/path/to/file1.csv", "/path/to/file2.csv", "/path/to/file3.csv", ), - ) + src_schema, + dst_schema, +) ``` ## Self-Managed -Open one connection to the database for the duration of your additions. Handle errors. +Wrap the additions in a backend transaction. Any tables or views created within the block are +automatically dropped if an exception is raised. ```python -with store.engine.connect() as conn: - try: - store.ingest_from_csv(src_schema, dst_schema, "/path/to/file1.csv") - store.ingest_from_csv(src_schema, dst_schema, "/path/to/file2.csv") - store.ingest_from_csv(src_schema, dst_schema, "/path/to/file3.csv") - except Exception: - conn.rollback() +with store.backend.transaction(): + store.ingest_from_csv("/path/to/file1.csv", src_schema, dst_schema) + store.ingest_from_csv("/path/to/file2.csv", src_schema, dst_schema) + store.ingest_from_csv("/path/to/file3.csv", src_schema, dst_schema) ``` diff --git a/docs/how_tos/spark_backend.md b/docs/how_tos/spark_backend.md index 4cf69b3..46153c8 100644 --- a/docs/how_tos/spark_backend.md +++ b/docs/how_tos/spark_backend.md @@ -62,9 +62,18 @@ schema = TableSchema( ```python from chronify import Store +from chronify.ibis.spark_backend import SparkBackend -store = Store.create_new_hive_store("hive://localhost:10000/default") -store.create_view_from_parquet("data.parquet") +store = Store(backend=SparkBackend()) +store.create_view_from_parquet("data.parquet", schema) +``` + +Alternatively, pass a pre-configured PySpark session: +```python +from pyspark.sql import SparkSession + +session = SparkSession.builder.master("local").getOrCreate() +store = Store(backend=SparkBackend(session=session)) ``` Verify the data: diff --git a/docs/index.md b/docs/index.md index 964ff86..0cede77 100644 --- a/docs/index.md +++ b/docs/index.md @@ -4,7 +4,7 @@ This package implements validation, mapping, and storage of time series data in Python-based modeling packages. ## Features -- Stores time series data in any database supported by SQLAlchemy. +- Stores time series data in any database supported by Ibis (DuckDB, SQLite, and Spark). - Supports data ingestion in a variety of file formats and configurations. - Supports efficient retrieval of time series through SQL queries. - Validates consistency of timestamps and resolution. @@ -23,24 +23,12 @@ Python-based modeling packages. ``` ## Supported Backends -While chronify should work with any database supported by SQLAlchemy, it has been tested with -the following: +Chronify uses [Ibis](https://ibis-project.org) for all database operations. The following +backends are supported: - DuckDB (default) - SQLite -- Apache Spark through Apache Thrift Server - -DuckDB and SQLite are fully supported. - -Because of limitations in the backend software, chronify functionality with Spark is limited to -the following: - -- Create a view into an existing Parquet file (or directory). -- Perform time series checks. -- Map between time configurations. -- Write output data to Parquet files. - -There is no support for creating tables and ingesting data with Spark. +- Apache Spark (via PySpark) ## How to use this guide - Refer to [How Tos](#how-tos-page) for step-by-step instructions for creating store and ingesting data. From c6c7a6a17bd16c18481e25a35563922262da2001 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Mon, 27 Apr 2026 09:17:01 -0600 Subject: [PATCH 38/48] Refactor code --- docs/how_tos/getting_started/quick_start.md | 16 ++--- docs/how_tos/ingest_multiple_tables.md | 33 ++++++--- docs/how_tos/map_time_config.md | 4 +- src/chronify/ibis/base.py | 76 +++++++++++++++----- src/chronify/ibis/duckdb_backend.py | 3 + src/chronify/ibis/spark_backend.py | 3 + src/chronify/ibis/sqlite_backend.py | 13 ++-- src/chronify/schema_manager.py | 29 ++++++-- src/chronify/store.py | 30 ++++---- src/chronify/time_series_checker.py | 78 +++++++++++++-------- tests/test_column_representative_period.py | 30 ++++++-- 11 files changed, 219 insertions(+), 96 deletions(-) diff --git a/docs/how_tos/getting_started/quick_start.md b/docs/how_tos/getting_started/quick_start.md index 1773c91..0e7ec78 100644 --- a/docs/how_tos/getting_started/quick_start.md +++ b/docs/how_tos/getting_started/quick_start.md @@ -28,16 +28,16 @@ store.ingest_tables( time_array_id_columns=["id"], ) ) -query = "SELECT timestamp, value FROM devices WHERE id = 2" -df = store.read_query(query).execute() +devices = store.read_table("devices") +df = devices[devices.id == 2]["timestamp", "value"].execute() df.head() ``` ``` - timestamp value id -0 2020-01-01 00:00:00 0.594748 2 -1 2020-01-01 01:00:00 0.608295 2 -2 2020-01-01 02:00:00 0.297535 2 -3 2020-01-01 03:00:00 0.870238 2 -4 2020-01-01 04:00:00 0.376144 2 + timestamp value +0 2020-01-01 00:00:00 0.594748 +1 2020-01-01 01:00:00 0.608295 +2 2020-01-01 02:00:00 0.297535 +3 2020-01-01 03:00:00 0.870238 +4 2020-01-01 04:00:00 0.376144 ``` diff --git a/docs/how_tos/ingest_multiple_tables.md b/docs/how_tos/ingest_multiple_tables.md index 149d5e5..264b1e8 100644 --- a/docs/how_tos/ingest_multiple_tables.md +++ b/docs/how_tos/ingest_multiple_tables.md @@ -13,12 +13,15 @@ device. ```python from datetime import datetime, timedelta -import numpy as np -import pandas as pd -from chronify import DatetimeRange, Store, TableSchema, CsvTableSchema +from chronify import ( + ColumnDType, + CsvTableSchema, + DatetimeRange, + Store, + TableSchema, +) store = Store.create_in_memory_db() -resolution = timedelta(hours=1) time_config = DatetimeRange( time_column="timestamp", start=datetime(2020, 1, 1, 0), @@ -28,22 +31,23 @@ time_config = DatetimeRange( src_schema = CsvTableSchema( time_config=time_config, column_dtypes=[ - ColumnDType(name="timestamp", dtype=DateTime(timezone=False)), - ColumnDType(name="device1", dtype=Double()), - ColumnDType(name="device2", dtype=Double()), - ColumnDType(name="device3", dtype=Double()), + ColumnDType(name="timestamp", dtype="datetime"), + ColumnDType(name="device1", dtype="float"), + ColumnDType(name="device2", dtype="float"), + ColumnDType(name="device3", dtype="float"), ], value_columns=["device1", "device2", "device3"], pivoted_dimension_name="device", ) dst_schema = TableSchema( name="devices", + time_config=time_config, value_column="value", - time_array_id_columns=["id"], + time_array_id_columns=["device"], ) ``` -## Automated through chronfiy +## Automated through chronify Chronify will manage the database connection and errors. ```python store.ingest_from_csvs( @@ -67,3 +71,12 @@ with store.backend.transaction(): store.ingest_from_csv("/path/to/file2.csv", src_schema, dst_schema) store.ingest_from_csv("/path/to/file3.csv", src_schema, dst_schema) ``` + +```{note} +Real database transaction semantics depend on the backend. The DuckDB and SQLite backends issue +a real `BEGIN` / `COMMIT` / `ROLLBACK` around the block, so partial inserts to existing tables +are rolled back on failure. The Spark backend does not support transactions; the context +manager falls back to best-effort cleanup that drops any tables or views created inside the +block when an exception is raised, but rows appended to pre-existing tables cannot be +rolled back. +``` diff --git a/docs/how_tos/map_time_config.md b/docs/how_tos/map_time_config.md index 265514d..1050082 100644 --- a/docs/how_tos/map_time_config.md +++ b/docs/how_tos/map_time_config.md @@ -51,7 +51,7 @@ schema = TableSchema( ) store = Store.create_in_memory_db() store.ingest_table(df, schema) -store.read_query(f"SELECT * FROM {src_table_name} LIMIT 5").execute().head() +store.read_table(src_table_name).limit(5).execute() ``` ``` @@ -77,7 +77,7 @@ dst_schema = TableSchema( ) ) store.map_table_time_config(src_table_name, dst_schema) -store.read_query(f"SELECT * FROM {dst_table_name} LIMIT 5").execute().head() +store.read_table(dst_table_name).limit(5).execute() ``` ``` diff --git a/src/chronify/ibis/base.py b/src/chronify/ibis/base.py index 1fafc0e..7ee1a51 100644 --- a/src/chronify/ibis/base.py +++ b/src/chronify/ibis/base.py @@ -5,12 +5,13 @@ from contextlib import contextmanager from enum import StrEnum from functools import singledispatch, singledispatchmethod -from typing import Any, Generator, Sequence, cast +from typing import Any, Generator, Iterable, cast import ibis import ibis.expr.types as ir import pandas as pd import pyarrow as pa +import pyarrow.compute as pc from loguru import logger from pandas import DatetimeTZDtype @@ -27,7 +28,7 @@ DatetimeRanges = DatetimeRange | DatetimeRangeWithTZColumn -def _check_one_config_per_datetime_column(configs: Sequence[TimeBaseModel]) -> None: +def _check_one_config_per_datetime_column(configs: Iterable[TimeBaseModel]) -> None: time_col_count = Counter( config.time_column for config in configs if isinstance(config, DatetimeRangeBase) ) @@ -39,15 +40,18 @@ def _check_one_config_per_datetime_column(configs: Sequence[TimeBaseModel]) -> N def _normalize_timestamps( df: pd.DataFrame, - configs: Sequence[TimeBaseModel], + configs: Iterable[TimeBaseModel], ) -> pd.DataFrame: - """Normalize datetime columns so their pandas dtype matches the schema config.""" + """Normalize datetime columns so their pandas dtype matches the schema config. + Does not change the caller's DataFrame. + """ copied = False + columns = set(df.columns) for config in configs: if not isinstance(config, _DATETIME_RANGES): continue col = config.time_column - if col not in df.columns: + if col not in columns: continue if not pd.api.types.is_datetime64_any_dtype(df[col]): continue @@ -70,7 +74,7 @@ def _normalize_timestamps( def _arrow_needs_timestamp_normalization( table: pa.Table, - configs: Sequence[TimeBaseModel], + configs: Iterable[TimeBaseModel], ) -> bool: fields = {field.name: field.type for field in table.schema} for config in configs: @@ -87,6 +91,37 @@ def _arrow_needs_timestamp_normalization( return False +def _normalize_arrow_timestamps( + table: pa.Table, + configs: Iterable[TimeBaseModel], +) -> pa.Table: + """Normalize timestamp columns of an Arrow table to match the schema configs. + + Casts tz-aware → tz-naive (preserving UTC instants) and localizes tz-naive → + UTC, matching the semantics of :func:`_normalize_timestamps` for pandas. Stays + in Arrow so backends that ingest Arrow natively avoid a pandas round-trip. + """ + indices = {name: i for i, name in enumerate(table.column_names)} + for config in configs: + if not isinstance(config, _DATETIME_RANGES): + continue + idx = indices.get(config.time_column) + if idx is None: + continue + arr = table.column(idx) + if not pa.types.is_timestamp(arr.type): + continue + is_tz_aware = arr.type.tz is not None + if config.dtype == TimeDataType.TIMESTAMP_NTZ and is_tz_aware: + new_arr = arr.cast(pa.timestamp(arr.type.unit)) + elif config.dtype == TimeDataType.TIMESTAMP_TZ and not is_tz_aware: + new_arr = pc.assume_timezone(arr, "UTC") + else: + continue + table = table.set_column(idx, table.column_names[idx], new_arr) + return table + + @singledispatch def _get_columns(data: Any) -> list[str]: msg = f"Unsupported data type: {type(data)}" @@ -116,7 +151,7 @@ def _select_columns(data: Any, columns: list[str]) -> Any: @_select_columns.register def _(data: pd.DataFrame, columns: list[str]) -> pd.DataFrame: - return data.loc[:, columns] + return data[columns] @_select_columns.register @@ -165,6 +200,10 @@ def database(self) -> str | None: def connection(self) -> ibis.BaseBackend: """Return the underlying ibis connection.""" + @abstractmethod + def _supports_parquet_partitioning(self) -> bool: + """Return True if the backend supports Hive partitioning of Parquet files.""" + def create_table( self, name: str, @@ -234,9 +273,12 @@ def write_parquet( ) -> None: """Write an ibis expression result to a Parquet file.""" if partition_by: - msg = f"{self.name} backend does not support partitioned Parquet writes." - raise NotImplementedError(msg) - self.connection.to_parquet(expr, path) + if not self._supports_parquet_partitioning(): + msg = f"{self.name} backend does not support partitioned Parquet writes." + raise NotImplementedError(msg) + self.connection.to_parquet(expr, path, partition_by=partition_by) + else: + self.connection.to_parquet(expr, path) @abstractmethod def create_view_from_parquet(self, path: str, name: str) -> tuple[ibis.Table, ObjectType]: @@ -258,7 +300,7 @@ def execute_sql(self, query: str) -> None: def execute_sql_to_df(self, query: str) -> pd.DataFrame: """Execute a raw SQL query and return a DataFrame.""" logger.trace("execute_sql_to_df: {}", query) - return cast(pd.DataFrame, self.sql(query).execute()) + return self.execute(self.sql(query)) def read_query(self, expr: ibis.Table, config: TimeBaseModel) -> pd.DataFrame: """Execute an Ibis expression and return a pandas DataFrame.""" @@ -282,7 +324,7 @@ def write_table( self, data: pd.DataFrame | pa.Table | ibis.Table, name: str, - configs: Sequence[TimeBaseModel], + configs: Iterable[TimeBaseModel], if_exists: str = "append", ) -> None: """Write tabular data to the database, applying backend-specific normalization. @@ -300,7 +342,7 @@ def write_table( def _prepare_write_data( self, data: Any, - configs: Sequence[TimeBaseModel], + configs: Iterable[TimeBaseModel], ) -> pd.DataFrame | pa.Table | ibis.Table: """Normalize data before insert/create_table. @@ -312,17 +354,17 @@ def _prepare_write_data( raise TypeError(msg) @_prepare_write_data.register - def _(self, data: pd.DataFrame, configs: Sequence[TimeBaseModel]) -> pd.DataFrame: + def _(self, data: pd.DataFrame, configs: Iterable[TimeBaseModel]) -> pd.DataFrame: return _normalize_timestamps(data, configs) @_prepare_write_data.register - def _(self, data: pa.Table, configs: Sequence[TimeBaseModel]) -> pd.DataFrame | pa.Table: + def _(self, data: pa.Table, configs: Iterable[TimeBaseModel]) -> pa.Table: if _arrow_needs_timestamp_normalization(data, configs): - return self._prepare_write_data(data.to_pandas(), configs) + return _normalize_arrow_timestamps(data, configs) return data @_prepare_write_data.register - def _(self, data: ibis.Table, configs: Sequence[TimeBaseModel]) -> ibis.Table: + def _(self, data: ibis.Table, configs: Iterable[TimeBaseModel]) -> ibis.Table: return data def _apply_if_exists( diff --git a/src/chronify/ibis/duckdb_backend.py b/src/chronify/ibis/duckdb_backend.py index fd743a7..d6f6ffb 100644 --- a/src/chronify/ibis/duckdb_backend.py +++ b/src/chronify/ibis/duckdb_backend.py @@ -130,6 +130,9 @@ def _commit_transaction(self) -> None: def _rollback_transaction(self) -> None: self._connection.con.execute("ROLLBACK") + def _supports_parquet_partitioning(self) -> bool: + return True + def _quote_identifier(identifier: str) -> str: """Quote a SQL identifier for DuckDB, escaping embedded double quotes.""" diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index a941e24..fca7c64 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -218,6 +218,9 @@ def _pinned_utc_session(self) -> Generator[None, None, None]: else: self._session.conf.set(key, prev) + def _supports_parquet_partitioning(self) -> bool: + return True + def _quote_identifier(identifier: str) -> str: """Quote a SQL identifier for Spark SQL, escaping embedded backticks.""" diff --git a/src/chronify/ibis/sqlite_backend.py b/src/chronify/ibis/sqlite_backend.py index da8307b..82b5c41 100644 --- a/src/chronify/ibis/sqlite_backend.py +++ b/src/chronify/ibis/sqlite_backend.py @@ -4,7 +4,7 @@ from datetime import datetime from pathlib import Path from functools import singledispatchmethod -from typing import Any, Sequence +from typing import Any, Iterable import ibis import pandas as pd @@ -178,7 +178,7 @@ def _commit_if_needed(self) -> None: def _prepare_write_data( self, data: Any, - configs: Sequence[TimeBaseModel], + configs: Iterable[TimeBaseModel], ) -> pd.DataFrame: """SQLite stores timestamps as text, so joins compare raw strings. @@ -190,15 +190,15 @@ def _prepare_write_data( raise TypeError(msg) @_prepare_write_data.register - def _(self, data: pa.Table, configs: Sequence[TimeBaseModel]) -> pd.DataFrame: + def _(self, data: pa.Table, configs: Iterable[TimeBaseModel]) -> pd.DataFrame: return self._prepare_write_data(data.to_pandas(), configs) @_prepare_write_data.register - def _(self, data: ibis.Table, configs: Sequence[TimeBaseModel]) -> pd.DataFrame: + def _(self, data: ibis.Table, configs: Iterable[TimeBaseModel]) -> pd.DataFrame: return self._prepare_write_data(data.execute(), configs) @_prepare_write_data.register - def _(self, data: pd.DataFrame, configs: Sequence[TimeBaseModel]) -> pd.DataFrame: + def _(self, data: pd.DataFrame, configs: Iterable[TimeBaseModel]) -> pd.DataFrame: data = _normalize_timestamps(data, configs) copied = False for config in configs: @@ -216,6 +216,9 @@ def _(self, data: pd.DataFrame, configs: Sequence[TimeBaseModel]) -> pd.DataFram data[config.time_column] = data[config.time_column].dt.tz_convert("UTC") return data + def _supports_parquet_partitioning(self) -> bool: + return False + def _infer_sqlite_path(connection: ibis.BaseBackend) -> str | None: """Return the database file path for an ibis SQLite connection, or None for in-memory.""" diff --git a/src/chronify/schema_manager.py b/src/chronify/schema_manager.py index f4c9bdf..ce3ed20 100644 --- a/src/chronify/schema_manager.py +++ b/src/chronify/schema_manager.py @@ -4,7 +4,7 @@ import pandas as pd from loguru import logger -from chronify.exceptions import InvalidParameter, TableNotStored +from chronify.exceptions import InvalidOperation, InvalidParameter, TableNotStored from chronify.ibis.base import IbisBackend from chronify.models import TableSchema @@ -31,17 +31,20 @@ def _create_schemas_table(self) -> None: schema = ibis.schema({"name": "string", "schema": "string"}) try: self._backend.create_table(self.SCHEMAS_TABLE, schema=schema) - except Exception: + except Exception as exc: # On Spark, a stale warehouse directory can cause # LOCATION_ALREADY_EXISTS even though list_tables() didn't find - # the table. Drop the stale remnant and retry. - logger.debug("Retrying schemas table creation after dropping stale remnant.") + # the table. Drop the stale remnant and retry. Other failures + # (permissions, connectivity, etc.) must propagate. + if self._backend.name != "spark" or "LOCATION_ALREADY_EXISTS" not in str(exc): + raise + logger.debug("Retrying schemas table creation after dropping stale Spark remnant.") self._backend.drop_table(self.SCHEMAS_TABLE) self._backend.create_table(self.SCHEMAS_TABLE, schema=schema) def add_schema(self, schema: TableSchema) -> None: """Add the schema to the store.""" - if schema.name in self._cache: + if schema.name in self._cache or self._schema_row_exists(schema.name): msg = f"A schema with name={schema.name!r} is already registered" raise InvalidParameter(msg) df = pd.DataFrame({"name": [schema.name], "schema": [schema.model_dump_json()]}) @@ -49,6 +52,11 @@ def add_schema(self, schema: TableSchema) -> None: self._cache[schema.name] = schema logger.trace("Added schema for table {}", schema.name) + def _schema_row_exists(self, name: str) -> bool: + table = self._backend.table(self.SCHEMAS_TABLE) + df = self._backend.execute(table.filter(table["name"] == name).limit(1)) + return not df.empty + def get_schema(self, name: str) -> TableSchema: """Retrieve the schema for the table with name.""" schema = self._cache.get(name) @@ -77,6 +85,13 @@ def _rebuild_cache(self) -> None: for _, row in df.iterrows(): name = row["name"] schema = TableSchema(**json.loads(row["schema"])) - assert name == schema.name - assert name not in self._cache + if name != schema.name: + msg = ( + f"schemas table is corrupt: row name={name!r} does not match " + f"schema.name={schema.name!r}" + ) + raise InvalidOperation(msg) + if name in self._cache: + msg = f"schemas table is corrupt: duplicate entry for name={name!r}" + raise InvalidOperation(msg) self._cache[name] = schema diff --git a/src/chronify/store.py b/src/chronify/store.py index 70eb24d..9c64d4d 100644 --- a/src/chronify/store.py +++ b/src/chronify/store.py @@ -5,6 +5,7 @@ import ibis import pandas as pd +import pyarrow as pa from loguru import logger from chronify.exceptions import ( @@ -303,22 +304,24 @@ def _ingest_from_csv( src_schema: CsvTableSchema, dst_schema: TableSchema, ) -> bool: - df = read_csv(path, src_schema).to_df() - columns = set(src_schema.list_columns()) - check_columns(df.columns, columns) - if isinstance(src_schema.time_config, IndexTimeRangeBase): if isinstance(dst_schema.time_config, DatetimeRange): raise NotImplementedError - else: - cls_name = dst_schema.time_config.__class__.__name__ - msg = f"{src_schema.time_config.__class__.__name__} cannot be converted to {cls_name}" - raise NotImplementedError(msg) + cls_name = dst_schema.time_config.__class__.__name__ + msg = f"{src_schema.time_config.__class__.__name__} cannot be converted to {cls_name}" + raise NotImplementedError(msg) + + rel = read_csv(path, src_schema) + check_columns(list(rel.columns), set(src_schema.list_columns())) + # Hand the CSV through as Arrow rather than pandas so DuckDB can ingest + # zero-copy and other backends only pay the materialization cost they + # would have paid anyway. + data = rel.to_arrow_table() if src_schema.pivoted_dimension_name is not None: - return self._ingest_pivoted_table(df, src_schema, dst_schema) + return self._ingest_pivoted_table(data, src_schema, dst_schema) - return self._ingest_table(df, dst_schema) + return self._ingest_table(data, dst_schema) def ingest_pivoted_table( self, @@ -447,7 +450,7 @@ def _ingest_pivoted_tables( def _ingest_pivoted_table( self, - data: pd.DataFrame | ibis.Table, + data: pd.DataFrame | pa.Table | ibis.Table, src_schema: PivotedTableSchema | CsvTableSchema, dst_schema: TableSchema, ) -> bool: @@ -584,10 +587,11 @@ def _ingest_tables( def _ingest_table( self, - data: pd.DataFrame | ibis.Table, + data: pd.DataFrame | pa.Table | ibis.Table, schema: TableSchema, ) -> bool: - check_columns(list(data.columns), schema.list_columns()) + cols = data.column_names if isinstance(data, pa.Table) else list(data.columns) + check_columns(cols, schema.list_columns()) if not self._backend.has_table(schema.name): self._backend.write_table(data, schema.name, [schema.time_config], if_exists="fail") diff --git a/src/chronify/time_series_checker.py b/src/chronify/time_series_checker.py index a3a9a31..b5c99e8 100644 --- a/src/chronify/time_series_checker.py +++ b/src/chronify/time_series_checker.py @@ -148,43 +148,65 @@ def _check_expected_timestamps_by_time_array(self, count: int) -> None: filtered = table.filter(table[first_time_col].notnull()) if not id_cols: + # Single time array: scalar count comparisons, no per-row materialization. + count_by_ta = int(cast(Any, filtered.count().execute())) + if has_tz_naive_prevailing: + if count_by_ta != count: + msg = ( + f"The count of time values in each time array must be {count}. " + f"count = {count_by_ta}" + ) + raise InvalidTable(msg) + return distinct_count_by_ta = int( cast(Any, filtered.select(time_cols).distinct().count().execute()) ) - count_by_ta = int(cast(Any, filtered.count().execute())) - df = pd.DataFrame( - [{"distinct_count_by_ta": distinct_count_by_ta, "count_by_ta": count_by_ta}] - ) - else: - counts = filtered.group_by(id_cols).aggregate(count_by_ta=filtered.count()) - distinct_rows = filtered.select(id_cols + time_cols).distinct() - distinct = distinct_rows.group_by(id_cols).aggregate( - distinct_count_by_ta=distinct_rows.count() - ) - df = counts.join(distinct, id_cols).execute() - - for _, result in df.iterrows(): - distinct_count_by_ta = result["distinct_count_by_ta"] - count_by_ta = result["count_by_ta"] - - if has_tz_naive_prevailing and not count_by_ta == count: - values = ", ".join(f"{x}={result[x]}" for x in id_cols) - msg = ( - f"The count of time values in each time array must be {count}." - f"Time array identifiers: {values}. " - f"count = {count_by_ta}" - ) - raise InvalidTable(msg) - - if not has_tz_naive_prevailing and not count_by_ta == count == distinct_count_by_ta: - values = ", ".join(f"{x}={result[x]}" for x in id_cols) + if not count_by_ta == count == distinct_count_by_ta: msg = ( f"The count of time values in each time array must be {count}, and each " "value must be distinct. " - f"Time array identifiers: {values}. " f"count = {count_by_ta}, distinct count = {distinct_count_by_ta}. " ) raise InvalidTable(msg) + return + + # Multiple time arrays: aggregate per id, then push the count comparison into + # SQL so we only materialize the first offending row (if any) for the error. + counts = filtered.group_by(id_cols).aggregate(count_by_ta=filtered.count()) + distinct_rows = filtered.select(id_cols + time_cols).distinct() + distinct = distinct_rows.group_by(id_cols).aggregate( + distinct_count_by_ta=distinct_rows.count() + ) + joined = counts.join(distinct, id_cols) + + if has_tz_naive_prevailing: + invalid = joined.filter(joined["count_by_ta"] != count) + else: + invalid = joined.filter( + (joined["count_by_ta"] != count) | (joined["distinct_count_by_ta"] != count) + ) + + bad_rows = self._backend.execute(invalid.limit(1)) + if bad_rows.empty: + return + + result = bad_rows.iloc[0] + values = ", ".join(f"{x}={result[x]}" for x in id_cols) + if has_tz_naive_prevailing: + msg = ( + f"The count of time values in each time array must be {count}." + f"Time array identifiers: {values}. " + f"count = {result['count_by_ta']}" + ) + else: + msg = ( + f"The count of time values in each time array must be {count}, and each " + "value must be distinct. " + f"Time array identifiers: {values}. " + f"count = {result['count_by_ta']}, " + f"distinct count = {result['distinct_count_by_ta']}. " + ) + raise InvalidTable(msg) def check_timestamp_lists( diff --git a/tests/test_column_representative_period.py b/tests/test_column_representative_period.py index 89d7125..957083e 100644 --- a/tests/test_column_representative_period.py +++ b/tests/test_column_representative_period.py @@ -1,5 +1,7 @@ """Tests for ColumnRepresentativeTimeGenerator with Period handler.""" +from datetime import tzinfo + import pandas as pd import pytest @@ -7,8 +9,9 @@ ColumnRepresentativeHandlerPeriod, ColumnRepresentativeTimeGenerator, ) -from chronify.exceptions import InvalidValue +from chronify.exceptions import InvalidOperation, InvalidValue from chronify.time_configs import ( + ColumnRepresentativeBase, MonthDayHourTimeNTZ, YearMonthDayPeriodTimeNTZ, ) @@ -25,6 +28,23 @@ def _make_period_config(year: int = 2024, length: int = 8760) -> YearMonthDayPer ) +class _UnsupportedColumnRepresentative(ColumnRepresentativeBase): + """ColumnRepresentativeBase subclass with no registered handler.""" + + @classmethod + def default_config(cls, length: int, year: int) -> "_UnsupportedColumnRepresentative": + return cls(year=year, length=length, month_column="month", day_column="day") + + def list_time_columns(self) -> list[str]: + return [self.month_column, self.day_column, *self.hour_columns] + + def get_time_zone_column(self) -> None: + return None + + def get_time_zones(self) -> list[tzinfo | None]: + return [] + + class TestColumnRepresentativeTimeGeneratorPeriod: def test_list_timestamps(self): config = _make_period_config(year=2024, length=8784) # 366 days * 24 hours @@ -81,8 +101,6 @@ def test_no_year_raises(self): def test_unsupported_config_raises(self): """ColumnRepresentativeBase subclasses not matching known handlers should raise.""" - config = _make_period_config() - gen = ColumnRepresentativeTimeGenerator(config) - # The generator was created successfully with a period config. - # Verify it works correctly. - assert len(gen.list_timestamps()) > 0 + config = _UnsupportedColumnRepresentative.default_config(length=8760, year=2024) + with pytest.raises(InvalidOperation, match="No time generator"): + ColumnRepresentativeTimeGenerator(config) From f7a52b8451ee36678c87b904c53f8b510c40311b Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Mon, 27 Apr 2026 11:53:54 -0600 Subject: [PATCH 39/48] Fix SQLite DDL atomicity and simplify transaction handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ibis-sqlite's create_table calls con.commit() internally, which silently terminated any outer BEGIN — DDL inside transaction() blocks survived rollback. Wrap the underlying sqlite3.Connection in a no-commit proxy for the duration of the transaction so DDL is now genuinely covered by rollback. Drop the unused `created` cleanup list from transaction(); no caller populated it. transaction() is now a thin DB BEGIN/COMMIT/ROLLBACK wrapper, with semantics documented per backend. Wrap apply_mapping and _intermediate_mapping_ymdp_to_ymdh in transaction() so DuckDB and SQLite get atomic rollback of the multi-step DDL for free. Spark falls back to manual cleanup. The output_file parquet (not transactional) is unlinked on failure. Drop the redundant DuckDB and Spark write_parquet overrides — the base class already delegates to ibis's to_parquet, which uses server-side COPY/native partitioning. Switch the quick_start doc example from pandas-style ibis indexing to filter()/select(), matching the codebase idiom. Add backend-parametric regression tests for transactional DDL on both DuckDB and SQLite. Co-Authored-By: Claude Opus 4.7 --- docs/how_tos/getting_started/quick_start.md | 2 +- src/chronify/ibis/base.py | 54 ++++--- src/chronify/ibis/duckdb_backend.py | 11 -- src/chronify/ibis/spark_backend.py | 11 -- src/chronify/ibis/sqlite_backend.py | 67 ++++++++- src/chronify/time_series_mapper_base.py | 110 +++++++++----- ...apper_column_representative_to_datetime.py | 51 ++++--- tests/test_ibis_base.py | 141 ++++++++++++++++-- tests/test_store.py | 134 +++++++++++++++++ 9 files changed, 463 insertions(+), 118 deletions(-) diff --git a/docs/how_tos/getting_started/quick_start.md b/docs/how_tos/getting_started/quick_start.md index 0e7ec78..2b915a7 100644 --- a/docs/how_tos/getting_started/quick_start.md +++ b/docs/how_tos/getting_started/quick_start.md @@ -29,7 +29,7 @@ store.ingest_tables( ) ) devices = store.read_table("devices") -df = devices[devices.id == 2]["timestamp", "value"].execute() +df = devices.filter(devices.id == 2).select("timestamp", "value").execute() df.head() ``` diff --git a/src/chronify/ibis/base.py b/src/chronify/ibis/base.py index 7ee1a51..112587f 100644 --- a/src/chronify/ibis/base.py +++ b/src/chronify/ibis/base.py @@ -185,6 +185,12 @@ class ObjectType(StrEnum): class IbisBackend(ABC): """Abstract base class defining the interface for Ibis database backends.""" + # Set while inside transaction(). Used for nesting (inner transaction() + # calls become passthroughs) and, on SQLite, by ``_commit_if_needed`` to + # decide whether DML auto-commits. Subclasses may shadow with an + # instance attribute. + _in_transaction: bool = False + @property @abstractmethod def name(self) -> str: @@ -406,29 +412,37 @@ def _rollback_transaction(self) -> None: """Roll back a real database transaction, if one was started.""" @contextmanager - def transaction(self) -> Generator[list[tuple[str, ObjectType]], None, None]: - """Context manager for pseudo-transactions. - - Tracks created objects (tables/views) so they can be cleaned up on failure. - On success, created objects are kept. On exception, they are dropped. - - Yields a list to which callers should append (name, ObjectType) tuples. + def transaction(self) -> Generator[None, None, None]: + """Context manager for a database transaction. + + On DuckDB and SQLite this issues a real ``BEGIN`` / ``COMMIT`` / + ``ROLLBACK`` and covers both DML and DDL — work inside the block is + atomic. Spark has no transaction support, so partial writes inside the + block are not rolled back; callers that need to clean up after a + Spark failure must do so themselves. + + Nesting is supported as a passthrough: a ``transaction()`` block + opened while one is already active does not start a new transaction + and does not commit or roll back on its own. The outermost block + controls the lifecycle, so callers can wrap operations that + themselves use ``transaction()`` (e.g. mapping helpers) without + savepoints. """ - created: list[tuple[str, ObjectType]] = [] + if self._in_transaction: + yield + return self._begin_transaction() + self._in_transaction = True try: - yield created + yield except Exception: - self._rollback_transaction() - for obj_name, obj_type in reversed(created): - try: - if obj_type == ObjectType.TABLE: - self.drop_table(obj_name) - else: - self.drop_view(obj_name) - logger.debug("Rolled back {} {}", obj_type.value, obj_name) - except Exception: - logger.warning("Failed to roll back {} {}", obj_type.value, obj_name) + try: + self._rollback_transaction() + finally: + self._in_transaction = False raise else: - self._commit_transaction() + try: + self._commit_transaction() + finally: + self._in_transaction = False diff --git a/src/chronify/ibis/duckdb_backend.py b/src/chronify/ibis/duckdb_backend.py index d6f6ffb..bd8a2a8 100644 --- a/src/chronify/ibis/duckdb_backend.py +++ b/src/chronify/ibis/duckdb_backend.py @@ -83,17 +83,6 @@ def execute(self, expr: ir.Expr) -> pd.DataFrame: return cast(pd.DataFrame, self._connection.con.execute(sql).fetch_df()) return cast(pd.DataFrame, self._connection.execute(expr)) - def write_parquet( - self, - expr: ibis.Table, - path: str, - partition_by: list[str] | None = None, - ) -> None: - if partition_by: - self._connection.to_parquet(expr, path, partition_by=partition_by) - else: - self._connection.to_parquet(expr, path) - def create_view_from_parquet(self, path: str, name: str) -> tuple[ibis.Table, ObjectType]: parquet_path = Path(path) if parquet_path.is_dir(): diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index fca7c64..249c6fe 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -126,17 +126,6 @@ def _overwrite_without_deleted_rows(self, name: str, where: str, args: dict[str, self._session.sql(f"DROP TABLE IF EXISTS {quoted_tmp}") self._remove_managed_table_location(tmp_name) - def write_parquet( - self, - expr: ibis.Table, - path: str, - partition_by: list[str] | None = None, - ) -> None: - if partition_by: - self._connection.to_parquet(expr, path, partitionBy=partition_by) - else: - self._connection.to_parquet(expr, path) - def create_view_from_parquet(self, path: str, name: str) -> tuple[ibis.Table, ObjectType]: self._connection.create_view(name, self._connection.read_parquet(path)) return self.table(name), ObjectType.VIEW diff --git a/src/chronify/ibis/sqlite_backend.py b/src/chronify/ibis/sqlite_backend.py index 82b5c41..ed8efd0 100644 --- a/src/chronify/ibis/sqlite_backend.py +++ b/src/chronify/ibis/sqlite_backend.py @@ -39,6 +39,36 @@ def _adapt_value(v: Any) -> Any: return v +class _NoCommitProxy: + """Connection proxy that suppresses commit/rollback. + + ibis-sqlite's ``create_table`` (and a few other DDL paths) wraps its work in + ``with self.begin() as cur:``, which calls ``con.commit()`` on success. + That commit terminates whatever outer transaction we started in + :meth:`SQLiteBackend._begin_transaction`, breaking atomicity for DDL. + + Swapping ``connection.con`` for a proxy that no-ops ``commit``/``rollback`` + while still forwarding cursor/execute calls to the real connection lets + chronify own the transaction lifecycle for the duration of + :meth:`SQLiteBackend.transaction`. We restore the real connection on + commit/rollback. + """ + + __slots__ = ("_real",) + + def __init__(self, real: sqlite3.Connection) -> None: + self._real = real + + def commit(self) -> None: # noqa: D401 + return None + + def rollback(self) -> None: # noqa: D401 + return None + + def __getattr__(self, name: str) -> Any: + return getattr(self._real, name) + + class SQLiteBackend(IbisBackend): """Ibis backend for SQLite databases.""" @@ -64,6 +94,7 @@ def __init__( raise ConflictingInputsError(msg) self._in_transaction = False + self._real_con: sqlite3.Connection | None = None self._owns_connection = connection is None if connection is None: db = str(database) @@ -140,8 +171,7 @@ def create_view_from_parquet(self, path: str, name: str) -> tuple[ibis.Table, Ob def execute_sql(self, query: str) -> None: logger.trace("execute_sql: {}", query) - con = self._connection.con - con.execute(query) + self._connection.con.execute(query) self._commit_if_needed() def dispose(self) -> None: @@ -159,16 +189,43 @@ def backup(self, dst: str) -> None: dst_con.close() def _begin_transaction(self) -> None: - self._connection.con.execute("BEGIN") + # Issue BEGIN before swapping in the proxy so a failure (e.g. nested + # BEGIN, or an outer caller already in a transaction) leaves the + # backend in its original state rather than permanently pointed at + # _NoCommitProxy. + real = self._connection.con + real.execute("BEGIN") + # Swap in a proxy so ibis's internal con.commit() during DDL (notably + # in create_table) is suppressed for the duration of the transaction. + # The real connection is held aside and used directly for BEGIN / + # COMMIT / ROLLBACK. + self._real_con = real + self._connection.con = _NoCommitProxy(real) # type: ignore[assignment] self._in_transaction = True def _commit_transaction(self) -> None: - self._connection.con.commit() + assert self._real_con is not None + real = self._real_con + # Restore connection state up front so a failure on COMMIT (e.g. the + # user finalized the transaction inside the block via + # ``execute_sql("COMMIT")``) doesn't leak the no-commit proxy into + # the backend's steady state. The ``in_transaction`` guard makes the + # wrapper a no-op when the underlying transaction has already been + # closed. + self._connection.con = real + self._real_con = None self._in_transaction = False + if real.in_transaction: + real.execute("COMMIT") def _rollback_transaction(self) -> None: - self._connection.con.rollback() + assert self._real_con is not None + real = self._real_con + self._connection.con = real + self._real_con = None self._in_transaction = False + if real.in_transaction: + real.execute("ROLLBACK") def _commit_if_needed(self) -> None: if not self._in_transaction: diff --git a/src/chronify/time_series_mapper_base.py b/src/chronify/time_series_mapper_base.py index 022be4f..1d7e50b 100644 --- a/src/chronify/time_series_mapper_base.py +++ b/src/chronify/time_series_mapper_base.py @@ -1,4 +1,5 @@ import abc +import uuid from pathlib import Path from typing import Any, Optional @@ -11,7 +12,7 @@ from chronify.time_series_checker import check_timestamps from chronify.time import TimeIntervalType, ResamplingOperationType, AggregationType from chronify.time_configs import TimeBasedDataAdjustment -from chronify.utils.path_utils import check_overwrite, to_path +from chronify.utils.path_utils import check_overwrite, delete_if_exists, to_path class TimeSeriesMapperBase(abc.ABC): @@ -74,7 +75,7 @@ def map_time(self) -> None: """Convert time columns with from_schema to to_schema configuration.""" -def apply_mapping( +def apply_mapping( # noqa: C901 df_mapping: pd.DataFrame, mapping_schema: MappingTableSchema, from_schema: TableSchema, @@ -85,52 +86,89 @@ def apply_mapping( output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> None: - """Apply mapping to create result table with process to clean up and roll back if checks fail.""" - backend.write_table( - df_mapping, - mapping_schema.name, - mapping_schema.time_configs, - if_exists="fail", + """Apply mapping to create result table. + + The whole multi-step DDL — write the mapping table, create the result + table (or write the parquet file), optionally create a temp view from + the parquet, run timestamp checks — runs inside a single + ``backend.transaction()``. On DuckDB and SQLite, a failure rolls back + every DB-side artifact atomically. Spark has no rollback, so the except + path also handles cleanup there. + + When ``output_file`` is set, the parquet write goes to a uniquely-named + staging path inside the transaction, then atomically renames over the + target only after the transaction commits. A failure leaves any + pre-existing target file untouched. + """ + output_path = to_path(output_file) if output_file is not None else None + staging_path = ( + output_path.with_name(f".{output_path.name}.staging.{uuid.uuid4().hex[:8]}") + if output_path is not None + else None ) created_tmp_obj: Optional[ObjectType] = None try: - _apply_mapping( - mapping_schema.name, - from_schema, - to_schema, - backend, - resampling_operation=resampling_operation, - output_file=output_file, - ) - if check_mapped_timestamps: - if output_file is not None: - output_file = to_path(output_file) - _, created_tmp_obj = backend.create_view_from_parquet( - str(output_file), to_schema.name - ) - try: + with backend.transaction(): + backend.write_table( + df_mapping, + mapping_schema.name, + mapping_schema.time_configs, + if_exists="fail", + ) + _apply_mapping( + mapping_schema.name, + from_schema, + to_schema, + backend, + resampling_operation=resampling_operation, + output_file=staging_path, + ) + if check_mapped_timestamps: + if staging_path is not None: + _, created_tmp_obj = backend.create_view_from_parquet( + str(staging_path), to_schema.name + ) check_timestamps( backend, to_schema.name, to_schema, leap_day_adjustment=data_adjustment.leap_day_adjustment, ) - except Exception: - logger.exception( - "check_timestamps failed on mapped table {}. Drop it", - to_schema.name, - ) - if output_file is None: - backend.drop_table(to_schema.name) - raise - finally: - if backend.has_table(mapping_schema.name): + # Drop temp artifacts inside the transaction so the commit doesn't + # retain them. The mapping table is always temp; the parquet view + # is temp only when we created it for the timestamp check. backend.drop_table(mapping_schema.name) - if created_tmp_obj is not None: - if created_tmp_obj == ObjectType.TABLE: + if created_tmp_obj is ObjectType.TABLE: backend.drop_table(to_schema.name) - else: + elif created_tmp_obj is ObjectType.VIEW: backend.drop_view(to_schema.name) + # Promote the staged parquet only after the transaction commits. + # The staging output may be a file (DuckDB) or a directory (Spark + # parquet writes are always directories), so explicitly remove any + # existing target first — ``Path.replace`` cannot overwrite a + # non-empty directory. + if staging_path is not None: + assert output_path is not None + delete_if_exists(output_path) + staging_path.replace(output_path) + except Exception: + logger.exception( + "Mapping failed for {} -> {}. Cleaning up.", from_schema.name, to_schema.name + ) + # Idempotent cleanup. On DuckDB/SQLite the rollback already dropped + # these objects (has_table returns False); on Spark it didn't. + if backend.has_table(mapping_schema.name): + backend.drop_table(mapping_schema.name) + if backend.has_table(to_schema.name): + if created_tmp_obj is ObjectType.VIEW: + backend.drop_view(to_schema.name) + else: + backend.drop_table(to_schema.name) + # Remove the staging output (file or directory); the original target + # is untouched because the rename never ran. + if staging_path is not None: + delete_if_exists(staging_path) + raise def _apply_mapping( # noqa: C901 diff --git a/src/chronify/time_series_mapper_column_representative_to_datetime.py b/src/chronify/time_series_mapper_column_representative_to_datetime.py index 36d686c..eab56e3 100644 --- a/src/chronify/time_series_mapper_column_representative_to_datetime.py +++ b/src/chronify/time_series_mapper_column_representative_to_datetime.py @@ -135,31 +135,40 @@ def _intermediate_mapping_ymdp_to_ymdh(self) -> TableSchema: table = self._backend.table(self._from_schema.name) df_periods = self._backend.execute(table.select(period_col).distinct()) df_mapping = generate_period_mapping(df_periods.iloc[:, 0]) - self._backend.write_table( - df_mapping, - mapping_table_name, - [self._from_time_config], - if_exists="fail", - ) try: - # Build the join query using ibis - ymdp_table = self._backend.table(self._from_schema.name) - mapping_table = self._backend.table(mapping_table_name) - - # Select all columns from ymdp except the period column, plus hour from mapping - ymdp_cols = [c for c in ymdp_table.columns if c != period_col] - select_exprs = [ymdp_table[c] for c in ymdp_cols] + [mapping_table["hour"]] - - joined = ymdp_table.join( - mapping_table, ymdp_table[period_col] == mapping_table["from_period"] - ) - result = joined.select(select_exprs) - self._backend.create_table(intermediate_ymdh_table_name, result) - finally: - # Always clean up the mapping table + with self._backend.transaction(): + self._backend.write_table( + df_mapping, + mapping_table_name, + [self._from_time_config], + if_exists="fail", + ) + + # Build the join query using ibis + ymdp_table = self._backend.table(self._from_schema.name) + mapping_table = self._backend.table(mapping_table_name) + + # Select all columns from ymdp except the period column, plus hour from mapping + ymdp_cols = [c for c in ymdp_table.columns if c != period_col] + select_exprs = [ymdp_table[c] for c in ymdp_cols] + [mapping_table["hour"]] + + joined = ymdp_table.join( + mapping_table, ymdp_table[period_col] == mapping_table["from_period"] + ) + result = joined.select(select_exprs) + self._backend.create_table(intermediate_ymdh_table_name, result) + # Drop the helper mapping table inside the transaction so commit + # doesn't retain it. + self._backend.drop_table(mapping_table_name) + except Exception: + # Spark fallback: rollback is a no-op, so clean up manually. + # Idempotent on DuckDB/SQLite where the rollback already dropped these. if self._backend.has_table(mapping_table_name): self._backend.drop_table(mapping_table_name) + if self._backend.has_table(intermediate_ymdh_table_name): + self._backend.drop_table(intermediate_ymdh_table_name) + raise if not isinstance(self._from_time_config, YearMonthDayPeriodTimeNTZ): msg = "Intermediate mapping only valid for YearMonthDayPeriodNTZ time config" diff --git a/tests/test_ibis_base.py b/tests/test_ibis_base.py index ca8004f..3ce8d18 100644 --- a/tests/test_ibis_base.py +++ b/tests/test_ibis_base.py @@ -1,9 +1,10 @@ """Tests for the IbisBackend base class (transaction, execute_sql, etc.).""" +import ibis +import pandas as pd import pytest from chronify.ibis import make_backend -from chronify.ibis.base import ObjectType def test_execute_sql(create_duckdb_backend): @@ -28,48 +29,162 @@ def test_dispose(): def test_transaction_success(create_duckdb_backend): backend = create_duckdb_backend - with backend.transaction() as created: + with backend.transaction(): backend.create_table( "txn_table", obj=None, schema={"id": "int64", "val": "float64"}, ) - created.append(("txn_table", ObjectType.TABLE)) - # Table should still exist after successful transaction assert backend.has_table("txn_table") def test_transaction_rollback_on_exception(create_duckdb_backend): - import pandas as pd - backend = create_duckdb_backend df = pd.DataFrame({"id": [1], "val": [2.0]}) with pytest.raises(ValueError, match="test error"): - with backend.transaction() as created: + with backend.transaction(): backend.create_table("txn_rollback", obj=df) - created.append(("txn_rollback", ObjectType.TABLE)) msg = "test error" raise ValueError(msg) - # Table should have been cleaned up assert not backend.has_table("txn_rollback") def test_transaction_rollback_view(create_duckdb_backend): - import pandas as pd - backend = create_duckdb_backend df = pd.DataFrame({"id": [1], "val": [2.0]}) backend.create_table("base_for_view", obj=df) expr = backend.table("base_for_view") with pytest.raises(ValueError, match="test error"): - with backend.transaction() as created: + with backend.transaction(): backend.create_view("txn_view", expr) - created.append(("txn_view", ObjectType.VIEW)) msg = "test error" raise ValueError(msg) assert not backend.has_table("txn_view") + + +def test_transaction_rolls_back_ddl_and_dml(iter_backends): + """A failing transaction must roll back both CREATE TABLE and INSERTs. + + Regression for an ibis-sqlite quirk: ``Backend.create_table`` runs its work + inside ``with self.begin() as cur:`` which calls ``con.commit()`` on + success, terminating any outer BEGIN. The chronify SQLiteBackend swaps in + a no-commit connection proxy for the duration of ``transaction()`` so the + commit is suppressed and the rollback covers DDL. + """ + backend = iter_backends + df = pd.DataFrame({"x": [1, 2, 3]}) + with pytest.raises(ValueError, match="boom"): + with backend.transaction(): + backend.create_table("ddl_rollback", obj=df) + assert backend.has_table("ddl_rollback") + backend.insert("ddl_rollback", df) + msg = "boom" + raise ValueError(msg) + assert not backend.has_table("ddl_rollback") + + +def test_transaction_commit_persists_ddl_and_dml(iter_backends): + backend = iter_backends + df = pd.DataFrame({"x": [1, 2, 3]}) + with backend.transaction(): + backend.create_table("ddl_commit", schema=ibis.schema({"x": "int64"})) + backend.insert("ddl_commit", df) + assert backend.has_table("ddl_commit") + rows = backend.execute(backend.table("ddl_commit")) + assert len(rows) == 3 + + +def test_sqlite_inner_commit_does_not_leak_proxy(tmp_path): + """Reviewer regression: if user code finalizes the SQLite transaction + inside the block (e.g. ``execute_sql("COMMIT")``), the wrapper's exit + must still restore the connection. Previously the wrapper's COMMIT raised + before the proxy was put back, leaving the backend silently pointed at + ``_NoCommitProxy`` so later writes never committed. + """ + import sqlite3 + + from chronify.ibis.sqlite_backend import SQLiteBackend, _NoCommitProxy + + db_path = tmp_path / "leak.db" + backend = SQLiteBackend(database=str(db_path)) + + with backend.transaction(): + backend.create_table("leak_t", schema=ibis.schema({"x": "int64"})) + backend.execute_sql("COMMIT") + + # Backend must be back on the real connection. + assert not isinstance(backend._connection.con, _NoCommitProxy) + assert backend._real_con is None + assert not backend._in_transaction + + # And later writes must actually commit — visible from a fresh connection. + backend.insert("leak_t", pd.DataFrame({"x": [1, 2, 3]})) + other = sqlite3.connect(str(db_path)) + try: + rows = other.execute("SELECT COUNT(*) FROM leak_t").fetchone() + assert rows == (3,) + finally: + other.close() + + +def test_sqlite_failed_begin_does_not_corrupt_backend(): + """A failed BEGIN must not leave SQLite pointed at the no-commit proxy. + + Regression for a reviewer-reported issue: the proxy was installed before + BEGIN, so any BEGIN failure (e.g. an outer transaction already active) + would leak the proxy and silently swallow all subsequent commits. + """ + from chronify.ibis.sqlite_backend import SQLiteBackend, _NoCommitProxy + + backend = SQLiteBackend(database=":memory:") + # Simulate an outer transaction already in progress. + backend._connection.con.execute("BEGIN") + with pytest.raises(Exception, match="transaction"): + backend._begin_transaction() + # Backend must still be operating on the real sqlite3 connection. + assert not isinstance(backend._connection.con, _NoCommitProxy) + assert not backend._in_transaction + + # Recovery: clean up the stranded outer BEGIN, then a normal transaction + # block should work end-to-end. + backend._connection.con.rollback() + with backend.transaction(): + backend.create_table("recover", schema=ibis.schema({"x": "int64"})) + assert backend.has_table("recover") + + +def test_transaction_nesting_commits_as_a_unit(iter_backends): + """Nested transaction() calls compose: the outermost block governs.""" + backend = iter_backends + with backend.transaction(): + backend.create_table("nest_outer", schema=ibis.schema({"x": "int64"})) + with backend.transaction(): + backend.create_table("nest_inner", schema=ibis.schema({"x": "int64"})) + # Inner block exited successfully, but the outer transaction is + # still in progress — both tables remain pending until the outer + # commit. + assert backend._in_transaction + assert backend.has_table("nest_outer") + assert backend.has_table("nest_inner") + assert not backend._in_transaction + + +def test_transaction_nesting_outer_rollback_includes_inner(iter_backends): + """An exception in the outer block must roll back work done inside an + inner (nested) transaction() block too.""" + backend = iter_backends + with pytest.raises(ValueError, match="boom"): + with backend.transaction(): + backend.create_table("nest_outer_r", schema=ibis.schema({"x": "int64"})) + with backend.transaction(): + backend.create_table("nest_inner_r", schema=ibis.schema({"x": "int64"})) + msg = "boom" + raise ValueError(msg) + assert not backend.has_table("nest_outer_r") + assert not backend.has_table("nest_inner_r") + assert not backend._in_transaction diff --git a/tests/test_store.py b/tests/test_store.py index d608350..f61cc58 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -11,6 +11,7 @@ import pandas as pd import pytest +from chronify import time_series_mapper_base from chronify.csv_io import read_csv from chronify.exceptions import ( ConflictingInputsError, @@ -1033,3 +1034,136 @@ def test_localize_time_zone_by_column(tmp_path, iter_stores_by_engine_no_data_in for tz, expected in expected_dct.items(): actual = sorted(df2.loc[df2["time_zone"] == tz, "timestamp"]) check_timestamp_lists(actual, expected) + + +def test_map_table_preserves_existing_parquet_on_failure(tmp_path, monkeypatch): + """A failing remap to an existing parquet must not destroy the original.""" + store = Store.create_in_memory_db() + year = 2020 + length = 24 + src_schema = TableSchema( + name="src_preserve", + value_column="value", + time_array_id_columns=["id"], + time_config=DatetimeRange( + start=datetime(year, 1, 1), + resolution=timedelta(hours=1), + length=length, + interval_type=TimeIntervalType.PERIOD_BEGINNING, + time_column="timestamp", + ), + ) + df = pd.DataFrame( + { + "timestamp": pd.date_range(datetime(year, 1, 1), periods=length, freq="h"), + "id": 1, + "value": list(range(length)), + } + ) + store.ingest_table(df, src_schema) + + output_file = tmp_path / "out.parquet" + sentinel = b"PRE-EXISTING-CONTENT" + output_file.write_bytes(sentinel) + + dst_schema = TableSchema( + name="dst_preserve", + value_column="value", + time_array_id_columns=["id"], + time_config=DatetimeRange( + start=datetime(year, 1, 1, 1), + resolution=timedelta(hours=1), + length=length, + interval_type=TimeIntervalType.PERIOD_ENDING, + time_column="timestamp", + ), + ) + + def _fail(*args, **kwargs): + msg = "forced failure" + raise InvalidTable(msg) + + monkeypatch.setattr(time_series_mapper_base, "check_timestamps", _fail) + + with pytest.raises(InvalidTable, match="forced"): + store.map_table_time_config( + src_schema.name, + dst_schema, + output_file=output_file, + check_mapped_timestamps=True, + ) + + assert output_file.read_bytes() == sentinel + # No staging files left behind in tmp_path either. + leftover = [p for p in tmp_path.iterdir() if p.name != output_file.name] + assert leftover == [] + + +def test_map_table_cleanup_handles_directory_staging(tmp_path, monkeypatch): + """Spark writes parquet as a directory, not a file. Cleanup of a failed + map must use a directory-safe delete instead of unlink().""" + from chronify.ibis.base import IbisBackend + + # Override write_parquet so the staging output is a directory, mirroring + # how pyspark's Backend.to_parquet writes (a directory of part-* files). + def _write_dir(self, expr, path, partition_by=None): # noqa: ARG001 + out = Path(path) + out.mkdir(parents=True, exist_ok=True) + (out / "part-0.parquet").write_bytes(b"not a real parquet") + + monkeypatch.setattr(IbisBackend, "write_parquet", _write_dir) + + store = Store.create_in_memory_db() + year = 2020 + length = 24 + src_schema = TableSchema( + name="src_dir_cleanup", + value_column="value", + time_array_id_columns=["id"], + time_config=DatetimeRange( + start=datetime(year, 1, 1), + resolution=timedelta(hours=1), + length=length, + interval_type=TimeIntervalType.PERIOD_BEGINNING, + time_column="timestamp", + ), + ) + df = pd.DataFrame( + { + "timestamp": pd.date_range(datetime(year, 1, 1), periods=length, freq="h"), + "id": 1, + "value": list(range(length)), + } + ) + store.ingest_table(df, src_schema) + + dst_schema = TableSchema( + name="dst_dir_cleanup", + value_column="value", + time_array_id_columns=["id"], + time_config=DatetimeRange( + start=datetime(year, 1, 1, 1), + resolution=timedelta(hours=1), + length=length, + interval_type=TimeIntervalType.PERIOD_ENDING, + time_column="timestamp", + ), + ) + + output_file = tmp_path / "spark_style_out" + + # The fake parquet content fails downstream (in create_view_from_parquet + # or check_timestamps), exercising the cleanup path on a directory + # staging output. Without the directory-safe delete this raises + # IsADirectoryError and masks the original error. + with pytest.raises(Exception): # noqa: B017 + store.map_table_time_config( + src_schema.name, + dst_schema, + output_file=output_file, + check_mapped_timestamps=True, + ) + + assert not output_file.exists() + leftover = sorted(p.name for p in tmp_path.iterdir()) + assert leftover == [] From 17ca716a7cc1feab191739ebcf3f3b9a3b3ecc66 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Mon, 27 Apr 2026 13:04:08 -0600 Subject: [PATCH 40/48] Make parquet promotion crash-safe MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reviewer-flagged: ``apply_mapping`` was deleting the existing ``output_path`` before renaming the staged output into place, which opened a window where a failure or process crash left the user with neither the old output nor the new one — contradicting the docstring's "prior on-disk state is restored" claim. Replace the delete-then-rename sequence with a backup-rename-replace pattern: rename the existing target aside to a sibling ``..backup.`` path, rename staging into the target, then delete the backup. If the second rename fails, restore the backup so the original is preserved. Each ``Path.replace`` call is atomic, and the user's data is never simultaneously absent from both target and backup. Add ``test_map_table_preserves_existing_parquet_on_promotion_failure``, which monkey-patches ``Path.replace`` to fail on the staging→target call and asserts the pre-existing parquet is preserved with no debris left over. Co-Authored-By: Claude Opus 4.7 --- src/chronify/time_series_mapper_base.py | 29 ++++++++-- tests/test_store.py | 77 +++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 6 deletions(-) diff --git a/src/chronify/time_series_mapper_base.py b/src/chronify/time_series_mapper_base.py index 1d7e50b..83409a5 100644 --- a/src/chronify/time_series_mapper_base.py +++ b/src/chronify/time_series_mapper_base.py @@ -143,14 +143,31 @@ def apply_mapping( # noqa: C901 elif created_tmp_obj is ObjectType.VIEW: backend.drop_view(to_schema.name) # Promote the staged parquet only after the transaction commits. - # The staging output may be a file (DuckDB) or a directory (Spark - # parquet writes are always directories), so explicitly remove any - # existing target first — ``Path.replace`` cannot overwrite a - # non-empty directory. + # When the target already exists, do a backup-rename-replace dance + # rather than a delete-then-rename so the original is preserved if + # anything goes wrong (or the process crashes) between the two + # renames. ``Path.replace`` is atomic per call but cannot overwrite + # a non-empty directory, so we always rename the existing target + # aside first regardless of file vs. directory shape. if staging_path is not None: assert output_path is not None - delete_if_exists(output_path) - staging_path.replace(output_path) + if output_path.exists(): + backup_path = output_path.with_name( + f".{output_path.name}.backup.{uuid.uuid4().hex[:8]}" + ) + output_path.replace(backup_path) + try: + staging_path.replace(output_path) + except Exception: + # Restore the original; the user keeps their pre-existing + # output. If this also fails, the chained exception + # surfaces both errors and the backup remains on disk + # for manual recovery. + backup_path.replace(output_path) + raise + delete_if_exists(backup_path) + else: + staging_path.replace(output_path) except Exception: logger.exception( "Mapping failed for {} -> {}. Cleaning up.", from_schema.name, to_schema.name diff --git a/tests/test_store.py b/tests/test_store.py index f61cc58..21eff50 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -1099,6 +1099,83 @@ def _fail(*args, **kwargs): assert leftover == [] +def test_map_table_preserves_existing_parquet_on_promotion_failure(tmp_path, monkeypatch): + """If the staging→target rename fails after the transaction commits, the + pre-existing target must survive. Regression for an earlier + delete-then-rename window where a crash between the two operations + would leave the user with neither the old nor the new output. + """ + store = Store.create_in_memory_db() + year = 2020 + length = 24 + src_schema = TableSchema( + name="src_promo_fail", + value_column="value", + time_array_id_columns=["id"], + time_config=DatetimeRange( + start=datetime(year, 1, 1), + resolution=timedelta(hours=1), + length=length, + interval_type=TimeIntervalType.PERIOD_BEGINNING, + time_column="timestamp", + ), + ) + df = pd.DataFrame( + { + "timestamp": pd.date_range(datetime(year, 1, 1), periods=length, freq="h"), + "id": 1, + "value": list(range(length)), + } + ) + store.ingest_table(df, src_schema) + + output_file = tmp_path / "out.parquet" + sentinel = b"PRE-EXISTING-DO-NOT-DESTROY" + output_file.write_bytes(sentinel) + + dst_schema = TableSchema( + name="dst_promo_fail", + value_column="value", + time_array_id_columns=["id"], + time_config=DatetimeRange( + start=datetime(year, 1, 1, 1), + resolution=timedelta(hours=1), + length=length, + interval_type=TimeIntervalType.PERIOD_ENDING, + time_column="timestamp", + ), + ) + + # The promotion sequence is: target→backup, then staging→target, then + # delete backup. Inject a failure on the second Path.replace so we + # exercise the restore path. + real_replace = Path.replace + n = [0] + + def _flaky_replace(self, target): + n[0] += 1 + if n[0] == 2: + msg = "simulated promotion failure" + raise OSError(msg) + return real_replace(self, target) + + monkeypatch.setattr(Path, "replace", _flaky_replace) + + with pytest.raises(OSError, match="simulated"): + store.map_table_time_config( + src_schema.name, + dst_schema, + output_file=output_file, + ) + + # Original content must survive — restored from the backup after the + # second rename failed. + assert output_file.read_bytes() == sentinel + # And no debris left in the directory. + leftover = sorted(p.name for p in tmp_path.iterdir() if p.name != output_file.name) + assert leftover == [] + + def test_map_table_cleanup_handles_directory_staging(tmp_path, monkeypatch): """Spark writes parquet as a directory, not a file. Cleanup of a failed map must use a directory-safe delete instead of unlink().""" From 511cf7bf6a0de84828cac3864a79350bd96a271c Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Mon, 27 Apr 2026 13:29:54 -0600 Subject: [PATCH 41/48] Make post-promotion backup cleanup non-fatal MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reviewer-flagged: ``delete_if_exists(backup_path)`` ran inside the same try block as the mapping work, so any failure cleaning up the backup caused ``apply_mapping`` to re-raise. ``Store.map_table_time_config`` then skipped its ``add_schema`` call and the user was left with the new parquet on disk but no schema registration — the store believed the mapping had failed. Wrap the backup cleanup in try/except + warning. The promotion already succeeded by this point and the new output is observable; a leftover backup is cosmetic debris that the user can remove manually. Add ``test_map_table_succeeds_when_backup_cleanup_fails`` which monkey-patches ``delete_if_exists`` to fail for backup paths only and asserts the mapping completes, the new content is in place, and the schema is registered — matching the reviewer's exact reproducer. Co-Authored-By: Claude Opus 4.7 --- src/chronify/time_series_mapper_base.py | 16 +++++- tests/test_store.py | 69 +++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 1 deletion(-) diff --git a/src/chronify/time_series_mapper_base.py b/src/chronify/time_series_mapper_base.py index 83409a5..c235103 100644 --- a/src/chronify/time_series_mapper_base.py +++ b/src/chronify/time_series_mapper_base.py @@ -165,7 +165,21 @@ def apply_mapping( # noqa: C901 # for manual recovery. backup_path.replace(output_path) raise - delete_if_exists(backup_path) + # Promotion succeeded — the new output is observable. A + # failure to remove the backup at this point is non-fatal + # debris; surfacing it would cause the caller (e.g. + # ``Store.map_table_time_config``) to skip the post-success + # schema registration and leave the store metadata + # inconsistent with the on-disk parquet. + try: + delete_if_exists(backup_path) + except Exception: + logger.warning( + "Promoted output to {} but failed to remove backup at {}; " + "this is cosmetic debris and may be deleted manually.", + output_path, + backup_path, + ) else: staging_path.replace(output_path) except Exception: diff --git a/tests/test_store.py b/tests/test_store.py index 21eff50..a8148a5 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -1176,6 +1176,75 @@ def _flaky_replace(self, target): assert leftover == [] +def test_map_table_succeeds_when_backup_cleanup_fails(tmp_path, monkeypatch): + """A failure to remove the post-promotion backup is cosmetic debris; the + mapping must still succeed and the schema must still be registered. + + Reviewer regression: previously a backup-cleanup OSError caused + ``apply_mapping`` to re-raise, so ``map_table_time_config`` skipped its + ``add_schema`` call and the user was left with a new parquet on disk + that the store didn't know about. + """ + store = Store.create_in_memory_db() + year = 2020 + length = 24 + src_schema = TableSchema( + name="src_backup_fail", + value_column="value", + time_array_id_columns=["id"], + time_config=DatetimeRange( + start=datetime(year, 1, 1), + resolution=timedelta(hours=1), + length=length, + interval_type=TimeIntervalType.PERIOD_BEGINNING, + time_column="timestamp", + ), + ) + df = pd.DataFrame( + { + "timestamp": pd.date_range(datetime(year, 1, 1), periods=length, freq="h"), + "id": 1, + "value": list(range(length)), + } + ) + store.ingest_table(df, src_schema) + + output_file = tmp_path / "out.parquet" + output_file.write_bytes(b"ORIGINAL") + + dst_schema = TableSchema( + name="dst_backup_fail", + value_column="value", + time_array_id_columns=["id"], + time_config=DatetimeRange( + start=datetime(year, 1, 1, 1), + resolution=timedelta(hours=1), + length=length, + interval_type=TimeIntervalType.PERIOD_ENDING, + time_column="timestamp", + ), + ) + + real_delete = time_series_mapper_base.delete_if_exists + + def _fail_on_backup(path): + if ".backup." in path.name: + msg = "simulated backup cleanup failure" + raise OSError(msg) + return real_delete(path) + + monkeypatch.setattr(time_series_mapper_base, "delete_if_exists", _fail_on_backup) + + # Mapping must succeed despite the backup-cleanup OSError. + store.map_table_time_config(src_schema.name, dst_schema, output_file=output_file) + + # New parquet content is in place. + new = pd.read_parquet(output_file) + assert len(new) == length + # And the schema is registered, so the store and on-disk state agree. + assert store._schema_mgr.get_schema(dst_schema.name).name == dst_schema.name + + def test_map_table_cleanup_handles_directory_staging(tmp_path, monkeypatch): """Spark writes parquet as a directory, not a file. Cleanup of a failed map must use a directory-safe delete instead of unlink().""" From 6af444cc5485686f5bb2ae346008c25562e14f05 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Mon, 27 Apr 2026 13:47:25 -0600 Subject: [PATCH 42/48] Harden post-failure cleanup and document Spark rollback limits Three correctness improvements to transaction handling, motivated by a review of the branch: - Cleanup paths (Store.ingest_*, apply_mapping, _intermediate_mapping_ymdp_to_ymdh) no longer mask the original error. A connectivity failure on Spark inside drop_table during rollback previously replaced the user-visible InvalidTable, leaving callers debugging the wrong thing. Each cleanup step is now individually guarded; failures are logged via logger.exception and swallowed so the bare raise propagates the cause. Adds test_ingest_cleanup_failure_does_not_mask_original_error as a regression test. - DuckDBBackend and SparkBackend now initialize self._in_transaction in __init__ to match SQLiteBackend. The class-level default on IbisBackend is dropped in favor of a docstring contract that subclasses must initialize the instance attribute, eliminating the attribute-shadowing trick. - Docstrings on ingest_tables, ingest_from_csvs, and ingest_pivoted_tables now spell out that Spark cannot roll back partial appends to a pre-existing table; only the new-table case is fully cleaned up on failure. Co-Authored-By: Claude Opus 4.7 --- src/chronify/ibis/base.py | 12 ++-- src/chronify/ibis/duckdb_backend.py | 1 + src/chronify/ibis/spark_backend.py | 1 + src/chronify/store.py | 67 ++++++++++++++----- src/chronify/time_series_mapper_base.py | 33 ++++++--- ...apper_column_representative_to_datetime.py | 14 ++-- tests/test_store.py | 26 +++++++ 7 files changed, 119 insertions(+), 35 deletions(-) diff --git a/src/chronify/ibis/base.py b/src/chronify/ibis/base.py index 112587f..90d12fa 100644 --- a/src/chronify/ibis/base.py +++ b/src/chronify/ibis/base.py @@ -183,13 +183,13 @@ class ObjectType(StrEnum): class IbisBackend(ABC): - """Abstract base class defining the interface for Ibis database backends.""" + """Abstract base class defining the interface for Ibis database backends. - # Set while inside transaction(). Used for nesting (inner transaction() - # calls become passthroughs) and, on SQLite, by ``_commit_if_needed`` to - # decide whether DML auto-commits. Subclasses may shadow with an - # instance attribute. - _in_transaction: bool = False + Subclasses must set ``self._in_transaction = False`` in ``__init__``. + The flag is read by :meth:`transaction` (to make inner blocks + passthroughs) and on SQLite by ``_commit_if_needed`` (to decide whether + DML auto-commits). + """ @property @abstractmethod diff --git a/src/chronify/ibis/duckdb_backend.py b/src/chronify/ibis/duckdb_backend.py index bd8a2a8..eb2c6e6 100644 --- a/src/chronify/ibis/duckdb_backend.py +++ b/src/chronify/ibis/duckdb_backend.py @@ -38,6 +38,7 @@ def __init__( msg = f"{database=} and {connection=} cannot both be set" raise ConflictingInputsError(msg) + self._in_transaction = False self._owns_connection = connection is None if connection is None: db = str(database) diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index 249c6fe..2eb2562 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -44,6 +44,7 @@ def __init__(self, session: Any = None, *, owns_session: bool | None = None) -> msg = "pyspark is required for SparkBackend. Install with: pip install chronify[spark]" raise ImportError(msg) from e + self._in_transaction = False self._owns_session = session is None if owns_session is None else owns_session if session is None: session = ( diff --git a/src/chronify/store.py b/src/chronify/store.py index 9c64d4d..1440ca8 100644 --- a/src/chronify/store.py +++ b/src/chronify/store.py @@ -244,8 +244,11 @@ def ingest_from_csvs( """Ingest data into the table specified by schema. If the table does not exist, create it. This is faster than calling :meth:`ingest_from_csv` many times. Each file is loaded into memory one at a time. - If any error occurs, all added data will be removed and the state of the database will - be the same as the original state. + + On DuckDB and SQLite, a failure rolls back every change made by this call + and the database state is restored. On Spark, partial appends to a + pre-existing table cannot be rolled back; only the case where this call + creates a new table is fully cleaned up on failure. Parameters ---------- @@ -275,10 +278,9 @@ def ingest_from_csvs( with self._backend.transaction(): created_table = self._ingest_from_csvs(paths, src_schema, dst_schema) except Exception: - if not table_existed and self._backend.has_table(dst_schema.name): - self._backend.drop_table(dst_schema.name) - if not table_existed: - self._schema_mgr.remove_schema(dst_schema.name) + _safe_cleanup_after_ingest_error( + self._backend, self._schema_mgr, dst_schema.name, table_existed + ) raise return created_table @@ -402,8 +404,11 @@ def ingest_pivoted_tables( If the table does not exist, create it. Unpivot the data before ingesting it. This is faster than calling :meth:`ingest_pivoted_table` many times. - If any error occurs, all added data will be removed and the state of the database will - be the same as the original state. + + On DuckDB and SQLite, a failure rolls back every change made by this call + and the database state is restored. On Spark, partial appends to a + pre-existing table cannot be rolled back; only the case where this call + creates a new table is fully cleaned up on failure. Parameters ---------- @@ -428,10 +433,9 @@ def ingest_pivoted_tables( with self._backend.transaction(): created_table = self._ingest_pivoted_tables(data, src_schema, dst_schema) except Exception: - if not table_existed and self._backend.has_table(dst_schema.name): - self._backend.drop_table(dst_schema.name) - if not table_existed: - self._schema_mgr.remove_schema(dst_schema.name) + _safe_cleanup_after_ingest_error( + self._backend, self._schema_mgr, dst_schema.name, table_existed + ) raise return created_table @@ -534,6 +538,11 @@ def ingest_tables( This offers significant performance advantages over calling :meth:`ingest_table` many times. + On DuckDB and SQLite, a failure rolls back every change made by this call + and the database state is restored. On Spark, partial appends to a + pre-existing table cannot be rolled back; only the case where this call + creates a new table is fully cleaned up on failure. + Parameters ---------- data @@ -564,10 +573,9 @@ def ingest_tables( with self._backend.transaction(): created_table = self._ingest_tables(data, schema, **kwargs) except Exception: - if not table_existed and self._backend.has_table(schema.name): - self._backend.drop_table(schema.name) - if not table_existed: - self._schema_mgr.remove_schema(schema.name) + _safe_cleanup_after_ingest_error( + self._backend, self._schema_mgr, schema.name, table_existed + ) raise return created_table @@ -1223,3 +1231,30 @@ def check_columns( cols = " ".join(sorted(diff)) msg = f"These columns are defined in the schema but not present in the table: {cols}" raise InvalidTable(msg) + + +def _safe_cleanup_after_ingest_error( + backend: IbisBackend, + schema_mgr: SchemaManager, + name: str, + table_existed: bool, +) -> None: + """Best-effort post-failure cleanup that never replaces the original error. + + Caller invokes this from inside an ``except:`` block and follows it with a + bare ``raise``. Any exception raised here is logged and swallowed so the + original exception is what propagates — otherwise a Spark connectivity + failure during ``drop_table`` would mask the real cause of the ingest + error and leave the user debugging the wrong thing. + """ + if table_existed: + return + try: + if backend.has_table(name): + backend.drop_table(name) + schema_mgr.remove_schema(name) + except Exception: + logger.exception( + "Cleanup after failed ingest of {} did not complete; original error follows.", + name, + ) diff --git a/src/chronify/time_series_mapper_base.py b/src/chronify/time_series_mapper_base.py index c235103..a2e36ff 100644 --- a/src/chronify/time_series_mapper_base.py +++ b/src/chronify/time_series_mapper_base.py @@ -187,18 +187,33 @@ def apply_mapping( # noqa: C901 "Mapping failed for {} -> {}. Cleaning up.", from_schema.name, to_schema.name ) # Idempotent cleanup. On DuckDB/SQLite the rollback already dropped - # these objects (has_table returns False); on Spark it didn't. - if backend.has_table(mapping_schema.name): - backend.drop_table(mapping_schema.name) - if backend.has_table(to_schema.name): - if created_tmp_obj is ObjectType.VIEW: - backend.drop_view(to_schema.name) - else: - backend.drop_table(to_schema.name) + # these objects (has_table returns False); on Spark it didn't. Each + # step is independently guarded so a Spark connectivity failure (or + # similar) here cannot mask the original mapping exception. + try: + if backend.has_table(mapping_schema.name): + backend.drop_table(mapping_schema.name) + except Exception: + logger.exception( + "Failed to drop mapping table {} during cleanup.", mapping_schema.name + ) + try: + if backend.has_table(to_schema.name): + if created_tmp_obj is ObjectType.VIEW: + backend.drop_view(to_schema.name) + else: + backend.drop_table(to_schema.name) + except Exception: + logger.exception("Failed to drop result table/view {} during cleanup.", to_schema.name) # Remove the staging output (file or directory); the original target # is untouched because the rename never ran. if staging_path is not None: - delete_if_exists(staging_path) + try: + delete_if_exists(staging_path) + except Exception: + logger.exception( + "Failed to remove staging output {} during cleanup.", staging_path + ) raise diff --git a/src/chronify/time_series_mapper_column_representative_to_datetime.py b/src/chronify/time_series_mapper_column_representative_to_datetime.py index eab56e3..4f66199 100644 --- a/src/chronify/time_series_mapper_column_representative_to_datetime.py +++ b/src/chronify/time_series_mapper_column_representative_to_datetime.py @@ -5,6 +5,8 @@ import pandas as pd from datetime import datetime +from loguru import logger + from chronify.exceptions import InvalidParameter, InvalidValue from chronify.ibis.base import IbisBackend from chronify.time_series_mapper_base import TimeSeriesMapperBase, apply_mapping @@ -164,10 +166,14 @@ def _intermediate_mapping_ymdp_to_ymdh(self) -> TableSchema: except Exception: # Spark fallback: rollback is a no-op, so clean up manually. # Idempotent on DuckDB/SQLite where the rollback already dropped these. - if self._backend.has_table(mapping_table_name): - self._backend.drop_table(mapping_table_name) - if self._backend.has_table(intermediate_ymdh_table_name): - self._backend.drop_table(intermediate_ymdh_table_name) + # Each step is independently guarded so a cleanup failure cannot + # mask the original error. + for tbl in (mapping_table_name, intermediate_ymdh_table_name): + try: + if self._backend.has_table(tbl): + self._backend.drop_table(tbl) + except Exception: + logger.exception("Failed to drop intermediate table {} during cleanup.", tbl) raise if not isinstance(self._from_time_config, YearMonthDayPeriodTimeNTZ): diff --git a/tests/test_store.py b/tests/test_store.py index a8148a5..e9f9960 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -210,6 +210,32 @@ def test_ingest_multiple_tables_error(iter_stores_by_engine: Store, multiple_tab assert df.equals(tables[1]) +def test_ingest_cleanup_failure_does_not_mask_original_error(monkeypatch, multiple_tables): + """A cleanup-side failure during post-ingest rollback must not replace + the user-visible ingest error. Regression: a Spark connectivity blip + inside ``drop_table`` would otherwise hide the InvalidTable that + originally caused the rollback. + """ + store = Store.create_in_memory_db() + tables, schema = multiple_tables + tables[1].loc[8783] = (tables[1].loc[8783]["timestamp"], 0.1, 99) + + real_drop = store._backend.drop_table + + def _exploding_drop(name): + if name == schema.name: + msg = "simulated cleanup failure" + raise RuntimeError(msg) + return real_drop(name) + + monkeypatch.setattr(store._backend, "drop_table", _exploding_drop) + + # The user must see the InvalidTable from validation, not the RuntimeError + # from cleanup. + with pytest.raises(InvalidTable): + store.ingest_tables(tables, schema) + + @pytest.mark.parametrize("use_pandas", [False, True]) def test_ingest_pivoted_table(iter_stores_by_engine: Store, generators_schema, use_pandas: bool): import ibis From 17bccf683bfc163b9edb0bc0363e39a266e7c3e4 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Mon, 27 Apr 2026 13:51:42 -0600 Subject: [PATCH 43/48] Restore _in_transaction type annotation for mypy Removing the class-level default in 6af444c left mypy unable to infer the type of ``self._in_transaction`` when read from the base class's ``transaction()`` method. Add a bare ``_in_transaction: bool`` annotation on ``IbisBackend`` to declare the attribute's type without giving it a value (subclasses still must initialize it in ``__init__``), and drop the now-stale ``# type: ignore[assignment]`` on the SQLite proxy swap. Co-Authored-By: Claude Opus 4.7 --- src/chronify/ibis/base.py | 2 ++ src/chronify/ibis/sqlite_backend.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/chronify/ibis/base.py b/src/chronify/ibis/base.py index 90d12fa..beaf0cb 100644 --- a/src/chronify/ibis/base.py +++ b/src/chronify/ibis/base.py @@ -191,6 +191,8 @@ class IbisBackend(ABC): DML auto-commits). """ + _in_transaction: bool + @property @abstractmethod def name(self) -> str: diff --git a/src/chronify/ibis/sqlite_backend.py b/src/chronify/ibis/sqlite_backend.py index ed8efd0..dcfac8b 100644 --- a/src/chronify/ibis/sqlite_backend.py +++ b/src/chronify/ibis/sqlite_backend.py @@ -200,7 +200,7 @@ def _begin_transaction(self) -> None: # The real connection is held aside and used directly for BEGIN / # COMMIT / ROLLBACK. self._real_con = real - self._connection.con = _NoCommitProxy(real) # type: ignore[assignment] + self._connection.con = _NoCommitProxy(real) self._in_transaction = True def _commit_transaction(self) -> None: From b03b41a55cdb4d13d3cc6684d54a898d1e53f2c4 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Mon, 27 Apr 2026 13:51:50 -0600 Subject: [PATCH 44/48] Fix Spark partitioned Parquet writes ibis-pyspark's ``to_parquet`` forwards ``partition_by`` as a kwarg to ``df.write.format('parquet').save(path, **kwargs)``. PySpark's writer expects camel-cased ``partitionBy``; the snake_case kwarg is silently dropped as an unknown option, so partitioned writes fell out as a single unpartitioned directory. (DuckDB's native ``to_parquet`` accepts ``partition_by`` directly, so this only broke on Spark.) Override ``write_parquet`` on ``SparkBackend`` to call ``df.write.partitionBy(*cols).parquet(path)`` directly when partition columns are given, falling back to the ibis path otherwise. Fixes ``test_spark_write_parquet_partitioned``. Co-Authored-By: Claude Opus 4.7 --- src/chronify/ibis/spark_backend.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index 2eb2562..a1c4c6e 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -131,6 +131,27 @@ def create_view_from_parquet(self, path: str, name: str) -> tuple[ibis.Table, Ob self._connection.create_view(name, self._connection.read_parquet(path)) return self.table(name), ObjectType.VIEW + def write_parquet( + self, + expr: ibis.Table, + path: str, + partition_by: list[str] | None = None, + ) -> None: + """Write to Parquet, using PySpark's writer directly when partitioning. + + ibis-pyspark's ``to_parquet`` passes ``partition_by`` through as a + ``DataFrameWriter.save`` kwarg, but PySpark's writer expects the + camel-cased ``partitionBy`` (snake_case is silently ignored as an + unknown option). Fall back to the unpartitioned ibis path when no + partition columns are given. + """ + if not partition_by: + self._connection.to_parquet(expr, path) + return + sql = self._connection.compile(expr) + df = self._session.sql(sql) + df.write.partitionBy(*partition_by).parquet(path) + def execute_sql(self, query: str) -> None: logger.trace("execute_sql: {}", query) self._session.sql(query) From e6755fb0d89de90155c866b9cccd91a62aaf2907 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Mon, 27 Apr 2026 14:03:45 -0600 Subject: [PATCH 45/48] Suppress pyarrow.compute attr-defined error for assume_timezone pyarrow 24 added a ``py.typed`` marker, so mypy now uses the bundled type info instead of treating pyarrow as ``Any``. The compute kernels (``assume_timezone``, etc.) are registered dynamically at import time and aren't declared in the stubs, so mypy reports them as missing attributes even though they exist at runtime. CI installs pyarrow 24 fresh and trips this; local mypy was clean only because the existing environment still had pyarrow 23. Add a targeted ``# type: ignore[attr-defined]`` on the single call site with a comment explaining the cause. Co-Authored-By: Claude Opus 4.7 --- src/chronify/ibis/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/chronify/ibis/base.py b/src/chronify/ibis/base.py index beaf0cb..51e29e4 100644 --- a/src/chronify/ibis/base.py +++ b/src/chronify/ibis/base.py @@ -115,7 +115,10 @@ def _normalize_arrow_timestamps( if config.dtype == TimeDataType.TIMESTAMP_NTZ and is_tz_aware: new_arr = arr.cast(pa.timestamp(arr.type.unit)) elif config.dtype == TimeDataType.TIMESTAMP_TZ and not is_tz_aware: - new_arr = pc.assume_timezone(arr, "UTC") + # pyarrow 24 ships incomplete type stubs that omit dynamically- + # registered compute kernels like assume_timezone, even though + # the function exists at runtime. + new_arr = pc.assume_timezone(arr, "UTC") # type: ignore[attr-defined] else: continue table = table.set_column(idx, table.column_names[idx], new_arr) From 0a8df8f071c4580528419b24aab1a8ca097834f7 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Tue, 28 Apr 2026 09:54:38 -0600 Subject: [PATCH 46/48] Address Copilot review: 'None' tz sentinel and cursor cleanup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - TimeZoneLocalizerByColumn._get_time_zones() previously rejected only the *mixed* case of the literal 'None' string (the get_tzname(None) sentinel) plus real time zones. The all-'None' case slipped through and crashed with ZoneInfoNotFoundError on ZoneInfo('None'). Tighten the check to reject 'None' whenever it appears, with a message pointing the user to localize_time_zone(None) for tz-naive rows. Add test_localize_time_zone_by_column_rejects_none_sentinel as a regression test. - SQLiteBackend.insert() now wraps con.cursor() in contextlib.closing so the cursor is deterministically closed even on executemany failure or non-CPython runtimes. (Cursor close is independent of transaction state — commit still happens on the connection afterwards via _commit_if_needed.) Co-Authored-By: Claude Opus 4.7 --- src/chronify/ibis/sqlite_backend.py | 5 +++-- src/chronify/time_zone_localizer.py | 21 +++++++++++++++------ tests/test_time_zone_localizer.py | 25 +++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 8 deletions(-) diff --git a/src/chronify/ibis/sqlite_backend.py b/src/chronify/ibis/sqlite_backend.py index dcfac8b..f0976ef 100644 --- a/src/chronify/ibis/sqlite_backend.py +++ b/src/chronify/ibis/sqlite_backend.py @@ -1,6 +1,7 @@ """SQLite backend implementation for Ibis.""" import sqlite3 +from contextlib import closing from datetime import datetime from pathlib import Path from functools import singledispatchmethod @@ -150,8 +151,8 @@ def insert(self, name: str, data: pd.DataFrame | pa.Table | ibis.Table) -> None: ordered = data.loc[:, columns] rows = [tuple(_adapt_value(v) for v in row) for row in ordered.itertuples(index=False)] - cursor = con.cursor() - cursor.executemany(sql, rows) + with closing(con.cursor()) as cursor: + cursor.executemany(sql, rows) self._commit_if_needed() logger.trace("Inserted {} rows into {}", len(data), name) diff --git a/src/chronify/time_zone_localizer.py b/src/chronify/time_zone_localizer.py index bb41f32..ee2b722 100644 --- a/src/chronify/time_zone_localizer.py +++ b/src/chronify/time_zone_localizer.py @@ -376,16 +376,25 @@ def _get_time_zones(self) -> list[tzinfo | None]: self.time_zone_column ].to_list() - if "None" in time_zones and len(time_zones) > 1: + # ``"None"`` is the canonical string sentinel for tz-naive rows + # (``get_tzname(None) == "None"``). Reject it here regardless of + # whether it is alone or mixed: mixed because databases cannot + # store tz-aware and tz-naive timestamps in the same column; + # alone because the all-tz-naive case is degenerate for + # localize_time_zone_by_column — the caller should use + # localize_time_zone(None) instead. Without this check the + # all-"None" branch falls through to ``ZoneInfo("None")`` and + # raises ``ZoneInfoNotFoundError``. + if "None" in time_zones: msg = ( - "Chronify does not support mix of None and time zones in time_zone_column." - "This is because databases do not support tz-aware and tz-naive timestamps " - f"in the same column: {time_zones}" + "Chronify does not support the 'None' time zone in time_zone_column. " + "Use localize_time_zone(None) for tz-naive rows. Mixing 'None' with " + "real time zones is also unsupported because databases cannot store " + f"tz-aware and tz-naive timestamps in the same column: {time_zones}" ) raise InvalidParameter(msg) - time_zones = [ZoneInfo(tz) for tz in time_zones] - return time_zones + return [ZoneInfo(tz) for tz in time_zones] def _create_mapping(self) -> tuple[pd.DataFrame, MappingTableSchema]: assert isinstance(self._from_schema.time_config, DatetimeRangeWithTZColumn) diff --git a/tests/test_time_zone_localizer.py b/tests/test_time_zone_localizer.py index 905565f..2df745e 100644 --- a/tests/test_time_zone_localizer.py +++ b/tests/test_time_zone_localizer.py @@ -357,3 +357,28 @@ def test_localize_time_zone_by_column_missing_tz_column_error( run_localization_by_column_with_error( iter_all_backends, df, from_schema, error, time_zone_column=None ) + + +def test_localize_time_zone_by_column_rejects_none_sentinel( + iter_all_backends: IbisBackend, +) -> None: + """The literal string 'None' (the get_tzname(None) sentinel) in + time_zone_column must produce a clear InvalidParameter pointing the + user to localize_time_zone(None), not crash with + ZoneInfoNotFoundError on ZoneInfo('None'). + """ + # Ingest with a plain DatetimeRange so __init__ goes through + # _get_time_zones() (the path that scans the table for distinct + # zones), not the DatetimeRangeWithTZColumn path that trusts the + # schema's time_zones field. + from_schema = get_datetime_schema(2018, None, TimeIntervalType.PERIOD_BEGINNING, "base_table") + df = generate_datetime_dataframe(from_schema) + df["time_zone"] = "None" + ingest_data(iter_all_backends, df, from_schema) + + with pytest.raises(InvalidParameter, match="Use localize_time_zone\\(None\\)"): + TimeZoneLocalizerByColumn( + iter_all_backends, + from_schema, + time_zone_column="time_zone", + ) From b3a1c695f9abd22358bf439a66a6ac6bd102d74d Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Thu, 7 May 2026 13:03:46 -0600 Subject: [PATCH 47/48] Address PR #66 review: restore docstrings, hoist test import MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Restore Numpy-style Parameters/Returns docstrings on localize_time_zone() and localize_time_zone_by_column() (per Daniel's PR comments at time_zone_localizer.py:37 and :53), with scratch_dir / connection references stripped — those parameters were removed during the SQLAlchemy → ibis migration. - Move 'import pandas as pd' from inside test_list_distinct_timestamps_from_dataframe_not_implemented to the top of test_annual_time_range_generator.py (Daniel's comment at :33). - Drop unused 'DatetimeRanges' alias in ibis/base.py — it shadowed the one exported from time_configs and was never referenced. - Drop a stale comment in spark_backend.delete_rows; the parameterized SQL behavior is now self-evident from the surrounding code. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/chronify/ibis/base.py | 1 - src/chronify/ibis/spark_backend.py | 1 - src/chronify/time_zone_localizer.py | 53 ++++++++++++++++++++++- tests/test_annual_time_range_generator.py | 3 +- 4 files changed, 52 insertions(+), 6 deletions(-) diff --git a/src/chronify/ibis/base.py b/src/chronify/ibis/base.py index 51e29e4..442cbbb 100644 --- a/src/chronify/ibis/base.py +++ b/src/chronify/ibis/base.py @@ -25,7 +25,6 @@ ) _DATETIME_RANGES = (DatetimeRange, DatetimeRangeWithTZColumn) -DatetimeRanges = DatetimeRange | DatetimeRangeWithTZColumn def _check_one_config_per_datetime_column(configs: Iterable[TimeBaseModel]) -> None: diff --git a/src/chronify/ibis/spark_backend.py b/src/chronify/ibis/spark_backend.py index a1c4c6e..d1d876f 100644 --- a/src/chronify/ibis/spark_backend.py +++ b/src/chronify/ibis/spark_backend.py @@ -98,7 +98,6 @@ def create_table( return self._connection.create_table(name, obj=obj, schema=schema, overwrite=overwrite) def delete_rows(self, name: str, values: dict[str, Any]) -> None: - # Spark 3.4+ supports parameterized SQL via the ``args`` keyword. quoted_name = _quote_identifier(name) param_names = [f"p{i}" for i in range(len(values))] where = " AND ".join(f"{_quote_identifier(c)} = :{p}" for c, p in zip(values, param_names)) diff --git a/src/chronify/time_zone_localizer.py b/src/chronify/time_zone_localizer.py index ee2b722..5f1f8aa 100644 --- a/src/chronify/time_zone_localizer.py +++ b/src/chronify/time_zone_localizer.py @@ -34,7 +34,32 @@ def localize_time_zone( output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> TableSchema: - """Localize TIMESTAMP_NTZ time column in a table to a specified standard time zone.""" + """Localize TIMESTAMP_NTZ time column in a table to a specified standard time zone. + + Input data must be in a standard time zone (without DST) because it's ambiguous to localize + tz-naive timestamps with skips and duplicates to a prevailing time zone. + + Updates table to TIMESTAMP_TZ time column and returns a new time config. + + Parameters + ---------- + backend : IbisBackend + Backend wrapping the database connection that holds the source table. + src_schema : TableSchema + Defines the source table in the database. + to_time_zone : tzinfo or None + Standard time zone to convert to. If None, convert to tz-naive. + output_file : pathlib.Path, optional + If set, write the mapped table to this Parquet file. + check_mapped_timestamps : bool, optional + Perform time checks on the result of the mapping operation. This can be slow and + is not required. + + Returns + ------- + TableSchema + Schema of output table with converted timestamps. + """ tzl = TimeZoneLocalizer(backend, src_schema, to_time_zone) tzl.localize_time_zone( output_file=output_file, @@ -50,7 +75,31 @@ def localize_time_zone_by_column( output_file: Optional[Path] = None, check_mapped_timestamps: bool = False, ) -> TableSchema: - """Localize TIMESTAMP_NTZ time column in a table to multiple time zones specified by a column.""" + """Localize TIMESTAMP_NTZ time column in a table to multiple time zones specified by a column. + + Updates table to TIMESTAMP_TZ time column and returns a new time config. + + Parameters + ---------- + backend : IbisBackend + Backend wrapping the database connection that holds the source table. + src_schema : TableSchema + Defines the source table in the database. + time_zone_column : Optional[str] + Column name in the source table that contains the time zone information. + - Required if src_schema.time_config is of type DatetimeRange. + - Ignored if src_schema.time_config is of type DatetimeRangeWithTZColumn. + output_file : pathlib.Path, optional + If set, write the mapped table to this Parquet file. + check_mapped_timestamps : bool, optional + Perform time checks on the result of the mapping operation. This can be slow and + is not required. + + Returns + ------- + dst_schema : TableSchema + schema of output table with converted timestamps + """ if isinstance(src_schema.time_config, DatetimeRange) and time_zone_column is None: msg = ( "time_zone_column must be provided when localizing time zones " diff --git a/tests/test_annual_time_range_generator.py b/tests/test_annual_time_range_generator.py index 9fedce9..de99d26 100644 --- a/tests/test_annual_time_range_generator.py +++ b/tests/test_annual_time_range_generator.py @@ -1,5 +1,6 @@ """Tests for AnnualTimeRangeGenerator.""" +import pandas as pd import pytest from chronify.annual_time_range_generator import AnnualTimeRangeGenerator @@ -30,8 +31,6 @@ def test_list_time_columns(): def test_list_distinct_timestamps_from_dataframe_not_implemented(): - import pandas as pd - gen = AnnualTimeRangeGenerator(_make_config()) with pytest.raises(NotImplementedError): gen.list_distinct_timestamps_from_dataframe(pd.DataFrame()) From 649b8012041c052a2c9875b857b917fb659cd118 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Thu, 7 May 2026 13:57:52 -0600 Subject: [PATCH 48/48] Bump codecov-action v4.2.0 -> v5 and pin coverage.xml MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous run failed in CI because codecov-action v4.2.0 crashed during its uploader binary fetch (`Could not pull latest version information: SyntaxError: Unexpected token '<'` followed by a `write EPIPE` from the gpg verification subprocess). That failure mode is not caught by `fail_ci_if_error: false` on v4, and was fixed in the v5 series (which uses the Codecov CLI directly). Also pin `files: coverage.xml` to skip the auto-discovery walk — we already produce that exact path with `pytest --cov-report=xml:coverage.xml`. Sticking on v5 (not v6) so the action keeps running on Node.js 20; the runners don't default to Node 24 until 2026-06-02. Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/ci.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 966410d..cbfcb24 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,11 +34,12 @@ jobs: run: | pytest -v --cov=chronify --cov-report=xml:coverage.xml - name: codecov - uses: codecov/codecov-action@v4.2.0 + uses: codecov/codecov-action@v5 if: ${{ matrix.os == env.DEFAULT_OS && matrix.python-version == env.DEFAULT_PYTHON }} with: token: ${{ secrets.CODECOV_TOKEN }} name: chronify-tests + files: coverage.xml fail_ci_if_error: false verbose: true mypy: