From fd0dab5729f792bada0cbb3ab8c33196d1895be9 Mon Sep 17 00:00:00 2001 From: Andrzej Rybczak Date: Mon, 13 Apr 2026 16:53:43 +0200 Subject: [PATCH 1/2] Add support for (de)serialization of Integer to/from numeric This turned out to be way more annoying than anticipated, but the C code for (de)serialization of numeric was incomprehensible (and buggy, e.g. it didn't handle NaN properly), so I adjusted code from postgresql-binary package to serialize on Haskell side instead. --- CHANGELOG.md | 3 +- libpqtypes/src/libpqtypes.h | 13 + libpqtypes/src/numerics.c | 562 +----------------- src/Database/PostgreSQL/PQTypes/Format.hs | 3 + src/Database/PostgreSQL/PQTypes/FromSQL.hs | 5 + .../PostgreSQL/PQTypes/Internal/C/Types.hsc | 39 ++ .../PostgreSQL/PQTypes/Internal/Utils.hs | 51 ++ src/Database/PostgreSQL/PQTypes/ToSQL.hs | 4 + test/Main.hs | 25 +- 9 files changed, 149 insertions(+), 556 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 833b164..1d088b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ -# hpqtypes-1.14.1.0 +# hpqtypes-1.14.1.0 (????-??-??) * Introduce From/ToSQL instances for Word16, Word32 and Word64. +* Add support for (de)serialization of `Integer` to/from `numeric`. # hpqtypes-1.14.0.0 (2025-12-10) * Make `begin`, `commit` and `rollback` do nothing instead of throwing an error diff --git a/libpqtypes/src/libpqtypes.h b/libpqtypes/src/libpqtypes.h index 3529b04..145ccdf 100644 --- a/libpqtypes/src/libpqtypes.h +++ b/libpqtypes/src/libpqtypes.h @@ -180,6 +180,19 @@ typedef float PGfloat4; typedef double PGfloat8; typedef char *PGnumeric; +#define NUMERIC_POS 0x0000 +#define NUMERIC_NEG 0x4000 +#define NUMERIC_NAN 0xC000 + +typedef struct +{ + short ndigits; /* # of digits in digits[] - can be 0! */ + short weight; /* weight of first digit */ + short sign; /* NUMERIC_POS, NUMERIC_NEG, or NUMERIC_NAN */ + short dscale; /* display scale */ + const short *digits; /* base-NBASE digits in network byte order */ +} NumericVar; + /* Defined by an end-user if the system is missing long long. */ #ifdef PQT_LONG_LONG typedef PQT_LONG_LONG PGint8; diff --git a/libpqtypes/src/numerics.c b/libpqtypes/src/numerics.c index aa77fd8..97e1c62 100644 --- a/libpqtypes/src/numerics.c +++ b/libpqtypes/src/numerics.c @@ -20,50 +20,6 @@ # pragma warning (disable : 4244) #endif -/* - * Macros and structures for receiving numeric field in binary - */ -#define NBASE 10000 -#define HALF_NBASE 5000 -#define DEC_DIGITS 4 /* decimal digits per NBASE digit */ -#define MUL_GUARD_DIGITS 2 /* these are measured in NBASE digits */ -#define DIV_GUARD_DIGITS 4 - -/* - * Hardcoded precision limit - arbitrary, but must be small enough that - * dscale values will fit in 14 bits. - */ -#define NUMERIC_MAX_PRECISION 1000 - -/* - * Sign values and macros to deal with packing/unpacking n_sign_dscale - */ -#define NUMERIC_SIGN_MASK 0xC000 -#define NUMERIC_POS 0x0000 -#define NUMERIC_NEG 0x4000 -#define NUMERIC_NAN 0xC000 -#define NUMERIC_DSCALE_MASK 0x3FFF -#define NUMERIC_SIGN(n) ((n)->n_sign_dscale & NUMERIC_SIGN_MASK) -#define NUMERIC_DSCALE(n) ((n)->n_sign_dscale & NUMERIC_DSCALE_MASK) -#define NUMERIC_IS_NAN(n) (NUMERIC_SIGN(n) != NUMERIC_POS && \ - NUMERIC_SIGN(n) != NUMERIC_NEG) - -typedef short NumericDigit; -static const int round_powers[4] = {0, 1000, 100, 10}; - -typedef struct NumericVar -{ - int ndigits; /* # of digits in digits[] - can be 0! */ - int weight; /* weight of first digit */ - int sign; /* NUMERIC_POS, NUMERIC_NEG, or NUMERIC_NAN */ - int dscale; /* display scale */ - NumericDigit *buf; /* start of palloc'd space for digits[] */ - NumericDigit *digits; /* base-NBASE digits */ -} NumericVar; - -static int str2num(PGtypeArgs *args, const char *str, NumericVar *dest); -static int num2str(char *out, size_t outl, NumericVar *var, int dscale); - int pqt_put_int2(PGtypeArgs *args) { @@ -223,526 +179,46 @@ pqt_get_float8(PGtypeArgs *args) return 0; } -/* exposing a NumericVar struct to a libpq user, or something similar, - * doesn't seem useful w/o a library to operate on it. Instead, we - * always expose a numeric in text format and let the API user decide - * how to use it .. like strod or a 3rd party big number library. We - * always send a numeric in binary though. - */ int pqt_put_numeric(PGtypeArgs *args) { - int numlen; - NumericVar num = {0}; - short *out; - PGnumeric str = va_arg(args->ap, PGnumeric); - - PUTNULLCHK(args, str); - - if (str2num(args, str, &num)) - { - if (num.digits) - free(num.digits); - return -1; - } + NumericVar *num = va_arg(args->ap, NumericVar *); + PUTNULLCHK(args, num); /* variable length data type, grow args->put.out buffer if needed */ - numlen = (int) sizeof(short) * (4 + num.ndigits); + int numlen = (int) sizeof(short) * (4 + num->ndigits); if (args->put.expandBuffer(args, numlen) == -1) return -1; - out = (short *) args->put.out; - *out++ = htons((short) num.ndigits); - *out++ = htons((short) num.weight); - *out++ = htons((short) num.sign); - *out++ = htons((short) num.dscale); + short *out = (short *) args->put.out; + *out++ = htons(num->ndigits); + *out++ = htons(num->weight); + *out++ = htons(num->sign); + *out++ = htons(num->dscale); - if (num.digits) - { - int i; - for (i=0; i < num.ndigits; i++) - *out++ = htons(num.digits[i]); - free(num.digits); - } + for (int i = 0; i < num->ndigits; ++i) + *out++ = num->digits[i]; return numlen; } -/* exposing a NumericVar struct to a libpq user, or something similar, - * doesn't seem useful w/o a library to operate on it. Instead, we - * always expose a numeric in text format and let the API user decide - * how to use it .. like strod or a 3rd party big number library. - */ int pqt_get_numeric(PGtypeArgs *args) { - int i; - short *s; - NumericVar num; DECLVALUE(args); - char buf[4096]; - size_t len; - PGnumeric *str = va_arg(args->ap, PGnumeric *); + NumericVar *num = va_arg(args->ap, NumericVar *); - CHKGETVALS(args, str); + CHKGETVALS(args, num); if (args->format == TEXTFMT) - { - *str = value; - return 0; - } - - s = (short *) value; - num.ndigits = ntohs(*s); s++; - num.weight = ntohs(*s); s++; - num.sign = ntohs(*s); s++; - num.dscale = ntohs(*s); s++; - num.digits = (short *) malloc(num.ndigits * sizeof(short)); - if (!num.digits) - RERR_MEM(args); - - for (i=0; i < num.ndigits; i++) - { - num.digits[i] = ntohs(*s); - s++; - } - - i = num2str(buf, sizeof(buf), &num, num.dscale); - free(num.digits); - - /* num2str failed, only fails when 'str' is too small */ - if (i == -1) - RERR(args, "out buffer is too small"); - - len = strlen(buf)+1; - *str = PQresultAlloc(args->get.result, len); - if (!*str) - RERR_MEM(args); - - memcpy(*str, buf, len); - return 0; - -} - - -/* - * round_var - * - * Round the value of a variable to no more than rscale decimal digits - * after the decimal point. NOTE: we allow rscale < 0 here, implying - * rounding before the decimal point. - */ -static void -round_var(NumericVar *var, int rscale) -{ - NumericDigit *digits = var->digits; - int di; - int ndigits; - int carry; - - var->dscale = rscale; - - /* decimal digits wanted */ - di = (var->weight + 1) * DEC_DIGITS + rscale; - - /* - * If di = 0, the value loses all digits, but could round up to 1 if its - * first extra digit is >= 5. If di < 0 the result must be 0. - */ - if (di < 0) - { - var->ndigits = 0; - var->weight = 0; - var->sign = NUMERIC_POS; - } - else - { - /* NBASE digits wanted */ - ndigits = (di + DEC_DIGITS - 1) / DEC_DIGITS; - - /* 0, or number of decimal digits to keep in last NBASE digit */ - di %= DEC_DIGITS; - - if (ndigits < var->ndigits || - (ndigits == var->ndigits && di > 0)) - { - var->ndigits = ndigits; - - if (di == 0) - carry = (digits[ndigits] >= HALF_NBASE) ? 1 : 0; - else - { - /* Must round within last NBASE digit */ - int extra, - pow10; - - pow10 = round_powers[di]; - extra = digits[--ndigits] % pow10; - digits[ndigits] = digits[ndigits] - (NumericDigit) extra; - carry = 0; - if (extra >= pow10 / 2) - { - pow10 += digits[ndigits]; - if (pow10 >= NBASE) - { - pow10 -= NBASE; - carry = 1; - } - digits[ndigits] = (NumericDigit) pow10; - } - } - - /* Propagate carry if needed */ - while (carry) - { - carry += digits[--ndigits]; - if (carry >= NBASE) - { - digits[ndigits] = (NumericDigit) (carry - NBASE); - carry = 1; - } - else - { - digits[ndigits] = (NumericDigit) carry; - carry = 0; - } - } - - if (ndigits < 0) - { - var->digits--; - var->ndigits++; - var->weight++; - } - } - } -} - -/* - * strip_var - * - * Strip any leading and trailing zeroes from a numeric variable - */ -static void -strip_var(NumericVar *var) -{ - NumericDigit *digits = var->digits; - int ndigits = var->ndigits; - - /* Strip leading zeroes */ - while (ndigits > 0 && *digits == 0) - { - digits++; - var->weight--; - ndigits--; - } + return args->errorf(args, "text format is not supported"); - /* Strip trailing zeroes */ - while (ndigits > 0 && digits[ndigits - 1] == 0) - ndigits--; + short *s = (short *) value; + num->ndigits = ntohs(*s); s++; + num->weight = ntohs(*s); s++; + num->sign = ntohs(*s); s++; + num->dscale = ntohs(*s); s++; + num->digits = s; - /* If it's zero, normalize the sign and weight */ - if (ndigits == 0) - { - var->sign = NUMERIC_POS; - var->weight = 0; - } - - var->digits = digits; - var->ndigits = ndigits; -} - -/* - * str2num() - * - * Parse a string and put the number into a variable - * returns -1 on error and 0 for success. - */ -static int -str2num(PGtypeArgs *args, const char *str, NumericVar *dest) -{ - const char *cp = str; - int have_dp = FALSE; - int i; - unsigned char *decdigits; - int sign = NUMERIC_POS; - int dweight = -1; - int ddigits; - int dscale = 0; - int weight; - int ndigits; - int offset; - NumericDigit *digits; - - /* - * We first parse the string to extract decimal digits and determine the - * correct decimal weight. Then convert to NBASE representation. - */ - - /* skip leading spaces */ - while (*cp) - { - if (!isspace((unsigned char) *cp)) - break; - cp++; - } - - switch (*cp) - { - case '+': - sign = NUMERIC_POS; - cp++; - break; - - case '-': - sign = NUMERIC_NEG; - cp++; - break; - } - - if (*cp == '.') - { - have_dp = TRUE; - cp++; - } - - if (!isdigit((unsigned char) *cp)) - return args->errorf(args, - "invalid input syntax for type numeric: '%s'", str); - - decdigits = (unsigned char *) malloc(strlen(cp) + DEC_DIGITS * 2); - - /* leading padding for digit alignment later */ - memset(decdigits, 0, DEC_DIGITS); - i = DEC_DIGITS; - - while (*cp) - { - if (isdigit((unsigned char) *cp)) - { - decdigits[i++] = *cp++ - '0'; - if (!have_dp) - dweight++; - else - dscale++; - } - else if (*cp == '.') - { - if (have_dp) - { - free(decdigits); - return args->errorf(args, - "invalid input syntax for type numeric: '%s'", str); - } - - have_dp = TRUE; - cp++; - } - else - break; - } - - ddigits = i - DEC_DIGITS; - /* trailing padding for digit alignment later */ - memset(decdigits + i, 0, DEC_DIGITS - 1); - - /* Handle exponent, if any */ - if (*cp == 'e' || *cp == 'E') - { - long exponent; - char *endptr; - - cp++; - exponent = strtol(cp, &endptr, 10); - if (endptr == cp) - { - free(decdigits); - return args->errorf(args, - "invalid input syntax for type numeric: '%s'", str); - } - - cp = endptr; - if (exponent > NUMERIC_MAX_PRECISION || - exponent < -NUMERIC_MAX_PRECISION) - { - free(decdigits); - return args->errorf(args, - "invalid input syntax for type numeric: '%s'", str); - } - - dweight += (int) exponent; - dscale -= (int) exponent; - if (dscale < 0) - dscale = 0; - } - - /* Should be nothing left but spaces */ - while (*cp) - { - if (!isspace((unsigned char) *cp)) - { - free(decdigits); - return args->errorf(args, - "invalid input syntax for type numeric: '%s'", str); - } - cp++; - } - - /* - * Okay, convert pure-decimal representation to base NBASE. First we need - * to determine the converted weight and ndigits. offset is the number of - * decimal zeroes to insert before the first given digit to have a - * correctly aligned first NBASE digit. - */ - if (dweight >= 0) - weight = (dweight + 1 + DEC_DIGITS - 1) / DEC_DIGITS - 1; - else - weight = -((-dweight - 1) / DEC_DIGITS + 1); - offset = (weight + 1) * DEC_DIGITS - (dweight + 1); - ndigits = (ddigits + offset + DEC_DIGITS - 1) / DEC_DIGITS; - - dest->digits = (NumericDigit *) malloc((ndigits) * sizeof(NumericDigit)); - dest->ndigits = ndigits; - dest->sign = sign; - dest->weight = weight; - dest->dscale = dscale; - - i = DEC_DIGITS - offset; - digits = dest->digits; - - while (ndigits-- > 0) - { - *digits++ = ((decdigits[i] * 10 + decdigits[i + 1]) * 10 + - decdigits[i + 2]) * 10 + decdigits[i + 3]; - i += DEC_DIGITS; - } - - free(decdigits); - - /* Strip any leading/trailing zeroes, and normalize weight if zero */ - strip_var(dest); - return 0; -} - -/* - * num2str() - - * - * Convert a var to text representation (guts of numeric_out). - * CAUTION: var's contents may be modified by rounding! - * returns -1 on error and 0 for success. - */ -static int -num2str(char *out, size_t outl, NumericVar *var, int dscale) -{ - //char *str; - char *cp; - char *endcp; - int i; - int d; - NumericDigit dig; - NumericDigit d1; - - if (dscale < 0) - dscale = 0; - - /* - * Check if we must round up before printing the value and do so. - */ - round_var(var, dscale); - - /* - * Allocate space for the result. - * - * i is set to to # of decimal digits before decimal point. dscale is the - * # of decimal digits we will print after decimal point. We may generate - * as many as DEC_DIGITS-1 excess digits at the end, and in addition we - * need room for sign, decimal point, null terminator. - */ - i = (var->weight + 1) * DEC_DIGITS; - if (i <= 0) - i = 1; - - if (outl <= (size_t) (i + dscale + DEC_DIGITS + 2)) - return -1; - - //str = palloc(i + dscale + DEC_DIGITS + 2); - //cp = str - cp = out; - - /* - * Output a dash for negative values - */ - if (var->sign == NUMERIC_NEG) - *cp++ = '-'; - - /* - * Output all digits before the decimal point - */ - if (var->weight < 0) - { - d = var->weight + 1; - *cp++ = '0'; - } - else - { - for (d = 0; d <= var->weight; d++) - { - dig = (d < var->ndigits) ? var->digits[d] : 0; - /* In the first digit, suppress extra leading decimal zeroes */ - { - int putit = (d > 0); - - d1 = dig / 1000; - dig -= d1 * 1000; - putit |= (d1 > 0); - if (putit) - *cp++ = (char) (d1 + '0'); - d1 = dig / 100; - dig -= d1 * 100; - putit |= (d1 > 0); - if (putit) - *cp++ = (char) (d1 + '0'); - d1 = dig / 10; - dig -= d1 * 10; - putit |= (d1 > 0); - if (putit) - *cp++ = (char) (d1 + '0'); - *cp++ = (char) (dig + '0'); - } - } - } - - /* - * If requested, output a decimal point and all the digits that follow it. - * We initially put out a multiple of DEC_DIGITS digits, then truncate if - * needed. - */ - if (dscale > 0) - { - *cp++ = '.'; - endcp = cp + dscale; - for (i = 0; i < dscale; d++, i += DEC_DIGITS) - { - dig = (d >= 0 && d < var->ndigits) ? var->digits[d] : 0; - d1 = dig / 1000; - dig -= d1 * 1000; - *cp++ = (char) (d1 + '0'); - d1 = dig / 100; - dig -= d1 * 100; - *cp++ = (char) (d1 + '0'); - d1 = dig / 10; - dig -= d1 * 10; - *cp++ = (char) (d1 + '0'); - *cp++ = (char) (dig + '0'); - } - cp = endcp; - } - - /* - * terminate the string and return it - */ - *cp = '\0'; return 0; } - - diff --git a/src/Database/PostgreSQL/PQTypes/Format.hs b/src/Database/PostgreSQL/PQTypes/Format.hs index fa02bfd..5a4bbfb 100644 --- a/src/Database/PostgreSQL/PQTypes/Format.hs +++ b/src/Database/PostgreSQL/PQTypes/Format.hs @@ -94,6 +94,9 @@ instance PQFormat Word32 where instance PQFormat Word64 where pqFormat = BS.pack "%int8" +instance PQFormat Integer where + pqFormat = BS.pack "%numeric" + -- CHAR instance PQFormat Char where diff --git a/src/Database/PostgreSQL/PQTypes/FromSQL.hs b/src/Database/PostgreSQL/PQTypes/FromSQL.hs index 59dffd3..e7018a2 100644 --- a/src/Database/PostgreSQL/PQTypes/FromSQL.hs +++ b/src/Database/PostgreSQL/PQTypes/FromSQL.hs @@ -83,6 +83,11 @@ instance FromSQL Word64 where fromSQL Nothing = unexpectedNULL fromSQL (Just n) = pure . fromIntegral $ n +instance FromSQL Integer where + type PQBase Integer = NumericVar + fromSQL Nothing = unexpectedNULL + fromSQL (Just nv) = numericVarToInteger nv + -- CHAR instance FromSQL Char where diff --git a/src/Database/PostgreSQL/PQTypes/Internal/C/Types.hsc b/src/Database/PostgreSQL/PQTypes/Internal/C/Types.hsc index 0a15425..a438a1c 100644 --- a/src/Database/PostgreSQL/PQTypes/Internal/C/Types.hsc +++ b/src/Database/PostgreSQL/PQTypes/Internal/C/Types.hsc @@ -34,6 +34,10 @@ module Database.PostgreSQL.PQTypes.Internal.C.Types , PGdate(..) , PGtime(..) , PGtimestamp(..) + , c_NUMERIC_POS + , c_NUMERIC_NEG + , c_NUMERIC_NAN + , NumericVar(..) ) where import Data.Word @@ -355,3 +359,38 @@ instance Storable PGtimestamp where #{poke PGtimestamp, epoch} ptr pgTimestampEpoch #{poke PGtimestamp, date} ptr pgTimestampDate #{poke PGtimestamp, time} ptr pgTimestampTime + +---------------------------------------- + +c_NUMERIC_POS :: CShort +c_NUMERIC_POS = #{const NUMERIC_POS} + +c_NUMERIC_NEG :: CShort +c_NUMERIC_NEG = #{const NUMERIC_NEG} + +c_NUMERIC_NAN :: CShort +c_NUMERIC_NAN = #{const NUMERIC_NAN} + +data NumericVar = NumericVar + { numVarNdigits :: !CShort + , numVarWeight :: !CShort + , numVarSign :: !CShort + , numVarDscale :: !CShort + , numVarDigits :: !(Ptr CShort) -- elements in network byte order + } + +instance Storable NumericVar where + sizeOf _ = #{size NumericVar} + alignment _ = #{alignment NumericVar} + peek ptr = NumericVar + <$> #{peek NumericVar, ndigits} ptr + <*> #{peek NumericVar, weight} ptr + <*> #{peek NumericVar, sign} ptr + <*> #{peek NumericVar, dscale} ptr + <*> #{peek NumericVar, digits} ptr + poke ptr NumericVar{..} = do + #{poke NumericVar, ndigits} ptr numVarNdigits + #{poke NumericVar, weight} ptr numVarWeight + #{poke NumericVar, sign} ptr numVarSign + #{poke NumericVar, dscale} ptr numVarDscale + #{poke NumericVar, digits} ptr numVarDigits diff --git a/src/Database/PostgreSQL/PQTypes/Internal/Utils.hs b/src/Database/PostgreSQL/PQTypes/Internal/Utils.hs index d0b58e5..939587d 100644 --- a/src/Database/PostgreSQL/PQTypes/Internal/Utils.hs +++ b/src/Database/PostgreSQL/PQTypes/Internal/Utils.hs @@ -6,6 +6,8 @@ module Database.PostgreSQL.PQTypes.Internal.Utils , cStringLenToBytea , byteaToCStringLen , textToCString + , numericVarToInteger + , withIntegerAsNumericVar , verifyPQTRes , withPGparam , throwLibPQError @@ -22,6 +24,7 @@ import Data.Kind (Type) import Data.Maybe import Data.Text qualified as T import Data.Text.Encoding qualified as T +import Data.Vector.Storable qualified as V import Foreign.C import Foreign.ForeignPtr import Foreign.Marshal.Alloc @@ -81,6 +84,54 @@ textToCString bs = unsafeUseAsCStringLen (T.encodeUtf8 bs) $ \(cs, len) -> do pokeByteOff ptr len (0 :: CChar) pure fptr +---------------------------------------- + +-- Note: these can be generalized to convert from/to Scientific and support +-- arbitrary floating point precision (relevant code can be borrowed from the +-- postgresql-binary package), but while deserialization is easy either way, +-- serialization is significantly more annoying, so let's leave this until it's +-- actually needed. + +numericVarToInteger :: NumericVar -> IO Integer +numericVarToInteger NumericVar {..} + | numVarDscale /= 0 = hpqTypesError "not an integer" + | numVarSign == c_NUMERIC_NAN = hpqTypesError "not a number" + | numVarSign == c_NUMERIC_POS = mkInteger 0 numVarDigits numVarNdigits + | numVarSign == c_NUMERIC_NEG = negate <$> mkInteger 0 numVarDigits numVarNdigits + | otherwise = hpqTypesError $ "unexpected sign: " ++ show numVarSign + where + mkInteger :: Integer -> Ptr CShort -> CShort -> IO Integer + mkInteger acc ptr = \case + 0 -> pure acc + n -> do + v <- ntohs <$> peek ptr + mkInteger (acc * 10000 + fromIntegral v) (ptr `plusPtr` 2) (n - 1) + +withIntegerAsNumericVar :: Integer -> (NumericVar -> IO r) -> IO r +withIntegerAsNumericVar n k = V.unsafeWith digits $ \digitsPtr -> do + k $ + NumericVar + { numVarNdigits = digitsLen + , numVarWeight = max 0 (digitsLen - 1) + , numVarSign = if n < 0 then c_NUMERIC_NEG else c_NUMERIC_POS + , numVarDscale = 0 + , numVarDigits = digitsPtr + } + where + digitsLen :: CShort + digitsLen = fromIntegral $ V.length digits + + digits :: V.Vector CShort + digits = V.reverse . (`V.unfoldr` abs n) $ \case + 0 -> Nothing + x -> case x `quotRem` 10000 of + (d, m) -> Just (htons $ fromIntegral m, d) + +foreign import ccall unsafe "htons" htons :: CShort -> CShort +foreign import ccall unsafe "ntohs" ntohs :: CShort -> CShort + +---------------------------------------- + -- | Check return value of a function from libpqtypes -- and if it indicates an error, throw appropriate exception. verifyPQTRes :: HasCallStack => Ptr PGerror -> String -> CInt -> IO () diff --git a/src/Database/PostgreSQL/PQTypes/ToSQL.hs b/src/Database/PostgreSQL/PQTypes/ToSQL.hs index 376ba63..983fdc8 100644 --- a/src/Database/PostgreSQL/PQTypes/ToSQL.hs +++ b/src/Database/PostgreSQL/PQTypes/ToSQL.hs @@ -96,6 +96,10 @@ instance ToSQL Word64 where type PQDest Word64 = CULLong toSQL n _ = putAsPtr (fromIntegral n) +instance ToSQL Integer where + type PQDest Integer = NumericVar + toSQL n _ k = withIntegerAsNumericVar n $ \nv -> putAsPtr nv k + -- CHAR instance ToSQL Char where diff --git a/test/Main.hs b/test/Main.hs index f70bbaf..eaee80f 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -641,6 +641,7 @@ tests td = , nullTest td (u :: Word16) , nullTest td (u :: Word32) , nullTest td (u :: Word64) + , nullTest td (u :: Integer) , nullTest td (u :: String) , nullTest td (u :: BS.ByteString) , nullTest td (u :: T.Text) @@ -669,6 +670,7 @@ tests td = , putGetTest td 100 (u :: Word16) (==) , putGetTest td 100 (u :: Word32) (==) , putGetTest td 100 (u :: Word64) (==) + , putGetTest td 1000000000000 (u :: Integer) (==) , putGetTest td 1000 (u :: String0) (==) , putGetTest td 1000 (u :: BS.ByteString) (==) , putGetTest td 1000 (u :: T.Text) (==) @@ -731,18 +733,17 @@ tests td = , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested)) , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16)) , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16, Word32)) - , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16, Word32, Word64)) - , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16, Word32, Word64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float)) - , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16, Word32, Word64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float, Double)) - , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16, Word32, Word64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float, Double, Bool)) - , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16, Word32, Word64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float, Double, Bool, AsciiChar)) - , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16, Word32, Word64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8)) - , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16, Word32, Word64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0)) - , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16, Word32, Word64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString)) - , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16, Word32, Word64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text)) - , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16, Word32, Word64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString)) - , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16, Word32, Word64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, U.UUID)) - , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16, Word32, Word64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, U.UUID, Day)) + , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16, Word32, Word64, Integer)) + , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16, Word32, Word64, Integer, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float)) + , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16, Word32, Word64, Integer, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float, Double)) + , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16, Word32, Word64, Integer, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float, Double, Bool)) + , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16, Word32, Word64, Integer, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float, Double, Bool, AsciiChar)) + , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16, Word32, Word64, Integer, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8)) + , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16, Word32, Word64, Integer, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0)) + , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16, Word32, Word64, Integer, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString)) + , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16, Word32, Word64, Integer, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text)) + , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16, Word32, Word64, Integer, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString)) + , rowTest td (u :: (Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Word16, Word32, Word64, Integer, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, Day, Array1 Int32, Composite Simple, CompositeArray1 Simple, Composite Nested, CompositeArray1 Nested, Int16, Int32, Int64, Float, Double, Bool, AsciiChar, Word8, String0, BS.ByteString, T.Text, BS.ByteString, U.UUID)) ] where u = undefined From f9fd77639e6339e6e87991abfded937d0dbc4c71 Mon Sep 17 00:00:00 2001 From: Andrzej Rybczak Date: Tue, 14 Apr 2026 13:03:42 +0200 Subject: [PATCH 2/2] Add a range check when deserializing array of digits --- src/Database/PostgreSQL/PQTypes/Internal/Utils.hs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Database/PostgreSQL/PQTypes/Internal/Utils.hs b/src/Database/PostgreSQL/PQTypes/Internal/Utils.hs index 939587d..dae4b34 100644 --- a/src/Database/PostgreSQL/PQTypes/Internal/Utils.hs +++ b/src/Database/PostgreSQL/PQTypes/Internal/Utils.hs @@ -105,6 +105,8 @@ numericVarToInteger NumericVar {..} 0 -> pure acc n -> do v <- ntohs <$> peek ptr + when (v < 0 || v > 9999) $ do + hpqTypesError $ "invalid digit: " ++ show v mkInteger (acc * 10000 + fromIntegral v) (ptr `plusPtr` 2) (n - 1) withIntegerAsNumericVar :: Integer -> (NumericVar -> IO r) -> IO r