From 18c46d8b1018571e5625669da8cef37371dd583a Mon Sep 17 00:00:00 2001 From: Greg Burd Date: Tue, 16 Jun 2026 10:04:28 -0400 Subject: [PATCH] libpq: SQLSTATE-based error matching for query failures Extract the SQLSTATE from a failed result (PQresultErrorField, which bindings.py already declares) onto ResultData.sqlstate, carry it on QueryError, and add named QueryError subclasses (QueryCanceled, UniqueViolation, DeadlockDetected, ...) so a test can write `with pytest.raises(QueryCanceled):` instead of catching the generic QueryError and string-matching its message. query_safe raises the SQLSTATE-specific subclass via query_error_for(); every subclass remains catchable as QueryError / LibpqError, and an unmapped SQLSTATE still raises a plain QueryError. --- src/test/pytest/libpq/errors.py | 119 ++++++++++++++++++++++++++++++- src/test/pytest/libpq/result.py | 5 ++ src/test/pytest/libpq/session.py | 7 +- 3 files changed, 128 insertions(+), 3 deletions(-) diff --git a/src/test/pytest/libpq/errors.py b/src/test/pytest/libpq/errors.py index f61d0631b0..3b8f754050 100644 --- a/src/test/pytest/libpq/errors.py +++ b/src/test/pytest/libpq/errors.py @@ -12,4 +12,121 @@ class PqConnectionError(LibpqError): class QueryError(LibpqError): - """Raised by the *_safe query helpers when a statement fails.""" + """Raised by the *_safe query helpers when a statement fails. + + ``sqlstate`` carries the five-character SQLSTATE from libpq when available + (None otherwise); ``sqlstate_class`` is its first two characters. These are + the stable, locale-independent way to assert on a specific error condition, + rather than matching against the human-readable message text. + """ + + def __init__(self, message, *, sqlstate=None): + super().__init__(message) + self.sqlstate = sqlstate + + @property + def sqlstate_class(self): + """The two-character SQLSTATE class, or None if no SQLSTATE is set.""" + if self.sqlstate and len(self.sqlstate) >= 2: + return self.sqlstate[:2] + return None + + +# Named QueryError subclasses for the SQLSTATEs tests most often assert on, so a +# test can write ``with pytest.raises(QueryCanceled):`` instead of catching the +# generic QueryError and then checking ``.sqlstate``. Each maps to its +# five-character SQLSTATE; query_error_for() picks the right class when raising. +class SyntaxErrorState(QueryError): + """42601 -- syntax_error.""" + + +class UndefinedTable(QueryError): + """42P01 -- undefined_table.""" + + +class UndefinedColumn(QueryError): + """42703 -- undefined_column.""" + + +class InsufficientPrivilege(QueryError): + """42501 -- insufficient_privilege.""" + + +class UniqueViolation(QueryError): + """23505 -- unique_violation.""" + + +class ForeignKeyViolation(QueryError): + """23503 -- foreign_key_violation.""" + + +class NotNullViolation(QueryError): + """23502 -- not_null_violation.""" + + +class CheckViolation(QueryError): + """23514 -- check_violation.""" + + +class SerializationFailure(QueryError): + """40001 -- serialization_failure.""" + + +class DeadlockDetected(QueryError): + """40P01 -- deadlock_detected.""" + + +class QueryCanceled(QueryError): + """57014 -- query_canceled.""" + + +class AdminShutdown(QueryError): + """57P01 -- admin_shutdown.""" + + +class CrashShutdown(QueryError): + """57P02 -- crash_shutdown.""" + + +class CannotConnectNow(QueryError): + """57P03 -- cannot_connect_now.""" + + +class ReadOnlySqlTransaction(QueryError): + """25006 -- read_only_sql_transaction.""" + + +class ObjectInUse(QueryError): + """55006 -- object_in_use.""" + + +# SQLSTATE -> exception subclass. Anything not listed raises a plain QueryError. +_SQLSTATE_EXCEPTIONS = { + "42601": SyntaxErrorState, + "42P01": UndefinedTable, + "42703": UndefinedColumn, + "42501": InsufficientPrivilege, + "23505": UniqueViolation, + "23503": ForeignKeyViolation, + "23502": NotNullViolation, + "23514": CheckViolation, + "40001": SerializationFailure, + "40P01": DeadlockDetected, + "57014": QueryCanceled, + "57P01": AdminShutdown, + "57P02": CrashShutdown, + "57P03": CannotConnectNow, + "25006": ReadOnlySqlTransaction, + "55006": ObjectInUse, +} + + +def query_error_for(message, sqlstate): + """Return a QueryError (or its SQLSTATE-specific subclass) for *sqlstate*. + + Used when a statement fails so callers can match on the specific condition + (e.g. ``pytest.raises(QueryCanceled)``) while still catching the base + QueryError/LibpqError when they want any failure. + """ + cls = _SQLSTATE_EXCEPTIONS.get(sqlstate or "", QueryError) + return cls(message, sqlstate=sqlstate) diff --git a/src/test/pytest/libpq/result.py b/src/test/pytest/libpq/result.py index 1bebbd0edc..e0d51a5a5c 100644 --- a/src/test/pytest/libpq/result.py +++ b/src/test/pytest/libpq/result.py @@ -12,6 +12,9 @@ from .constants import ExecStatusType +# PG_DIAG_SQLSTATE field id from postgres_ext.h, for PQresultErrorField(). +_PG_DIAG_SQLSTATE = ord("C") + def _decode(raw): """Decode a libpq C string (bytes or None) to str/None.""" @@ -26,6 +29,7 @@ class ResultData: status: int error_message: Optional[str] = None + sqlstate: Optional[str] = None names: List[str] = field(default_factory=list) types: List[int] = field(default_factory=list) rows: List[List[Optional[str]]] = field(default_factory=list) @@ -46,6 +50,7 @@ def extract_result_data(lib, result, conn): res.error_message = _decode(lib.PQresultErrorMessage(result)) or _decode( lib.PQerrorMessage(conn) ) + res.sqlstate = _decode(lib.PQresultErrorField(result, _PG_DIAG_SQLSTATE)) return res if status == ExecStatusType.PGRES_COMMAND_OK: return res diff --git a/src/test/pytest/libpq/session.py b/src/test/pytest/libpq/session.py index 9111c5d2a9..f9fde0d163 100644 --- a/src/test/pytest/libpq/session.py +++ b/src/test/pytest/libpq/session.py @@ -31,7 +31,7 @@ PQTRANS_INERROR, ) from .errors import PqConnectionError -from .errors import QueryError +from .errors import QueryError, query_error_for from .pgnotify import read_notification from .result import extract_result_data @@ -366,7 +366,10 @@ def query_safe(self, sql): res = self.query(sql) if res.error_message is not None: short = re.sub(r"\s+", " ", sql[:100]) - raise QueryError(f"query_safe failed on [{short}...]: {res.error_message}") + raise query_error_for( + f"query_safe failed on [{short}...]: {res.error_message}", + res.sqlstate, + ) return res.psqlout def query_oneval(self, sql, missing_ok=False):