diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 4d360f5..88609e1 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -28,17 +28,20 @@ jobs: python -m pip install --upgrade pip pip install -U flake8 setuptools pip install -U openapi-core uwsgi simplejson WSocket PyJWT pyaes - pip install -U pytest pytest-doctestplus pytest-pylint pytest-mypy requests websocket-client - pip install -U types-simplejson types-requests types-PyYAML + pip install -U pytest pytest-doctestplus pytest-mypy requests websocket-client + pip install -U pylint types-simplejson types-requests types-PyYAML - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Lint with pylit + - name: Lint with pylint run: | - pytest -v poorwsgi --pylint --mypy --doctest-plus --doctest-rst + pylint poorwsgi + - name: Lint with mypy and doctest + run: | + pytest -v poorwsgi --mypy --doctest-plus --doctest-rst - name: Tests run: | pytest -v tests --mypy diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index f4e8e86..a9fd7ba 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -1,11 +1,3 @@ -# This workflow will upload a Python Package using Twine when a release is created -# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries - -# This workflow uses actions that are not certified by GitHub. -# They are provided by a third-party and are governed by -# separate terms of service, privacy policy, and support -# documentation. - name: Upload Python Package on: @@ -14,6 +6,7 @@ on: permissions: contents: read + id-token: write jobs: deploy: @@ -21,9 +14,9 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: '3.x' - name: Install dependencies @@ -34,7 +27,4 @@ jobs: run: | python -m build - name: Publish package - uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 - with: - user: __token__ - password: ${{ secrets.PYPI_API_TOKEN }} + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.gitignore b/.gitignore index 83cde51..03f4fc8 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ tags __pycache__/ *.pyc *.profile +.coverage diff --git a/.isort.cfg b/.isort.cfg deleted file mode 100644 index 627f315..0000000 --- a/.isort.cfg +++ /dev/null @@ -1,2 +0,0 @@ -[settings] -line_length = 79 diff --git a/doc/ChangeLog b/doc/ChangeLog index ac86ac0..864b4f1 100644 --- a/doc/ChangeLog +++ b/doc/ChangeLog @@ -1,3 +1,25 @@ +==== 2.8.1 ==== + * Session, PoorSession, AESSession: validate same_site argument + - Accepted values are 'Strict', 'Lax', 'None', or False + - ValueError is raised for any other value + - ValueError is raised when same_site='None' and secure=False + (browsers reject SameSite=None without the Secure flag) + - Type annotation corrected from Union[str, bool] to + Union[str, Literal[False]] + * Updated docstrings and documentation to document valid same_site + values, the Secure requirement for 'None', and raised exceptions + * Fix Session.destroy() with max_age: Max-Age=-1 was overwritten by + the subsequent write() call; _destroyed flag now prevents that + * Unit tests for session.py and aes_session.py: 100% coverage + - get_token, check_token, NoCompress + - Session.destroy() with max_age and secure + - Session.header() with a headers argument + - PoorSession and AESSession: str key, empty cookie, short + signature/payload, non-dict data + * Consolidate tool configuration into pyproject.toml + - ruff.toml and .isort.cfg removed + - pytest testpaths and coverage source configured + ==== 2.8.0 ==== * Fix I/O operation on closed file/buffer error in Response classes (#21) * Validate route filter definitions to reject spaces with clear error diff --git a/doc/documentation.rst b/doc/documentation.rst index 76f9888..7d9facd 100644 --- a/doc/documentation.rst +++ b/doc/documentation.rst @@ -1075,10 +1075,34 @@ No encryption is applied — the value is stored as-is in the cookie. session.load(req.cookies) server_data = server_store[session.data] -The ``Session`` class accepts the following keyword arguments: -``sid``, ``expires``, ``max_age``, ``domain``, ``path``, ``secure``, -``same_site``. It exposes ``load()``, ``write()``, ``destroy()``, and -``header()`` methods. +The ``Session`` constructor accepts the following keyword arguments: + +``sid`` + Cookie name (default ``'SESSID'``). +``expires`` + ``Expires`` time in seconds from now. ``0`` (default) means no expiration. +``max_age`` + ``Max-Age`` in seconds. Takes precedence over ``expires`` when both are set. +``domain`` + ``Domain`` attribute — restricts which hosts receive the cookie. +``path`` + ``Path`` attribute (default ``'/'``). +``secure`` + Set the ``Secure`` flag. Required when ``same_site='None'``. +``same_site`` + ``SameSite`` attribute. Accepted values: + + * ``'Strict'`` — cookie is sent only in same-site requests. + * ``'Lax'`` — cookie is sent in same-site requests and top-level + navigations (browsers default to this when the attribute is absent). + * ``'None'`` — cookie is sent in all contexts including cross-site + requests. **Requires** ``secure=True``. + * ``False`` (default) — the attribute is omitted entirely. + + Passing any other value raises ``ValueError``. Passing ``'None'`` + without ``secure=True`` also raises ``ValueError``. + +It exposes ``load()``, ``write()``, ``destroy()``, and ``header()`` methods. .. note:: @@ -1117,6 +1141,13 @@ variable or the ``Application.secret_key`` property. * **Cookie format**: ``base64(ciphertext).base64(hmac-sha256)`` +``PoorSession`` accepts the same cookie keyword arguments as ``Session`` +(``sid``, ``expires``, ``max_age``, ``domain``, ``path``, ``secure``, +``same_site``) plus ``compress`` and ``secret_key`` (positional). +The constructor raises ``SessionError`` if ``secret_key`` is empty, and +``ValueError`` for an invalid ``same_site`` value or the ``'None'`` / +``secure=False`` combination (see `Session`_ for details). + The ``KEYSTREAM_SIZE`` constant in ``poorwsgi.session`` controls the keystream length (default ``1024``). Increasing it makes known-plaintext attacks harder at the cost of slightly larger memory usage per session instance. Changing it @@ -1206,6 +1237,13 @@ A 16-byte random nonce is generated on every ``write()`` call to prevent CTR nonce reuse. A missing or tampered signature causes ``SessionError`` to be raised in ``load()``. +``AESSession`` accepts the same cookie keyword arguments as ``Session`` +(``sid``, ``expires``, ``max_age``, ``domain``, ``path``, ``secure``, +``same_site``) plus ``secret_key`` (positional). +The constructor raises ``SessionError`` if ``secret_key`` is empty, and +``ValueError`` for an invalid ``same_site`` value or the ``'None'`` / +``secure=False`` combination (see `Session`_ for details). + JSON Web Tokens ``````````````` diff --git a/poorwsgi/aes_session.py b/poorwsgi/aes_session.py index 411fa74..92914db 100644 --- a/poorwsgi/aes_session.py +++ b/poorwsgi/aes_session.py @@ -13,7 +13,7 @@ from json import dumps, loads from logging import getLogger from os import urandom -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Literal, Optional, Union from pyaes import ( # type: ignore[import-untyped] AESModeOfOperationCTR, Counter) @@ -54,7 +54,7 @@ def __init__( # pylint: disable=too-many-positional-arguments domain: str = '', path: str = '/', secure: bool = False, - same_site: Union[str, bool] = False, + same_site: Union[str, Literal[False]] = False, sid: str = 'SESSID'): """Constructor. @@ -72,10 +72,22 @@ def __init__( # pylint: disable=too-many-positional-arguments secure If ``True``, set the ``Secure`` cookie attribute. same_site - The ``SameSite`` attribute value (``'Strict'``, ``'Lax'``, - ``'None'``) or ``False`` to omit it. + The ``SameSite`` cookie attribute. Accepted values are + ``'Strict'``, ``'Lax'``, ``'None'``, or ``False``. + ``False`` (the default) omits the attribute entirely, + which browsers treat as ``'Lax'``. ``'None'`` permits + the cookie in cross-site requests but requires + ``secure=True``. sid Cookie name. + + Raises: + SessionError + If *secret_key* is empty. + ValueError + If *same_site* is not one of ``'Strict'``, ``'Lax'``, + ``'None'``, or ``False``; or if *same_site* is ``'None'`` + and *secure* is ``False``. """ if not secret_key: raise SessionError("Empty secret_key") diff --git a/poorwsgi/digest.py b/poorwsgi/digest.py index 739f358..045432d 100644 --- a/poorwsgi/digest.py +++ b/poorwsgi/digest.py @@ -12,7 +12,8 @@ # pylint: disable=consider-using-f-string from argparse import ArgumentParser -from hashlib import md5, sha256 +from hashlib import md5, sha256 # nosec B324 - required by RFC 7616 +from hmac import compare_digest from traceback import print_exc from getpass import getpass from os.path import exists @@ -68,7 +69,7 @@ def check_response(req, password): response = req.app.auth_hash( '{hash1}:{nonce}:{hash2}' ''.format(**kwargs).encode()).hexdigest() - return response == kwargs['response'] + return compare_digest(response, kwargs['response']) def check_credentials(req, realm, username=None): @@ -182,6 +183,9 @@ def hexdigest(username, realm, password, algorithm=md5): """Returns the digest hash value for a user's password. Returns algorithm(username:realm:password).hexdigest(). + + The default algorithm is MD5 for compatibility with the htdigest file + format (RFC 2617). Use algorithm=sha256 for stronger hashing. """ return algorithm( ('%s:%s:%s' % (username, realm, password)).encode() @@ -217,7 +221,7 @@ def find(self, realm, username): def verify(self, realm, username, digest): """Checks the digest in the password map.""" digest_ = self.find(realm, username) - return bool(digest_) and digest_ == digest + return bool(digest_) and compare_digest(digest_, digest) def load(self): """Loads the map from a file.""" @@ -302,7 +306,7 @@ def main(): # noqa: C901 password = get_re_type() if password is None: return 1 - digest = hexdigest(args.username, args.realm, password) + digest = hexdigest(args.username, args.realm, password, algorithm) print('%s:%s:%s' % (args.username, args.realm, digest)) return 0 diff --git a/poorwsgi/session.py b/poorwsgi/session.py index df53f4c..7408cd9 100644 --- a/poorwsgi/session.py +++ b/poorwsgi/session.py @@ -25,7 +25,7 @@ from logging import getLogger from random import Random from time import time -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Literal, Optional, Union from http.cookies import SimpleCookie @@ -62,7 +62,7 @@ def hidden(text: Union[str, bytes], secret_hash: bytes) -> bytes: for i, val in enumerate(text): retval.append(val ^ secret_hash[i % secret_len]) - return retval + return bytes(retval) def encrypt(data: bytes, table: bytearray) -> bytes: @@ -175,7 +175,8 @@ class Session: def __init__(self, expires: int = 0, max_age: Optional[int] = None, domain: str = '', path: str = '/', secure: bool = False, - same_site: Union[str, bool] = False, sid: str = 'SESSID'): + same_site: Union[str, Literal[False]] = False, + sid: str = 'SESSID'): """Constructor. Arguments: @@ -192,12 +193,27 @@ def __init__(self, expires: int = 0, max_age: Optional[int] = None, secure If the ``Secure`` cookie attribute will be sent. same_site - The ``SameSite`` attribute. When set, it can be one of - ``Strict|Lax|None``. By default, the attribute is not - set, which browsers default to ``Lax``. + The ``SameSite`` cookie attribute. Accepted values are + ``'Strict'``, ``'Lax'``, ``'None'``, or ``False``. + ``False`` (the default) omits the attribute entirely, + which browsers treat as ``'Lax'``. ``'None'`` permits + the cookie in cross-site requests but requires + ``secure=True``. sid The cookie key name. + + Raises: + ValueError + If *same_site* is not one of ``'Strict'``, ``'Lax'``, + ``'None'``, or ``False``; or if *same_site* is ``'None'`` + and *secure* is ``False``. """ + if same_site not in ("Strict", "Lax", "None", False): + msg = (f"same_site={same_site!r} is not valid; " + "use 'Strict', 'Lax', 'None', or False") + raise ValueError(msg) + if same_site == "None" and not secure: + raise ValueError("same_site='None' requires secure=True") self._sid = sid self.__expires = expires self.__max_age = max_age @@ -209,6 +225,7 @@ def __init__(self, expires: int = 0, max_age: Optional[int] = None, self.data: Any = "" self.cookie: SimpleCookie = SimpleCookie() self.cookie[sid] = '' + self._destroyed: bool = False def _apply_cookie_attrs(self): """Apply security and configuration attributes to the session cookie. @@ -216,6 +233,8 @@ def _apply_cookie_attrs(self): Called by ``write`` and subclass overrides of ``write``. Sets ``HttpOnly``, ``Domain``, ``Path``, ``Secure``, ``SameSite``, ``Expires``, and ``Max-Age`` as configured. + Skips ``Expires`` and ``Max-Age`` when ``destroy()`` was already + called, so those fields are not overwritten with the original values. """ self.cookie[self._sid]['HttpOnly'] = True if self.__domain: @@ -226,10 +245,11 @@ def _apply_cookie_attrs(self): self.cookie[self._sid]['Secure'] = True if self.__same_site: self.cookie[self._sid]['SameSite'] = self.__same_site - if self.__expires: - self.cookie[self._sid]['expires'] = self.__expires - if self.__max_age is not None: - self.cookie[self._sid]['Max-Age'] = self.__max_age + if not self._destroyed: + if self.__expires: + self.cookie[self._sid]['expires'] = self.__expires + if self.__max_age is not None: + self.cookie[self._sid]['Max-Age'] = self.__max_age def load(self, cookies: Optional[SimpleCookie]): """Load the session value from the request's cookies. @@ -259,6 +279,7 @@ def destroy(self): Ensures that data cannot be changed: https://stackoverflow.com/a/5285982/8379994 """ + self._destroyed = True self.cookie[self._sid]['expires'] = -1 if self.__max_age is not None: self.cookie[self._sid]['Max-Age'] = -1 @@ -332,7 +353,7 @@ def to_dict(self): def __init__(self, secret_key: Union[str, bytes], expires: int = 0, max_age: Optional[int] = None, domain: str = '', path: str = '/', secure: bool = False, - same_site: Union[str, bool] = False, compress=bz2, + same_site: Union[str, Literal[False]] = False, compress=bz2, sid: str = 'SESSID'): """Constructor. @@ -352,9 +373,12 @@ def __init__(self, secret_key: Union[str, bytes], secure If the ``Secure`` cookie attribute will be sent. same_site - The ``SameSite`` attribute. When set, it can be one of - ``Strict|Lax|None``. By default, the attribute is not - set, which browsers default to ``Lax``. + The ``SameSite`` cookie attribute. Accepted values are + ``'Strict'``, ``'Lax'``, ``'None'``, or ``False``. + ``False`` (the default) omits the attribute entirely, + which browsers treat as ``'Lax'``. ``'None'`` permits + the cookie in cross-site requests but requires + ``secure=True``. compress Can be ``bz2``, ``gzip.zlib``, or any other, which has standard compress and decompress methods. Or it can be @@ -383,6 +407,14 @@ def __init__(self, secret_key: Union[str, bytes], *Changed in version 2.4.x*: Use app.secret_key in the constructor, and then call the load method. + + Raises: + ValueError + If *same_site* is not one of ``'Strict'``, ``'Lax'``, + ``'None'``, or ``False``; or if *same_site* is ``'None'`` + and *secure* is ``False``. + SessionError + If *secret_key* is empty. """ super().__init__(expires=expires, max_age=max_age, domain=domain, path=path, secure=secure, same_site=same_site, diff --git a/poorwsgi/state.py b/poorwsgi/state.py index 7421673..6181e35 100644 --- a/poorwsgi/state.py +++ b/poorwsgi/state.py @@ -8,9 +8,9 @@ import warnings __author__ = "Ondrej Tuma (McBig) " -__date__ = "5 May 2026" +__date__ = "20 Jun 2026" # PEP 0386 -- Version Identification and Dependency Specification -__version__ = "2.8.0" +__version__ = "2.8.1" DECLINED = 0 diff --git a/ruff.toml b/pyproject.toml similarity index 86% rename from ruff.toml rename to pyproject.toml index 67ce596..fbf0ee8 100644 --- a/ruff.toml +++ b/pyproject.toml @@ -1,6 +1,20 @@ +# Central configuration file. +# Packaging stays in setup.py until a full migration to pyproject.toml. + +[tool.pytest.ini_options] +testpaths = ["tests"] + +[tool.coverage.run] +source = ["poorwsgi"] + +[tool.isort] +line_length = 79 + +[tool.ruff] line-length = 79 -lint.select = [ +[tool.ruff.lint] +select = [ "F", # pyflakes "E", # pycodestyle "W", # pycodestyle @@ -50,8 +64,7 @@ lint.select = [ "NPY", # NumPy-specific rules "RUF", # Ruff-specific rules ] - -lint.ignore = [ +ignore = [ "Q000", # [*] Single quotes found but double quotes preferred "I001", # [*] Import block is un-sorted or un-formatted "COM812", # [*] Trailing comma missing @@ -93,7 +106,7 @@ lint.ignore = [ "RET504", # Unnecessary variable assignment before `return` statement "F841", # [*] Local variable `accepted` is assigned to but never used "F401", # [*] `time.time` imported but unused - "RUF005", # [*] Consider `[('share/doc/poorwsgi', ['doc/ChangeLog', 'doc/licence.txt', 'README.rst', 'CONTRIBUTION.rst']), *find_data_files('examples', 'share/poorwsgi/examples')]` instead of concatenation + "RUF005", # [*] Consider `[...]` instead of concatenation "G010", # [*] Logging statement uses `warn` instead of `warning` "PGH002", # `warn` is deprecated in favor of `warning` "C400", # [*] Unnecessary generator (rewrite as a `list` comprehension) @@ -109,24 +122,22 @@ lint.ignore = [ "ARG004", # Unused static method argument: `compresslevel` "ARG002", # Unused method argument: `start_response` "G001", # Logging statement uses `string.format()` - "SIM108", # [*] Use ternary operator `val = CgiFieldStorage.getlist(self, key) if key in self else default or []` instead of `if`-`else`-block - "A003", # Class attribute `input` is shadowing a Python builtin - "SIM108", # [*] Use ternary operator `val = CgiFieldStorage.getlist(self, key) if key in self else default or []` instead of `if`-`else`-block + "SIM108", # [*] Use ternary operator instead of `if`-`else`-block "A003", # Class attribute `input` is shadowing a Python builtin "C402", # [*] Unnecessary generator (rewrite as a `dict` comprehension) "C417", # [*] Unnecessary `map` usage (rewrite using a generator expression) "PLR0915", # Too many statements (52 > 50) "EM103", # [*] Exception must not use a `.format()` string directly, assign to variable first - "DTZ003", # The use of `datetime.datetime.utcnow()` is not allowed, use `datetime.datetime.now(tz=)` instead + "DTZ003", # The use of `datetime.datetime.utcnow()` is not allowed "PGH004", # Use specific rule codes when using `noqa` - "INP001", # File `examples/websocket.py` is part of an implicit namespace package. Add an `__init__.py`. + "INP001", # File is part of an implicit namespace package. Add an `__init__.py`. "ARG001", # Unused function argument: `req` "S324", # Probable use of insecure hash functions in `hashlib`: `md5` - "SIM118", # [*] Use `key in req.form` instead of `key in req.form.keys()` - "SIM105", # [*] Use `contextlib.suppress(SessionError)` instead of `try`-`except`-`pass` + "SIM118", # [*] Use `key in dict` instead of `key in dict.keys()` + "SIM105", # [*] Use `contextlib.suppress(...)` instead of `try`-`except`-`pass` "PIE810", # [*] Call `startswith` once with a `tuple` "C414", # [*] Unnecessary `tuple` call within `sorted()` - "TRY401", # Redundant exception object included in `logging.exception` call¸ + "TRY401", # Redundant exception object included in `logging.exception` call "TRY002", # Create your own exception "E402", # Module level import not at top of file "RUF100", # [*] Unused blanket `noqa` directive diff --git a/tests/test_aes_session.py b/tests/test_aes_session.py index 75851df..f3dbef3 100644 --- a/tests/test_aes_session.py +++ b/tests/test_aes_session.py @@ -1,10 +1,16 @@ """Unit tests for AESSession class.""" +import hmac as _hmac +from base64 import urlsafe_b64encode +from hashlib import sha256, sha3_256 +from json import dumps from os import urandom from http.cookies import SimpleCookie +from pyaes import ( # type: ignore[import-untyped] + AESModeOfOperationCTR, Counter) from pytest import fixture, raises -from poorwsgi.aes_session import AESSession +from poorwsgi.aes_session import AESSession, _NONCE_SIZE from poorwsgi.session import Session, SessionError SECRET_KEY = urandom(32) @@ -41,6 +47,72 @@ def test_empty_string_secret_key(self): with raises(SessionError): AESSession("") + def test_same_site_invalid_raises(self): + """Unrecognised same_site value must raise ValueError.""" + with raises(ValueError, match="is not valid"): + AESSession(SECRET_KEY, same_site=True) + + def test_same_site_none_without_secure_raises(self): + """same_site='None' without secure=True must raise ValueError.""" + with raises(ValueError, match="requires secure=True"): + AESSession(SECRET_KEY, same_site="None", secure=False) + + def test_string_secret_key(self): + """AESSession accepts a str key (encodes to bytes internally).""" + session = AESSession("string-secret-key") + session.data['x'] = 1 + session.write() + session2 = AESSession("string-secret-key") + session2.load(session.cookie) + assert session2.data == {'x': 1} + + def test_load_missing_sid(self): + """load() with no matching cookie name leaves data unchanged.""" + session = AESSession(SECRET_KEY) + session.load(SimpleCookie()) + assert session.data == {} + + def test_load_empty_cookie_value(self): + """load() with an empty cookie value leaves data unchanged.""" + cookies = SimpleCookie() + cookies['SESSID'] = '' + session = AESSession(SECRET_KEY) + session.load(cookies) + assert session.data == {} + + def test_load_short_payload(self): + """Payload shorter than the nonce size must raise SessionError.""" + root = sha3_256(SECRET_KEY).digest() + mac_key = sha3_256(root + b'mac').digest() + short_payload = b'\x00' * (_NONCE_SIZE - 1) + digest = _hmac.digest(mac_key, short_payload, digest=sha256) + raw = (urlsafe_b64encode(short_payload) + + b'.' + + urlsafe_b64encode(digest)) + cookies = SimpleCookie() + cookies['SESSID'] = raw.decode() + session = AESSession(SECRET_KEY) + with raises(SessionError): + session.load(cookies) + + def test_load_non_dict_data(self): + """Non-dict cookie data must raise SessionError.""" + root = sha3_256(SECRET_KEY).digest() + enc_key = sha3_256(root + b'enc').digest() + mac_key = sha3_256(root + b'mac').digest() + nonce = urandom(_NONCE_SIZE) + counter = Counter(initial_value=int.from_bytes(nonce, 'big')) + aes = AESModeOfOperationCTR(enc_key, counter=counter) + ciphertext = aes.encrypt(dumps([1, 2, 3])) + payload = nonce + ciphertext + digest = _hmac.digest(mac_key, payload, digest=sha256) + raw = urlsafe_b64encode(payload) + b'.' + urlsafe_b64encode(digest) + cookies = SimpleCookie() + cookies['SESSID'] = raw.decode() + session = AESSession(SECRET_KEY) + with raises(SessionError): + session.load(cookies) + def test_bad_session_data(self): cookies = SimpleCookie() cookies["SESSID"] = "notvalidbase64!!!" diff --git a/tests/test_digest.py b/tests/test_digest.py index f74d4f7..6e3bca4 100644 --- a/tests/test_digest.py +++ b/tests/test_digest.py @@ -1,12 +1,26 @@ """Tests for digest functionality.""" +# pylint: disable=too-many-lines from collections import defaultdict +from hashlib import md5, sha256 from os.path import dirname, join, pardir +from unittest.mock import MagicMock, patch +import pytest from pytest import fixture from poorwsgi import Application -from poorwsgi.digest import PasswordMap, hexdigest +from poorwsgi.digest import ( + PasswordMap, + check_credentials, + check_digest, + check_response, + get_re_type, + hexdigest, + main, +) from poorwsgi.request import Request +from poorwsgi.response import HTTPException +from poorwsgi import state FILE = join(dirname(__file__), pardir, "examples/test.digest") REALM = 'User Zone' @@ -56,11 +70,68 @@ def app(): def req(app): env = defaultdict(str) env['PATH_INFO'] = '/user' + env['REQUEST_METHOD'] = 'GET' env['HTTP_AUTHORIZATION'] = HEADER req = Request(env, app) return req +def _make_req(auth, hostname='testhost', path='/resource', + algorithm='MD5', qop=None, method='GET', + auth_map=None): + """Return a MagicMock request with digest auth attributes set.""" + req = MagicMock() + req.authorization = auth.copy() + req.server_hostname = hostname + req.full_path = path + req.app.auth_algorithm = algorithm + req.app.auth_hash = md5 + req.app.auth_qop = qop + req.app.auth_map = auth_map or {} + req.method = method + return req + + +def _build_credentials( # pylint: disable=too-many-locals + username, realm, password, hostname, path, + method='GET', algorithm='MD5', qop=None, + nonce='testnonce', cnonce='testcnonce', nc='00000001'): + """Compute a self-consistent Digest auth dict for the given parameters.""" + digest = hexdigest(username, realm, password) + opaque = sha256(hostname.encode()).hexdigest() + uses_sess = algorithm.endswith('-sess') + hash1 = digest + if uses_sess: + # RFC 7616 §3.4: -sess requires cnonce for HA1 regardless of qop + hash1 = md5(f'{digest}:{nonce}:{cnonce}'.encode()).hexdigest() + hash2 = md5(f'{method}:{path}'.encode()).hexdigest() + if qop: + response = md5( + f'{hash1}:{nonce}:{nc}:{cnonce}:{qop}:{hash2}'.encode() + ).hexdigest() + else: + response = md5(f'{hash1}:{nonce}:{hash2}'.encode()).hexdigest() + + auth = { + 'type': 'Digest', + 'username': username, + 'realm': realm, + 'nonce': nonce, + 'uri': path, + 'algorithm': algorithm, + 'opaque': opaque, + 'response': response, + } + if qop: + auth['qop'] = qop + auth['nc'] = nc + auth['cnonce'] = cnonce + elif uses_sess: + # cnonce required by -sess even without qop + auth['cnonce'] = cnonce + return auth, digest + + class TestMap(): """Tests for the PasswordMap class.""" @@ -96,12 +167,466 @@ def test_load(self): pmap.load() assert pmap.verify(REALM, USER, DIGEST) is True + def test_load_no_pathname(self): + """load() without pathname raises RuntimeError.""" + with pytest.raises(RuntimeError, match="No pathname"): + PasswordMap().load() + + def test_write_no_pathname(self): + """write() without pathname raises RuntimeError.""" + with pytest.raises(RuntimeError, match="No pathname"): + PasswordMap().write() + + def test_write_roundtrip(self, tmp_path): + """Written file can be loaded back and verified.""" + path = str(tmp_path / "test.digest") + pmap = PasswordMap(path) + pmap.set(REALM, USER, DIGEST) + pmap.write() + + loaded = PasswordMap(path) + loaded.load() + assert loaded.verify(REALM, USER, DIGEST) is True + + def test_write_multiple_realms(self, tmp_path): + """All entries across multiple realms survive a write/load cycle.""" + path = str(tmp_path / "multi.digest") + pmap = PasswordMap(path) + pmap.set(REALM, 'alice', hexdigest('alice', REALM, 'pass1')) + pmap.set(REALM, 'bob', hexdigest('bob', REALM, 'pass2')) + pmap.set('Admin Zone', 'admin', hexdigest('admin', 'Admin Zone', 'x')) + pmap.write() + + loaded = PasswordMap(path) + loaded.load() + assert loaded.find(REALM, 'alice') is not None + assert loaded.find(REALM, 'bob') is not None + assert loaded.find('Admin Zone', 'admin') is not None + def test_hexdigest(): """Tests the hexdigest function.""" assert hexdigest(USER, REALM, 'looser') == DIGEST +def test_hexdigest_sha256(): + """SHA-256 variant produces a different, longer digest.""" + d = hexdigest(USER, REALM, 'looser', algorithm=sha256) + assert d != DIGEST + assert len(d) == 64 # SHA-256 hex is 64 chars vs MD5's 32 + + def test_header_parsing(req): """Tests parsing the Authorization header.""" assert req.authorization == DICT + + +class TestCheckResponse: + """RFC 7616 §3.4 — response field computation.""" + + def test_valid_md5_sess_qop_auth(self, req): + """MD5-sess + qop=auth response must validate against known value.""" + assert check_response(req, DIGEST) is True + + def test_wrong_password_rejected(self, req): + """A mismatched password hash must return False.""" + assert check_response(req, 'a' * 32) is False + + def test_plain_md5_without_qop(self): + """Plain MD5, no qop: response = MD5(H1:nonce:H2) (RFC 2617 legacy).""" + auth, digest = _build_credentials( + USER, REALM, 'secret', 'host', '/path', + algorithm='MD5', qop=None, + ) + req = _make_req(auth, hostname='host', path='/path', + algorithm='MD5', qop=None, + auth_map={REALM: {USER: digest}}) + assert check_response(req, digest) is True + + def test_md5_with_qop_auth(self): + """Plain MD5 + qop=auth: response = MD5(H1:nonce:nc:cnonce:qop:H2).""" + auth, digest = _build_credentials( + USER, REALM, 'secret', 'host', '/path', + algorithm='MD5', qop='auth', + ) + req = _make_req(auth, hostname='host', path='/path', + algorithm='MD5', qop='auth', + auth_map={REALM: {USER: digest}}) + assert check_response(req, digest) is True + + def test_md5_sess_without_qop(self): + """MD5-sess without qop: HA1 = MD5(stored:nonce:cnonce).""" + auth, digest = _build_credentials( + USER, REALM, 'secret', 'host', '/path', + algorithm='MD5-sess', qop=None, + ) + req = _make_req(auth, hostname='host', path='/path', + algorithm='MD5-sess', qop=None, + auth_map={REALM: {USER: digest}}) + assert check_response(req, digest) is True + + def test_tampered_response_rejected(self): + """Altering the response field by one character must return False.""" + auth, digest = _build_credentials( + USER, REALM, 'secret', 'host', '/path', + algorithm='MD5', qop='auth', + ) + auth['response'] = auth['response'][:-1] + ( + 'f' if auth['response'][-1] != 'f' else 'e' + ) + req = _make_req(auth, hostname='host', path='/path', + algorithm='MD5', qop='auth', + auth_map={REALM: {USER: digest}}) + assert check_response(req, digest) is False + + def test_uri_mismatch_changes_response(self): + """Changing the URI invalidates the pre-computed response hash.""" + auth, digest = _build_credentials( + USER, REALM, 'secret', 'host', '/path', + algorithm='MD5', qop='auth', + ) + auth['uri'] = '/other' # change URI but keep old response + req = _make_req(auth, hostname='host', path='/other', + algorithm='MD5', qop='auth', + auth_map={REALM: {USER: digest}}) + assert check_response(req, digest) is False + + +class TestCheckCredentials: + """RFC 7616 §3.3 — server-side credential validation.""" + + def _valid_req(self, **overrides): + """Build a request with fully valid credentials, apply overrides.""" + auth, digest = _build_credentials( + USER, REALM, 'secret', 'myhost', '/res', + algorithm='MD5', qop='auth', + ) + auth.update(overrides.pop('auth', {})) + req = _make_req(auth, hostname='myhost', path='/res', + algorithm='MD5', qop='auth', + auth_map={REALM: {USER: digest}}) + for key, val in overrides.items(): + setattr(req, key, val) + return req + + def test_valid_passes(self): + """Fully valid credentials must return True.""" + assert check_credentials(self._valid_req(), REALM) is True + + def test_algorithm_mismatch(self): + """algorithm field different from app.auth_algorithm → False.""" + req = self._valid_req(auth={'algorithm': 'SHA-256'}) + assert check_credentials(req, REALM) is False + + def test_opaque_mismatch(self): + """Opaque not matching sha256(server_hostname) → False.""" + req = self._valid_req(auth={'opaque': 'deadbeef' * 8}) + assert check_credentials(req, REALM) is False + + def test_uri_not_suffix_of_path(self): + """URI in header not ending with req.full_path → False.""" + auth, digest = _build_credentials( + USER, REALM, 'secret', 'myhost', '/res', + algorithm='MD5', qop='auth', + ) + auth['uri'] = '/other/path' + req = _make_req(auth, hostname='myhost', path='/res', + algorithm='MD5', qop='auth', + auth_map={REALM: {USER: digest}}) + assert check_credentials(req, REALM) is False + + def test_uri_prefix_accepted(self): + """URI with query string that still ends with the path is accepted.""" + auth, digest = _build_credentials( + USER, REALM, 'secret', 'myhost', '/res', + algorithm='MD5', qop='auth', + ) + # auth['uri'] is already '/res' which ends with '/res' + req = _make_req(auth, hostname='myhost', path='/res', + algorithm='MD5', qop='auth', + auth_map={REALM: {USER: digest}}) + assert check_credentials(req, REALM) is True + + def test_qop_mismatch(self): + """qop in header different from app.auth_qop → False.""" + req = self._valid_req(auth={'qop': 'auth-int'}) + assert check_credentials(req, REALM) is False + + def test_realm_mismatch(self): + """realm in header different from expected realm → False.""" + assert check_credentials(self._valid_req(), 'Wrong Realm') is False + + def test_username_filter_match(self): + """Correct username with username filter → True.""" + assert check_credentials(self._valid_req(), REALM, username=USER) \ + is True + + def test_username_filter_mismatch(self): + """Wrong username with username filter → False.""" + assert check_credentials(self._valid_req(), REALM, username='admin') \ + is False + + def test_missing_response_field(self): + """Authorization without response field → False.""" + auth, _ = _build_credentials( + USER, REALM, 'secret', 'myhost', '/res', + algorithm='MD5', qop='auth', + ) + del auth['response'] + secret_digest = hexdigest(USER, REALM, 'secret') + req = _make_req(auth, hostname='myhost', path='/res', + algorithm='MD5', qop='auth', + auth_map={REALM: {USER: secret_digest}}) + assert check_credentials(req, REALM) is False + + def test_user_not_in_auth_map(self): + """Username absent from auth_map → False.""" + auth, _ = _build_credentials( + USER, REALM, 'secret', 'myhost', '/res', + algorithm='MD5', qop='auth', + ) + req = _make_req(auth, hostname='myhost', path='/res', + algorithm='MD5', qop='auth', + auth_map={}) # empty map + assert check_credentials(req, REALM) is False + + def test_wrong_password_in_map(self): + """Auth map has different password than presented → False.""" + auth, _ = _build_credentials( + USER, REALM, 'secret', 'myhost', '/res', + algorithm='MD5', qop='auth', + ) + wrong_digest = hexdigest(USER, REALM, 'wrong') + req = _make_req(auth, hostname='myhost', path='/res', + algorithm='MD5', qop='auth', + auth_map={REALM: {USER: wrong_digest}}) + assert check_credentials(req, REALM) is False + + +class TestCheckDigest: + """check_digest decorator — RFC 7616 §3.3 (server gate).""" + + @staticmethod + def _handler(_req): + return 'ok' + + def test_missing_authorization_header(self): + """No Authorization header → HTTP 401 Unauthorized.""" + req = MagicMock() + req.headers = {} + + with pytest.raises(HTTPException) as exc_info: + check_digest(REALM)(self._handler)(req) + assert exc_info.value.status_code == state.HTTP_UNAUTHORIZED + + def test_non_digest_auth_type(self): + """Basic auth type in Authorization → HTTP 401.""" + req = MagicMock() + req.headers = {'Authorization': 'Basic dXNlcjpwYXNz'} + req.authorization = {'type': 'Basic'} + + with pytest.raises(HTTPException) as exc_info: + check_digest(REALM)(self._handler)(req) + assert exc_info.value.status_code == state.HTTP_UNAUTHORIZED + + def test_invalid_nonce_sets_stale(self): + """Expired nonce → HTTP 401 with stale attribute in realm.""" + req = MagicMock() + req.headers = {'Authorization': HEADER} + req.authorization = DICT.copy() + + with patch('poorwsgi.digest.check_token', return_value=False): + with pytest.raises(HTTPException) as exc_info: + check_digest(REALM)(self._handler)(req) + assert exc_info.value.status_code == state.HTTP_UNAUTHORIZED + + def test_invalid_credentials(self): + """Valid nonce but wrong credentials → HTTP 401.""" + req = MagicMock() + req.headers = {'Authorization': HEADER} + req.authorization = DICT.copy() + + with patch('poorwsgi.digest.check_token', return_value=True): + with patch('poorwsgi.digest.check_credentials', + return_value=False): + with pytest.raises(HTTPException) as exc_info: + check_digest(REALM)(self._handler)(req) + assert exc_info.value.status_code == state.HTTP_UNAUTHORIZED + + def test_valid_auth_sets_user_and_calls_handler(self): + """Valid credentials set req.user and return the handler result.""" + req = MagicMock() + req.headers = {'Authorization': HEADER} + req.authorization = DICT.copy() + + with patch('poorwsgi.digest.check_token', return_value=True): + with patch('poorwsgi.digest.check_credentials', return_value=True): + result = check_digest(REALM)(self._handler)(req) + assert result == 'ok' + assert req.user == USER + + def test_username_filter_passed_through(self): + """Username filter is forwarded to check_credentials.""" + req = MagicMock() + req.headers = {'Authorization': HEADER} + req.authorization = DICT.copy() + + with patch('poorwsgi.digest.check_token', return_value=True): + with patch('poorwsgi.digest.check_credentials', + return_value=True) as mock_cc: + check_digest(REALM, username='alice')(self._handler)(req) + mock_cc.assert_called_once_with(req, REALM, 'alice') + + def test_wraps_preserves_name(self): + """@check_digest must preserve the wrapped function's __name__.""" + @check_digest(REALM) + def my_view(_req): # pylint: disable=unused-argument + return 'ok' + assert my_view.__name__ == 'my_view' + + +class TestGetReType: + """Tests for the get_re_type helper (interactive password input).""" + + def test_matching_passwords_returned(self): + """Returns the password when both prompts match.""" + with patch('poorwsgi.digest.getpass', + side_effect=['secret', 'secret']): + assert get_re_type() == 'secret' + + def test_mismatched_passwords_returns_none(self, capsys): + """Returns None and prints an error when passwords differ.""" + with patch('poorwsgi.digest.getpass', + side_effect=['secret', 'wrong']): + result = get_re_type() + assert result is None + out, _ = capsys.readouterr() + assert "don't match" in out + + +class TestMain: + """CLI tests for the htdigest-like main() function.""" + + def test_display_only_md5(self, capsys): + """-n flag prints username:realm:digest to stdout and returns 0.""" + with patch('sys.argv', ['digest', '-n', REALM, USER]): + with patch('poorwsgi.digest.get_re_type', return_value='looser'): + rc = main() + out, _ = capsys.readouterr() + assert rc == 0 + assert f'{USER}:{REALM}:{DIGEST}' in out + + def test_display_only_sha256(self, capsys): + """-n -s flag uses SHA-256 hash.""" + with patch('sys.argv', ['digest', '-n', '-s', REALM, USER]): + with patch('poorwsgi.digest.get_re_type', return_value='looser'): + rc = main() + out, _ = capsys.readouterr() + assert rc == 0 + expected = hexdigest(USER, REALM, 'looser', algorithm=sha256) + assert expected in out + + def test_display_only_password_mismatch(self): + """-n with mismatched password re-type returns 1.""" + with patch('sys.argv', ['digest', '-n', REALM, USER]): + with patch('poorwsgi.digest.get_re_type', return_value=None): + rc = main() + assert rc == 1 + + def test_create_and_verify(self, tmp_path, capsys): + """Create a new file, then verify the password.""" + path = str(tmp_path / 'new.digest') + with patch('sys.argv', ['digest', '-c', path, REALM, USER]): + with patch('poorwsgi.digest.get_re_type', return_value='looser'): + rc = main() + assert rc == 0 + + with patch('sys.argv', ['digest', '-v', path, REALM, USER]): + with patch('poorwsgi.digest.getpass', return_value='looser'): + rc = main() + capsys.readouterr() + assert rc == 0 + + def test_verify_wrong_password(self, tmp_path): + """Verify with wrong password returns 2.""" + path = str(tmp_path / 'new.digest') + with patch('sys.argv', ['digest', '-c', path, REALM, USER]): + with patch('poorwsgi.digest.get_re_type', return_value='looser'): + main() + + with patch('sys.argv', ['digest', '-v', path, REALM, USER]): + with patch('poorwsgi.digest.getpass', return_value='wrong'): + rc = main() + assert rc == 2 + + def test_add_and_delete_user(self, tmp_path, capsys): + """Add a user then delete it; second delete returns False.""" + path = str(tmp_path / 'del.digest') + with patch('sys.argv', ['digest', '-c', path, REALM, USER]): + with patch('poorwsgi.digest.get_re_type', return_value='looser'): + main() + + with patch('sys.argv', ['digest', '-D', path, REALM, USER]): + rc = main() + assert rc == 0 + out, _ = capsys.readouterr() + assert 'deleted' in out.lower() + + def test_no_passwordfile_required(self): + """Omitting password file without -n or -c triggers parser error.""" + with patch('sys.argv', ['digest', REALM, USER]): + with pytest.raises(SystemExit): + main() + + def test_delete_nonexistent_user(self, tmp_path, capsys): + """Deleting a user not in the file prints 'not found'.""" + path = str(tmp_path / 'del.digest') + with patch('sys.argv', ['digest', '-c', path, REALM, USER]): + with patch('poorwsgi.digest.get_re_type', return_value='looser'): + main() + + with patch('sys.argv', ['digest', '-D', path, REALM, 'nobody']): + rc = main() + out, _ = capsys.readouterr() + assert rc == 0 + assert 'not found' in out.lower() + + def test_update_existing_user(self, tmp_path, capsys): + """Updating an existing user prints 'Changing password' message.""" + path = str(tmp_path / 'update.digest') + with patch('sys.argv', ['digest', '-c', path, REALM, USER]): + with patch('poorwsgi.digest.get_re_type', return_value='looser'): + main() + capsys.readouterr() + + with patch('sys.argv', ['digest', path, REALM, USER]): + with patch('poorwsgi.digest.get_re_type', return_value='newpass'): + rc = main() + out, _ = capsys.readouterr() + assert rc == 0 + assert 'changing' in out.lower() + + def test_add_password_mismatch(self, tmp_path): + """Mismatched re-type during add returns 1.""" + path = str(tmp_path / 'new.digest') + with patch('sys.argv', ['digest', '-c', path, REALM, USER]): + with patch('poorwsgi.digest.get_re_type', return_value=None): + rc = main() + assert rc == 1 + + def test_missing_file_with_verify(self, tmp_path): + """Using -v on a missing file triggers parser error.""" + path = str(tmp_path / 'absent.digest') + with patch('sys.argv', ['digest', '-v', path, REALM, USER]): + with pytest.raises(SystemExit): + main() + + def test_exception_in_operation(self, tmp_path): + """An unexpected exception during an operation returns 1.""" + path = str(tmp_path / 'new.digest') + with patch('sys.argv', ['digest', '-c', path, REALM, USER]): + with patch('poorwsgi.digest.get_re_type', return_value='pass'): + with patch('poorwsgi.digest.hexdigest', + side_effect=ValueError('boom')): + with pytest.raises(SystemExit): + main() diff --git a/tests/test_fieldstorage.py b/tests/test_fieldstorage.py new file mode 100644 index 0000000..89053fd --- /dev/null +++ b/tests/test_fieldstorage.py @@ -0,0 +1,1081 @@ +"""Unit tests for poorwsgi/fieldstorage.py. + +Focuses on HTTP-level behavior: multipart/form-data (RFC 7578), +application/x-www-form-urlencoded, boundary validation, file uploads, +and browser compatibility quirks. +""" + +import tempfile +import warnings +from io import BytesIO, StringIO, TextIOWrapper + +import pytest + +from poorwsgi.fieldstorage import ( + FieldStorage, + FieldStorageInterface, + FieldStorageParser, + valid_boundary, +) +from poorwsgi.headers import Headers + +# pylint: disable=missing-function-docstring +# pylint: disable=too-many-public-methods +# pylint: disable=too-many-lines +# pylint: disable=no-self-use +# pylint: disable=R6301 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _urlencoded(body: str, charset="utf-8") -> FieldStorage: + """Parses a URL-encoded body and returns the root FieldStorage.""" + body_bytes = body.encode(charset) + headers = Headers([ + ("Content-Type", "application/x-www-form-urlencoded"), + ("Content-Length", str(len(body_bytes))), + ]) + parser = FieldStorageParser(BytesIO(body_bytes), headers) + return parser.parse() + + +def _multipart(body: bytes, boundary: str) -> FieldStorage: + """Parses a multipart/form-data body and returns the root FieldStorage.""" + headers = Headers([ + ("Content-Type", + f"multipart/form-data; boundary={boundary}"), + ("Content-Length", str(len(body))), + ]) + parser = FieldStorageParser(BytesIO(body), headers) + return parser.parse() + + +def _make_multipart(fields: list, boundary: str = "TestBoundary") -> bytes: + """Builds a minimal multipart/form-data body. + + Each entry in fields is either: + (name, value) — simple text field + (name, filename, content, ctype) — file upload + """ + parts = [] + for item in fields: + if len(item) == 2: + name, value = item + parts.append( + f"--{boundary}\r\n" + f"Content-Disposition: form-data; name=\"{name}\"\r\n" + f"\r\n" + f"{value}\r\n" + ) + else: + name, filename, content, ctype = item + part_header = ( + f"--{boundary}\r\n" + f"Content-Disposition: form-data; " + f"name=\"{name}\"; filename=\"{filename}\"\r\n" + f"Content-Type: {ctype}\r\n" + f"\r\n" + ) + if isinstance(content, str): + content = content.encode("utf-8") + parts.append(part_header.encode("utf-8") + content + b"\r\n") + body = b"".join( + p.encode("utf-8") if isinstance(p, str) else p for p in parts + ) + body += f"--{boundary}--\r\n".encode("utf-8") + return body + + +# --------------------------------------------------------------------------- +# valid_boundary +# --------------------------------------------------------------------------- + +class TestValidBoundary: + """RFC 2046 §5.1.1 boundary syntax validation.""" + + def test_str_valid(self): + """Typical browser-generated boundary string is valid.""" + assert valid_boundary("----WebKitFormBoundaryMPRpF8CUUmlmqKqy") + + def test_bytes_valid(self): + """Bytes boundary is also accepted.""" + assert valid_boundary(b"----WebKitFormBoundaryMPRpF8CUUmlmqKqy") + + def test_simple_str(self): + """Simple ASCII string is a valid boundary.""" + assert valid_boundary("boundary123") + + def test_empty_is_invalid(self): + """Empty boundary is invalid.""" + assert not valid_boundary("") + + def test_too_long_invalid(self): + """Boundary exceeding 201 characters is invalid.""" + assert not valid_boundary("x" * 202) + + def test_space_only_invalid(self): + """Space-only boundary is invalid (must end with !-~).""" + assert not valid_boundary(" ") + + def test_non_ascii_invalid(self): + """Non-ASCII boundary characters are invalid.""" + assert not valid_boundary("boundary\x00") + + def test_max_length_valid(self): + """Boundary of exactly 201 visible chars is valid.""" + assert valid_boundary("x" * 201) + + +# --------------------------------------------------------------------------- +# FieldStorage — standalone field object +# --------------------------------------------------------------------------- + +class TestFieldStorageBasic: + """Unit tests for FieldStorage data container.""" + + def test_repr_with_value(self): + """repr includes name and value.""" + field = FieldStorage("key", "val") + assert "key" in repr(field) + + def test_repr_with_file(self): + """repr includes name and file when no value set.""" + field = FieldStorage("key") + field.file = StringIO("data") + r = repr(field) + assert "key" in r + + def test_bool_true_value(self): + """Field with a value is truthy.""" + assert bool(FieldStorage("k", "v")) + + def test_bool_true_list(self): + """Root FieldStorage with children is truthy.""" + root = FieldStorage() + root.list = [FieldStorage("k")] + assert bool(root) + + def test_bool_false_empty(self): + """Field with no value and no list is falsy.""" + assert not bool(FieldStorage("k")) + + def test_iter_over_keys(self): + """Iteration yields unique key names.""" + root = FieldStorage() + root.list = [FieldStorage("a", "1"), FieldStorage("b", "2")] + assert set(root) == {"a", "b"} + + def test_len_unique_keys(self): + """len() counts unique key names.""" + root = FieldStorage() + root.list = [ + FieldStorage("k", "1"), + FieldStorage("k", "2"), + FieldStorage("x", "3"), + ] + assert len(root) == 2 + + def test_contains_true(self): + """'in' returns True for a present key.""" + root = FieldStorage() + root.list = [FieldStorage("key", "v")] + assert "key" in root + + def test_contains_false(self): + """'in' returns False for an absent key.""" + root = FieldStorage() + root.list = [FieldStorage("other", "v")] + assert "missing" not in root + + def test_contains_empty(self): + """'in' returns False when list is empty.""" + root = FieldStorage() + assert "key" not in root + + def test_getitem_single(self): + """__getitem__ returns the FieldStorage for a unique key.""" + root = FieldStorage() + root.list = [FieldStorage("k", "v")] + assert root["k"].value == "v" + + def test_getitem_multiple(self): + """__getitem__ returns a list when a key has multiple values.""" + root = FieldStorage() + root.list = [FieldStorage("k", "1"), FieldStorage("k", "2")] + result = root["k"] + assert isinstance(result, list) + assert len(result) == 2 + + def test_getitem_missing_raises(self): + """__getitem__ raises KeyError for missing key.""" + root = FieldStorage() + root.list = [] + with pytest.raises(KeyError): + _ = root["missing"] + + def test_getitem_no_list_raises(self): + """__getitem__ raises KeyError when list is empty.""" + root = FieldStorage() + with pytest.raises(KeyError): + _ = root["k"] + + def test_value_from_string(self): + """value property returns the _value string.""" + field = FieldStorage("k", "hello") + assert field.value == "hello" + + def test_value_from_stringio(self): + """value property reads from StringIO file.""" + field = FieldStorage("k") + field.file = StringIO("text data") + assert field.value == "text data" + + def test_value_from_bytesio(self): + """value property reads from BytesIO file.""" + field = FieldStorage("k") + field.file = BytesIO(b"binary") + assert field.value == b"binary" + + def test_value_from_file(self): + """value property reads from a seekable file-like object.""" + field = FieldStorage("k") + with tempfile.TemporaryFile("w+") as tmp: + tmp.write("file content") + field.file = tmp + assert field.value == "file content" + + def test_value_from_list(self): + """value property returns list when field has child list.""" + root = FieldStorage() + child = FieldStorage("k", "v") + root.list = [child] + assert root.value == [child] + + def test_value_none_when_empty(self): + """value property returns None for an empty field.""" + field = FieldStorage("k") + assert field.value is None + + def test_keys(self): + """keys() returns unique key names.""" + root = FieldStorage() + root.list = [FieldStorage("a", "1"), FieldStorage("b", "2")] + assert set(root.keys()) == {"a", "b"} + + def test_get_single(self): + """get() returns the field value for a single-value key.""" + root = FieldStorage() + root.list = [FieldStorage("k", "v")] + assert root.get("k") == "v" + + def test_get_multiple(self): + """get() returns a list of values for a multi-value key.""" + root = FieldStorage() + root.list = [FieldStorage("k", "1"), FieldStorage("k", "2")] + result = root.get("k") + assert isinstance(result, list) + assert len(result) == 2 + + def test_get_default(self): + """get() returns default for missing key.""" + root = FieldStorage() + root.list = [] + assert root.get("missing", "def") == "def" + + def test_context_manager(self): + """FieldStorage can be used as a context manager.""" + field = FieldStorage("k") + field.file = StringIO("value") + with field as f: + assert f.value == "value" + + def test_del_closes_file(self): + """Deleting a FieldStorage closes its file.""" + field = FieldStorage("k") + sio = StringIO("data") + field.file = sio + del field + assert sio.closed + + +class TestFieldStorageInterface: + """Tests for getvalue / getfirst / getlist on FieldStorageInterface + via FieldStorage (which inherits from it).""" + + def _root(self, pairs): + root = FieldStorage() + root.list = [FieldStorage(k, v) for k, v in pairs] + return root + + def test_getvalue_single(self): + """getvalue returns the value for a single-value key.""" + root = self._root([("k", "42")]) + assert root.getvalue("k") == "42" + + def test_getvalue_with_func(self): + """getvalue applies func to each value.""" + root = self._root([("n", "7")]) + assert root.getvalue("n", func=int) == 7 + + def test_getvalue_multiple(self): + """getvalue returns a list for a multi-value key.""" + root = self._root([("k", "1"), ("k", "2")]) + result = root.getvalue("k") + assert isinstance(result, list) + assert len(result) == 2 + + def test_getvalue_missing(self): + """getvalue returns default for missing key.""" + root = self._root([]) + assert root.getvalue("x", default="d") == "d" + + def test_getfirst_single(self): + """getfirst returns the only value for a single-value key.""" + root = self._root([("k", "v")]) + assert root.getfirst("k") == "v" + + def test_getfirst_multiple(self): + """getfirst returns the first value for a multi-value key.""" + root = self._root([("k", "first"), ("k", "second")]) + assert root.getfirst("k") == "first" + + def test_getfirst_with_func(self): + """getfirst applies func to the first value.""" + root = self._root([("n", "5"), ("n", "10")]) + assert root.getfirst("n", func=int) == 5 + + def test_getfirst_missing(self): + """getfirst returns default when key is absent.""" + root = self._root([]) + assert root.getfirst("x", default=0) == 0 + + def test_getfirst_deprecated_fce(self): + """getfirst emits DeprecationWarning for the fce argument.""" + root = self._root([("k", "3")]) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = root.getfirst("k", fce=int) + assert result == 3 + assert any(issubclass(x.category, DeprecationWarning) for x in w) + + def test_getlist_single(self): + """getlist wraps a single value in a list.""" + root = self._root([("k", "v")]) + assert root.getlist("k") == ["v"] + + def test_getlist_multiple(self): + """getlist returns all values for a multi-value key.""" + root = self._root([("k", "1"), ("k", "2")]) + assert root.getlist("k") == ["1", "2"] + + def test_getlist_with_func(self): + """getlist applies func to each value.""" + root = self._root([("k", "1"), ("k", "2")]) + assert root.getlist("k", func=int) == [1, 2] + + def test_getlist_missing_empty(self): + """getlist returns [] for a missing key.""" + root = self._root([]) + assert root.getlist("x") == [] + + def test_getlist_missing_with_default(self): + """getlist returns default list for a missing key.""" + root = self._root([]) + assert root.getlist("x", default=["fallback"]) == ["fallback"] + + def test_getlist_deprecated_fce(self): + """getlist emits DeprecationWarning for the fce argument.""" + root = self._root([("k", "1"), ("k", "2")]) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = root.getlist("k", fce=int) + assert result == [1, 2] + assert any(issubclass(x.category, DeprecationWarning) for x in w) + + +# --------------------------------------------------------------------------- +# FieldStorageParser — URL-encoded forms +# --------------------------------------------------------------------------- + +class TestURLEncoded: + """application/x-www-form-urlencoded parsing (HTML form default).""" + + def test_simple_field(self): + """Single name=value pair is parsed correctly.""" + form = _urlencoded("name=Ondrej") + assert form.getvalue("name") == "Ondrej" + + def test_multiple_fields(self): + """Multiple fields are all present.""" + form = _urlencoded("a=1&b=2&c=3") + assert form.getvalue("a") == "1" + assert form.getvalue("b") == "2" + assert form.getvalue("c") == "3" + + def test_repeated_key(self): + """Repeated key results in a list of values.""" + form = _urlencoded("tag=a&tag=b&tag=c") + result = form.getlist("tag") + assert result == ["a", "b", "c"] + + def test_percent_encoding(self): + """Percent-encoded values are decoded.""" + form = _urlencoded("name=Ond%C5%99ej") + assert form.getvalue("name") == "Ondřej" + + def test_plus_as_space(self): + """Plus sign in URL encoding represents a space.""" + form = _urlencoded("msg=hello+world") + assert form.getvalue("msg") == "hello world" + + def test_blank_value_ignored_by_default(self): + """Blank values are ignored when keep_blank_values is False.""" + body = b"name=&other=val" + headers = Headers([ + ("Content-Type", "application/x-www-form-urlencoded"), + ("Content-Length", str(len(body))), + ]) + parser = FieldStorageParser(BytesIO(body), headers) + form = parser.parse() + assert form.getvalue("name") is None + + def test_blank_value_kept(self): + """Blank values are kept as fields when keep_blank_values=1. + + The key appears in the form even though the value reads as None + (the value property treats empty string as falsy). + """ + body = b"name=&other=val" + headers = Headers([ + ("Content-Type", "application/x-www-form-urlencoded"), + ("Content-Length", str(len(body))), + ]) + parser = FieldStorageParser( + BytesIO(body), headers, keep_blank_values=1 + ) + form = parser.parse() + assert "name" in form + + def test_empty_body(self): + """Empty body produces empty form.""" + form = _urlencoded("") + assert len(form) == 0 + + def test_missing_content_type_defaults_to_urlencoded(self): + """No Content-Type header defaults to URL-encoded parsing.""" + body = b"x=1" + headers = Headers([("Content-Length", "3")]) + parser = FieldStorageParser(BytesIO(body), headers) + form = parser.parse() + assert form.getvalue("x") == "1" + + def test_separator_semicolon(self): + """Custom separator is respected.""" + body = b"a=1;b=2" + headers = Headers([ + ("Content-Type", "application/x-www-form-urlencoded"), + ("Content-Length", str(len(body))), + ]) + parser = FieldStorageParser( + BytesIO(body), headers, separator=";" + ) + form = parser.parse() + assert form.getvalue("a") == "1" + assert form.getvalue("b") == "2" + + def test_content_length_invalid_ignored(self): + """Non-integer Content-Length defaults to -1 (read all).""" + body = b"k=v" + headers = Headers([ + ("Content-Type", "application/x-www-form-urlencoded"), + ("Content-Length", "not-a-number"), + ]) + parser = FieldStorageParser(BytesIO(body), headers) + form = parser.parse() + assert form.getvalue("k") == "v" + + def test_max_num_fields_exceeded(self): + """max_num_fields raises ValueError when limit is exceeded.""" + body = b"a=1&b=2&c=3" + headers = Headers([ + ("Content-Type", "application/x-www-form-urlencoded"), + ("Content-Length", str(len(body))), + ]) + parser = FieldStorageParser( + BytesIO(body), headers, max_num_fields=2 + ) + with pytest.raises(ValueError, match="fields"): + parser.parse() + + +# --------------------------------------------------------------------------- +# FieldStorageParser — multipart/form-data +# --------------------------------------------------------------------------- + +class TestMultipartFormData: + """multipart/form-data (RFC 7578) parsing. + + Also tests browser-specific behaviour that deviates from RFC. + """ + + def test_simple_text_field(self): + """Single text field is parsed correctly.""" + body = _make_multipart([("name", "Ondrej")]) + form = _multipart(body, "TestBoundary") + assert form.getvalue("name") == "Ondrej" + + def test_multiple_text_fields(self): + """Multiple text fields are all parsed.""" + body = _make_multipart([("a", "1"), ("b", "2")]) + form = _multipart(body, "TestBoundary") + assert form.getvalue("a") == "1" + assert form.getvalue("b") == "2" + + def test_repeated_field(self): + """Repeated form field name produces multiple values.""" + body = _make_multipart([("tag", "python"), ("tag", "wsgi")]) + form = _multipart(body, "TestBoundary") + tags = [f.value for f in form["tag"]] + assert "python" in tags + assert "wsgi" in tags + + def test_file_upload_bytes(self): + """File upload is parsed and content is accessible.""" + body = _make_multipart( + [("upload", "hello.txt", b"Hello, file!", "text/plain")] + ) + form = _multipart(body, "TestBoundary") + field = form["upload"] + assert field.filename == "hello.txt" + assert field.value == b"Hello, file!" + + def test_file_upload_content_type(self): + """File upload preserves Content-Type.""" + body = _make_multipart( + [("img", "photo.jpg", b"\xff\xd8\xff", "image/jpeg")] + ) + form = _multipart(body, "TestBoundary") + field = form["img"] + assert field.type == "image/jpeg" + assert field.filename == "photo.jpg" + + def test_text_and_file_mixed(self): + """Text fields and file uploads coexist in the same form.""" + body = _make_multipart([ + ("username", "alice"), + ("avatar", "me.png", b"\x89PNG", "image/png"), + ]) + form = _multipart(body, "TestBoundary") + assert form.getvalue("username") == "alice" + assert form["avatar"].filename == "me.png" + + def test_boundary_with_dashes(self): + """Boundary string with leading dashes (Chrome/Firefox style).""" + boundary = "----WebKitFormBoundaryABCDEF123456" + body = _make_multipart([("x", "y")], boundary) + form = _multipart(body, boundary) + assert form.getvalue("x") == "y" + + def test_unicode_field_value(self): + """Unicode text field values are decoded correctly.""" + body = _make_multipart([("city", "Brno")]) + form = _multipart(body, "TestBoundary") + assert form.getvalue("city") == "Brno" + + def test_empty_file_upload(self): + """Empty file upload has a filename but empty content.""" + body = _make_multipart( + [("doc", "empty.txt", b"", "text/plain")] + ) + form = _multipart(body, "TestBoundary") + field = form["doc"] + assert field.filename == "empty.txt" + assert field.value == b"" + + def test_binary_file_upload(self): + """Binary file upload content is not decoded.""" + binary_data = bytes(range(256)) + body = _make_multipart( + [("bin", "data.bin", binary_data, "application/octet-stream")] + ) + form = _multipart(body, "TestBoundary") + assert form["bin"].value == binary_data + + def test_no_content_type_inner(self): + """Missing Content-Type on a part defaults to text/plain.""" + boundary = "Bound" + body = ( + b"--Bound\r\n" + b"Content-Disposition: form-data; name=\"field\"\r\n" + b"\r\n" + b"value\r\n" + b"--Bound--\r\n" + ) + form = _multipart(body, boundary) + assert form.getvalue("field") == "value" + + def test_max_num_fields_multipart(self): + """max_num_fields raises ValueError when multipart exceeds limit.""" + body = _make_multipart([("a", "1"), ("b", "2"), ("c", "3")]) + headers = Headers([ + ("Content-Type", "multipart/form-data; boundary=TestBoundary"), + ("Content-Length", str(len(body))), + ]) + parser = FieldStorageParser( + BytesIO(body), headers, max_num_fields=1 + ) + with pytest.raises(ValueError, match="Max number of fields"): + parser.parse() + + def test_content_length_header_on_part_ignored(self): + """Content-Length in individual part headers is stripped (browser + behaviour — some agents include it, RFC says it's optional).""" + boundary = "Bound" + body = ( + b"--Bound\r\n" + b"Content-Disposition: form-data; name=\"f\"\r\n" + b"Content-Length: 5\r\n" + b"\r\n" + b"hello\r\n" + b"--Bound--\r\n" + ) + form = _multipart(body, boundary) + assert form.getvalue("f") == "hello" + + def test_crlf_line_endings(self): + """CRLF line endings (required by RFC 7578 §4.1) are handled.""" + boundary = "Bound" + body = ( + b"--Bound\r\n" + b"Content-Disposition: form-data; name=\"k\"\r\n" + b"\r\n" + b"val\r\n" + b"--Bound--\r\n" + ) + form = _multipart(body, boundary) + assert form.getvalue("k") == "val" + + def test_field_name_with_unicode_filename(self): + """Filename with non-ASCII characters is preserved as-is.""" + boundary = "Bound" + body = ( + "--Bound\r\n" + "Content-Disposition: form-data; name=\"f\"; " + "filename=\"tëst.txt\"\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "data\r\n" + "--Bound--\r\n" + ).encode("utf-8") + form = _multipart(body, boundary) + assert "ë" in form["f"].filename + + +# --------------------------------------------------------------------------- +# FieldStorageParser — other content types (read_single / read_binary) +# --------------------------------------------------------------------------- + +class TestReadSingle: + """Parsing non-form content types — text/plain, application/octet-stream. + + When Content-Type is not URL-encoded or multipart, the body is + treated as a single field via read_single. + + Note: read_binary is designed for multipart file parts (filename set). + For content without a filename, read_lines is used regardless of + Content-Length, because make_file() returns a text-mode tempfile + when filename is None. + """ + + def test_text_plain_body(self): + """text/plain body without Content-Length is parsed via read_lines.""" + body = b"Hello, world" + headers = Headers([("Content-Type", "text/plain")]) + parser = FieldStorageParser(BytesIO(body), headers) + field = parser.parse() + assert field.file is not None + field.file.seek(0) + assert "Hello" in field.file.read() + + def test_read_lines_to_eof(self): + """Without outer boundary, read_lines reads the entire input.""" + body = b"line1\nline2\n" + headers = Headers([("Content-Type", "text/plain")]) + parser = FieldStorageParser(BytesIO(body), headers) + field = parser.parse() + field.file.seek(0) + content = field.file.read() + assert "line1" in content + + def test_read_binary_empty_data_sets_done(self): + """read_binary sets done=-1 when input is exhausted early.""" + parser = FieldStorageParser() + parser.filename = "file.bin" # filename → binary tempfile + parser.length = 100 + parser.input = BytesIO(b"") + file_ = parser.read_binary() + assert parser.done == -1 + file_.close() + + def test_file_upload_large_spills_to_tempfile(self): + """File upload larger than BUFSIZE spills to a temporary file.""" + binary_data = b"x" * (FieldStorageParser.BUFSIZE + 100) + body = _make_multipart( + [("f", "large.bin", binary_data, "application/octet-stream")] + ) + form = _multipart(body, "TestBoundary") + assert form["f"].value == binary_data + + +# --------------------------------------------------------------------------- +# FieldStorageParser — file_callback +# --------------------------------------------------------------------------- + +class TestFileCallback: # pylint: disable=too-few-public-methods + """file_callback lets the caller supply a custom writable stream.""" + + def test_file_callback_used_for_upload(self): + """file_callback is called with the filename for uploads.""" + called_with = [] + custom_buf = BytesIO() + + def callback(filename): + called_with.append(filename) + return custom_buf + + body = ( + b"--Bound\r\n" + b"Content-Disposition: form-data; name=\"f\"; " + b"filename=\"report.pdf\"\r\n" + b"Content-Type: application/pdf\r\n" + b"\r\n" + b"PDF content\r\n" + b"--Bound--\r\n" + ) + headers = Headers([ + ("Content-Type", "multipart/form-data; boundary=Bound"), + ("Content-Length", str(len(body))), + ]) + parser = FieldStorageParser( + BytesIO(body), headers, file_callback=callback + ) + parser.parse() + assert "report.pdf" in called_with + + +# --------------------------------------------------------------------------- +# FieldStorageParser — read_urlencoded error path +# --------------------------------------------------------------------------- + +class TestReadUrlencodedErrors: + """Error handling in URL-encoded parsing.""" + + def test_non_bytes_input_raises(self): + """read_urlencoded raises ValueError when input returns non-bytes.""" + + class _TextStream: # pylint: disable=too-few-public-methods + def read(self, _n=-1): + return "not bytes" + + headers = Headers([ + ("Content-Type", "application/x-www-form-urlencoded"), + ("Content-Length", "5"), + ]) + parser = FieldStorageParser(_TextStream(), headers) + parser.length = 5 + with pytest.raises(ValueError, match="should return bytes"): + parser.read_urlencoded() + + def test_invalid_boundary_raises(self): + """_skip_to_boundary raises ValueError for an invalid boundary.""" + headers = Headers([ + ("Content-Type", "multipart/form-data; boundary=Bound"), + ("Content-Length", "0"), + ]) + parser = FieldStorageParser(BytesIO(b""), headers) + parser.innerboundary = b"" # empty → invalid + with pytest.raises(ValueError, match="Invalid boundary"): + parser._skip_to_boundary() # pylint: disable=protected-access + + +# --------------------------------------------------------------------------- +# FieldStorageParser — make_file +# --------------------------------------------------------------------------- + +class TestMakeFile: + """make_file returns the correct stream type.""" + + def test_no_filename_returns_text_tempfile(self): + """Without a filename, make_file returns a text-mode temp file.""" + parser = FieldStorageParser() + parser.filename = None + f = parser.make_file() + f.write("hello") + f.seek(0) + assert f.read() == "hello" + f.close() + + def test_filename_returns_binary_tempfile(self): + """With a filename, make_file returns a binary-mode temp file.""" + parser = FieldStorageParser() + parser.filename = "test.bin" + f = parser.make_file() + f.write(b"\x00\xff") + f.seek(0) + assert f.read() == b"\x00\xff" + f.close() + + def test_file_callback_overrides_tempfile(self): + """file_callback completely replaces the default temp file.""" + custom = BytesIO() + parser = FieldStorageParser(file_callback=lambda _fn: custom) + parser.filename = "upload.bin" + result = parser.make_file() + assert result is custom + + +# --------------------------------------------------------------------------- +# FieldStorageParser — _parse_content_type +# --------------------------------------------------------------------------- + +class TestParseContentType: + """Content-Type header parsing edge cases.""" + + # pylint: disable=protected-access + + def test_no_header_outer(self): + """No Content-Type at outer level → url-encoded default.""" + parser = FieldStorageParser() + ctype, _ = parser._parse_content_type() + assert ctype == "application/x-www-form-urlencoded" + + def test_no_header_inner(self): + """No Content-Type at inner level (outerboundary set) → text/plain.""" + parser = FieldStorageParser(outerboundary=b"Bound") + ctype, _ = parser._parse_content_type() + assert ctype == "text/plain" + + def test_explicit_header_wins(self): + """Explicit Content-Type overrides defaults.""" + headers = Headers([("Content-Type", "application/json")]) + parser = FieldStorageParser(headers=headers) + ctype, _ = parser._parse_content_type() + assert ctype == "application/json" + + +# --------------------------------------------------------------------------- +# FieldStorageInterface — base class methods +# --------------------------------------------------------------------------- + +class TestFieldStorageInterfaceDirect: + """Tests for the FieldStorageInterface base class default implementations. + + These are only reachable via a subclass that does NOT override + getvalue/getfirst/getlist — i.e., not FieldStorage which overrides all. + """ + + def _minimal(self, items): + """Minimal concrete FieldStorageInterface with a dict store.""" + + class _Impl(FieldStorageInterface): + def __init__(self, data): + self._data = data + + def __contains__(self, key): + return key in self._data + + def __getitem__(self, key): + return self._data[key] + + return _Impl(dict(items)) + + def test_getvalue_found(self): + """getvalue returns the item when key is present.""" + impl = self._minimal([("k", "v")]) + assert impl.getvalue("k") == "v" + + def test_getvalue_with_func(self): + """getvalue applies func to the found value.""" + impl = self._minimal([("n", "42")]) + assert impl.getvalue("n", func=int) == 42 + + def test_getvalue_missing(self): + """getvalue returns default when key is absent.""" + impl = self._minimal([]) + assert impl.getvalue("x", default="d") == "d" + + def test_getfirst_single(self): + """getfirst returns the single value.""" + impl = self._minimal([("k", "v")]) + assert impl.getfirst("k") == "v" + + def test_getfirst_list(self): + """getfirst returns the first element of a list value.""" + impl = self._minimal([("k", ["a", "b"])]) + assert impl.getfirst("k") == "a" + + def test_getfirst_missing(self): + """getfirst returns default when key is absent.""" + impl = self._minimal([]) + assert impl.getfirst("x", default=0) == 0 + + def test_getfirst_deprecated_fce(self): + """getfirst base warns about deprecated fce.""" + impl = self._minimal([("k", "3")]) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = impl.getfirst("k", fce=int) + assert result == 3 + assert any(issubclass(x.category, DeprecationWarning) for x in w) + + def test_getlist_single(self): + """getlist wraps a non-list value in a list.""" + impl = self._minimal([("k", "v")]) + assert impl.getlist("k") == ["v"] + + def test_getlist_multiple(self): + """getlist returns the list as-is when value is already a list.""" + impl = self._minimal([("k", ["a", "b"])]) + assert impl.getlist("k") == ["a", "b"] + + def test_getlist_missing(self): + """getlist returns [] for absent key.""" + impl = self._minimal([]) + assert impl.getlist("x") == [] + + def test_getlist_missing_with_default(self): + """getlist returns default list for absent key.""" + impl = self._minimal([]) + assert impl.getlist("x", default=["d"]) == ["d"] + + def test_getlist_deprecated_fce(self): + """getlist base warns about deprecated fce.""" + impl = self._minimal([("k", ["1", "2"])]) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = impl.getlist("k", fce=int) + assert result == [1, 2] + assert any(issubclass(x.category, DeprecationWarning) for x in w) + + +# --------------------------------------------------------------------------- +# FieldStorage — remaining edge cases +# --------------------------------------------------------------------------- + +class TestFieldStorageEdgeCases: + """Edge cases not covered by the main test classes.""" + + def test_getitem_key_not_in_non_empty_list(self): + """__getitem__ raises KeyError when key absent in non-empty list.""" + root = FieldStorage() + root.list = [FieldStorage("other", "v")] + with pytest.raises(KeyError): + _ = root["missing"] + + def test_valid_boundary_bytes_no_match(self): + """Bytes boundary with control characters is invalid.""" + assert not valid_boundary(b"\x00invalid") + + def test_parser_textiowrapper_input(self): + """FieldStorageParser accepts a TextIOWrapper and uses its buffer.""" + raw = BytesIO(b"key=value") + wrapper = TextIOWrapper(raw, encoding="utf-8") + headers = Headers([ + ("Content-Type", "application/x-www-form-urlencoded"), + ("Content-Length", "9"), + ]) + parser = FieldStorageParser(wrapper, headers) + assert parser.input is raw # buffer, not wrapper + + +# --------------------------------------------------------------------------- +# read_single with Content-Disposition filename +# --------------------------------------------------------------------------- + +class TestReadSingleWithFilename: + """read_single with a binary Content-Disposition filename. + + When a filename is present, make_file returns a binary temp file, + so read_binary can write bytes successfully. + """ + + def test_binary_body_with_filename(self): + """Binary body with filename is read correctly via read_binary.""" + body = b"\x00\x01\x02\x03" + headers = Headers([ + ("Content-Type", "application/octet-stream"), + ("Content-Disposition", 'attachment; filename="blob.bin"'), + ("Content-Length", str(len(body))), + ]) + parser = FieldStorageParser(BytesIO(body), headers) + field = parser.parse() + field.file.seek(0) + assert field.file.read() == body + + def test_read_binary_non_bytes_raises(self): + """read_binary raises ValueError when stream returns non-bytes.""" + + class _BadStream: # pylint: disable=too-few-public-methods + def read(self, _n=-1): + return "not bytes" + + parser = FieldStorageParser() + parser.filename = "f.bin" + parser.length = 5 + parser.input = _BadStream() + with pytest.raises(ValueError, match="should return bytes"): + parser.read_binary() + + def test_skip_to_boundary_non_bytes_readline(self): + """_skip_to_boundary raises ValueError when readline returns + non-bytes.""" + + class _BadStream: # pylint: disable=too-few-public-methods + def readline(self): + return "string not bytes" + + parser = FieldStorageParser() + parser.innerboundary = b"Bound" + parser.input = _BadStream() + with pytest.raises(ValueError, match="should return bytes"): + parser._skip_to_boundary() # pylint: disable=protected-access + + +# --------------------------------------------------------------------------- +# read_lines_to_outerboundary — edge cases +# --------------------------------------------------------------------------- + +class TestReadLinesToOuterBoundary: + r"""Tests for \r-only line ending handling and limit in + read_lines_to_outerboundary.""" + + def test_limit_stops_reading(self): + """Reading stops when the limit byte count is reached.""" + body = _make_multipart([("f", "file.txt", b"A" * 20, "text/plain")]) + headers = Headers([ + ("Content-Type", "multipart/form-data; boundary=Bound"), + ("Content-Length", str(len(body))), + ]) + parser = FieldStorageParser( + BytesIO(body), headers, limit=len(body) + ) + form = parser.parse() + assert form is not None # just confirm parsing doesn't crash + + def test_cr_only_line_ending(self): + r"""Bare \r at end of a chunk is handled as a split \r\n.""" + boundary = "Bound" + # Build a multipart body where the field content ends with \r\n + # spanning a chunk boundary (16-bit read). The library handles + # \r split across reads via the delim variable. + body = ( + b"--Bound\r\n" + b"Content-Disposition: form-data; name=\"f\"\r\n" + b"\r\n" + b"value\r\n" + b"--Bound--\r\n" + ) + form = _multipart(body, boundary) + assert form.getvalue("f") == "value" diff --git a/tests/test_headers.py b/tests/test_headers.py new file mode 100644 index 0000000..7666e94 --- /dev/null +++ b/tests/test_headers.py @@ -0,0 +1,242 @@ +"""Unit tests for poorwsgi/headers.py module-level functions and Headers.""" +from datetime import datetime, timezone + +from pytest import raises + +from poorwsgi.headers import ( + ContentRange, + Headers, + datetime_to_http, + http_to_datetime, + http_to_time, + parse_header, + parse_negotiation, + parse_range, + render_negotiation, + time_to_http, +) + +# pylint: disable=missing-function-docstring +# pylint: disable=no-self-use +# pylint: disable=too-few-public-methods + +EPOCH = datetime(1970, 1, 1, 0, 0, 0, tzinfo=timezone.utc) +EPOCH_HTTP = "Thu, 01 Jan 1970 00:00:00 GMT" + + +class TestParseHeader: + """Tests for parse_header().""" + + def test_simple(self): + key, params = parse_header("text/html") + assert key == "text/html" + assert not params + + def test_with_param(self): + key, params = parse_header("text/html; charset=utf-8") + assert key == "text/html" + assert params == {"charset": "utf-8"} + + def test_quoted_value(self): + key, params = parse_header('attachment; filename="hello world.txt"') + assert key == "attachment" + assert params["filename"] == "hello world.txt" + + def test_quoted_semicolon_in_value(self): + """Semicolon inside a quoted string must not split the parameter.""" + key, params = parse_header('form-data; name="a;b"') + assert key == "form-data" + assert params["name"] == "a;b" + + +class TestParseNegotiation: + """Tests for parse_negotiation().""" + + def test_single_no_quality(self): + assert parse_negotiation("gzip") == [("gzip", 1.0)] + + def test_multiple_with_quality(self): + result = parse_negotiation("gzip;q=1.0, identity;q=0.5, *;q=0") + assert result == [("gzip", 1.0), ("identity", 0.5), ("*", 0.0)] + + def test_param_before_quality(self): + result = parse_negotiation( + "text/html;level=1, text/html;level=2;q=0.5") + assert result == [("text/html;level=1", 1.0), + ("text/html;level=2", 0.5)] + + def test_bad_quality_falls_back_to_1(self): + """Non-numeric quality value must fall back to 1.0.""" + result = parse_negotiation("br;q=bad") + assert result == [("br", 1.0)] + + +class TestRenderNegotiation: + """Tests for render_negotiation().""" + + def test_with_quality(self): + assert render_negotiation([("gzip", 1.0), ("*", 0)]) == \ + "gzip;q=1.0, *;q=0" + + def test_without_quality(self): + assert render_negotiation([("gzip",)]) == "gzip" + + +class TestParseRange: + """Tests for parse_range().""" + + def test_simple(self): + assert parse_range("bytes=0-499") == {"bytes": [(0, 499)]} + + def test_suffix_range(self): + assert parse_range("bytes=-500") == {"bytes": [(None, 500)]} + + def test_open_ended(self): + assert parse_range("bytes=9500-") == {"bytes": [(9500, None)]} + + def test_multi_range(self): + assert parse_range("chunks=500-600,601-999") == \ + {"chunks": [(500, 600), (601, 999)]} + + def test_invalid_no_equals(self): + assert not parse_range("invalid") + + def test_invalid_values(self): + result = parse_range("invalid=a-b") + assert result == {"invalid": []} + + def test_empty_pair_skipped(self): + """A '-' with no numbers on either side must be silently skipped.""" + result = parse_range("bytes=0-1,-") + assert result == {"bytes": [(0, 1)]} + + +class TestDatetimeFunctions: + """Tests for datetime_to_http, time_to_http, http_to_datetime, + http_to_time.""" + + def test_datetime_to_http(self): + assert datetime_to_http(EPOCH) == EPOCH_HTTP + + def test_time_to_http_with_value(self): + assert time_to_http(0) == EPOCH_HTTP + + def test_time_to_http_without_value(self): + result = time_to_http() + assert result.endswith(" GMT") + + def test_http_to_datetime(self): + assert http_to_datetime(EPOCH_HTTP) == EPOCH + + def test_http_to_time(self): + assert http_to_time(EPOCH_HTTP) == 0 + + +class TestContentRange: + """Tests for ContentRange.""" + + def test_without_full(self): + assert str(ContentRange(1, 2)) == "bytes 1-2/*" + + def test_with_full(self): + assert str(ContentRange(1, 2, 10)) == "bytes 1-2/10" + + def test_custom_units(self): + assert str(ContentRange(2, 5, units="lines")) == "lines 2-5/*" + + +class TestHeadersMethods: + """Tests for Headers methods not covered by test_header.py.""" + + def test_repr(self): + headers = Headers([("X-Test", "value")]) + assert "X-Test" in repr(headers) + + def test_names_and_keys(self): + headers = Headers([("X-A", "1"), ("X-B", "2")]) + assert headers.names() == ("X-A", "X-B") + assert headers.keys() == headers.names() + + def test_values(self): + headers = Headers([("X-A", "1"), ("X-B", "2")]) + assert headers.values() == ("1", "2") + + def test_get_all_multiple(self): + headers = Headers([("Set-Cookie", "a=1"), ("Set-Cookie", "b=2")]) + assert headers.get_all("Set-Cookie") == ("a=1", "b=2") + + def test_get_all_missing(self): + assert not Headers().get_all("X-Missing") + + def test_delitem(self): + headers = Headers([("X-Test", "v")]) + del headers["X-Test"] + assert "X-Test" not in headers + + def test_setitem_overwrites(self): + headers = Headers([("X-Test", "old")]) + headers["X-Test"] = "new" + assert headers["X-Test"] == "new" + assert len(headers.get_all("X-Test")) == 1 + + def test_setdefault_missing(self): + headers = Headers() + result = headers.setdefault("X-Test", "default") + assert result == "default" + assert headers["X-Test"] == "default" + + def test_setdefault_existing(self): + headers = Headers([("X-Test", "original")]) + result = headers.setdefault("X-Test", "other") + assert result == "original" + + def test_add_duplicate_raises(self): + headers = Headers([("X-Test", "v")]) + with raises(KeyError): + headers.add("X-Test", "v2") + + def test_add_set_cookie_allows_duplicate(self): + headers = Headers() + headers.add("Set-Cookie", "a=1") + headers.add("Set-Cookie", "b=2") + assert len(headers.get_all("Set-Cookie")) == 2 + + def test_add_header_none_kwarg(self): + """Kwarg with value=None adds a bare flag (no '=value' suffix).""" + headers = Headers() + headers.add_header("Cache-Control", "no-cache", no_store=None) + assert headers["Cache-Control"] == "no-cache; no-store" + + def test_strict_false_list(self): + headers = Headers([("X-Raw", "value")], strict=False) + assert headers["X-Raw"] == "value" + + def test_strict_false_dict(self): + headers = Headers({"X-Raw": "value"}, strict=False) + assert headers["X-Raw"] == "value" + + def test_iso88591_unicode_error(self): + """Lone surrogate codepoints cannot be UTF-8 encoded and must raise + ValueError.""" + with raises(ValueError, match="iso-8859-1"): + Headers.iso88591("\ud800") # lone surrogate + + def test_utf8_roundtrip(self): + original = "café" + iso = original.encode("utf-8").decode("iso-8859-1") + assert Headers.utf8(iso) == original + + def test_utf8_already_utf8(self): + """Bytes that cannot be decoded as UTF-8 (raw ISO-8859-1 high bytes) + are returned as-is.""" + raw = "\x80" # U+0080 → b'\x80' as ISO-8859-1, invalid as UTF-8 + assert Headers.utf8(raw) == raw + + def test_len(self): + headers = Headers([("X-A", "1"), ("X-B", "2")]) + assert len(headers) == 2 + + def test_iter(self): + pairs = [("X-A", "1"), ("X-B", "2")] + headers = Headers(pairs) + assert list(headers) == pairs diff --git a/tests/test_request.py b/tests/test_request.py index 6d885d0..caf2de1 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -1,4 +1,6 @@ """Tests for request module functionality.""" +import base64 +import warnings from io import BytesIO from time import time from typing import Any, ClassVar @@ -6,16 +8,21 @@ from pytest import fixture, raises from poorwsgi import Application -from poorwsgi.fieldstorage import FieldStorageParser +from poorwsgi.fieldstorage import FieldStorage, FieldStorageParser from poorwsgi.headers import Headers -from poorwsgi.request import (Args, EmptyForm, JsonDict, JsonList, Request, +from poorwsgi.request import (Args, CachedInput, EmptyForm, + FieldStorage as DeprecatedFieldStorage, + JsonDict, JsonList, Request, SimpleRequest, parse_json_request) from poorwsgi.response import HTTPException +from poorwsgi.state import methods # pylint: disable=missing-function-docstring # pylint: disable=no-self-use # pylint: disable=redefined-outer-name # pylint: disable=too-few-public-methods +# pylint: disable=too-many-lines +# pylint: disable=too-many-public-methods @fixture(scope='session') @@ -411,3 +418,894 @@ def start_response(status, headers): assert '400' in body_str assert 'Bad Request' in body_str assert 'Invalid PATH_INFO encoding' in body_str + + +# --------------------------------------------------------------------------- +# Helpers shared by the new test classes +# --------------------------------------------------------------------------- + +def _make_env(**kwargs): + """Build a minimal valid WSGI environ for SimpleRequest / Request.""" + env = { + 'PATH_INFO': '/path', + 'REQUEST_METHOD': 'GET', + 'SERVER_NAME': 'localhost', + 'SERVER_PORT': '80', + 'SERVER_PROTOCOL': 'HTTP/1.1', + 'wsgi.url_scheme': 'http', + 'REQUEST_STARTTIME': time(), + 'wsgi.input': BytesIO(), + 'wsgi.errors': BytesIO(), + } + env.update(kwargs) + return env + + +# --------------------------------------------------------------------------- +# SimpleRequest properties +# --------------------------------------------------------------------------- + +class TestSimpleRequest: + """Tests for SimpleRequest properties and edge cases.""" + + def test_uwsgi_poor_environ(self, app): + """uwsgi.version in environ causes poor_environ to use os.environ.""" + env = _make_env(**{'uwsgi.version': b'2.0.0'}) + req = SimpleRequest(env, app) + poor = req.poor_environ + assert poor is not env + + def test_poor_version_env_detection(self, app, monkeypatch): + """poor.Version in os.environ causes poor_environ to use os.environ.""" + monkeypatch.setenv('poor.Version', 'test') + env = _make_env() + req = SimpleRequest(env, app) + assert 'poor.Version' in req.poor_environ + + def test_debug_from_environ_on(self, app): + """poor_Debug=on in environ sets debug=True.""" + env = _make_env(poor_Debug='on') + req = SimpleRequest(env, app) + assert req.debug is True + + def test_debug_from_environ_off(self, app): + """poor_Debug=off in environ sets debug=False.""" + env = _make_env(poor_Debug='off') + req = SimpleRequest(env, app) + assert req.debug is False + + def test_debug_falls_back_to_app(self, app): + """Without poor_Debug, debug comes from app.debug.""" + env = _make_env() + req = SimpleRequest(env, app) + assert req.debug == app.debug + + def test_app_property(self, app): + """app property returns the Application object.""" + env = _make_env() + req = SimpleRequest(env, app) + assert req.app is app + + def test_environ_copy(self, app): + """environ property returns a copy of the environ dict.""" + env = _make_env() + req = SimpleRequest(env, app) + copy = req.environ + assert copy is not env + assert copy['PATH_INFO'] == '/path' + + def test_poor_environ_copy(self, app): + """poor_environ property returns a copy.""" + env = _make_env() + req = SimpleRequest(env, app) + copy = req.poor_environ + assert isinstance(copy, dict) + + def test_uri_rule_set_once(self, app): + """uri_rule setter ignores subsequent assignments.""" + env = _make_env() + req = SimpleRequest(env, app) + req.uri_rule = '/first' + req.uri_rule = '/second' + assert req.uri_rule == '/first' + + def test_uri_handler_set_once(self, app): + """uri_handler setter ignores subsequent assignments.""" + def handler1(): + pass + + def handler2(): + pass + env = _make_env() + req = SimpleRequest(env, app) + req.uri_handler = handler1 + req.uri_handler = handler2 + assert req.uri_handler is handler1 + + def test_error_handler_set_once(self, app): + """error_handler setter ignores subsequent assignments.""" + def h1(): + pass + + def h2(): + pass + env = _make_env() + req = SimpleRequest(env, app) + req.error_handler = h1 + req.error_handler = h2 + assert req.error_handler is h1 + + def test_host_port_https_default(self, app): + """host_port returns 443 for https when no port in HTTP_HOST.""" + env = _make_env(**{'wsgi.url_scheme': 'https', 'SERVER_PORT': '443'}) + req = SimpleRequest(env, app) + assert req.host_port == 443 + + def test_host_port_http_default(self, app): + """host_port returns 80 for http when no port in HTTP_HOST.""" + env = _make_env() + req = SimpleRequest(env, app) + assert req.host_port == 80 + + def test_method_number_unknown_method(self, app): + """method_number falls back to GET for unknown methods.""" + env = _make_env(REQUEST_METHOD='UNKNOWN') + req = SimpleRequest(env, app) + assert req.method_number == methods['GET'] + + def test_method_number_post(self, app): + """method_number returns the POST constant.""" + env = _make_env(REQUEST_METHOD='POST') + req = SimpleRequest(env, app) + assert req.method_number == methods['POST'] + + def test_full_path_with_query(self, app): + """full_path includes query string when present.""" + env = _make_env(QUERY_STRING='foo=bar&baz=1') + req = SimpleRequest(env, app) + assert req.full_path == '/path?foo=bar&baz=1' + + def test_remote_host(self, app): + """remote_host returns REMOTE_HOST environ value.""" + env = _make_env(REMOTE_HOST='client.example.org') + req = SimpleRequest(env, app) + assert req.remote_host == 'client.example.org' + + def test_remote_addr(self, app): + """remote_addr returns REMOTE_ADDR environ value.""" + env = _make_env(REMOTE_ADDR='1.2.3.4') + req = SimpleRequest(env, app) + assert req.remote_addr == '1.2.3.4' + + def test_referer(self, app): + """referer returns HTTP_REFERER environ value.""" + env = _make_env(HTTP_REFERER='http://example.org/') + req = SimpleRequest(env, app) + assert req.referer == 'http://example.org/' + + def test_user_agent(self, app): + """user_agent returns HTTP_USER_AGENT environ value.""" + env = _make_env(HTTP_USER_AGENT='TestBot/1.0') + req = SimpleRequest(env, app) + assert req.user_agent == 'TestBot/1.0' + + def test_server_admin_custom(self, app): + """server_admin returns SERVER_ADMIN when set.""" + env = _make_env(SERVER_ADMIN='admin@example.org') + req = SimpleRequest(env, app) + assert req.server_admin == 'admin@example.org' + + def test_server_admin_default(self, app): + """server_admin defaults to webmaster@.""" + env = _make_env() + req = SimpleRequest(env, app) + assert req.server_admin == 'webmaster@localhost' + + def test_server_port(self, app): + """server_port returns int SERVER_PORT.""" + env = _make_env(SERVER_PORT='8080') + req = SimpleRequest(env, app) + assert req.server_port == 8080 + + def test_port_alias(self, app): + """port is an alias for server_port.""" + env = _make_env(SERVER_PORT='9000') + req = SimpleRequest(env, app) + assert req.port == 9000 + + def test_protocol(self, app): + """protocol returns SERVER_PROTOCOL.""" + env = _make_env(SERVER_PROTOCOL='HTTP/2.0') + req = SimpleRequest(env, app) + assert req.protocol == 'HTTP/2.0' + + def test_forwarded_for(self, app): + """forwarded_for returns X-Forwarded-For header.""" + env = _make_env(HTTP_X_FORWARDED_FOR='10.0.0.1') + req = SimpleRequest(env, app) + assert req.forwarded_for == '10.0.0.1' + + def test_forwarded_host_with_port_strips_port(self, app): + """forwarded_host strips the port component.""" + env = _make_env(HTTP_X_FORWARDED_HOST='proxy.example.org:8080') + req = SimpleRequest(env, app) + assert req.forwarded_host == 'proxy.example.org' + + def test_forwarded_host_without_port(self, app): + """forwarded_host returns host unchanged when no port present.""" + env = _make_env(HTTP_X_FORWARDED_HOST='proxy.example.org') + req = SimpleRequest(env, app) + assert req.forwarded_host == 'proxy.example.org' + + def test_forwarded_port_from_host_header(self, app): + """forwarded_port uses port from X-Forwarded-Host.""" + env = _make_env(HTTP_X_FORWARDED_HOST='proxy.example.org:9090') + req = SimpleRequest(env, app) + assert req.forwarded_port == 9090 + + def test_forwarded_port_from_proto_https(self, app): + """forwarded_port returns 443 when X-Forwarded-Proto is https.""" + env = _make_env(HTTP_X_FORWARDED_PROTO='https') + req = SimpleRequest(env, app) + assert req.forwarded_port == 443 + + def test_forwarded_port_from_proto_http(self, app): + """forwarded_port returns 80 when X-Forwarded-Proto is http.""" + env = _make_env(HTTP_X_FORWARDED_PROTO='http') + req = SimpleRequest(env, app) + assert req.forwarded_port == 80 + + def test_forwarded_port_none_when_absent(self, app): + """forwarded_port returns None when no forwarding headers.""" + env = _make_env() + req = SimpleRequest(env, app) + assert req.forwarded_port is None + + def test_forwarded_proto(self, app): + """forwarded_proto returns X-Forwarded-Proto.""" + env = _make_env(HTTP_X_FORWARDED_PROTO='https') + req = SimpleRequest(env, app) + assert req.forwarded_proto == 'https' + + def test_secret_key_from_environ(self, app): + """secret_key returns poor_SecretKey when set in environ.""" + env = _make_env(poor_SecretKey='mysecret') # noqa: S105 + req = SimpleRequest(env, app) + assert req.secret_key == 'mysecret' # noqa: S105 + + def test_document_index_from_environ(self, app): + """poor_DocumentIndex=on in environ sets document_index=True.""" + env = _make_env(poor_DocumentIndex='on') + req = SimpleRequest(env, app) + assert req.document_index is True + + def test_get_options_deprecated(self, app): + """get_options() emits DeprecationWarning and returns app_ vars.""" + env = _make_env(app_db='localhost', app_templates='templ') + req = SimpleRequest(env, app) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter('always') + opts = req.get_options() + assert len(caught) == 1 + assert issubclass(caught[0].category, DeprecationWarning) + assert opts == {'db': 'localhost', 'templates': 'templ'} + + def test_construct_url_already_absolute(self, app): + """construct_url returns URI unchanged when it already has a scheme.""" + env = _make_env() + req = SimpleRequest(env, app) + url = 'http://other.example.org/page' + assert req.construct_url(url) == url + + def test_construct_url_nondefault_port(self, app): + """construct_url includes port when it is non-default.""" + env = _make_env(HTTP_HOST='localhost:8080') + req = SimpleRequest(env, app) + assert req.construct_url('/foo') == 'http://localhost:8080/foo' + + def test_server_software_uwsgi(self, app): + """server_software returns 'uWsgi' when uwsgi.version is present.""" + env = _make_env(**{'uwsgi.version': b'2.0'}) + req = SimpleRequest(env, app) + assert req.server_software == 'uWsgi' + + +# --------------------------------------------------------------------------- +# Request.__init__ branches +# --------------------------------------------------------------------------- + +class TestRequestInit: + """Tests for Request.__init__ edge cases.""" + + def test_missing_path_info_raises(self, app): + """PATH_INFO=None raises ConnectionError.""" + env = _make_env() + env['PATH_INFO'] = None + with raises(ConnectionError): + Request(env, app) + + def test_content_headers_parsed(self, app): + """CONTENT_TYPE and CONTENT_LENGTH from environ reach the headers.""" + env = _make_env( + CONTENT_TYPE='text/plain', + CONTENT_LENGTH='5', + **{'wsgi.input': BytesIO(b'hello')}, + ) + req = Request(env, app) + assert req.mime_type == 'text/plain' + assert req.content_length == 5 + + def test_auto_data_wraps_body(self, app): + """auto_data caches body in BytesIO when content_length <= data_size. + """ + body = b'cached' + env = _make_env( + REQUEST_METHOD='POST', + CONTENT_LENGTH=str(len(body)), + CONTENT_TYPE='application/octet-stream', + **{'wsgi.input': BytesIO(body)}, + ) + req = Request(env, app) + assert req.data == body + + def test_auto_json_parses_body(self, app): + """auto_json parses a JSON body into req.json.""" + body = b'{"name": "test"}' + env = _make_env( + REQUEST_METHOD='POST', + CONTENT_TYPE='application/json', + CONTENT_LENGTH=str(len(body)), + **{'wsgi.input': BytesIO(body)}, + ) + req = Request(env, app) + assert isinstance(req.json, JsonDict) + assert req.json['name'] == 'test' + + def test_auto_form_parses_body(self, app): + """auto_form parses a URL-encoded body into req.form.""" + body = b'name=Ondrej&age=30' + env = _make_env( + REQUEST_METHOD='POST', + CONTENT_TYPE='application/x-www-form-urlencoded', + CONTENT_LENGTH=str(len(body)), + **{'wsgi.input': BytesIO(body)}, + ) + req = Request(env, app) + assert isinstance(req.form, FieldStorage) + assert req.form.getvalue('name') == 'Ondrej' + + def test_auto_cookies_parses_cookie_header(self, app): + """Cookie header is parsed into req.cookies when auto_cookies is set. + """ + env = _make_env(HTTP_COOKIE='session=abc123; token=xyz') + req = Request(env, app) + assert req.cookies is not None + assert 'session' in req.cookies + assert req.cookies['session'].value == 'abc123' + + +# --------------------------------------------------------------------------- +# Request properties +# --------------------------------------------------------------------------- + +class TestRequestProperties: + """Tests for Request properties not covered by TestRequest.""" + + def _req(self, app, **kwargs): + return Request(_make_env(**kwargs), app) + + def test_charset_default(self, app): + """charset defaults to utf-8 when not in Content-Type.""" + req = self._req(app) + assert req.charset == 'utf-8' + + def test_charset_from_content_type(self, app): + """charset is extracted from Content-Type parameter.""" + req = self._req(app, CONTENT_TYPE='text/plain; charset=iso-8859-1') + assert req.charset == 'iso-8859-1' + + def test_content_length_absent(self, app): + """content_length is -1 when Content-Length header is absent.""" + req = self._req(app) + assert req.content_length == -1 + + def test_accept_parses_header(self, app): + """accept property parses the Accept header.""" + req = self._req(app, HTTP_ACCEPT='text/html, application/json;q=0.9') + acc = req.accept + assert ('text/html', 1.0) in acc + assert ('application/json', 0.9) in acc + + def test_accept_charset(self, app): + """accept_charset parses the Accept-Charset header.""" + req = self._req(app, HTTP_ACCEPT_CHARSET='utf-8, iso-8859-1;q=0.5') + result = req.accept_charset + assert any(v == 'utf-8' for v, _ in result) + + def test_accept_encoding(self, app): + """accept_encoding parses the Accept-Encoding header.""" + req = self._req(app, HTTP_ACCEPT_ENCODING='gzip, deflate') + result = req.accept_encoding + assert any(v == 'gzip' for v, _ in result) + + def test_accept_language(self, app): + """accept_language parses the Accept-Language header.""" + req = self._req(app, HTTP_ACCEPT_LANGUAGE='en-US, cs;q=0.8') + result = req.accept_language + assert any(v == 'en-US' for v, _ in result) + + def test_accept_html_true(self, app): + """accept_html returns True when text/html is accepted.""" + req = self._req(app, HTTP_ACCEPT='text/html') + assert req.accept_html is True + + def test_accept_xhtml_true(self, app): + """accept_xhtml returns True when text/xhtml is accepted.""" + req = self._req(app, HTTP_ACCEPT='text/xhtml') + assert req.accept_xhtml is True + + def test_accept_json_true(self, app): + """accept_json returns True when application/json is accepted.""" + req = self._req(app, HTTP_ACCEPT='application/json') + assert req.accept_json is True + + def test_authorization_basic(self, app): + """authorization parses a Basic auth header.""" + creds = base64.b64encode(b'user:pass').decode() + req = self._req(app, HTTP_AUTHORIZATION=f'Basic {creds}') + auth = req.authorization + assert auth['type'] == 'Basic' + + def test_authorization_digest(self, app): + """authorization parses a Digest auth header.""" + header = 'Digest username="alice", realm="test", nonce="abc"' + req = self._req(app, HTTP_AUTHORIZATION=header) + auth = req.authorization + assert auth['type'] == 'Digest' + assert auth['username'] == 'alice' + + def test_is_xhr_true(self, app): + """is_xhr returns True when X-Requested-With is XMLHttpRequest.""" + req = self._req(app, HTTP_X_REQUESTED_WITH='XMLHttpRequest') + assert req.is_xhr is True + + def test_is_xhr_false(self, app): + """is_xhr returns False when X-Requested-With is absent.""" + req = self._req(app) + assert req.is_xhr is False + + def test_is_chunked_request_deprecated(self, app): + """is_chunked_request emits DeprecationWarning.""" + req = self._req(app, HTTP_TRANSFER_ENCODING='chunked') + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter('always') + result = req.is_chunked_request + assert result is True + assert any(issubclass(w.category, DeprecationWarning) for w in caught) + + def test_path_args_default_empty(self, app): + """path_args returns {} when not set.""" + req = self._req(app) + assert not req.path_args + + def test_path_args_setter_once(self, app): + """path_args setter ignores subsequent assignments.""" + req = self._req(app) + req.path_args = {'id': '1'} + req.path_args = {'id': '2'} + assert req.path_args == {'id': '1'} + + def test_args_property(self, app): + """args property returns Args instance parsed from QUERY_STRING.""" + req = self._req(app, QUERY_STRING='x=1&y=2') + assert req.args.getvalue('x') == '1' + + def test_args_setter_once(self, app): + """args setter replaces EmptyForm but ignores further sets.""" + req = self._req(app) + req.args = Args(req) + assert not isinstance(req.args, EmptyForm) + + def test_form_setter_once(self, app): + """form setter replaces EmptyForm but ignores further sets.""" + req = self._req(app) + first = FieldStorage() + second = FieldStorage() + req.form = first + req.form = second + assert req.form is first + + def test_json_property_empty(self, app): + """json property returns EmptyForm by default.""" + req = self._req(app) + assert isinstance(req.json, EmptyForm) + + def test_data_property_with_body(self, app): + """data property returns body when auto_data is active.""" + body = b'hello' + env = _make_env( + REQUEST_METHOD='POST', + CONTENT_LENGTH='5', + CONTENT_TYPE='application/octet-stream', + **{'wsgi.input': BytesIO(body)}, + ) + req = Request(env, app) + assert req.data == body + + def test_data_property_none_for_non_bytesio(self, app): + """data returns None when wsgi.input is not a BytesIO instance.""" + class _RawIO: + def read(self, _n=-1): # pylint: disable=invalid-name + return b'' + + env = _make_env(**{'wsgi.input': _RawIO()}) + req = Request(env, app) + assert req.data is None + + def test_input_returns_file(self, app): + """input property returns the wsgi.input stream.""" + req = self._req(app) + assert req.input is not None + + def test_user_property(self, app): + """user property can be set and read.""" + req = self._req(app) + assert req.user is None + req.user = 'alice' + assert req.user == 'alice' + + def test_api_property(self, app): + """api property can be set and read.""" + req = self._req(app) + assert req.api is None + req.api = {'version': '1.0'} + assert req.api == {'version': '1.0'} + + def test_db_property(self, app): + """db property can be set and read.""" + req = self._req(app) + assert req.db is None + req.db = object() + assert req.db is not None + + +# --------------------------------------------------------------------------- +# Request methods +# --------------------------------------------------------------------------- + +class TestRequestMethods: + """Tests for Request read, read_chunk, and __del__.""" + + def test_read_returns_empty_without_body(self, app): + """read() returns b'' when there is no body.""" + req = Request(_make_env(), app) + assert req.read() == b'' + + def test_read_returns_body(self, app): + """read() returns the full body from wsgi.input.""" + body = b'payload data' + env = _make_env( + REQUEST_METHOD='POST', + CONTENT_LENGTH=str(len(body)), + CONTENT_TYPE='application/octet-stream', + **{'wsgi.input': BytesIO(body)}, + ) + req = Request(env, app) + assert req.read() == body + + def test_read_partial_length(self, app): + """read(n) reads exactly n bytes and switches to __read mode.""" + body = b'hello world' + env = _make_env( + REQUEST_METHOD='POST', + CONTENT_LENGTH=str(len(body)), + CONTENT_TYPE='application/octet-stream', + **{'wsgi.input': BytesIO(body)}, + ) + req = Request(env, app) + chunk = req.read(5) + assert chunk == b'hello' + + def test_read_chunk(self, app): + """read_chunk() reads a hex-length-prefixed chunk.""" + chunk_data = b'Hello' + # Chunked format: hex_length CRLF data CRLF + raw = b'5\r\n' + chunk_data + b'\r\n' + env = _make_env(**{'wsgi.input': BytesIO(raw)}) + req = Request(env, app) + result = req.read_chunk() + assert result == chunk_data + + def test_del_does_not_raise(self, app): + """__del__ executes without error.""" + req = Request(_make_env(), app) + del req + + +# --------------------------------------------------------------------------- +# EmptyForm and JsonList deprecated fce argument +# --------------------------------------------------------------------------- + +class TestDeprecatedFce: + """Tests for the deprecated fce argument in EmptyForm and JsonList.""" + + def test_empty_form_getfirst_fce_warns(self): + """EmptyForm.getfirst with fce emits DeprecationWarning.""" + form = EmptyForm() + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter('always') + result = form.getfirst('x', default=42, fce=int) + assert result == 42 + assert any(issubclass(w.category, DeprecationWarning) for w in caught) + + def test_empty_form_getlist_fce_warns(self): + """EmptyForm.getlist with fce emits DeprecationWarning.""" + form = EmptyForm() + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter('always') + result = form.getlist('x', fce=str) + assert not result + assert any(issubclass(w.category, DeprecationWarning) for w in caught) + + def test_json_list_getfirst_fce_warns(self): + """JsonList.getfirst with fce emits DeprecationWarning.""" + jl = JsonList([10, 20]) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter('always') + result = jl.getfirst('any', fce=str) + assert result == '10' + assert any(issubclass(w.category, DeprecationWarning) for w in caught) + + def test_json_list_getlist_fce_warns(self): + """JsonList.getlist with fce emits DeprecationWarning.""" + jl = JsonList([1, 2, 3]) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter('always') + result = jl.getlist('any', fce=str) + assert result == ['1', '2', '3'] + assert any(issubclass(w.category, DeprecationWarning) for w in caught) + + def test_json_list_getlist_empty_with_default(self): + """JsonList.getlist on empty list returns the provided default.""" + jl = JsonList() + assert jl.getlist('any', default=[9, 8]) == [9, 8] + + +# --------------------------------------------------------------------------- +# Deprecated FieldStorage compatibility function +# --------------------------------------------------------------------------- + +class TestFieldStorageCompat: + """Tests for the deprecated FieldStorage backwards-compatibility wrapper. + """ + + def test_field_storage_deprecated_warns(self, app): + """FieldStorage() emits a DeprecationWarning and returns a form.""" + # Use a MIME type that auto_form does not consume, so wsgi.input + # is still available when FieldStorage is called manually. + body = b'key=value' + env = _make_env( + REQUEST_METHOD='POST', + CONTENT_TYPE='application/x-www-form-urlencoded', + CONTENT_LENGTH=str(len(body)), + **{'wsgi.input': BytesIO(body)}, + ) + req = Request(env, app) + # Seek the input back to 0 — auto_form may have consumed it. + req.input.seek(0) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter('always') + form = DeprecatedFieldStorage(req) + assert any(issubclass(w.category, DeprecationWarning) for w in caught) + assert form is not None + + +# --------------------------------------------------------------------------- +# CachedInput +# --------------------------------------------------------------------------- + +class TestCachedInput: + """Tests for CachedInput buffered read and readline.""" + + # --- read() --- + + def test_read_from_file_no_buffer(self): + """read() reads directly from file when buffer is empty.""" + ci = CachedInput(BytesIO(b'hello world'), 11, block_size=32768) + assert ci.read(5) == b'hello' + + def test_read_default_uses_block_size(self): + """read() with no argument reads block_size bytes.""" + ci = CachedInput(BytesIO(b'x' * 10), 10, block_size=4) + chunk = ci.read() + assert chunk == b'xxxx' + + def test_read_from_buffer_sufficient(self): + """read() returns from buffer when it holds enough data.""" + # Pre-populate buffer and keep todo > 0 so size is not capped to 0. + ci = CachedInput(BytesIO(b'extra'), 10, block_size=32768) + ci._CachedInput__buffer = b'hello' # pylint: disable=protected-access + ci._CachedInput__todo = 10 # pylint: disable=protected-access + assert ci.read(3) == b'hel' + + def test_read_combines_buffer_and_file(self): + """read() combines partial buffer with additional file data.""" + # 2 bytes in buffer, 5 bytes in file, request 5 → combine. + ci = CachedInput(BytesIO(b'CDEFG'), 10, block_size=32768) + ci._CachedInput__buffer = b'AB' # pylint: disable=protected-access + ci._CachedInput__todo = 10 # pylint: disable=protected-access + result = ci.read(5) + assert result == b'ABCDE' + + # --- readline() --- + + def test_readline_finds_crlf(self): + """readline() returns up to and including the first CRLF.""" + data = b'hello\r\nworld' + ci = CachedInput(BytesIO(data), len(data), block_size=32768) + assert ci.readline() == b'hello\r\n' + + def test_readline_no_crlf_returns_all(self): + """readline() returns all data when no CRLF is present.""" + data = b'noeol' + ci = CachedInput(BytesIO(data), len(data), block_size=32768, + timeout=None) + assert ci.readline() == data + + def test_readline_timeout_raises(self): + """readline() raises TimeoutError when data never arrives. + + todo must be > 0 so the initial buffer fill sets size > 0 and the + while-loop actually runs the timeout check. + """ + # BytesIO(b'') returns b'' for any read → buffer stays empty forever. + # timeout=0 means times_out_at is already in the past at first check. + ci = CachedInput(BytesIO(b''), 5, block_size=5, timeout=0) + with raises(TimeoutError): + ci.readline() + + def test_readline_seen_data_resets_timer(self): + """readline() resets the timeout timer after consuming data.""" + # With timeout=None, this just exercises the seen_data=True branch + # by having data in the buffer at timeout-check time. + data = b'nodot' + ci = CachedInput(BytesIO(data), len(data), block_size=32768, + timeout=10) + result = ci.readline() + # No CRLF → returns full data + assert result == data + + def test_readline_with_existing_buffer(self): + """readline() uses existing buffer content before reading file.""" + data = b'first\r\nsecond\r\n' + ci = CachedInput(BytesIO(data), len(data), block_size=len(data)) + assert ci.readline() == b'first\r\n' + # Second readline uses leftover buffer + assert ci.readline() == b'second\r\n' + + def test_readline_reads_more_from_file(self): + """readline() reads additional data when buffer is shorter than size. + + Covers lines 1115-1118 (n_size read) and 1103-1104 (seen_data=True). + """ + # Pre-load a short buffer (2 bytes) plus file that completes the line. + ci = CachedInput(BytesIO(b'\r\n'), 2, block_size=5, timeout=10) + # pylint: disable=protected-access + ci._CachedInput__buffer = b'ab' + ci._CachedInput__todo = 2 + # readline: buffer is non-empty → skip initial fill, size=block_size=5 + # iter1: buffer=b'ab' → seen_data=True; consume; l_size=2 < 5 + # → reads 2 more bytes from file → buffer=b'\r\n' + # iter2: finds \r\n at pos=0 → returns b'ab\r\n' + result = ci.readline() + assert result == b'ab\r\n' + + def test_readline_seen_data_timer_reset(self): + """readline() resets the timer when buffer empties after having data. + + Covers lines 1106-1107 (seen_data=True → False, timer reset). + After the reset, timeout=0 fires TimeoutError on next iteration. + """ + # Pre-populate buffer with b'ab', file returns nothing, timeout=0. + # pylint: disable=protected-access + ci = CachedInput(BytesIO(b''), 2, block_size=5, timeout=0) + ci._CachedInput__buffer = b'ab' + ci._CachedInput__todo = 2 + # iter1: buffer=b'ab' → seen_data=True; consume; read()→b''; l_size=2<5 + # iter2: buffer=b'' → seen_data resets to False + timer reset (1106) + # iter3: buffer=b'' + no data → time()>times_out_at → TimeoutError + with raises(TimeoutError): + ci.readline() + + +# --------------------------------------------------------------------------- +# Additional targeted tests for remaining uncovered lines +# --------------------------------------------------------------------------- + +class TestRemainingCoverage: + """Covers lines not reached by the main test classes.""" + + def test_uri_deprecated_alias(self, app): + """uri property is a deprecated alias for path.""" + req = Request(_make_env(PATH_INFO='/foo'), app) + assert req.uri == '/foo' + + def test_scheme_property(self, app): + """scheme property returns wsgi.url_scheme.""" + req = Request(_make_env(**{'wsgi.url_scheme': 'https', + 'SERVER_PORT': '443'}), app) + assert req.scheme == 'https' + + def test_document_index_from_app(self, app): + """document_index falls back to app.document_index.""" + req = Request(_make_env(), app) + assert req.document_index == app.document_index + + def test_start_time_and_end_time(self, app): + """start_time and end_time return timestamps.""" + req = Request(_make_env(), app) + assert isinstance(req.start_time, float) + assert isinstance(req.end_time, float) + + def test_authorization_utf8_username(self, app): + """authorization decodes RFC 5987 UTF-8'' encoded username.""" + # username*=UTF-8''Ond%C5%99ej decodes to 'Ondřej' + header = "Digest username*=UTF-8''Ond%C5%99ej, realm=\"test\"" + req = Request(_make_env(HTTP_AUTHORIZATION=header), app) + auth = req.authorization + assert auth.get('username') == 'Ondřej' + + def test_args_setter_ignored_when_not_empty_form(self, app): + """args.setter does nothing when args is already set.""" + req = Request(_make_env(QUERY_STRING='x=1'), app) + original = req.args + req.args = EmptyForm() # setter should ignore this + assert req.args is original + + def test_input_creates_cached_input(self): + """input property creates CachedInput when cached_size > 0 and + wsgi.input is not BytesIO (auto_data must be off).""" + class _RawIO: + def read(self, n=-1): + return b'x' * n if n > 0 else b'' + + local_app = Application("_test_cached_input_creation") + local_app.auto_data = False + local_app.auto_json = False + local_app.auto_form = False + local_app.cached_size = 4096 + + env = _make_env( + REQUEST_METHOD='POST', + CONTENT_LENGTH='10', + CONTENT_TYPE='application/octet-stream', + **{'wsgi.input': _RawIO()}, + ) + req = Request(env, local_app) + # Access twice; second call returns cached instance (line 719). + first = req.input + second = req.input + assert first is second + assert isinstance(first, CachedInput) + + def test_auto_args_false_uses_empty_form(self): + """When auto_args=False, req.args is an EmptyForm.""" + local_app = Application("_test_auto_args_false") + local_app.auto_args = False + req = Request(_make_env(QUERY_STRING='x=1'), local_app) + assert isinstance(req.args, EmptyForm) + + def test_args_setter_sets_when_empty_form(self): + """args.setter replaces EmptyForm (line 655).""" + local_app = Application("_test_args_setter_line655") + local_app.auto_args = False + req = Request(_make_env(QUERY_STRING='x=1'), local_app) + assert isinstance(req.args, EmptyForm) + new_args = Args(req) + req.args = new_args + assert req.args is new_args + + def test_document_root_property(self, app): + """document_root returns poor_DocumentRoot or app.document_root.""" + req = Request(_make_env(), app) + _ = req.document_root # covers line 328 diff --git a/tests/test_responses.py b/tests/test_responses.py index 1ef9e4e..af87474 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -1,33 +1,49 @@ """Tests for Response objects and their functionality.""" +import re +import warnings from datetime import datetime, timezone -from io import BufferedWriter, BytesIO +from io import BufferedWriter, BytesIO, RawIOBase +from unittest.mock import patch import pytest from simplejson import load, loads from poorwsgi.request import Headers from poorwsgi.response import ( + BaseResponse, + Declined, + EmptyResponse, FileObjResponse, FileResponse, GeneratorResponse, HTTPException, + IBytesIO, JSONGeneratorResponse, JSONResponse, + NoContentResponse, NotModifiedResponse, PartialResponse, RedirectResponse, Response, + ResponseError, StrGeneratorResponse, TextResponse, abort, + make_response, redirect, ) -from poorwsgi.state import HTTP_NOT_FOUND # , HTTP_RANGE_NOT_SATISFIABLE +from poorwsgi.state import ( + DECLINED, + HTTP_NOT_FOUND, + HTTP_OK, + HTTP_PARTIAL_CONTENT, +) # , HTTP_RANGE_NOT_SATISFIABLE # pylint: disable=missing-function-docstring # pylint: disable=redefined-outer-name # pylint: disable=no-self-use +# pylint: disable=too-many-lines args = ( @@ -652,9 +668,11 @@ def test_partial_content_mid(self): def test_partial_content_last(self): """Tests partial content response for FileResponse requesting the last N bytes.""" + with open(__file__, "rb") as fh: + last4 = fh.read()[-4:] res = FileResponse(__file__) res.make_partial([(None, 4)]) - assert res(start_response).read() == b"one\n" + assert res(start_response).read() == last4 assert int(res.headers.get("Content-Length")) == 4 @@ -695,3 +713,597 @@ def test_date_empty_string(self): Date header.""" res = NotModifiedResponse(date="") assert res.headers.get("Date") is None + + def test_status_line_format(self): + """304 response status line must be '304 Not Modified'.""" + received = [] + + def capture(status, _headers): + received.append(status) + return lambda _data: None + + res = NotModifiedResponse(etag='"abc"') + res(capture) + assert received[0] == "304 Not Modified" + + +class TestStatusLineFormat: + """Verify that all response types emit properly formatted status lines.""" + + _status_re = re.compile(r"^\d{3} \S") + + def _capture(self): + received = [] + + def sr(status, _headers): + received.append(status) + return lambda _d: None + + return sr, received + + def test_response_200(self): + """Response emits '200 OK' status line.""" + sr, received = self._capture() + Response(b"hello")(sr) + assert received[0] == "200 OK" + + def test_response_404(self): + """Response with HTTP_NOT_FOUND emits '404 Not Found'.""" + sr, received = self._capture() + Response(b"nope", status_code=HTTP_NOT_FOUND)(sr) + assert received[0] == "404 Not Found" + + def test_redirect_302(self): + """RedirectResponse emits '302 Found' status line.""" + sr, received = self._capture() + RedirectResponse("/new")(sr) + assert received[0] == "302 Found" + + def test_redirect_301(self): + """RedirectResponse permanent emits '301 Moved Permanently'.""" + sr, received = self._capture() + RedirectResponse("/new", 301)(sr) + assert received[0] == "301 Moved Permanently" + + def test_partial_response(self): + """PartialResponse emits '206 Partial Content' status line.""" + sr, received = self._capture() + res = PartialResponse(b"56789") + res.make_range([(5, 9)]) + res(sr) + assert received[0] == "206 Partial Content" + + def test_no_content_response(self): + """NoContentResponse emits '204 No Content' status line.""" + sr, received = self._capture() + NoContentResponse()(sr) + assert received[0] == "204 No Content" + + def test_status_line_pattern(self): + """Status line must start with three digits followed by a space.""" + sr, received = self._capture() + Response(b"data")(sr) + assert self._status_re.match(received[0]) + + def test_304_deny_headers_warning(self): + """BaseResponse with status 304 warns when representation headers + (Content-Type etc.) are present.""" + res = Response( + headers={"Content-Type": "text/html", "ETag": '"abc"'}, + status_code=304, + ) + with patch("poorwsgi.response.log") as mock_log: + res(lambda _s, _h: lambda _d: None) + warning_msgs = [str(call) for call in mock_log.warning.call_args_list] + assert any("representation" in m for m in warning_msgs) + + def test_304_no_required_headers_warning(self): + """BaseResponse with status 304 warns when none of Date/ETag/Vary/ + Content-Location are present.""" + res = Response(b"", status_code=304) + with patch("poorwsgi.response.log") as mock_log: + res(lambda _s, _h: lambda _d: None) + warning_msgs = [str(call) for call in mock_log.warning.call_args_list] + assert any("required" in m or "Missing" in m for m in warning_msgs) + + +class TestContentLengthAccuracy: + """Content-Length header must exactly match the body length.""" + + def test_response_bytes(self): + """Content-Length matches byte body length.""" + body = b"Hello, World!" + res = Response(body) + res(lambda _s, _h: lambda _d: None) + assert int(res.headers["Content-Length"]) == len(body) + + def test_response_string_utf8(self): + """Content-Length reflects encoded byte length for multi-byte chars.""" + text = "Čeština" + body = text.encode("utf-8") + res = Response(text) + res(lambda _s, _h: lambda _d: None) + assert int(res.headers["Content-Length"]) == len(body) + + def test_response_write(self): + """Content-Length is updated after calling write().""" + res = Response(b"Hello") + res.write(b" World") + res(lambda _s, _h: lambda _d: None) + assert int(res.headers["Content-Length"]) == 11 + + def test_json_response(self): + """JSONResponse Content-Length matches the JSON-encoded body.""" + res = JSONResponse(x=1) + body = res.data + res(lambda _s, _h: lambda _d: None) + assert int(res.headers["Content-Length"]) == len(body) + + def test_partial_content_length(self): + """Partial response Content-Length reflects the slice.""" + res = Response(b"0123456789") + res.make_partial([(2, 5)]) + res(lambda _s, _h: lambda _d: None) + assert int(res.headers["Content-Length"]) == 4 # bytes 2,3,4,5 + + def test_body_matches_content_length(self): + """Actual bytes returned equal declared Content-Length.""" + res = Response(b"0123456789") + buf = res(lambda _s, _h: lambda _d: None) + data = buf.read() + assert len(data) == 10 + + +class TestIBytesIO: + """Tests for IBytesIO helper class.""" + + def test_read_kilo(self): + """read_kilo returns up to 1024 bytes.""" + buf = IBytesIO(b"x" * 2048) + chunk = buf.read_kilo() + assert len(chunk) == 1024 + + def test_iteration(self): + """Iterating IBytesIO yields 1024-byte chunks.""" + buf = IBytesIO(b"y" * 2000) + chunks = list(buf) + assert len(chunks) == 2 + assert len(chunks[0]) == 1024 + assert len(chunks[1]) == 976 + + +class TestBaseResponse: + """Tests for BaseResponse standalone behavior.""" + + def test_headers_setter_list(self): + """Setting headers from a list creates a Headers instance.""" + res = BaseResponse() + res.headers = [("X-Foo", "bar")] + assert res.headers["X-Foo"] == "bar" + + def test_headers_setter_headers(self): + """Setting headers from a Headers object keeps it as-is.""" + h = Headers([("X-Foo", "bar")]) + res = BaseResponse() + res.headers = h + assert res.headers["X-Foo"] == "bar" + + def test_data_property(self): + """BaseResponse.data always returns empty bytes.""" + res = BaseResponse() + assert res.data == b"" + + def test_make_partial_non_200_status(self): + """make_partial is a no-op when status_code is not 200.""" + res = BaseResponse(status_code=HTTP_NOT_FOUND) + res.make_partial([(0, 10)]) + assert not res.ranges + + def test_make_range_non_206_status(self): + """make_range is a no-op when status_code is not 206.""" + res = BaseResponse(status_code=HTTP_OK) + res.make_range([(0, 10)]) + assert not res.ranges + + def test_add_header(self): + """add_header delegates to the Headers object.""" + res = BaseResponse() + res.add_header("X-Test", "value") + assert res.headers["X-Test"] == "value" + + def test_status_code_invalid(self): + """Setting an invalid status code raises ValueError.""" + res = BaseResponse() + with pytest.raises(ValueError, match="Bad response status"): + res.status_code = 999 + + def test_status_code_setter_updates_reason(self): + """Setting status_code updates the reason phrase automatically.""" + res = BaseResponse() + res.status_code = HTTP_NOT_FOUND + assert res.reason == "Not Found" + + def test_status_code_setter_clears_ranges(self): + """Changing status_code away from 200/206 clears ranges.""" + res = BaseResponse() + res.make_partial([(0, 10)]) + assert res.ranges # ranges set + res.status_code = HTTP_NOT_FOUND + assert not res.ranges + + def test_make_partial_inconsistent_range(self): + """make_partial logs a warning for end < start and skips the range.""" + res = BaseResponse() + with patch("poorwsgi.response.log") as mock_log: + res.make_partial([(10, 5)]) # end < start + assert mock_log.warning.called + assert not res.ranges + + def test_make_range_multiple_ranges_warning(self): + """make_range logs a warning when more than one range is given.""" + res = Response(b"0123456789", status_code=HTTP_PARTIAL_CONTENT) + with patch("poorwsgi.response.log") as mock_log: + res.make_range([(0, 3), (5, 8)]) + warning_msgs = [str(c) for c in mock_log.warning.call_args_list] + assert any("one range" in m for m in warning_msgs) + + def test_make_range_none_in_mixed_ranges(self): + """make_range warns for None start/end and still uses valid range.""" + res = Response(b"0123456789", status_code=HTTP_PARTIAL_CONTENT) + with patch("poorwsgi.response.log") as mock_log: + res.make_range([(None, 5), (1, 3)]) + warning_msgs = " ".join( + str(c) for c in mock_log.warning.call_args_list + ) + assert "full range" in warning_msgs + assert res.ranges == ((1, 3),) + + def test_make_range_inconsistent_in_mixed_ranges(self): + """make_range warns for end < start and still uses the valid range.""" + res = Response(b"0123456789", status_code=HTTP_PARTIAL_CONTENT) + with patch("poorwsgi.response.log") as mock_log: + res.make_range([(10, 5), (1, 3)]) + warning_msgs = " ".join( + str(c) for c in mock_log.warning.call_args_list + ) + assert "Inconsistent" in warning_msgs + assert res.ranges == ((1, 3),) + + +class TestNoContentAndDeclined: + """Tests for NoContentResponse, EmptyResponse (deprecated), Declined.""" + + def test_no_content_status(self): + """NoContentResponse defaults to 204 No Content.""" + res = NoContentResponse() + assert res.status_code == 204 + + def test_no_content_call(self): + """NoContentResponse emits empty header list via start_response.""" + received = [] + res = NoContentResponse() + res(lambda s, h: received.append((s, h)) or (lambda _d: None)) + assert received[0][0] == "204 No Content" + assert received[0][1] == [] + + def test_empty_response_deprecated(self): + """EmptyResponse emits a deprecation warning on construction.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + EmptyResponse() + assert any(issubclass(warning.category, DeprecationWarning) + for warning in w) + + def test_declined_call_returns_empty(self): + """Declined.__call__ returns an empty tuple without calling + start_response.""" + called = [] + res = Declined() + result = res(lambda _s, _h: called.append(1) or (lambda _d: None)) + assert not result + assert not called + + def test_declined_headers_warning(self): + """Setting headers on a Declined response logs a warning.""" + res = Declined() + with patch("poorwsgi.response.log") as mock_log: + res.headers = [("X-Foo", "bar")] + assert mock_log.warning.called + + def test_declined_add_header_warning(self): + """Calling add_header on a Declined response logs a warning.""" + res = Declined() + with patch("poorwsgi.response.log") as mock_log: + res.add_header("X-Foo", "bar") + assert mock_log.warning.called + + +class TestHTTPExceptionMakeResponse: + """Tests for HTTPException.make_response and response property.""" + + def test_make_response_with_response(self): + """make_response returns the wrapped response object.""" + inner = Response(b"body", status_code=201) + exc = HTTPException(inner) + assert exc.make_response() is inner + + def test_make_response_declined(self): + """make_response for DECLINED returns a Declined instance.""" + exc = HTTPException(DECLINED) + result = exc.make_response() + assert isinstance(result, Declined) + + def test_make_response_200_returns_empty(self): + """make_response for 200 returns an EmptyResponse.""" + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + exc = HTTPException(HTTP_OK) + result = exc.make_response() + assert isinstance(result, EmptyResponse) + + def test_make_response_other_returns_none(self): + """make_response for other status codes returns None.""" + exc = HTTPException(404) + assert exc.make_response() is None + + def test_response_property_int(self): + """response property returns None when arg is an int.""" + exc = HTTPException(404) + assert exc.response is None + + def test_response_property_response(self): + """response property returns the response when arg is a response.""" + inner = Response(b"body") + exc = HTTPException(inner) + assert exc.response is inner + + +class TestMakeResponse: + """Tests for the make_response factory function.""" + + def test_string_returns_response(self): + """make_response with a string returns a Response.""" + res = make_response("Hello") + assert isinstance(res, Response) + assert res.data == b"Hello" + + def test_bytes_returns_response(self): + """make_response with bytes returns a Response.""" + res = make_response(b"data") + assert isinstance(res, Response) + assert res.data == b"data" + + def test_dict_returns_json(self): + """make_response with a dict returns a JSONResponse.""" + res = make_response({"key": "val"}) + assert isinstance(res, JSONResponse) + assert b'"key"' in res.data + + def test_list_of_non_bytes_returns_json(self): + """make_response with a list of non-bytes returns JSONResponse.""" + res = make_response([1, 2, 3]) + assert isinstance(res, JSONResponse) + + def test_list_of_bytes_returns_generator(self): + """make_response with a list of bytes returns GeneratorResponse.""" + res = make_response([b"a", b"b"]) + assert isinstance(res, GeneratorResponse) + + def test_none_returns_no_content(self): + """make_response with None returns NoContentResponse with 204.""" + res = make_response(None) + assert isinstance(res, NoContentResponse) + assert res.status_code == 204 + + def test_none_explicit_status(self): + """make_response with None and explicit status keeps that status.""" + res = make_response(None, status_code=201) + assert res.status_code == 201 + + def test_bytes_custom_status(self): + """make_response passes status_code through to the Response.""" + res = make_response(b"err", status_code=HTTP_NOT_FOUND) + assert res.status_code == HTTP_NOT_FOUND + + def test_invalid_raises_response_error(self): + """make_response with an unsupported type raises ResponseError.""" + with pytest.raises(ResponseError): + make_response(12345) + + +class TestRedirectHTTP: + """HTTP-level tests for redirect responses.""" + + def test_location_header(self): + """RedirectResponse sets Location header to the given URL.""" + res = RedirectResponse("/new-path") + res(lambda _s, _h: lambda _d: None) + assert res.headers["Location"] == "/new-path" + + def test_redirect_function_location(self): + """redirect() raises HTTPException with Location header set.""" + with pytest.raises(HTTPException) as exc_info: + redirect("/target") + exc_info.value.response(lambda _s, _h: lambda _d: None) + assert exc_info.value.response.headers["Location"] == "/target" + + def test_redirect_message_body(self): + """RedirectResponse body carries the supplied message text.""" + res = RedirectResponse("/go", message="Moved here") + buf = res(lambda _s, _h: lambda _d: None) + assert buf.read() == b"Moved here" + + def test_redirect_content_type(self): + """RedirectResponse Content-Type is text/plain.""" + res = RedirectResponse("/go") + res(lambda _s, _h: lambda _d: None) + assert res.headers["Content-Type"] == "text/plain" + + +class TestFileResponseHTTP: + """HTTP-level tests for file-based responses.""" + + def test_last_modified_format(self): + """Last-Modified header is in RFC 7231 date-time format.""" + rfc_re = re.compile( + r"^[A-Z][a-z]{2}, \d{2} [A-Z][a-z]{2} \d{4} \d{2}:\d{2}:\d{2} GMT$" + ) + res = FileResponse(__file__) + assert rfc_re.match(res.headers["Last-Modified"]) + + def test_accept_ranges_bytes(self): + """FileResponse sets Accept-Ranges: bytes before calling response.""" + res = FileResponse(__file__) + assert res.headers["Accept-Ranges"] == "bytes" + + def test_content_length_positive(self): + """FileResponse Content-Length is positive.""" + res = FileResponse(__file__) + res(lambda _s, _h: lambda _d: None) + assert int(res.headers["Content-Length"]) > 0 + + def test_file_obj_response_bytesio(self): + """FileObjResponse works with an in-memory BytesIO object.""" + buf = BytesIO(b"hello world") + res = FileObjResponse(buf) + res(lambda _s, _h: lambda _d: None) + assert int(res.headers["Content-Length"]) == 11 + + def test_file_obj_response_content_type_default(self): + """FileObjResponse defaults Content-Type to + application/octet-stream.""" + buf = BytesIO(b"data") + res = FileObjResponse(buf) + res(lambda _s, _h: lambda _d: None) + assert res.headers["Content-Type"] == "application/octet-stream" + + def test_file_obj_response_data(self): + """FileObjResponse.data reads file content from current position.""" + buf = BytesIO(b"hello") + res = FileObjResponse(buf) + assert res.data == b"hello" + + def test_file_obj_unknown_size(self): + """FileObjResponse handles a stream with no fileno and no getbuffer, + defaulting content_length to 0.""" + + class _FakeStream(RawIOBase): + def readable(self): + return True + + def read(self, _n=-1): + return b"" + + def readinto(self, _b): + return 0 + + def seekable(self): + return False + + def fileno(self): + raise OSError("no fileno") + + stream = _FakeStream() + with patch("poorwsgi.response.log"): + res = FileObjResponse(stream) + assert res.content_length == 0 + + def test_file_obj_data_closed(self): + """FileObjResponse.data returns b'' and logs warning when closed.""" + buf = BytesIO(b"hello") + res = FileObjResponse(buf) + buf.close() + with patch("poorwsgi.response.log") as mock_log: + result = res.data + assert result == b"" + assert mock_log.warning.called + + def test_file_obj_data_non_seekable(self): + """FileObjResponse.data returns b'' when file is not seekable.""" + + class _NonSeekable(RawIOBase): + def readable(self): + return True + + def read(self, _n=-1): + return b"" + + def readinto(self, _b): + return 0 + + def seekable(self): + return False + + def fileno(self): + raise OSError("no fileno") + + stream = _NonSeekable() + with patch("poorwsgi.response.log"): + res = FileObjResponse(stream) + with patch("poorwsgi.response.log") as mock_log: + result = res.data + assert result == b"" + assert mock_log.info.called + + def test_file_obj_end_of_response_closed(self): + """FileObjResponse.__end_of_response__ returns empty IBytesIO and + logs error when the underlying file has been closed.""" + buf = BytesIO(b"hello") + res = FileObjResponse(buf) + buf.close() + with patch("poorwsgi.response.log") as mock_log: + result = res.__end_of_response__() + assert result.read() == b"" + assert mock_log.error.called + + def test_file_response_unreadable(self): + """FileResponse raises IOError when the file is not readable.""" + with pytest.raises(IOError, match="Could not stat file"): + FileResponse("/nonexistent/path/to/file.txt") + + def test_declined_headers_property(self): + """Declined.headers property always returns a fresh empty Headers.""" + res = Declined() + assert len(list(res.headers.items())) == 0 + + +class TestResponseWrite: + """Tests for Response.write() and data property edge cases.""" + + def test_write_increases_content_length(self): + """write() increases content_length by the written byte count.""" + res = Response(b"abc") + assert res.content_length == 3 + res.write(b"de") + assert res.content_length == 5 + + def test_write_string(self): + """write() accepts str and encodes it to UTF-8.""" + res = Response(b"") + res.write("Čeština") + assert res.content_length == len("Čeština".encode("utf-8")) + + def test_data_after_write(self): + """data property reflects all bytes in the buffer after write().""" + res = Response(b"") + res.write(b"Hello World") + assert res.data == b"Hello World" + + def test_data_closed_buffer_returns_empty(self): + """data property returns b'' and logs warning when buffer is closed.""" + res = Response(b"hello") + res._Response__buffer.close() # pylint: disable=protected-access + with patch("poorwsgi.response.log") as mock_log: + result = res.data + assert result == b"" + assert mock_log.warning.called + + def test_end_of_response_closed_buffer(self): + """__end_of_response__ returns empty IBytesIO and logs error + when buffer is closed.""" + res = Response(b"hello") + res._Response__buffer.close() # pylint: disable=protected-access + with patch("poorwsgi.response.log") as mock_log: + result = res.__end_of_response__() + assert result.read() == b"" + assert mock_log.error.called diff --git a/tests/test_results.py b/tests/test_results.py new file mode 100644 index 0000000..4a7cbeb --- /dev/null +++ b/tests/test_results.py @@ -0,0 +1,767 @@ +"""Unit tests for poorwsgi/results.py — default HTTP handlers.""" +# pylint: disable=too-many-lines +import os +import tempfile +from collections import defaultdict +from hashlib import sha256 +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest + +from poorwsgi import Application +from poorwsgi.request import Request +from poorwsgi.response import HTTPException, NotModifiedResponse +from poorwsgi.results import ( + bad_request, + debug_info, + directory_index, + forbidden, + hbytes, + html_escape, + human_methods_, + handlers_view, + internal_server_error, + method_not_allowed, + not_found, + not_implemented, + not_modified, + unauthorized, +) +from poorwsgi.state import ( + HTTP_BAD_REQUEST, + HTTP_FORBIDDEN, + HTTP_INTERNAL_SERVER_ERROR, + HTTP_METHOD_NOT_ALLOWED, + HTTP_NOT_FOUND, + HTTP_NOT_IMPLEMENTED, + HTTP_NOT_MODIFIED, + HTTP_UNAUTHORIZED, + METHOD_ALL, + METHOD_GET, + METHOD_POST, +) + +# pylint: disable=missing-function-docstring +# pylint: disable=no-self-use +# pylint: disable=redefined-outer-name + + +@pytest.fixture(scope="module") +def digest_app(): + """Application configured for Digest authentication.""" + app = Application("results_test") + app.secret_key = "testsecret" # noqa: S105 + app.auth_type = "Digest" + return app + + +def _call(res): + """Call a Response and return (status_code, headers_dict, body_bytes).""" + status_holder = [] + hdrs = {} + + def start_response(status, headers): + status_holder.append(int(status.split()[0])) + hdrs.update(headers) + + body = b"".join(res(start_response)) + return status_holder[0], hdrs, body + + +def _make_req(**kwargs): + """Return a MagicMock request with sensible defaults.""" + req = MagicMock() + req.method = kwargs.get("method", "GET") + req.uri = kwargs.get("uri", "/test") + req.server_admin = kwargs.get("server_admin", "admin@example.com") + req.debug = kwargs.get("debug", False) + req.headers = MagicMock() + req.headers.get = lambda k, d=None: kwargs.get(f"hdr_{k}", d) + req.headers.items = lambda: kwargs.get("headers_items", {}).items() + req.path = kwargs.get("path", req.uri) + return req + + +# --------------------------------------------------------------------------- +# Utility helpers +# --------------------------------------------------------------------------- + +class TestHtmlEscape: + """html_escape must neutralise all five HTML special characters.""" + + def test_ampersand(self): + assert html_escape("a&b") == "a&b" + + def test_double_quote(self): + assert html_escape('"value"') == ""value"" + + def test_single_quote(self): + assert html_escape("it's") == "it's" + + def test_less_than(self): + assert html_escape("") == "<tag>" + + def test_greater_than(self): + assert html_escape("a>b") == "a>b" + + def test_plain_text_unchanged(self): + assert html_escape("hello world 123") == "hello world 123" + + def test_xss_payload(self): + raw = '' + escaped = html_escape(raw) + assert "' + _, _, body = _call(bad_request(_make_req(uri=xss_uri))) + assert b"')) + ) + assert b"' + _, _, body = _call(internal_server_error(req)) + assert b"" not in body + + def test_active_exception_traceback(self): + """When called inside an except block, the traceback is rendered.""" + req = self._req(debug=True) + body = b"" + try: + raise ValueError("intentional test error") + except ValueError: + _, _, body = _call(internal_server_error(req)) + assert b"intentional test error" in body + + +class TestDirectoryIndex: + """directory_index → HTML directory listing (WSGI static file serving).""" + + def _req(self, root, uri="/files/", debug=False): + req = MagicMock() + req.document_root = root + req.uri = uri + req.debug = debug + req.server_software = "TestServer/1.0" + req.server_admin = "admin@example.com" + return req + + def test_not_a_directory_raises(self): + """Passing a file path instead of directory → HTTP 500.""" + with tempfile.NamedTemporaryFile() as f: + req = self._req(root="/") + with pytest.raises(HTTPException) as exc_info: + directory_index(req, f.name) + assert exc_info.value.status_code == HTTP_INTERNAL_SERVER_ERROR + + def test_returns_tuple(self): + with tempfile.TemporaryDirectory() as d: + result = directory_index(self._req(root="/"), d) + assert isinstance(result, tuple) + assert len(result) == 3 + + def test_content_type(self): + with tempfile.TemporaryDirectory() as d: + _, content_type, _ = directory_index(self._req(root="/"), d) + assert "text/html" in content_type + + def test_last_modified_header(self): + """Third element must be the Last-Modified header tuple.""" + with tempfile.TemporaryDirectory() as d: + _, _, header = directory_index(self._req(root="/"), d) + assert header[0] == "Last-Modified" + # RFC 7231 §7.1.1: HTTP-date format + assert "GMT" in header[1] + + def test_html_structure(self): + """Response must be valid HTML with a listing table.""" + with tempfile.TemporaryDirectory() as d: + content, _, _ = directory_index(self._req(root="/"), d) + assert "" in content + assert "" in content + + def test_regular_file_listed(self): + """Regular files in the directory must appear in the listing.""" + with tempfile.TemporaryDirectory() as d: + with open(os.path.join(d, "readme.txt"), "w", + encoding="utf-8") as f: + f.write("") + content, _, _ = directory_index(self._req(root="/"), d) + assert "readme.txt" in content + + def test_dot_files_hidden(self): + """Files starting with '.' (other than '..') must be excluded.""" + with tempfile.TemporaryDirectory() as d: + for name in (".hidden", "visible.txt"): + with open(os.path.join(d, name), "w", encoding="utf-8") as f: + f.write("") + content, _, _ = directory_index(self._req(root="/"), d) + assert ".hidden" not in content + assert "visible.txt" in content + + def test_backup_files_hidden(self): + """Files ending with '~' (editor backups) must be excluded.""" + with tempfile.TemporaryDirectory() as d: + for name in ("file.py~", "file.py"): + with open(os.path.join(d, name), "w", encoding="utf-8") as f: + f.write("") + content, _, _ = directory_index(self._req(root="/"), d) + assert "file.py~" not in content + assert "file.py" in content + + def test_subdirectory_listed(self): + """Subdirectories must appear with trailing slash.""" + with tempfile.TemporaryDirectory() as d: + os.makedirs(os.path.join(d, "subdir")) + content, _, _ = directory_index(self._req(root="/"), d) + assert "subdir/" in content + + def test_parent_link_when_not_root(self): + """Parent directory '..' link must appear when not at document root.""" + with tempfile.TemporaryDirectory() as d: + req = self._req(root="/other/root", uri="/files/") + content, _, _ = directory_index(req, d) + assert ".." in content + + def test_no_parent_link_at_document_root(self): + """'..' must NOT appear when the path IS the document root.""" + with tempfile.TemporaryDirectory() as d: + req = self._req(root=d[:-1], uri="/") + content, _, _ = directory_index(req, d) + assert "../" not in content + + def test_unknown_extension_falls_back_to_octet_stream(self): + """Files with unknown extensions use application/octet-stream.""" + with tempfile.TemporaryDirectory() as d: + with open(os.path.join(d, "data.xyzzy"), + "w", encoding="utf-8") as f: + f.write("x") + content, _, _ = directory_index(self._req(root="/"), d) + assert "octet-stream" in content + + def test_unreadable_file_skipped(self): + """Files that fail os.access(R_OK) must not appear in the listing.""" + with tempfile.TemporaryDirectory() as d: + with open(os.path.join(d, "secret.txt"), + "w", encoding="utf-8") as f: + f.write("x") + with patch("poorwsgi.results.os.access", return_value=False): + content, _, _ = directory_index(self._req(root="/"), d) + assert "secret.txt" not in content + + def test_xss_in_uri_escaped(self): + """URI is HTML-escaped in the page title.""" + with tempfile.TemporaryDirectory() as d: + req = self._req(root="/", uri='//') + content, _, _ = directory_index(req, d) + assert "" not in content + assert "<script>" in content + + def test_debug_shows_server_software(self): + """In debug mode, server_software string appears in the footer.""" + with tempfile.TemporaryDirectory() as d: + req = self._req(root="/", debug=True) + content, _, _ = directory_index(req, d) + assert "TestServer/1.0" in content + + def test_no_debug_hides_server_software(self): + """Without debug, server_software must NOT leak into the page.""" + with tempfile.TemporaryDirectory() as d: + req = self._req(root="/", debug=False) + content, _, _ = directory_index(req, d) + assert "TestServer/1.0" not in content + + +class TestDebugInfo: + """debug_info → HTML debugging page (application introspection).""" + + @pytest.fixture(scope="class") + def app_req(self): + app = Application("results_debug") + app.secret_key = "testsecret" # noqa: S105 + + env = defaultdict(str) + env["PATH_INFO"] = "/debug-info" + env["SERVER_PORT"] = "80" + env["SERVER_NAME"] = "localhost" + env["HTTP_HOST"] = "localhost" + env["REQUEST_METHOD"] = "GET" + req = Request(env, app) + return req, app + + def test_returns_string(self, app_req): + req, app = app_req + result = debug_info(req, app) + assert isinstance(result, str) + + def test_html_structure(self, app_req): + req, app = app_req + result = debug_info(req, app) + assert "" in result + assert " str: + """Build a PoorSession-compatible cookie with arbitrary (possibly non-dict) + data so we can craft edge-case payloads in tests.""" + secret_hash = shake_256(b'ks\x00' + key).digest(KEYSTREAM_SIZE) + mac_key = shake_256(b'mac\x00' + key).digest(32) + table = bytearray(range(256)) + perm_seed = shake_256(b'perm\x00' + key).digest(32) + Random(perm_seed).shuffle(table) # nosec # noqa: S311 + payload = bz2.compress( + encrypt(hidden(dumps(data), secret_hash), table), 9) + sig = _hmac.digest(mac_key, payload, 'sha256') + return b64encode(payload).decode() + '.' + b64encode(sig).decode() + + +class MockHeaders: + """Minimal stand-in for Headers/Response that records add_header calls.""" + def __init__(self): + self.headers = [] + + def add_header(self, name, value): + self.headers.append((name, value)) + + class Request: """A mock Request class.""" secret_key = SECRET_KEY @@ -42,6 +75,68 @@ def req_session(): return request +class TestNoCompress: + """Tests for the NoCompress pass-through class.""" + + # pylint: disable=no-self-use + + def test_compress_returns_data_unchanged(self): + data = b"hello world" + assert NoCompress.compress(data) == data + + def test_decompress_returns_data_unchanged(self): + data = b"hello world" + assert NoCompress.decompress(data) == data + + +class TestTokens: + """Tests for get_token and check_token helper functions.""" + + # pylint: disable=no-self-use + + def test_get_token_without_timeout(self): + token = get_token("secret", "client") + assert isinstance(token, str) + assert len(token) == 64 # sha3_256 hex digest + + def test_get_token_is_deterministic(self): + assert get_token("secret", "client") == get_token("secret", "client") + + def test_get_token_with_timeout(self): + token = get_token("secret", "client", timeout=300) + assert isinstance(token, str) + assert len(token) == 64 + + def test_get_token_with_explicit_expired(self): + token = get_token("secret", "client", timeout=300, expired=9999999999) + assert isinstance(token, str) + + def test_check_token_without_timeout_valid(self): + token = get_token("secret", "client") + assert check_token(token, "secret", "client") is True + + def test_check_token_without_timeout_invalid(self): + assert check_token("badtoken", "secret", "client") is False + + def test_check_token_with_timeout_valid(self): + token = get_token("secret", "client", timeout=300) + assert check_token(token, "secret", "client", timeout=300) is True + + def test_check_token_with_timeout_first_window_match(self): + """check_token must return True on first window match (covers early + return branch).""" + timeout = 300 + now = int(_time() / timeout) * timeout + # check_token tries expired = now + timeout first + token = get_token("secret", "client", timeout=timeout, + expired=now + timeout) + assert check_token(token, "secret", "client", timeout=timeout) is True + + def test_check_token_with_timeout_invalid(self): + assert check_token( + "badtoken", "secret", "client", timeout=300) is False + + class TestPoorSession: """Tests PoorSession configuration options.""" @@ -133,8 +228,8 @@ def test_default(self): assert "; SameSite" not in headers[0][1] def test_none(self): - """Tests PoorSession with SameSite set to 'None'.""" - session = PoorSession(SECRET_KEY, same_site="None") + """Tests PoorSession with SameSite='None' (requires secure=True).""" + session = PoorSession(SECRET_KEY, same_site="None", secure=True) headers = session.header() assert "; SameSite=None" in headers[0][1] @@ -150,6 +245,16 @@ def test_strict(self): headers = session.header() assert "; SameSite=Strict" in headers[0][1] + def test_invalid_value_raises(self): + """Unrecognised same_site value must raise ValueError.""" + with raises(ValueError, match="is not valid"): + PoorSession(SECRET_KEY, same_site=True) + + def test_none_without_secure_raises(self): + """same_site='None' without secure=True must raise ValueError.""" + with raises(ValueError, match="requires secure=True"): + PoorSession(SECRET_KEY, same_site="None", secure=False) + class TestErrors: """Tests exceptions.""" @@ -176,6 +281,50 @@ def test_bad_session(self): with raises(SessionError): session.load(cookies) + def test_string_secret_key(self): + """PoorSession accepts a str key (encodes to bytes internally).""" + session = PoorSession("string-secret-key") + session.data['x'] = 1 + session.write() + session2 = PoorSession("string-secret-key") + session2.load(session.cookie) + assert session2.data == {'x': 1} + + def test_load_empty_cookie_value(self): + """load() with an empty cookie value leaves data unchanged.""" + cookies = SimpleCookie() + cookies['SESSID'] = '' + session = PoorSession(SECRET_KEY) + session.load(cookies) + assert session.data == {} + + def test_load_no_dot_separator(self): + """Cookie without a '.' separator must raise SessionError.""" + cookies = SimpleCookie() + cookies['SESSID'] = b64encode(b'nodot').decode() + session = PoorSession(SECRET_KEY) + with raises(SessionError): + session.load(cookies) + + def test_load_short_signature(self): + """Cookie with a signature shorter than 32 bytes must raise + SessionError.""" + payload = b64encode(b'somepayload').decode() + sig = b64encode(b'short').decode() + cookies = SimpleCookie() + cookies['SESSID'] = f'{payload}.{sig}' + session = PoorSession(SECRET_KEY) + with raises(SessionError): + session.load(cookies) + + def test_load_non_dict_data(self): + """Non-dict cookie data must raise SessionError.""" + cookies = SimpleCookie() + cookies['SESSID'] = _make_poor_cookie(SECRET_KEY, [1, 2, 3]) + session = PoorSession(SECRET_KEY) + with raises(SessionError): + session.load(cookies) + def test_bad_session_compatibility(self, req): """Tests PoorSession compatibility with a bad session cookie, expecting SessionError.""" @@ -281,6 +430,30 @@ def test_destroy(self): headers = session.header() assert "expires=" in headers[0][1] + def test_destroy_with_max_age(self): + """destroy() must also set Max-Age=-1 when max_age was configured.""" + session = Session(max_age=3600) + session.destroy() + headers = session.header() + assert "Max-Age=-1" in headers[0][1] + assert "Max-Age=3600" not in headers[0][1] + + def test_destroy_with_secure(self): + """destroy() must preserve the Secure flag when configured.""" + session = Session(secure=True) + session.destroy() + headers = session.header() + assert "Secure" in headers[0][1] + + def test_header_writes_to_headers_object(self): + """header(obj) must call obj.add_header() for each cookie header.""" + session = Session() + session.data = "tok" + mock = MockHeaders() + returned = session.header(mock) + assert len(mock.headers) == len(returned) + assert mock.headers == returned + def test_expires(self): """Tests Session with an expires setting.""" session = Session(expires=3600) @@ -305,6 +478,16 @@ def test_same_site(self): headers = session.header() assert "SameSite=Strict" in headers[0][1] + def test_same_site_invalid_raises(self): + """Unrecognised same_site value must raise ValueError.""" + with raises(ValueError, match="is not valid"): + Session(same_site=True) + + def test_same_site_none_without_secure_raises(self): + """same_site='None' without secure=True must raise ValueError.""" + with raises(ValueError, match="requires secure=True"): + Session(same_site="None", secure=False) + def test_custom_sid(self): """Tests Session with a custom cookie name.""" session = Session(sid="MYSESSID")