Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ function write_e2e_env(){

}

## with_venv - runs a command with the venv activated
function with_venv() {
"$@"
}

## help - prints the help details
##
function help() {
Expand Down
29 changes: 16 additions & 13 deletions google/cloud/sql/connector/asyncpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""

import ssl
from typing import Any, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING

SERVER_PROXY_PORT = 3307

Expand All @@ -24,16 +24,15 @@


async def connect(
ip_address: str, ctx: ssl.SSLContext, **kwargs: Any
ip_address: str, ctx: Optional[ssl.SSLContext], **kwargs: Any
) -> "asyncpg.Connection":
"""Helper function to create an asyncpg DB-API connection object.

Args:
ip_address (str): A string containing an IP address for the Cloud SQL
instance.
ctx (ssl.SSLContext): An SSLContext object created from the Cloud SQL
server CA cert and ephemeral cert.
server CA cert and ephemeral cert.
server CA cert and ephemeral cert. Pass None to disable SSL.
kwargs: Keyword arguments for establishing asyncpg connection
object to Cloud SQL instance.

Expand All @@ -53,14 +52,18 @@ async def connect(
user = kwargs.pop("user")
db = kwargs.pop("db")
passwd = kwargs.pop("password", None)
port = kwargs.pop("port", SERVER_PROXY_PORT)

return await asyncpg.connect(
user=user,
database=db,
password=passwd,
host=ip_address,
port=SERVER_PROXY_PORT,
ssl=ctx,
direct_tls=True,
connect_args = {
"user": user,
"database": db,
"password": passwd,
"host": ip_address,
"port": port,
**kwargs,
)
}
if ctx is not None:
connect_args["ssl"] = ctx
connect_args["direct_tls"] = True

return await asyncpg.connect(**connect_args)
12 changes: 10 additions & 2 deletions google/cloud/sql/connector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,13 @@ async def _get_metadata(
if dns_name:
ip_addresses["PSC"] = dns_name.rstrip(".")

server_ca_cert = None
if "serverCaCert" in ret_dict and "cert" in ret_dict["serverCaCert"]:
server_ca_cert = ret_dict["serverCaCert"]["cert"]

return {
"ip_addresses": ip_addresses,
"server_ca_cert": ret_dict["serverCaCert"]["cert"],
"server_ca_cert": server_ca_cert,
"database_version": ret_dict["databaseVersion"],
}

Expand Down Expand Up @@ -228,7 +232,11 @@ async def _get_ephemeral(
finally:
resp.raise_for_status()

ephemeral_cert: str = ret_dict["ephemeralCert"]["cert"]
try:
ephemeral_cert: str = ret_dict["ephemeralCert"]["cert"]
except KeyError as e:
logger.error(f"KeyError in _get_ephemeral parsing generateEphemeralCert: {e}. Response dict: {ret_dict}")
raise

# decode cert to read expiration
x509 = load_pem_x509_certificate(
Expand Down
6 changes: 5 additions & 1 deletion google/cloud/sql/connector/connection_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class ConnectionInfo:

conn_name: ConnectionName
client_cert: str
server_ca_cert: str
server_ca_cert: Optional[str]
private_key: bytes
ip_addrs: dict[str, Any]
database_version: str
Expand All @@ -79,6 +79,10 @@ async def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLCont
# if SSL context is cached, use it
if self.context is not None:
return self.context

if self.server_ca_cert is None:
raise ValueError("Cannot create SSL context: server CA certificate is missing.")

context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)

# update ssl.PROTOCOL_TLS_CLIENT default
Expand Down
Loading
Loading