Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 118 additions & 1 deletion src/test/pytest/libpq/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,121 @@ class PqConnectionError(LibpqError):


class QueryError(LibpqError):
"""Raised by the *_safe query helpers when a statement fails."""
"""Raised by the *_safe query helpers when a statement fails.

``sqlstate`` carries the five-character SQLSTATE from libpq when available
(None otherwise); ``sqlstate_class`` is its first two characters. These are
the stable, locale-independent way to assert on a specific error condition,
rather than matching against the human-readable message text.
"""

def __init__(self, message, *, sqlstate=None):
super().__init__(message)
self.sqlstate = sqlstate

@property
def sqlstate_class(self):
"""The two-character SQLSTATE class, or None if no SQLSTATE is set."""
if self.sqlstate and len(self.sqlstate) >= 2:
return self.sqlstate[:2]
return None


# Named QueryError subclasses for the SQLSTATEs tests most often assert on, so a
# test can write ``with pytest.raises(QueryCanceled):`` instead of catching the
# generic QueryError and then checking ``.sqlstate``. Each maps to its
# five-character SQLSTATE; query_error_for() picks the right class when raising.
class SyntaxErrorState(QueryError):
"""42601 -- syntax_error."""


class UndefinedTable(QueryError):
"""42P01 -- undefined_table."""


class UndefinedColumn(QueryError):
"""42703 -- undefined_column."""


class InsufficientPrivilege(QueryError):
"""42501 -- insufficient_privilege."""


class UniqueViolation(QueryError):
"""23505 -- unique_violation."""


class ForeignKeyViolation(QueryError):
"""23503 -- foreign_key_violation."""


class NotNullViolation(QueryError):
"""23502 -- not_null_violation."""


class CheckViolation(QueryError):
"""23514 -- check_violation."""


class SerializationFailure(QueryError):
"""40001 -- serialization_failure."""


class DeadlockDetected(QueryError):
"""40P01 -- deadlock_detected."""


class QueryCanceled(QueryError):
"""57014 -- query_canceled."""


class AdminShutdown(QueryError):
"""57P01 -- admin_shutdown."""


class CrashShutdown(QueryError):
"""57P02 -- crash_shutdown."""


class CannotConnectNow(QueryError):
"""57P03 -- cannot_connect_now."""


class ReadOnlySqlTransaction(QueryError):
"""25006 -- read_only_sql_transaction."""


class ObjectInUse(QueryError):
"""55006 -- object_in_use."""


# SQLSTATE -> exception subclass. Anything not listed raises a plain QueryError.
_SQLSTATE_EXCEPTIONS = {
"42601": SyntaxErrorState,
"42P01": UndefinedTable,
"42703": UndefinedColumn,
"42501": InsufficientPrivilege,
"23505": UniqueViolation,
"23503": ForeignKeyViolation,
"23502": NotNullViolation,
"23514": CheckViolation,
"40001": SerializationFailure,
"40P01": DeadlockDetected,
"57014": QueryCanceled,
"57P01": AdminShutdown,
"57P02": CrashShutdown,
"57P03": CannotConnectNow,
"25006": ReadOnlySqlTransaction,
"55006": ObjectInUse,
}


def query_error_for(message, sqlstate):
"""Return a QueryError (or its SQLSTATE-specific subclass) for *sqlstate*.

Used when a statement fails so callers can match on the specific condition
(e.g. ``pytest.raises(QueryCanceled)``) while still catching the base
QueryError/LibpqError when they want any failure.
"""
cls = _SQLSTATE_EXCEPTIONS.get(sqlstate or "", QueryError)
return cls(message, sqlstate=sqlstate)
5 changes: 5 additions & 0 deletions src/test/pytest/libpq/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

from .constants import ExecStatusType

# PG_DIAG_SQLSTATE field id from postgres_ext.h, for PQresultErrorField().
_PG_DIAG_SQLSTATE = ord("C")


def _decode(raw):
"""Decode a libpq C string (bytes or None) to str/None."""
Expand All @@ -26,6 +29,7 @@ class ResultData:

status: int
error_message: Optional[str] = None
sqlstate: Optional[str] = None
names: List[str] = field(default_factory=list)
types: List[int] = field(default_factory=list)
rows: List[List[Optional[str]]] = field(default_factory=list)
Expand All @@ -46,6 +50,7 @@ def extract_result_data(lib, result, conn):
res.error_message = _decode(lib.PQresultErrorMessage(result)) or _decode(
lib.PQerrorMessage(conn)
)
res.sqlstate = _decode(lib.PQresultErrorField(result, _PG_DIAG_SQLSTATE))
return res
if status == ExecStatusType.PGRES_COMMAND_OK:
return res
Expand Down
7 changes: 5 additions & 2 deletions src/test/pytest/libpq/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
PQTRANS_INERROR,
)
from .errors import PqConnectionError
from .errors import QueryError
from .errors import QueryError, query_error_for
from .pgnotify import read_notification
from .result import extract_result_data

Expand Down Expand Up @@ -366,7 +366,10 @@ def query_safe(self, sql):
res = self.query(sql)
if res.error_message is not None:
short = re.sub(r"\s+", " ", sql[:100])
raise QueryError(f"query_safe failed on [{short}...]: {res.error_message}")
raise query_error_for(
f"query_safe failed on [{short}...]: {res.error_message}",
res.sqlstate,
)
return res.psqlout

def query_oneval(self, sql, missing_ok=False):
Expand Down