diff --git a/build.sh b/build.sh index 7e7a1c376..f7c66e19a 100755 --- a/build.sh +++ b/build.sh @@ -150,6 +150,11 @@ function write_e2e_env(){ } +## with_venv - runs a command with the venv activated +function with_venv() { + "$@" +} + ## help - prints the help details ## function help() { diff --git a/google/cloud/sql/connector/asyncpg.py b/google/cloud/sql/connector/asyncpg.py index 2fbc30273..2e28dbbaf 100644 --- a/google/cloud/sql/connector/asyncpg.py +++ b/google/cloud/sql/connector/asyncpg.py @@ -15,7 +15,7 @@ """ import ssl -from typing import Any, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING SERVER_PROXY_PORT = 3307 @@ -24,7 +24,7 @@ async def connect( - ip_address: str, ctx: ssl.SSLContext, **kwargs: Any + ip_address: str, ctx: Optional[ssl.SSLContext], **kwargs: Any ) -> "asyncpg.Connection": """Helper function to create an asyncpg DB-API connection object. @@ -32,8 +32,7 @@ async def connect( ip_address (str): A string containing an IP address for the Cloud SQL instance. ctx (ssl.SSLContext): An SSLContext object created from the Cloud SQL - server CA cert and ephemeral cert. - server CA cert and ephemeral cert. + server CA cert and ephemeral cert. Pass None to disable SSL. kwargs: Keyword arguments for establishing asyncpg connection object to Cloud SQL instance. @@ -53,14 +52,18 @@ async def connect( user = kwargs.pop("user") db = kwargs.pop("db") passwd = kwargs.pop("password", None) + port = kwargs.pop("port", SERVER_PROXY_PORT) - return await asyncpg.connect( - user=user, - database=db, - password=passwd, - host=ip_address, - port=SERVER_PROXY_PORT, - ssl=ctx, - direct_tls=True, + connect_args = { + "user": user, + "database": db, + "password": passwd, + "host": ip_address, + "port": port, **kwargs, - ) + } + if ctx is not None: + connect_args["ssl"] = ctx + connect_args["direct_tls"] = True + + return await asyncpg.connect(**connect_args) diff --git a/google/cloud/sql/connector/client.py b/google/cloud/sql/connector/client.py index 11508ce17..1befdb793 100644 --- a/google/cloud/sql/connector/client.py +++ b/google/cloud/sql/connector/client.py @@ -171,9 +171,13 @@ async def _get_metadata( if dns_name: ip_addresses["PSC"] = dns_name.rstrip(".") + server_ca_cert = None + if "serverCaCert" in ret_dict and "cert" in ret_dict["serverCaCert"]: + server_ca_cert = ret_dict["serverCaCert"]["cert"] + return { "ip_addresses": ip_addresses, - "server_ca_cert": ret_dict["serverCaCert"]["cert"], + "server_ca_cert": server_ca_cert, "database_version": ret_dict["databaseVersion"], } @@ -228,7 +232,11 @@ async def _get_ephemeral( finally: resp.raise_for_status() - ephemeral_cert: str = ret_dict["ephemeralCert"]["cert"] + try: + ephemeral_cert: str = ret_dict["ephemeralCert"]["cert"] + except KeyError as e: + logger.error(f"KeyError in _get_ephemeral parsing generateEphemeralCert: {e}. Response dict: {ret_dict}") + raise # decode cert to read expiration x509 = load_pem_x509_certificate( diff --git a/google/cloud/sql/connector/connection_info.py b/google/cloud/sql/connector/connection_info.py index c9e48935f..bf9330e1b 100644 --- a/google/cloud/sql/connector/connection_info.py +++ b/google/cloud/sql/connector/connection_info.py @@ -63,7 +63,7 @@ class ConnectionInfo: conn_name: ConnectionName client_cert: str - server_ca_cert: str + server_ca_cert: Optional[str] private_key: bytes ip_addrs: dict[str, Any] database_version: str @@ -79,6 +79,10 @@ async def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLCont # if SSL context is cached, use it if self.context is not None: return self.context + + if self.server_ca_cert is None: + raise ValueError("Cannot create SSL context: server CA certificate is missing.") + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) # update ssl.PROTOCOL_TLS_CLIENT default diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 798969c2c..6d902b1e2 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -31,6 +31,7 @@ import google.cloud.sql.connector.asyncpg as asyncpg from google.cloud.sql.connector.client import CloudSQLClient +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.enums import DriverMapping from google.cloud.sql.connector.enums import IPTypes from google.cloud.sql.connector.enums import RefreshStrategy @@ -44,6 +45,8 @@ import google.cloud.sql.connector.pytds as pytds from google.cloud.sql.connector.resolver import DefaultResolver from google.cloud.sql.connector.resolver import DnsResolver +from google.cloud.sql.connector.sqldata_client import FallbackSocket +from google.cloud.sql.connector.sqldata_client import SqlDataClient from google.cloud.sql.connector.utils import format_database_user from google.cloud.sql.connector.utils import generate_keys @@ -73,6 +76,8 @@ def __init__( refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, resolver: type[DefaultResolver] | type[DnsResolver] = DefaultResolver, failover_period: int = 30, + sql_data_endpoint: str = "sqladmin.googleapis.com", + sql_data_stream_timeout: int = 7200, ) -> None: """Initializes a Connector instance. @@ -212,6 +217,12 @@ def __init__( "configured the universe domain explicitly, `googleapis.com` " "is the default." ) + self._sql_data_endpoint = sql_data_endpoint + self._sql_data_stream_timeout = sql_data_stream_timeout + self._sql_data_fallback_cache: set[str] = set() + self._sqldata_clients: list[SqlDataClient] = [] + + @property def universe_domain(self) -> str: @@ -258,6 +269,48 @@ def connect( ) return connect_future.result() + def _get_or_create_cache( + self, + conn_name: ConnectionName, + enable_iam_auth: bool, + ) -> MonitoredCache: + assert self._client is not None, "client must be initialized before creating cache" + assert self._keys is not None, "keys must be initialized before creating cache" + if (str(conn_name), enable_iam_auth) in self._cache and not self._cache[ + (str(conn_name), enable_iam_auth) + ].closed: + return self._cache[(str(conn_name), enable_iam_auth)] + + if self._refresh_strategy == RefreshStrategy.LAZY: + logger.debug( + f"['{conn_name}']: Refresh strategy is set to lazy refresh" + ) + cache: Union[LazyRefreshCache, RefreshAheadCache] = LazyRefreshCache( + conn_name, + self._client, + self._keys, + enable_iam_auth, + ) + else: + logger.debug( + f"['{conn_name}']: Refresh strategy is set to backgound refresh" + ) + cache = RefreshAheadCache( + conn_name, + self._client, + self._keys, + enable_iam_auth, + ) + # wrap cache as a MonitoredCache + monitored_cache = MonitoredCache( + cache, + self._failover_period, + self._resolver, + ) + logger.debug(f"['{conn_name}']: Connection info added to cache") + self._cache[(str(conn_name), enable_iam_auth)] = monitored_cache + return monitored_cache + async def connect_async( self, instance_connection_string: str, driver: str, **kwargs: Any ) -> Any: @@ -317,42 +370,14 @@ async def connect_async( driver=driver, ) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) + ip_type = kwargs.pop("ip_type", self._ip_type) + if isinstance(ip_type, str): + ip_type = IPTypes._from_str(ip_type) conn_name = await self._resolver.resolve(instance_connection_string) - # Cache entry must exist and not be closed - if (str(conn_name), enable_iam_auth) in self._cache and not self._cache[ - (str(conn_name), enable_iam_auth) - ].closed: - monitored_cache = self._cache[(str(conn_name), enable_iam_auth)] - else: - if self._refresh_strategy == RefreshStrategy.LAZY: - logger.debug( - f"['{conn_name}']: Refresh strategy is set to lazy refresh" - ) - cache: Union[LazyRefreshCache, RefreshAheadCache] = LazyRefreshCache( - conn_name, - self._client, - self._keys, - enable_iam_auth, - ) - else: - logger.debug( - f"['{conn_name}']: Refresh strategy is set to backgound refresh" - ) - cache = RefreshAheadCache( - conn_name, - self._client, - self._keys, - enable_iam_auth, - ) - # wrap cache as a MonitoredCache - monitored_cache = MonitoredCache( - cache, - self._failover_period, - self._resolver, - ) - logger.debug(f"['{conn_name}']: Connection info added to cache") - self._cache[(str(conn_name), enable_iam_auth)] = monitored_cache + + if ip_type != IPTypes.SQL_DATA: + monitored_cache = self._get_or_create_cache(conn_name, enable_iam_auth) connect_func = { "pymysql": pymysql.connect, @@ -366,11 +391,6 @@ async def connect_async( connector: Callable = connect_func[driver] # type: ignore except KeyError: raise KeyError(f"Driver '{driver}' is not supported.") - - ip_type = kwargs.pop("ip_type", self._ip_type) - # if ip_type is str, convert to IPTypes enum - if isinstance(ip_type, str): - ip_type = IPTypes._from_str(ip_type) kwargs["timeout"] = kwargs.get("timeout", self._timeout) # Host and ssl options come from the certificates and metadata, so we don't @@ -379,85 +399,149 @@ async def connect_async( kwargs.pop("ssl", None) kwargs.pop("port", None) - # attempt to get connection info for Cloud SQL instance + # attempt to establish connection try: - conn_info = await monitored_cache.connect_info() - # validate driver matches intended database engine - DriverMapping.validate_engine(driver, conn_info.database_version) - ip_address = conn_info.get_preferred_ip(ip_type) - except Exception: - # with an error from Cloud SQL Admin API call or IP type, invalidate - # the cache and re-raise the error - await self._remove_cached(str(conn_name), enable_iam_auth) - raise + if ip_type == IPTypes.SQL_DATA: + logger.debug(f"['{conn_name}']: Connecting via SQL Data Service tunnel") + if enable_iam_auth: + engine = DriverMapping[driver].value + formatted_user = format_database_user( + engine, kwargs["user"] + ) + if formatted_user != kwargs["user"]: + logger.debug( + f"['{instance_connection_string}']: Truncated IAM database username from {kwargs['user']} to {formatted_user}" + ) + kwargs["user"] = formatted_user + + sqldata_client = SqlDataClient( + endpoint=self._sql_data_endpoint, + credentials=self._credentials, + quota_project=self._quota_project, + timeout=self._sql_data_stream_timeout, + ) + self._sqldata_clients.append(sqldata_client) + + def on_fallback(name): + self._sql_data_fallback_cache.add(name) + + def is_fallback_cached(name): + return name in self._sql_data_fallback_cache + + # Defer cache creation and connect_info call + async def get_conn_info(): + cache = self._get_or_create_cache(conn_name, enable_iam_auth) + return await cache.connect_info() + + tunnel_port = await sqldata_client.connect_tunnel( + instance_connection_name=str(conn_name), + region=conn_name.region, + project=conn_name.project, + get_conn_info=get_conn_info, + enable_iam_auth=enable_iam_auth, + on_fallback=on_fallback, + is_fallback_cached=is_fallback_cached, + ) - # If the connector is configured with a custom DNS name, attempt to use - # that DNS name to connect to the instance. Fall back to the metadata IP - # address if the DNS name does not resolve to an IP address. - if conn_info.conn_name.domain_name and isinstance(self._resolver, DnsResolver): - try: - ips = await self._resolver.resolve_a_record(conn_info.conn_name.domain_name) - if ips: - ip_address = ips[0] - logger.debug( - f"['{instance_connection_string}']: Custom DNS name " - f"'{conn_info.conn_name.domain_name}' resolved to '{ip_address}', " - "using it to connect" + if driver in ASYNC_DRIVERS: + return await connector( + "127.0.0.1", + None, + port=tunnel_port, + **kwargs, ) else: - logger.debug( - f"['{instance_connection_string}']: Custom DNS name " - f"'{conn_info.conn_name.domain_name}' resolved but returned no " - f"entries, using '{ip_address}' from instance metadata" + raw_sock = socket.create_connection(("127.0.0.1", tunnel_port)) + fd = raw_sock.detach() + fallback_sock = FallbackSocket(fileno=fd) + + if conn_name.domain_name: + monitored_cache.sockets.append(fallback_sock) + + connect_partial = partial( + connector, + "127.0.0.1", + fallback_sock, + **kwargs, ) - except Exception as e: - logger.debug( - f"['{instance_connection_string}']: Custom DNS name " - f"'{conn_info.conn_name.domain_name}' did not resolve to an IP " - f"address: {e}, using '{ip_address}' from instance metadata" - ) - - logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307") - # format `user` param for automatic IAM database authn - if enable_iam_auth: - formatted_user = format_database_user( - conn_info.database_version, kwargs["user"] - ) - if formatted_user != kwargs["user"]: - logger.debug( - f"['{instance_connection_string}']: Truncated IAM database username from {kwargs['user']} to {formatted_user}" + return await self._loop.run_in_executor(None, connect_partial) + else: + # Standard path (requires metadata and certs) + try: + conn_info = await monitored_cache.connect_info() + # validate driver matches intended database engine + DriverMapping.validate_engine(driver, conn_info.database_version) + ip_address = conn_info.get_preferred_ip(ip_type) + except Exception: + # with an error from Cloud SQL Admin API call or IP type, invalidate + # the cache and re-raise the error + await self._remove_cached(str(conn_name), enable_iam_auth) + raise + + # If the connector is configured with a custom DNS name, attempt to use + # that DNS name to connect to the instance. Fall back to the metadata IP + # address if the DNS name does not resolve to an IP address. + if conn_info.conn_name.domain_name and isinstance(self._resolver, DnsResolver): + try: + ips = await self._resolver.resolve_a_record(conn_info.conn_name.domain_name) + if ips: + ip_address = ips[0] + logger.debug( + f"['{instance_connection_string}']: Custom DNS name " + f"'{conn_info.conn_name.domain_name}' resolved to '{ip_address}', " + "using it to connect" + ) + else: + logger.debug( + f"['{instance_connection_string}']: Custom DNS name " + f"'{conn_info.conn_name.domain_name}' resolved but returned no " + f"entries, using '{ip_address}' from instance metadata" + ) + except Exception as e: + logger.debug( + f"['{instance_connection_string}']: Custom DNS name " + f"'{conn_info.conn_name.domain_name}' did not resolve to an IP " + f"address: {e}, using '{ip_address}' from instance metadata" + ) + + logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307") + # format `user` param for automatic IAM database authn + if enable_iam_auth: + formatted_user = format_database_user( + conn_info.database_version, kwargs["user"] + ) + if formatted_user != kwargs["user"]: + logger.debug( + f"['{instance_connection_string}']: Truncated IAM database username from {kwargs['user']} to {formatted_user}" + ) + kwargs["user"] = formatted_user + + if driver in ASYNC_DRIVERS: + return await connector( + ip_address, + await conn_info.create_ssl_context(enable_iam_auth), + **kwargs, + ) + ctx = await conn_info.create_ssl_context(enable_iam_auth) + ssl_sock = ctx.wrap_socket( + socket.create_connection((ip_address, SERVER_PROXY_PORT)), + server_hostname=ip_address, ) - kwargs["user"] = formatted_user - try: - # async drivers are unblocking and can be awaited directly - if driver in ASYNC_DRIVERS: - return await connector( + if conn_info.conn_name.domain_name: + monitored_cache.sockets.append(ssl_sock) + connect_partial = partial( + connector, ip_address, - await conn_info.create_ssl_context(enable_iam_auth), + ssl_sock, **kwargs, ) - # Create socket with SSLContext for sync drivers - ctx = await conn_info.create_ssl_context(enable_iam_auth) - sock = ctx.wrap_socket( - socket.create_connection((ip_address, SERVER_PROXY_PORT)), - server_hostname=ip_address, - ) - # If this connection was opened using a domain name, then store it - # for later in case we need to forcibly close it on failover. - if conn_info.conn_name.domain_name: - monitored_cache.sockets.append(sock) - # Synchronous drivers are blocking and run using executor - connect_partial = partial( - connector, - ip_address, - sock, - **kwargs, - ) - return await self._loop.run_in_executor(None, connect_partial) + return await self._loop.run_in_executor(None, connect_partial) except Exception: # with any exception, we attempt a force refresh, then throw the error - await monitored_cache.force_refresh() + cache = self._cache.get((str(conn_name), enable_iam_auth)) + if cache: + await cache.force_refresh() raise async def _remove_cached( @@ -505,8 +589,11 @@ def close(self) -> None: close_future = asyncio.run_coroutine_threadsafe( self.close_async(), loop=self._loop ) - # Will attempt to safely shut down tasks for 3s - close_future.result(timeout=3) + try: + # Will attempt to safely shut down tasks for 3s + close_future.result(timeout=3) + except Exception as e: + logger.error(f"Error during close_async: {e}") # if background thread exists for Connector, clean it up if self._thread: if self._loop.is_running(): @@ -521,7 +608,10 @@ async def close_async(self) -> None: self._closed = True if self._client: await self._client.close() - await asyncio.gather(*[cache.close() for cache in self._cache.values()]) + await asyncio.gather( + *[cache.close() for cache in self._cache.values()], + *[client.close() for client in self._sqldata_clients], + ) async def create_async_connector( diff --git a/google/cloud/sql/connector/enums.py b/google/cloud/sql/connector/enums.py index e6b56af0e..f936dba84 100644 --- a/google/cloud/sql/connector/enums.py +++ b/google/cloud/sql/connector/enums.py @@ -41,6 +41,7 @@ class IPTypes(Enum): PUBLIC = "PRIMARY" PRIVATE = "PRIVATE" PSC = "PSC" + SQL_DATA = "SQL_DATA" @classmethod def _missing_(cls, value: object) -> None: diff --git a/google/cloud/sql/connector/monitored_cache.py b/google/cloud/sql/connector/monitored_cache.py index 0c3fc4d03..79a77aeda 100644 --- a/google/cloud/sql/connector/monitored_cache.py +++ b/google/cloud/sql/connector/monitored_cache.py @@ -14,7 +14,7 @@ import asyncio import logging -import ssl +import socket from typing import Any, Callable, Optional, Union from google.cloud.sql.connector.connection_info import ConnectionInfo @@ -38,7 +38,7 @@ def __init__( self.resolver = resolver self.cache = cache self.domain_name_ticker: Optional[asyncio.Task] = None - self.sockets: list[ssl.SSLSocket] = [] + self.sockets: list[socket.socket] = [] # If domain name is configured for instance and failover period is set, # poll for DNS record changes. @@ -62,11 +62,11 @@ def _purge_closed_sockets(self) -> None: list of sockets. """ open_sockets = [] - for socket in self.sockets: + for sock in self.sockets: # Check fileno for if socket is closed. Will return # -1 on failure, which will be used to signal socket closed. - if socket.fileno() != -1: - open_sockets.append(socket) + if sock.fileno() != -1: + open_sockets.append(sock) self.sockets = open_sockets async def _check_domain_name(self) -> None: @@ -128,11 +128,11 @@ async def close(self) -> None: await self.cache.close() # Close any still open sockets - for socket in self.sockets: + for sock in self.sockets: # Check fileno for if socket is closed. Will return # -1 on failure, which will be used to signal socket closed. - if socket.fileno() != -1: - socket.close() + if sock.fileno() != -1: + sock.close() async def ticker(interval: int, function: Callable, *args: Any, **kwargs: Any) -> None: diff --git a/google/cloud/sql/connector/proto/google/rpc/code.proto b/google/cloud/sql/connector/proto/google/rpc/code.proto new file mode 100644 index 000000000..8fef41170 --- /dev/null +++ b/google/cloud/sql/connector/proto/google/rpc/code.proto @@ -0,0 +1,186 @@ +// Copyright 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package google.rpc; + +option go_package = "google.golang.org/genproto/googleapis/rpc/code;code"; +option java_multiple_files = true; +option java_outer_classname = "CodeProto"; +option java_package = "com.google.rpc"; +option objc_class_prefix = "RPC"; + + +// The canonical error codes for Google APIs. +// +// +// Sometimes multiple error codes may apply. Services should return +// the most specific error code that applies. For example, prefer +// `OUT_OF_RANGE` over `FAILED_PRECONDITION` if both codes apply. +// Similarly prefer `NOT_FOUND` or `ALREADY_EXISTS` over `FAILED_PRECONDITION`. +enum Code { + // Not an error; returned on success + // + // HTTP Mapping: 200 OK + OK = 0; + + // The operation was cancelled, typically by the caller. + // + // HTTP Mapping: 499 Client Closed Request + CANCELLED = 1; + + // Unknown error. For example, this error may be returned when + // a `Status` value received from another address space belongs to + // an error space that is not known in this address space. Also + // errors raised by APIs that do not return enough error information + // may be converted to this error. + // + // HTTP Mapping: 500 Internal Server Error + UNKNOWN = 2; + + // The client specified an invalid argument. Note that this differs + // from `FAILED_PRECONDITION`. `INVALID_ARGUMENT` indicates arguments + // that are problematic regardless of the state of the system + // (e.g., a malformed file name). + // + // HTTP Mapping: 400 Bad Request + INVALID_ARGUMENT = 3; + + // The deadline expired before the operation could complete. For operations + // that change the state of the system, this error may be returned + // even if the operation has completed successfully. For example, a + // successful response from a server could have been delayed long + // enough for the deadline to expire. + // + // HTTP Mapping: 504 Gateway Timeout + DEADLINE_EXCEEDED = 4; + + // Some requested entity (e.g., file or directory) was not found. + // + // Note to server developers: if a request is denied for an entire class + // of users, such as gradual feature rollout or undocumented whitelist, + // `NOT_FOUND` may be used. If a request is denied for some users within + // a class of users, such as user-based access control, `PERMISSION_DENIED` + // must be used. + // + // HTTP Mapping: 404 Not Found + NOT_FOUND = 5; + + // The entity that a client attempted to create (e.g., file or directory) + // already exists. + // + // HTTP Mapping: 409 Conflict + ALREADY_EXISTS = 6; + + // The caller does not have permission to execute the specified + // operation. `PERMISSION_DENIED` must not be used for rejections + // caused by exhausting some resource (use `RESOURCE_EXHAUSTED` + // instead for those errors). `PERMISSION_DENIED` must not be + // used if the caller can not be identified (use `UNAUTHENTICATED` + // instead for those errors). This error code does not imply the + // request is valid or the requested entity exists or satisfies + // other pre-conditions. + // + // HTTP Mapping: 403 Forbidden + PERMISSION_DENIED = 7; + + // The request does not have valid authentication credentials for the + // operation. + // + // HTTP Mapping: 401 Unauthorized + UNAUTHENTICATED = 16; + + // Some resource has been exhausted, perhaps a per-user quota, or + // perhaps the entire file system is out of space. + // + // HTTP Mapping: 429 Too Many Requests + RESOURCE_EXHAUSTED = 8; + + // The operation was rejected because the system is not in a state + // required for the operation's execution. For example, the directory + // to be deleted is non-empty, an rmdir operation is applied to + // a non-directory, etc. + // + // Service implementors can use the following guidelines to decide + // between `FAILED_PRECONDITION`, `ABORTED`, and `UNAVAILABLE`: + // (a) Use `UNAVAILABLE` if the client can retry just the failing call. + // (b) Use `ABORTED` if the client should retry at a higher level + // (e.g., when a client-specified test-and-set fails, indicating the + // client should restart a read-modify-write sequence). + // (c) Use `FAILED_PRECONDITION` if the client should not retry until + // the system state has been explicitly fixed. E.g., if an "rmdir" + // fails because the directory is non-empty, `FAILED_PRECONDITION` + // should be returned since the client should not retry unless + // the files are deleted from the directory. + // + // HTTP Mapping: 400 Bad Request + FAILED_PRECONDITION = 9; + + // The operation was aborted, typically due to a concurrency issue such as + // a sequencer check failure or transaction abort. + // + // See the guidelines above for deciding between `FAILED_PRECONDITION`, + // `ABORTED`, and `UNAVAILABLE`. + // + // HTTP Mapping: 409 Conflict + ABORTED = 10; + + // The operation was attempted past the valid range. E.g., seeking or + // reading past end-of-file. + // + // Unlike `INVALID_ARGUMENT`, this error indicates a problem that may + // be fixed if the system state changes. For example, a 32-bit file + // system will generate `INVALID_ARGUMENT` if asked to read at an + // offset that is not in the range [0,2^32-1], but it will generate + // `OUT_OF_RANGE` if asked to read from an offset past the current + // file size. + // + // There is a fair bit of overlap between `FAILED_PRECONDITION` and + // `OUT_OF_RANGE`. We recommend using `OUT_OF_RANGE` (the more specific + // error) when it applies so that callers who are iterating through + // a space can easily look for an `OUT_OF_RANGE` error to detect when + // they are done. + // + // HTTP Mapping: 400 Bad Request + OUT_OF_RANGE = 11; + + // The operation is not implemented or is not supported/enabled in this + // service. + // + // HTTP Mapping: 501 Not Implemented + UNIMPLEMENTED = 12; + + // Internal errors. This means that some invariants expected by the + // underlying system have been broken. This error code is reserved + // for serious errors. + // + // HTTP Mapping: 500 Internal Server Error + INTERNAL = 13; + + // The service is currently unavailable. This is most likely a + // transient condition, which can be corrected by retrying with + // a backoff. + // + // See the guidelines above for deciding between `FAILED_PRECONDITION`, + // `ABORTED`, and `UNAVAILABLE`. + // + // HTTP Mapping: 503 Service Unavailable + UNAVAILABLE = 14; + + // Unrecoverable data loss or corruption. + // + // HTTP Mapping: 500 Internal Server Error + DATA_LOSS = 15; +} diff --git a/google/cloud/sql/connector/proto/google/rpc/error_details.proto b/google/cloud/sql/connector/proto/google/rpc/error_details.proto new file mode 100644 index 000000000..f24ae0099 --- /dev/null +++ b/google/cloud/sql/connector/proto/google/rpc/error_details.proto @@ -0,0 +1,200 @@ +// Copyright 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package google.rpc; + +import "google/protobuf/duration.proto"; + +option go_package = "google.golang.org/genproto/googleapis/rpc/errdetails;errdetails"; +option java_multiple_files = true; +option java_outer_classname = "ErrorDetailsProto"; +option java_package = "com.google.rpc"; +option objc_class_prefix = "RPC"; + + +// Describes when the clients can retry a failed request. Clients could ignore +// the recommendation here or retry when this information is missing from error +// responses. +// +// It's always recommended that clients should use exponential backoff when +// retrying. +// +// Clients should wait until `retry_delay` amount of time has passed since +// receiving the error response before retrying. If retrying requests also +// fail, clients should use an exponential backoff scheme to gradually increase +// the delay between retries based on `retry_delay`, until either a maximum +// number of retires have been reached or a maximum retry delay cap has been +// reached. +message RetryInfo { + // Clients should wait at least this long between retrying the same request. + google.protobuf.Duration retry_delay = 1; +} + +// Describes additional debugging info. +message DebugInfo { + // The stack trace entries indicating where the error occurred. + repeated string stack_entries = 1; + + // Additional debugging information provided by the server. + string detail = 2; +} + +// Describes how a quota check failed. +// +// For example if a daily limit was exceeded for the calling project, +// a service could respond with a QuotaFailure detail containing the project +// id and the description of the quota limit that was exceeded. If the +// calling project hasn't enabled the service in the developer console, then +// a service could respond with the project id and set `service_disabled` +// to true. +// +// Also see RetryDetail and Help types for other details about handling a +// quota failure. +message QuotaFailure { + // A message type used to describe a single quota violation. For example, a + // daily quota or a custom quota that was exceeded. + message Violation { + // The subject on which the quota check failed. + // For example, "clientip:" or "project:". + string subject = 1; + + // A description of how the quota check failed. Clients can use this + // description to find more about the quota configuration in the service's + // public documentation, or find the relevant quota limit to adjust through + // developer console. + // + // For example: "Service disabled" or "Daily Limit for read operations + // exceeded". + string description = 2; + } + + // Describes all quota violations. + repeated Violation violations = 1; +} + +// Describes what preconditions have failed. +// +// For example, if an RPC failed because it required the Terms of Service to be +// acknowledged, it could list the terms of service violation in the +// PreconditionFailure message. +message PreconditionFailure { + // A message type used to describe a single precondition failure. + message Violation { + // The type of PreconditionFailure. We recommend using a service-specific + // enum type to define the supported precondition violation types. For + // example, "TOS" for "Terms of Service violation". + string type = 1; + + // The subject, relative to the type, that failed. + // For example, "google.com/cloud" relative to the "TOS" type would + // indicate which terms of service is being referenced. + string subject = 2; + + // A description of how the precondition failed. Developers can use this + // description to understand how to fix the failure. + // + // For example: "Terms of service not accepted". + string description = 3; + } + + // Describes all precondition violations. + repeated Violation violations = 1; +} + +// Describes violations in a client request. This error type focuses on the +// syntactic aspects of the request. +message BadRequest { + // A message type used to describe a single bad request field. + message FieldViolation { + // A path leading to a field in the request body. The value will be a + // sequence of dot-separated identifiers that identify a protocol buffer + // field. E.g., "field_violations.field" would identify this field. + string field = 1; + + // A description of why the request element is bad. + string description = 2; + } + + // Describes all violations in a client request. + repeated FieldViolation field_violations = 1; +} + +// Contains metadata about the request that clients can attach when filing a bug +// or providing other forms of feedback. +message RequestInfo { + // An opaque string that should only be interpreted by the service generating + // it. For example, it can be used to identify requests in the service's logs. + string request_id = 1; + + // Any data that was used to serve this request. For example, an encrypted + // stack trace that can be sent back to the service provider for debugging. + string serving_data = 2; +} + +// Describes the resource that is being accessed. +message ResourceInfo { + // A name for the type of resource being accessed, e.g. "sql table", + // "cloud storage bucket", "file", "Google calendar"; or the type URL + // of the resource: e.g. "type.googleapis.com/google.pubsub.v1.Topic". + string resource_type = 1; + + // The name of the resource being accessed. For example, a shared calendar + // name: "example.com_4fghdhgsrgh@group.calendar.google.com", if the current + // error is [google.rpc.Code.PERMISSION_DENIED][google.rpc.Code.PERMISSION_DENIED]. + string resource_name = 2; + + // The owner of the resource (optional). + // For example, "user:" or "project:". + string owner = 3; + + // Describes what error is encountered when accessing this resource. + // For example, updating a cloud project may require the `writer` permission + // on the developer console project. + string description = 4; +} + +// Provides links to documentation or for performing an out of band action. +// +// For example, if a quota check failed with an error indicating the calling +// project hasn't enabled the accessed service, this can contain a URL pointing +// directly to the right place in the developer console to flip the bit. +message Help { + // Describes a URL link. + message Link { + // Describes what the link offers. + string description = 1; + + // The URL of the link. + string url = 2; + } + + // URL(s) pointing to additional information on handling the current error. + repeated Link links = 1; +} + +// Provides a localized error message that is safe to return to the user +// which can be attached to an RPC error. +message LocalizedMessage { + // The locale used following the specification defined at + // http://www.rfc-editor.org/rfc/bcp/bcp47.txt. + // Examples are: "en-US", "fr-CH", "es-MX" + string locale = 1; + + // The localized error message in the above locale. + string message = 2; +} diff --git a/google/cloud/sql/connector/proto/google/rpc/status.proto b/google/cloud/sql/connector/proto/google/rpc/status.proto new file mode 100644 index 000000000..0839ee966 --- /dev/null +++ b/google/cloud/sql/connector/proto/google/rpc/status.proto @@ -0,0 +1,92 @@ +// Copyright 2017 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package google.rpc; + +import "google/protobuf/any.proto"; + +option go_package = "google.golang.org/genproto/googleapis/rpc/status;status"; +option java_multiple_files = true; +option java_outer_classname = "StatusProto"; +option java_package = "com.google.rpc"; +option objc_class_prefix = "RPC"; + + +// The `Status` type defines a logical error model that is suitable for different +// programming environments, including REST APIs and RPC APIs. It is used by +// [gRPC](https://github.com/grpc). The error model is designed to be: +// +// - Simple to use and understand for most users +// - Flexible enough to meet unexpected needs +// +// # Overview +// +// The `Status` message contains three pieces of data: error code, error message, +// and error details. The error code should be an enum value of +// [google.rpc.Code][google.rpc.Code], but it may accept additional error codes if needed. The +// error message should be a developer-facing English message that helps +// developers *understand* and *resolve* the error. If a localized user-facing +// error message is needed, put the localized message in the error details or +// localize it in the client. The optional error details may contain arbitrary +// information about the error. There is a predefined set of error detail types +// in the package `google.rpc` that can be used for common error conditions. +// +// # Language mapping +// +// The `Status` message is the logical representation of the error model, but it +// is not necessarily the actual wire format. When the `Status` message is +// exposed in different client libraries and different wire protocols, it can be +// mapped differently. For example, it will likely be mapped to some exceptions +// in Java, but more likely mapped to some error codes in C. +// +// # Other uses +// +// The error model and the `Status` message can be used in a variety of +// environments, either with or without APIs, to provide a +// consistent developer experience across different environments. +// +// Example uses of this error model include: +// +// - Partial errors. If a service needs to return partial errors to the client, +// it may embed the `Status` in the normal response to indicate the partial +// errors. +// +// - Workflow errors. A typical workflow has multiple steps. Each step may +// have a `Status` message for error reporting. +// +// - Batch operations. If a client uses batch request and batch response, the +// `Status` message should be used directly inside batch response, one for +// each error sub-response. +// +// - Asynchronous operations. If an API call embeds asynchronous operation +// results in its response, the status of those operations should be +// represented directly using the `Status` message. +// +// - Logging. If some API errors are stored in logs, the message `Status` could +// be used directly after any stripping needed for security/privacy reasons. +message Status { + // The status code, which should be an enum value of [google.rpc.Code][google.rpc.Code]. + int32 code = 1; + + // A developer-facing error message, which should be in English. Any + // user-facing error message should be localized and sent in the + // [google.rpc.Status.details][google.rpc.Status.details] field, or localized by the client. + string message = 2; + + // A list of messages that carry the error details. There is a common set of + // message types for APIs to use. + repeated google.protobuf.Any details = 3; +} diff --git a/google/cloud/sql/connector/proto/sql_data_service.proto b/google/cloud/sql/connector/proto/sql_data_service.proto new file mode 100644 index 000000000..98d688cd3 --- /dev/null +++ b/google/cloud/sql/connector/proto/sql_data_service.proto @@ -0,0 +1,264 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package google.cloud.sql.v1beta4; + +option go_package = "internal/sqldata"; +option java_package = "com.google.cloud.sql.v1beta4"; +option java_outer_classname = "CloudSqlDataProto"; +option java_multiple_files = true; + +import "google/rpc/status.proto"; + +// Service for streaming data to and from Cloud SQL instances. +service SqlDataService { + // `StreamSqlData` establishes a bidirectional stream to a Cloud SQL instance, + // and then streams data to and from the instance. + // + // The first message from the client MUST be a `StreamSqlDataRequest` request + // with configuration settings, including required values for the + // `connection_settings` field. Subsequent messages from the client may + // contain the `payload` field. + // + // Messages from the server may contain the `payload` field. + // + // The `payload` fields of the request and response streams contain the raw + // data of the database's native wire protocol (e.g., PostgreSQL wire + // protocol). The database client is responsible for generating and parsing + // this data. + // + // Any errors on initial connection (e.g., connection failure, authorization + // issues, network problems) will result in the stream being terminated with + // an appropriate RPC status exception. + // + // After a successful connection is made, if an error occurs, then the server + // terminates connection and returns the appropriate RPC status exception. + rpc StreamSqlData(stream StreamSqlDataRequest) + returns (stream StreamSqlDataResponse) {} +} + +// Message sent from the client to `SqlDataService`. +message StreamSqlDataRequest { + // Deprecated: Use `StartSession.location_id` or `ContinueSession.location_id` + // instead. `location_id` is used to route the request to a specific region. + // Use the same region which was used to create the instance. Use the format + // `locations/{location}`, for example: `locations/us-central1`. + string location_id = 1; + + // Deprecated: Use the `message` oneof instead. The type of message sent + // within the stream. + oneof message_type { + // Deprecated: Use `start_session` or `continue_session` instead. + // Parameters for establishing the connection. MUST be sent as the first + // message on the stream. + ConnectionSettings connection_settings = 2; + + // Deprecated: Use `DataPacket` instead. + // Data to be forwarded to the database. + ClientPayload payload = 3; + } + + // Acknowledges data received by the client. + Ack ack = 4; + + // The message to the server. + oneof message { + // Starts a new session. When starting a new session, this is the first + // message the client sends. + StartSession start_session = 5; + // Continues an existing session. When starting a new session, this is the + // first message the client sends. + ContinueSession continue_session = 6; + // Database data. + DataPacket data = 7; + // Terminates the session. This closes the connection to the database. + TerminateSession terminate_session = 8; + } +} + +// Deprecated: New schema structure. Initial connection parameters. +message ConnectionSettings { + option deprecated = true; + + // The target of the connection. + oneof target { + // The identifier of the Cloud SQL instance. + InstanceId instance_id = 1; + } +} + +// Start a new session. The client must send this as the first message to the +// server to start a new session. The client may immediately send Data messages +// without waiting for a reply from the server. +message StartSession { + //`location_id` is used to route the + // request to a specific region. Use the same region which was used to create + // the instance. Use the format `locations/{location}`, for example: + // `locations/us-central1`. + string location_id = 1; + // The Cloud SQL instance resource name, for example: + // projects/example-project/instances/example-instance + string instance_id = 2; + // The session id, chosen by the client. This should be an unguessable string. + // If the client does not intend to reconnect to this session, the client may + // leave session_id unset. + string session_id = 3; +} + +// Reconnects to an existing session. The client must send this as the first +// message to the server to reconnect to an existing session. The client may +// immediately send Data messages without waiting for a reply from the server. +message ContinueSession { + //`location_id` is used to route the + // request to a specific region. Use the same region which was used to create + // the instance. Use the format `locations/{location}`, for example: + // `locations/us-central1`. + string location_id = 1; + + // The Cloud SQL instance resource name, for example: + // projects/example-project/instances/example-instance + string instance_id = 2; + + // The id of the session to reconnect. + string session_id = 3; +} + +// Deprecated: New schema structure. The identifier of the Cloud SQL instance. +message InstanceId { + option deprecated = true; + + // Full resource name of the Cloud SQL instance, in the form: + // `projects/{project}/instances/{instance}`, for example: + // `projects/foo-project/instances/bar-instance`. + string instance = 1; +} + +// Deprecated: New schema structure. Wrapper for data being sent to the +// database. +message ClientPayload { + option deprecated = true; + + // Raw data to be sent to the database. See the documentation for + // `StreamSqlData` for details on the expected wire format. + bytes data = 1; +} + +// Message sent from SqlDataService back to the client. +message StreamSqlDataResponse { + // Deprecated: New schema structure. The type of the message received from + // `SqlDataService`. + oneof type { + // Raw data received from the database. + ServerPayload payload = 1; + } + + // Acknowledges data received by the server. + Ack ack = 2; + // A message from the server to the client. + oneof message { + // The first message from the server to the client, containing metadata + // about this session. + SessionMetadata session_metadata = 3; + // Data from the database. + DataPacket data = 4; + // Terminates the session. This indicates that the database connection + // is closed. When the client receives this message, it should not + // attempt to reconnect. + TerminateSession terminate_session = 5; + } +} + +// Deprecated: New schema structure. Wrapper for data being received from the +// database. +message ServerPayload { + option deprecated = true; + + // Raw data received from the database. See the documentation for + // `StreamSqlData` for details on the expected wire format. + bytes data = 1; +} +// Metadata from the server to the client about the session. The server will +// always send this as the first message +message SessionMetadata { + // The features supported by the server for this session. This field is used + // by the client to determine which features are available on the server. + // The features supported by the server for this session. + repeated SqlDataFeature supported_features = 1; +} + +// Contains data being sent or received by the database. +message DataPacket { + // The absolute byte offset of the first byte in this payload. + // 0 for new connections or resumed connections that hasn't acked any bytes + // from server. Non-zero for resumed connections + int64 first_byte_offset = 1; + // Raw data being sent or received by the database. + bytes data = 2; +} +// Acknowledges data received by the client or server. +message Ack { + // The absolute number of bytes processed in the session. + int64 received_offset = 1; +} +// Indicates that the session is permanently ended. +message TerminateSession { + // The session termination status. + google.rpc.Status status = 1; +} + +// Error reasons for `StreamSqlData`. +// Typically used with standard error codes, with the error info/reason field +// set to the string representation of the enum value. +enum StreamSqlDataErrorReason { + // Indicates that the error reason is unknown. + STREAM_SQL_DATA_ERROR_REASON_UNKNOWN = 0; + + // Indicates that the operation is not supported for given instance type. + // Used with status code `google.rpc.Code.FAILED_PRECONDITION`. + STREAM_SQL_DATA_ERROR_REASON_UNSUPPORTED_INSTANCE_TYPE = 1; + + // Indicates that reconnect failed and should not be retried. + // Used with status code `google.rpc.Code.INTERNAL`. + STREAM_SQL_DATA_ERROR_REASON_RECONNECT_FAILED = 3; + + // Indicates that the database client closed its connection normally. + // Used with status code `google.rpc.Code.CANCELED`. + STREAM_SQL_DATA_ERROR_REASON_DB_CLIENT_CLOSED = 4; + + // Indicates that the database server closed its connection normally. + // Used with status code `google.rpc.Code.CANCELED`. + STREAM_SQL_DATA_ERROR_REASON_DB_SERVER_CLOSED = 5; + + // Indicates that the peer sent an ACK message that was not within an + // acceptable range. Used with the status code + // `google.rpc.Code.FAILED_PRECONDITION`. + STREAM_SQL_DATA_ERROR_REASON_INVALID_ACK = 6; + + // Indicates that the SqlDataService lost its connection to the + // database instance. This is a retryable error. + // Used with status code `google.rpc.Code.ABORTED`. + STREAM_SQL_DATA_ERROR_REASON_DISCONNECTED = 7; +} + +// The session features. The server must send the supported features in its +// first message to the client. +enum SqlDataFeature { + // The feature is not specified. This value should not be used. + SQL_DATA_FEATURE_UNSPECIFIED = 0; + // The server supports reconnecting to the session. If this feature is not + // present, the client should not attempt to reconnect to the session. + SQL_DATA_FEATURE_RECONNECT = 1; +} diff --git a/google/cloud/sql/connector/proto/sql_data_service_pb2.py b/google/cloud/sql/connector/proto/sql_data_service_pb2.py new file mode 100644 index 000000000..fb25be21d --- /dev/null +++ b/google/cloud/sql/connector/proto/sql_data_service_pb2.py @@ -0,0 +1,88 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: google/cloud/sql/connector/proto/sql_data_service.proto +# Protobuf Python Version: 6.33.5 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 6, + 33, + 5, + '', + 'google/cloud/sql/connector/proto/sql_data_service.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n7google/cloud/sql/connector/proto/sql_data_service.proto\x12\x18google.cloud.sql.v1beta4\x1a\x17google/rpc/status.proto\"\x82\x04\n\x14StreamSqlDataRequest\x12\x13\n\x0blocation_id\x18\x01 \x01(\t\x12K\n\x13\x63onnection_settings\x18\x02 \x01(\x0b\x32,.google.cloud.sql.v1beta4.ConnectionSettingsH\x00\x12:\n\x07payload\x18\x03 \x01(\x0b\x32\'.google.cloud.sql.v1beta4.ClientPayloadH\x00\x12*\n\x03\x61\x63k\x18\x04 \x01(\x0b\x32\x1d.google.cloud.sql.v1beta4.Ack\x12?\n\rstart_session\x18\x05 \x01(\x0b\x32&.google.cloud.sql.v1beta4.StartSessionH\x01\x12\x45\n\x10\x63ontinue_session\x18\x06 \x01(\x0b\x32).google.cloud.sql.v1beta4.ContinueSessionH\x01\x12\x34\n\x04\x64\x61ta\x18\x07 \x01(\x0b\x32$.google.cloud.sql.v1beta4.DataPacketH\x01\x12G\n\x11terminate_session\x18\x08 \x01(\x0b\x32*.google.cloud.sql.v1beta4.TerminateSessionH\x01\x42\x0e\n\x0cmessage_typeB\t\n\x07message\"_\n\x12\x43onnectionSettings\x12;\n\x0binstance_id\x18\x01 \x01(\x0b\x32$.google.cloud.sql.v1beta4.InstanceIdH\x00:\x02\x18\x01\x42\x08\n\x06target\"L\n\x0cStartSession\x12\x13\n\x0blocation_id\x18\x01 \x01(\t\x12\x13\n\x0binstance_id\x18\x02 \x01(\t\x12\x12\n\nsession_id\x18\x03 \x01(\t\"O\n\x0f\x43ontinueSession\x12\x13\n\x0blocation_id\x18\x01 \x01(\t\x12\x13\n\x0binstance_id\x18\x02 \x01(\t\x12\x12\n\nsession_id\x18\x03 \x01(\t\"\"\n\nInstanceId\x12\x10\n\x08instance\x18\x01 \x01(\t:\x02\x18\x01\"!\n\rClientPayload\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c:\x02\x18\x01\"\xd8\x02\n\x15StreamSqlDataResponse\x12:\n\x07payload\x18\x01 \x01(\x0b\x32\'.google.cloud.sql.v1beta4.ServerPayloadH\x00\x12*\n\x03\x61\x63k\x18\x02 \x01(\x0b\x32\x1d.google.cloud.sql.v1beta4.Ack\x12\x45\n\x10session_metadata\x18\x03 \x01(\x0b\x32).google.cloud.sql.v1beta4.SessionMetadataH\x01\x12\x34\n\x04\x64\x61ta\x18\x04 \x01(\x0b\x32$.google.cloud.sql.v1beta4.DataPacketH\x01\x12G\n\x11terminate_session\x18\x05 \x01(\x0b\x32*.google.cloud.sql.v1beta4.TerminateSessionH\x01\x42\x06\n\x04typeB\t\n\x07message\"!\n\rServerPayload\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c:\x02\x18\x01\"W\n\x0fSessionMetadata\x12\x44\n\x12supported_features\x18\x01 \x03(\x0e\x32(.google.cloud.sql.v1beta4.SqlDataFeature\"5\n\nDataPacket\x12\x19\n\x11\x66irst_byte_offset\x18\x01 \x01(\x03\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x1e\n\x03\x41\x63k\x12\x17\n\x0freceived_offset\x18\x01 \x01(\x03\"6\n\x10TerminateSession\x12\"\n\x06status\x18\x01 \x01(\x0b\x32\x12.google.rpc.Status*\xf6\x02\n\x18StreamSqlDataErrorReason\x12(\n$STREAM_SQL_DATA_ERROR_REASON_UNKNOWN\x10\x00\x12:\n6STREAM_SQL_DATA_ERROR_REASON_UNSUPPORTED_INSTANCE_TYPE\x10\x01\x12\x31\n-STREAM_SQL_DATA_ERROR_REASON_RECONNECT_FAILED\x10\x03\x12\x31\n-STREAM_SQL_DATA_ERROR_REASON_DB_CLIENT_CLOSED\x10\x04\x12\x31\n-STREAM_SQL_DATA_ERROR_REASON_DB_SERVER_CLOSED\x10\x05\x12,\n(STREAM_SQL_DATA_ERROR_REASON_INVALID_ACK\x10\x06\x12-\n)STREAM_SQL_DATA_ERROR_REASON_DISCONNECTED\x10\x07*R\n\x0eSqlDataFeature\x12 \n\x1cSQL_DATA_FEATURE_UNSPECIFIED\x10\x00\x12\x1e\n\x1aSQL_DATA_FEATURE_RECONNECT\x10\x01\x32\x88\x01\n\x0eSqlDataService\x12v\n\rStreamSqlData\x12..google.cloud.sql.v1beta4.StreamSqlDataRequest\x1a/.google.cloud.sql.v1beta4.StreamSqlDataResponse\"\x00(\x01\x30\x01\x42\x45\n\x1c\x63om.google.cloud.sql.v1beta4B\x11\x43loudSqlDataProtoP\x01Z\x10internal/sqldatab\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.cloud.sql.connector.proto.sql_data_service_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + _globals['DESCRIPTOR']._loaded_options = None + _globals['DESCRIPTOR']._serialized_options = b'\n\034com.google.cloud.sql.v1beta4B\021CloudSqlDataProtoP\001Z\020internal/sqldata' + _globals['_CONNECTIONSETTINGS']._loaded_options = None + _globals['_CONNECTIONSETTINGS']._serialized_options = b'\030\001' + _globals['_INSTANCEID']._loaded_options = None + _globals['_INSTANCEID']._serialized_options = b'\030\001' + _globals['_CLIENTPAYLOAD']._loaded_options = None + _globals['_CLIENTPAYLOAD']._serialized_options = b'\030\001' + _globals['_SERVERPAYLOAD']._loaded_options = None + _globals['_SERVERPAYLOAD']._serialized_options = b'\030\001' + _globals['_STREAMSQLDATAERRORREASON']._serialized_start=1569 + _globals['_STREAMSQLDATAERRORREASON']._serialized_end=1943 + _globals['_SQLDATAFEATURE']._serialized_start=1945 + _globals['_SQLDATAFEATURE']._serialized_end=2027 + _globals['_STREAMSQLDATAREQUEST']._serialized_start=111 + _globals['_STREAMSQLDATAREQUEST']._serialized_end=625 + _globals['_CONNECTIONSETTINGS']._serialized_start=627 + _globals['_CONNECTIONSETTINGS']._serialized_end=722 + _globals['_STARTSESSION']._serialized_start=724 + _globals['_STARTSESSION']._serialized_end=800 + _globals['_CONTINUESESSION']._serialized_start=802 + _globals['_CONTINUESESSION']._serialized_end=881 + _globals['_INSTANCEID']._serialized_start=883 + _globals['_INSTANCEID']._serialized_end=917 + _globals['_CLIENTPAYLOAD']._serialized_start=919 + _globals['_CLIENTPAYLOAD']._serialized_end=952 + _globals['_STREAMSQLDATARESPONSE']._serialized_start=955 + _globals['_STREAMSQLDATARESPONSE']._serialized_end=1299 + _globals['_SERVERPAYLOAD']._serialized_start=1301 + _globals['_SERVERPAYLOAD']._serialized_end=1334 + _globals['_SESSIONMETADATA']._serialized_start=1336 + _globals['_SESSIONMETADATA']._serialized_end=1423 + _globals['_DATAPACKET']._serialized_start=1425 + _globals['_DATAPACKET']._serialized_end=1478 + _globals['_ACK']._serialized_start=1480 + _globals['_ACK']._serialized_end=1510 + _globals['_TERMINATESESSION']._serialized_start=1512 + _globals['_TERMINATESESSION']._serialized_end=1566 + _globals['_SQLDATASERVICE']._serialized_start=2030 + _globals['_SQLDATASERVICE']._serialized_end=2166 +# @@protoc_insertion_point(module_scope) diff --git a/google/cloud/sql/connector/proto/sql_data_service_pb2_grpc.py b/google/cloud/sql/connector/proto/sql_data_service_pb2_grpc.py new file mode 100644 index 000000000..42240ecbd --- /dev/null +++ b/google/cloud/sql/connector/proto/sql_data_service_pb2_grpc.py @@ -0,0 +1,137 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" + +import grpc + +from google.cloud.sql.connector.proto import ( + sql_data_service_pb2 as google_dot_cloud_dot_sql_dot_connector_dot_proto_dot_sql__data__service__pb2, +) + +GRPC_GENERATED_VERSION = '1.81.0' +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + ' but the generated code in google/cloud/sql/connector/proto/sql_data_service_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + ) + + +class SqlDataServiceStub: + """Service for streaming data to and from Cloud SQL instances. + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.StreamSqlData = channel.stream_stream( + '/google.cloud.sql.v1beta4.SqlDataService/StreamSqlData', + request_serializer=google_dot_cloud_dot_sql_dot_connector_dot_proto_dot_sql__data__service__pb2.StreamSqlDataRequest.SerializeToString, + response_deserializer=google_dot_cloud_dot_sql_dot_connector_dot_proto_dot_sql__data__service__pb2.StreamSqlDataResponse.FromString, + _registered_method=True) + + +class SqlDataServiceServicer: + """Service for streaming data to and from Cloud SQL instances. + """ + + def StreamSqlData(self, request_iterator, context): + """`StreamSqlData` establishes a bidirectional stream to a Cloud SQL instance, + and then streams data to and from the instance. + + The first message from the client MUST be a `StreamSqlDataRequest` request + with configuration settings, including required values for the + `connection_settings` field. Subsequent messages from the client may + contain the `payload` field. + + Messages from the server may contain the `payload` field. + + The `payload` fields of the request and response streams contain the raw + data of the database's native wire protocol (e.g., PostgreSQL wire + protocol). The database client is responsible for generating and parsing + this data. + + Any errors on initial connection (e.g., connection failure, authorization + issues, network problems) will result in the stream being terminated with + an appropriate RPC status exception. + + After a successful connection is made, if an error occurs, then the server + terminates connection and returns the appropriate RPC status exception. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_SqlDataServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'StreamSqlData': grpc.stream_stream_rpc_method_handler( + servicer.StreamSqlData, + request_deserializer=google_dot_cloud_dot_sql_dot_connector_dot_proto_dot_sql__data__service__pb2.StreamSqlDataRequest.FromString, + response_serializer=google_dot_cloud_dot_sql_dot_connector_dot_proto_dot_sql__data__service__pb2.StreamSqlDataResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'google.cloud.sql.v1beta4.SqlDataService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('google.cloud.sql.v1beta4.SqlDataService', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class SqlDataService: + """Service for streaming data to and from Cloud SQL instances. + """ + + @staticmethod + def StreamSqlData(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_stream( + request_iterator, + target, + '/google.cloud.sql.v1beta4.SqlDataService/StreamSqlData', + google_dot_cloud_dot_sql_dot_connector_dot_proto_dot_sql__data__service__pb2.StreamSqlDataRequest.SerializeToString, + google_dot_cloud_dot_sql_dot_connector_dot_proto_dot_sql__data__service__pb2.StreamSqlDataResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/google/cloud/sql/connector/sqldata_client.py b/google/cloud/sql/connector/sqldata_client.py new file mode 100644 index 000000000..950373f4b --- /dev/null +++ b/google/cloud/sql/connector/sqldata_client.py @@ -0,0 +1,355 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import logging +import socket +from typing import Any, Callable, Optional + +from google.auth.credentials import Credentials +from google.auth.transport.grpc import AuthMetadataPlugin +from google.auth.transport.requests import Request +import grpc + +import google.rpc.status_pb2 # noqa: F401 # isort: skip +from google.cloud.sql.connector.proto import sql_data_service_pb2 # type: ignore +from google.cloud.sql.connector.proto import sql_data_service_pb2_grpc # type: ignore + +logger = logging.getLogger(__name__) + + +class SqlDataClient: + def __init__( + self, + endpoint: str, + credentials: Credentials, + quota_project: Optional[str] = None, + timeout: Optional[float] = None, + ): + self._endpoint = endpoint + self._credentials = credentials + self._quota_project = quota_project + self._timeout = timeout + + async def connect_tunnel( + self, + instance_connection_name: str, + region: str, + project: str, + get_conn_info: Callable[[], Any], + enable_iam_auth: bool, + on_fallback: Callable[[str], None], + is_fallback_cached: Callable[[str], bool], + ) -> int: + """Starts a local TCP tunnel and returns the local port. + + If the instance does not support SQL Data Service, it falls back + to a direct TLS connection. + """ + # Start local TCP server + server = await asyncio.start_server( + lambda r, w: self._handle_tunnel( + r, + w, + instance_connection_name, + region, + project, + get_conn_info, + enable_iam_auth, + on_fallback, + is_fallback_cached, + ), + "127.0.0.1", + 0, + ) + + port = server.sockets[0].getsockname()[1] + logger.debug(f"SQL Data tunnel listening on 127.0.0.1:{port}") + + # Keep reference to server to close it + self._server = server + return port + + async def close(self) -> None: + """Closes the local tunnel server if it is running.""" + if hasattr(self, "_server") and self._server: + self._server.close() + try: + await asyncio.wait_for(self._server.wait_closed(), timeout=2.0) + logger.debug("SQL Data tunnel server closed by client close()") + except asyncio.TimeoutError: + logger.warning("Timeout waiting for SQL Data tunnel server to close") + + async def _handle_tunnel( + self, + client_reader: asyncio.StreamReader, + client_writer: asyncio.StreamWriter, + instance_connection_name: str, + region: str, + project: str, + get_conn_info: Callable[[], Any], + enable_iam_auth: bool, + on_fallback: Callable[[str], None], + is_fallback_cached: Callable[[str], bool], + ): + logger.debug("Accepted local connection for SQL Data tunnel") + # Close the server so no more connections are accepted on this port + self._server.close() + + # Buffer to cache client writes for fallback replay + client_write_buffer = bytearray() + first_read_done = False + fallback_triggered = False + + # We need to share these streams between tasks + backend_reader: Optional[asyncio.StreamReader] = None + backend_writer: Optional[asyncio.StreamWriter] = None + grpc_stream: Optional[Any] = None + grpc_channel: Optional[grpc.aio.Channel] = None + + # Check if fallback is already cached + use_fallback = is_fallback_cached(instance_connection_name) + + async def connect_grpc() -> tuple[grpc.aio.Channel, Any]: + auth_request = Request() + plugin = AuthMetadataPlugin(self._credentials, auth_request) + call_creds = grpc.metadata_call_credentials(plugin) + channel_creds = grpc.composite_channel_credentials( + grpc.ssl_channel_credentials(), call_creds + ) + + endpoint = self._endpoint + if endpoint.startswith("https://"): + endpoint = endpoint[len("https://") :] + if endpoint.startswith("http://"): + endpoint = endpoint[len("http://") :] + + logger.debug(f"Creating secure channel to {endpoint}") + channel = grpc.aio.secure_channel(endpoint, channel_creds) + stub = sql_data_service_pb2_grpc.SqlDataServiceStub(channel) + + instance_id = f"projects/{project}/instances/{instance_connection_name.split(':')[-1]}" + location_id = f"locations/{region}" + + metadata = [] + quota_project_in_creds = getattr(self._credentials, "quota_project_id", None) + if self._quota_project and self._quota_project != quota_project_in_creds: + metadata.append(("x-goog-user-project", self._quota_project)) + metadata.append( + ( + "x-goog-request-params", + f"instance_id={instance_id}&location_id={location_id}", + ) + ) + + # Start stream + logger.debug(f"Starting StreamSqlData with metadata {metadata}") + stream = stub.StreamSqlData(metadata=metadata) + + # Send StartSession + start_session = sql_data_service_pb2.StartSession( # type: ignore[attr-defined] + instance_id=instance_id, location_id=location_id + ) + req = sql_data_service_pb2.StreamSqlDataRequest( # type: ignore[attr-defined] + start_session=start_session + ) + logger.debug("Writing StartSession to stream...") + await stream.write(req) + logger.debug("StartSession written successfully") + return channel, stream + + async def connect_direct() -> tuple[asyncio.StreamReader, asyncio.StreamWriter]: + logger.debug("Fallback triggered, fetching connection info...") + conn_info = await get_conn_info() + # Find a fallback IP address + fallback_ip = None + from google.cloud.sql.connector.enums import IPTypes + for t in [IPTypes.PUBLIC, IPTypes.PRIVATE, IPTypes.PSC]: + try: + fallback_ip = conn_info.get_preferred_ip(t) + break + except Exception: + continue + if not fallback_ip: + raise ValueError("Cannot fallback to direct connection: no IP address available.") + logger.debug(f"Connecting directly to {fallback_ip}:3307") + ssl_context = await conn_info.create_ssl_context(enable_iam_auth) + return await asyncio.open_connection( + fallback_ip, 3307, ssl=ssl_context, server_hostname=fallback_ip + ) + + # Initialize connection + if use_fallback: + logger.debug("Using cached fallback connection") + backend_reader, backend_writer = await connect_direct() + fallback_triggered = True + else: + try: + grpc_channel, grpc_stream = await connect_grpc() + except Exception as e: + logger.debug(f"Failed to initialize gRPC stream: {e}") + # Try fallback immediately + backend_reader, backend_writer = await connect_direct() + fallback_triggered = True + on_fallback(instance_connection_name) + + # Task to read from client and write to backend + async def client_to_backend(): + nonlocal first_read_done, fallback_triggered, backend_writer, grpc_stream + try: + while True: + data = await client_reader.read(4096) + if not data: + logger.debug("Client socket EOF") + break + + if not first_read_done and not fallback_triggered: + client_write_buffer.extend(data) + + if fallback_triggered: + if backend_writer: + backend_writer.write(data) + await backend_writer.drain() + else: + packet = sql_data_service_pb2.DataPacket(data=data) # type: ignore[attr-defined] + req = sql_data_service_pb2.StreamSqlDataRequest( # type: ignore[attr-defined] + data=packet + ) + if grpc_stream: + await grpc_stream.write(req) + except Exception as e: + logger.error(f"Error in client_to_backend: {e}") + raise + finally: + if fallback_triggered: + if backend_writer: + backend_writer.write_eof() + else: + if grpc_stream: + try: + await grpc_stream.done_writing() + except Exception: + pass + logger.debug("Client to backend task finished") + + # Task to read from backend and write to client + async def backend_to_client(): + nonlocal first_read_done, fallback_triggered, backend_reader, backend_writer, grpc_stream, grpc_channel + try: + if fallback_triggered: + # If we started with fallback, just copy + while True: + if not backend_reader: + break + data = await backend_reader.read(4096) + if not data: + break + client_writer.write(data) + await client_writer.drain() + else: + # gRPC read loop + try: + if not grpc_stream: + return + async for resp in grpc_stream: + first_read_done = True + msg_type = resp.WhichOneof("message") + if msg_type == "session_metadata": + logger.debug("Received SessionMetadata") + elif msg_type == "data": + data = resp.data.data + client_writer.write(data) + await client_writer.drain() + elif msg_type == "terminate_session": + logger.debug("Received TerminateSession") + break + except grpc.aio.AioRpcError as e: + logger.debug(f"gRPC stream error: {e}") + # Check for fallback condition + if ( + not first_read_done + and e.code() == grpc.StatusCode.FAILED_PRECONDITION + ): + logger.info( + f"SQL Data Service not supported for {instance_connection_name}. " + "Falling back to direct connection." + ) + fallback_triggered = True + on_fallback(instance_connection_name) + + # Clean up gRPC + if grpc_channel: + await grpc_channel.close() + + # Connect direct + backend_reader, backend_writer = await connect_direct() + + # Replay buffered client data + if client_write_buffer: + logger.debug(f"Replaying {len(client_write_buffer)} bytes to fallback connection") + backend_writer.write(bytes(client_write_buffer)) + await backend_writer.drain() + + # Start copying from direct connection + while True: + data = await backend_reader.read(4096) + if not data: + break + client_writer.write(data) + await client_writer.drain() + else: + # Other gRPC error, re-raise to close connection + raise + except Exception as e: + logger.error(f"Error in backend_to_client: {e}") + raise + finally: + client_writer.close() + try: + await client_writer.wait_closed() + except Exception: + pass + if fallback_triggered and backend_writer: + backend_writer.close() + try: + await backend_writer.wait_closed() + except Exception: + pass + elif grpc_channel: + await grpc_channel.close() + logger.debug("Backend to client task finished") + + # Run both tasks + try: + await asyncio.gather(client_to_backend(), backend_to_client()) + finally: + logger.debug("Closing client socket in _handle_tunnel finally") + try: + client_writer.close() + sock = client_writer.get_extra_info('socket') + if sock: + sock.close() + except Exception as e: + logger.debug(f"Error closing client writer: {e}") + logger.debug("SQL Data tunnel handler finished") + + +class FallbackSocket(socket.socket): + def connect(self, *args: Any, **kwargs: Any) -> None: + # Already connected, do nothing. + # This is needed because some drivers (like pymysql) try to call connect() + # internally even if passed an already connected socket. + pass diff --git a/pyproject.toml b/pyproject.toml index cbf0dd10f..4a2894a51 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,9 @@ dependencies = [ "dnspython>=2.0.0", "Requests", "google-auth>=2.28.0", + "grpcio", + "protobuf", + "googleapis-common-protos", ] dynamic = ["version"] diff --git a/requirements-test.txt b/requirements-test.txt index 296878dd8..e1a150b21 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -10,4 +10,5 @@ pg8000==1.31.5 asyncpg==0.31.0 python-tds==1.17.1 aioresponses==0.7.8 -pytest-aiohttp==1.1.0 +pytest-aiohttp<1.1.0 +aiohttp==3.10.11 diff --git a/tests/system/test_sqldata_connection.py b/tests/system/test_sqldata_connection.py new file mode 100644 index 000000000..15a9fe425 --- /dev/null +++ b/tests/system/test_sqldata_connection.py @@ -0,0 +1,104 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os + +import pytest + +from google.cloud.sql.connector import Connector +from google.cloud.sql.connector import IPTypes + +# AI Developer Edition instance connection details +DB_USER = "postgres" +DB_NAME = "postgres" + +# Sandbox endpoints +ADMIN_ENDPOINT = "https://coreltest-sqladmin.mtls.sandbox.googleapis.com" +SQL_DATA_ENDPOINT = "coreltest-sqladmin.mtls.sandbox.googleapis.com:443" + + +@pytest.fixture(name="config") +def config_fixture(): + conn_name = os.environ.get("SQL_DATA_CONNECTION_NAME") + quota_project = os.environ.get("SQL_DATA_PROJECT") + + if not conn_name: + pytest.skip("SQL_DATA_CONNECTION_NAME env var not set") + + password = os.environ.get("POSTGRES_CUSTOMER_CAS_PASS") + if not password: + pytest.skip("POSTGRES_CUSTOMER_CAS_PASS env var not set") + + return { + "conn_name": conn_name, + "quota_project": quota_project, + "password": password, + } + + +@pytest.mark.asyncio +async def test_asyncpg_sqldata_connect(config): + loop = asyncio.get_running_loop() + connector = Connector( + loop=loop, + sqladmin_api_endpoint=ADMIN_ENDPOINT, + sql_data_endpoint=SQL_DATA_ENDPOINT, + quota_project=config["quota_project"], + ) + + conn = None + try: + conn = await connector.connect_async( + config["conn_name"], + "asyncpg", + user=DB_USER, + password=config["password"], + db=DB_NAME, + ip_type=IPTypes.SQL_DATA, + ) + val = await conn.fetchval("SELECT NOW()") + assert val is not None + finally: + if conn: + await conn.close() + await connector.close_async() + + +def test_pg8000_sqldata_connect(config): + connector = Connector( + sqladmin_api_endpoint=ADMIN_ENDPOINT, + sql_data_endpoint=SQL_DATA_ENDPOINT, + quota_project=config["quota_project"], + ) + + conn = None + try: + conn = connector.connect( + config["conn_name"], + "pg8000", + user=DB_USER, + password=config["password"], + db=DB_NAME, + ip_type=IPTypes.SQL_DATA, + ) + cursor = conn.cursor() + cursor.execute("SELECT NOW()") + val = cursor.fetchone() + assert val is not None + cursor.close() + finally: + if conn: + conn.close() + connector.close() diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index a09b5b72f..ac1288da2 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -234,7 +234,7 @@ def test_Connector_Init_bad_ip_type(fake_credentials: Credentials) -> None: assert ( exc_info.value.args[0] == f"Incorrect value for ip_type, got '{bad_ip_type.upper()}'. " - "Want one of: 'PRIMARY', 'PRIVATE', 'PSC', 'PUBLIC'." + "Want one of: 'PRIMARY', 'PRIVATE', 'PSC', 'SQL_DATA', 'PUBLIC'." ) @@ -257,7 +257,7 @@ def test_Connector_connect_bad_ip_type( assert ( exc_info.value.args[0] == f"Incorrect value for ip_type, got '{bad_ip_type.upper()}'. " - "Want one of: 'PRIMARY', 'PRIVATE', 'PSC', 'PUBLIC'." + "Want one of: 'PRIMARY', 'PRIVATE', 'PSC', 'SQL_DATA', 'PUBLIC'." )